1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
| import functools import torch from torch import nn import torch.nn.functional as F from torch.nn.modules.conv import _ConvNd from torch.nn.modules.utils import _pair from torch.nn.parameter import Parameter class _routing(nn.Module): def __init__(self, in_channels, num_experts, dropout_rate): super(_routing, self).__init__() self.dropout = nn.Dropout(dropout_rate) self.fc = nn.Linear(in_channels, num_experts) def forward(self, x): x = torch.flatten(x) x = self.dropout(x) x = self.fc(x) return F.sigmoid(x) class CondConv2D(_ConvNd): r"""Learn specialized convolutional kernels for each example. As described in the paper `CondConv: Conditionally Parameterized Convolutions for Efficient Inference`_ , conditionally parameterized convolutions (CondConv), which challenge the paradigm of static convolutional kernels by computing convolutional kernels as a function of the input. Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` num_experts (int): Number of experts per layer Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` .. _CondConv: Conditionally Parameterized Convolutions for Efficient Inference: https://arxiv.org/abs/1904.04971 """ def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros", num_experts=3, dropout_rate=0.2, ): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(CondConv2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, ) self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) self._routing_fn = _routing(in_channels, num_experts, dropout_rate) self.weight = Parameter( torch.Tensor(num_experts, out_channels, in_channels // groups, *kernel_size) ) self.reset_parameters() def _conv_forward(self, input, weight): if self.padding_mode != "zeros": return F.conv2d( F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), weight, self.bias, self.stride, _pair(0), self.dilation, self.groups, ) return F.conv2d( input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups, ) def forward(self, inputs): b, _, _, _ = inputs.size() res = [] for input in inputs: input = input.unsqueeze(0) pooled_inputs = self._avg_pooling(input) routing_weights = self._routing_fn(pooled_inputs) kernels = torch.sum( routing_weights[:, None, None, None, None] * self.weight, 0 ) out = self._conv_forward(input, kernels) res.append(out) return torch.cat(res, dim=0)