Zhangzhe's Blog

The projection of my life.

0%

MLP-Mixer: An all-MLP Architecture for Vision

URL

https://arxiv.org/pdf/2105.01601.pdf

TL;DR

  • 作者认为 CNN 与 Transformer 对视觉任务有利,但不是必要,MLP 也能在视觉任务中做的很好
  • 效果比 VIT 略差,但速度快

Algorithm

作者认为的 CV 的本质

  • 局部的特征提取 + 全局特征相关性的建模
  • CNN 天生对局部的特征提取很擅长,全局特征相关性的分析需要靠下采样 + 堆叠层来增大感受野完成
  • Transformer 与 CNN 相反,Transformation 很擅长全局特征相关性建模(关系矩阵),但局部特征提取比较弱,所以通常需要在超大数据集上预训练来获得局部先验
  • MLP 也是擅长全局特征相关性建模,对于不擅长的局部特征提取,使用 dim shuffle 交换 h*w 和 c,将全局特征提取变成针对 channel 的局部特征提取

网络结构

lm1

  • lecun 嘲讽

lecun 嘲讽

Thoughts

  • 虽然 lecun 说的没错,但我想作者的意思是 MLP-mixer 架构不依赖于传统的 Conv 3*3 结构而是使用 matmul + dim shuffle 一样可以做到 局部特征提取 + 全局依赖
  • Lecun 的嘲讽翻译过来就是:“MLP 的本质是 Conv,Matmul 的本质是 Conv1d with 1*1 kernel”,这明显是抬杠…

实现代码(非伪代码)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce


class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x):
return self.fn(self.norm(x)) + x


def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear):
return nn.Sequential(
dense(dim, dim * expansion_factor),
nn.GELU(),
nn.Dropout(dropout),
dense(dim * expansion_factor, dim),
nn.Dropout(dropout),
)


def MLPMixer(
*,
image_size,
channels,
patch_size,
dim,
depth,
num_classes,
expansion_factor=4,
dropout=0.0
):
assert (image_size % patch_size) == 0, "image must be divisible by patch size"
num_patches = (image_size // patch_size) ** 2
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear

return nn.Sequential(
Rearrange(
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size
),
nn.Linear((patch_size ** 2) * channels, dim),
*[
nn.Sequential(
PreNormResidual(
dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)
),
PreNormResidual(
dim, FeedForward(dim, expansion_factor, dropout, chan_last)
),
)
for _ in range(depth)
],
nn.LayerNorm(dim),
Reduce("b n c -> b c", "mean"),
nn.Linear(dim, num_classes)
)