Zhangzhe's Blog

The projection of my life.

0%

ULSAM: Ultra-Lightweight Subspace Attention Module for Compact Convolutional Neural Networks

URL

https://arxiv.org/pdf/2006.15102.pdf

TL;DR

  1. ULSAM 是一个超轻量级的子空间注意力网络,适合用在轻量级的网络中,例如 MobileNetShuffleNet
  2. 适合用在图像细粒度分类任务中,能减少大约 13%Flops 和大约 25%params,在 ImageNet - 1K 和 其他三个细粒度分类数据集上 Top1 error 分别降低 0.27%1%
  3. SENet 有点类似,SENetC 维度上添加注意力,ULSAMHW 维度上添加注意力

Algorithm

网络结构

ul1.png

  • 将输入 tensor F 按照通道分为 g 组:CHW --> gGHW,$F =
    [F_1,F_2,…,F_g] $ ,每一组 FnF_n 被称为一个子空间
  • 对每个子空间 FnF_n 进行如下运算:
    • Depth-wise Conv(kernel_size = 1)
    • MaxPool2d(kernel_size = 3, stride = 1, padding = 1), 这一步可以获得感受野同时减小方差
    • Point-wise Conv(kernel_size = 1), kernels = 1
    • softmax
    • out = x + x * softmax
  • 将所有子空间的结果 concat 作为输出

公式表示

dwn=DW11(Fn)dw_n = {DW}^{1*1}(F_n)
maxpooln=maxpool33,1(dwn)maxpool_n = {maxpool}^{3*3, 1}(dw_n)
pwn=PW1(maxpooln)pw_n = {PW}^1(maxpool_n)
An=softmax(pwn)A_{n} = softmax(pw_n)
F^n=(AnFn)Fn\hat F_n = (A_n \otimes F_n) \oplus F_n
F^=concat([F^1,F^2,...,F^g])\hat F = concat([\hat F_1,\hat F_2,...,\hat F_g])

源码表示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
torch.set_default_tensor_type(torch.cuda.FloatTensor)
class SubSpace(nn.Module):
"""
Subspace class.
...
Attributes
----------
nin : int
number of input feature volume.
Methods
-------
__init__(nin)
initialize method.
forward(x)
forward pass.
"""
def __init__(self, nin):
super(SubSpace, self).__init__()
self.conv_dws = nn.Conv2d(
nin, nin, kernel_size=1, stride=1, padding=0, groups=nin
)
self.bn_dws = nn.BatchNorm2d(nin, momentum=0.9)
self.relu_dws = nn.ReLU(inplace=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
self.conv_point = nn.Conv2d(
nin, 1, kernel_size=1, stride=1, padding=0, groups=1
)
self.bn_point = nn.BatchNorm2d(1, momentum=0.9)
self.relu_point = nn.ReLU(inplace=False)
self.softmax = nn.Softmax(dim=2)
def forward(self, x):
out = self.conv_dws(x)
out = self.bn_dws(out)
out = self.relu_dws(out)
out = self.maxpool(x)
out = self.conv_point(out)
out = self.bn_point(out)
out = self.relu_point(out)
m, n, p, q = out.shape
out = self.softmax(out.view(m, n, -1))
out = out.view(m, n, p, q)
out = out.expand(x.shape[0], x.shape[1], x.shape[2], x.shape[3])
out = torch.mul(out, x)
out = out + x
return out

Grad-CAM++ 热力图

  • ULSAM 加入到 MobileNet v1v2 之后,模型的 focus 能力更好
    ul2.jpeg

Thoughts

  • 虽然 Flopsparams 减小或者几乎不变,但引入了很多 element-wise 运算,估计速度会慢
  • SENet 使用 sigmoid 来处理权重,而 ULSAM 使用 HW 维度上 softmax 处理权重,所以需要使用残差结构

网络表现

  • 通过控制变量实验,验证子空间数量 g 和替换位置 pos 对模型表现的影响

对比实验

ul3.png
ul4.png
ul5.png
ul6.png