bert结构模型的转换及[unusedxx]的不拆token

2023-11-16

前沿

业界主流的模型结构包括tensorflow和pytorch,很多时候两者的模型需要转换成中间格式,比如onnx,另外在tokenized的时候需要保留[unusedx]不被分词,但默认的是会分词的,这里记录一下处理方式。

torch格式转onnc

torch转onnx方法很多,这里介绍两种方式

方法1

# model_path为torch保存的文件,onnx_path为保存的文件路径
def lower_level(model_path, onnx_path="bert_std.onnx"):
    # load model and tokenizer
    added_token = ["[unused%s]" % i for i in range(100)]
    print("added_token:", added_token[:10])
    tokenizer = AutoTokenizer.from_pretrained(model_path, additional_special_tokens=added_token)
    dummy_model_input = tokenizer("hello bert", return_tensors="pt")
    unused_input = tokenizer("hello bert[unused17]", return_tensors="pt")

    print("dummy_model_input", dummy_model_input)
    print("unused_input:", unused_input)
    model = AutoModelForMaskedLM.from_pretrained(model_path)

    # export
    torch.onnx.export(
        model,
        tuple(dummy_model_input.values()),
        f=onnx_path,
        input_names=['input_ids', 'attention_mask'],
        output_names=['logits'],
        dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'},
                      'attention_mask': {0: 'batch_size', 1: 'sequence'},
                      'logits': {0: 'batch_size', 1: 'sequence'}},
        do_constant_folding=True,
        opset_version=13,
    )
    print("over")

方法2

def middle_level(model_path, onnx_path="bert_std.onnx"):
    from pathlib import Path
    import transformers
    from transformers.onnx import FeaturesManager
    from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

    # load model and tokenizer
    feature = "sequence-classification"
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # load config
    model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model, feature=feature)
    onnx_config = model_onnx_config(model.config)

    # export
    onnx_inputs, onnx_outputs = transformers.onnx.export(
        preprocessor=tokenizer,
        model=model,
        config=onnx_config,
        opset=13,
        output=Path(onnx_path)
    )
    print("onnx_inputs:", onnx_inputs)
    print("onnx_outputs:", onnx_outputs)
    print("over")

保留[unused9]不分词

transformers模块

在AutoTokenizer.from_pretrained增加additional_special_tokens参数,如:

    added_token = ["[unused%s]" % i for i in range(100)]
    print("added_token:", added_token[:10])
    tokenizer = AutoTokenizer.from_pretrained(model_path, additional_special_tokens=added_token)

完整代码如下:

def lower_level(model_path, onnx_path="bert_std.onnx"):
    # load model and tokenizer
    added_token = ["[unused%s]" % i for i in range(100)]
    print("added_token:", added_token[:10])
    tokenizer = AutoTokenizer.from_pretrained(model_path, additional_special_tokens=added_token)
    dummy_model_input = tokenizer("hello bert", return_tensors="pt")
    unused_input = tokenizer("hello bert[unused17]", return_tensors="pt")

    print("dummy_model_input", dummy_model_input)
    print("unused_input:", unused_input)
    model = AutoModelForMaskedLM.from_pretrained(model_path)
    print("over")

tensorflow模块

		preprocessor = hub.load(bert_preprocess_path)
		okenize = tfm.nlp.layers.BertTokenizer(vocab_file=vocab_path, lower_case=True, 
                tokenizer_kwargs=dict(preserve_unused_token=True, token_out_type=tf.int32))

        bert_pack_inputs = hub.KerasLayer(
            preprocessor.bert_pack_inputs,
            arguments=dict(seq_length=seq_length))  # Optional argument.
       encoder = TFAutoModel.from_pretrained(checkpoint_dir, from_pt=True)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

bert结构模型的转换及[unusedxx]的不拆token 的相关文章

随机推荐

  • Vue3 —— 使用Vite配置环境变量

    文章目录 一 为什么要配置环境变量 二 在Vite中配置环境变量 1 环境变量和模式 2 环境变量 3 生产环境替换 4 env 文件 总结 一 为什么要配置环境变量 在一个产品的前端开发过程中 一般来说会经历本地开发 测试脚本 开发自测
  • Spring boot定制个性化banner(七彩佛祖版)

    1 在项目的src main resources目录下创建banner txt文件 2 创建完成banner txt文件后 即可以在文件中放入需要自定义的任意字符图案 本次以佛祖图案为例 代码直接复制放到banner txt文件即可 不用做
  • 2021年蓝桥杯Python常见考点【持续更新ing】

    目录 一 常用技巧 一 输入输出 1 一行输入数值 2 多行输入 二 列表 1 存储多行 用 2 从多行数字 转变为二维列表 3 怎样将以下列表转化为整数 三 元组 四 集合 二 常见内置函数 一 itertools 二 数学函数 三 数据
  • Redis主从复制与Redis集群

    Redis主从复制与Redis集群 前言 一 主从复制 1 是什么 2 能干嘛 3 怎么玩 主从复制 4 新建redis conf配置文件 5 主从集群常用3种 1 主从模式一 一主二从 2 主从模式二 薪火相传 3 主从模式三 反客为主
  • Python循环的技巧

    Python的for循环是coder最常用的语句之一 如果只是简单地对容器循环遍历 那便会少了很多美好的体验 像下面这样 for i in range 10 print i python提供了很多用于循环的技巧 这些方法能让代码更加简洁美观
  • [Linux] 输入命令ls -laF后的各字段含义解析

    在登陆Ubuntu之后 我们切换超级管理用户root su root 然后切换到其所在的主目录 cd 然后以该目录下的所有文件以及文件夹为例进行介绍 我们输入命令查看该目录下面的所有文件以及文件夹 包括隐藏文件 ls laF 然后显示的内容
  • 回调函数使用

    https www cnblogs com shenwen p 9046482 html
  • 用c语言简单实现通讯录(详解和具体代码)

    前言 一 明确通讯录的功能 1 查找通讯录上的姓名 性别 电话和住址 2 可以增加 删除或修改相关信息 二 如何实现通讯录的功能 1 使用struct函数 2 实现通讯录的步骤 1 初始化通讯录并打印目录 2 实现增加信息与展示通讯录 3
  • [ 注意力机制 ] 经典网络模型2——CBAM 详解与复现

    Author Horizon Max 编程技巧篇 各种操作小结 机器视觉篇 会变魔术 OpenCV 深度学习篇 简单入门 PyTorch 神经网络篇 经典网络模型 算法篇 再忙也别忘了 LeetCode 注意力机制 经典网络模型2 CBAM
  • mysql导出数据为文本,MySQL 文本文件的导入导出数据的方法

    搜索热词 MysqL写入数据通常用insert语句 如 insert into person values 张三 20 李四 21 王五 70 但有时为了更快速地插入大批量数据或交换数据 需要从文本中导入数据或导出数据到文本 一 建立测试表
  • 【TensorFlow】TensorBoard的使用(一)

    概述 TensorBoard是一个可视化工具 它可以用来展示网络图 张量的指标变化 张量的分布情况等 特别是在训练网络的时候 我们可以设置不同的参数 比如 权重W 偏置B 卷积层数 全连接层数等 使用TensorBoader可以很直观的帮我
  • 关于spring integration jpa 使用druid 连接池 不可恢复问题排查

    背景 2023年6月10日 测试说生产环境报错 有个job 没执行 我打开服务就报如下错 却看不到代码在哪报错 由于比较忙 直接暴力重启了应用 问题解决 2023年6月17日 测试说生产环境报错 有个job 又没执行 依旧是如上的错 等我有
  • 使用exe4j打包exe

    首先 需要下载一个exe4j的软件 网址 http www softpedia com get Authoring tools Setup creators exe4j shtml 现在主要说一下怎么打exe的过程 1 打开安装好的exe4
  • unity新动画系统之IK动画

    国际惯例 先来一段说明 IK动画全称Inverse Kinematics 即反向动力学 牵一发而动全身的既视感 代码如下 using System Collections using System Collections Generic u
  • Mac使用工具tree,打印项目目录树到Markdown

    主要使用tree这个工具 安装方法 brew install tree 使用方法是 tree 参数 目录 常用方法 显示当前目录及子目录结构 tree 只显示目录 不显示文件 tree d 保存打印的结果到文件 tree gt my pro
  • Python安装包的三种方式: pip在线安装、setup.py安装、whl文件安装

    之前在自己电脑上一直用 pip instal xx 来安装python的包 后来因为公司电脑的网络连接限制 无法通过正常联网的方式安装 所以总结了几种在线 或 离线安装包的方式 具体如下 在线安装 pip install xx 正常在线安装
  • Android android:configChanges的简介

    AndroidManifest xml 文件中 在声明Activity时 会有这样一个属性设置 即 android configChanges 现在就来简单介绍下吧 程序在运行时 一些设备的配置可能会改变 如 横竖屏的切换 键盘的可用性等
  • cannot find -lstdc++解决方案

    今天在ubuntu12 10 64位下编译32位android 4 04源码时报错 usr bin ld skipping incompatible usr lib gcc x86 64 linux gnu 4 5 4 libstdc so
  • Linux:进程(概念)

    学习目标 1 认识冯诺依曼系统 2 认识操作系统概念与定位 系统调用接口 3 理解进程的概念 PCB 4 理解进程的状态 fork创建进程 僵尸进程及孤儿进程 5 了解进程的调度 优先级 竞争性 独立性 并行 并发 6 理解环境变量 熟悉常
  • bert结构模型的转换及[unusedxx]的不拆token

    这里写自定义目录标题 前沿 torch格式转onnc 方法1 方法2 保留 unused9 不分词 transformers模块 tensorflow模块 前沿 业界主流的模型结构包括tensorflow和pytorch 很多时候两者的模型