Zhangzhe's Blog

The projection of my life.

0%

Denoising Diffusion Implicit Models

URL

TL;DR

  • 这篇论文提出了去噪扩散隐式模型 Denoising Diffusion Implicit Models (DDIM) 模型,可以看作是对 Denoising Diffusion Probabilistic Models (DDPM) 模型的改进。
  • DDPM 的采样过程是一个 Markov 过程,Markov 过程只有知道 t 时刻的状态才能计算第 t-1 时刻,而 DDIM 的采样过程是一个非 Markov 过程。
  • DDIM 的优势是:
    1. 去噪速度更快。可以在采样(去噪)时使用更少的时间步数,从而加快采样速度,并且用 DDPM 训练的模型可以直接用于 DDIM 的采样,二者可以无缝衔接。
    2. 确定性。在 DDIM 中,给定一个模型和一个噪声图像和时间步数,可以确定性地生成一个图像(运行再多次也是同一个图)。在 DDPM 中,给定一个模型和一个噪声图像和时间步数,生成的图像是随机的(每次跑都不一样)。

Algorithm

  • DDIM 的公式是从 DDPM 公式通过复杂推导得到的,推导过程比较复杂,这里不做详细介绍,重点讲二者逆向过程(去噪)公式区别和使用上的区别。

DDPM 逆向过程公式

xt1=1αt(xtβt1αˉtϵθ(xt,t))x_{t-1}=\frac{1}{\sqrt{\alpha_t}}(x_t-\frac{\beta_t}{\sqrt{1-\bar\alpha_t}}\cdot \epsilon_\theta(x_t,t))

  • 其中:
    • ϵθ(xt,t)\epsilon_\theta(x_t,t) 是模型预测的噪声,即 model(x_t, t)
    • αt=1βt\alpha_t=1-\beta_t
    • αˉt=s=1tαs\bar\alpha_t=\prod_{s=1}^t\alpha_s
    • t 取值是 999, 998,..., 1, 0,即从 T-10 逐步去噪

DDIM 逆向过程公式

xt1=αˉt1(xtαˉt1ϵθ(xt,t)αˉt)+1αˉt1ϵθ(xt,t)x_{t-1}=\sqrt {\bar\alpha_t-1}(\frac{x_t-\sqrt {\bar\alpha_t-1}\cdot\epsilon_\theta(x_t,t)}{\sqrt {\bar\alpha_t}})+\sqrt{1-\bar\alpha_{t-1}}\cdot\epsilon_\theta(x_t,t)

  • 其中:
    • 大多数符号含义都和 DDPM 一样
    • 只有 t 取值是 980, 960,..., 20, 0 这种非连续的取值

DDIM 公式拆解:

1. 预测原输入

predx0=xtαˉt1ϵθ(xt,t)αˉtpred_{x0}=\frac{x_t-\sqrt {\bar\alpha_t-1}\cdot\epsilon_\theta(x_t,t)}{\sqrt{\bar\alpha_t}}

  • 通过当前噪声隐变量 xtx_t 和模型预测的噪声 ϵθ(xt,t)\epsilon_\theta(x_t,t) 估计原始输入 x0x_0,即去噪

2. 计算调整方向

direction_point=1αˉt1ϵθ(xt,t)direction\_point=\sqrt{1-\bar\alpha_{t-1}}\cdot\epsilon_\theta(x_t,t)

  • 根据噪声预测结果,计算从当前时间步 t 到前一个时间步 t-1 的调整方向,这一方向结合了噪声预测和噪声调度参数,用于引导隐变量的更新。

3. 更新隐变量

xt1=αˉt1predx0+direction_pointx_{t-1}=\sqrt {\bar\alpha_t-1}\cdot pred_{x0}+direction\_point

  • 将预测的原始输入 predx0pred_{x0} 与调整方向结合,生成前一时刻的隐变量 xt1x_{t-1}。此步骤通过线性组合逐步去噪,最终逼近目标数据 x0x_0

DDPM 和 DDIM 使用上的区别

  • DDPM 去噪过程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to(device)
scheduler = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256")

# Get precalculated alphas and alpha bars from the scheduler
alphas = scheduler.alphas
alphas_cumprod = scheduler.alphas_cumprod

# Initialize sample with static random noise
sample = torch.load("random_noise.pt").to(device)

# DDPM denoising loop
# scheduler.timesteps = [999, 998, ..., 1, 0]
for t in tqdm.tqdm(scheduler.timesteps):
with torch.no_grad():
# Model prediction (noise residual)
residual = model(sample, t).sample

# DDPM denoising formula
sample = (
sample - (1 - alphas[t]) / torch.sqrt(1 - alphas_cumprod[t]) * residual
) / torch.sqrt(alphas[t])

# Add random noise only for t > 1
if t > 1:
noise = torch.randn_like(sample).to(device)
sample += torch.sqrt(1 - alphas[t]) * noise
  • DDIM 去噪过程
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
model = UNet2DModel.from_pretrained("google/ddpm-celebahq-256").to(device)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
# set inference steps
scheduler.set_timesteps(num_inference_steps=50)

# Initialize sample with static random noise
sample = torch.load("random_noise.pt").to(device)

# DDIM denoising loop
# scheduler.timesteps = [980, 960,..., 20, 0]
for i, t in enumerate(tqdm.tqdm(scheduler.timesteps)):
# 将时间步转换为LongTensor并确保在正确设备上
t = t.to(device).long()

# 获取当前和上一步的alpha累积乘积
alpha_cumprod_t = scheduler.alphas_cumprod[t]
alpha_cumprod_prev = (
scheduler.alphas_cumprod[scheduler.timesteps[i + 1]]
if i + 1 < len(scheduler.timesteps)
else torch.tensor(1.0)
)

# 将alpha值转换到相同设备
alpha_cumprod_t = alpha_cumprod_t.to(device)
alpha_cumprod_prev = alpha_cumprod_prev.to(device)

with torch.no_grad():
# 1. 预测噪声残差
residual = model(sample, t).sample

# 2. 计算预测的原始图像x0(去噪后的图像)
pred_x0 = (sample - torch.sqrt(1.0 - alpha_cumprod_t) * residual) / torch.sqrt(
alpha_cumprod_t
)

# 3. 计算下一步的样本方向
direction_xt = torch.sqrt(1.0 - alpha_cumprod_prev) * residual

# 4. 组合得到新的样本
sample = torch.sqrt(alpha_cumprod_prev) * pred_x0 + direction_xt

二者对比分析

  1. DDIM 只需要 50 次迭代就能生成高质量的图像,而 DDPM 需要 1000 次迭代。
  2. 生成的图像质量相似,DDIM 生成的图像质量略高。
  3. 上面的代码中加载的噪声图是静态的,DDIM 跑多次生成的图像是一样的,而 DDPM 跑多次生成的图像是不一样的。
  4. 二者去噪结果对比,左侧是 DDIM,右侧是 DDPM
    concat.png

Thoughts

  • DDIM 解决了 DDPM 的两大痛点,算是一个很好的改进。
  • 为后续的 LDM 等模型打下了基础。