1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| from torch import nn from functools import partial from einops.layers.torch import Rearrange, Reduce class PreNormResidual(nn.Module): def __init__(self, dim, fn): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) def forward(self, x): return self.fn(self.norm(x)) + x def FeedForward(dim, expansion_factor=4, dropout=0.0, dense=nn.Linear): return nn.Sequential( dense(dim, dim * expansion_factor), nn.GELU(), nn.Dropout(dropout), dense(dim * expansion_factor, dim), nn.Dropout(dropout), ) def MLPMixer( *, image_size, channels, patch_size, dim, depth, num_classes, expansion_factor=4, dropout=0.0 ): assert (image_size % patch_size) == 0, "image must be divisible by patch size" num_patches = (image_size // patch_size) ** 2 chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear return nn.Sequential( Rearrange( "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=patch_size, p2=patch_size ), nn.Linear((patch_size ** 2) * channels, dim), *[ nn.Sequential( PreNormResidual( dim, FeedForward(num_patches, expansion_factor, dropout, chan_first) ), PreNormResidual( dim, FeedForward(dim, expansion_factor, dropout, chan_last) ), ) for _ in range(depth) ], nn.LayerNorm(dim), Reduce("b n c -> b c", "mean"), nn.Linear(dim, num_classes) )
|