Vision Transformer

发布时间 2023-07-27 11:17:39作者: dctwan

Vision Transformer

本文关注ViT论文4.5 Inspecting Vision Transformer可视化的原理及实现,此外还对ViT pytorch源码实现进行理解

Introduction

论文地址

Title

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

一张图像就相当于一些16x16的单词:将transformer应用到大规模图像识别任务上

transformer是2017年提出基于自注意力机制解决翻译任务的方法,一经推出就在nlp(neutral language process)领域爆火,image recognition是cv(computer vision)领域的基本任务之一。本文将nlp领域的transformer方法应用到cv领域的图像识别任务上,故将其称为vision transformer,简称ViT

Usage

ViT自2020年提出之后,迅速在cv领域爆火,论文引用量已超18k(2023.7) 图片来源

image-20230726105723431

ViT甚至在2022年超越了经典的ResNet架构的CNN模型,成为图像分类任务中的主流方法 图片来源

image-20230726105858527

目前ViT主要应用在图像分类、语义分割、目标检测这三个主流的cv任务上 图片来源

image-20230726110410857

目前ImageNet图像分类数据集上表现最好的3个方法中,都是基于transformer的结构,而ViT是将transformer应用到cv领域的开山之作,所以有必要认真掌握,就像掌握AlexNet和ResNet那样 图片来源

image-20230726110527107

Method

Overview

ViT 模型动态图 展示了ViT模型的整体结构和forward过程,学习过transformer之后,再学习这篇文章,相对就会感觉比较轻松。

ViT的主要思想就是将transformer模型中Encoder部分直接搬过来,想办法将其应用到图片上。就ViT注意力机制部分来说,比transformer还要简单,只有Encoder,没有Decoder,以及Encoder-Decoder Cross Attention。

transformer用来处理seq2seq的任务,也就是输入和输出都是序列,Encoder的输入和输出也都是序列。那么要将其应用到图像分类任务上,关键是解决2个问题

  1. 如何将图片转换成Encoder能够接受的输入,也就是将图片转换为序列(tokens)
  2. 如何利用Encoder的输出来做分类任务,最终需要输出的图像的类别分布

基于上述关键问题,我将模型拆解为3个部分进行理解,其中Encoder部分在transformer中已经学习过,所有重点是InputOutput部分,随着学习的深入会发现,Output部分非常简单,就是接上一个MLP的分类头即可,所以重中之重是理解Input部分。

  • Input

    将原始图像转换为Encoder能够接受的输入

  • Encoder

    transformer原封不动搬过来

  • Output

    将经过Encoder处理的输出用来做分类

image-20230726112340098

论文中4个公式分别与3个部分的对应关系如下

image-20230726153225184

Input

基本思想是将原始图片拆分成多个小图(对应了论文的标题),再把这些小图变换成Encoder的输入

这部分由可拆解成4部分来理解

  1. 将原始图像拆分成固定大小的小图片,论文中称之为patch,并展成向量
  2. 对patches向量做线性映射,映射到特定的长度,也就是Encoder输入token的维度
  3. 拼接上class token,此方法参考了BERT
  4. 对每个token(包括class token)加上可学习的位置编码

Encoder

参见transformer学习:

  1. https://jalammar.github.io/illustrated-transformer/
  2. transformer论文《Attention is all you need》,arXiv:1706.03762
  3. 台大李宏毅视频课程:https://www.bilibili.com/video/BV1Wv411h7kN

Output

使用class token经过Encoder后的输出向量,接上一个MLP分类头,映射到类别数量即可

Experiments

ViT论文根据提出了大小不同的3个模型

image-20230726153307330

实验结果如下图

image-20230726153414612

通过比较可以得出结论

  1. 更大的ViT模型有更好的表现
  2. ViT模型只有在大规模数据集上预训练后,准确度才能超过传统的ResNet架构模型
  3. ViT模型的训练开销要比传统模型小

Explainability

可视化部分共4个图,是老板提出要重点理解和实现的内容,分别是

  1. Linear Embedding过滤器学习的内容
  2. Position Embedding
  3. Attention Distance随网络深度增加的变化
  4. ViT模型关注图像中的哪些内容
image-20230726153750722 image-20230726154116212

Visualize Filters of Linear Embedding

这一部分是花最长时间搞懂的部分,论文中并没有提到如何实现,在探索的过程中走了很多弯路,原因有2个

  1. 论文中模型描述和实现代码的差异
  2. 自身知识的局限

先描述一下我在探索的过程中遇到的问题

论文中提到,要可视化Linear Embedding层中Filters学习到的东西,我当时就陷入了矛盾,可视化filters是可视化卷积层的filters,而我们现在是线性层啊,怎么可视化线性层呢?于是就想去找可视化线性层的相关资料,找来找去都是可视化卷积层的。后来又去找可视化卷积层的相关资料,发现入了一个大坑,有很多方法,什么可视化特征图、反卷积,而且一开始只想到用在python中使用matlab的库绘图,本身对这些工具的使用就不熟悉,导致很郁闷。

后来又去仔细看了一下代码,发现:Linear Embedding层是用卷积核大小和步长都为patch_size,输出通道数为768的卷积层实现的!此时,虽然还不知道如何实现,但最起码有了方向,去查如何可视化卷积层就行了

查相关资料的时候以结果为导向,直接看博客中有没有和论文中相似的可视化图,很快就查到这篇博客,看到其中贴的图和论文中的图特别相似,并且有代码实现,使用tensorboard实现的可视化,于是把代码拿过来改了一下,实现了和论文中类似的可视化结果

image-20230726120628951

修改后的代码可视化Linear Embedding层卷积层的权重,代码如下

import os
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import torchvision.utils as vutils
from modeling import VisionTransformer, CONFIGS
from sklearn.decomposition import PCA
from einops import rearrange

# 环境设置
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
log_dir = os.path.join(BASE_DIR, "results")
writer = SummaryWriter(log_dir=log_dir, filename_suffix="_kernel")
imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

# 模型选择:vit Large patch_size=32
config = CONFIGS["ViT-L_32"]
# 创建模型 并加载预训练权重
model = VisionTransformer(config, num_classes=21843, zero_head=False, img_size=224, vis=True)
model.load_from(np.load("attention_data/ViT-L_32.npz"))
model.eval()

# 取卷积层权重 X:(1024, 3, 32, 32)
X = model.transformer.embeddings.patch_embeddings.weight.detach()
# 将后3个维度合并,使用主成分分析工具要求输入维度为2
X = rearrange(X, 'n c h w -> n (c h w)')
# 与ViT论文一致,做28个主成分; pca components: (28, 3072)
pca = PCA(n_components=28)
pca.fit(X)
# 再将主成分转换成tensorboard 能够展示的RGB的形式
# filters: (28, 3, 32, 32)
filters = torch.tensor(pca.components_.reshape((-1, 3, 32, 32)))
# 使用tensorboard绘制filters的RGB图像
grid = vutils.make_grid(filters, normalize=True, scale_each=True, nrow=7)
writer.add_image("RGB embedding filters\nfirst 28 principal components", grid, global_step=620)

Visualize Positon Embedding

对于196个position embedding,计算每一个position embedding和其他position embedding之间的余弦相似度,会得到一个矩阵,矩阵的第 i 行表示,第 i 个position embedding和其他position embedding的相关性,将每一行转换成(14,14)小图,然后绘制出来即可,总共有14x14个这样的小图,每一个小图表示该位置上的patch和其他patch的相关性

image-20230726162310205

实现代码如下

import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from modeling import VisionTransformer, CONFIGS
from einops import rearrange

# 环境设置
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))


# 定义计算两个向量余弦相似度函数
def cosine_similarity(x, y, norm=False):
    """ 计算两个向量x和y的余弦相似度 """
    assert len(x) == len(y), "len(x) != len(y)"

    xy = x.dot(y)
    x2y2 = np.linalg.norm(x, ord=2) * np.linalg.norm(x, ord=2)
    sim = xy / x2y2
    return sim


# 模型选择:vit base
config = CONFIGS["ViT-B_16"]
# 创建模型 并加载预训练模型权重
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
model.load_from(np.load("attention_data/ViT-B_16-224.npz"))
model.eval()

# 获取可学习的位置编码
pos = model.transformer.embeddings.position_embeddings  # pos: (1, 197, 768)
# 移除维度大小为1的维度,且忽略[cls] token,并转换为numpy类型
pos = torch.squeeze(pos)[1:, :].detach().numpy()  # pos: numpy (196, 768)

# 计算similarity matrix
result = np.zeros((196, 196))
for i in tqdm(range(196)):
    for j in range(196):
        result[i, j] = cosine_similarity(pos[i, :], pos[j, :])
# result: (196, 196) => (196, 14, 14)
result = rearrange(result, 'n (h w) -> n h w', h=14, w=14)

# 绘制图像
fig, axs = plt.subplots(nrows=14, ncols=14, figsize=(14, 14),
                        subplot_kw={'xticks': [], 'yticks': []})
i = 0
for ax in axs.flat:
    ax.imshow(result[i, :, :], cmap='viridis')
    i += 1
plt.tight_layout()
plt.show()

Visualize Attention Distance with Network Depth

还没有看对应的实现,只是理解了一下。attention distance就相当于是卷积神经网络中的卷积层的感受野

在网络的较低层,由于注意力机制的存在,模型既关注一些小的区域,也关注一些大的区域,随着网络深度的增加,模型越来越关注更高的抽象层次,也就是图片的语义部分,感受野就比较大(个人感觉我这段理解写的很垃圾)

image-20230726162521398

Visualize Attention Map

ViT在大规模数据集上与训练之后,在图像分类任务上取得了非常好的表现。想要深入了解模型,探索模型的可解释性,就要搞清模型根据图像的哪些部分做出的决策,或者说模型在这个任务上关注了图像的哪些部分。将模型关注的图像的部分可视化,也就是绘制特征图(Attention Map)

论文中明确提到了使用Attention Rollout方法绘制特征图,这是另外一篇论文提出的可视化transformer注意力流的其中一个方法

Samira Abnar and Willem Zuidema. Quantifying attention flow in transformers. In ACL, 2020

arXiv:2005.00928

在附录中作者也简要介绍了这个方法的思想,将ViT模型每一个注意力层的多个头的权重平均一下,然后把所有注意力层的多头平均权重累乘,得到的结果就能够表示整个模型的注意力

Briefly, we averaged attention weights of ViTL/16 across all heads and then recursively multiplied the weight matrices of all layers. This accounts for the mixing of attention across tokens through all layers.

为了更好的理解这个方法,首先要明确一个概念,注意力权重(Attention Weights)或称注意力矩阵 图片来源

image-20230727083244635

熟悉transformer的可以非常好的理解这个概念,也就是每个token和其他token计算的attention score所构成的矩阵

Attention Rollout方法的说明如下,简单理解就是将所有Attention Weight累乘,但是这里面要考虑多头和残差连接的处理

image-20230727083447505

具体实现如下图所示,这里的Attention Maxtrix是运行过attention rollout算法之后的结果,通过reshape,再上采样之后得到一个类似于掩码功能的图像mask,再将mask作用到原图像上可视化出来即可

image-20230727083652186 image-20230727091643092

至于HeatMap的生成方式,同样也是使用mask作用在原始图像上,只不过使用python相关包对mask进行以下处理即可

实现代码如下 代码参考

import os

import torch
import numpy as np
import cv2
from urllib.request import urlretrieve
from PIL import Image
from torchvision import transforms
from modeling import VisionTransformer, CONFIGS

# 环境设置
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
imagenet_labels = dict(enumerate(open('ilsvrc2012_wordnet_lemmas.txt')))

# Prepare Model
# 模型选择:vit base
config = CONFIGS["ViT-B_16"]
# 创建模型
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
# 加载预训练模型权重
model.load_from(np.load("ViT-B_16-224.npz"))
model.eval()

# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("attention_data/plane2.png")
x = transform(im)
print(x.size())

# inference
logits, att_mat = model(x.unsqueeze(0))  # att_mat (list:12) elem:(1, 12, 197, 197)
att_mat = torch.stack(att_mat)  # att_mat (12, 1, 12, 197, 197)
att_mat = att_mat.squeeze(1)  # att_mat (12, 12, 197, 197)

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)  # att_mat (12, 12, 197, 197) => (12, 197, 197)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices	
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])

# Attention from the output token to the input space.
v = joint_attentions[-1]  # v: (197, 197)
grid_size = int(np.sqrt(aug_att_mat.size(-1)))  # grid_size: 14
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()  # mask: (14, 14)
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]  # mask: (224, 224, 1)
# 将mask作用到原始图片上
result = (mask * im).astype("uint8")
# 以下为绘图操作
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
ax1.set_title('Original')
ax2.set_title('Attention Map')
ax1.imshow(im)
ax2.imshow(result)
plt.show()


# 以下为实现heatmap的代码 

# def show_mask_on_image(img, mask):
#     img = np.float32(img) / 255
#     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
#     heatmap = np.float32(heatmap) / 255
#     cam = heatmap + np.float32(img)
#     cam = cam / np.max(cam)
#     return np.uint8(255 * cam)
# 
# 
# np_img = np.array(im)[:, :, ::-1]
# mask = show_mask_on_image(np_img, mask)
# cv2.imshow("Input Image", np_img)
# cv2.imshow("Output Image", mask)
# cv2.waitKey(-1)

Code

代码参考

代码结构框架如下图所示

image-20230727093546375

Model Variants

以ViT Base-16模型作为例子理解

# Prepare Model
# 模型选择:vit base
config = CONFIGS["ViT-B_16"]
# 创建模型
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
# 加载预训练模型权重
model.load_from(np.load("attention_data/ViT-B_16-224.npz"))
model.eval()

配置参数如下

def get_b16_config():
    """Returns the ViT-B/16 configuration."""
    config = ml_collections.ConfigDict()
    config.patches = ml_collections.ConfigDict({'size': (16, 16)})
    config.hidden_size = 768
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 3072
    config.transformer.num_heads = 12
    config.transformer.num_layers = 12
    config.transformer.attention_dropout_rate = 0.0
    config.transformer.dropout_rate = 0.1
    config.classifier = 'token'
    config.representation_size = None
    return config

阅读类代码时,主要关注forward函数,根据需要查看__init__函数中的定义,搞清类别之间的调用关系,对照模型架构图进行理解,补充学习代码细节实现(pytorch)

VisionTransformer

VisionTransformer类,可直接用来实例化一个模型

  • 输入

    一批原始图像:(B,224,224,3)

  • 调用

    Transformer:将原始图片做embedding转成Encoder输入tokens,经过Encoder处理产生输出

    Linear:对Encoder的输出,取class token对应的输出,接上分类头,产生类别分布

  • 输出

    类别分布logits和12个注意力层的attention weights

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.head = Linear(config.hidden_size, num_classes)

    def forward(self, x, labels=None):
        # 输入x: (B, 224, 224, 3)
        x, attn_weights = self.transformer(x)
        # 输出x:(B, 197, 768) Transformer Encoder的输出,取[cls]对应的输出作为分类头的输入
        # attn_weights list:12 (1, 12, 197, 197)

        logits = self.head(x[:, 0])  # Linear: 768 => 1000

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights  # 最终返回1000分类的概率分布和attn_weights

Transformer

Transformer类,VisionTransformer类的主要构成部分,由Embedding层和Encoder组成

  • 输入

    一批原始图像:(B,224,224,3)

  • 调用

    Embeddings:将原始图像转成Encoder的输入tokens

    Encoder:输入tokens经过12个注意力层处理产生对应输出

  • 输出

    原始图像经过Encoder之后的输出,以及各注意力层的attention weights

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)  # 通过embedding层产生输入tokens
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights

Embeddings

Embeddings类,由Transformer类调用,将原始图片转换为Encoder的输入tokens,由patch embedding、class token和positional embedding 三部分构成

  • 输入

    一批原始图片:(B,224,224,3)

  • 处理过程

    1. 在原始图像上使用卷积核、步长均为patch size的二维卷积操作,将原始图像转换成196个patch embedding
    2. 在patch embeddings上拼接class token,形成197个输入token
    3. 对每个输入token加上可训练的位置编码position embedding
  • 输出

    包含class token和位置编码的197个输入token

其中需要重点理解的是:在原始图像上使用卷积核尺寸、步长均为patch size的二维卷积操作,将原始图像转换成196个patch embedding

这一卷积层相当于完成了3个动作:

  1. 将原始图像拆分成16x16x3的196个patch

    这一个动作是通过卷积核尺寸、步长都为patch size来实现的,随着过滤器(卷积核)在原始图像(224x224x3)上滑动,那么每次覆盖的就是一个16x16x3的小图嘛,整个过程做下来,不就相当于是分了196个patch

  2. 将patch展成向量16x16x3=768维向量

    仔细想一下卷积操作,把卷积核和原图对应位置像素值加权求和,不就相当于是将这个patch展成向量,然后把卷积核当作另一个向量,求向量内积嘛

  3. 将patch展成的向量做线性映射,映射到特定维度(根据模型参数而定,不一定必须要768,但ViTB16就是768)

    只用二维卷积,没有用非线性激活函数,相当于只做了线性映射,也就相当于是一个Linear层,那么这次卷积的输出通道数就相当于是映射到的输出向量的长度

感觉文字解释的不是很清楚,做了一个动图再表达一下

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """
    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:
            grid_size = config.patches["grid"]
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            n_patches = (img_size[0] // 16) * (img_size[1] // 16)
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches + 1, config.hidden_size))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        B = x.shape[0]  # batch
        cls_tokens = self.cls_token.expand(B, -1, -1)

        if self.hybrid:
            x = self.hybrid_model(x)
        # 用f=16, s=16,out_channels=768的卷积做拆分patche和将patch做线性映射的工作
        x = self.patch_embeddings(x)  # (B, 768, 14, 14)
        x = x.flatten(2)  # 从第2维开始展平 => (B, 768, 196)
        x = x.transpose(-1, -2)  # 将后两个维度交换位置 => (B, 196, 768)
        x = torch.cat((cls_tokens, x), dim=1)  # 在第1维度上拼接cls token => (B, 197, 768)
        # 加上位置编码(可训练)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings

Encoder

Encoder类,由Transformer类调用,堆叠12个注意力层

  • 输入

    Embeddings层的输出,输入tokens:(B,197,768)

  • 处理过程

    调用Block,堆叠12个注意力层,创建Encoder,然后将输入依次经过12个注意力层产生输出

  • 输出

    tokens经过Encoder之后的输出,已经每一个注意力层的attention weight

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

Block

Block类,由Encoder类调用,定义一个注意力层

  • 输入

    tokens的中间状态,前一层的输出

  • 处理过程

    调用Attention

    如图

    image-20230727110021063
  • 输出

    中间状态经过本层后的结果,作为后一层的输入

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

Attention

好累...

这里的关键是理解多头注意力机制,学习Transformer时应该要理解。基本思想是:对QKV,分别使用一个大矩阵,一次性做12个头的线性变换处理,然后再从结果中拆分出12个头,这样做的原因是并行性更好且更容易编码

class Attention(nn.Module):
    """
        Multi-Head Attention
    """

    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]  # heads=12
        # head_size=768/12=64
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        # all_head_size=12*64=768
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Linear: 768 -> 768(768=12*64 相当于一次性进行12个头的映射)
        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        # Linear: 768 -> 768
        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        # slice左闭右开,`+`直接在后面拼接维度,实现效果如下
        # x:(B, 197, 768) -> new_x_shape: (B, 197, 12, 64)
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # torch.view()用于改变张量的形状,参数使用*new_x_shape
        # 参数前的`*`号是解包操作,将new_x_shape拆开成一个个的参数,相当于x.view(B, 197, 12, 64)
        x = x.view(*new_x_shape)
        # torch.permute()用于交换张量的维度,参数表示交换后维度的下标顺序,实现效果如下
        # x:(B, 197, 12, 64) -> return:(B, 12, 197, 64)
        return x.permute(0, 2, 1, 3)
        # transpose_for_scores函数的作用是,将合并的QKV,按照头的维度,拆分成12个头各自的qkv,增加的第1维是头的维度

    def forward(self, hidden_states):
        # hidden_states: (B, 197, 768)
        mixed_query_layer = self.query(hidden_states)   # (B, 197, 768)
        mixed_key_layer = self.key(hidden_states)       # (B, 197, 768)
        mixed_value_layer = self.value(hidden_states)   # (B, 197, 768)

        # 从mixed中拆分出12个头
        query_layer = self.transpose_for_scores(mixed_query_layer)  # (B, 12, 197, 64)
        key_layer = self.transpose_for_scores(mixed_key_layer)      # (B, 12, 197, 64)
        value_layer = self.transpose_for_scores(mixed_value_layer)  # (B, 12, 197, 64)

        # 以下3行代码计算Attention Matrix
        # Q:(B, 12, 197, 64)*K.T:(B, 12, 64, 197) => (B, 12, 197, 197)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)  # 缩放
        attention_probs = self.softmax(attention_scores)  # softmax处理

        # 保存Attention Matrix
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        # 对V加权求输出   A:(B, 12, 197, 197)*V:(B, 12, 197, 64) => (B, 12, 197, 64)
        context_layer = torch.matmul(attention_probs, value_layer)
        # (B, 12, 197, 64) => (B, 197, 12, 64) 使用 .contiguous()深拷贝
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # new_context_layer_shape: (B, 197, 768) 也就是将12个头concat
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        # 合并后再接一次线性映射得到最终的输出
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

Conclusion

Achievement

提出了将transformer应用于图像分类任务的ViT模型,在大规模数据集上预训练之后的取得了SOTA,并且训练代价比传统ResNet模型小

Prospect

  1. 将ViT模型应用到cv领域其他任务上,目标检测、分割...
  2. 探索自监督预训练方法

Reference

下面是在学习ViT及可视化过程中参考的一些文章

ViT讲解

  1. https://theaisummer.com/vision-transformer/
  2. https://blog.csdn.net/qq_36560894/article/details/119706064
  3. https://www.bilibili.com/video/BV15P4y137jb

可视化

  1. visualize posotion embedding:https://blog.csdn.net/weixin_41978699/article/details/122404192
  2. visualize attention map 及 ViT pytorch实现:https://github.com/jeonsworld/ViT-pytorch/
  3. visualize filters of linear embedding:https://blog.csdn.net/hjkdh/article/details/125357950

博观约取,厚积薄发

能力和知识有限,如果错误,请不吝赐教。 —— dctwan