首先准备好一个本地文件,在此我命名为唐诗三百首.txt如下图(https://img-blog.csdnimg.
图片:
##代码如下
import numpy as np, os
from collections import Counter
from warnings import filterwarnings
filterwarnings('ignore')
from keras.utils import to_categorical
from keras.models import Sequential, load_model
from keras.layers import Conv1D, MaxPool1D, GlobalMaxPool1D, Dense,Flatten
corpus_path = '唐诗三百首.txt'
len_chr = 1000
window = 24
filters = 20
kernel_size = 5
times = 4
batch_size = 250
epochs = 2
window = 24
filepath = 'model.hdf5'
with open(corpus_path, encoding='utf-8') as f:
seq_chr = f.read().replace('\n', '')
len_seq = len(seq_chr)
chr_ls = Counter(list(seq_chr)).most_common(len_chr)
chr_ls = [i[0] for i in chr_ls]
chr2id = {c: i for i, c in enumerate(chr_ls)}
id2chr = {i: c for c, i in chr2id.items()}
seq_id = [chr2id[c] for c in seq_chr]
c2i = lambda c: chr2id.get(c, np.random.randint(len_chr))
输入x,输出y处理:
reshape = lambda x: np.reshape(x, (-1, window, 1)) / len_chr
x = [seq_id[i: i + window] for i in range(len_seq - window)]
x = reshape(x)
y = [seq_id[i + window] for i in range(len_seq - window)]
y = to_categorical(y, num_classes=len_chr)
模型:
def CNNmodel():
model = Sequential()
model.add(Conv1D(filters, kernel_size * 2, padding='same', activation='relu'))
model.add(MaxPool1D())
model.add(Conv1D(filters * 2, kernel_size, padding='same', activation='relu'))
model.add(Flatten())
model.add(Dense(len_chr, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy'])
return model
model=CNNmodel()
model.fit(x,y)
model.save(filepath)
#随机采样:
def draw_sample(predictions, temperature):
pred = predictions.astype('float64')
pred = np.log(pred) / temperature
pred = np.exp(pred)
pred = pred / np.sum(pred)
pred = np.random.multinomial(1, pred, 1)
return np.argmax(pred)
预测函数:
def predict(t, pred):
if t:
print('随机采样,温度:%.1f' % t)
sample = draw_sample
else:
print('贪婪采样')
sample = np.argmax
for _ in range(window):
x_pred = reshape(pred[-window:])
y_pred = model.predict(x_pred)[0]
i = sample(y_pred, t)
pred.append(i)
text = ''.join([id2chr[i] for i in pred[-window:]])
print('\033[033m%s\033[0m' % text)
#加载模型
model = load_model(filepath)
#主函数
if __name__ == '__main__':
while True:
title = input('输入标题').strip() + '。'
len_t = len(title)
randint = np.random.randint(len_seq - window + len_t)
randint = int(randint // 12 * 12)
pred = seq_id[randint: randint + window - len_t] + [c2i(c) for c in title]
for t in (None, 1, 2,3):
predict(t, pred)
预测示范:
总代码
import numpy as np, os
from collections import Counter
from warnings import filterwarnings
filterwarnings('ignore')
from keras.utils import to_categorical
from keras.models import Sequential, load_model
from keras.layers import Conv1D, MaxPool1D, GlobalMaxPool1D, Dense,Flatten
corpus_path = '唐诗三百首.txt'
len_chr = 1000
window = 24
filters = 20
kernel_size = 5
times = 4
batch_size = 250
epochs = 2
window = 24
filepath = 'model.hdf5'
with open(corpus_path, encoding='utf-8') as f:
seq_chr = f.read().replace('\n', '')
len_seq = len(seq_chr)
chr_ls = Counter(list(seq_chr)).most_common(len_chr)
chr_ls = [i[0] for i in chr_ls]
chr2id = {c: i for i, c in enumerate(chr_ls)}
id2chr = {i: c for c, i in chr2id.items()}
seq_id = [chr2id[c] for c in seq_chr]
c2i = lambda c: chr2id.get(c, np.random.randint(len_chr))
reshape = lambda x: np.reshape(x, (-1, window, 1)) / len_chr
x = [seq_id[i: i + window] for i in range(len_seq - window)]
x = reshape(x)
y = [seq_id[i + window] for i in range(len_seq - window)]
y = to_categorical(y, num_classes=len_chr)
def CNNmodel():
model = Sequential()
model.add(Conv1D(filters, kernel_size * 2, padding='same', activation='relu'))
model.add(MaxPool1D())
model.add(Conv1D(filters * 2, kernel_size, padding='same', activation='relu'))
model.add(Flatten())
model.add(Dense(len_chr, activation='softmax'))
model.compile(optimizer='adam', loss='categorical_crossentropy',metrics=['accuracy'])
return model
"""
model=CNNmodel()
model.fit(x,y)
model.save(filepath)
"""
model = load_model(filepath)
def draw_sample(predictions, temperature):
pred = predictions.astype('float64')
pred = np.log(pred) / temperature
pred = np.exp(pred)
pred = pred / np.sum(pred)
pred = np.random.multinomial(1, pred, 1)
return np.argmax(pred)
def predict(t, pred):
if t:
print('随机采样,温度:%.1f' % t)
sample = draw_sample
else:
print('贪婪采样')
sample = np.argmax
for _ in range(window):
x_pred = reshape(pred[-window:])
y_pred = model.predict(x_pred)[0]
i = sample(y_pred, t)
pred.append(i)
text = ''.join([id2chr[i] for i in pred[-window:]])
print('\033[033m%s\033[0m' % text)
if __name__ == '__main__':
while True:
title = input('输入标题').strip() + '。'
len_t = len(title)
randint = np.random.randint(len_seq - window + len_t)
randint = int(randint // 12 * 12)
pred = seq_id[randint: randint + window - len_t] + [c2i(c) for c in title]
for t in (None, 1, 2,3):
predict(t, pred)
数据集低配版链接:唐诗一百首.txt 数据集
高配数据集链接:
数据集:唐诗三百首.txt
数据集决定着最后好坏,几百首还是太少啦。
电气工程的计算机萌新:余登武。写博文不容易,如果你觉得本文对你有用,请点个赞支持下。谢谢。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)