1import torch
2import torch.nn.functional as F
3import matplotlib.pyplot as plt
4
5from torch import nn
6from torch import Tensor
7from PIL import Image
8from torchvision.transforms import Compose, Resize, ToTensor
9from einops import rearrange, reduce, repeat
10from einops.layers.torch import Rearrange, Reduce
11from typing import Tuple, Optional
1class PatchEmbedding(nn.Module):
2 def __init__(self, in_channels=3, patch_size=4, embed_dim=768):
3 self.patch_size = patch_size
4 super().__init__()
5 self.projection = nn.Sequential(
6 nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
7 Rearrange('b e (h) (w) -> b (h w) e'),
8 )
9
10 def forward(self, x: Tensor) -> Tensor:
11 x = self.projection(x)
12 return x
1class MLP(nn.Sequential):
2 def __init__(self, dim, mlp_ratio=4.0, dropout=0.):
3 super(MLP, self).__init__(
4 nn.Linear(dim, int(dim * mlp_ratio)),
5 nn.GELU(),
6 nn.Dropout(dropout),
7 nn.Linear(int(dim * mlp_ratio), dim),
8 nn.Dropout(dropout)
9 )
1def windows_partition(x, window_size):
2 x = rearrange(
3 x,
4 "b (win_size1 h2) (win_size2 w2) c -> (b h2 w2) (win_size1 win_size2) c",
5 win_size1=window_size,
6 win_size2=window_size
7 )
8 return
1def windows_reverse(windows, window_size, h, w):
2 h2 = h // window_size
3 w2 = w // window_size
4 # logically we should reverse to the shape "b (win_size1 h2) (win_size2 w2) c",
5 # but technically we are going to add the result with skip connection,
6 # therefore we reshape directly to "b (win_size1 h2 win_size2 w2) c"
7 x = rearrange(
8 windows,
9 "(b h2 w2) (win_size1 win_size2) c -> b (win_size1 h2 win_size2 w2) c",
10 win_size1=window_size,
11 win_size2=window_size,
12 h2=h2,
13 w2=w2
14 )
15 return x
1class WindowAttention(nn.Module):
2 def __init__(self, dim, num_heads, window_size):
3 super(WindowAttention, self).__init__()
4 self.dim = dim
5 self.window_size = window_size
6 self.head_dim = dim // num_heads
7 self.num_heads = num_heads
8 self.scale = self.head_dim ** -0.5
9 self.softmax = nn.Softmax(-1)
10 self.qkv = nn.Linear(dim, dim * 3)
11 self.proj = nn.Linear(dim, dim)
12
13 """ <--- Create Relative Position Index """
14 coords_h = torch.arange(window_size)
15 coords_w = torch.arange(window_size)
16 coords_flatten = torch.stack(torch.meshgrid([coords_h, coords_w])).flatten(1)
17 relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
18 relative_coords = relative_coords.permute(1, 2, 0).contiguous()
19 relative_coords[:, :, 0] += self.window_size - 1
20 relative_coords[:, :, 1] += self.window_size - 1
21 relative_coords[:, :, 0] *= 2 * self.window_size - 1
22
23 # record the index from which we take value from a feature vector
24 relative_position_index = relative_coords.sum(-1)
25
26 # we don't need to learn the indexing
27 self.register_buffer("relative_position_index", relative_position_index)
28 """ Create Relative Position Index --->"""
29
30 self.relative_position_bias_table = nn.Parameter(
31 torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
32 )
33 nn.init.trunc_normal_(self.relative_position_bias_table)
34
35 def forward(self, x, mask=None):
36 # x: [b, num_img_tokens, embed_dim]
37 # mask: [n, ws*ws, ws*ws]
38 x = self.qkv(x)
39 qkv = rearrange(x, "b n (h d qkv) -> (qkv) b h n d", h=self.num_heads, qkv=3)
40 q, k, v = qkv[0], qkv[1], qkv[2]
41
42 q = q * self.scale
43 attn = torch.einsum("bhqd, bhkd -> bhqk", q, k) # attn = Q * K^T
44
45 relative_position_bias = self.relative_position_bias_table.index_select(
46 0,
47 self.relative_position_index.reshape((-1,))
48 ).reshape((self.window_size**2, self.window_size**2, -1))
49
50 # shift number of heads back to the first dimension
51 # unsqueeze in order to broadcast for batches
52 relative_position_bias = relative_position_bias.permute((2, 0, 1)).unsqueeze(0)
53
54 attn = attn + relative_position_bias
55
56 if mask is not None:
57 discard_mask = 1 - mask
58 discard_mask = discard_mask * -1e10
59 attn = attn + discard_mask
60
61 attn = self.softmax(attn)
62 out = torch.einsum("bhai, bhid -> bhad", attn, v) # attn * V
63 out = rearrange(out, "b h n d -> b n (h d)")
64 out = self.proj(out)
65 return out
1def generate_mask(window_size=4, shift_size=2, input_resolution=(8, 8)):
2 H, W = input_resolution
3 img_mask = torch.zeros((1, H, W, 1)) # we keep the last dimension becuase we want to apply windows_partition
4 h_slices = [slice(0, -window_size),
5 slice(-window_size, -shift_size),
6 slice(-shift_size, None)]
7 w_slices = [slice(0, -window_size),
8 slice(-window_size, -shift_size),
9 slice(-shift_size, None)]
10
11 count = 0
12 for h in h_slices:
13 for w in w_slices:
14 img_mask[:, h, w, :] = count
15 count += 1
16
17 windows_mask = windows_partition(img_mask, window_size)
18 # windows_mask: [(b h2 w2), (win_size1 win_size2), 1]
19 windows_mask = windows_mask.reshape((-1, window_size * window_size))
20 # [n, 1, ws*ws] - [n, ws*ws, 1]
21 attn_mask = windows_mask.unsqueeze(1) - windows_mask.unsqueeze(2)
22 attn_mask = torch.where(attn_mask == 0, 1., 0.)
23 return attn_mask
1class SwinBlock(nn.Module):
2 def __init__(self, dim, input_resolution, num_heads, window_size, shift_size=0):
3 super(SwinBlock, self).__init__()
4 self.dim = dim
5 self.resolution = input_resolution
6 self.window_size = window_size
7 self.shift_size = shift_size
8
9 self.attn_norm = nn.LayerNorm(dim)
10 self.attn = WindowAttention(dim, num_heads, window_size)
11
12 self.mlp_norm = nn.LayerNorm(dim)
13 self.mlp = MLP(dim)
14
15 if self.shift_size > 0:
16 attn_mask = generate_mask(window_size=self.window_size,
17 shift_size=self.shift_size,
18 input_resolution=self.resolution)
19 else:
20 attn_mask = None
21 self.register_buffer('attn_mask', attn_mask)
22
23 def forward(self, x):
24 # x: [b, n, d]
25 H, W = self.resolution
26 B, N, C = x.shape
27 h = x
28 x = self.attn_norm(x)
29 x = rearrange(x, "b (h w) c -> b h w c", h=H, w=W)
30
31 if self.shift_size > 0:
32 shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
33 else:
34 shifted_x = x
35
36 x_windows = windows_partition(shifted_x, self.window_size)
37
38 if self.attn_mask is not None:
39 self.attn_mask = repeat(self.attn_mask[None, ...], "() num_patches h w -> b num_patches h w", b=B)
40 # exactly the same rearrange with that in windows_reverse
41 self.attn_mask = rearrange(self.attn_mask, "b num_patches h w -> (b num_patches) () h w")
42
43 attn_windows = self.attn(x_windows, mask=self.attn_mask)
44 attn_windows = windows_reverse(attn_windows, window_size=self.window_size, h=H, w=W)
45
46 # reverse cyclic shift
47 if self.shift_size > 0:
48 x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
49 else:
50 x = shifted_x
51
52 x = h + attn_windows
53
54 h = x
55 x = self.mlp_norm(x)
56 x = self.mlp(x)
57 x = h + x
58 return
1class SwinStage(nn.Module):
2 def __init__(self,
3 dim: int,
4 input_resolution: Tuple[int, int],
5 depth: int,
6 num_heads: int,
7 window_size: int,
8 patch_merging: Optional[PatchMerging] = None):
9 super(SwinStage, self).__init__()
10 self.blocks = nn.ModuleList()
11
12 for i in range(depth):
13 self.blocks.append(
14 SwinBlock(dim=dim,
15 input_resolution=input_resolution,
16 num_heads=num_heads,
17 window_size=window_size,
18 shift_size=0 if i % 2 == 0 else window_size // 2)
19 )
20 if patch_merging is None:
21 self.patch_merging = nn.Identity()
22 else:
23 self.patch_merging = patch_merging(input_resolution, dim)
24
25 def forward(self, x):
26 for block in self.blocks:
27 x = block(x)
28
29 x = self.patch_merging(x)
30 return
1class SwinTransformer(nn.Module):
2 def __init__(self,
3 image_size=224,
4 patch_size=4,
5 embed_dim=96,
6 window_size=7,
7 num_heads=[3, 6, 12, 24],
8 depths=[2, 2, 6, 2],
9 num_classes=1000,
10 output_hidden_states=False
11 ):
12 super(SwinTransformer, self).__init__()
13 self.num_class = num_classes
14 self.depths = depths
15 self.num_heads = num_heads
16 self.embed_dim = embed_dim
17 self.num_stages = len(depths)
18 self.num_features = int(self.embed_dim * (2 ** (self.num_stages - 1)))
19 self.patch_resolution = [image_size // patch_size, image_size // patch_size]
20
21 self.patch_embedding = PatchEmbedding(patch_size=patch_size, embed_dim=embed_dim)
22 self.stages = nn.ModuleList()
23 self.output_hidden_states = output_hidden_states
24
25 for idx, (depth, n_heads) in enumerate(zip(self.depths, self.num_heads)):
26 h, w = self.patch_resolution
27 stage = SwinStage(dim=int(self.embed_dim * (2 ** idx)),
28 input_resolution=(h // (2**idx), w // (2**idx)),
29 depth=depth,
30 num_heads=n_heads,
31 window_size=window_size,
32 patch_merging=PatchMerging if (idx < self.num_stages-1) else None)
33 self.stages.append(stage)
34 self.window_size = window_size
35 self.norm = nn.LayerNorm(self.num_features)
36 self.avgpool = nn.AdaptiveAvgPool1d(1) # last diemnsion will be shrinked to 1
37 self.fc = nn.Linear(self.num_features, num_classes)
38
39 def forward(self, x):
40 x = self.patch_embedding(x)
41
42 for stage in self.stages:
43 x = stage(x)
44
45 x = self.norm(x)
46
47 if self.output_hidden_states:
48 x = rearrange(
49 x,
50 "b (win_size_h win_size_w) embed_dim -> b embed_dim win_size_h win_size_w",
51 win_size_h=self.window_size,
52 win_size_w =self.window_size
53 )
54 return x
55
56 x = rearrange(x, "b num_windows embed_dim -> b embed_dim num_windows")
57 x = self.avgpool(x)
58 x = rearrange(x, "b embed_dim c -> b (embed_dim c)") # c = 1 due to avgpool
59 x = self.fc(x)
60 return x
1t = torch.randn([4, 3, 224, 224])
2patch_embedding = PatchEmbedding(patch_size=4, embed_dim=96)
3swin_block = SwinBlock(dim=96, input_resolution=[56, 56], num_heads=4, window_size=7)
4shifted_swin_block = SwinBlock(dim=96, input_resolution=[56, 56], num_heads=4, window_size=7, shift_size=7 // 2)
5patch_merging = PatchMerging(input_resolution=[56, 56], dim=96)
6
7out = patch_embedding(t) # result: [4, 56*56, 96], here (224/4) * (224/4) = 56*56
8out = swin_block(out) # result: [4, 56*56, 96]
9out = shifted_swin_block(out) # result: [4, 56*56, 96]
10out = patch_merging(out) # result: [4, 784, 192], here 56*56 / 4 = 784
11 # 784 = 28*28 is considered as new number of windows