LSTM生成文本(字符级别)

2023-11-12

20200817 -

引言

在网上看到过一些利用深度学习来生成文本的文章,不管生成宋词也好,生成小说也好,各种各样,都是利用深度学习的模型来生成新的东西。之前的时候,我也一直觉得,他们这种生成方式,应该就是记忆性的东西,他并没有真正的从语义的角度上理解这个文章。当然,我自己也是才疏学浅,本身就不是专门搞这种东西的人。
本篇文章中,记录一下我在网上看到的一篇利用LSTM生成文本的文章。需要注意的几个点是
1)训练过程中,输入的是什么
2)根据输出,预测的又是什么
3)最后输出的内容是否可读,又是否有意义,是否有意义是否只能从人的角度来检测

LSTM生成文本

本篇文章主要参考了另一篇文章[1],主要记录一下对数据的处理过程。

问题描述(文本生成)

利用深度学习模型生成文本,是通过已有的文本作为训练集,然后生成新的文本。
但是从我阅读完整个文章来看,他就是学习了训练文本中的一些模式,比如他文章也提到,多少个字符之后就该换行了,然后这个文章就换行了。

数据预处理流程

文章[1]中采用的数据源是《爱丽丝梦游仙境》的文章,同时是针对字符级别来进行预测。

数据输入与输出

在文章[1]中,并没有利用词嵌入的方式来将字符进行向量化,而是统计了全部的字符之后,全部按照ASCII码的数值来统计,还包含了一些特殊字符,比如"\n",","等。从这种处理的方式来看,它就是制造了一种方式,**通过输入训练字符,然后输出字符的形式来生成完整的文本。**那么,它输入的是多大,输出的又是多长的字符呢。下面来具体介绍。

在文章[1]中,它采用的方式是,**定义一个滑动窗口,滑动窗口在整个本文上一直滑动,然后输出是滑动窗口文本的下一个字符。**类似时序数据预测一样的流程,可以从它预处理数据的代码来看。

# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
	seq_in = raw_text[i:i + seq_length]
	seq_out = raw_text[i + seq_length]
	dataX.append([char_to_int[char] for char in seq_in])
	dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns

从代码中可以看到,其采用的方式就是滑动窗口一直滑动到倒数第一个字符,每次选取这些字符的后一个字符作为后续预测的结果。

数据输出的过程

关于具体到底是怎么训练模型的,这里就不不多说了,因为他预测的是一个字符,需要一个多类别的交叉熵作为损失函数,同时将结果进行one-hot编码。下面重点来说一下他们生成文本的过程。

从前文的理解中可以发现,每次输入是一个固定长度的滑动窗口大小的字符串,然后输出一个字符作为预测结果。从模型的结构来说,如果是这种角度的话,那么你的输出必然也是固定长度的内容(当然可以通过一些技巧改变这个长度,这里暂不考虑)。

既然如此,就需要一个种子(模型需要长度的字符串)来驱动模型来进行持续生成,下面来看一下代码。

# pick a random seed
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print "Seed:"
print "\"", ''.join([int_to_char[value] for value in pattern]), "\""
# generate characters
for i in range(1000):
	x = numpy.reshape(pattern, (1, len(pattern), 1))
	x = x / float(n_vocab)
	prediction = model.predict(x, verbose=0)
	index = numpy.argmax(prediction)
	result = int_to_char[index]
	seq_in = [int_to_char[value] for value in pattern]
	sys.stdout.write(result)
	pattern.append(index)
	pattern = pattern[1:len(pattern)]
print "\nDone."

上述代码的整体过程就是,每次将新预测的字符添加到尾部,然后将滑动窗口往后一位,这样就是持续生成了。
注:不过,这让我想起来之前做时序数据的东西的时候,本身你预测出来的东西可能就是错的,你还用错的东西继续来作为输入,这不是积累误差吗。当然,这只是我的理解

小节

从上述的讲解中,基本上明白了,这里的文本生成是通过一个滑动窗口的字符来预测下一个字符。在原文中,也提到了其生成的单词有些的确是没有意义的。所以,看来这里还是有待提升。
当然,这里只是记录一种思路,具体的生成过程还是需要去考虑。
完整代码:

# Load LSTM network and generate text
import sys
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils
# load ascii text and covert to lowercase
filename = "wonderland.txt"
raw_text = open(filename, 'r', encoding='utf-8').read()
raw_text = raw_text.lower()
# create mapping of unique chars to integers, and a reverse mapping
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))
int_to_char = dict((i, c) for i, c in enumerate(chars))
# summarize the loaded data
n_chars = len(raw_text)
n_vocab = len(chars)
print "Total Characters: ", n_chars
print "Total Vocab: ", n_vocab
# prepare the dataset of input to output pairs encoded as integers
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
	seq_in = raw_text[i:i + seq_length]
	seq_out = raw_text[i + seq_length]
	dataX.append([char_to_int[char] for char in seq_in])
	dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
print "Total Patterns: ", n_patterns
# reshape X to be [samples, time steps, features]
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))
# normalize
X = X / float(n_vocab)
# one hot encode the output variable
y = np_utils.to_categorical(dataY)
# define the LSTM model
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(y.shape[1], activation='softmax'))
# load the network weights
filename = "weights-improvement-19-1.9435.hdf5"
model.load_weights(filename)
model.compile(loss='categorical_crossentropy', optimizer='adam')
# pick a random seed
start = numpy.random.randint(0, len(dataX)-1)
pattern = dataX[start]
print "Seed:"
print "\"", ''.join([int_to_char[value] for value in pattern]), "\""
# generate characters
for i in range(1000):
	x = numpy.reshape(pattern, (1, len(pattern), 1))
	x = x / float(n_vocab)
	prediction = model.predict(x, verbose=0)
	index = numpy.argmax(prediction)
	result = int_to_char[index]
	seq_in = [int_to_char[value] for value in pattern]
	sys.stdout.write(result)
	pattern.append(index)
	pattern = pattern[1:len(pattern)]
print "\nDone."

它这里使用checkpoint的方法来记录损失函数最低的模型。

思考

前文已经把文章[1]的整体思路给记录下来了,但是也引发了我的思考。一直以来都有这些问题困扰着我,配合这篇文章来说一下就是,LSTM模型到底学会了什么呢?这个东西我怎么解释呢?每次看到文章总说,LSTM模型能够学到长依赖,但是这个依赖是什么呢?之前使用时序数据的时候,这个依赖可能是利用历史数据来拟合后面数据的数值关系,但是这里又是什么关系呢?这些字符我可以给他任意编码,虽然代码中进行了归一化。
所以,这个我感觉才是我应该思考的东西,这一点其实挺难懂的。

参考

[1]Text Generation With LSTM Recurrent Neural Networks in Python with Keras

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

LSTM生成文本(字符级别) 的相关文章

  • celery任务eta已关闭,使用rabbitmq

    我使用教程中的默认设置和在 ubuntu 上运行的rabbitmq 使 Celery 任务正常进行 当我毫不延迟地安排任务时 一切都很好 但是当我给他们一个预计时间时 他们会被安排在未来 就好像我的时钟在某个地方关闭了一样 下面是一些请求任
  • 如何打印前面有一定数量空格的整数?

    C has printf Xd Y 它只打印整数 X 并使其在控制台窗口上占据 Y 空格 例如 printf 3d 10 console 10 printf 5d 5 console 5 我如何在 python 3 中使用它 This pr
  • 按每个元素中出现的数字对字符串列表进行排序[重复]

    这个问题在这里已经有答案了 我有一个脚本 其目的是对不断下载到服务器上的空间数据集文件进行排序和处理 我的列表目前大致如下 list file t00Z wrff02 grib2 file t00Z wrff03 grib2 file t0
  • virtualenvwrapper 函数在 shell 脚本中不可用

    所以 我再一次制作了一个很棒的 python 程序 它让我的生活变得更加轻松 并节省了大量时间 当然 这涉及到一个 virtualenv 用mkvirtualenvvirtualenvwrapper 的功能 该项目有一个requiremen
  • caffe安装:opencv libpng16.so.16链接问题

    我正在尝试在 Ubuntu 14 04 机器上使用 python 接口编译 caffe 我已经安装了 Anaconda 和 opencvconda install opencv 我还安装了咖啡中规定的所有要求 并更改了注释块makefile
  • NumPy 数组与 SQLite

    我在 Python 中见过的最常见的 SQLite 接口是sqlite3 但是有什么东西可以很好地与 NumPy 数组或 rearray 配合使用吗 我的意思是 它可以识别数据类型 不需要逐行插入 并提取到 NumPy rec 数组中 有点
  • DataFrame.loc 的“索引器太多”

    我读了关于切片器的文档 http pandas pydata org pandas docs stable advanced html using slicers一百万次 但我从来没有理解过它 所以我仍在试图弄清楚如何使用loc切片Data
  • pandas 数据框的最大大小

    我正在尝试使用读取一个有点大的数据集pandas read csv or read stata功能 但我不断遇到Memory Errors 数据帧的最大大小是多少 我的理解是 只要数据适合内存 数据帧就应该没问题 这对我来说不应该是问题 还
  • 无法使用Python请求会话模块登录网站

    我刚刚开始进行网络抓取 对于我的第一个项目 我尝试使用 requests Session 登录 artofproblemsolving com 并访问另一个用户的帐户 这是我的代码 import requests LOGIN URL htt
  • 对法语文本进行词形还原[关闭]

    Closed 这个问题需要多问focused help closed questions 目前不接受答案 我有一些法语文本需要以某种方式进行处理 为此 我需要 首先 将文本标记为单词 然后对这些单词进行词形还原以避免多次处理相同的词根 据我
  • cxfreeze virtualenv 中缺少 distutils 模块

    从 python3 2 项目运行 cxfreeze 二进制文件时 我收到以下运行时错误 project dist project distutils init py 13 UserWarning The virtualenv distuti
  • Python 中“is”运算符的语义是什么?

    如何is运算符确定两个对象是否相同 它是如何工作的 我找不到它的记录 来自文档 http docs python org reference datamodel html 每个对象都有一个身份 一个类型 和一个值 对象的身份 一旦发生就永远
  • 如何在 python 中使用交叉验证执行 GridSearchCV

    我正在执行超参数调整RandomForest如下使用GridSearchCV X np array df features all features y np array df gold standard labels x train x
  • 在python中使用编解码器utf-8打开文件错误

    我在 windows xp 和 python 2 6 4 上执行以下代码 但它显示 IOError 如何打开名称带有 utf 8 编解码器的文件 gt gt gt open unicode txt euc kr encode utf 8 T
  • 如何将 pandas DataFrame 转换为 TimeSeries?

    我正在寻找一种将 DataFrame 转换为 TimeSeries 而不拆分索引和值列的方法 有任何想法吗 谢谢 In 20 import pandas as pd In 21 import numpy as np In 22 dates
  • 在Python中随机交错2个数组

    假设我有两个数组 a 1 2 3 4 b 5 6 7 8 9 我想将这两个数组交错为变量 c 注意 a 和 b 不一定具有相同的长度 但我不希望它们以确定性的方式交错 简而言之 仅仅压缩这两个数组是不够的 我不想要 c 1 5 2 6 3
  • 如何在类型提示中定义元组或列表的大小

    有没有办法在参数的类型提示中定义元组或列表的大小 目前我正在使用这样的东西 from typing import List Optional Tuple def function name self list1 List Class1 if
  • 将二进制数据视为文件对象?

    在此代码片段 由另一个人编写 中 self archive是一个大文件的路径并且raw file是以二进制数据形式读取的文件内容 with open self archive rb as f f seek offset raw file s
  • Django 中使用外键的抽象基类继承

    我正在尝试在 Django 支持的网站上进行模型继承 以遵守 DRY 我的目标是使用一个名为 BasicCompany 的抽象基类来为三个子类提供通用信息 Butcher Baker CandlestickMaker 它们位于各自的应用程序
  • 如何创建简单的梯度下降算法

    我正在研究简单的机器学习算法 从简单的梯度下降开始 但在尝试用 python 实现它时遇到了一些麻烦 这是我试图重现的示例 我获得了有关房屋的数据 居住面积 以英尺为单位 和卧室数量 以及最终的价格 居住面积 英尺2 2104 卧室 3 价

随机推荐

  • 程序媛菜鸡面经(八 - offer篇)

    投简历 简历是要多投的 但是有时候投多了简历也会有问题 头条 没有面试机会 在看过简历后HR发邮件告知我 从简历上能看出你是一位很优秀的人 但看不出你在前端 技术方面的竞争力 当时投的是旧版简历 于是我回邮问简历有误能否重申 至今未有回音
  • 子网掩码的作用

    IP地址由网络和主机两部分标识组成 IP地址由 网络标识 网络地址 和 主机标识 主机地址 两部分组成 在局域网内相互间通信的网络必须具有相同网络地址 也叫相同的网段 在同一个网段内每个设备的主机地址都不相同 在IPV4中 IP地址由32位
  • Vue中query与params两种传参的区别

    query语法 this router push path 地址 query id 123 这是传递参数 this route query id 这是接受参数 params语法 this router push name 地址 params
  • linux系统哪个好用

    linux系统哪个好用 1 Ubuntu服务器 Ubuntu是众所周知的最佳LinuxServerDistro 它能为您提供出色的用户体验 如果你是Linux世界的新手 选择Ubuntu作为你的服务器发行版将是最好的 使用此服务器 您可以做
  • Mac系统如何在圣诞节让电脑屏幕下雪?

    对于苹果 Mac 电脑上的 终端 应用 可能大家在平时用得不多 所以对它应该都会比较陌生 其实这个终端应用是用于让用户可以直接输入一些系统指令 让它执行相应的操作 比如简单的显示当前目录中的文件 显示日期与时间 删除文件等操作都是可以的 今
  • Android项目Gradle: Download gradle-6.5-bin.zip一直卡住解决方法

    1 首先停止gradle的下载 通过迅雷或浏览器将gradle下载下来 下载地址为 https services gradle org distributions gradle 6 5 bin zip 其他版本的gradle同理 2 打开C
  • 二级MS Office高级应用

    1 在长度为n的有序线性表中进行二分查找 最坏的情况下需要比较的次数是 O log2n 以2为底n对数 解析 当有序线性表为顺序存储时才可以用二分查找 可以证明的是对于长度为n的有序线性表 最坏的情况下 二分查找只需要比较O log2n 次
  • 数据仓库开发之路之一--准备工作

    在数据仓库的开发过程中 需要熟悉大量的概念以及相关工具的使用 还需要了解宏观上的各种开发流程 串联起来完成最终的数据仓库项目的开发 本篇介绍一些准备工作 包括涉及到的工具介绍 以及开发过程的描述 记录学习研究的印记 并和大家讨论研究存在的相
  • conda upgrade --all惹的祸,该怎么解决?

    本想要安装scikit surprise库 由于环境问题 就更新一下 谁知道差点酿成大祸 anaconda不灵了 无论什么语句都报错 jupyter notebook 不能用 navigator也打不开 万念俱灰了 导致我想要重装anaco
  • atx860和java_捷安特XTC800和ATX860有什么区别

    展开全部 区别比较大 简单说 ATX 8xx就是e69da5e887aa62616964757a686964616f31333431353237ATX 6xx的 局部升级 轮组由26寸换为27 5寸 车架外观改进 变速套件等级略微提高 仅此
  • mmclassification 训练自定义数据

    1 mmclassification 安装 如果环境已安装mmclassification 请跳过该步骤 mmclassification框架安装与调试验证请参考博客 mmclassification安装与调试 Coding的叶子的博客 C
  • STM32基于IIC协议的温湿度(AHT20)采集

    STM32基于IIC协议的温湿度 AHT20 采集 文章目录 STM32基于IIC协议的温湿度 AHT20 采集 1 IIC总线协议 1 1 什么是IIC协议 1 2 IIC协议的物理层和协议层 1 2 1 物理层 1 2 2 协议层 1
  • orm模型的查询方法集合

    目录 3 4 1 基本查询 3 4 2过滤查询 3 4 2 2 模糊查询 3 4 2 3 空查询 3 4 2 4 范围查询 3 4 2 5 比较查询 3 4 2 6 日期查询 3 4 3 1 F对象 3 4 3 2 Q对象 values 返
  • Aborted (core dumped) Assertion `Engine.getNbBindings() == 4' failed.

    记录一次特别粗心的错误 错误代码位置 assert的作用是现计算表达式 expression 如果其值为假 即为0 那么它先向stderr打印一条出错信息 然后通过调用 abort 来终止程序运行 需要 inputname 3 output
  • 垂直广告是什么意思_网上常看到带货这个词,那么带货到底是什么意思?又要怎么通过平台带货呢?...

    网上常看到带货这个词 那么带货到底是什么意思 又要怎么通过平台带货呢 直播带货就是通过短视频平台 吸引消费者来购买自己所售卖的产品 可以投放广告或是与达人合作进行带货 短视频 品牌营销优势 新一代广告宠儿 5G时代即将来临 人们越来越习惯且
  • HTML5----FormData实例用法

    ajax 异步上传文件 1 前言 在网页与后台的交互中 用的最多的网络交互方式之一就是ajax ajax 是免刷新页面就能从进行post与get方式的提交表单和获取服务端数据 但是在原先的ajax中 是不能携带文件上传的 但是由于h5里面的
  • Mysql数据库手册

    数据库基本概念 1 数据库 就是数据的仓库 由表 关系 操作对象组成 2 表 由行和列组成 数据都存放在表中 由于mysql是关系数据库 所以表又被称为关系 3 字段 就是属性 4 记录 一行数据就是一条记录 也是一条实体 需要设置主键 5
  • 基于惯性动作捕捉技术进行快速动画制作教程

    长久以来动画制作流程上有着诸多不可回避的问题 尤其在于角色动画的制作周期和动画效果方面 一般来说 每一秒钟的角色动画都需要动画师手动关键帧制作耗费8小时才能完成 也就是说 一个动画师每个月只能制作出22秒动画 动作捕捉技术为动画制作者带来福
  • Elasticsearch-基本命令

    基本命令 创建索引 添加数据 删除数据 简单查询 复杂查询1 复杂查询2 获得所有index 获得所有mapping type 根据某个字段精确查找 api的分组查询 bool查询 创建索引 put http localhost 9200
  • LSTM生成文本(字符级别)

    20200817 引言 在网上看到过一些利用深度学习来生成文本的文章 不管生成宋词也好 生成小说也好 各种各样 都是利用深度学习的模型来生成新的东西 之前的时候 我也一直觉得 他们这种生成方式 应该就是记忆性的东西 他并没有真正的从语义的角