Swin Transformer代码详解

源代码地址:[https://github.com/microsoft/Swin-Transformer]

Swin Transformer 代码结构

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
  config.py
  get_started.md
  logger.py
  lr_scheduler.py
  main.py
  optimizer.py
  utils.py

├─configs
      swin_base_patch4_window12_384.yaml
      swin_base_patch4_window7_224.yaml
      swin_large_patch4_window12_384.yaml
      swin_large_patch4_window7_224.yaml
      swin_small_patch4_window7_224.yaml
      swin_tiny_patch4_window7_224.yaml

├─data
      build.py
      cached_image_folder.py
      samplers.py
      zipreader.py
      __init__.py

├─figures
      teaser.png

└─models
        build.py
        swin_transformer.py
        __init__.py

Swin Transformer执行文件是main.py,它包含了几个基本的函数:

zoom

程序入口是main,调用其它函数。Main中最关键的是这两个对象的建立

zoom

build_model

build_model位于/models/build.py文件内,build_optimizer位于optimizer.py文件内。首先先分析build_model,它的内容只是继承了一下SwinTransformer类。

zoom

SwinTransformer类位于Swin_transformer.py中,它包含了模型建立阶段的所有操作,包括将图像拆分为Patchs并内嵌embedding、建立Swin Transformer block模块等等。Swin_transformer.py文件包含这些类和函数:

zoom

zoom

SwinTransformer

SwinTransformer类调用了其它几个类和函数,实现了算法建模的主要功能。

zoom

首先先看图像的拆分和embedding:

zoom

它实现了图像嵌入PatchEmbed类的实例,这个对象定义了图片大小,patch大小,嵌入维度标准化层等等参数。而PatchEmbed类继承了pytorch的nn.Module,里面只有两个关键方法,forward和flops。

zoom

forward是重新实现的nn.Module方法,它描述了模块内对张量(BCHW型的)进行的操作。

重点是这一行

1
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C

其中self.proj是在__init__中定义的

1
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

就是用nn.Conv2d实现的逐个patch乘以相应大小的卷积核,并且步长Stride也是batch尺寸。这样就实现了每个卷积核只对每个patch线性映射一次的效果。输出通道数即为卷积核的个数。

zoom

卷积操作(线性映射)以后,还需要对其进行压平处理。x.flatten(2).transpose(1, 2)的效果如图

zoom

zoom

意思是将原本(B,C,H,W)维度的tensor转化为(B,H*W,C),即将2维的特征压平成1维的。

Flops是算浮点运算次数的方法,可以先不看。

然后一起学习一下SwinTransformer类实现的其它功能。

随机深度stochastic depth。这部分的代码是

1
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule

借鉴的文章是Deep Networks with Stochastic Depth(https://arxiv.org/pdf/1603.09382)))。原理后面再说。

下面是建立层

zoom

代码中建了一个nn.ModuleList,它可以像python的list一样被索引,所以也可以append。这样就可以做一个循环,然后把所有层都放到self.layer中。根据层数不同,每次循环改变维数。这部分功能主要用到了一个BasicLayer类,它和SwinTransformer类在同一文件中,同样继承的nn.Module。

zoom

这个BasicLayer类里面,建立了self.block,是用SwinTransformerBlock建的。

zoom

SwinTransformerBlock

SwinTransformerBlock是文章内容的核心之一,它的主要功能是进行窗口分割、改变Patch的尺寸、窗口shift后进行填充、计算多头注意力。

zoom

先是self.norm1标准化层,然后是self.attn计算注意力,然后再是标准化层self.norm2,然后是多层感知器self.mlp。后面定义了切割窗口时如果需要填充,标记了注意力掩码attn_mask,方便后面计算注意力的时候只计算自己相关的窗口,而不计算位移填充到一起的不相关的窗口。

重新构建的forward函数,包括循环填充、划分窗口、W-MSA/SW-MSA、去掉多余窗口以及前向短路FFN。

zoom

在__init__中定义了self.attn计算注意力,构建了一个WindowAttention的对象。这部分是ViT方法的核心,计算了窗口的注意力。

WindowAttention

zoom

计算多头注意力的过程是先建立了相对位置的偏差表

1
self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH

然后计算了窗口内每个tokens的成对相对位置索引relative_position_index。# 窗口WhWw, WhWw

zoom

重新构建的forward函数进行了注意力计算。首先计算query、key、value三个参数。

1
2
3
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]

然后计算注意力

$$\begin{equation} \operatorname{Attention}(Q, K, V)=\operatorname{SoftMax}\left(Q K^{T} / \sqrt{d}+B\right) V \end{equation}$$

1
2
3
4
attn = (q @ k.transpose(-2, -1))
attn = attn + relative_position_bias.unsqueeze(0)
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)

See Also