URL
TL;DR
- 这篇论文提出了去噪扩散隐式模型
Denoising Diffusion Implicit Models (DDIM)
模型,可以看作是对 Denoising Diffusion Probabilistic Models (DDPM)
模型的改进。
DDPM
的采样过程是一个 Markov
过程,Markov
过程只有知道 t
时刻的状态才能计算第 t-1
时刻,而 DDIM
的采样过程是一个非 Markov
过程。
DDIM
的优势是:
- 去噪速度更快。可以在采样(去噪)时使用更少的时间步数,从而加快采样速度,并且用
DDPM
训练的模型可以直接用于 DDIM
的采样,二者可以无缝衔接。
- 确定性。在
DDIM
中,给定一个模型和一个噪声图像和时间步数,可以确定性地生成一个图像(运行再多次也是同一个图)。在 DDPM
中,给定一个模型和一个噪声图像和时间步数,生成的图像是随机的(每次跑都不一样)。
Algorithm
DDIM
的公式是从 DDPM
公式通过复杂推导得到的,推导过程比较复杂,这里不做详细介绍,重点讲二者逆向过程(去噪)公式区别和使用上的区别。
DDPM 逆向过程公式
xt−1=αt1(xt−1−αˉtβt⋅ϵθ(xt,t))
- 其中:
- ϵθ(xt,t) 是模型预测的噪声,即
model(x_t, t)
- αt=1−βt
- αˉt=∏s=1tαs
t
取值是 999, 998,..., 1, 0
,即从 T-1
到 0
逐步去噪
DDIM 逆向过程公式
xt−1=αˉt−1(αˉtxt−αˉt−1⋅ϵθ(xt,t))+1−αˉt−1⋅ϵθ(xt,t)
- 其中:
- 大多数符号含义都和
DDPM
一样
- 只有
t
取值是 980, 960,..., 20, 0
这种非连续的取值
DDIM
公式拆解:
1. 预测原输入
predx0=αˉtxt−αˉt−1⋅ϵθ(xt,t)
- 通过当前噪声隐变量 xt 和模型预测的噪声 ϵθ(xt,t) 估计原始输入 x0,即去噪
2. 计算调整方向
direction_point=1−αˉt−1⋅ϵθ(xt,t)
- 根据噪声预测结果,计算从当前时间步
t
到前一个时间步 t-1
的调整方向,这一方向结合了噪声预测和噪声调度参数,用于引导隐变量的更新。
3. 更新隐变量
xt−1=αˉt−1⋅predx0+direction_point
- 将预测的原始输入 predx0 与调整方向结合,生成前一时刻的隐变量 xt−1。此步骤通过线性组合逐步去噪,最终逼近目标数据 x0。
DDPM 和 DDIM 使用上的区别
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to(device) scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")
alphas = scheduler.alphas alphas_cumprod = scheduler.alphas_cumprod
sample = torch.load("random_noise.pt").to(device)
for t in tqdm.tqdm(scheduler.timesteps): with torch.no_grad(): residual = model(sample, t).sample sample = ( sample - (1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t]) * residual ) / torch.sqrt(alphas[t]) if t > 1: noise = torch.randn_like(sample).to(device) sample += torch.sqrt(1 - alphas[t]) * noise
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to(device) scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(num_inference_steps=50)
sample = torch.load("random_noise.pt").to(device)
for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)): t = t.to(device).long() alpha_cumprod_t = scheduler.alphas_cumprod[t] alpha_cumprod_prev = ( scheduler.alphas_cumprod[scheduler.timesteps[i + 1]] if i + 1 < len(scheduler.timesteps) else torch.tensor(1.0) ) alpha_cumprod_t = alpha_cumprod_t.to(device) alpha_cumprod_prev = alpha_cumprod_prev.to(device) with torch.no_grad(): residual = model(sample, t).sample pred_x0 = (sample - torch.sqrt(1.0 - alpha_cumprod_t) * residual) / torch.sqrt( alpha_cumprod_t ) direction_xt = torch.sqrt(1.0 - alpha_cumprod_prev) * residual sample = torch.sqrt(alpha_cumprod_prev) * pred_x0 + direction_xt
|
二者对比分析
DDIM
只需要 50
次迭代就能生成高质量的图像,而 DDPM
需要 1000
次迭代。
- 生成的图像质量相似,
DDIM
生成的图像质量略高。
- 上面的代码中加载的噪声图是静态的,
DDIM
跑多次生成的图像是一样的,而 DDPM
跑多次生成的图像是不一样的。
- 二者去噪结果对比,左侧是
DDIM
,右侧是 DDPM
:

Thoughts
DDIM
解决了 DDPM
的两大痛点,算是一个很好的改进。
- 为后续的
LDM
等模型打下了基础。