人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别

2023-11-11

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别,BiLSTM+CRF 模型是一种常用的序列标注算法,可用于词性标注、分词、命名实体识别等任务。本文利用pytorch搭建一个BiLSTM+CRF模型,并给出数据样例,通过一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。文章将分为以下几个部分:

1. BiLSTM+CRF模型的介绍
2. BiLSTM+CRF模型的数学原理
3. 数据准备
4. 模型搭建
5. 训练与评估
6. 预测
7. 总结

1. BiLSTM+CRF模型的介绍

BiLSTM+CRF模型结合了双向长短时记忆网络(BiLSTM)和条件随机场(CRF)两种技术。BiLSTM用于捕捉序列中的上下文信息,而CRF用于解决标签之间的依赖关系。实际上,BiLSTM用于为每个输入序列生成一个特征向量,然后将这些特征向量输入到CRF层,以便为序列中的每个元素分配一个标签。BiLSTM 和 CRF 结合在一起,使模型即可以像 CRF 一样考虑序列前后之间的关联性,又可以拥有 LSTM 的特征抽取及拟合能力。

2.BiLSTM+CRF模型的数学原理

假设我们有一个序列 x = ( x 1 , x 2 , . . . , x n ) \boldsymbol{x} = (x_1, x_2, ..., x_n) x=(x1,x2,...,xn),其中 x i x_i xi 是第 i i i 个位置的输入特征。我们要对每个位置进行标注,即为每个位置 i i i 预测一个标签 y i y_i yi。标签集合为 Y = y 1 , y 2 , . . . , y n \mathcal{Y}={y_1, y_2, ..., y_n} Y=y1,y2,...,yn,其中 y i ∈ L y_i \in \mathcal{L} yiL L \mathcal{L} L 表示标签的类别集合。

BiLSTM用于从输入序列中提取特征,它由两个方向的LSTM组成,分别从前向后和从后向前处理输入序列。在时间步 t t t,BiLSTM的输出为 h t ∈ R 2 d h_t \in \mathbb{R}^{2d} htR2d,其中 d d d 是LSTM的隐藏状态维度。具体来说,前向LSTM从左至右处理输入序列 x \boldsymbol{x} x,输出隐状态序列 h → = ( h 1 → , h 2 → , . . . , h n → ) \overrightarrow{h}=(\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_n}) h =(h1 ,h2 ,...,hn ),其中 h t → \overrightarrow{h_t} ht 表示在时间步 t t t 时前向LSTM的隐藏状态;后向LSTM从右至左处理输入序列 x \boldsymbol{x} x,输出隐状态序列 h ← = ( h 1 ← , h 2 ← , . . . , h n ← ) \overleftarrow{h}=(\overleftarrow{h_1},\overleftarrow{h_2},...,\overleftarrow{h_n}) h =(h1 ,h2 ,...,hn ),其中 h t ← \overleftarrow{h_t} ht 表示在时间步 t t t 时后向LSTM的隐藏状态。则每个位置 i i i 的特征表示为 h i = [ h i → ; h i ← ] h_i=[\overrightarrow{h_i};\overleftarrow{h_i}] hi=[hi ;hi ],其中 [ ⋅ ; ⋅ ] [\cdot;\cdot] [;] 表示向量拼接操作。

CRF用于建模标签之间的关系,并进行全局优化。CRF模型定义了一个由 Y \mathcal{Y} Y 构成的联合分布 P ( y ∣ x ) P(\boldsymbol{y}|\boldsymbol{x}) P(yx),其中 y = ( y 1 , y 2 , . . . , y n ) \boldsymbol{y} = (y_1, y_2, ..., y_n) y=(y1,y2,...,yn) 表示标签序列。具体来说,CRF模型将标签序列的概率分解为多个位置的条件概率的乘积,即

P ( y ∣ x ) = ∏ i = 1 n ψ i ( y i ∣ x ) ∏ i = 1 n − 1 ψ i , i + 1 ( y i , y i + 1 ∣ x ) P(\boldsymbol{y}|\boldsymbol{x})=\prod_{i=1}^{n}\psi_i(y_i|\boldsymbol{x}) \prod_{i=1}^{n-1}\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) P(yx)=i=1nψi(yix)i=1n1ψi,i+1(yi,yi+1x)

其中 ψ i ( y i ∣ x ) \psi_i(y_i|\boldsymbol{x}) ψi(yix) 表示在位置 i i i 时预测标签为 y i y_i yi 的条件概率, ψ i , i + 1 ( y i , y i + 1 ∣ x ) \psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) ψi,i+1(yi,yi+1x) 表示预测标签为 y i y_i yi y i + 1 y_{i+1} yi+1 的联合概率。这些条件概率和联合概率可以用神经网络来建模,其中输入为位置 i i i 的特征表示 h i h_i hi

CRF模型的全局优化问题可以通过对数似然函数最大化来实现,即

max ⁡ y log ⁡ P ( y ∣ x ) = ∑ i = 1 n log ⁡ ψ i ( y i ∣ x ) ∑ i = 1 n − 1 log ⁡ ψ i , i + 1 ( y i , y i + 1 ∣ x ) \max_{\boldsymbol{y}}\log P(\boldsymbol{y}|\boldsymbol{x}) = \sum_{i=1}^{n}\log\psi_i(y_i|\boldsymbol{x}) \sum_{i=1}^{n-1}\log\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) ymaxlogP(yx)=i=1nlogψi(yix)i=1n1logψi,i+1(yi,yi+1x)
其中 y \boldsymbol{y} y 是所有可能的标签序列。可以使用动态规划算法(如维特比算法)来求解全局最优标签序列。

综上所述,BiLSTM+CRF模型的数学原理可以表示为:

P ( y ∣ x ) = ∏ i = 1 n ψ i ( y i ∣ x ) ∏ i = 1 n − 1 ψ i , i + 1 ( y i , y i + 1 ∣ x ) P(\boldsymbol{y}|\boldsymbol{x}) = \prod_{i=1}^{n}\psi_i(y_i|\boldsymbol{x}) \prod_{i=1}^{n-1}\psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) P(yx)=i=1nψi(yix)i=1n1ψi,i+1(yi,yi+1x)

其中

ψ i ( y i ∣ x ) = exp ⁡ ( W o T h i + b o T y i ) ∑ y i ′ ∈ L exp ⁡ ( W o T h i + b o T y i ′ ) \psi_i(y_i|\boldsymbol{x}) = \frac{\exp(\boldsymbol{W}_o^{T}\boldsymbol{h}_i + \boldsymbol{b}_o^{T}\boldsymbol{y}i)}{\sum{y_i'\in\mathcal{L}}\exp(\boldsymbol{W}_o^{T}\boldsymbol{h}_i + \boldsymbol{b}_o^{T}\boldsymbol{y}_i')} ψi(yix)=yiLexp(WoThi+boTyi)exp(WoThi+boTyi)

ψ i , i + 1 ( y i , y i + 1 ∣ x ) = exp ⁡ ( W t T y i , i + 1 ) ∑ y i ′ ∈ L ∑ y i + 1 ′ ∈ L exp ⁡ ( W t T y i ′ , i + 1 ′ ) \psi_{i,i+1}(y_i,y_{i+1}|\boldsymbol{x}) = \frac{\exp(\boldsymbol{W}t^{T}\boldsymbol{y}{i,i+1})}{\sum_{y_i'\in\mathcal{L}}\sum_{y_{i+1}'\in\mathcal{L}}\exp(\boldsymbol{W}t^{T}\boldsymbol{y}{i',i+1}')} ψi,i+1(yi,yi+1x)=yiLyi+1Lexp(WtTyi,i+1)exp(WtTyi,i+1)

其中 W o \boldsymbol{W}_o Wo b o \boldsymbol{b}_o bo 是输出层的参数, W t \boldsymbol{W}_t Wt 是转移矩阵, h i \boldsymbol{h}_i hi 是位置 i i i 的特征表示, y i \boldsymbol{y}i yi 是位置 i i i 的标签表示, y i , i + 1 \boldsymbol{y}{i,i+1} yi,i+1 是位置 i i i i + 1 i+1 i+1 的标签联合表示。

在这里插入图片描述

3. 数据准备

下面我将使用一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。数据集包含了一些句子,每个句子中的单词都被标记为“B-PER”(人名开始)、“I-PER”(人名中间)、“B-LOC”(地名开始)、“I-LOC”(地名中间)或“O”(其他)。

数据样例:

John B-PER
lives O
in O
New B-LOC
York I-LOC
. O

4. 模型搭建

首先,我们需要安装PyTorch库:

pip install torch

接下来,我们将使用PyTorch搭建BiLSTM+CRF模型。完整的模型代码如下:

import torch
import torch.nn as nn
import torch.optim as optim

from TorchCRF import CRF

class BiLSTM_CRF(nn.Module):
    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
                            num_layers=1, bidirectional=True)

        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
        self.crf = CRF(self.tagset_size)

    def forward(self, sentence):
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
        lstm_out, _ = self.lstm(embeds)
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def loss(self, sentence, tags):
        feats = self.forward(sentence)
        return -self.crf(torch.unsqueeze(feats, 0), tags)

    def predict(self, sentence):
        feats = self.forward(sentence)
        return self.crf.decode(torch.unsqueeze(feats, 0))

5. 训练与评估

接下来,我们将使用训练数据对模型进行训练,并在每个epoch后打印损失值和准确率。

def train(model, optimizer, data):
    for epoch in range(10):
        total_loss = 0
        total_correct = 0
        total_count = 0
        for sentence, tags in data:
            model.zero_grad()
            loss = model.loss(sentence, tags)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            prediction = model.predict(sentence)
            total_correct += sum([1 for p, t in zip(prediction, tags) if p == t])
            total_count += len(tags)

        print(f"Epoch {epoch + 1}: Loss = {total_loss / len(data)}, Accuracy = {total_correct / total_count}")

6. 预测

最后,我们将使用训练好的模型对新的句子进行预测。

def predict(model, sentence):
    prediction = model.predict(sentence)
    return [p for p in prediction]

7. 总结

用训练好的模型对新的句子进行预测。

def predict(model, sentence):
    prediction = model.predict(sentence)
    return [p for p in prediction]

7. 总结

本文介绍了如何使用PyTorch搭建一个BiLSTM+CRF模型,并通过一个简单的命名实体识别(NER)任务来演示模型的训练和预测过程。希望这篇文章能帮助你理解BiLSTM+CRF模型的原理,并为你的实际项目提供参考作用哦。

更新精彩的模型搭建与应用请持续关注哦!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别 的相关文章

随机推荐

  • 10、Java 8 - 接口 ( interface ) 默认方法

    总所周知 在 Java 7 和之前的版本中 接口 interface 是不能包含具体的方法实现的 比如 下面的代码 是会报错的 public class InterfaceDefaultMethodTester public static
  • Java 读取freemarker模板(html)转换成String

    2019独角兽企业重金招聘Python工程师标准 gt gt gt Java代码 package com main util import freemarker template Template import org slf4j Logg
  • springboot日期时间格式全局处理

    import com fasterxml jackson datatype jsr310 deser LocalDateTimeDeserializer import com fasterxml jackson datatype jsr31
  • Mysql进阶优化篇01——四万字详解数据库性能分析工具(深入、全面、详细,收藏备用)

    前 言 作者简介 半旧518 长跑型选手 立志坚持写10年博客 专注于java后端 专栏简介 mysql进阶 主要讲解mysql数据库进阶知识 包括索引 数据库调优 分库分表等 文章简介 本文将介绍数据库优化的步骤 思路 性能分析工具 比如
  • js 中的隐士转换 + ==规则

    ToString 1 数组中的null或undefined 会被当做空字符串处理 2 普通对象 转为字符串相当于直接使用Object prototype toString 返回 object Object ToNumber 1 null 转
  • Linux学习之基础工具一

    1 Linux 软件包管理器 yum 首先我们需要知道的是在Linux下 现存的软件和指令是一定的 而有的时候我们想需要更多的指令或者软件 而这在Linux本身下是没有的 故我们可以利用指令yum指令安装或卸载你想要或者不需要的软件 ubu
  • k8s学习pod第七天

    init Container 初始化容器是一类只运行一次的容器 本质是也是容器 不同容器间启动有先后顺序 只有前面的容器运行成功了 后面的容器才能运行 初始化容器的场景 在其他容器运行之前做个初始化 比如配置文件生成 环境变量生成 有先后顺
  • OpenCV——分水岭算法

    目录 一 分水岭算法 1 概述 2 图像分割概念 3 分水岭算法原理 二 主要函数 三 C 代码 四 结果展示 1 原始图像 2 分割结果 五 参考链接 一 分水岭算法 1 概述 分水岭算法是一种图像分割常用的算法 可以有效地将图像中的目标
  • Javascript高级程序设计——15-1.匿名函数和闭包

    1 匿名函数 表示没有定义函数名的函数 案例1 1 简单的匿名函数 function 单独的匿名函数无法执行 alert Lee 案例1 2 将匿名函数赋值给一个变量 var box function return Lee alert bo
  • 复数矩阵计算行列式

    项目上需要对复矩阵的行列式计算 根据计算一般矩阵行列式的代码改成了复矩阵行列式计算 include
  • 性能测试中TPS上不去的几种原因

    中TPS一直上不去 是什么原因 这篇文章 就具体说说在实际压力测试中 为什么有时候TPS上不去的原因 先来解释下什么叫TPS TPS Transaction Per Second 每秒事务数 指服务器在单位时间内 秒 可以处理的事务数量 一
  • Python库的使用说明

    目录 1 第三方库索引网站 2 第三方安装 2 1 pip工具介绍 2 2 pip工具安装 2 2 1 list 命令查看已安装的库列表 2 2 2 uninstall 命令 2 2 3 show 命令 2 2 4 download 命令
  • C++标准模板库 迭代器 iterator 详解(二)

    迭代器提供对一个容器中的对象的访问方法 并且定义了容器中对象的范围 迭代器就如同一个指针 事实上 C 的指针也是一种迭代器 但是 迭代器不仅仅是指针 因此你不能认为他们一定具有地址值 例如 一个数组索引 也可以认为是一种迭代器 迭代器有各种
  • [NOI2009]植物大战僵尸【拓扑+最大权闭合子图】

    题目链接 BZOJ 1565 看到这道题之后很容易想到的就是最大权闭合子图了 但是却有个问题就是要去除掉那些环 因为构成了环之后 相当于是无敌的状态 它们就永远不会得到贡献 并且环之后的点也是得不到贡献的 所以 这里利用拓扑 知道哪些点是可
  • 「Qt」事件概念

    0 引言 在本文所属专栏的前面的文章里 我们介绍了Qt的 信号 Signal 与 槽 Slot 机制 信号 Signal 与 槽 Slot 机制是 Qt 框架用于多个对象之间通信的 是 Qt 的核心特性 也是 Qt 与其他框架最大的不同之处
  • anaconda中spyder改变背景颜色(黑色)

    spyder挺好用的 但是未定义的背景颜色实在不好看 纯属个人审美 下面开始更换背景图 打开spyder 依此点击 Tools 再点击preference 喜爱 选择Syntax coloring Scheme调成Monokai 这是我喜欢
  • python+selenium+unittest自动化测试框架

    前言 关于自动化测试的介绍 网上已有很多资料 这里不再赘述 UI自动化测试是自动化测试的一种 也是测试金字塔最上面的一层 selenium是应用于web的自动化测试工具 支持多平台 多浏览器 多语言来实现自动化 优点如下 开源 免费且对we
  • pyecharts在数据可视化中的应用 (二)(pyecharts绘制树图、矩形树图、地理热力图、词云图、相关性矩阵等图)

    1 使用以下JSON数据绘制树图 矩形树图 from pyecharts import options as opts from pyecharts charts import Tree data name flare children n
  • Android 系统性能优化(57)---MTK 平台开关机、重启时间优化

    MTK 平台开关机 重启时间优化 开关机 重启时间优化 开机性能优化 是用功能和其它因素多方面平衡的结果 片面追求单方面的性能没有太大意义 有些产品设计开机动画非常酷炫 动画图片过多 高帧率会影响开机速度 这时就需要看是开机速度优先还是体验
  • 人工智能(pytorch)搭建模型8-利用pytorch搭建一个BiLSTM+CRF模型,实现简单的命名实体识别

    大家好 我是微学AI 今天给大家介绍一下人工智能 pytorch 搭建模型8 利用pytorch搭建一个BiLSTM CRF模型 实现简单的命名实体识别 BiLSTM CRF 模型是一种常用的序列标注算法 可用于词性标注 分词 命名实体识别