Hugging Face——MLM预训练掩码语言模型方法

2023-11-06

对于许多涉及 Transformer 模型的 NLP 程序, 我们可以简单地从 Hugging Face Hub 中获取一个预训练的模型, 然后直接在你的数据上对其进行微调, 以完成手头的任务。只要用于预训练的语料库与用于微调的语料库没有太大区别, 迁移学习通常会产生很好的结果。

但是, 在某些情况下, 你需要先微调数据上的语言模型, 然后再训练特定于任务的head。

这种在域内数据上微调预训练语言模型的过程通常称为 领域适应。 它于 2018 年由 ULMFiT推广, 这是使迁移学习真正适用于 NLP 的首批神经架构之一 (基于 LSTM)。 下图显示了使用 ULMFiT 进行域自适应的示例; 在本节中, 我们将做类似的事情, 但使用的是 Transformer 而不是 LSTM!
在这里插入图片描述

如何训练?

加载模型

依托于Hugging Face,根据提供的API,选择AutoModelForMaskedLM用于加载模型:

from transformers import AutoModelForMaskedLM

model_checkpoint = "Hub中的仓库/模型名称"
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)

AutoModelForMaskedLM在源码中已经默认给配置好了MLM的 Head,这里以BERT为例——BertForMaskedLM

源码中可以先锁定到forward方法中:

def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        """

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        prediction_scores = self.cls(sequence_output)  # 这是相比于BERTModel不同的地方

代码中self.cls是这样定义的:

self.cls = BertOnlyMLMHead(config)

代码中BertOnlyMLMHead是这样定义的:

class BertOnlyMLMHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.predictions = BertLMPredictionHead(config)

    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
        prediction_scores = self.predictions(sequence_output)
        return prediction_scores

继续套娃,BertLMPredictionHead是这样定义的:

class BertLMPredictionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias

    def forward(self, hidden_states):
        hidden_states = self.transform(hidden_states)
        hidden_states = self.decoder(hidden_states)
        return hidden_states

终于破案了

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Hugging Face——MLM预训练掩码语言模型方法 的相关文章

随机推荐

  • Vue监听滚动实现导航栏锚点定位

    父级 div div div div div div div div div 子组件 ul class compreDiagnosisInfoUl li class active li ul div
  • 变分自编码器 (Variational Autoencoders, VAEs)

    Contents Dimensionality reduction PCA and autoencoders Dimensionality reduction Principal components analysis PCA Autoen
  • VMware Workstation克隆虚拟机(CentOS系统)

    问题 VMware Workstation克隆虚拟机 CentOS系统 下面具体说明下 如何在VMware Workstation中克隆一台已有的虚拟机 方法 如下图所示 要克隆名叫 CentOS7 base的一台虚拟 右键该虚拟机选择 管
  • userManager.do不可用问题

    dao层和业务层都可以成功添加 但在页面上调用Servlet显示不可用 原因 userAdd jsp被放到了web的子目录usermanager下面了 所以定位不到servlet资源了 修改 userManager do 成功解决问题
  • 怎么查EI论文的检索号

    论文题目 Study on joint probability density algorithm in multi sensor data fusion 那位热心人帮忙查一下检索号是多少啊 去http www engineeringvil
  • OpenCV调用cv2.imshow显示错误 “The function is not implemented. Rebuild the library with Windows”的解决办法

    在Windows环境下 已经安装了opencv python 读取图片 处理都没有问题 唯独显示就会出错 说 The function is not implemented Rebuild the library with Windows
  • Android 蓝牙开发(六)hfp连接

    转载请注明出处 http blog csdn net vnanyesheshou article details 71106622 本文已授权微信公众号 fanfan程序媛 独家发布 扫一扫文章底部的二维码或在微信搜索 fanfan程序媛
  • nginx+php 出现404错误解决方法

    http www 51ou com browse linuxwt 32263 html 错误日志 装好 nginx 1 0 5 与 php 5 3 6 php fpm 迫不及待的测试 info php 但是只返回了空白页 什么也没有输出 以
  • 关于解决Linux(ubuntu) 中不允许root用户ssh远程登录的问题

    当我们在ubuntu中登录ssh的时候 会出现如下问题 是因为系统默认禁止root用户登录ssh 此时我们可以这样解决 1 首先 按Ctrl C退出密码输入界面 2 然后输入 su 一定是su 不是su 3 编辑sshd config文件
  • VS快捷键大全(超详细)

    本文主要介绍VS编译器下的快捷键 文章目录 1 项目相关的快捷键 2 编辑相关的键盘快捷键 3 导航相关的键盘快捷键 4 调试相关的键盘快捷键 5 搜索相关的键盘快捷键 1 项目相关的快捷键 Ctrl Shift B 生成项目 Ctrl A
  • 2020.11.14 数组的相对排序

    2020 11 14 数组的相对排序 题目描述 给你两个数组 arr1 和 arr2 arr2 中的元素各不相同 arr2 中的每个元素都出现在 arr1 中 对 arr1 中的元素进行排序 使 arr1 中项的相对顺序和 arr2 中的相
  • CORE-ESP32C3

    目录 参考博文 项目官方地址 显示效果 硬件准备 软件版本 日志及soc下载工具 软件使用 接线示意图 硬件接线 一 Elink驱动管脚适配 二 天气信息获取 API使用方式 接口格式 注意需不需要tls http apicn luatos
  • Error[Pe147]: declaration is incompatible with "__nounwind __interwork __softfp unsigned long __get_

    原文地址 http www emcu it ARM Compiler IAR IAR tips and tricks html IAR tips and tricks Home Page STM32 home page CMSIS buil
  • vue项目兼容IE11

    1 npm安装babel polyfill npm install babel polyfill save dev 2 在入口文件main js中引入 import babel polyfill 3 如果也是用了官方脚手架vue cli 还
  • Tomcat项目500报错处理方法之一

    今天做的项目出现了Tomcat500错误 根据错误提示是 java lang ClassNotFoundException 搜索了很多解决方法 最终找到一个解决方案 https blog csdn net u011008029 articl
  • 类#是公共的,应在名为#.java的文件中声明

    1 如果类A被声明为公共的 public 那么必须将类A保存在名为A java的文件中 2 反之 在一个文件中最多包含一个顶级的公共类 并且该公共类的名字与文件名相同 比如文件A java中 允许定义一个或多个类 但最多允许一个顶级的公共类
  • 又一波Microsemi招聘信息

    Position Manager ASIC Design Business Unit ESC PerformanceStorage Location Shanghai China Youwill build and lead an IC d
  • Deeplabcut----(3)新建自己的训练(多只动物)

    多动物的标注比单动物复杂了N倍 动物回来会穿梭 以至于一时不知道是哪只 需要对比着视频观看找对应的 1 打开deeplabcut 新建训练 2 编辑配置文件设置动物数量 关节 骨架等 单击Edit config file开始编辑 可以按照自
  • RT-DETR原理与简介(干翻YOLO的最新目标检测项目)

    概述与简介 RT DETR是一种实时目标检测模型 它结合了两种经典的目标检测方法 Transformer和DETR Detection Transformer Transformer是一种用于序列建模的神经网络架构 最初是用于自然语言处理
  • Hugging Face——MLM预训练掩码语言模型方法

    对于许多涉及 Transformer 模型的 NLP 程序 我们可以简单地从 Hugging Face Hub 中获取一个预训练的模型 然后直接在你的数据上对其进行微调 以完成手头的任务 只要用于预训练的语料库与用于微调的语料库没有太大区别