利用 Pytorch 加载词向量库文件

2023-11-07

1. 示例代码

完整代码:

import torch
from torch.nn import Embedding

# 载入讯飞词向量文件
word_vector_file = '../Downloads/tencent-ailab-embedding-zh-d100-v0.2.0-s.txt' # 库文件的路径。本程序选择了最小的一个库文件。
word_vectors = {}
with open(word_vector_file, 'r', encoding='utf-8') as f:
    for line in f:
        word, vector_str = line.split(' ', 1)
        vector = torch.FloatTensor([float(x) for x in vector_str.split()])
        word_vectors[word] = vector

# 构建词典和词向量矩阵
words = list(word_vectors.keys())
word_dict = {w: i for i, w in enumerate(words)}
word_vectors_matrix = torch.stack(list(word_vectors.values()))

# 构建Embedding层
embedding = Embedding.from_pretrained(word_vectors_matrix)

# 得到某个单词的词向量
word = "勇敢"
word_idx = word_dict.get(word, None)
if word_idx is not None:
    word_vector = embedding(torch.LongTensor([word_idx]))
else:
    # 如果单词不在词向量文件中,则随机初始化一个词向量
    word_vector = embedding.weight.mean(dim=0, keepdim=True)

2. 数据流分析

  • step-01: 词向量文件.txt => word_vectors。word_vectors 是一个 (key = 词,value = 词向量) 的词典。
  • step-02:word_vectors.keys() => word_dict。word_dict 是一个词的索引倒排表,给定单词,可查出顺序号。因为后面会把词向量表转换成张量,因此,查询单词对用的词向量,必须根据顺序号才可以。
  • step-03:word_vectors.values() => word_vectors_matrix。把词向量列表转成矩阵,详见后面的代码详解。
  • step-04:word_vectors_matrix => embedding。词向量矩阵转成嵌入层。嵌入层是神经网络中的一部分,用于将词汇表中的单词(由数字表示)转换为词向量,以便神经网络进行学习和预测。
  • step-05: word => word_idx。根据单词,得到它在词向量表中的顺序号。
  • step-06:word_idx + embedding => word_vector。查询顺序号为 word_idx 的词向量。

3. 代码详解

word, vector_str = line.split(' ', 1)

这行代码将字符串“line”根据第一个空格分割成两个部分,并将分割后的结果保存在两个变量“word”和“vector_str”中。

具体来说,变量“word”是第一个空格之前的子串,变量“vector_str”是第一个空格之后的剩余子串。参数“1”表示最多只对“line”进行一次分割,因此如果“line”中有多个空格,只会将第一个空格作为分隔符。

这种操作常见于通过空格或其他特定字符来分割文本数据,并将分割后的结果存储到不同的变量中进行进一步处理。例如,在NLP任务中,可能会将每个单词的向量存储为一个文本文件,每行格式为“word value1 value2 … valuen”,该语句就可以帮助将每行数据分解成单词和向量值两部分。

vector = torch.FloatTensor([float(x) for x in vector_str.split()])

这行代码首先将变量“vector_str”按空格分割成一组字符串,然后使用列表推导式将每个字符串转化为浮点数,并将它们作为一个列表传递给torch.FloatTensor()函数。该函数将得到的列表转换为PyTorch张量(tensor),并赋值给变量“vector”。

这里的张量是PyTorch中的一种数据结构,类似于多维数组,可以用来存储和进行高效计算。该语句构建了一个浮点型的一维张量(即向量),其中每个元素对应于分隔符分割的字符串所表示的浮点数。

总之,这一行代码的目的是将一个表示向量的字符串转化为一个一维的PyTorch张量,以便后续的计算和处理。

word_vectors[word] = vector

这行代码将变量“vector”赋值给字典“word_vectors”的键“word”,其中“word_vectors”是一个存储单词向量的字典变量,而“word”是具体的单词。

这里“vector”是一个一维PyTorch张量(即向量),其表示了“word”所对应的向量。通过将其赋值给字典中的相应键,“word_vectors”就可以存储该单词的向量表示。

字典是 Python 中的一种 可变数据类型,它是一种键-值(key-value)映射的集合,可用于存储各种类型的数据。在这里,字典“word_vectors”可以用于存储单词及其对应的向量表示,从而方便后续的机器学习或自然语言处理任务中的使用。

words = list(word_vectors.keys())

这行代码将创建一个名为“words”的列表变量,其中包含了“word_vectors”字典中所有键(即所有单词)的列表。

具体而言,函数“word_vectors.keys()”返回一个迭代器对象,其中包含“word_vectors”字典中所有的键。通过将这个迭代器对象传递给“list()”函数,可以将其转换为一个包含所有键的列表。代码把这个列表赋给变量“words”,以便于我们能够归纳地处理字典中的每个键值对。

这行代码通常用于在对单词向量进行处理时,需要遍历所有单词,并进行一些操作的情况下。例如,可能需要计算两个单词之间的相似度,这就需要比较两个向量的余弦距离值。此时需要一个包含所有单词的列表,以便于按顺序获取每个单词的向量表示,并进行一些数学计算和处理。

word_dict = {w: i for i, w in enumerate(words)}

这行代码创建了一个名为“word_dict”的字典变量,其中包含了单词列表“words”中的所有单词,并为每个单词赋予了一个唯一的整数ID。每个单词都作为键,与其对应的整数ID作为值。

具体而言,代码使用了Python中的一种语法规则“字典推导式”,即通过一行代码在创建字典时完成键值对的定义和赋值操作。

上述代码执行的具体步骤如下:

  1. 循环遍历“words”列表中每个单词。循环过程中,“enumerate(words)”函数用来同时获取单词列表“words”中每个单词的索引i和该单词w。这样可以方便地记录每个单词在列表中的位置。

  2. 对于每个遍历到的单词w,将它与对应的整数i一起作为键值对添加到字典变量“word_dict”中。这里“{w: i}”表示一个包含单个键值对的字典,表示单词w和对应的整数i的映射关系。

  3. 最终将所有遍历到的单词都添加到了字典“word_dict”中,并为每个单词赋予了一个唯一的整数ID。

这行代码的主要用途是在文本分析任务中,将单词映射到唯一的整数ID,便于模型处理和优化。这是因为许多模型需要将输入的单词表示成向量形式,而这通常需要通过查询包含单词及其向量表示的字典进行实现。在此过程中,将单词映射到唯一的整数ID可以避免语义相似的不同单词造成的混淆和误差。

word_vectors_matrix = torch.stack(list(word_vectors.values()))

这行代码的作用是将一个包含所有词向量的字典转换成一个tensor形式的矩阵。

具体而言,代码使用了pytorch库中的torch.stack函数和Python中的list结构。

首先,代码中的word_vectors是一个字典,其中每个键表示一个单词,而每个值则是对应单词的词向量,即一个一维tensor。

然后,代码使用Python内置函数list()将所有的词向量拼接成一个列表,此时得到一个列表,列表中的每个元素都是具有相同维数的张量;接着使用torch.stack()按照行的方式将它们组合为一个新张量。这就是代码的目的:

  1. 首先通过“list(word_vectors.values())”将所有词向量以列表的形式存储并传给torch.stack函数。

  2. torch.stack函数调用后会将所有输入的元素沿着新的维度将它们拼接起来,返回一个新的张量对象。

  3. 最终将所有单词的词向量堆叠成一个矩阵对象“word_vectors_matrix”。此时“word_vectors_matrix”矩阵的维度为(n, d),其中n是词典中单词数量,d是每个单词的词向量维度。

这种将词向量矩阵转换成tensor的方式可以方便进行大规模的计算和矩阵操作,像词向量相加或者平均池化,同时也可以便于将这个Tensor作为神经网络的输入。

embedding = Embedding.from_pretrained(word_vectors_matrix)

这行代码的作用是将预先训练好的词向量矩阵“word_vectors_matrix”传入到嵌入层(Embedding)中,并返回一个新的嵌入层对象“embedding”。

具体而言,Embedding.from_pretrained()函数是在PyTorch中用来生成一个预训练的嵌入层的方法。它使用了预训练的权重来初始化网络中的嵌入层,这里是使用预先训练好的词向量矩阵来初始化嵌入层。

通过指定参数“word_vectors_matrix”,Embedding.from_pretrained()函数会将所有单词的词向量矩阵作为输入,生成一个新的嵌入层对象。可以看作是一种单词向量表征方法,每个单词都映射到一个d维的向量上。嵌入层是神经网络中的一部分,用于将词汇表中的单词(由数字表示)转换为词向量,以便神经网络进行学习和预测。

嵌入层“embedding”中保存着一个用于查找任意词汇表中单词对应嵌入值的方法,通过对嵌入层调用该方法可以返回任意单词对应的预训练的向量。 这个向量可以代表一个词在N维向量空间内的位置,可以作为神经网络的输入,在自然语言处理任务中进行模型训练和推理。由于词向量是通过预训练得到的,并且已经捕捉到了文本数据集中单词之间的相关性,这可以加速神经网络的收敛,并提高预测性能。

word = "勇敢"
word_idx = word_dict.get(word, None)

这段代码是在一个程序中使用了一个字典变量word_dict,该字典中包含了一些词语及其在程序中所对应的数字编号。其中word = "勇敢"表示要查找的词语是“勇敢”,word_idx是用于存储查询结果的变量,get()方法可以用来获取指定key对应的value值。如果word_dict中存在“勇敢”这个词,那么该函数会返回这个词的编号并赋值给word_idx变量;如果word_dict中不存在这个词,则返回None。

word_vector = embedding(torch.LongTensor([word_idx]))

这段代码是用PyTorch实现的嵌入层对输入词语进行词向量编码的过程。其中word_idx表示输入词在词典中的索引,torch.LongTensor([word_idx])会将该索引值转化为一个LongTensor类型的变量作为输入,传递给embedding函数进行编码。embedding函数是PyTorch中的一个嵌入层函数,接收一个整数型的Tensor作为输入,将其转化成一个词向量矩阵输出。为加速运算,通常会采用GPU进行计算,因此这里使用了torch.LongTensor()来将word_idx转换为PyTorch张量的格式以便在GPU上运行。最终经过词向量编码的结果被存储在word_vector变量中。

word_vector = embedding.weight.mean(dim=0, keepdim=True)

这段代码是用PyTorch实现的计算嵌入层的权重矩阵embedding中所有词向量的平均值,然后将结果赋值给word_vector。其中,embedding.weight表示嵌入层的权重矩阵,该矩阵的大小为vocabulary_size * embedding_dim,vocabulary_size表示词典大小,embedding_dim表示词向量维度。mean(dim=0, keepdim=True)表示对embedding.weight按照第0维(也就是第一维)进行取平均值,即对所有词向量进行了平均操作。keepdim=True表示不改变张量的形状,保持和原始张量一样的形状。最后得到的平均值被存储在一个大小为1 * embedding_dim的张量中。因此,word_vector即为所有词向量的平均值,仍然是一个embedding_dim维的向量。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

利用 Pytorch 加载词向量库文件 的相关文章

随机推荐

  • Java初学疑问之接口为什么能运行Object的方法

    public class CommonTest public static void main String args Animal animal new Dog animal toString 为什么能运行该方法 class Dog im
  • 通过清华大学镜像和pip进行安装

    通过清华大学镜像和pip进行安装 有时候网络不佳时 直接通过pip安装可能会很慢或者不成功 因此可以借助清华镜像 可以在使用pip的时候加参数 i https pypi tuna tsinghua edu cn simple 以gensim
  • 前端实战:小实例1——导航栏

    前言 一个导航栏可看作一个列表 在 HTML 使用 ul 标签和 li 标签元素进行结构表示 在 CSS 中进行样式处理 对应标签元素的具体用法可查看 HTML常见标签介绍 实现思路 使用 div 包装导航栏 用 ul 和 li 标签展示导
  • EasyPoi 数据导入导出,贼方便

    1 maven坐标
  • 银行卡编码规则及检验算法详解

    一 银行卡结构 XXXXXX XXXXXXXXXXXX X 发卡行标识代码 自定义位 校验码 根据ISO标准 银行卡长度一般在13 19位 国际上也有12位的 银联标准卡卡长度一般是在16 19位 双组织卡也有13 19位的 二 发卡行标识
  • grid - 显式网格

    显式网格布局包含 行 列 列 grid template columns page color fff grid padding 1 display grid grid gap 1px grid template rows 50px 100
  • 养生指南 4 : 睡眠 与 外因

    参考 老中医给的100条养生建议 强烈推荐 1 睡眠 1 睡觉 是养生第一要素 睡觉的时间 应该是 晚 21 00 早3 00 因为这个时间是一天的 冬季 冬季主藏 冬季不藏 春夏不长 即第 2 天没精神 早起如在寅时三点至五点 此时切忌郁
  • Python数据分析与可视化------NumPy第三方库

    目录 数据的维度 NumPy CSV文件 多维数据的存取 NumPy的便捷式文件截取 NumPy的随机数函数子库 NumPy的统计函数 NumBy的梯度函数 图像的数组表示 图像的变换 数据的维度 维度 一组数据的组织形式 一维数据 由对等
  • 1.出现需要keil突破内存限制

    出现 error L6050U The code size of this image 37186 bytes exceeds the maximum allowed for this version of the linker 是因为超出
  • openlayers绘制圆形区域,消除误差的一种方法

    我需要以某点为圆心 以某长度 单位米 为半径 在地图上绘制圆形区域 前提 地图显示 图层和数据源的创建与设置方法这里就不详细描述了 直接上关键部分 一开始 我使用如下代码实现圆形区域的绘制 绘制以坐标 1 1 为中心 200000米为半径的
  • Codeforces Round #553 (Div. 2)

    A Maxim and Biology time limit per test 1 second memory limit per test 256 megabytes input standard input output standar
  • 无法通过http://burp获取BurpSuite证书的解决方法

    为了能够对https协议的数据进行抓取必须安装BurpSuite的证书 但在下载证书的过程中出现了问题 官方和百度下载证书的方法都是在能够抓取http的状态下访问http burp下载证书 但http burp页面却加载不出来 百度了很久也
  • 【Bootstrap】Bootstrap基础导航栏(响应式导航菜单)

    Bootstrap基础导航栏 响应式导航菜单
  • 自动化测试(五):自动化测试框架的搭建和基于yaml热加载的测试用例的设计

    该部分是对自动化测试专栏前四篇的一个补充 本次参考以下文章实现一个完整的谷歌翻译接口自动化测试 1 python小脚本 Yaml配置文件动态加载 2 python做接口测试的学习记录day8 pytest自动化测试框架之热加载和断言封装 目
  • 如何自己开发一个Android APP(4)——JAVA

    资源使用 在java文件中 通过资源id完成对资源的访问 可以通过对象 getId 的方法得到组件 因为XML布局文件与java文件实际上是不互通的 也就是说我们的xml只控制外观 当你需要为某个地方作出某些设置时 java必须先获取到这个
  • vue面试面试

    MVVM model js对象data view dom模板代码 viewmodel vue实例 ViewModel负责把Model的数据同步到View 还负责把View的修改同步回Model 实现数据 视图分离 数据不会影响视图 框架优缺
  • 美国专利知识

    1 美国专利查看网站 http patft uspto gov 2 美国专利类型 Application Type APT This field contains a single digit number which indicates
  • ASP网页给服务器传参,asp.net页面与页面之间传参数值方法(post传值和get传值)

    一 利用POST传值 传值asp文件send aspx 接受asp文件receive aspx string username Ruquest Form receive 一 get方法传值 QueryString 也叫查询字符串 这种方法将
  • 计算机视觉——图像视觉显著性检测

    目录 系列文章目录 零 问题描述 一 图像显著性检测 1 定义 2 难点 二 常用评价标准和计算方法 1 综述 2 ROS曲线详述 2 1 混淆矩阵 2 2 ROC曲线简介 2 3 ROC曲线绘制及其判别标准 2 4 ROC曲线补充 三 F
  • 利用 Pytorch 加载词向量库文件

    1 示例代码 完整代码 import torch from torch nn import Embedding 载入讯飞词向量文件 word vector file Downloads tencent ailab embedding zh