Bert的MLM任务loss原理

2023-11-10

         bert预训练有MLM和NSP两个任务,其中MLM是类似于“完形填空”的方式,对一个句子里的15%的词进行mask,通过双向transformer+feedforward+rediual_add+layer_norm完成对每个词的embedding编码,然后对mask的这个词进行预测,预测过程相当于做多分类,类别的个数是词汇的总个数,将mask的词的emb经过MLP变换生成在每个类别词汇上的logits 概率,label是mask位置上真实词在整个词汇上的one-hot编码,将logits和label计算交叉熵,又做了加权平均,即可得出MLM的loss,过程如下:

         源码中的get_masked_lm_output()方法过程解析:

1、输入input_tensor:[batch,maskednums, embed_size]

2、经过线性变换+layernorm:[batch,maskednums, 768]

3、logits:将embedding table[3万,768]作为变换矩阵,计算logits:[batch,maskednums, 3万],相当于得出每个被盖住词在3万个词上的概率,其实就是3万个类别多分类

4、labels:one-hot编码[maskednums,3万]

5、计算交叉熵:[bactch, maskednums]

6、loss:加权平均得出一个实数

def get_masked_lm_output(bert_config, input_tensor, output_weights, positions,
                         label_ids, label_weights):
  """Get loss and log probs for the masked LM."""
  input_tensor = gather_indexes(input_tensor, positions)

  with tf.variable_scope("cls/predictions"):
    # We apply one more non-linear transformation before the output layer.
    # This matrix is not used after pre-training.
    with tf.variable_scope("transform"):
      input_tensor = tf.layers.dense(
          input_tensor,
          units=bert_config.hidden_size,
          activation=modeling.get_activation(bert_config.hidden_act),
          kernel_initializer=modeling.create_initializer(
              bert_config.initializer_range))
      input_tensor = modeling.layer_norm(input_tensor)

    # The output weights are the same as the input embeddings, but there is
    # an output-only bias for each token.
    output_bias = tf.get_variable(
        "output_bias",
        shape=[bert_config.vocab_size],
        initializer=tf.zeros_initializer())
    logits = tf.matmul(input_tensor, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)

    label_ids = tf.reshape(label_ids, [-1])
    label_weights = tf.reshape(label_weights, [-1])

    one_hot_labels = tf.one_hot(
        label_ids, depth=bert_config.vocab_size, dtype=tf.float32)

    # The `positions` tensor might be zero-padded (if the sequence is too
    # short to have the maximum number of predictions). The `label_weights`
    # tensor has a value of 1.0 for every real prediction and 0.0 for the
    # padding predictions.
    per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1])
    numerator = tf.reduce_sum(label_weights * per_example_loss)
    denominator = tf.reduce_sum(label_weights) + 1e-5
    loss = numerator / denominator

  return (loss, per_example_loss, log_probs)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Bert的MLM任务loss原理 的相关文章

  • Zabbix整合钉钉实战

    1 基本告警思路 钉钉类似于微信 但是偏向于办公方向 可以通过钉钉的群机器人功能 实现将告警信息通过机器人发送到钉钉群 群机器人是钉钉群的高级扩展功能 群机器人可以将第三方服务的信息聚合到群聊中 实现自动化的信息同步 2 添加钉钉机器人 首
  • YOLOV5学习笔记(一)——原理概述

    目录 一 目标检测概述 1 1 数据集介绍 1 2 性能指标 1 2 1 混淆矩阵 1 2 2 IOU 边界框回归 1 2 3 AP mAP 1 2 4 检测速度 1 3 YOLO发展史 1 3 1 算法思想 1 3 2 YOLOv5网络架
  • NLP学习(九)文本向量化doc2vec及案例实现-Python3实现

    doc2vec基本原理 A distributed memory model 训练句向量的方法和词向量的方法非常类似 训练词向量的核心思想就是说可以根据每个单词的上下文预测 也就是说上下文的单词对是有影响的 那么同理 可以用同样的方法训练d
  • opengl库中gl glu glut glaux的区别

    GL GLU GLUT的关系 OpenGL中的gl库是核心库 glu是实用库 glut是实用工具库 gl是核心 glu是对gl的部分封装 glut是OpenGL的跨平台工具库 gl中包含了最基本的3D函数 而glu似乎对gl的辅助 如果算数
  • AIGC零基础30天学习——CLIP与zero-shot

    在前面对CLIP的学习中 对zero shot prediction环节一直有一些疑惑 zero shot是什么 它该如何进行操作 1 zero shot是什么 zero shot是指零样本学习 和zero shot相关联的概念包括many
  • 用java写一个简单的考勤管理系统

    我可以为您提供一个参考 您可以使用Java语言来编写一个简单的考勤管理系统 具体的步骤如下 1 定义考勤类 它包括考勤日期 考勤时间以及考勤状态 2 定义考勤管理类 实现考勤的添加 删除 查看等操作 3 定义用户类 它包括用户的姓名 职位
  • GO语言网络编程(并发编程)并发介绍,Goroutine

    GO语言网络编程 并发编程 并发介绍 Goroutine 1 并发介绍 进程和线程 A 进程是程序在操作系统中的一次执行过程 系统进行资源分配和调度的一个独立单位 B 线程是进程的一个执行实体 是CPU调度和分派的基本单位 它是比进程更小的

随机推荐

  • 深入源码分析Spring boot 集成Pagehelper

    引入依赖
  • Unity 代码实现多个Image帧动画播放

    using UnityEngine using System Collections using System Collections Generic using UnityEngine UI using System RequireCom
  • 小米9008刷机授权补丁_学会手机刷机这几种方法,这些问题都可以迎刃而解

    智能手机bug很多 尤其是安卓系统的手机 不仅玩游戏卡 运行慢 有时候手机无法正常开机 或者是无法开机 一些功能不能使用 有的是手机系统造成的 只要通过给手机刷机 这些问题都可以迎刃而解 很多人刷机一般都是去手机维修店 但是你看完这篇文章
  • golang中strings.split的使用,分割

    package main import fmt strings func main fmt Printf q n strings Split a b b fmt Printf q n strings Split a boy a girl a
  • 图技术在 LLM 下的应用:知识图谱驱动的大语言模型 Llama Index

    LLM 如火如荼地发展了大半年 各类大模型和相关框架也逐步成型 可被大家应用到业务实际中 在这个过程中 我们可能会遇到一类问题是 现有的哪些数据 如何更好地与 LLM 对接上 像是大家都在用的知识图谱 现在的图谱该如何借助大模型 发挥更大的
  • Jenkins构建(8):Jenkins 执行远程shell :Send files or execute commands over SSH

    Jenkins 执行远程shell Send files or execute commands over SSH 一 远程执行shell命令 python脚本 1 环境配置 管理Jenkins gt Configure System 模块
  • idea 国内插件库_IDEA 超实用使用技巧分享(长篇)

    前言 工欲善其事 必先利其器 最近受部门的邀请 给入职新人统一培训IDEA 发现有很多新人虽然日常开发使用的是IDEA 但是还是很多好用的技巧没有用到 只是用到一些基本的功能 蛮浪费IDEA这个优秀的IDE 同时 在这次分享之后 本人自己也
  • 排序算法——基数排序(C语言)

    基数排序的概念 什么是基数排序 基数排序是一种和快排 归并 希尔等等不一样的排序 它不需要比较和移动就可以完成整型的排序 它是时间复杂度是O K N 空间复杂度是O K M 基数排序的思想 基数排序是一种借助多关键字的思想对单逻辑关键字进行
  • python爬虫从零开始_python爬虫---从零开始(一)初识爬虫

    我们开始来谈谈python的爬虫 1 什么是爬虫 网络爬虫是一种按照一定的规则 自动地抓取万维网信息的程序或者脚本 另外一些不常使用的名字还有蚂蚁 自动索引 模拟程序或者蠕虫 互联网犹如一个大蜘蛛网 我们的爬虫就犹如一个蜘蛛 当在互联网遇到
  • 计算机网络mask是什么意思,mask是什么意思

    你知道mask是什么意思吗 可能你在网络上偶尔会看到这样的词 但网络上的新词多到数不清 根本没有时间去仔细去了解 下面就让我们带你一起 来详细了解一下mask是什么意思吧 mask是什么意思 假面具 伪装 遮蔽物 All guests wo
  • ppt拖动就复制_PPT快捷键丨这些快捷键可助你事半功倍

    工欲善其事 必先利其器 如果你常用的快捷键只有Ctrl C Ctrl V 那你要仔细看下这篇文章了 PS 这个键盘是PPT做的哦 后台回复 键盘 获取源文件 快捷键 顾名思义就是快和方便 所以能熟练使用PPT快捷键 会使我们变得更高效 桔子
  • Shiro和Spring Security对比

    一 Shiro简介 1 什么是Shiro Shiro是apache旗下一个开源框架 它将软件系统的安全认证相关的功能抽取出来 实现用户身份 认证 权限授权 加密 会话管理等功能 组成了一个通用的安全认证框架 2 Shiro 的特点 Shir
  • VMware虚拟机连不上网络,最详细排查解决方案

    虚拟机连不上网 ping某个网站时并显示此信息 ping www baidu com Name or service not known 步骤一 排查Windows自身问题 有可能这个问题不是你虚拟机有问题 而是装虚拟机的Windows本身
  • 【数据结构】数组和字符串

    本文是对leetbook 数组和字符串 学习完成后的总结 数组和字符串 数组简介 寻找数组的中心索引 搜索插入位置 合并区间 二维数组简介 旋转矩阵 零矩阵 对角线遍历 字符串简介 最长公共前缀 最长回文子串 翻转字符串里的单词 实现 st
  • 前端开发同步和异步的区别?

    在前端开发中 同步 一般指的是在代码运行的过程中 从上到下逐步运行代码 每一部分代码运行完成之后 下面的代码才能开始运行 异步 指的是当我们需要一些代码在执行的时候不会影响其他代码的执行 也就是在执行代码的同时 可以进行其他的代码的执行 不
  • 转:安装MySQL遇到MySQL Server Instance Configuration Wizard未响应的解决办法

    问题 安装了MySQL之后进入配置界面的时候 总会显示 MySQL Server Instance Configuration Wizard未响应 一直卡死 解决办法 Win7系统中 以管理员的权限登录系统 将C盘的ProgramData中
  • postman接口测试要点及错误总结

    本文主要针对接口测试工具postman出现的常见错误及解决办法进行了总结 请求分类及具体传参介绍 GET请求 GET请求是最常见的请求类型 最常用于向服务器查询信息 必要时 可以将查询字符串参数追加到URL的末尾 以便将信息发送给服务器 P
  • 机器学习的特征工程

    机器学习的特征工程 一 数据集 Kaggle网址 https www kaggle com datasets UCI数据集网址 http archive ics uci edu ml scikit learn网址 http scikit l
  • 蓝桥杯-基础训练-龟兔赛跑预测

    问题描述 话说这个世界上有各种各样的兔子和乌龟 但是研究发现 所有的兔子和乌龟都有一个共同的特点 喜欢赛跑 于是世界上各个角落都不断在发生着乌龟和兔子的比赛 小华对此很感兴趣 于是决定研究不同兔子和乌龟的赛跑 他发现 兔子虽然跑比乌龟快 但
  • Bert的MLM任务loss原理

    bert预训练有MLM和NSP两个任务 其中MLM是类似于 完形填空 的方式 对一个句子里的15 的词进行mask 通过双向transformer feedforward rediual add layer norm完成对每个词的embed