URL
TL;DR
这篇论文提出了去噪扩散隐式模型 Denoising Diffusion Implicit Models (DDIM)
模型,可以看作是对 Denoising Diffusion Probabilistic Models (DDPM)
模型的改进。
DDPM
的采样过程是一个 Markov
过程,Markov
过程只有知道 t
时刻的状态才能计算第 t-1
时刻,而 DDIM
的采样过程是一个非 Markov
过程。
DDIM
的优势是:
去噪速度更快。可以在采样(去噪)时使用更少的时间步数,从而加快采样速度,并且用 DDPM
训练的模型可以直接用于 DDIM
的采样,二者可以无缝衔接。
确定性。在 DDIM
中,给定一个模型和一个噪声图像和时间步数,可以确定性地生成一个图像(运行再多次也是同一个图)。在 DDPM
中,给定一个模型和一个噪声图像和时间步数,生成的图像是随机的(每次跑都不一样)。
Algorithm
DDIM
的公式是从 DDPM
公式通过复杂推导得到的,推导过程比较复杂,这里不做详细介绍,重点讲二者逆向过程(去噪)公式区别和使用上的区别。
DDPM 逆向过程公式
x t − 1 = 1 α t ( x t − β t 1 − α ˉ t ⋅ ϵ θ ( x t , t ) ) x_{t-1}=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot \epsilon_\theta(x_t,t))
x t − 1 = α t 1 ( x t − 1 − α ˉ t β t ⋅ ϵ θ ( x t , t ) )
其中:
ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵ θ ( x t , t ) 是模型预测的噪声,即 model(x_t, t)
α t = 1 − β t \alpha_t=1-\beta_t α t = 1 − β t
α ˉ t = ∏ s = 1 t α s \bar\alpha_t=\prod_{s=1}^t\alpha_s α ˉ t = ∏ s = 1 t α s
t
取值是 999, 998,..., 1, 0
,即从 T-1
到 0
逐步去噪
DDIM 逆向过程公式
x t − 1 = α ˉ t − 1 ( x t − α ˉ t − 1 ⋅ ϵ θ ( x t , t ) α ˉ t ) + 1 − α ˉ t − 1 ⋅ ϵ θ ( x t , t ) x_{t-1}=\sqrt {\bar\alpha_t-1}(\frac{x_t-\sqrt {\bar\alpha_t-1}\cdot\epsilon_\theta(x_t,t)}{\sqrt {\bar\alpha_t}})+\sqrt{1-\bar\alpha_{t-1}}\cdot\epsilon_\theta(x_t,t)
x t − 1 = α ˉ t − 1 ( α ˉ t x t − α ˉ t − 1 ⋅ ϵ θ ( x t , t ) ) + 1 − α ˉ t − 1 ⋅ ϵ θ ( x t , t )
其中:
大多数符号含义都和 DDPM
一样
只有 t
取值是 980, 960,..., 20, 0
这种非连续的取值
DDIM
公式拆解:
1. 预测原输入
p r e d x 0 = x t − α ˉ t − 1 ⋅ ϵ θ ( x t , t ) α ˉ t pred_{x0}=\frac{x_t-\sqrt {\bar\alpha_t-1}\cdot\epsilon_\theta(x_t,t)}{\sqrt{\bar\alpha_t}}
p r e d x 0 = α ˉ t x t − α ˉ t − 1 ⋅ ϵ θ ( x t , t )
通过当前噪声隐变量 x t x_t x t 和模型预测的噪声 ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵ θ ( x t , t ) 估计原始输入 x 0 x_0 x 0 ,即去噪
2. 计算调整方向
d i r e c t i o n _ p o i n t = 1 − α ˉ t − 1 ⋅ ϵ θ ( x t , t ) direction\_point=\sqrt{1-\bar\alpha_{t-1}}\cdot\epsilon_\theta(x_t,t)
d i r e c t i o n _ p o i n t = 1 − α ˉ t − 1 ⋅ ϵ θ ( x t , t )
根据噪声预测结果,计算从当前时间步 t
到前一个时间步 t-1
的调整方向,这一方向结合了噪声预测和噪声调度参数,用于引导隐变量的更新。
3. 更新隐变量
x t − 1 = α ˉ t − 1 ⋅ p r e d x 0 + d i r e c t i o n _ p o i n t x_{t-1}=\sqrt {\bar\alpha_t-1}\cdot pred_{x0}+direction\_point
x t − 1 = α ˉ t − 1 ⋅ p r e d x 0 + d i r e c t i o n _ p o i n t
将预测的原始输入 p r e d x 0 pred_{x0} p r e d x 0 与调整方向结合,生成前一时刻的隐变量 x t − 1 x_{t-1} x t − 1 。此步骤通过线性组合逐步去噪,最终逼近目标数据 x 0 x_0 x 0 。
DDPM 和 DDIM 使用上的区别
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256" ).to(device) scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256" ) alphas = scheduler.alphas alphas_cumprod = scheduler.alphas_cumprod sample = torch.load("random_noise.pt" ).to(device) for t in tqdm.tqdm(scheduler.timesteps): with torch.no_grad(): residual = model(sample, t).sample sample = ( sample - (1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t]) * residual ) / torch.sqrt(alphas[t]) if t > 1 : noise = torch.randn_like(sample).to(device) sample += torch.sqrt(1 - alphas[t]) * noise
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256" ).to(device) scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256" ) scheduler.set_timesteps(num_inference_steps=50 ) sample = torch.load("random_noise.pt" ).to(device) for i, t in enumerate (tqdm.tqdm(scheduler.timesteps)): t = t.to(device).long() alpha_cumprod_t = scheduler.alphas_cumprod[t] alpha_cumprod_prev = ( scheduler.alphas_cumprod[scheduler.timesteps[i + 1 ]] if i + 1 < len (scheduler.timesteps) else torch.tensor(1.0 ) ) alpha_cumprod_t = alpha_cumprod_t.to(device) alpha_cumprod_prev = alpha_cumprod_prev.to(device) with torch.no_grad(): residual = model(sample, t).sample pred_x0 = (sample - torch.sqrt(1.0 - alpha_cumprod_t) * residual) / torch.sqrt( alpha_cumprod_t ) direction_xt = torch.sqrt(1.0 - alpha_cumprod_prev) * residual sample = torch.sqrt(alpha_cumprod_prev) * pred_x0 + direction_xt
二者对比分析
DDIM
只需要 50
次迭代就能生成高质量的图像,而 DDPM
需要 1000
次迭代。
生成的图像质量相似,DDIM
生成的图像质量略高。
上面的代码中加载的噪声图是静态的,DDIM
跑多次生成的图像是一样的,而 DDPM
跑多次生成的图像是不一样的。
二者去噪结果对比,左侧是 DDIM
,右侧是 DDPM
:
Thoughts
DDIM
解决了 DDPM
的两大痛点,算是一个很好的改进。
为后续的 LDM
等模型打下了基础。