NLP中BERT在文本二分类中的应用

2023-11-20

最近参加了一次kaggle竞赛Jigsaw Unintended Bias in Toxicity Classification,经过一个多月的努力探索,从5月20日左右到6月26日提交最终的两个kernel,在public dataset上最终排名为4%(115/3167),说实话以前也并没有怎么接触过NLP方面的东西,对深度学习的理解也不是特别深刻。
BERT是目前非常火的NLP模型,采用两段式的训练方式,分为pretrain和fine-tune两部分,pretrain部分由谷歌在TPU集群上训练完成,并给出部分模型供免费下载使用。其中包括’cased_l-24_h-1024_a-16’, ‘chinese_l-12_h-768_a-12’, ‘uncased_l-12_h-768_a-12’, ‘uncased_l-24_h-1024_a-16’, ‘multi_cased_l-12_h-768_a-12’, ‘cased_l-12_h-768_a-12’
在这里插入图片描述
github上的下载地址为
https://github.com/google-research/bert
在本次比赛中主要采用的是BERT-Base,uncased的这个model,训练全数据180万左右的样本,样本长度设为220,在四卡Titan的GPU卡上训练的时间接近6小时,BERT-Large,uncased的训练时间大概是33小时左右。uncased和cased的区别在于uncased将全部样本变为小写,而cased则要区分大小写,将cased和uncased的模型训练结果进行融合也会有一定程度的提升。
比赛的主要代码是在一个以色列大佬的kernrl上进行修改,比赛结束时已经被fork了将近1595,代码地址
https://www.kaggle.com/yuval6967/toxic-bert-plain-vanila,非常感谢这位无私的大佬

代码详解

将样本文本转为成bert-format

def convert_lines(example, max_seq_length,tokenizer):
    max_seq_length -=2
    all_tokens = []
    longer = 0
    for text in tqdm_notebook(example):
        tokens_a = tokenizer.tokenize(text)
        if len(tokens_a)>max_seq_length:
            tokens_a = tokens_a[:max_seq_length]
            longer += 1
        one_token = tokenizer.convert_tokens_to_ids(["[CLS]"]+tokens_a+["[SEP]"])+[0] * (max_seq_length - len(tokens_a))
        all_tokens.append(one_token)
    print(longer)
    return np.array(all_tokens)

应用BERT模型进行词嵌入,转换成sequences

bert_config = BertConfig('../input/bert-pretrained-models/uncased_l-12_h-768_a-12/uncased_L-12_H-768_A-12/'+'bert_config.json')
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH, cache_dir=None,do_lower_case=True)
train_df = pd.read_csv(os.path.join(Data_dir,"train.csv")).sample(num_to_load+valid_size,random_state=SEED)
train_df['comment_text'] = train_df['comment_text'].astype(str) 
sequences = convert_lines(train_df["comment_text"].fillna("DUMMY_VALUE"),MAX_SEQUENCE_LENGTH,tokenizer)
#此处可以对其进行保存,下次可以直接调用npy文件
#np.save('sequences.npy',sequences)
#sequences = np.load('../wql/wql_base_sequences.npy')

通过uncased_L-12_H-768_A-12文件下的bert_model.ckpt和bert_config.json文件,在working目录下生成pytorch_model.bin文件,并copy其中的bert_config.json到working目录下

BERT_MODEL_PATH = '../input/bert-pretrained-models/uncased_l-12_h-768_a-12/uncased_L-12_H-768_A-12/'
convert_tf_checkpoint_to_pytorch.convert_tf_checkpoint_to_pytorch(
    BERT_MODEL_PATH + 'bert_model.ckpt',
BERT_MODEL_PATH + 'bert_config.json',
WORK_DIR + 'pytorch_model.bin')
shutil.copyfile(BERT_MODEL_PATH + 'bert_config.json', WORK_DIR + 'bert_config.json')

实例化模型

train_dataset = torch.utils.data.TensorDataset(torch.tensor(X,dtype=torch.long), torch.tensor(y,dtype=torch.float))
#实例化模型,确保working目录下有bert_config.json文件和pytorch_model.bin文件,
model = BertForSequenceClassification.from_pretrained('./working/',cache_dir=None,num_labels=1)
#四卡并行计算的代码
devices = [0,1,2,3]
if len(devices) > 1:
    print('use multi gpus')
    model = nn.DataParallel(model, device_ids=devices)
    model = model.cuda(devices[0])
#定义权重衰减的参数
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
model=model.train()

设置bert训练的参数

lr=2e-5
#若batch_size大可能导致cuda memory out的问题,可以通过减低batchsize来解决,在训练large时,batchsize设过4
batch_size = 100
#用于参数更新的加速,设为2,参数更新的次数为1的1/2
accumulation_steps=2
save_steps = 1000
checkpoint = None
# 全数据训练的epoch次数
EPOCHS = 1
num_train_optimization_steps = int(EPOCHS*len(train) / batch_size / accumulation_steps)
optimizer = BertAdam(optimizer_grouped_parameters,
                     lr=lr,  #在epoch2时可适当降低lr,例如取其1/2
                     warmup=0.05,   #当lr下降的时候,也可以降低warmup,例如epoch2时设为0
                     t_total=num_train_optimization_steps)

开始进行一个epoch的全样本数据训练,batch_size为100

tq = tqdm_notebook(range(EPOCHS))
for epoch in tq:
    #将avg_loss和avg_accuracy写入txt文件
    file_name = 'loss_log_' + 'epoch' + str(epoch) + '.txt'
    file = open(file_name, 'w', encoding='utf-8')
    
    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)
    avg_loss = 0.
    avg_accuracy = 0.
    optimizer.zero_grad()   
    for (x_batch, y_batch) in tqdm(train_loader):
        y_pred = model(x_batch.cuda(), attention_mask=(x_batch>0).cuda(), labels=None)
        loss = nn.随意一个loss-function()(y_pred, y_batch.cuda())
        loss.backward()
        optimizer.step()  # Now we can do an optimizer step
        optimizer.zero_grad()
        avg_loss += loss.item() / len(train_loader)
        avg_accuracy += torch.mean(((torch.sigmoid(y_pred[:,0])>0.5) == (y_batch[:,0]>0.5).cuda()).to(torch.float) ).item()/len(train_loader)
        i += 1
        
        file.write('batch' + str(i) + '\t' + 'avg_loss' + '=' + str(avg_loss) + '\t' + 'avg_accuracy' + '=' + str(avg_accuracy) + '\n')
    file.close()
    
    file_path = output_model_file + str(epoch) +'.bin'
    #保存bert model用于迁移学习或测试
    torch.save(model.state_dict(), file_path)
    print(f'loss:{avg_loss} accuracy:{avg_accuracy}')

应用过程中的一点理解

就我目前对bert的理解和在比赛中调参的过程来浅谈一点经验,大佬贴出了bert模型的模板,几乎所有选手都采用这位大佬的模型作为原本进行改写,有几个重要点可以进行改写来提高预测的准确率

loss-function

首先是loss-function,这决定了这次比赛的成败,如果没法写出一个适合比赛的lossfunction,那么比赛最终成绩也就是可以fork到的最高分。虽然大佬在bert fine-tune中只给出了F.binary_cross_entropy_with_logits,但是在其他LSTM的kernel上却给出了下面这样比较优秀的loss

def custom_loss(data, targets):
     bce_loss_1 = nn.BCEWithLogitsLoss(weight=targets[:,1:2])(data[:,:1],targets[:,:1])
     bce_loss_2 = nn.BCEWithLogitsLoss()(data[:,2:],targets[:,2:])
     return (bce_loss_1 * loss_weight) + bce_loss_2

可以通过一定的补充,应用到bert模型中去,仅仅是这样一步就使得bert的预测准确率提高了很多

epoch

在对全样本数据进行fine-tune的过程中发现,训练到第三个epoch的时候,bert模型就发生了过拟合,通过小样本数据多次进行lr的调整之后也得到同样的结果,因此epoch的数量设置为2,比赛结束后一些高分单个bert的epoch数量也确实是2。

参数调整

在我看来fine-tune过程中可以调节的参数应该有以下几个,lr、warmup、dropout、accumulation_steps、batch_size、weight_decay,MAX_SEQUENCE_LENGTH

hidden_dropout_prob: The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention probabilities.

accumulation_steps减小从实践效果来说可以略微提升训练效果。
weight_decay 调整之后并没有收到准确率提升的效果
lr的话,第一个epoch为2e-5,第二个epoch为1e-5是相对较好的lr参数
batch_size对于结果影响并不大,越大的batchsize计算速度越快,但是对于gpu的显存要求更高
MAX_SEQUENCE_LENGTH 规定了sequence的长度,长度越长对显存要求也越高
warmup 一个高分的bert模型中,epoch0 lr为2e-5,warmup为0.05,epoch1 lr为1e-5,warmup为0,对结果有一定的提升
给出一高分的bert模型地址 https://www.kaggle.com/hanyaopeng/single-bert-base-with-0-94376
其实我对bert理解也只停留在模型的应用上,如有了解的大佬看到觉得不对的地方,希望指出错误。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

NLP中BERT在文本二分类中的应用 的相关文章

  • 波特 油炸的去梗

    为什么波特词干算法在线 http text processing com demo stem http text processing com demo stem stem fried to fri并不是fry 我不记得任何以以下结尾的单词
  • 在 python 中快速/优化 N-gram 实现

    python 中哪种 ngram 实现速度最快 我试图分析 nltk 与 scott 的 zip http locallyoptimal com blog 2013 01 20 elegant n gram Generation in py
  • 如何在keras中使用Bert作为长文本分类中的段落编码器来实现网络?

    我正在做一个长文本分类任务 文档中有超过 10000 个单词 我计划使用 Bert 作为段落编码器 然后将段落的嵌入逐步输入 BiLSTM 网络如下 输入 batch size max paragraph len max tokens pe
  • 有什么工具可以以编程方式将日语句子转换为其罗马字(语音阅读)? [关闭]

    Closed 这个问题是无关 help closed questions 目前不接受答案 Input 日本 好 Output 日本 ga sukidesu 遗憾的是 无法通过 Google Translate API 进行语音阅读 KAKA
  • 创建向量空间

    我有一个问题 我有很多文档 每一行都是由某种模式构建的 当然 我有这一系列的图案 我想创建一些向量空间 然后通过某种规则来向量这个模式 我还不知道这个规则是什么 即使这个模式像我的向量空间的 质心 然后向量当前文档的每一行 再次按照此规则
  • 如何获取与某个单词相关的相似单词?

    我正在尝试解决一个 nlp 问题 其中我有一个单词字典 例如 list 1 phone android chair netflit charger macbook laptop sony 现在 如果输入是 phone 我可以轻松地使用 in
  • 我应该如何使用 scikit learn 对以下列表进行矢量化?

    我想用 scikit 进行矢量化学习一个有列表的列表 我转到有训练文本的路径 我阅读了它们 然后我得到如下内容 corpus this is spam SPAM this is ham HAM this is nothing NOTHING
  • 在非单一维度 1 处,张量 a (2) 的大小必须与张量 b (39) 的大小匹配

    这是我第一次从事文本分类工作 我正在使用 CamemBert 进行二进制文本分类 使用 fast bert 库 该库主要受到 fastai 的启发 当我运行下面的代码时 from fast bert data cls import Bert
  • IOB 准确度和精密度之间的差异

    我正在使用命名实体识别和分块器对 NLTK 进行一些工作 我使用重新训练了分类器nltk chunk named entity py为此 我采取了以下措施 ChunkParse score IOB Accuracy 96 5 Precisi
  • 比较文本文档含义的最佳方法?

    我正在尝试找到使用人工智能和机器学习方法来比较两个文本文档的最佳方法 我使用了 TF IDF Cosine 相似度和其他相似度度量 但这会在单词 或 n gram 级别上比较文档 我正在寻找一种方法来比较meaning的文件 最好的方法是什
  • 使用我自己的训练示例训练 spaCy 现有的 POS 标记器

    我正在尝试在我自己的词典上训练现有的词性标注器 而不是从头开始 我不想创建一个 空模型 在spaCy的文档中 它说 加载您想要统计的模型 下一步是 使用add label方法将标签映射添加到标记器 但是 当我尝试加载英文小模型并添加标签图时
  • Blenderbot 微调

    我一直在尝试微调 HuggingFace 的对话模型 Blendebot 我已经尝试过官方拥抱脸网站上给出的传统方法 该方法要求我们使用 trainer train 方法来完成此操作 我使用 compile 方法尝试了它 我尝试过使用 Py
  • 否定句子的算法

    我想知道是否有人熟悉算法句子否定的任何尝试 例如 给定一个句子 这本书很好 请提供任意数量的意思相反的替代句子 例如 这本书不好 甚至 这本书不好 显然 以高精度实现这一点可能超出了当前 NLP 的范围 但我确信在这个主题上已经有了一些工作
  • 如何训练斯坦福 NLP 情感分析工具

    地狱大家 我正在使用斯坦福核心 NLP 包 我的目标是对推文直播进行情感分析 按原样使用情感分析工具对文本 态度 的分析非常差 许多积极因素被标记为中性 许多消极因素被评为积极 我已经在文本文件中获取了超过一百万条推文 但我不知道如何实际获
  • Node2vec 的工作原理

    我一直在读关于node2vec https cs stanford edu jure pubs node2vec kdd16 pdf嵌入算法 我有点困惑它是如何工作的 作为参考 node2vec 由 p 和 q 参数化 并通过模拟来自节点的
  • ANEW 字典可以用于 Quanteda 中的情感分析吗?

    我正在尝试找到一种方法来实施英语单词情感规范 荷兰语 以便使用 Quanteda 进行纵向情感分析 我最终想要的是每年的 平均情绪 以显示任何纵向趋势 在数据集中 所有单词均由 64 名编码员按照 7 分李克特量表在四个类别上进行评分 这提
  • gensim如何计算doc2vec段落向量

    我正在看这篇论文http cs stanford edu quocle paragraph vector pdf http cs stanford edu quocle paragraph vector pdf 它指出 段落向量和词向量被平
  • 使用正则表达式标记化进行 NLP 词干提取和词形还原

    定义一个函数 名为performStemAndLemma 它需要一个参数 第一个参数 textcontent 是一个字符串 编辑器中给出了函数定义代码存根 执行以下指定任务 1 对给出的所有单词进行分词textcontent 该单词应包含字
  • NLTK 中的 wordnet lemmatizer 不适用于副词 [重复]

    这个问题在这里已经有答案了 from nltk stem import WordNetLemmatizer x WordNetLemmatizer x lemmatize angrily pos r Out 41 angrily 这是 nl
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下

随机推荐

  • 国际软件项目经理的七大素质

    国际软件项目经理的七大素质 1 在一个或多个应用领域内使用整合了道德 法律和经济问题的工程方法来设计合适的解决方案 2 懂得确定客户需求并将其转换成软件需求的过程 3 履行项目经理的职责 善于处理技术和管理方面的事务 4 懂得并使用有用的项
  • 人脸特征点检测

    CVPR2016刚刚落下帷幕 本文对面部特征点定位的论文做一个简单总结 让大家快速了解该领域最新的研究进展 希望能给读者们带来启发 CVPR2016相关的文章大致可以分为三大类 处理大姿态问题 处理表情问题 处理遮挡问题 1 姿态鲁棒的人脸
  • 描述性能测试工作中的完整过程?

    有简单接触 采用的工具是Jmeter 进行轻量级的压力测试 1 确定好压力测试的功能模块 首先用Jmeter录制脚本 然后对脚本进行优化 2 对一些数据进行参数化 利用CSV导入存在txt文档里面的数据 3 设计测试场景 4 执行压力测试
  • 如何在windows的DOS窗口中正常显示中文(UTF-8字符)

    打开CMD exe命令行窗口 通过 chcp命令改变代码页 UTF 8的代码页为65001 ANSI OEM 简体中文 GBK为936 window default OEM 美国为437 如果chcp命令得到437 那么一定不能显示中文 此
  • 无法安装vmnet8虚拟网络适配器、vmware network editor未响应、注册失败,请检查账号数据库配置是否正确的解决

    文章目录 虚拟网络适配器安装 vmware network editor未响应 注册失败 请检查账号数据库配置是否正确的解决 关于第一次安装虚拟机的 全文约 423 字 预计阅读时长 2分钟 虚拟网络适配器安装 vmware network
  • rol/ror in c++

    template
  • 20天拿下华为OD笔试之【BFS】2023Q1A-微服务的集成测试【闭着眼睛学数理化】全网注释最详细分类最全的华为OD真题题解

    BFS 2023Q1A 微服务的集成测试 题目描述与示例 题目描述 现在有 n 个容器服务 服务的启动可能有一定的依赖性 有些服务启动没有依赖 其次服务自身启动加载会消耗一些时间 给你一个 nxn 的二维矩阵 useTime 其中 useT
  • simulink仿真adc采样和epwm输出基础知识讲解

    F28027 12位ADC 2的y次方 tbclk 计数时钟的频率 tprd 一个周期内记得个数 1 tbclk 每次计一个数的时间 一个pwm周期的时间 pwm的周期 时基计数器 CRT 计数时钟由系统时钟分频来的 比较寄存器 CMR 决
  • 大数据、数据分析和数据挖掘的区别

    大数据 数据分析 数据挖掘的区别是 大数据是互联网的海量数据挖掘 而数据挖掘更多是针对内部企业行业小众化的数据挖掘 数据分析就是进行做出针对性的分析和诊断 大数据需要分析的是趋势和发展 数据挖掘主要发现的是问题和诊断 1 大数据 big d
  • 软件项目管理的平衡原则和高效原则

    1 平衡原则 在我们讨论软件项目为什么会失败时 列出了很多的原因 答案有很多 如管理问题 技术问题 人员问题等等 但是 有一个根本的问题是最容易被忽视的 也是软件系统的用户 软件开发商 销售代理商最不愿证实的 那就是 需求 资源 工期 质量
  • 计算机网络 网络层——IP数据报 详记

    IP 数据报的格式 一个 IP 数据报由首部和数据两部分组成 首部的前一部分是固定长度 共 20 字节 是所有 IP 数据报必须具有的 在首部的固定部分的后面是一些可选字段 其长度是可变的 IP数据报首部的固定部分中的各字段 版本 占4位
  • 信号量机制

    简介 信号量是一种数据结构 信号量的值与相应资源的使用情况有关 信号量的值由P V操作改变 常用信号量 整型信号量 整型信号量S的等待 唤醒机制 P V操作 wait S while S lt 0 do no op s signal S S
  • python字符串与列表

    字符串 字符串定义 输入输出 定义 切片是指对操作的对象截取其中一部分的操作 适用范围 字符串 列表 元组都支持切片操作 切片的语法 起始下标 结束 步长 字符串中的索引是从 0 开始的 最后一个元素的索引是 1 字符串的常见操作 查找 f
  • centos7搭建ftp服务器及ftp配置讲解

    ftp 即文件传输 它是INTERNET上仍然常用的最老的网络协议之一 它为系统提供了通过网络与远程服务器传输的简单方法 FTP服务器包的名称为vsftpd 一 vsftpd安装 并简单配置启动 安装 很简单 一句话 yum install
  • Socket接收数据耗时

    1 遇到问题 首先说明一下我遇到的问题 服务端传递Byte数组 长度在900w 客户端接收时会耗时10s 我的代码是这样的 2 Socket缓冲区 http t zoukankan com bigberg p 7747419 html 每个
  • 即刻掌握python格式化输出的三种方式 (o゜▽゜)o☆

    目录 1 f 转化的格式化输出方式 2 格式化输出的方法 3 format 格式化输出的方法 1 f 转化的格式化输出方式 只需要在我们要格式化输出的内容开头引号的前面加上 f 在字符串内要转义的内容用 括起来即可 模板 print f x
  • 企业微信登录-前端实现

    企业微信登录 企业微信登录 前端具体实现 下面代码中配置项的字段具体用途说明可以阅读企业微信开发者说明文档 我们通过提供的企业微信登录组件来进行站内登录 下面是我封装的登录组件以及使用方法 weChatLogin vue 封装的组件
  • hudi-hive-sync

    hudi hive sync Syncing to Hive 有两种方式 在hudi 写时同步 使用run sync tool sh 脚本进行同步 1 代码同步 改方法最终会同步元数据 但是会抛出异常 val spark SparkSess
  • spring:AOP面向切面编程+事务管理

    目录 一 Aop Aspect Oriented Programming 二 springAOP实现 1 XML实现 2 注解实现 三 spring事务管理 一 Aop Aspect Oriented Programming 将程序中的非业
  • NLP中BERT在文本二分类中的应用

    最近参加了一次kaggle竞赛Jigsaw Unintended Bias in Toxicity Classification 经过一个多月的努力探索 从5月20日左右到6月26日提交最终的两个kernel 在public dataset