神经网络学习小记录68——Tensorflow2版 Vision Transformer(VIT)模型的复现详解

2023-11-20

学习前言

视觉Transformer最近非常的火热,从VIT开始,我先学学看。
在这里插入图片描述

什么是Vision Transformer(VIT)

Vision Transformer是Transformer的视觉版本,Transformer基本上已经成为了自然语言处理的标配,但是在视觉中的运用还受到限制。

Vision Transformer打破了这种NLP与CV的隔离,将Transformer应用于图像图块(patch)序列上,进一步完成图像分类任务。简单来理解,Vision Transformer就是将输入进来的图片,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列,将组合后的结果传入Transformer特有的Multi-head Self-attention进行特征提取。最后利用Cls Token进行分类。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/classification-tf2
复制该路径到地址栏跳转。

Vision Transforme的实现思路

一、整体结构解析

在这里插入图片描述
与寻常的分类网络类似,整个Vision Transformer可以氛围两部分,一部分是特征提取部分,另一部分是分类部分。

在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。

二、网络结构解析

1、特征提取部分介绍

a、Patch+Position Embedding

在这里插入图片描述
Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列

该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了

在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。
请添加图片描述
下一步就是将这个特征层组合成序列,组合的方式非常简单,就是将高宽维度进行平铺,14, 14, 768在高宽维度平铺后,获得一个196, 768的特征层。平铺完成后,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,图中的这个0*就是Cls Token,我们此时获得一个197, 768的特征层在这里插入图片描述
添加完成Cls Token后,再为所有特征添加上位置信息这样网络才有区分不同区域的能力。添加方式其实也非常简单,我们生成一个197, 768的参数矩阵,这个参数矩阵是可训练的,把这个矩阵加上197, 768的特征层即可。

到这里,Patch+Position Embedding就构建完成了,构建代码如下:

#--------------------------------------------------------------------------------------------------------------------#
#   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
#   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
class ClassToken(Layer):
    def __init__(self, cls_initializer='zeros', cls_regularizer=None, cls_constraint=None, **kwargs):
        super(ClassToken, self).__init__(**kwargs)
        self.cls_initializer    = keras.initializers.get(cls_initializer)
        self.cls_regularizer    = keras.regularizers.get(cls_regularizer)
        self.cls_constraint     = keras.constraints.get(cls_constraint)

    def get_config(self):
        config = {
            'cls_initializer': keras.initializers.serialize(self.cls_initializer),
            'cls_regularizer': keras.regularizers.serialize(self.cls_regularizer),
            'cls_constraint': keras.constraints.serialize(self.cls_constraint),
        }
        base_config = super(ClassToken, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] + 1, input_shape[2])

    def build(self, input_shape):
        self.num_features = input_shape[-1]
        self.cls = self.add_weight(
            shape       = (1, 1, self.num_features),
            initializer = self.cls_initializer,
            regularizer = self.cls_regularizer,
            constraint  = self.cls_constraint,
            name        = 'cls',
        )
        super(ClassToken, self).build(input_shape)

    def call(self, inputs):
        batch_size      = tf.shape(inputs)[0]
        cls_broadcasted = tf.cast(tf.broadcast_to(self.cls, [batch_size, 1, self.num_features]), dtype = inputs.dtype)
        return tf.concat([cls_broadcasted, inputs], 1)

#--------------------------------------------------------------------------------------------------------------------#
#   为网络提取到的特征添加上位置信息。
#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
class AddPositionEmbs(Layer):
    def __init__(self, image_shape, patch_size, pe_initializer='zeros', pe_regularizer=None, pe_constraint=None, **kwargs):
        super(AddPositionEmbs, self).__init__(**kwargs)
        self.image_shape        = image_shape
        self.patch_size         = patch_size
        self.pe_initializer     = keras.initializers.get(pe_initializer)
        self.pe_regularizer     = keras.regularizers.get(pe_regularizer)
        self.pe_constraint      = keras.constraints.get(pe_constraint)

    def get_config(self):
        config = {
            'pe_initializer': keras.initializers.serialize(self.pe_initializer),
            'pe_regularizer': keras.regularizers.serialize(self.pe_regularizer),
            'pe_constraint': keras.constraints.serialize(self.pe_constraint),
        }
        base_config = super(AddPositionEmbs, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        assert (len(input_shape) == 3), f"Number of dimensions should be 3, got {len(input_shape)}"
        length  = (224 // self.patch_size) * (224 // self.patch_size) + 1
        self.pe = self.add_weight(
            # shape       = [1, input_shape[1], input_shape[2]],
            shape       = [1, length, input_shape[2]],
            initializer = self.pe_initializer,
            regularizer = self.pe_regularizer,
            constraint  = self.pe_constraint,
            name        = 'pos_embedding',
        )
        super(AddPositionEmbs, self).build(input_shape)

    def call(self, inputs):
        num_features = tf.shape(inputs)[2]

        cls_token_pe = self.pe[:, 0:1, :]
        img_token_pe = self.pe[:, 1: , :]

        img_token_pe = tf.reshape(img_token_pe, [1, (224 // self.patch_size), (224 // self.patch_size), num_features])
        img_token_pe = tf.compat.v1.image.resize_images(img_token_pe, (self.image_shape[0] // self.patch_size, self.image_shape[1] // self.patch_size), tf.image.ResizeMethod.BICUBIC, align_corners=False)
        img_token_pe = tf.reshape(img_token_pe, [1, -1, num_features])
        
        pe = tf.concat([cls_token_pe, img_token_pe], axis = 1)

        return inputs + tf.cast(pe, dtype=inputs.dtype)

def VisionTransformer(input_shape = [224, 224], patch_size = 16, num_layers = 12, num_features = 768, num_heads = 12, mlp_dim = 3072, 
            classes = 1000, dropout = 0.1):
    #-----------------------------------------------#
    #   224, 224, 3
    #-----------------------------------------------#
    inputs      = Input(shape = (input_shape[0], input_shape[1], 3))
    
    #-----------------------------------------------#
    #   224, 224, 3 -> 14, 14, 768
    #-----------------------------------------------#
    x           = Conv2D(num_features, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs)
    #-----------------------------------------------#
    #   14, 14, 768 -> 196, 768
    #-----------------------------------------------#
    x           = Reshape(((input_shape[0] // patch_size) * (input_shape[1] // patch_size), num_features))(x)
    #-----------------------------------------------#
    #   196, 768 -> 197, 768
    #-----------------------------------------------#
    x           = ClassToken(name="cls_token")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x           = AddPositionEmbs(input_shape, patch_size, name="pos_embed")(x)

b、Transformer Encoder

在这里插入图片描述
在上一步获得shape为197, 768的序列信息后,将序列信息传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。

I、Self-attention结构解析

看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。
请添加图片描述
如果我们想要获得input-1的输出,那么我们进行如下几步:
1、利用input-1的查询向量,分别乘上input-1、input-2、input-3的键向量,此时我们获得了三个score
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和。
4、此时我们获得了input-1的输出。

如图所示,我们进行如下几步:
1、input-1的查询向量为[1, 0, 2],分别乘上input-1、input-2、input-3的键向量,获得三个score为2,4,4。
2、然后对这三个score取softmax,获得了input-1、input-2、input-3各自的重要程度,获得三个重要程度为0.0,0.5,0.5。
3、然后将这个重要程度乘上input-1、input-2、input-3的值向量,求和,即
0.0 ∗ [ 1 , 2 , 3 ] + 0.5 ∗ [ 2 , 8 , 0 ] + 0.5 ∗ [ 2 , 6 , 3 ] = [ 2.0 , 7.0 , 1.5 ] 0.0 * [1, 2, 3] + 0.5 * [2, 8, 0] + 0.5 * [2, 6, 3] = [2.0, 7.0, 1.5] 0.0[1,2,3]+0.5[2,8,0]+0.5[2,6,3]=[2.0,7.0,1.5]
4、此时我们获得了input-1的输出 [2.0, 7.0, 1.5]。

上述的例子中,序列长度仅为3,每个单位序列的特征长度仅为3,在VIT的Transformer Encoder中,序列长度为197,每个单位序列的特征长度为768 // num_heads。但计算过程是一样的。在实际运算时,我们采用矩阵进行运算。

II、Self-attention的矩阵运算

实际的矩阵运算过程如下图所示。我以实际矩阵为例子给大家解析:
在这里插入图片描述
输入的Query、Key、Value如下图所示:
在这里插入图片描述
首先利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。

输出的每一行,都代表input-1、input-2、input-3,对当前input的贡献,我们对这个贡献值取一个softmax。
在这里插入图片描述
在这里插入图片描述
然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
在这里插入图片描述
这个矩阵运算的代码如下所示,各位同学可以自己试试。

import numpy as np

def soft_max(z):
    t = np.exp(z)
    a = np.exp(z) / np.expand_dims(np.sum(t, axis=1), 1)
    return a

Query = np.array([
    [1,0,2],
    [2,2,2],
    [2,1,3]
])

Key = np.array([
    [0,1,1],
    [4,4,0],
    [2,3,1]
])

Value = np.array([
    [1,2,3],
    [2,8,0],
    [2,6,3]
])

scores = Query @ Key.T
print(scores)
scores = soft_max(scores)
print(scores)
out = scores @ Value
print(out)
III、MultiHead多头注意力机制

多头注意力机制的示意图如图所示:
在这里插入图片描述
这幅图给人的感觉略显迷茫,我们跳脱出这个图,直接从矩阵的shape入手会清晰很多。

在第一步进行图像的分割后,我们获得的特征层为197, 768。

在施加多头的时候,我们直接对196, 768的最后一维度进行分割,比如我们想分割成12个头,那么矩阵的shepe就变成了196, 12, 64。

然后我们将196, 12, 64进行转置,将12放到前面去,获得的特征层为12, 196, 64。之后我们忽略这个12,把它和batch维度同等对待只对196, 64进行处理其实也就是上面的注意力机制的过程了

#--------------------------------------------------------------------------------------------------------------------#
#   Attention机制
#   将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
#   然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
#   然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(Layer):
    def __init__(self, num_features, num_heads, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.num_features   = num_features
        self.num_heads      = num_heads
        self.projection_dim = num_features // num_heads

    def get_config(self):
        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()))

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2] // 3)

    def call(self, inputs):
        #-----------------------------------------------#
        #   获得batch_size
        #-----------------------------------------------#
        bs      = tf.shape(inputs)[0]

        #-----------------------------------------------#
        #   b, 197, 3 * 768 -> b, 197, 3, 12, 64
        #-----------------------------------------------#
        inputs  = tf.reshape(inputs, [bs, -1, 3, self.num_heads, self.projection_dim])
        #-----------------------------------------------#
        #   b, 197, 3, 12, 64 -> 3, b, 12, 197, 64
        #-----------------------------------------------#
        inputs  = tf.transpose(inputs, [2, 0, 3, 1, 4])
        #-----------------------------------------------#
        #   将query, key, value划分开
        #   query     b, 12, 197, 64
        #   key       b, 12, 197, 64
        #   value     b, 12, 197, 64
        #-----------------------------------------------#
        query, key, value = inputs[0], inputs[1], inputs[2]
        #-----------------------------------------------#
        #   b, 12, 197, 64 @ b, 12, 197, 64 = b, 12, 197, 197
        #-----------------------------------------------#
        score           = tf.matmul(query, key, transpose_b=True)
        #-----------------------------------------------#
        #   进行数量级的缩放
        #-----------------------------------------------#
        scaled_score    = score / tf.math.sqrt(tf.cast(self.projection_dim, score.dtype))
        #-----------------------------------------------#
        #   b, 12, 197, 197 -> b, 12, 197, 197
        #-----------------------------------------------#
        weights         = tf.nn.softmax(scaled_score, axis=-1)
        #-----------------------------------------------#
        #   b, 12, 197, 197 @ b, 12, 197, 64 = b, 12, 197, 64
        #-----------------------------------------------#
        value          = tf.matmul(weights, value)

        #-----------------------------------------------#
        #   b, 12, 197, 64 -> b, 197, 12, 64
        #-----------------------------------------------#
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        #-----------------------------------------------#
        #   b, 197, 12, 64 -> b, 197, 768
        #-----------------------------------------------#
        output = tf.reshape(value, (bs, -1, self.num_features))
        return output

def MultiHeadSelfAttention(inputs, num_features, num_heads, dropout, name):
    #-----------------------------------------------#
    #   qkv   b, 197, 768 -> b, 197, 3 * 768
    #-----------------------------------------------#
    qkv = Dense(int(num_features * 3), name = name + "qkv")(inputs)
    #-----------------------------------------------#
    #   b, 197, 3 * 768 -> b, 197, 768
    #-----------------------------------------------#
    x   = Attention(num_features, num_heads)(qkv)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x   = Dense(num_features, name = name + "proj")(x)
    x   = Dropout(dropout)(x)
    return x
IV、TransformerBlock的构建。

在这里插入图片描述

在完成MultiHeadSelfAttention的构建后,我们需要在其后加上两个全连接。就构建了整个TransformerBlock。

def MLP(y, num_features, mlp_dim, dropout, name):
    y = Dense(mlp_dim, name = name + "fc1")(y)
    y = Gelu()(y)
    y = Dropout(dropout)(y)
    y = Dense(num_features, name = name + "fc2")(y)
    return y

def TransformerBlock(inputs, num_features, num_heads, mlp_dim, dropout, name):
    #-----------------------------------------------#
    #   施加层标准化
    #-----------------------------------------------#
    x = LayerNormalization(epsilon=1e-6, name = name + "norm1")(inputs)
    #-----------------------------------------------#
    #   施加多头注意力机制
    #-----------------------------------------------#
    x = MultiHeadSelfAttention(x, num_features, num_heads, dropout, name = name + "attn.")
    x = Dropout(dropout)(x)
    #-----------------------------------------------#
    #   施加残差结构
    #-----------------------------------------------#
    x = Add()([x, inputs])

    #-----------------------------------------------#
    #   施加层标准化
    #-----------------------------------------------#
    y = LayerNormalization(epsilon=1e-6, name = name + "norm2")(x)
    #-----------------------------------------------#
    #   施加两次全连接
    #-----------------------------------------------#
    y = MLP(y, num_features, mlp_dim, dropout, name = name + "mlp.")
    y = Dropout(dropout)(y)
    #-----------------------------------------------#
    #   施加残差结构
    #-----------------------------------------------#
    y = Add()([x, y])
    return y

c、整个VIT模型的构建

在这里插入图片描述
整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。

def VisionTransformer(input_shape = [224, 224], patch_size = 16, num_layers = 12, num_features = 768, num_heads = 12, mlp_dim = 3072, 
            classes = 1000, dropout = 0.1):
    #-----------------------------------------------#
    #   224, 224, 3
    #-----------------------------------------------#
    inputs      = Input(shape = (input_shape[0], input_shape[1], 3))
    
    #-----------------------------------------------#
    #   224, 224, 3 -> 14, 14, 768
    #-----------------------------------------------#
    x           = Conv2D(num_features, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs)
    #-----------------------------------------------#
    #   14, 14, 768 -> 196, 768
    #-----------------------------------------------#
    x           = Reshape(((input_shape[0] // patch_size) * (input_shape[1] // patch_size), num_features))(x)
    #-----------------------------------------------#
    #   196, 768 -> 197, 768
    #-----------------------------------------------#
    x           = ClassToken(name="cls_token")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x           = AddPositionEmbs(input_shape, patch_size, name="pos_embed")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768  12次
    #-----------------------------------------------#
    for n in range(num_layers):
        x = TransformerBlock(
            x,
            num_features= num_features,
            num_heads   = num_heads,
            mlp_dim     = mlp_dim,
            dropout     = dropout,
            name        = "blocks." + str(n) + ".",
        )
    x = LayerNormalization(
        epsilon=1e-6, name="norm"
    )(x)

2、分类部分

在这里插入图片描述
在分类部分,VIT所做的工作是利用提取到的特征进行分类。

在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征

最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类

def VisionTransformer(input_shape = [224, 224], patch_size = 16, num_layers = 12, num_features = 768, num_heads = 12, mlp_dim = 3072, 
            classes = 1000, dropout = 0.1):
    #-----------------------------------------------#
    #   224, 224, 3
    #-----------------------------------------------#
    inputs      = Input(shape = (input_shape[0], input_shape[1], 3))
    
    #-----------------------------------------------#
    #   224, 224, 3 -> 14, 14, 768
    #-----------------------------------------------#
    x           = Conv2D(num_features, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs)
    #-----------------------------------------------#
    #   14, 14, 768 -> 196, 768
    #-----------------------------------------------#
    x           = Reshape(((input_shape[0] // patch_size) * (input_shape[1] // patch_size), num_features))(x)
    #-----------------------------------------------#
    #   196, 768 -> 197, 768
    #-----------------------------------------------#
    x           = ClassToken(name="cls_token")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x           = AddPositionEmbs(input_shape, patch_size, name="pos_embed")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768  12次
    #-----------------------------------------------#
    for n in range(num_layers):
        x = TransformerBlock(
            x,
            num_features= num_features,
            num_heads   = num_heads,
            mlp_dim     = mlp_dim,
            dropout     = dropout,
            name        = "blocks." + str(n) + ".",
        )
    x = LayerNormalization(
        epsilon=1e-6, name="norm"
    )(x)
    x = Lambda(lambda v: v[:, 0], name="ExtractToken")(x)
    x = Dense(classes, name="head")(x)
    x = Softmax()(x)
    return keras.models.Model(inputs, x)

Vision Transforme的构建代码

import math

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import (Add, Conv2D, Dense, Dropout, Input,
                                     Lambda, Layer, Reshape, Softmax)


#--------------------------------------#
#   LayerNormalization
#   层标准化的实现
#--------------------------------------#
class LayerNormalization(keras.layers.Layer):
    def __init__(self,
                 center=True,
                 scale=True,
                 epsilon=None,
                 gamma_initializer='ones',
                 beta_initializer='zeros',
                 gamma_regularizer=None,
                 beta_regularizer=None,
                 gamma_constraint=None,
                 beta_constraint=None,
                 **kwargs):
        """Layer normalization layer
        See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
        :param center: Add an offset parameter if it is True.
        :param scale: Add a scale parameter if it is True.
        :param epsilon: Epsilon for calculating variance.
        :param gamma_initializer: Initializer for the gamma weight.
        :param beta_initializer: Initializer for the beta weight.
        :param gamma_regularizer: Optional regularizer for the gamma weight.
        :param beta_regularizer: Optional regularizer for the beta weight.
        :param gamma_constraint: Optional constraint for the gamma weight.
        :param beta_constraint: Optional constraint for the beta weight.
        :param kwargs:
        """
        super(LayerNormalization, self).__init__(**kwargs)
        self.supports_masking = True
        self.center = center
        self.scale = scale
        if epsilon is None:
            epsilon = K.epsilon() * K.epsilon()
        self.epsilon = epsilon
        self.gamma_initializer = keras.initializers.get(gamma_initializer)
        self.beta_initializer = keras.initializers.get(beta_initializer)
        self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
        self.beta_regularizer = keras.regularizers.get(beta_regularizer)
        self.gamma_constraint = keras.constraints.get(gamma_constraint)
        self.beta_constraint = keras.constraints.get(beta_constraint)
        self.gamma, self.beta = None, None

    def get_config(self):
        config = {
            'center': self.center,
            'scale': self.scale,
            'epsilon': self.epsilon,
            'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
            'beta_initializer': keras.initializers.serialize(self.beta_initializer),
            'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
            'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
            'gamma_constraint': keras.constraints.serialize(self.gamma_constraint),
            'beta_constraint': keras.constraints.serialize(self.beta_constraint),
        }
        base_config = super(LayerNormalization, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

    def compute_mask(self, inputs, input_mask=None):
        return input_mask

    def build(self, input_shape):
        shape = input_shape[-1:]
        if self.scale:
            self.gamma = self.add_weight(
                shape=shape,
                initializer=self.gamma_initializer,
                regularizer=self.gamma_regularizer,
                constraint=self.gamma_constraint,
                name='gamma',
            )
        if self.center:
            self.beta = self.add_weight(
                shape=shape,
                initializer=self.beta_initializer,
                regularizer=self.beta_regularizer,
                constraint=self.beta_constraint,
                name='beta',
            )
        super(LayerNormalization, self).build(input_shape)

    def call(self, inputs, training=None):
        mean = K.mean(inputs, axis=-1, keepdims=True)
        variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
        std = K.sqrt(variance + self.epsilon)
        outputs = (inputs - mean) / std
        if self.scale:
            outputs *= self.gamma
        if self.center:
            outputs += self.beta
        return outputs

#--------------------------------------#
#   Gelu激活函数的实现
#   利用近似的数学公式
#--------------------------------------#
class Gelu(Layer):
    def __init__(self, **kwargs):
        super(Gelu, self).__init__(**kwargs)
        self.supports_masking = True

    def call(self, inputs):
        return 0.5 * inputs * (1 + tf.tanh(tf.sqrt(2 / math.pi) * (inputs + 0.044715 * tf.pow(inputs, 3))))

    def get_config(self):
        config = super(Gelu, self).get_config()
        return config

    def compute_output_shape(self, input_shape):
        return input_shape

#--------------------------------------------------------------------------------------------------------------------#
#   classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。
#
#   在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。
#   此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。
#   在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。
#--------------------------------------------------------------------------------------------------------------------#
class ClassToken(Layer):
    def __init__(self, cls_initializer='zeros', cls_regularizer=None, cls_constraint=None, **kwargs):
        super(ClassToken, self).__init__(**kwargs)
        self.cls_initializer    = keras.initializers.get(cls_initializer)
        self.cls_regularizer    = keras.regularizers.get(cls_regularizer)
        self.cls_constraint     = keras.constraints.get(cls_constraint)

    def get_config(self):
        config = {
            'cls_initializer': keras.initializers.serialize(self.cls_initializer),
            'cls_regularizer': keras.regularizers.serialize(self.cls_regularizer),
            'cls_constraint': keras.constraints.serialize(self.cls_constraint),
        }
        base_config = super(ClassToken, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1] + 1, input_shape[2])

    def build(self, input_shape):
        self.num_features = input_shape[-1]
        self.cls = self.add_weight(
            shape       = (1, 1, self.num_features),
            initializer = self.cls_initializer,
            regularizer = self.cls_regularizer,
            constraint  = self.cls_constraint,
            name        = 'cls',
        )
        super(ClassToken, self).build(input_shape)

    def call(self, inputs):
        batch_size      = tf.shape(inputs)[0]
        cls_broadcasted = tf.cast(tf.broadcast_to(self.cls, [batch_size, 1, self.num_features]), dtype = inputs.dtype)
        return tf.concat([cls_broadcasted, inputs], 1)

#--------------------------------------------------------------------------------------------------------------------#
#   为网络提取到的特征添加上位置信息。
#   以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768
#   此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。
#--------------------------------------------------------------------------------------------------------------------#
class AddPositionEmbs(Layer):
    def __init__(self, image_shape, patch_size, pe_initializer='zeros', pe_regularizer=None, pe_constraint=None, **kwargs):
        super(AddPositionEmbs, self).__init__(**kwargs)
        self.image_shape        = image_shape
        self.patch_size         = patch_size
        self.pe_initializer     = keras.initializers.get(pe_initializer)
        self.pe_regularizer     = keras.regularizers.get(pe_regularizer)
        self.pe_constraint      = keras.constraints.get(pe_constraint)

    def get_config(self):
        config = {
            'pe_initializer': keras.initializers.serialize(self.pe_initializer),
            'pe_regularizer': keras.regularizers.serialize(self.pe_regularizer),
            'pe_constraint': keras.constraints.serialize(self.pe_constraint),
        }
        base_config = super(AddPositionEmbs, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def compute_output_shape(self, input_shape):
        return input_shape

    def build(self, input_shape):
        assert (len(input_shape) == 3), f"Number of dimensions should be 3, got {len(input_shape)}"
        length  = (224 // self.patch_size) * (224 // self.patch_size) + 1
        self.pe = self.add_weight(
            # shape       = [1, input_shape[1], input_shape[2]],
            shape       = [1, length, input_shape[2]],
            initializer = self.pe_initializer,
            regularizer = self.pe_regularizer,
            constraint  = self.pe_constraint,
            name        = 'pos_embedding',
        )
        super(AddPositionEmbs, self).build(input_shape)

    def call(self, inputs):
        num_features = tf.shape(inputs)[2]

        cls_token_pe = self.pe[:, 0:1, :]
        img_token_pe = self.pe[:, 1: , :]

        img_token_pe = tf.reshape(img_token_pe, [1, (224 // self.patch_size), (224 // self.patch_size), num_features])
        img_token_pe = tf.compat.v1.image.resize_images(img_token_pe, (self.image_shape[0] // self.patch_size, self.image_shape[1] // self.patch_size), tf.image.ResizeMethod.BICUBIC, align_corners=False)
        img_token_pe = tf.reshape(img_token_pe, [1, -1, num_features])
        
        pe = tf.concat([cls_token_pe, img_token_pe], axis = 1)

        return inputs + tf.cast(pe, dtype=inputs.dtype)

#--------------------------------------------------------------------------------------------------------------------#
#   Attention机制
#   将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
#   然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
#   然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(Layer):
    def __init__(self, num_features, num_heads, **kwargs):
        super(Attention, self).__init__(**kwargs)
        self.num_features   = num_features
        self.num_heads      = num_heads
        self.projection_dim = num_features // num_heads

    def get_config(self):
        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()))

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2] // 3)

    def call(self, inputs):
        #-----------------------------------------------#
        #   获得batch_size
        #-----------------------------------------------#
        bs      = tf.shape(inputs)[0]

        #-----------------------------------------------#
        #   b, 197, 3 * 768 -> b, 197, 3, 12, 64
        #-----------------------------------------------#
        inputs  = tf.reshape(inputs, [bs, -1, 3, self.num_heads, self.projection_dim])
        #-----------------------------------------------#
        #   b, 197, 3, 12, 64 -> 3, b, 12, 197, 64
        #-----------------------------------------------#
        inputs  = tf.transpose(inputs, [2, 0, 3, 1, 4])
        #-----------------------------------------------#
        #   将query, key, value划分开
        #   query     b, 12, 197, 64
        #   key       b, 12, 197, 64
        #   value     b, 12, 197, 64
        #-----------------------------------------------#
        query, key, value = inputs[0], inputs[1], inputs[2]
        #-----------------------------------------------#
        #   b, 12, 197, 64 @ b, 12, 197, 64 = b, 12, 197, 197
        #-----------------------------------------------#
        score           = tf.matmul(query, key, transpose_b=True)
        #-----------------------------------------------#
        #   进行数量级的缩放
        #-----------------------------------------------#
        scaled_score    = score / tf.math.sqrt(tf.cast(self.projection_dim, score.dtype))
        #-----------------------------------------------#
        #   b, 12, 197, 197 -> b, 12, 197, 197
        #-----------------------------------------------#
        weights         = tf.nn.softmax(scaled_score, axis=-1)
        #-----------------------------------------------#
        #   b, 12, 197, 197 @ b, 12, 197, 64 = b, 12, 197, 64
        #-----------------------------------------------#
        value          = tf.matmul(weights, value)

        #-----------------------------------------------#
        #   b, 12, 197, 64 -> b, 197, 12, 64
        #-----------------------------------------------#
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        #-----------------------------------------------#
        #   b, 197, 12, 64 -> b, 197, 768
        #-----------------------------------------------#
        output = tf.reshape(value, (bs, -1, self.num_features))
        return output

def MultiHeadSelfAttention(inputs, num_features, num_heads, dropout, name):
    #-----------------------------------------------#
    #   qkv   b, 197, 768 -> b, 197, 3 * 768
    #-----------------------------------------------#
    qkv = Dense(int(num_features * 3), name = name + "qkv")(inputs)
    #-----------------------------------------------#
    #   b, 197, 3 * 768 -> b, 197, 768
    #-----------------------------------------------#
    x   = Attention(num_features, num_heads)(qkv)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x   = Dense(num_features, name = name + "proj")(x)
    x   = Dropout(dropout)(x)
    return x

def MLP(y, num_features, mlp_dim, dropout, name):
    y = Dense(mlp_dim, name = name + "fc1")(y)
    y = Gelu()(y)
    y = Dropout(dropout)(y)
    y = Dense(num_features, name = name + "fc2")(y)
    return y

def TransformerBlock(inputs, num_features, num_heads, mlp_dim, dropout, name):
    #-----------------------------------------------#
    #   施加层标准化
    #-----------------------------------------------#
    x = LayerNormalization(epsilon=1e-6, name = name + "norm1")(inputs)
    #-----------------------------------------------#
    #   施加多头注意力机制
    #-----------------------------------------------#
    x = MultiHeadSelfAttention(x, num_features, num_heads, dropout, name = name + "attn.")
    x = Dropout(dropout)(x)
    #-----------------------------------------------#
    #   施加残差结构
    #-----------------------------------------------#
    x = Add()([x, inputs])

    #-----------------------------------------------#
    #   施加层标准化
    #-----------------------------------------------#
    y = LayerNormalization(epsilon=1e-6, name = name + "norm2")(x)
    #-----------------------------------------------#
    #   施加两次全连接
    #-----------------------------------------------#
    y = MLP(y, num_features, mlp_dim, dropout, name = name + "mlp.")
    y = Dropout(dropout)(y)
    #-----------------------------------------------#
    #   施加残差结构
    #-----------------------------------------------#
    y = Add()([x, y])
    return y

def VisionTransformer(input_shape = [224, 224], patch_size = 16, num_layers = 12, num_features = 768, num_heads = 12, mlp_dim = 3072, 
            classes = 1000, dropout = 0.1):
    #-----------------------------------------------#
    #   224, 224, 3
    #-----------------------------------------------#
    inputs      = Input(shape = (input_shape[0], input_shape[1], 3))
    
    #-----------------------------------------------#
    #   224, 224, 3 -> 14, 14, 768
    #-----------------------------------------------#
    x           = Conv2D(num_features, patch_size, strides = patch_size, padding = "valid", name = "patch_embed.proj")(inputs)
    #-----------------------------------------------#
    #   14, 14, 768 -> 196, 768
    #-----------------------------------------------#
    x           = Reshape(((input_shape[0] // patch_size) * (input_shape[1] // patch_size), num_features))(x)
    #-----------------------------------------------#
    #   196, 768 -> 197, 768
    #-----------------------------------------------#
    x           = ClassToken(name="cls_token")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768
    #-----------------------------------------------#
    x           = AddPositionEmbs(input_shape, patch_size, name="pos_embed")(x)
    #-----------------------------------------------#
    #   197, 768 -> 197, 768  12次
    #-----------------------------------------------#
    for n in range(num_layers):
        x = TransformerBlock(
            x,
            num_features= num_features,
            num_heads   = num_heads,
            mlp_dim     = mlp_dim,
            dropout     = dropout,
            name        = "blocks." + str(n) + ".",
        )
    x = LayerNormalization(
        epsilon=1e-6, name="norm"
    )(x)
    x = Lambda(lambda v: v[:, 0], name="ExtractToken")(x)
    x = Dense(classes, name="head")(x)
    x = Softmax()(x)
    return keras.models.Model(inputs, x)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

神经网络学习小记录68——Tensorflow2版 Vision Transformer(VIT)模型的复现详解 的相关文章

  • Qt中父子widget的事件传递

    以前我一直以为 在父widget上摆一个子widget后 当click子widget时 只会进入到子widget的相关事件处理函数中 比如进入到mousePressEvent 中 而不会进入到父widget的对应事件处理函数中 毕竟 cli
  • ajax跨域post请求数据_基于Python的Post请求数据爬取

    为什么做这个 和同学聊天 他想爬取一个网站的post请求 观察 该网站的post请求参数有两种类型 1 参数体放在了query中 即url拼接参数 2 body中要加入一个空的json对象 关于为什么要加入空的json对象 猜测原因为反爬虫
  • 《OSPF和IS-IS详解》一1.7 独立且平等

    本节书摘来自异步社区 OSPF和IS IS详解 一书中的第1章 第1 7节 作者 美 Jeff Doyle 更多章节内容可以访问云栖社区 异步社区 公众号查看 1 7 独立且平等 OSPF和IS IS详解与TCP IP相比 OSI协议对各国
  • shell命令之cp复制拷贝

    1 复制文件到文件中 cp file1 file2 file1 file2 表示某一文件 在当前目录下 将file1 的文件内容复制到file2 文件中 如果第二个文件不存在 则先创建文件 然后再拷贝内容 如果存在则直接覆盖 没有警告 加
  • C++ 函数指针

    include
  • 基于SSM+JSP的宠物医院信息管理系统

    项目背景 21世纪的今天 随着社会的不断发展与进步 人们对于信息科学化的认识 已由低层次向高层次发展 由原来的感性认识向理性认识提高 管理工作的重要性已逐渐被人们所认识 科学化的管理 使信息存储达到准确 快速 完善 并能提高工作管理效率 促
  • bp利率最新消息是多少,bps利率是什么意思

    武汉房贷利率最新消息2022 3月26日起 武汉房贷利率将下调48BP 首套房贷款利率为5 2 二套房为5 4 其实武汉下调房贷利率也是在意料之内 此前的利率放在全国范围内比较 其实是比较高的 那利率降低后 每月能省多少钱呢 武汉房贷利率最

随机推荐

  • SSM框架和Spring Boot+Mybatis框架的性能比较?

    SSM框架和Spring Boot Mybatis框架的性能比较 没有一个绝对的答案 因为它们的性能受到很多因素的影响 例如项目的规模 复杂度 需求 技术栈 团队水平 测试环境 测试方法等 因此 我们不能简单地说哪个框架的性能更好 而是需要
  • qt 使用uic.exe 生成ui_xxxx.h文件的方法

    自己遇到这个问题 看了下别人的回答 总是有些不太清楚 就自己完善了下 1 制作好自己的xxxx ui文件 2 确定uic exe文件的地址 比如我的就是 D Anaconda3 pkgs qt 5 9 7 vc14h73c81de 0 Li
  • 雪糕的最大数量 排序+贪心

    雪糕的最大数量 雪糕的最大数量 题目描述 样例 数据范围 思路 代码 题目描述 夏日炎炎 小男孩 Tony 想买一些雪糕消消暑 商店中新到 n 支雪糕 用长度为 n 的数组 costs 表示雪糕的定价 其中 costs i 表示第 i 支雪
  • 于仕琪老师libfacedetection最新开源代码使用测试配置

    一 首先要感谢于老师的分享 二 此教程只是方便像我这样编程小白入门使用 若有不足之处 请原谅 网上对libfacedetection的介绍已经很多了 我在这里就不进行多余的解释 直接进入主题 下载地址 https github com Sh
  • Fsm2 Fsm2

    This is a Moore state machine with two states two inputs and one output Implement this state machine This exercise is th
  • 时序预测

    时序预测 MATLAB实现DBN深度置信网络时间序列预测 目录 时序预测 MATLAB实现DBN深度置信网络时间序列预测 预测效果 基本介绍 模型描述 程序设计 参考资料 预测效果 基本介绍 BP神经网络是1968年由Rumelhart和M
  • QMainwindow中添加的其他组件无法发送消息调用槽函数

    QMainwindow中添加的其他组件无法发送消息调用槽函数 问题所在 解决办法 问题所在 include mainwindow h include ui mainwindow h include QDebug include QMessa
  • [超实用]Java返回结果的工具类

    在做项目中 处理完各种业务数据后都需要返回值告诉前端最后的操作结果 但又不能直接返回一串错误代码信息 这个时候结果处理工具类就起了有比较好的作用 在此记录下 比较简单返回结果处理方法供大家参考学习 1 结果返回处理业务类 package r
  • python123.io---双一流高校及所在省份统计

    双一流高校及所在省份统计 类型 Python 组合数据类型 字典 d 中存储了我国 42 所双一流高校及所在省份的对应关系 请以这个列表为数据变量 完善 Python 代码 统计各省份学校的数量 d 北京大学
  • vue安装Base64转码

    第一步 项目文件路径下运行 npm install save js base64 或者 cnpm install save js base64 第二步 main js文件中引入 const Base64 require js base64
  • vue——vue-video-player插件实现rtmp直播流

    更新 flash已不可再使用 大家另寻出路吧 安装前首先需要注意几个点 vue video player插件 其实就是 video js 集成到 vue 中 所以千万不要再安装 video js 可能会出错 视频流我这个项目选择rtmp格式
  • 3559摄像头

    input aoni Webcam as devices platform soc 12310000 xhci 1 usb1 1 1 1 1 1 0 input input0 yuv转 的代码 https github com 198708
  • DC/DC闭环控制的丘克(Cuk)变换电路原理设计及实验仿真

    如果将降压 Buck 变换电路和升压 Boost 变换电路的拓扑结构进行对偶变换 即Boost变换电路和Buck变换电路串联在一起得到一种新的电路拓扑结构 丘克 CUK 变换电路 如图所示 Cuk变换电路的输入和输出均有电感 增加电感的值
  • matlab画圆并生成随机数

    A区域生成随机数 画圆 t 0 pi 100 2 pi x 10 cos t 30 3 y 10 sin t 89 8 plot x y r 生成随机数 a zeros 2 8 i 1 while i lt 8 temp1 rand 1 2
  • node中间件是什么意思?

    node中间件是什么意思 2020 09 11 16 11 17分类 常见问题 Node js答疑阅读 1757 评论 0 中间件是一种独立的系统软件或服务程序 分布式应用软件借助这种软件在不同的技术之间共享资源 中间件位于客户机 服务器的
  • Spark SQL 项目:实现各区域热门商品前N统计

    一 需求1 1 需求简介这里的热门商品是从点击量的维度来看的 计算各个区域前三大热门商品 并备注上每个商品在主要城市中的分布比例 超过两个城市用其他显示 1 2 思路分析使用 sql 来完成 碰到复杂的需求 可以使用 udf 或 udaf查
  • 四位均衡磨损格雷码

    什么是均衡磨损格雷码 均衡磨损格雷码是一种与标准格雷码具有相同的迭代后只变化一个位的特性 但每一个数位变化的次数相近的编码 为什么要均衡磨损 由于继电器输出PLC比晶体管输出PLC具有更好的可靠性 如果用继电器输出的PLC代替晶体管输出PL
  • 从0开始用shell写一个tomcat日志清理脚本

    一 目的 tomcat日志随着时间的流逝会越来越大 虽然我们可以使用cronolog对tomcat输出的日志根据日期进行切割 但是日子一长 进到logs 文件夹下都是密密麻麻的日志 不好查看也浪费了大量的空间 故本文的目的是编写一个脚本 能
  • linux 0.11 int80实现,Linux0.11内核--系统中断处理程序int 0x80实现原理

    extern int sys setup 系统启动初始化设置函数 kernel blk drv hd c 71 extern int sys exit 程序退出 kernel exit c 137 extern int sys fork 创
  • 神经网络学习小记录68——Tensorflow2版 Vision Transformer(VIT)模型的复现详解

    神经网络学习小记录68 Tensorflow2版 Vision Transformer VIT 模型的复现详解 学习前言 什么是Vision Transformer VIT 代码下载 Vision Transforme的实现思路 一 整体结