Zhangzhe's Blog

The projection of my life.

0%

High-Resolution Image Synthesis with Latent Diffusion Models

URL

TL;DR

  • LDMstable diffusion 系列的开山之作,让 Diffusion ModelImage Synthesis 领域大放异彩
  • 传统的 Diffusion Model 有两大问题:
    1. 没办法控制生成内容,只能确保生成的内容和训练数据集风格比较类似
    2. 在像素尺度上做去噪,计算量大,导致只能生成较小分辨率的图,且很慢
  • LDM 解决了上述的两个问题:
    1. 通过编码器将 文本 / mask / bbox 等条件信息转成 conditioning embedding,再通过 cross attention 机制将条件信息和 latent space 中的噪声结合起来做去噪,让条件信息可引导图片生成
    2. 通过 VAE 将图片压缩到 latent space 中,再进行去噪,计算量小,速度快

Algorithm

总体流程

LDM.png

  1. 生成隐空间下的随机噪声
  2. 将条件信息通过各自类型的编码器编码成 conditioning embedding
  3. latent noiseconditioning embeddingtimestep embedding 输入到 UNet 中进行多轮迭代去噪(50 step)
  4. 去噪后的 latent 通过 VAE decoder 解码成图片

Conditioning UNet

  • 通过 交叉注意力机制文本条件时序 通过 UNet 嵌入到 Noised Latent
  • Conditioning UNet 示意图,两次下采样 + 中间块 + 两次上采样
graph TD
    Input[Noised Latent: 32x32x4] --> DownBlock1[CrossAttnDownBlock2D]
    DownBlock1 --> DownBlock2[CrossAttnDownBlock2D]
    DownBlock2 --> MidBlock[UNetMidBlock2DCrossAttn]
    MidBlock --> UpBlock1[CrossAttnUpBlock2D]
    UpBlock1 --> UpBlock2[CrossAttnUpBlock2D]
    UpBlock2 --> Output[Denoised Latent: 32x32x4]
  
    TextEncoder[Text Encoder] -->|Text Embedding| DownBlock1
    TextEncoder -->|Text Embedding| DownBlock2
    TextEncoder -->|Text Embedding| MidBlock
    TextEncoder -->|Text Embedding| UpBlock1
    TextEncoder -->|Text Embedding| UpBlock2
  
    Time[Timestep] -->|Time Embedding| DownBlock1
    Time -->|Time Embedding| DownBlock2
    Time -->|Time Embedding| MidBlock
    Time -->|Time Embedding| UpBlock1
    Time -->|Time Embedding| UpBlock2
  • CrossAttnBlock2D 结构示意
graph TD
    %% 输入节点
    Input[输入特征图 h_in] --> ResNet
    TimeEmb[时间嵌入 t_emb] --> MLP
    TextEmb[文本条件 y_text] --> ProjText
  
    %% 主干计算路径
    ResNet[ResNet块] --> Add
    MLP[MLP时间投影] --> Add
    Add[逐元素相加] --> GroupNorm
    GroupNorm[GroupNorm] --> Conv1
    Conv1[Conv2D 1x1] --> CrossAttn
  
    %% 交叉注意力分支
    ProjText[文本投影 W_k/W_v] --> CrossAttn
    Conv2[Conv2D 1x1] --> Merge
    CrossAttn[交叉注意力层] --> Merge
  
    %% 残差连接
    Input --> Conv2
    Merge[特征合并] --> LayerNorm
    LayerNorm[LayerNorm] --> Output[输出特征图 h_out]
  • DecoderAttentionBlock2D 结构示意
graph TD
    X[("Input x
Shape: 1,512,32,32")] --> Norm["Normalize (GroupNorm)
Output: 1,512,32,32"] Norm --> Q["Q Conv2d(1x1)
Output: 1,512,32,32"] Norm --> K["K Conv2d(1x1)
Output: 1,512,32,32"] Norm --> V["V Conv2d(1x1)
Output: 1,512,32,32"] Q --> ReshapeQ["Reshape & Permute
1,512,32,32 → 1,1024,512"] K --> ReshapeK["Reshape
1,512,32,32 → 1,512,1024"] ReshapeQ --> MatmulQK["Matmul(Q,K)
1,1024,512 × 1,512,1024 → 1,1024,1024"] ReshapeK --> MatmulQK MatmulQK --> Scale["Scale (×1/√512)
1,1024,1024"] Scale --> Softmax["Softmax
1,1024,1024"] V --> ReshapeV["Reshape
1,512,32,32 → 1,512,1024"] Softmax --> PermuteSoftmax["Permute
1,1024,1024 → 1,1024,1024"] ReshapeV --> MatmulVW["Matmul(V, Softmax)
1,512,1024 × 1,1024,1024 → 1,512,1024"] PermuteSoftmax --> MatmulVW MatmulVW --> ReshapeOut["Reshape
1,512,1024 → 1,512,32,32"] ReshapeOut --> ProjOut["Proj_out Conv2d(1x1)
1,512,32,32"] ProjOut --> Add["Add (x + h_)
1,512,32,32"] X --> Add Add --> Output[("Final Output
1,512,32,32")]

Thoughts

  • 论文思路无比清晰,且说服力很强,把很多领域的知识结合起来,真正把图像生成在实用性方面推到了一个新的高度