一、简单实例
import torch as t
from torch import nn
from torch.autograd import Variable
rnn =nn.LSTM(10,20,2)
input =Variable(t.randn(5,3,10))
h0 = Variable(t.zeros(2,3,20))
c0 = Variable(t.zeros(2,3,20))
output,(hn,cn) =rnn(input,(h0,c0))
print('output--->',output.size())
print('hn--->',hn.size())
print('cn--->',cn.size())
注意: output的形状与LSTM的层数无关,只和序列长度有关,而hn和cn则只和层数有关,和序列长度无关。
二、 Pytorcn 做诗
2.1 数据预处理
原始数据:
由于原始数据为.json文件,前包含很多的脏数据在其中,需要对其进行预处理
原始数据来源:https://github.com/chinese-poetry/chinese-poetry
import sys
import os
import json
import re
def parseRawData(author = None, constrain = None,filePath=None):
def sentenceParse(para):
result, number = re.subn("(.*)", "", para)
result, number = re.subn("(.*)", "", para)
result, number = re.subn("{.*}", "", result)
result, number = re.subn("《.*》", "", result)
result, number = re.subn("《.*》", "", result)
result, number = re.subn("[\]\[]", "", result)
r = ""
for s in result:
if s not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-']:
r += s;
r, number = re.subn("。。", "。", r)
return r
def handleJson(file):
"""
获取json文件中的诗句
"""
rst = []
data = json.loads(open(file,encoding='utf8').read())
for poetry in data:
pdata = ""
if (author!=None and poetry.get("author")!=author):
continue
p = poetry.get("paragraphs")
flag = False
for s in p:
sp = re.split("[,!。]", s)
for tr in sp:
if constrain != None and len(tr) != constrain and len(tr)!=0:
flag = True
break
if flag:
break
if flag:
continue
for sentence in poetry.get("paragraphs"):
pdata += sentence
pdata = sentenceParse(pdata)
if pdata!="":
rst.append(pdata)
return rst
data = []
for filename in os.listdir(filePath):
if filename.startswith("poet.tang"):
data.extend(handleJson(src+filename))
return data
def pad_sequences(sequences,
maxlen=None,
dtype='int32',
padding='pre',
truncating='pre',
value=0.):
if not hasattr(sequences, '__len__'):
raise ValueError('`sequences` must be iterable.')
lengths = []
for x in sequences:
if not hasattr(x, '__len__'):
raise ValueError('`sequences` must be a list of iterables. '
'Found non-iterable: ' + str(x))
lengths.append(len(x))
num_samples = len(sequences)
if maxlen is None:
maxlen = np.max(lengths)
sample_shape = tuple()
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
for idx, s in enumerate(sequences):
if not len(s):
continue
if truncating == 'pre':
trunc = s[-maxlen:]
elif truncating == 'post':
trunc = s[:maxlen]
else:
raise ValueError('Truncating type "%s" not understood' % truncating)
trunc = np.asarray(trunc, dtype=dtype)
if trunc.shape[1:] != sample_shape:
raise ValueError(
'Shape of sample %s of sequence at position %s is different from '
'expected shape %s'
% (trunc.shape[1:], idx, sample_shape))
if padding == 'post':
x[idx, :len(trunc)] = trunc
elif padding == 'pre':
x[idx, -len(trunc):] = trunc
else:
raise ValueError('Padding type "%s" not understood' % padding)
return x
def get_data(opt):
"""
@param opt 配置选项 Config对象
@return word2ix: dict,每个字对应的序号,形如u'月'->100
@return ix2word: dict,每个序号对应的字,形如'100'->u'月'
@return data: numpy数组,每一行是一首诗对应的字的下标
"""
if os.path.exists(opt.pickle_path):
data = np.load(opt.pickle_path, allow_pickle=True)
data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item()
return data, word2ix, ix2word
data = parseRawData(filePath=opt.data_path)
words = {_word for _sentence in data for _word in _sentence}
word2ix = {_word: _ix for _ix, _word in enumerate(words)}
word2ix['<EOP>'] = len(word2ix)
word2ix['<START>'] = len(word2ix)
word2ix['</s>'] = len(word2ix)
ix2word = {_ix: _word for _word, _ix in list(word2ix.items())}
for i in range(len(data)):
data[i] = ["<START>"] + list(data[i]) + ["<EOP>"]
new_data = [[word2ix[_word] for _word in _sentence]
for _sentence in data]
pad_data = pad_sequences(new_data,
maxlen=opt.maxlen,
padding='pre',
truncating='post',
value=len(word2ix) - 1)
print(pad_data)
np.savez_compressed(opt.pickle_path,
data=pad_data,
word2ix=word2ix,
ix2word=ix2word)
return pad_data, word2ix, ix2word
涉及到的配置文件
class Config(object):
data_path = 'data/'
pickle_path = 'data/tang.npz'
author = None
constrain = None
lr = 1e-3
weight_decay = 1e-4
DEVICE = torch.device('cuda'if torch.cuda.is_available() else 'cpu')
epoch = 20
batch_size = 128
maxlen = 125
seq_len= 48
plot_every = 20
env = 'poetry'
max_gen_len = 200
debug_file = '/tmp/debugp'
model_path = None
prefix_words = '细雨鱼儿出,微风燕子斜。'
start_words = '闲云潭影日悠悠'
acrostic = False
model_prefix = 'checkpoints/tang'
num_layers=2
opt = Config()
2.2 构建数据集
class PoemDataSet(Dataset):
def __init__(self,opt):
self.seq_len =opt.seq_len
self.opt =opt
self.poem_data,self.word2ix,self.ix2word=self.get_raw_data()
self.no_space_data = self.filter_space()
def __getitem__(self,item):
inputs = self.no_space_data[item*self.seq_len:(item+1)*self.seq_len]
labels = self.no_space_data[item*self.seq_len+1:(item+1)*self.seq_len+1]
inputs = t.from_numpy(np.array(inputs))
labels= t.from_numpy(np.array(labels))
return inputs,labels
def __len__(self):
return int(len(self.no_space_data)/self.seq_len)
def filter_space(self):
t_data =t.from_numpy(self.poem_data).view(-1)
flat_data =t_data.numpy()
no_space_data =[]
for i in flat_data:
if (i!=8292):
no_space_data.append(i)
return no_space_data
def get_raw_data(self):
pad_data,word2ix,ix2word=get_data(self.opt)
return pad_data,word2ix,ix2word
dataset=PoemDataSet(opt)
print(len(dataset))
data= next(iter(dataset))
print('*'*15+'inputs'+'*'*15)
print(data[0])
print([dataset.ix2word[d.item()] for d in data[0]])
print('*'*15+'labels'+'*'*15)
print(data[1])
print([dataset.ix2word[d.item()] for d in data[1]])
2.3 定义模型
class PoetryModel(nn.Module):
def __init__(self,vocab_size,embedding_dim,hidden_dim):
super(PoetryModel,self).__init__()
self.hidden_dim=hidden_dim
self.embeddings=nn.Embedding(vocab_size,embedding_dim)
self.lstm = nn.LSTM(embedding_dim,hidden_dim,num_layers=opt.num_layers,batch_first=True)
self.linear =nn.Linear(hidden_dim,vocab_size)
def forward(self,input,hidden=None):
batch_size,seq_len = input.size()
if hidden is None:
h_0 = input.data.new(opt.num_layers,batch_size,self.hidden_dim).fill_(0).float()
c_0 = input.data.new(opt.num_layers,batch_size,self.hidden_dim).fill_(0).float()
else:
h_0,c_0 = hidden
embeds = self.embeddings(input)
output,hidden =self.lstm(embeds,(h_0,c_0))
output = self.linear(output.contiguous().view(batch_size*seq_len, -1))
return output,hidden
2.4 训练
model =PoetryModel(len(dataset.word2ix),128,256)
if opt.model_path:
model.load_state_dict(t.load(opt.model_path))
model.to(opt.DEVICE)
loss = nn.CrossEntropyLoss()
optim = t.optim.Adam(model.parameters(),lr=opt.lr)
dataloader =DataLoader(dataset,batch_size=opt.batch_size,shuffle=True,num_workers=0)
def train(model,dataloader,ix2word,word2ix,device,optim,loss,epoch):
model.train()
train_loss = 0.0
train_losses = []
for epoch in range(epoch):
for batch_idx, data in enumerate(dataloader):
optim.zero_grad()
inputs,targets =data[0].long().to(device),data[1].long().to(device)
outputs,_ = model(inputs)
los = loss(outputs,targets.view(-1))
los.backward()
optim.step()
train_loss +=los.item()
if (batch_idx+1)%200 ==0:
print('train epoch: {} [{}/{} ({:.0f}%)]\tloss: {:.6f}'.format(epoch, batch_idx * len(data[0]), len(dataloader.dataset),100. * batch_idx / len(dataloader), los.item()))
train_loss *= opt.batch_size
train_loss /=len(dataloader.dataset)
print('\ntrain epoch: {}\t average loss: {:.6f}\n'.format(epoch,train_loss))
train_losses.append(train_loss)
t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch))
2.5 测试
model.load_state_dict(torch.load(opt.model_prefix+'_17.pth'))
def generate(model,start_words,ix2word,word2ix,max_gen_len,prefix_words=None):
start_words =list(start_words)
start_word_len = len(start_words)
results=[]
input = torch.Tensor([word2ix['<START>']]).view(1,1).long()
input =input.to(opt.DEVICE)
model = model.to(opt.DEVICE)
model.eval()
hidden =None
index =0
pred_word =''
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(max_gen_len):
output,hidden = model(input,hidden)
top_index = output.data[0].topk(1)[1][0].item()
w = ix2word[top_index]
if i==0:
w= start_words[index]
index+=1
input = input.data.new([word2ix[w]]).view(1,1)
if pred_word in {'。','!',','}:
if index == start_word_len:
break
else:
w= start_words[index]
index +=1
input = input.data.new([word2ix[w]]).view(1,1)
else:
input = input.data.new([word2ix[w]]).view(1,1)
results.append(w)
pred_word=w
if w =='<EOP>':
del results[-1]
break
return results
result=generate(model,'深度学习',dataset.ix2word,dataset.word2ix,48)
print(result)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)