Pytorch 实战RNN

2023-05-16

一、简单实例

# coding:utf8
import torch as t
from torch import nn
from torch.autograd import Variable

#输入词用10维词向量表示
#隐藏层用20维向量表示
#两层LSTM
rnn =nn.LSTM(10,20,2)

# 输入每句话有5个词
# 每个词由10维的词向量表示
#总共有3句话(batch size)
input =Variable(t.randn(5,3,10))

# 隐藏元的初始值(num_layers, batch_size,hidden_size)
h0 = Variable(t.zeros(2,3,20))
c0 = Variable(t.zeros(2,3,20))

# output 是最后一层所有所有隐藏元的值
# hn 和cn 是所有层(这里是两层)的最后一个隐藏元的值
output,(hn,cn) =rnn(input,(h0,c0))

print('output--->',output.size())
print('hn--->',hn.size())
print('cn--->',cn.size())

结果
注意: output的形状与LSTM的层数无关,只和序列长度有关,而hn和cn则只和层数有关,和序列长度无关。

二、 Pytorcn 做诗

2.1 数据预处理

原始数据:
在这里插入图片描述
由于原始数据为.json文件,前包含很多的脏数据在其中,需要对其进行预处理
原始数据来源:https://github.com/chinese-poetry/chinese-poetry

#coding:utf-8
import sys
import os
import json
import re

def parseRawData(author = None, constrain = None,filePath=None):
 
    def sentenceParse(para):
        # para = "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡,生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。)好是去塵俗,煙花長一欄。"
        #subn统计sub替换次数
        #清洗数据,去掉诗词中()、《》、{}中的内容
        result, number = re.subn("(.*)", "", para)
        result, number = re.subn("(.*)", "", para)
        result, number = re.subn("{.*}", "", result)
        result, number = re.subn("《.*》", "", result)
        result, number = re.subn("《.*》", "", result)
        result, number = re.subn("[\]\[]", "", result)
        r = ""
        for s in result:
            
            if s not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-']:
                r += s;
        r, number = re.subn("。。", "。", r)
        return r

    def handleJson(file):
        """
        获取json文件中的诗句
        """
        rst = []
        data = json.loads(open(file,encoding='utf8').read())
        for poetry in data:
            pdata = ""
            if (author!=None and poetry.get("author")!=author):
                continue
            p = poetry.get("paragraphs")
            flag = False
            for s in p:
                sp = re.split("[,!。]", s)
                for tr in sp:
                    if constrain != None and len(tr) != constrain and len(tr)!=0:
                        flag = True
                        break
                    if flag:
                        break
            if flag:
                continue
            for sentence in poetry.get("paragraphs"):
                pdata += sentence
            pdata = sentenceParse(pdata)
            if pdata!="":
                rst.append(pdata)
        return rst
    # print sentenceParse("")
    data = []
    for filename in os.listdir(filePath):
        if filename.startswith("poet.tang"):
            data.extend(handleJson(src+filename))
    return data

def pad_sequences(sequences,
                  maxlen=None,
                  dtype='int32',
                  padding='pre',
                  truncating='pre',
                  value=0.):
 
    if not hasattr(sequences, '__len__'):
        raise ValueError('`sequences` must be iterable.')
    lengths = []
    for x in sequences:
        if not hasattr(x, '__len__'):
            raise ValueError('`sequences` must be a list of iterables. '
                             'Found non-iterable: ' + str(x))
        #存放每首诗的长度
        lengths.append(len(x))

    #诗的数量
    num_samples = len(sequences)
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.
    sample_shape = tuple()
    
    #for s in sequences:
        
#         if len(s) > 0:  # pylint: disable=g-explicit-length-test
#             sample_shape = np.asarray(s).shape[1:]
#             break

    #构造大小为(诗的数量,诗最长长度)的元组
    x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
    #print(x.shape)
    for idx, s in enumerate(sequences):
        if not len(s):  # pylint: disable=g-explicit-length-test
            continue  # empty list/array was found
        
         #截取方式:从后边截取还是从前边截取
        if truncating == 'pre':
            trunc = s[-maxlen:]  # pylint: disable=invalid-unary-operand-type
            
        elif truncating == 'post':
            trunc = s[:maxlen]
            #print(trunc)
        else:
            raise ValueError('Truncating type "%s" not understood' % truncating)

        # check `trunc` has expected shape
        trunc = np.asarray(trunc, dtype=dtype)
        if trunc.shape[1:] != sample_shape:
            raise ValueError(
                'Shape of sample %s of sequence at position %s is different from '
                'expected shape %s'
                % (trunc.shape[1:], idx, sample_shape))

        #填充方式,从后边填充还是从前边填充
        if padding == 'post':
            x[idx, :len(trunc)] = trunc
        elif padding == 'pre':
            x[idx, -len(trunc):] = trunc
        else:
            raise ValueError('Padding type "%s" not understood' % padding)
    return x

def get_data(opt):
    """
    @param opt 配置选项 Config对象
    @return word2ix: dict,每个字对应的序号,形如u'月'->100
    @return ix2word: dict,每个序号对应的字,形如'100'->u'月'
    @return data: numpy数组,每一行是一首诗对应的字的下标
    """
    if os.path.exists(opt.pickle_path):
        data = np.load(opt.pickle_path, allow_pickle=True)
        data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item()
        return data, word2ix, ix2word

    # 如果没有处理好的二进制文件,则处理原始的json文件
    data = parseRawData(filePath=opt.data_path)
    words = {_word for _sentence in data for _word in _sentence}
    word2ix = {_word: _ix for _ix, _word in enumerate(words)}
    word2ix['<EOP>'] = len(word2ix)  # 终止标识符
    word2ix['<START>'] = len(word2ix)  # 起始标识符
    word2ix['</s>'] = len(word2ix)  # 空格
    ix2word = {_ix: _word for _word, _ix in list(word2ix.items())}

    # 为每首诗歌加上起始符和终止符
    for i in range(len(data)):
        data[i] = ["<START>"] + list(data[i]) + ["<EOP>"]
    # 将每首诗歌保存的内容由‘字’变成‘数’
    # 形如[春,江,花,月,夜]变成[1,2,3,4,5]
    new_data = [[word2ix[_word] for _word in _sentence]
                for _sentence in data]
    
    # 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的
    pad_data = pad_sequences(new_data,
                             maxlen=opt.maxlen,
                             padding='pre',
                             truncating='post',
                             value=len(word2ix) - 1)

    # 保存成二进制文件
    print(pad_data)
    np.savez_compressed(opt.pickle_path,
                        data=pad_data,
                        word2ix=word2ix,
                        ix2word=ix2word)
    return pad_data, word2ix, ix2word

涉及到的配置文件

class Config(object):
 data_path = 'data/'  # 诗歌的文本文件存放路径
 pickle_path = 'data/tang.npz'  # 预处理好的二进制文件
 author = None  # 只学习某位作者的诗歌
 constrain = None  # 长度限制
 lr = 1e-3
 weight_decay = 1e-4
 DEVICE = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
 epoch = 20
 batch_size = 128
 maxlen = 125  # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格
 seq_len= 48 #由于唐诗主要是五言绝句和七言绝句,各自加上一个标点符号为 6和8,选择一个公约数48,这样刚好凑够48
 plot_every = 20  # 每20个batch 可视化一次
 # use_env = True # 是否使用visodm
 env = 'poetry'  # visdom env
 max_gen_len = 200  # 生成诗歌最长长度
 debug_file = '/tmp/debugp'
 model_path = None  # 预训练模型路径
 prefix_words = '细雨鱼儿出,微风燕子斜。'  # 不是诗歌的组成部分,用来控制生成诗歌的意境
 start_words = '闲云潭影日悠悠'  # 诗歌开始
 acrostic = False  # 是否是藏头诗
 model_prefix = 'checkpoints/tang'  # 模型保存路径
 num_layers=2

opt = Config()

2.2 构建数据集

class PoemDataSet(Dataset):
 
 def __init__(self,opt):
     self.seq_len =opt.seq_len #诗的长度
     self.opt =opt #配置
     self.poem_data,self.word2ix,self.ix2word=self.get_raw_data() #数据集
     self.no_space_data = self.filter_space()
     
 def __getitem__(self,item): #获取每一行数据
     
     #每首诗的下一个字为上一个字的目标值
     inputs = self.no_space_data[item*self.seq_len:(item+1)*self.seq_len]
     labels = self.no_space_data[item*self.seq_len+1:(item+1)*self.seq_len+1]
     inputs = t.from_numpy(np.array(inputs))
     labels= t.from_numpy(np.array(labels))
     return inputs,labels
 
 def __len__(self): 
     return int(len(self.no_space_data)/self.seq_len)
 
 def filter_space(self):#将空格的数据过滤掉,并将原始数据平整到一维
     t_data =t.from_numpy(self.poem_data).view(-1)
     flat_data =t_data.numpy()
     no_space_data =[]
     for i in flat_data:
         if (i!=8292):
             no_space_data.append(i)
     return no_space_data
 
 def get_raw_data(self): #获取数据
     pad_data,word2ix,ix2word=get_data(self.opt)
     return pad_data,word2ix,ix2word
#获取数据
dataset=PoemDataSet(opt)
print(len(dataset))
data= next(iter(dataset))

print('*'*15+'inputs'+'*'*15)
print(data[0])
print([dataset.ix2word[d.item()] for d in data[0]])
print('*'*15+'labels'+'*'*15)
print(data[1])
print([dataset.ix2word[d.item()] for d in data[1]])

在这里插入图片描述

2.3 定义模型

class PoetryModel(nn.Module):
    def __init__(self,vocab_size,embedding_dim,hidden_dim):
        super(PoetryModel,self).__init__()
        self.hidden_dim=hidden_dim
        # 词向量层,词表大小 * 向量维度
        self.embeddings=nn.Embedding(vocab_size,embedding_dim)
        # 网络主要结构
        self.lstm = nn.LSTM(embedding_dim,hidden_dim,num_layers=opt.num_layers,batch_first=True)
        # 进行分类
        self.linear =nn.Linear(hidden_dim,vocab_size)
        
    
    def forward(self,input,hidden=None):
        batch_size,seq_len = input.size()
        
        if hidden is None:
            #h_0: (num_layers*bidirectional,batch_size,hidden_dim)
            #c_0: (num_layers*bidirectional,batch_size,hidden_dim)
            h_0 = input.data.new(opt.num_layers,batch_size,self.hidden_dim).fill_(0).float()
            c_0 = input.data.new(opt.num_layers,batch_size,self.hidden_dim).fill_(0).float()
        else:
            h_0,c_0 = hidden
            
        # 输入 序列长度 * batch(每个汉字是一个数字下标),
        # 输出 序列长度 * batch * 向量维度
        embeds = self.embeddings(input)
        
        # output=batch_size * seq_len *  (num_directions*hidden_size)   
        # hn =(num_layers*num_directions) * batch_size * hidden_size
        output,hidden =self.lstm(embeds,(h_0,c_0))
        #print(output.size())
        output = self.linear(output.contiguous().view(batch_size*seq_len, -1))
        return output,hidden

2.4 训练

#模型定义
model =PoetryModel(len(dataset.word2ix),128,256)

if opt.model_path:
    model.load_state_dict(t.load(opt.model_path))
model.to(opt.DEVICE)

#损失函数
loss = nn.CrossEntropyLoss()
# loss = nn.BCELoss()  #二分类交叉熵损失函数
# loss = nn.BCEWithLogitsLoss() #二分类交叉熵损失函数 带log loss
# loss = nn.MSELoss()

#优化器
optim = t.optim.Adam(model.parameters(),lr=opt.lr)
#也可以选择SGD优化方法
# optimizer = torch.optim.SGD(model.parameters(),lr=1e-2)
 
#scheduler =StepLR(optim,step_size=10)

#获取数据
dataloader =DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=0)
#模型训练
def train(model,dataloader,ix2word,word2ix,device,optim,loss,epoch):
    
    model.train()
    train_loss = 0.0
    train_losses = []
    for epoch in range(epoch):
        for batch_idx, data in enumerate(dataloader):

            #如果模型没有设置 batch_first=True(将批次维度放到第一位)的话,需要将0维和1维互换位置。
            #transpose()函数的作用就是调换数组的行列值的索引值,类似于求矩阵的转置:
            # data = data.long().transpose(1,0).contiguous()
            # input, target = data[:-1, :], data[1:, :]
            optim.zero_grad()
            #inputs: batch_size * seq_len
            #targets: batch_size * seq_len
            inputs,targets =data[0].long().to(device),data[1].long().to(device)

            #模型训练
            outputs,_ = model(inputs)
            #print(output)
            los = loss(outputs,targets.view(-1))
            los.backward()
            optim.step()
            train_loss +=los.item()

            if (batch_idx+1)%200 ==0:
                print('train epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format(epoch, batch_idx * len(data[0]), len(dataloader.dataset),100. * batch_idx / len(dataloader), los.item()))

        train_loss *= opt.batch_size
        train_loss /=len(dataloader.dataset)
        print('\ntrain epoch: {}\t average loss: {:.6f}\n'.format(epoch,train_loss))
        train_losses.append(train_loss)
    
        #保存模型
        t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))

2.5 测试

model.load_state_dict(torch.load(opt.model_prefix+'_17.pth'))  # 模型加载

def generate(model,start_words,ix2word,word2ix,max_gen_len,prefix_words=None):
    #读取唐诗的第一句
    start_words =list(start_words)
    start_word_len = len(start_words)
    results=[]
    
    #设置第一个词为<START>
    input = torch.Tensor([word2ix['<START>']]).view(1,1).long()
    input =input.to(opt.DEVICE)
    model = model.to(opt.DEVICE)
    model.eval()
    
    hidden =None
    index =0 #指示生成了多少句
    pred_word ='' #上一个词
    
    #控制意境
    if prefix_words:
        for word in prefix_words:
            output,hidden = model(input,hidden)
            input = Variable(input.data.new([word2ix[word]])).view(1,1)
            
    #生成藏头诗
    for i in range(max_gen_len): #诗的长度
        output,hidden = model(input,hidden)
        #获取生成字的index
        top_index = output.data[0].topk(1)[1][0].item()
        w = ix2word[top_index]
        #诗的第一个字必须是输入关键词的第一个
        if i==0:
            w= start_words[index]
            index+=1
            input = input.data.new([word2ix[w]]).view(1,1) 
        #如果遇到标志一句话的结尾,喂入下一个“头”
        if pred_word in {'。','!',','}:
            #如果生成的诗已经包含全部“头”,则介绍
            if index == start_word_len:
                break
            #把‘头’作为输入喂入模型
            else:
                w= start_words[index]
                index +=1
                input = input.data.new([word2ix[w]]).view(1,1)
        
        #否则,把上一次预测作为下一个词的输入
        else: 
             input = input.data.new([word2ix[w]]).view(1,1)
        results.append(w)
        pred_word=w
        # 结束标志
        if w =='<EOP>':
            del results[-1]
            break
    return results
result=generate(model,'深度学习',dataset.ix2word,dataset.word2ix,48)
print(result)

在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch 实战RNN 的相关文章

随机推荐

  • ftp命令大全详解

    来熟悉熟悉ftp命令 xff0c 对于服务器之间的文件传输太有用啦 xff0c 不会怎么能行呢 xff01 先来看看基础的命令 xff0c 包括了连接 xff0c 列出列表 xff0c 下载 xff0c 上传 xff0c 断开这最基础的命令
  • TCP/IP,Linux中使用信号量控制运行中的进程,使用signal函数绑定信号量和处理函数,替换信号量默认功能,信号量会打断sleep的休眠状态

    TCP IP xff0c Linux中使用信号量控制运行中的进程 xff0c 绑定信号量和处理函数 xff0c 信号量会打断sleep的休眠状态 一 Linux中使用信号量对进程的调控 xff1a 1 信号量是一个int值 xff0c 由操
  • 几种经典非线性滤波算法简单概括(EKF,UKF,CKF,PF)

    几种经典非线性滤波算法概括 xff08 EKF xff0c UKF xff0c CKF xff0c PF xff09 上一篇文章阐述了Kalman滤波算法 xff0c 该算法是在线性高斯下的最优滤波估计算法 但是在实际控制系统中 xff0c
  • 扩展卡尔曼滤波(EKF)算法详细推导及仿真(Matlab)

    扩展卡尔曼滤波 xff08 EKF xff09 算法详细推导及仿真 xff08 Matlab xff09 扩展卡尔曼滤波算法是解决非线性状态估计问题最为直接的一种处理方法 xff0c 尽管EKF不是最精确的 最优 滤波器 xff0c 但在过
  • uio驱动编写 实例1

    AUTHOR xff1a Joseph Yang 杨红刚 lt eagle rtlinux 64 gmail com gt CONTENT uio驱动编写 实例1 NOTE xff1a linux 3 0 LAST MODIFIED xff
  • raspberry pi pico, 如何在macos平台使用picoprobe,vscode来debug程序

    debugprobe 80元人民币 再买一块pico 刷上debug程序 xff0c 仅要16元 xff0c 当然用便宜的 在mac上的vs code总是遇见问题 单独运行openocd时 xff0c 也有问题 xff0c 出现错误 CMS
  • 1—类、域、方法和实例对象

    Java 是面向对象的高级编程语言 xff0c 类和对象是 Java 程序的构成核心 围绕着 Java 类和 Java 对象 xff0c 有三大基本特性 xff1a 封装是 Java 类的编写规范 继承是类与类之间联系的一种形式 而多态为系
  • 常用数学公式汇总

    常用数学公式汇总 一 基础代数公式 1 平方差公式 xff1a xff08 a xff0b b xff09 xff08 a xff0d b xff09 xff1d a2 xff0d b2 2 完全平方公式 xff1a xff08 a b x
  • Kubernetes--API Server资源隔离

    Kubernetes的一些功能特性也与公有云提供商密切相关 xff0c 例如 xff1a 负载均衡服务 弹性公网IP 存储服务等 xff0c 具体实现也需要与API Server通信 xff0c 也属于运行商内部重点保障的安全区域 此外 x
  • 公式提取方法

    Mathpix Snipping Tool和MathType配合用法 Mathpix Snipping Tool是一个可以提取数学公式的工具 xff0c 当我们写毕业论文或者结课报告或者参加数学建模等比赛的用到的公式 xff0c 可以用这款
  • (学习unix编程)关于文件流与文件描述符的区别

    文件描述符 xff08 就是整数 xff09 用于在一个进程内唯一的标识打开的文件 这假定了内核能够在用户进程的描述符和内核内部使用的机构之间 xff0c 建立一种关联 xff08 深入linux内核架构 xff09 由于唯一标识进程的结构
  • 2000页kubernetes操作手册,内容详细代码清晰,小白也能看懂

    现如今 xff0c Kubernetes业务已成长为新时代的IT基础设施 xff0c 并成为高级运维工程师 架构师 后端开发工程师的必修技术栈 毫无疑问 xff0c Kubernetes是云计算发展演进的一次彻底革命性的突破 xff0c 只
  • Linux安装nodejs和npm

    最近window系统转向linux系统开发 xff0c linux系统的确适合程序员的开发 作为前端安装了nodejs和npm xff0c 遇到了一些坑 xff0c 赶紧记录下来 第一种安装方法 xff1a 安装nodejs xff1a s
  • 查看core dumped的详细错误原因

    什么是Core Dump Core的意思是内存 Dump的意思是扔出来 堆出来 开发和使用Unix程序时 有时程序莫名其妙的down了 却没有任何的提示 有时候会提示core dumped 这时候可以查看一下有没有形如core 进程号的文件
  • IntelliJ IDEA创建Servlet最新方法 Idea版本2020.2.2以及IntelliJ IDEA创建Servlet 404问题(超详细)

    第一次用IntelliJ IDEA写java代码 xff0c 之前都是用eclipse xff0c 但eclipse太老了 下面为兄弟们奉上IntelliJ IDEA创建Servlet方法 xff0c 写这个的目的也是因为在网上找了很多资料
  • Linux下做C语言/C++开发的一些建议

    相对于Linux下的C C 43 43 开发 xff0c 在windows下的初学者往往容易入门 xff0c 原因是visual studio 这个强大的工具隐藏了很多的细节 xff0c 好多人甚至以为拖拖控件 xff0c 写写消息响应函数
  • Target ‘STM32F4xx‘ uses ARM-Compiler ‘Default Compiler Version 5‘ which is not available.找不到v5版本解决方法

    现在官网上没有v5的版本了 xff0c keil默认安装的是v6的版本 xff0c 如果工程想要运行以前的工程 xff0c 可以设置将工程的编辑器从v5转到v6 xff0c 下面是方法 xff1a 1 使用MDK打开工程 2 选择 Proj
  • 关于imu的介绍

    1 imu时惯性运动丹云 xff0c 包含加速度计和陀螺传感器的组合 它被用来检查加速度和角速度 xff08 IMU传感器 xff0c 你所需要知道的全部 知乎 xff09 虽然时外文翻译的 xff0c 凡是整体风格清晰 2 imu的使用
  • LSTM与GRU

    LSTM 与 GRU 一 综述 LSTM 与 GRU是RNN的变种 xff0c 由于RNN存在梯度消失或梯度爆炸的问题 xff0c 所以RNN很难将信息从较早的时间步传送到后面的时间步 LSTM和GRU引入门 xff08 gate xff0
  • Pytorch 实战RNN

    一 简单实例 span class token comment coding utf8 span span class token keyword import span torch span class token keyword as