bert下albert_chinese_small实现文本分类

2023-11-05

import torch
from transformers import BertTokenizer, BertModel, BertConfig
import numpy as np
from torch.utils import data
from sklearn.model_selection import train_test_split
import pandas as pd

pretrained = r'albert_chinese_small'
tokenizer = BertTokenizer.from_pretrained(pretrained)
# 判断传递的预训练模型地址是否在PRETRAINED_VOCAB_ARCHIVE_MAP中,若不在则会将这个路径+VOCAB_NAME拼接成vocab.txt的路径
model = BertModel.from_pretrained(pretrained)
config = BertConfig.from_pretrained(pretrained)

inputtext = "今天心情情很好啊"
tokenized_text = tokenizer.encode(inputtext)
input_ids = torch.tensor(tokenized_text).view(-1, len(tokenized_text))
outputs = model(input_ids)
# outputs[0].shape, outputs[1].shape
# config.hidden_size, config.embedding_size, config.max_length

class AlbertClassfier(torch.nn.Module):
    def __init__(self, bert_model, bert_config, num_class):
        super(AlbertClassfier, self).__init__()
        self.bert_model = bert_model
        self.dropout = torch.nn.Dropout(0.4)
        self.fc1 = torch.nn.Linear(bert_config.hidden_size, bert_config.hidden_size)
        self.fc2 = torch.nn.Linear(bert_config.hidden_size, num_class)

    def forward(self, token_ids):
        # 不太明白为啥要用句向量,那命名实体识别的时候模型怎么写
        bert_out = self.bert_model(token_ids)[1]  # 句向量 [batch_size,hidden_size]
        bert_out = self.dropout(bert_out)
        bert_out = self.fc1(bert_out)
        bert_out = self.dropout(bert_out)
        bert_out = self.fc2(bert_out)  # [batch_size,num_class]
        return bert_out


albertBertClassifier = AlbertClassfier(model, config, 2)
device = torch.device("cuda:0") if torch.cuda.is_available() else 'cpu'
albertBertClassifier = albertBertClassifier.to(device)


def get_train_test_data(pos_file_path, neg_file_path, max_length=100, test_size=0.2):
    data = []
    label = []
    pos_df = pd.read_excel(pos_file_path, header=None)
    pos_df.columns = ['content']
    for index, row in pos_df.iterrows():
        row = row['content']
        ids = tokenizer.encode(row.strip(), max_length=max_length, padding='max_length', truncation=True)
        data.append(ids)
        label.append(1)

    neg_df = pd.read_excel(neg_file_path, header=None)
    neg_df.columns = ['content']
    for index, row in neg_df.iterrows():
        row = row['content']
        ids = tokenizer.encode(row.strip(), max_length=max_length, padding='max_length', truncation=True)
        data.append(ids)
        label.append(0)
    X_train, X_test, y_train, y_test = train_test_split(data, label, test_size=test_size, shuffle=True)
    return (X_train, y_train), (X_test, y_test)


pos_file_path = r"pos_sim.xlsx"
neg_file_path = r"neg_sim.xlsx"
(X_train, y_train), (X_test, y_test) = get_train_test_data(pos_file_path, neg_file_path)
len(X_train), len(X_test), len(y_train), len(y_test), len(X_train[0])
" ".join([str(i) for i in X_train[0]])
tokenizer.decode(X_train[0]), y_train[0]


class DataGen(data.Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return np.array(self.data[index]), np.array(self.label[index])


train_dataset = DataGen(X_train, y_train)
test_dataset = DataGen(X_test, y_test)
train_dataloader = data.DataLoader(train_dataset, batch_size=10)
test_dataloader = data.DataLoader(test_dataset, batch_size=10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(albertBertClassifier.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
for epoch in range(50):
    loss_sum = 0.0
    accu = 0
    albertBertClassifier.train()
    for step, (token_ids, label) in enumerate(train_dataloader):
        token_ids = token_ids.to(device)
        label = label.to(device).long()
        out = albertBertClassifier(token_ids)
        loss = criterion(out, label)
        optimizer.zero_grad()
        loss.backward()  # 反向传播
        optimizer.step()  # 梯度更新
        loss_sum += loss.cpu().data.numpy()
        accu += (out.argmax(1) == label).sum().cpu().data.numpy()

    test_loss_sum = 0.0
    test_accu = 0
    albertBertClassifier.eval()
    for step, (token_ids, label) in enumerate(test_dataloader):
        token_ids = token_ids.to(device)
        label = label.to(device).long()
        with torch.no_grad():
            out = albertBertClassifier(token_ids)
            loss = criterion(out, label)
            test_loss_sum += loss.cpu().data.numpy()
            test_accu += (out.argmax(1) == label).sum().cpu().data.numpy()
    print("epoch % d,train loss:%f,train acc:%f,test loss:%f,test acc:%f" % (
        epoch, loss_sum / len(train_dataset), accu / len(train_dataset), test_loss_sum / len(test_dataset),
        test_accu / len(test_dataset)))


对应文件我将存在csdn资源库中,或者加群753035545,我将上传资源

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

bert下albert_chinese_small实现文本分类 的相关文章

  • 【Linux】bert-base-cased 不在缓存需要从 s3 上下载的问题

    CONTENT bert base cased 手动下载 xff0c 更名 位置 home xxxx cache torch pytorch transformers bert base cased 下载地址文件名同名 json 文件内容h
  • pytorch BERT文本分类保姆级教学

    pytorch BERT文本分类保姆级教学 本文主要依赖的工具为huggingface的transformers xff0c 更详细的解释可以查阅文档 定义模型 模型定义主要是tokenizer config和model的定义 xff0c
  • Bert演变总结

  • NLP领域的预训练模型(Transformer、BERT、GPT-2等)

    英文原文链接 https www analyticsvidhya com blog 2019 03 pretrained models get started nlp 1 介 绍 如今 自然语言处理 Natural Language Pro
  • bert处理超过512的长文本(强制改变位置编码position_embeddings )

    最近在做 NER 任务的时候 需要处理最长为 1024 个字符的文本 BERT 模型最长的位置编码是 512 个字符 超过512的部分没有位置编码可以用了 处理措施 将bert的位置编码认为修改成 1 1024 前512维使用原始的 1 5
  • 基于Keras_bert模型的Bert使用与字词预测

    基于Keras bert模型的Bert使用与字词预测 学习参考杨老师的博客 请支持原文 一 Keras bert 基础知识 1 1 kert bert库安装 1 2 Tokenizer文本拆分 1 3 训练和使用 构建模型 模型训练 使用模
  • bert模型蒸馏实战

    由于bert模型参数很大 在用到生产环境中推理效率难以满足要求 因此经常需要将模型进行压缩 常用的模型压缩的方法有剪枝 蒸馏和量化等方法 比较容易实现的方法为知识蒸馏 下面便介绍如何将bert模型进行蒸馏 一 知识蒸馏原理 模型蒸馏的目的是
  • 【论文阅读笔记】BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding

    BERT的出现使我们终于可以在一个大数据集上训练号一个深的神经网络 应用在很多NLP应用上面 BERT Pre training of Deep Bidirectional Transformers for Language Underst
  • Bert CNN信息抽取

    Github参考代码 https github com Wangpeiyi9979 IE Bert CNN 数据集来源于百度2019语言与智能技术竞赛 在上述链接中提供下载方式 感谢作者提供的代码 1 信息抽取任务 给定schema约束集合
  • BERT:Pre-training of Deep Bidirectional Transformers for Language Understanding

    BERT 个人翻译 并不权威 paper https arxiv org pdf 1810 04805 pdf BERT Pre training of Deep Bidirectional Transformers for Languag
  • Bert的MLM任务loss原理

    bert预训练有MLM和NSP两个任务 其中MLM是类似于 完形填空 的方式 对一个句子里的15 的词进行mask 通过双向transformer feedforward rediual add layer norm完成对每个词的embed
  • Bert模型做多标签文本分类

    Bert模型做多标签文本分类 参考链接 BERT模型的详细介绍 图解BERT模型 从零开始构建BERT 强推 李宏毅2021春机器学习课程 我们现在来说 怎么把Bert应用到多标签文本分类的问题上 注意 本文的重点是Bert的应用 对多标签
  • NLP之BERT和GPT

    NLP之BERT和GPT杂谈 我们介绍了好几种获取句子表征的方法 然而值得注意的是 我们并不是只对如何获取更好的句子表征感兴趣 在评估他们各自模型性能的时候所采取的方法 回过头去进行梳理 发现 无论是稍早些的 InferSent 还是 20
  • Transformer 架构和 BERT、GPT 和 T5 的兴起:初学者指南

    在广阔且不断发展的人工智能 AI 领域 有些创新不仅会留下深刻的印象 而且会带来巨大的影响 他们重新定义了整个领域的轨迹 在这些突破性的创新中 Transformer 架构成为变革的灯塔 这类似于工业革命期间蒸汽机的发明 推动人工智能进入一
  • 【BERT类预训练模型整理】

    BERT类预训练模型整理 1 BERT的相关内容 1 1 BERT的预训练技术 1 1 1 掩码机制 1 1 2 NSP Next Sentence Prediction 1 2 BERT模型的局限性 2 RoBERTa的相关内容 2 1
  • 关于Bert被质疑利用“虚假统计性提示”的ACL论文

    曾经狂扫11项记录的谷歌NLP模型BERT 近日遭到了网友的质疑 该模型在一些基准测试中的成功仅仅是因为利用了数据集中的虚假统计线索 如若不然 还没有随机的结果好 这项研究已经在Reddit得到了广泛的讨论 引用自 新智元 真的不想那么标题
  • ChatGPT 最好的替代品

    前两天我们邀请了微软工程师为我们揭秘 ChatGPT 直播期间有个读者问到 有了 ChatGPT BERT 未来还有发展前途吗 我想起来最近读过的一篇博客 最好的 ChatGPT 替代品 不过聊到这俩模型 就不得不提到 Transforme
  • [论文精读]BERT

    BERT Pre training of Deep Bidirectional Transformers for Language Understanding Abstract 作者介绍了一种新的语言模型 叫做BERT 是来自transfo
  • [Python人工智能] 三十三.Bert模型 (2)keras-bert库构建Bert模型实现文本分类

    从本专栏开始 作者正式研究Python深度学习 神经网络及人工智能相关知识 前一篇文章开启了新的内容 Bert 首先介绍Keras bert库安装及基础用法 这将为后续文本分类 命名实体识别提供帮助 这篇文章将通过keras bert库构建
  • 基于tensorflow2.0+使用bert获取中文词、句向量并进行相似度分析

    本文基于transformers库 调用bert模型 对中文 英文的稠密向量进行探究 开始之前还是要说下废话 主要是想吐槽下 为啥写这个东西呢 因为我找了很多文章要么不是不清晰 要么就是基于pytorch 所以特地写了这篇基于tensorf

随机推荐

  • Android开发指南!2021中级Android开发面试解答,完整版开放下载

    Google 为了帮助 Android 开发者更快更好地开发 App 推出了一系列组件 这些组件被打包成了一个整体 称作 Android Jetpack 它包含的组件如下图所示 老的 support 包被整合进了 Jetpack 例如上图
  • 混合策略纳什均衡——附例题及解析

    目录 引入 混合纳什均衡 例题 求法 引入 假设这样一种对局 甲乙两人抽扑克牌 扑克牌只有两种花色 红和黑 两张牌花色相同算甲胜 反之乙胜 那么甲乙双方应该如何设定自己抽出不同花色的概率呢 比如 设甲抽红牌的概率P 60 那么黑牌概率就是1
  • [架构之路-204]- 常见的需求分析技术:结构化分析与面向对象分析

    目录 前言 1 1 3 需求分析概述 导言 11 3 1需求分析的任务 1 绘制系统上下文范围关系图 2 创建用户界面原型 3 分析需求的可行性 4 确定需求的优先级 5 为需求建立模型 最难的一项任务 SA and OOA 6 创建数据字
  • EMI滤波器设计概念

    EMI滤波器设计概念 1 1 基本概念 在开关电源的设计里 为了对策传导干扰大都会在输入端前端加入EMI滤波器 因传导测试是由AC端来做量测 因此滤波器愈靠近接收器效果愈好 让所有的干扰都可经由滤波器做衰减 而一般滤波器是经由电感与电容组合
  • AMR文件格式的解释

    一 什么是AMR AMR WB 全称Adaptive Multi Rate和Adaptive Multi Rate Wideband 主要用于移动设备的音频 压缩比比较大 但相对其他的压缩格式质量比较差 由于多用于人声 通话 效果还是很不错
  • docker swarm 集群构建及服务管理

    文章目录 一 集群构建及部分配置 1 环境准备 2 swarm 初始化 3 worker子节点加入 4 查看集群信息 1 查看 swarm 集群节点 2 查看各节点 swarm 信息 5 swarm 证书配置 二 集群服务管理 1 创建集群
  • elasticsearch8.2 http开启鉴权

    Elasticsearch 早期的版本配置鉴权 由于插件收费 所以配置起来比较麻烦 但是最近发现Elasticsearch的8 2版本中可以配置https及鉴权的操作 所以记录一下给想要获取该知识的人 分享一下 第一步 修改elastics
  • java导入自定义类_java如何引入自己定义的类,即import语句该如何写?

    我写了2个java的小程序Time java和MyTime java 其内容分别如下 Time java 文件的内容publicclassTime privateinthour privateintminute privateintseco
  • 怎样做自媒体视频剪辑赚钱?

    不想真人出镜 但是想做自媒体赚钱 除了发布图文作品和音频作品外 我们还可以做视频剪辑发布到自媒体平台上 简单的说就是剪辑一些现有的视频作品 重新剪辑成一个新的作品并发布到自媒体平台上获得收益 不说多了 每天收益100 200还是不难的 新手
  • 函数指针做函数参数

    什么是函数指针 当我们定义一个函数的时候 编译器会为这个函数分配一段内存空间 而这段内存空间的首地址就是函数指针 函数指针的定义 函数返回值类型 指针变量名 函数参数列表 int p int int 这个语句就定义了一个指向函数的指针变量
  • python uiautomation mac os_(selenium+python)_UI自动化01_Mac下selenium环境搭建

    前言 Selenium是一个用于Web网页UI自动化测试的开源框架 可以驱动浏览器模拟用户操作 支持多种平台 Windows Mac OS Linux 和多种浏览器 IE Firefox Chrome Safari 可以用多种语言 Java
  • java new file会创建文件吗_Java高级——文件与I/O流

    简介 本文分为四个部分 首先是介绍File类 概括了一下概念 构造方法及常用方法等 其次是描述了面对对象的三大特征 再次是对抽象类进行了简单的概述 最后从特性 使用等等几个方面对接口进行了一定的描述 一 File类 1 File类概念 1
  • STM32F103系列控制的OLED IIC 4针

    最近在研究四针的OLED 先上个效果图 总工程文件评论区留下邮箱我会发送 硬件部分 有开发板的直接用开发板就好 没有的去某宝买一块STM32F103C8T6 10元左右 类似这种 接线部分 OLED一共有四个接口 本别是SCL 时钟 SDA
  • Qt-OpenCV学习笔记--高级形态转换--morphologyEx()

    概述 OpenCV提供了一个综合的形态转换工具 morphologyEx 集成了腐蚀运算 膨胀运算 开运算 闭运算 梯度运算 顶帽运算 黑帽运算 函数 void cv morphologyEx InputArray src OutputAr
  • 【云原生--Kubernetes】Helm 工具安装

    文章目录 一 Helm 概述 1 1 Helm 简介 1 2 Helm重要概念 1 3Helm2 组件 1 4Helm2 工作原理 1 5 Helm2与Helm3区别 二 Helm部署 三 Helm常用命令 3 1 chart仓库管理 3
  • 记5.28大促压测的性能优化—线程池相关问题

    目录 1 环境介绍 2 症状 3 诊断 4 结论 5 解决 6 对比java实现 废话就不多说了 本文分享下博主在5 28大促压测期间解决的一个性能问题 觉得这个还是比较有意思的 值得总结拿出来分享下 博主所服务的部门是作为公共业务平台 公
  • 【100天精通python】Day8:数据结构_元组Tuple的创建、删除、访问、修改、推导系列操作

    目录 1 创建元组 2 删除元组 3 访问元组元素 4 多个值的同时赋值和交换 5 修改元组元素 6 元组推导式 7 元组运算符 8 元组常用场景 9 元组 Tuple 和列表 List 的区别 元组 tuple 是 Python 中的一种
  • mysql创建表

    http www cnblogs com yunf archive 2011 04 20 2022193 html 说明 此文件包含了blog数据库中建立所有的表的Mysql语句 在sql语句中注意 约束的概念 1 实体完整性约束 主键 唯
  • 链栈的基本操作

    define CRT SECURE NO WARNINGS 链栈 include
  • bert下albert_chinese_small实现文本分类

    import torch from transformers import BertTokenizer BertModel BertConfig import numpy as np from torch utils import data