read like humans 是中科大在2021年发在CVPR上的论文
视觉模型
class BaseVision(Model):
def __init__(self, config):
super().__init__(config)
if config.model_vision_backbone == 'transformer':
self.backbone = ResTranformer(config)
#restransformer = Resnet + transformer
else: self.backbone = resnet45()
if config.model_vision_attention == 'position':
self.attention = PositionAttention(
max_length=config.dataset_max_length + 1, # additional stop token
mode=mode,
)
elif config.model_vision_attention == 'attention':
self.attention = Attention(
max_length=config.dataset_max_length + 1, # additional stop token
n_feature=8*32,
)
self.cls = nn.Linear(self.out_channels, self.charset.num_classes)
if config.model_vision_checkpoint is not None:
logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
self.load(config.model_vision_checkpoint)
def forward(self, images, *args):
features = self.backbone(images) # (N, E, H, W)
attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
logits = self.cls(attn_vecs) # (N, T, C)
pt_lengths = self._get_length(logits)
return
整体流程:
Backbone(resnet45/ResTranformer) -> Attention(PositionAttention/Attention)
- Restransformer = resnet45 + transformer
- Attention 是加性模型的注意力机制:
这一块代码主要用的是SRN设计的字符注意力模块
- a = tanh(wx + uj)
- a = softmax(a)
- output = a*x
def forward(self, enc_output):
#这里的输入时enc_output为公式中的X,字符阅读顺序为公式中的j.U,W分别为线性全连接层
enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
reading_order_embed = self.f0_embedding(reading_order) # b,25,512
t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
attn = self.we(t) # b,256,25
attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
g_output = torch.bmm(attn, enc_output) # b,25,512
return g_output, attn.view(*attn.shape[:2], 8, 32)
- PositionAttention :这一块是作者的论文代码,借鉴自注意力,做的位置信息的模块。
class PositionAttention(nn.Module):
def __init__(self, max_length, in_channels=512, num_channels=64,
h=8, w=32, mode='nearest', **kwargs):
super().__init__()
self.max_length = max_length
self.k_encoder = nn.Sequential(
#这里是U-net结构的下采样部分,一共用了4层)
self.k_decoder = nn.Sequential(
#这里是U-net结构的上采样部分,一共用了4层)
self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
#pos_encoder是transformer里的正余弦的硬位置编码,不需要额外参数
self.project = nn.Linear(in_channels, in_channels)
def forward(self, x):
N, E, H, W = x.size()
k, v = x, x # (N, E, H, W)
# calculate key vector U-net结构
features = []
for i in range(0, len(self.k_encoder)):
k = self.k_encoder[i](k)
features.append(k)
for i in range(0, len(self.k_decoder) - 1):
k = self.k_decoder[i](k)
k = k + features[len(self.k_decoder) - 2 - i]
k = self.k_decoder[-1](k)
# calculate query vector
#模仿SRN做字符阅读顺序,但做法并不一致,这里用transformer的硬编码形式+FC层进行实现
# TODO q=f(q,k)
zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
q = self.pos_encoder(zeros) # (T, N, E)
q = q.permute(1, 0, 2) # (N, T, E)
q = self.project(q) # (N, T, E)
#value为原始的特征信息图
# calculate self-attention
attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
attn_scores = attn_scores / (E ** 0.5)
attn_scores = torch.softmax(attn_scores, dim=-1)
v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
return attn_vecs, attn_scores.view(N, -1, H, W)
这里在图中画的非常清晰。整体结构中为restransformer + Postion-attention的结构
Restransformer = resnet45+ transformer encoder*2
PositionAttention = key query value
- key = U-net(encoder_out)
- query = FC(Postion_Encoder(new_zeros))
- value = encoder_out
语言模型
这一块正如图中所示,query用的是字符位置,key value用的是gt的embedding信息,mask使用了对角线的mask部分
class BCNLanguage(Model):
def __init__(self, config):
super().__init__(config)
self.proj = nn.Linear(self.charset.num_classes, d_model, False)
self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
#均为transformer的正余弦硬编码
decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
activation, self_attn=self.use_self_attn, debug=self.debug)
self.model = TransformerDecoder(decoder_layer, num_layers)
self.cls = nn.Linear(d_model, self.charset.num_classes)
if config.model_language_checkpoint is not None:
logging.info(f'Read language model from {config.model_language_checkpoint}.')
self.load(config.model_language_checkpoint)
def forward(self, tokens, lengths):
"""
Args:
tokens: (N, T, C) where T is length, N is batch size and C is classes number
lengths: (N,)
"""
#transformer的正余弦的硬编码
if self.detach: tokens = tokens.detach()
embed = self.proj(tokens) # (N, T, E)
embed = embed.permute(1, 0, 2) # (T, N, E)
embed = self.token_encoder(embed) # (T, N, E)
padding_mask = self._get_padding_mask(lengths, self.max_length)
#类似视觉模型的查询硬编码pos_encoder(new_zeros)
zeros = embed.new_zeros(*embed.shape)
qeury = self.pos_encoder(zeros)
location_mask = self._get_location_mask(self.max_length, tokens.device)
output = self.model(qeury, embed,
tgt_key_padding_mask=padding_mask,
memory_mask=location_mask,
memory_key_padding_mask=padding_mask) # (T, N, E)
output = output.permute(1, 0, 2) # (N, T, E)
logits = self.cls(output) # (N, T, C)
pt_lengths = self._get_length(logits)
return res
融合模块
融合是一种动态的门控机制融合,和SRN robust scanner类似
class BaseAlignment(Model):
def __init__(self, config):
super().__init__(config)
d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
self.max_length = config.dataset_max_length + 1 # additional stop token
self.w_att = nn.Linear(2 * d_model, d_model)
self.cls = nn.Linear(d_model, self.charset.num_classes)
def forward(self, l_feature, v_feature):
"""
Args:
l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
v_feature: (N, T, E) shape the same as l_feature
l_lengths: (N,)
v_lengths: (N,)
"""
f = torch.cat((l_feature, v_feature), dim=2)
f_att = torch.sigmoid(self.w_att(f))
output = f_att * v_feature + (1 - f_att) * l_feature
logits = self.cls(output) # (N, T, C)
pt_lengths = self._get_length(logits)
return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
'name': 'alignment'}