Zhangzhe's Blog

The projection of my life.

0%

High-Resolution Image Synthesis with Latent Diffusion Models

URL

TL;DR

  • LDMstable diffusion 系列的开山之作,让 Diffusion ModelImage Synthesis 领域大放异彩
  • 传统的 Diffusion Model 有两大问题:
    1. 没办法控制生成内容,只能确保生成的内容和训练数据集风格比较类似
    2. 在像素尺度上做去噪,计算量大,导致只能生成较小分辨率的图,且很慢
  • LDM 解决了上述的两个问题:
    1. 通过编码器将 文本 / mask / bbox 等条件信息转成 conditioning embedding,再通过 cross attention 机制将条件信息和 latent space 中的噪声结合起来做去噪,让条件信息可引导图片生成
    2. 通过 VAE 将图片压缩到 latent space 中,再进行去噪,计算量小,速度快

Algorithm

总体流程

LDM.png

  1. 生成隐空间下的随机噪声
  2. 将条件信息通过各自类型的编码器编码成 conditioning embedding,例如文本编码使用 CLIP text encoder
  3. latent noiseconditioning embeddingtimestep embedding 输入到 UNet 中进行多轮迭代去噪(50 step)
  4. 去噪后的 latent 通过 VAE decoder 解码成图片

Conditioning UNet

  • 通过 交叉注意力机制文本条件时序 通过 UNet 嵌入到 Noised Latent
  • 具体来说:TimeEmbedding 直接和图像特征相加,TextEmbedding 和图像特征做 CrossAttention,(TextEmbedding 作为 KV,图像特征作为 Q)
  • Conditioning UNet 示意图,两次下采样 + 中间块 + 两次上采样
graph TD
    Input[Noised Latent: 32x32x4] --> DownBlock1[CrossAttnDownBlock2D]
    DownBlock1 --> DownBlock2[CrossAttnDownBlock2D]
    DownBlock2 --> MidBlock[UNetMidBlock2DCrossAttn]
    MidBlock --> UpBlock1[CrossAttnUpBlock2D]
    UpBlock1 --> UpBlock2[CrossAttnUpBlock2D]
    UpBlock2 --> Output[Denoised Latent: 32x32x4]
  
    TextEncoder[Text Encoder] -->|Text Embedding| DownBlock1
    TextEncoder -->|Text Embedding| DownBlock2
    TextEncoder -->|Text Embedding| MidBlock
    TextEncoder -->|Text Embedding| UpBlock1
    TextEncoder -->|Text Embedding| UpBlock2
  
    Time[Timestep] -->|Time Embedding| DownBlock1
    Time -->|Time Embedding| DownBlock2
    Time -->|Time Embedding| MidBlock
    Time -->|Time Embedding| UpBlock1
    Time -->|Time Embedding| UpBlock2
  • CrossAttnBlock2D 结构示意
graph TD
    %% 输入节点
    Input[输入特征图 h_in] --> ResNet
    TimeEmb[时间嵌入 t_emb] --> MLP
    TextEmb[文本条件 y_text] --> ProjText
  
    %% 主干计算路径
    ResNet[ResNet块] --> Add
    MLP[MLP时间投影] --> Add
    Add[逐元素相加] --> GroupNorm
    GroupNorm[GroupNorm] --> Conv1
    Conv1[Conv2D 1x1] --> CrossAttn
  
    %% 交叉注意力分支
    ProjText[文本投影 W_k/W_v] --> CrossAttn
    Conv2[Conv2D 1x1] --> Merge
    CrossAttn[交叉注意力层] --> Merge
  
    %% 残差连接
    Input --> Conv2
    Merge[特征合并] --> LayerNorm
    LayerNorm[LayerNorm] --> Output[输出特征图 h_out]
  • DecoderAttentionBlock2D 结构示意
graph TD
    X[("Input x
Shape: 1,512,32,32")] --> Norm["Normalize (GroupNorm)
Output: 1,512,32,32"] Norm --> Q["Q Conv2d(1x1)
Output: 1,512,32,32"] Norm --> K["K Conv2d(1x1)
Output: 1,512,32,32"] Norm --> V["V Conv2d(1x1)
Output: 1,512,32,32"] Q --> ReshapeQ["Reshape & Permute
1,512,32,32 → 1,1024,512"] K --> ReshapeK["Reshape
1,512,32,32 → 1,512,1024"] ReshapeQ --> MatmulQK["Matmul(Q,K)
1,1024,512 × 1,512,1024 → 1,1024,1024"] ReshapeK --> MatmulQK MatmulQK --> Scale["Scale (×1/√512)
1,1024,1024"] Scale --> Softmax["Softmax
1,1024,1024"] V --> ReshapeV["Reshape
1,512,32,32 → 1,512,1024"] Softmax --> PermuteSoftmax["Permute
1,1024,1024 → 1,1024,1024"] ReshapeV --> MatmulVW["Matmul(V, Softmax)
1,512,1024 × 1,1024,1024 → 1,512,1024"] PermuteSoftmax --> MatmulVW MatmulVW --> ReshapeOut["Reshape
1,512,1024 → 1,512,32,32"] ReshapeOut --> ProjOut["Proj_out Conv2d(1x1)
1,512,32,32"] ProjOut --> Add["Add (x + h_)
1,512,32,32"] X --> Add Add --> Output[("Final Output
1,512,32,32")]
  • 下面附上 Conditioning UNet block 的实现代码,可以看出非常优雅:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import torch
import torch.nn as nn
import torch.nn.functional as F


class Attention(nn.Module):
def __init__(self, in_dim, context_dim=None):
super().__init__()
self.to_q = nn.Linear(in_dim, in_dim, bias=False)
self.to_k = nn.Linear(
context_dim if context_dim else in_dim, in_dim, bias=False
)
self.to_v = nn.Linear(
context_dim if context_dim else in_dim, in_dim, bias=False
)
self.to_out = nn.Sequential(nn.Linear(in_dim, in_dim), nn.Dropout(0.0))

def forward(self, x, context=None):
q = self.to_q(x)
k = self.to_k(context if context is not None else x)
v = self.to_v(context if context is not None else x)

attn = torch.einsum("b i d, b j d -> b i j", q, k) * (x.shape[-1] ** -0.5)
attn = F.softmax(attn, dim=-1)
out = torch.einsum("b i j, b j d -> b i d", attn, v)
return self.to_out(out)


class GEGLU(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.proj = nn.Linear(in_dim, hidden_dim * 2)

def forward(self, x):
x_proj = self.proj(x)
x1, x2 = x_proj.chunk(2, dim=-1)
return x1 * F.gelu(x2)


class FeedForward(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
GEGLU(in_dim, hidden_dim), nn.Dropout(0.0), nn.Linear(hidden_dim, in_dim)
)

def forward(self, x):
return self.net(x)


class BasicTransformerBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=1e-5)
self.attn1 = Attention(dim)
self.norm2 = nn.LayerNorm(dim, eps=1e-5)
self.attn2 = Attention(dim, context_dim=768)
self.norm3 = nn.LayerNorm(dim, eps=1e-5)
self.ff = FeedForward(dim, 1280)

def forward(self, x, context=None):
# Self attention
x = self.attn1(self.norm1(x)) + x
# Cross attention
x = self.attn2(self.norm2(x), context=context) + x
# Feed forward
x = self.ff(self.norm3(x)) + x
return x


class Transformer2DModel(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.norm = nn.GroupNorm(32, in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.transformer_blocks = nn.ModuleList([BasicTransformerBlock(in_channels)])
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

def forward(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c)

for block in self.transformer_blocks:
x = block(x, context)

x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
x = self.proj_out(x)
return x + x_in


class ResnetBlock2D(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.norm1 = nn.GroupNorm(32, in_channels, eps=1e-5, affine=True)
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.time_emb_proj = nn.Linear(1280, in_channels)
self.norm2 = nn.GroupNorm(32, in_channels, eps=1e-5, affine=True)
self.dropout = nn.Dropout(0.0)
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.nonlinearity = nn.SiLU()

def forward(self, x, time_emb=None):
h = x
h = self.norm1(h)
h = self.nonlinearity(h)
h = self.conv1(h)

if time_emb is not None:
time_emb = self.nonlinearity(time_emb)
time_emb = self.time_emb_proj(time_emb)[:, :, None, None]
h = h + time_emb

h = self.norm2(h)
h = self.nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
return h + x


class Downsample2D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.conv(x)


class CrossAttnDownBlock2D(nn.Module):
def __init__(self, in_channels=320):
super().__init__()
self.attentions = nn.ModuleList(
[Transformer2DModel(in_channels) for _ in range(2)]
)
self.resnets = nn.ModuleList([ResnetBlock2D(in_channels) for _ in range(2)])
self.downsamplers = nn.ModuleList([Downsample2D(in_channels)])

def forward(self, x, context=None, time_emb=None):
for attn, resnet in zip(self.attentions, self.resnets):
x = attn(x, context)
x = resnet(x, time_emb)

for downsampler in self.downsamplers:
x = downsampler(x)

return x


# 测试代码
if __name__ == "__main__":
block = CrossAttnDownBlock2D(in_channels=320)
x = torch.randn(1, 320, 64, 64)
context = torch.randn(1, 77, 768) # 文本条件,77 个 token 组成的文本序列经过 CLIP 编码成向量
time_emb = torch.randn(1, 1280) # 时间嵌入,一个时间步(例如:961)变成一个 enbedding

output = block(x, context, time_emb)
print(f"输入形状: {x.shape} -> 输出形状: {output.shape}")
# 预期输出: torch.Size([1, 320, 32, 32])

Thoughts

  • 论文思路无比清晰,且说服力很强,把很多领域的知识结合起来,真正把图像生成在实用性方面推到了一个新的高度