【OCR文本识别系列】Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Tex

2023-11-06

read like humans 是中科大在2021年发在CVPR上的论文

在这里插入图片描述

视觉模型

dd

class BaseVision(Model):
    def __init__(self, config):
        super().__init__(config)
        if config.model_vision_backbone == 'transformer':
            self.backbone = ResTranformer(config)
			#restransformer = Resnet + transformer            
        else: self.backbone = resnet45()
        
        if config.model_vision_attention == 'position':
            self.attention = PositionAttention(
                max_length=config.dataset_max_length + 1,  # additional stop token
                mode=mode,
            )
        elif config.model_vision_attention == 'attention':
            self.attention = Attention(
                max_length=config.dataset_max_length + 1,  # additional stop token
                n_feature=8*32,
            )
        self.cls = nn.Linear(self.out_channels, self.charset.num_classes)

        if config.model_vision_checkpoint is not None:
            logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
            self.load(config.model_vision_checkpoint)

    def forward(self, images, *args):
        features = self.backbone(images)  # (N, E, H, W)
        attn_vecs, attn_scores = self.attention(features)  # (N, T, E), (N, T, H, W)
        logits = self.cls(attn_vecs) # (N, T, C)
        pt_lengths = self._get_length(logits)

        return 

整体流程:
Backbone(resnet45/ResTranformer) -> Attention(PositionAttention/Attention)

  • Restransformer = resnet45 + transformer
  • Attention 是加性模型的注意力机制:
    这一块代码主要用的是SRN设计的字符注意力模块
  1. a = tanh(wx + uj)
  2. a = softmax(a)
  3. output = a*x
    def forward(self, enc_output):
    	#这里的输入时enc_output为公式中的X,字符阅读顺序为公式中的j.U,W分别为线性全连接层
        enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
        reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
        reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1)  # (S,) -> (B, S)
        reading_order_embed = self.f0_embedding(reading_order)  # b,25,512

        t = self.w0(reading_order_embed.permute(0, 2, 1))  # b,512,256
        t = self.active(t.permute(0, 2, 1) + self.wv(enc_output))  # b,256,512

        attn = self.we(t)  # b,256,25
        attn = self.softmax(attn.permute(0, 2, 1))  # b,25,256
        g_output = torch.bmm(attn, enc_output)  # b,25,512
        return g_output, attn.view(*attn.shape[:2], 8, 32)
  • PositionAttention :这一块是作者的论文代码,借鉴自注意力,做的位置信息的模块。
class PositionAttention(nn.Module):
    def __init__(self, max_length, in_channels=512, num_channels=64, 
                 h=8, w=32, mode='nearest', **kwargs):
        super().__init__()
        self.max_length = max_length
        self.k_encoder = nn.Sequential(
            #这里是U-net结构的下采样部分,一共用了4层)
        self.k_decoder = nn.Sequential(
            #这里是U-net结构的上采样部分,一共用了4层)

        self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
        #pos_encoder是transformer里的正余弦的硬位置编码,不需要额外参数
        self.project = nn.Linear(in_channels, in_channels)

    def forward(self, x):
        N, E, H, W = x.size()
        k, v = x, x  # (N, E, H, W)

        # calculate key vector U-net结构
        features = []
        for i in range(0, len(self.k_encoder)):
            k = self.k_encoder[i](k)
            features.append(k)
        for i in range(0, len(self.k_decoder) - 1):
            k = self.k_decoder[i](k)
            k = k + features[len(self.k_decoder) - 2 - i]
        k = self.k_decoder[-1](k)

        # calculate query vector 
        #模仿SRN做字符阅读顺序,但做法并不一致,这里用transformer的硬编码形式+FC层进行实现
        # TODO q=f(q,k)
        zeros = x.new_zeros((self.max_length, N, E))  # (T, N, E)
        q = self.pos_encoder(zeros)  # (T, N, E)
        q = q.permute(1, 0, 2)  # (N, T, E)
        q = self.project(q)  # (N, T, E)
        
        #value为原始的特征信息图
        
        # calculate self-attention
        attn_scores = torch.bmm(q, k.flatten(2, 3))  # (N, T, (H*W))
        attn_scores = attn_scores / (E ** 0.5)
        attn_scores = torch.softmax(attn_scores, dim=-1)

        v = v.permute(0, 2, 3, 1).view(N, -1, E)  # (N, (H*W), E)
        attn_vecs = torch.bmm(attn_scores, v)  # (N, T, E)

        return attn_vecs, attn_scores.view(N, -1, H, W)

这里在图中画的非常清晰。整体结构中为restransformer + Postion-attention的结构

Restransformer = resnet45+ transformer encoder*2
PositionAttention = key query value

  • key = U-net(encoder_out)
  • query = FC(Postion_Encoder(new_zeros))
  • value = encoder_out

语言模型

在这里插入图片描述
这一块正如图中所示,query用的是字符位置,key value用的是gt的embedding信息,mask使用了对角线的mask部分

class BCNLanguage(Model):
    def __init__(self, config):
        super().__init__(config)
  
        self.proj = nn.Linear(self.charset.num_classes, d_model, False)
        self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
        self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
        #均为transformer的正余弦硬编码
        decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout, 
                activation, self_attn=self.use_self_attn, debug=self.debug)
        self.model = TransformerDecoder(decoder_layer, num_layers)

        self.cls = nn.Linear(d_model, self.charset.num_classes)

        if config.model_language_checkpoint is not None:
            logging.info(f'Read language model from {config.model_language_checkpoint}.')
            self.load(config.model_language_checkpoint)

    def forward(self, tokens, lengths):
        """
        Args:
            tokens: (N, T, C) where T is length, N is batch size and C is classes number
            lengths: (N,)
        """
        #transformer的正余弦的硬编码
        if self.detach: tokens = tokens.detach()
        embed = self.proj(tokens)  # (N, T, E)
        embed = embed.permute(1, 0, 2)  # (T, N, E)
        embed = self.token_encoder(embed)  # (T, N, E)
        padding_mask = self._get_padding_mask(lengths, self.max_length)

        #类似视觉模型的查询硬编码pos_encoder(new_zeros)
        zeros = embed.new_zeros(*embed.shape)
        qeury = self.pos_encoder(zeros)
        location_mask = self._get_location_mask(self.max_length, tokens.device)
        output = self.model(qeury, embed,
                tgt_key_padding_mask=padding_mask,
                memory_mask=location_mask,
                memory_key_padding_mask=padding_mask)  # (T, N, E)
        output = output.permute(1, 0, 2)  # (N, T, E)

        logits = self.cls(output)  # (N, T, C)
        pt_lengths = self._get_length(logits)

        return res

融合模块

融合是一种动态的门控机制融合,和SRN robust scanner类似

class BaseAlignment(Model):
    def __init__(self, config):
        super().__init__(config)
        d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])

        self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
        self.max_length = config.dataset_max_length + 1  # additional stop token
        self.w_att = nn.Linear(2 * d_model, d_model)
        self.cls = nn.Linear(d_model, self.charset.num_classes)

    def forward(self, l_feature, v_feature):
        """
        Args:
            l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
            v_feature: (N, T, E) shape the same as l_feature 
            l_lengths: (N,)
            v_lengths: (N,)
        """
        f = torch.cat((l_feature, v_feature), dim=2)
        f_att = torch.sigmoid(self.w_att(f))
        output = f_att * v_feature + (1 - f_att) * l_feature

        logits = self.cls(output)  # (N, T, C)
        pt_lengths = self._get_length(logits)

        return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
                'name': 'alignment'}
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【OCR文本识别系列】Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Tex 的相关文章

  • Java8 新特性使用

    文章目录 lamda接口语法 内置函数式接口 方法引用 语法 使用要求 构造器引用以及数组引用 构造器引用 数组引用 Stream流 串行流和并行流 创建Stream流的四种方式 流的形式 中间操作 筛选和切片 映射 排序 终止操作 匹配与
  • element ui中表格输入框回车跳到另一输入框

    纯输入框的回车 组件代码
  • m与n的数字运算python_M与N的数学运算

    描述 用户输入两个数M和N 其中N是整数 计算M和N的5种数学运算结果 并依次输出 结果间用空格分隔 5种数学运算分别是 M与N的和 M与N的乘积 M的N次幂 M除N的余数 M和N中较大的值
  • Linux内核中 SPI以太网W5500问题

    Linux内核中 SPI以太网W5500问题 Linux内核驱动中将W5500 W5200和W5100集成到了一起 本人只用到了W5500 问题描述 绿灯LinkLED和黄灯ACTLED交替闪烁 而正常状态应该为LinkLED常亮 ifco
  • 物联网仪表ADW300无线通讯灵活安装

    安科瑞 戈静怡 随着物联网新兴技术的发展 边缘智能 无线通讯 物联协议等越来越多的被应用 智能电表顺势而为 应用物联网技术 发展成如今的智能终端 物联网电力仪表 ADW300无线计量仪表主要用于计量低压网络的三相有功电能 具有RS485通讯
  • Redis代码示例

    RedisTemplate 如果想要在java中使用Redis相关的数据结构 要先注入RedisTemplate Autowired private RedisTemplate
  • RISC-V单周期处理器设计(基本介绍和数据通路)(一)

    一 设计步骤 1 处理器设计的基本规范 指令 包括处理器需要具有那些功能 需要注意的是处理器的功能是由指令唯一确定 2 处理器设计方案 包括数据通路和控制器 数据通路 指令执行过程中 数据所经过的路径 包括路径中的部件 它是指令的执行部件
  • 搭建cocos2d游戏引擎环境HelloWorld!

    转载自 黑米GameDev街区 原文链接 http www himigame com iphone cocos2d 415 html 本章节主要介绍cocos2D引擎的开发环境搭建 第一步 下载cocos2d iphone最新版本 地址如下
  • 清华2019最新AI发展报告出炉!400页干货,13大领域一文看懂

    2019 12 08 20 36 36 当前 人工智能正处在爆发期 我国在人工智能领域的科学技术研究和产业发展起步稍晚 但在最近十余年的时间里抓住了机遇 进入了快速发展阶段 在这个过程中 技术突破和创造性高端人才对人工智能的发展起着至关重要
  • 腾讯6大核心业务打造坚固护城河

    1998年11月 腾讯公司成立 腾讯之名取自小马哥名字与 网络通讯 这一初始业务定位 创始人为马 张 陈 许 曾五人 作为公司长期的核心决策层 分工明确 团队稳定 2000年OICQ更名为QQ 03年腾讯进入游戏领域 04年在港上市 11年
  • Qt字符编码要点

    1 首先明确几种常用的编码 UTF 8 GBK UNICODE UNICODE 明确概念0 我是汉字 是C语言中的字符串 它是char型的窄字符串 上面的例子可写为const char str 我是汉字 QString a str 或QSt
  • Vue技术—列表过滤

    div h2 人员列表 h2 div
  • mysql连接字最多查询_MySQL中应该多表连接查询一次取数据库还是多次查询取数据?...

    MySQL中应该多表连接查询一次取数据库还是多次查询取数据 具体的case在下面 三个表的字段如下 webcast cast表 id organizerId title startDate endDate number date等 webc
  • undefined reference to 问题的一种解决方法

    问题描述 今天在移植mcal中的部分代码时 然后进行编译 在链接的步骤里面 遇见了报错 undefined reference to xxx 未定义的函数被引用的问题 实际上那个函数是被定义了的 不过那个函数比较特殊 是一个被extern
  • qt day3

    1 gt 登录框实现注册功能 将注册的结果放入文件中 君子作业 2 gt 完成文本编辑器的保存工作 widget h ifndef WIDGET H define WIDGET H include
  • Vue只弹一次的弹框cookie

  • .Net 中的托管函数 Delegate

    1 什么是托管函数 托管函数是一个对类里面的某个函数的一个引用 它自己并没有具体的函数定义 只是指向某个函数实现 2 与C Delphi的横向比较 在C 和Delphi中与托管函数对应的类型是函数指针 形式如下 C typedef int
  • aaaadafdsafdashfhdskhk

    aaaadafdsafdashfhdskhksdfdsfd

随机推荐

  • DirectSound播放PCM(可播放实时采集的音频数据)

    前言 该篇整理的原始来源为http blog csdn net leixiaohua1020 article details 40540147 非常感谢该博主的无私奉献 写了不少关于不同多媒体库的博文 让我这个小白学习到不少 现在将其整理是
  • Labview+Hsl通讯(与欧姆龙NX1P2通讯)

    通过和欧姆龙客服沟通 NX1P2不支持fins tcp与OPC UA 但是支持FINS UDP通讯 没办将就用吧 这里PLC IP 192 168 250 0 1 端口 9600 下面是测试图 PC端的端口随意填就行 不要和PLC端口重复就
  • [1150]Linux服务器上使用rz命令上传文件报:Segmentation Fault

    使用rz命令上传一张几十KB的图片 一直上传不了服务器 试了sz命令却是没问题 一直在排查是否Linux服务器对上传命令有所限制 最终未果 接着想到是否是硬盘空间不足了 使用df h命令一看 果然硬盘没空间了 使用率达到了100 接着使用
  • tf.nn 激活函数

    tf nn sigmoid tf nn tanh tanh函数解决了Sigmoid函数的不是zero centered输出问题 但梯度消失 gradient vanishing 的问题和幂运算的问题仍然存在 tf nn relu tf nn
  • 访问数据库_常用的数据库访问方式是什么?

    常用的数据库访问方式是什么 ASP 访问数据库的方式有哪些 在 ASP 中可以通过三种方式访问数据库 1 IDC Internet Database Connector 方式 2 ADO ActiveX Data Objects 方式 3
  • 如何跳出ajax,让AJAX运作中跳出来Loading

    CSS部分 CSS一部分 div loadingdiv height 100 width 100 100 遮盖网页页面 防止user在loading时开展别的实际操作 position fixed z index 99999 须超过网页页面
  • flink架构

    JobManager控制应用执行的主进程 jobMaster处理单独的job ResuorseManager分配task slots Dispatcher提交应用 Web UI展示监控执行信息 TaskManager包含task slots
  • 技术人员要拿百万年薪,必须要经历这9个段位

    很多人都问 技术人员如何成长 每个阶段又是怎样的 如何才能走出当前的迷茫 实现自我的突破 所以我结合我自己10多年的从业经验 总结了技术人员成长的9个段位 希望对大家的职业生涯 有所帮助 1 刚接触编程的时候 会觉得这是个很神奇东西 平淡的
  • 认识计算机性能指标

    计算机性能指标 存储器的容量 MAR 的位数反应存储单元的数量 MDR 的位数反应每个存储单元的大小 cpu性能指标 高电平1代表1个数字脉冲 低电平0也代表1个数字脉冲 1个cpu时钟周期 1个数字脉冲信号 通常单位微秒 纳秒 cpu主频
  • Python 的简洁表达:for语句,if语句,3变量值互换

    Python 语句遵循的是简洁为美的原则 所以有很多表达方式非常简洁 同时在熟练以后也不会牺牲可读性 一 for 语句 比如我们要求 n 2 n 2 n2 的值的列表 其中 n n
  • IntelliJ Plugin-Gradle 配置

    Step 1 使用Gradle构建IntelliJ plugin工程 Step 2 调整配置信息 plugins id java id org jetbrains intellij version 0 4 8 group xxx versi
  • 模拟人脑:迄今最大规模4个实验,人工智能的救赎之路?(附PDF公号发“模拟人脑”下载)

    模拟人脑 迄今最大规模4个实验 人工智能的救赎之路 附PDF公号发 模拟人脑 下载 许铁 科学Sciences 今天 科学Sciences导读 公众最早了解模拟大脑的事件是 1997年 电脑 深蓝 击败世界象棋冠军 2011年 计算机 沃森
  • Windows PostgreSql创建服务

    一 创建服务 使用管理员cmd命令窗口在bin目录下 执行命令 pg ctl exe register N 服务名称 D 安装data数据目录 二 删除服务 执行命令 sc delete 服务名 三 启动服务 执行命令 sc start 服
  • 基于LinuxC语言实现的TCP多线程/进程服务器

    多进程并发服务器 设计流程 框架一 使用信号回收僵尸进程 void handler int sig while waitpid 1 NULL WNOHANG gt 0 int main 回收僵尸进程 siganl 17 handler 创建
  • 三阶魔方中心互换_三阶魔方入门

    一 魔方的构造 这里只讲常见的普通三阶魔方 三阶魔方一共有26个色块 分三个层 从上到下分别为顶层 中间层 底层 26个色块按位置分为中心块 角色块 棱色块 中心块6个 角色块8个 棱色块12个 中心块为每一个面最中央的色块 角色块为每一条
  • electron使用new Worker写入文件导致浏览器崩溃

    main js let data1 let data2 for let i 0 i lt 500000 i let j i 500 0 60000 0 data1 push j 200 Math random 100 data2 push
  • git下载别人的代码

    1 打开别人github上的源码地址 点击Clone or download 2 拷贝链接 3 通过git clone URL来下载 此外 还可以通过pwd来查看当前目录的路径 一般都是下载到当前目录下 注意 前提是自己的github上已添
  • 【剑指offer】数据结构——树

    目录 数据结构 树 直接解 剑指offer 07 重建二叉树 剑指offer 08 二叉树的下一个结点 剑指offer 26 树的子结构 剑指offer 27 二叉树的镜像 剑指offer 28 对称的二叉树 剑指offer 32 1 从上
  • Opencv中circle(),line(),cv2.rectangle(),cv2.putText()

    Opencv中circle line cv2 rectangle cv2 putText 一 circle 画圆 cv2 circle 方法用于在任何图像上绘制圆 用法 cv2 circle image center radius colo
  • 【OCR文本识别系列】Read Like Humans: Autonomous, Bidirectional and Iterative Language Modeling for Scene Tex

    read like humans 是中科大在2021年发在CVPR上的论文 论文链接 链接 代码链接 链接 视觉模型 class BaseVision Model def init self config super init config