Swin Transformer原理梳理

源代码地址:https://github.com/microsoft/Swin-Transformer

Transformer 是一种和CNN不同思路的机器视觉方法,开始被应用于自然语言处理,最近几年被引入视觉方向,发展十分迅速,但一直没有超过CNN,直到Swin Transformer的出现,以它作为backbone的网络霸榜了视觉的多个领域。我们来一起学习一下它的结构和原理。

单头注意力(Single-headed attention)

首先要看的是单头注意力模型,它是transformer模型的基础。也是和CNN卷积运算最大的不同。

$$y_{i j}=\sum_{a, b \in \mathcal{N}_{k}(i, j)} \operatorname{softmax}_{a b}\left(q_{i j}^{\top} k_{a b}\right) v_{a b}$$
$ \text { queries } q_{i j}=W_{Q} x_{i j}, \text { keys } k_{a b}=W_{K} x_{a b}, \text { and values } v_{a b}=W_{V} x_{a b} $
是位置ij和其相邻像素的线性变换。如图1所示

1

图1

多个注意头(multi-headed attention)用于学习输入的多个不同表示。将$x_{ij}$沿着深度拓展到N组,使用不同的变换分别计算每个组的单头注意力参数,再将输出表示连接到最终输出中。

早期的实验表明,使用相对位置嵌入可以显着提高精度。这样每个元素有两个距离,行偏移和列偏移。如图2所示

2

图2

这种相对空间的注意力定义为

$$\begin{equation} y_{i j}=\sum_{a, b \in \mathcal{N}_{k}(i, j)} \operatorname{softmax}_{a b}\left(q_{i j}^{\top} k_{a b}+q_{i j}^{\top} r_{a-i, b-j}\right) v_{a b} \end{equation} $$

$r_{a-i, b-j}$是行列偏移量相加。

Vision Transformer (VIT)

然后看一下VIT,它是Swin Transformer借鉴最多的一篇文章。 ViT原理中是将$H×W×C$维的图像,分解成一系列$N×(P^2·C)$维的图像。HW是原图像的尺寸,C代表图像的通道数。$(P,P)$ 是分解后小图像的分辨率;$N = HW/P^2$是 分解的图像个数,即Patches个数,是Transformer的有效输入序列长度。 然后将每个Patch压平(记作tokens),这部分功能就叫patch embeddings。如图3

3

图3

每一个Patch经过嵌入(embeddings)后,在经过batch标准化LayerNormal,然后经过上述的多头注意力算法,经过类似ResNet的前项短路结构,再经过归一化,再经过多层感知机MLP,这部分称作转换编码Transformer Encoder。

Vision Transformer通常只适合于大数据集。它采用的是固定的Patch分解方法。

Swin Transformer

Swin Transformer首先通过ViT类型的方法将输入RGB图像分割成不重叠的patches。每个patch被视为一个“标记”,即tokens,它的特征被设置为原始像素RGB值的串联。文中使用4×4的patch大小,因此每个patch的特征维数为4×4 × 3(通道数) = 48。 注意力模型修改为Swin Transformer。即,将标准多头自注意(MSA)模块替换为基于移位窗口的模块SW-MSA,其它保持不变,如图4

4

图4

Swin Transformer Block是Swin Transformer的基本单元,它工作时将两个Transformer串联,一个是W-MSA(即常规尺寸的多头注意力模块),一个是SW-MSA(移动窗口后的多头注意力模块)。在原始特征上应用上面提到的线性嵌入层,使用4×4的patch大小,因此最初分解的序列patch数为(H/4×W/4),在自然语言处理中一般叫“标记”tokens。根据前面TiV原理,压平变为1D处理后其投射维度记为C。用Swin Transformer block处理这些tokens, 为了产生分层表示,随着网络加深,将部分patchs合并,来减少tokens的数量。第一个合并层 先将每组相邻的2 × 2特征相连。这样就减少了4倍的tokens的数量,相当于2倍下采样。而且此时特征维度变为4C,再应用一个线性层,将特征维度输出变为2C。再用Swin Transformer block进行特征变换,整图分辨率变为(H/8×W/8),同样再经过Swin Transformer 变为(H/16×W/16). 这样就产生了和传统CNN相同的分层表示方法。如图5所示,这和b图中ViT的方法是不同的。