预训练 Transformer 模型的配置更改

2024-03-28

我正在尝试为重整变压器实现一个分类头。分类头工作正常,但是当我尝试更改配置参数之一 - config.axis_pos_shape 即模型的序列长度参数时,它会抛出错误;

Reformer.embeddings.position_embeddings.weights.0 的大小不匹配:从检查点复制形状为 torch.Size([512, 1, 64]) 的参数,当前模型中的形状为 torch.Size([64, 1, 64] )。 Reformer.embeddings.position_embeddings.weights.1 的大小不匹配:从检查点复制形状为 torch.Size([1, 1024, 192]) 的参数,当前模型中的形状为 torch.Size([1, 128, 192] )。

配置:

{
  "architectures": [
    "ReformerForSequenceClassification"
  ],
  "attention_head_size": 64,
  "attention_probs_dropout_prob": 0.1,
  "attn_layers": [
    "local",
    "lsh",
    "local",
    "lsh",
    "local",
    "lsh"
  ],
  "axial_norm_std": 1.0,
  "axial_pos_embds": true,
  "axial_pos_embds_dim": [
    64,
    192
  ],
  "axial_pos_shape": [
    64,
    256
  ],
  "chunk_size_feed_forward": 0,
  "chunk_size_lm_head": 0,
  "eos_token_id": 2,
  "feed_forward_size": 512,
  "hash_seed": null,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.05,
  "hidden_size": 256,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": true,
  "layer_norm_eps": 1e-12,
  "local_attention_probs_dropout_prob": 0.05,
  "local_attn_chunk_length": 64,
  "local_num_chunks_after": 0,
  "local_num_chunks_before": 1,
  "lsh_attention_probs_dropout_prob": 0.0,
  "lsh_attn_chunk_length": 64,
  "lsh_num_chunks_after": 0,
  "lsh_num_chunks_before": 1,
  "max_position_embeddings": 8192,
  "model_type": "reformer",
  "num_attention_heads": 2,
  "num_buckets": [
    64,
    128
  ],
  "num_chunks_after": 0,
  "num_chunks_before": 1,
  "num_hashes": 1,
  "num_hidden_layers": 6,
  "output_past": true,
  "pad_token_id": 0,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 100
    }
  },
  "vocab_size": 320
}

Python代码:

config = ReformerConfig()
config.max_position_embeddings = 8192
config.axial_pos_shape=[64, 128]

#config = ReformerConfig.from_pretrained('./cnp/config.json', output_attention=True)

model = ReformerForSequenceClassification(config)
model.load_state_dict(torch.load("./cnp/pytorch_model.bin"))

我遇到了同样的问题,尝试将 Reformer 预训练中使用的默认最大序列长度 65536 (128*512) 的大小减半。

正如@cronoik 提到的,你必须:

  1. 负载预训练塑身机
  2. 通过删除不必要的重量来调整其大小以满足您的需要
  3. 保存这个新模型
  4. 加载这个新模型来执行您想要的任务

这些不必要的权重是来自位置嵌入层的权重。在 Reformer 模型中,使用轴向位置编码策略来学习位置嵌入(而不是像 BERT 这样的固定嵌入)。轴向位置编码以一种内存有效的方式存储位置嵌入,使用两个小张量而不是一个大张量。

然而,位置嵌入的思想仍然完全相同,即为每个位置获得不同的嵌入。

也就是说,理论上(如果我在某个地方误解了,请纠正我),删除最后一个位置嵌入以匹配您的自定义最大序列长度不会损害性能。你可以参考这个帖子来自 HuggingFace https://huggingface.co/blog/reformer查看轴向位置编码的更详细描述并了解在哪里截断位置嵌入张量。

我已成功调整大小并使用自定义最大长度为 32768 (128*256) 的 Reformer,代码如下:

# Load intial pretrained model
model = ReformerForSequenceClassification.from_pretrained('google/reformer-enwik8', num_labels=2)

# Reshape Axial Position Embeddings layer to match desired max seq length       
model.reformer.embeddings.position_embeddings.weights[1] = torch.nn.Parameter(model.reformer.embeddings.position_embeddings.weights[1][0][:256])

# Update the config file to match custom max seq length
model.config.axial_pos_shape = 128, 256
model.config.max_position_embeddings = 128*256 # 32768

# Save model with custom max length
output_model_path = "path/to/model"
model.save_pretrained(output_model_path)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

预训练 Transformer 模型的配置更改 的相关文章

随机推荐

  • 在 php 中集成 payfort api 时遇到问题

    我正在关注 https docs start payfort com references api https docs start payfort com references api 实施 Payfort 付款 Api 的文档 但我遇到
  • Kendo 自动完成显示两个建议列表

    我的 Kendo 自动完成控件成功检索 Json 列表 不幸的是 它调用了 MVC 控制器方法两次并创建了两个建议列表 重复列表直接显示在第一个列表后面 当从第一个建议列表中选择一个值时 该列表会消失 但重复列表仍然可见 我正在使用自动完成
  • Angular“=”范围不适用于驼峰命名法

    我是指令的范围属性 我使用时效果很好show作为属性名称 span span
  • 如何在 pip 安装期间编译 C++ 依赖项?

    我想让我的 python 代码可以使用 pip 但是 我的代码依赖于另一个不可 pip 的库 所以 当用户调用时我需要以某种方式编译源代码pip install 我怎样才能做到这一点 我无法通过简单的谷歌搜索找到好的参考资料 我建议看看 l
  • 使用请求对象 Flask 获取 json 响应

    网络服务 app route get details def getDetails cur execute select from employee rows cur fetchall columns desc 0 for desc in
  • CSS / HTML 导航和徽标位于同一行

    我不知道如何将它们放在同一条线上 http codepen io anon pen dovZdQ http codepen io anon pen dovZdQ div class navigation bar div img src lo
  • 如何修复/调整 ggplot geom_tile 中每个带的宽度

    这是我的问题的示例数据 sampledata lt matrix c 1 60 1 60 rep 0 1 each 60 sample 1 3 120 replace T ncol 3 colnames sampledata lt c Ti
  • 如何进行递归子文件夹搜索并返回列表中的文件?

    我正在编写一个脚本 以递归方式遍历主文件夹中的子文件夹并构建特定文件类型的列表 我的脚本有问题 目前设置如下 for root subFolder files in os walk PATH for item in files if ite
  • Jquery AJAX:服务器端验证失败时如何显示Flash错误消息?

    我正在使用 Jquery 表单插件通过 ajax 提交表单 我已经在我的模型中的服务器端设置了验证 现在 当验证失败时 我想使用 ajax 向用户显示相同的 flash error 消息 如果验证成功 我可以显示 flash notice
  • Xcode 卡在索引上

    我已经工作了两个月的项目无缘无故停止工作 因为 Xcode 卡在 索引 上 我无法再构建该项目了 如果我尝试构建 Xcode 就会冻结 我必须强制退出 这种情况仅发生在该项目中 我尝试清理所有派生数据 但没有帮助 我正在使用 Xcode 4
  • jqgrid 更改单元格值并保持编辑模式

    我在网格中使用内联编辑 在某些情况下我想更改列内单元格的值 我用 setCell 更改它 效果很好 我的问题是 更改后 单元格失去了编辑模式 而该行的所有其他单元格都处于编辑模式 我想在更改单元格后将其保持在编辑模式 现在我所做的是保存该行
  • 种子中的 DHT

    我正在编写一个 P2P 实现 我希望将其去中心化 然而我在掌握如何做时遇到了一些困难DHT https en wikipedia org wiki Distributed hash table在像 BitTorrent 这样的协议中是有效的
  • 帮助正确计算atan2

    我需要计算线之间的角度 我需要计算atan 所以我正在使用这样的代码 static inline CGFloat angleBetweenLinesInRadians2 CGPoint line1Start CGPoint line1End
  • python中“追加”和“+”有什么区别? [复制]

    这个问题在这里已经有答案了 我不知道有什么区别f and g 功能中f 每当调用函数时 列表 L 就会累积 但在功能上g 它不是 def f a L L append 2 print L def g a L L L 2 print L pr
  • SQL Server 2008中的递归同表查询

    我在 SQL Server 2008 数据库中有下表 Id Name ParentFolder 1 Europe NULL 2 Asia NULL 3 Germany 1 4 UK 1 5 China 2 6 India 2 7 Scotl
  • echo 函数跳转到 Div 之外

    我创建了一个用于 gettext 翻译的函数 该函数位于头文件中 function ex text echo gettext text 当我使用函数 ex 时它会翻译该函数中的任何文本 效果很好 尽管当我在另一个内部有 div 的函数中使用
  • 使用 Apache Lucene 对 MySQL 数据库建立索引,并保持它们同步

    当MySQL中添加一个新项目时 它也必须被Lucene索引 当现有项目从 MySQL 中删除时 它也必须从 Lucene 的索引中删除 这个想法是编写一个脚本 通过调度程序 例如 CRON 任务 每 x 分钟调用一次 这是保持 MySQL
  • 简单的 Perl websocket 客户端

    我正在尝试用 Perl 编写一个简单的 websocket 客户端 use Protocol WebSocket Client my client Protocol WebSocket gt new url gt ws myserver p
  • 使用多核的 Numpy np.einsum 数组乘法

    我用MKL编译了numpy 1 6 2和scipy 希望有更好的性能 目前我有一个严重依赖 np einsum 的代码 并且我被告知 einsum 不适用于 MKL 因为几乎没有矢量化 所以我想用 np dot 和切片重新编写一些代码 只是
  • 预训练 Transformer 模型的配置更改

    我正在尝试为重整变压器实现一个分类头 分类头工作正常 但是当我尝试更改配置参数之一 config axis pos shape 即模型的序列长度参数时 它会抛出错误 Reformer embeddings position embeddin