Pytorch学习(2)——训练词向量的代码

2023-05-16

教程:https://www.bilibili.com/video/BV1vz4y1R7Mm?p=2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud

from collections import Counter
import numpy as np
import random
import math

import pandas as pd
import scipy
import sklearn
from sklearn.metrics.pairwise import cosine_similarity
USE_CUDA = torch.cuda.is_available()
random.seed(1)
np.random.seed(1)
torch.manual_seed(1)
if USE_CUDA:
    torch.cuda.manual_seed(1)
C = 3  # content window
K = 100  # number of negtive samples
NUM_EPOCHS = 2
MAX_VOCAB_SIZE = 30000
BATCH_SIZE = 128
LEARNING_RATE = 0.2
EMBEDDING_SIZE = 100 
def word_tokenize(text):
    return text.split()
  • 从文本文件中读取所有的文字,通过这些文本创建一个vocabulary
  • 由于单词数量可能太大,我们只选取最常见的MAX_VOCAB_SIZE个单词
  • 添加一个UNK单词表示所有不常见的单词
  • 我们需要记录单词到index的mapping, 以及index到单词的mapping, 单词的count, 单词的normalized frequency, 单词总数
with open('text8/text8.train.txt', 'r') as fin:
    text = fin.read()
    
text = word_tokenize(text)
vocab = dict(Counter(text).most_common(MAX_VOCAB_SIZE-1))
vocab['<unk>'] = len(text) - np.sum(list(vocab.values()))
idx_to_word = [word for word in vocab.keys()]
word_to_idx = {word:i for i, word in enumerate(idx_to_word)}
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_counts = np.array([count for count in vocab.values()], dtype=np.float32)
word_freqs = word_counts / np.sum(word_counts)
word_freqs = word_freqs ** (0.75)
word_freqs = word_freqs / np.sum(word_freqs)

实现DataLoader

返回batch的数据

class WordEmbeddingDataset(tud.Dataset):
    def __init__(self, text, word_to_idx, idx_to_word, word_freqs, word_counts):
        super()
        self.text_encoded = [word_to_idx.get(word, word_to_idx['<unk>']) for word in text]
        self.text_encoded = torch.LongTensor(self.text_encoded)
        self.word_to_idx = word_to_idx   
        self.word_freqs = torch.Tensor(word_freqs)
        self.word_counts = torch.Tensor(word_counts)
    
    def __len__(self):
        return len(self.text_encoded)
    
    def __getitem__(self, idx):
        center_word = self.text_encoded[idx]
        pos_indices = list(range(idx-C, idx)) + list(range(idx+1, idx+C+1))  # window内的单词(写错了)
        pos_indices = [i % len(self.text_encoded) for i in pos_indices]  # 取余,放置超出长度
        pos_words = self.text_encoded[pos_indices]  # 周围单词
        neg_words = torch.multinomial(self.word_freqs, K * pos_words.shape[0], True)  # 负采样单词
        # Returns a tensor where each row contains num_samples indices sampled from 
        # the multinomial probability distribution located in the corresponding row
        # of tensor input. (多项式概率分布)
        return center_word, pos_words, neg_words
dataset = WordEmbeddingDataset(text, word_to_idx, idx_to_word, word_freqs, word_counts)
dataloader = tud.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
dataset.word_freqs.shape
class EmbeddingModel(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(EmbeddingModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        
        self.in_embed = nn.Embedding(vocab_size, embed_size)
        self.out_embed = nn.Embedding(vocab_size, embed_size)
    
    def forward(self, input_labels, pos_labels, neg_labels):
        # input_label: [batch_size]
        # pos_labels: [batch_size, (window_size * 2)]
        # neg_labels: [batch_size, (window_size * 2 * K)]
        input_embedding = self.in_embed(input_labels)  # [batch_size, embed_size]
        pos_embedding = self.out_embed(pos_labels)  # [batch_size, (window_size*2), embed_size]
        neg_embedding = self.out_embed(neg_labels)  # [batch_size, (window_size*2)*K, embed_size]
        
        input_embedding = input_embedding.unsqueeze(2)  # [batch_size, embed_size, 1]
        pos_dot = torch.bmm(pos_embedding, input_embedding).squeeze(2)
        # documention of torch.bmm: batch matrix-matrix product
        # Performs a batch matrix-matrix product of matrices stored in input and mat2.
        # if A.shape=[batch_size, n, m], B.shape=[batch_size, m, p],
        # then torch.bmm(A, B).shape is [batch_size, n, p]
        neg_dot = torch.bmm(neg_embedding, -input_embedding).squeeze(2)
        
        log_pos = F.logsigmoid(pos_dot).sum(1)
        log_neg = F.logsigmoid(neg_dot).sum(1)
        loss= log_pos+log_neg
        return -loss
    
    def input_embeddings(self):
        return self.in_embed.weight.data.cpu().numpy()
model = EmbeddingModel(MAX_VOCAB_SIZE, EMBEDDING_SIZE)
if USE_CUDA:
    model = model.cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
for e in range(NUM_EPOCHS):
    for i, (input_labels, pos_labels, neg_labels) in enumerate(dataloader):
#         print(input_labels, pos_labels, neg_labels)
#         if i > 5: break
        input_labels = input_labels.long()
        pos_labels = pos_labels.long()
        neg_labels = neg_labels.long()
        if USE_CUDA:
            input_labels = input_labels.cuda()
            pos_labels = pos_labels.cuda()
            neg_labels = neg_labels.cuda()
        
        optimizer.zero_grad()
        loss = model(input_labels, pos_labels, neg_labels).mean()
        loss.backward()
        optimizer.step()

        # if i % 100 == 0:
        print("epoch", e, "iteration", i, loss.item())
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch学习(2)——训练词向量的代码 的相关文章

随机推荐

  • Redhat系列系统在线镜像源

    目录 前言Redhat7镜像源1 阿里云镜像源2 清华大学镜像源3 网易镜像源4 华为镜像源 Redhat8镜像源1 阿里云镜像源2 清华大学镜像源3 网易镜像源4 华为镜像源5 阿里云Rocky镜像源6 阿里云anolis镜像源 Redh
  • SuSE Enterprise linux安装mysql笔记

    目录 前言1 下载mysql二进制安装包2 解压MySQL安装包3 创建MySQL用户4 初始化mysql实例5 首次登录mysql6 修改登录密码 前言 本次安装MySQL的版本是8 0 30的二进制压缩包 xff0c 安装环境是SuSE
  • PostgresSql在linux下源码安装笔记

    目录 前言1 下载源码包并上传2 编译源码并安装3 本地登录PostgreSql4 客户端登录PostgreSql 前言 PostgreSql安装版本是14 5 xff0c 安装环境是Redhat Enterprise Linux serv
  • 判断两个IP地址(ipv4)是否在同一个网段

    我们通常会遇到的ip地址是这样的 xff1a ip地址 xff1a 192 168 227 205 子网掩码 xff1a 255 255 255 0 ip地址 xff1a 192 168 226 202 子网掩码 xff1a 255 255
  • 局域网搭建Linux镜像源

    前言 一般情况在企业的局域网内 xff0c 是不连接外网的 xff0c 所以像阿里云这样的在线的镜像源就用不了 xff0c 我相信大家个人在虚拟机里面连的就是阿里云镜像源了 xff0c 而且局域网内服务器较多的话 xff0c 本地挂载镜像源
  • ubuntu22.04 server安装

    目录 1 安装首页2 选择安装语言3 安装器4 选择键盘布局5 选择安装类型6 设置网络连接7 配置镜像源地址8 磁盘分区9 创建登录用户10 配置安装openssh server11 配置安装其他额外的软件12 开始安装系统13 重启系统
  • linux安装OceanBase数据库

    1 下载OceanBase数据库安装包 OceanBase官网下载页面 2 解压安装包并安装 tar xzf oceanbase all in one 4 0 0 0 beta 100120221102135736 el7 x86 64 t
  • linux下安装mysql客户端client

    1 下载mysql客户端 MySQL的Linux客户端官网下载地址 根据Linux的系统版本选择下载对应的rpm安装包 xff08 如下所示 xff09 xff0c 这里选择的是mysql8 0 27版本的redhat8系列的MySQL客户
  • linux下mysql的三种安装方法

    目录 1 离线安装 xff08 tar gz安装包 xff09 2 离线安装 xff08 rpm安装包 xff09 3 在线安装 xff08 yum安装 xff09 前言 安装环境 Redhat Enterprise Linux 8 1 离
  • linux+window+macos下的JDK安装

    1 Linux中安装JDK xff08 1 xff09 下载Linux版本的jdk压缩包 xff08 2 xff09 解压 tar zxvf 压缩包名 例如 xff1a tar zxvf jdk 8u251 linux x64 tar gz
  • bootstrap-table源码函数解读-sprintf

    var sprintf 61 function str var args 61 arguments flag 61 true i 61 1 str 61 str replace s g function var arg 61 args i
  • openGauss数据库的使用

    目录 前言1 启动 停止 重启数据库 xff08 1 xff09 极简版启动 停止 重启命令 xff08 2 xff09 企业版启动 停止 重启命令 2 登录数据库 xff08 1 xff09 登录数据库时的基本连接参数 xff08 2 x
  • openGauss数据库的安装(2.0.0极简版安装)

    目录 前言1 安装环境准备2 创建用户和用户组3 正式安装4 启动数据库实例并测试 前言 这里主要结合官网的文档 xff0c 安装系统环境是官网推荐的openEuler 20 03LTS openGauss数据库版本是openGauss 2
  • openGauss数据库安装(2.0.0企业版安装)

    目录 1 准备环境2 预安装3 正式安装4 启动并登录数据库 前言 此次数据库的系统安装环境仍然是openEuler20 03LTS openGauss安装版本是2 0 0版本 xff0c 相对于极简版安装 xff0c 确实多了一些工具 x
  • openEuler22.03安装

    目录 1 安装2 登录3 修改登录密码输错限制次数 1 安装 如果在此时没有设置网络 xff0c 那么需要在登录后可以编辑 etc sysconfig network scripts ifcfg ens160文件 xff0c 如下红框部分所
  • Linux查看日志常用命令

    第一种 xff1a 查看实时变化的日志 比较吃内存 最常用的 xff1a tail f app log 默认最后10行 xff0c 相当于增加参数 n 10 tail 200f app log 最后200行 xff0c 某一时刻往前推 Ct
  • ubuntu查看文件和文件夹大小

    在实际使用ubuntu时候 xff0c 经常要碰到需要查看文件以及文件夹大小的情况 有时候 xff0c 自己创建压缩文件 xff0c 可以使用 ls hl 查看文件大小 参数 h 表示Human Readable xff0c 使用GB MB
  • NLTK下载错误的终极解决办法

    Downloading package brown to C Users Ken AppData Roaming nltk data Error downloading 39 brown 39 from lt https raw githu
  • Tensorboard 不显示数据的问题

    No dashboards are active for the current data set Probable causes You haven 39 t written any data to your event files Te
  • Pytorch学习(2)——训练词向量的代码

    教程 xff1a https www bilibili com video BV1vz4y1R7Mm p 61 2 span class token keyword import span torch span class token ke