1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
| import functools
import torch from torch import nn import torch.nn.functional as F from torch.nn.modules.conv import _ConvNd from torch.nn.modules.utils import _pair from torch.nn.parameter import Parameter
class _routing(nn.Module): def __init__(self, in_channels, num_experts, dropout_rate): super(_routing, self).__init__()
self.dropout = nn.Dropout(dropout_rate) self.fc = nn.Linear(in_channels, num_experts)
def forward(self, x): x = torch.flatten(x) x = self.dropout(x) x = self.fc(x) return F.sigmoid(x)
class CondConv2D(_ConvNd): r"""Learn specialized convolutional kernels for each example.
As described in the paper `CondConv: Conditionally Parameterized Convolutions for Efficient Inference`_ , conditionally parameterized convolutions (CondConv), which challenge the paradigm of static convolutional kernels by computing convolutional kernels as a function of the input.
Args: in_channels (int): Number of channels in the input image out_channels (int): Number of channels produced by the convolution kernel_size (int or tuple): Size of the convolving kernel stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0 padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'`` dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True`` num_experts (int): Number of experts per layer Shape: - Input: :math:`(N, C_{in}, H_{in}, W_{in})` - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where .. math:: H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor .. math:: W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Attributes: weight (Tensor): the learnable weights of the module of shape :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},` :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`. The values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}` bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``, then the values of these weights are sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
.. _CondConv: Conditionally Parameterized Convolutions for Efficient Inference: https://arxiv.org/abs/1904.04971
"""
def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode="zeros", num_experts=3, dropout_rate=0.2, ): kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super(CondConv2D, self).__init__( in_channels, out_channels, kernel_size, stride, padding, dilation, False, _pair(0), groups, bias, padding_mode, )
self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1)) self._routing_fn = _routing(in_channels, num_experts, dropout_rate)
self.weight = Parameter( torch.Tensor(num_experts, out_channels, in_channels // groups, *kernel_size) )
self.reset_parameters()
def _conv_forward(self, input, weight): if self.padding_mode != "zeros": return F.conv2d( F.pad(input, self._padding_repeated_twice, mode=self.padding_mode), weight, self.bias, self.stride, _pair(0), self.dilation, self.groups, ) return F.conv2d( input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups, )
def forward(self, inputs): b, _, _, _ = inputs.size() res = [] for input in inputs: input = input.unsqueeze(0) pooled_inputs = self._avg_pooling(input) routing_weights = self._routing_fn(pooled_inputs) kernels = torch.sum( routing_weights[:, None, None, None, None] * self.weight, 0 ) out = self._conv_forward(input, kernels) res.append(out) return torch.cat(res, dim=0)
|