使用 LSTM 教程代码来预测句子中的下一个单词?

2023-12-27

我一直在尝试理解示例代码https://www.tensorflow.org/tutorials/recurrent https://www.tensorflow.org/tutorials/recurrent你可以在以下位置找到https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

(使用张量流1.3.0。)

我总结了(我认为的)我的问题的关键部分,如下:

 size = 200
 vocab_size = 10000
 layers = 2
 # input_.input_data is a 2D tensor [batch_size, num_steps] of
 #    word ids, from 1 to 10000

 cell = tf.contrib.rnn.MultiRNNCell(
    [tf.contrib.rnn.BasicLSTMCell(size) for _ in range(2)]
    )

 embedding = tf.get_variable(
      "embedding", [vocab_size, size], dtype=tf.float32)
 inputs = tf.nn.embedding_lookup(embedding, input_.input_data)

inputs = tf.unstack(inputs, num=num_steps, axis=1)
outputs, state = tf.contrib.rnn.static_rnn(
    cell, inputs, initial_state=self._initial_state)

output = tf.reshape(tf.stack(axis=1, values=outputs), [-1, size])
softmax_w = tf.get_variable(
    "softmax_w", [size, vocab_size], dtype=data_type())
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
logits = tf.matmul(output, softmax_w) + softmax_b

# Then calculate loss, do gradient descent, etc.

我最大的问题是给定句子的前几个单词,如何使用生成的模型实际生成下一个单词建议?具体来说,我想象流程是这样的,但我无法理解注释行的代码是什么:

prefix = ["What", "is", "your"]
state = #Zeroes
# Call static_rnn(cell) once for each word in prefix to initialize state
# Use final output to set a string, next_word
print(next_word)

我的子问题是:

  • 为什么使用随机(未初始化、未经训练的)词嵌入?
  • 为什么使用softmax?
  • 隐藏层是否必须与输入的维度匹配(即 word2vec 嵌入的维度)
  • 我如何/可以引入预先训练的 word2vec 模型,而不是未初始化的模型?

(我将它们作为一个问题来问,因为我怀疑它们都是相互关联的,并且与我的理解中的某些差距有关。)

我期望在这里看到的是加载现有的 word2vec 词嵌入集(例如使用 gensim 的KeyedVectors.load_word2vec_format()),在加载每个句子时将输入语料库中的每个单词转换为该表示,然后 LSTM 会吐出相同维度的向量,我们将尝试找到最相似的单词(例如使用 gensim 的similar_by_vector(y, topn=1)).

使用 softmax 是否可以让我们免于相对较慢的速度similar_by_vector(y, topn=1) call?


顺便说一句,对于我的问题中预先存在的 word2vec 部分使用预训练的 word2vec 和 LSTM 进行单词生成 https://stackoverflow.com/q/42064690/841830很相似。然而,目前那里的答案并不是我正在寻找的。我希望得到一个简单的英语解释,为我打开灯,并填补我理解中的任何空白。在lstm语言模型中使用预训练的word2vec? https://stackoverflow.com/questions/44614097/use-pre-trained-word2vec-in-lstm-language-model是另一个类似的问题。

UPDATE: 使用语言模型张量流示例预测下一个单词 https://stackoverflow.com/q/33773661/841830 and 使用 LSTM ptb 模型张量流示例预测下一个单词 https://stackoverflow.com/q/36286594/841830是类似的问题。然而,两者都没有显示代码实际获取句子的前几个单词,并打印出对下一个单词的预测。我尝试粘贴第二个问题的代码,以及https://stackoverflow.com/a/39282697/841830 https://stackoverflow.com/a/39282697/841830(附带一个 github 分支),但无法在没有错误的情况下运行。我认为它们可能适用于 TensorFlow 的早期版本?

另一个更新:还有一个问题询问基本上相同的事情:从 Tensorflow 示例预测 LSTM 模型的下一个单词 https://stackoverflow.com/q/42333101/841830它链接到使用语言模型张量流示例预测下一个单词 https://stackoverflow.com/q/33773661/841830(再说一次,那里的答案并不完全是我想要的)。

如果还不清楚,我正在尝试编写一个名为的高级函数getNextWord(model, sentencePrefix), where model是我从磁盘加载的先前构建的 LSTM,并且sentencePrefix是一个字符串,例如“Open the”,它可能返回“pod”。然后我可能会用“Open the pod”来调用它,它将返回“bay”,依此类推。

一个示例(使用字符 RNN,并使用 mxnet)是sample()函数显示在接近末尾处https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/simple-rnn.ipynb https://github.com/zackchase/mxnet-the-straight-dope/blob/master/chapter05_recurrent-neural-networks/simple-rnn.ipynb您可以致电sample()在训练期间,但您也可以在训练后调用它,并使用您想要的任何句子。


主要问题

加载单词

加载自定义数据而不是使用测试集:

reader.py@ptb_raw_data

test_path = os.path.join(data_path, "ptb.test.txt")
test_data = _file_to_word_ids(test_path, word_to_id)  # change this line

test_data应包含单词 ID(打印出word_to_id用于映射)。例如,它应该如下所示: [1, 52, 562, 246] ...

显示预测

我们需要返回 FC 层的输出(logits) 在调用中sess.run

ptb_word_lm.py@PTBModel.__init__

    logits = tf.reshape(logits, [self.batch_size, self.num_steps, vocab_size])
    self.top_word_id = tf.argmax(logits, axis=2)  # add this line

ptb_word_lm.py@run_epoch

  fetches = {
      "cost": model.cost,
      "final_state": model.final_state,
      "top_word_id": model.top_word_id # add this line
  }

后来在函数中,vals['top_word_id']将有一个整数数组,其中包含顶部单词的 ID。查找此内容word_to_id来确定预测词。我不久前用小模型做了这个,尽管困惑度是标题中预测的,但 top 1 的准确率相当低(20-30% iirc)。

子问题

为什么使用随机(未初始化、未经训练的)词嵌入?

你必须询问作者,但在我看来,训练嵌入使得这更像是一个独立的教程:它没有将嵌入视为黑匣子,而是展示了它是如何工作的。

为什么使用softmax?

最终的预测是not由与隐藏层输出的余弦相似度确定。 LSTM 之后有一个 FC 层,它将嵌入状态转换为最终单词的 one-hot 编码。

以下是神经网络中的操作和维度的草图:

word -> one hot code (1 x vocab_size) -> embedding (1 x hidden_size) -> LSTM -> FC layer (1 x vocab_size) -> softmax (1 x vocab_size)

隐藏层是否必须与输入的维度匹配(即 word2vec 嵌入的维度)

从技术上来说,不。如果您查看 LSTM 方程,您会发现 x(输入)可以是任意大小,只要适当调整权重矩阵即可。

我如何/可以引入预先训练的 word2vec 模型,而不是未初始化的模型?

我不知道,抱歉。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 LSTM 教程代码来预测句子中的下一个单词? 的相关文章

随机推荐

  • 设计模式:异常/错误处理

    是否有任何资源 网络或书籍 描述异常处理 错误处理设计模式 有很多关于如何编写干净代码的文献 也有很多涉及设计模式的书籍 然而 我从未见过任何设计模式涵盖以下问题 在何处以及如何最好地处理错误 以及如何最好地将低级函数中出现的错误向上传播到
  • 为多个版本的 Visual Studio 开发 Visual Studio 插件

    我的任务是为 Visual Studio 开发一些扩展以供我们内部使用 这些必须支持几个不同版本的 Visual Studio VS2008 2010 和 2012 VS2005 是一个很好的选择 但不是必需的 我希望以尽可能一致的方式开发
  • Django F 似乎不起作用?

    嗯 出于某种原因 即使在最简单的模型上 我似乎也无法让 F 正常工作 这里是 Django 1 9 x 最简单的形式是 TestAccount class TestAccount models Model decimal models De
  • Docker 运行 - 用户组未按预期工作?

    我有一个通过串行端口进行通信的脚本 dev ttyUSB0 我想从 Docker 映像中运行它 但是我似乎没有权限从图像中执行此操作 我按照以下步骤操作 在我的主机上 如果我运行ln l dev ttyUSB0 I get crw rw 1
  • 使用 SCons 进行真正的分层构建?

    所以我读过这里有关分层构建的问题 例如 使用 SCons 创建分层构建 https stackoverflow com questions 3709321 creating a hierarchical build with scons 我
  • 如何使用 Django migrate 命令跳过迁移?

    首先 我问的是1 7中引入的Django迁移 而不是south 假设我有迁移001 add field x 002 add field y 并且两者都应用于数据库 现在我改变主意并决定恢复第二次迁移并将其替换为另一个迁移003 add fi
  • 重写as_json没有效果?

    我试图在我的一个模型中覆盖 as json 部分是为了包含来自另一个模型的数据 部分是为了删除一些不必要的字段 据我所知 这是 Rails 3 中的首选方法 为了简单起见 假设我有类似的方法 class Country lt ActiveR
  • hostPID 和 hostIPC 选项在 kubernetes pod 中意味着什么?

    在 kubernetes pod yaml 规范文件中 您可以使用以下命令将 pod 设置为使用主机的网络hostNetwork true 我在任何地方都找不到关于什么的好的 适合初学者 解释hostPID true and hostIPC
  • 时间:2019-05-17 标签:c#WinformMSChartreverseYAxis

    我在 Windows 窗体上使用 MSChart 控件 我想有一个下降 Y 轴通过使用AxisY IsReversed true 但仍将 X 轴保留在底部 默认情况下 当我使用AxisY IsReversed true 然后 X 轴上升到顶
  • 将文件保存在我的项目内的指定文件夹中

    我正在创建一个 Xml 文件 我想将其保存在解决方案内项目内的指定文件夹中 我可以在解决方案资源管理器中访问它 如何指定路径以便创建文件夹并将 xml 文件保存在其中 目前 它在我的项目的根目录中创建文件 但我无法在解决方案资源管理器中查看
  • 使带有索引数组的 for 循环更快

    我有以下问题 我有带有重复索引的索引数组 并且想将值添加到数组中 如下所示 grid array xidx yidx zidx data 然而 由于我有重复的索引 这不起作用 因为 numpy 将创建一个临时数组 这会导致重复索引的数据被分
  • VueJS:选择同一文件时不会触发输入文件选择事件

    我们如何在 Vue Js 中文件输入检测相同文件输入的变化
  • 后退按钮未显示在导航控制器中

    我已将一个视图控制器中的表单元格的显示序列添加到嵌入导航控制器中的另一个表视图 当我单击第一个视图中的单元格时 segue 按预期工作并显示新视图 但是 后退 按钮 带有原始视图的标题 不会出现在导航栏中 我搜索了 SO 发现过去提出过很多
  • 设置为背景的 SVG 线性渐变在 Edge 和 IE 中不起作用

    我使用带有线性渐变颜色的 SVG 形状 background url imgUrlBase element svg 除了 Edge 和 IE 之外 在任何地方都可以正常工作 形状显示正确 但不是渐变 只有纯色 由于多种原因 简单使用 png
  • 通过 Instagram API 使用 php 获取照片和点赞

    使用 Instagram 时client idAPI 请求如下 https api instagram com v1 users https api instagram com v1 users 用户 ID media recent cli
  • 如何在 Tkinter 文本框中设置对齐方式

    Question 如何更改特定行的对齐方式ScrolledTextTkinter 中的小部件 我原来的错误的原因是什么 背景 我目前正在开发 Tkinter 文本框应用程序 并且正在寻找更改行对齐方式的方法 最终 我希望能够更改特定行左对齐
  • C++ 中数组的静态边界检查

    我需要一些关于我正在学习的编程语言课程中的问题的指导 我们需要想出一种在 C 中实现数组类的方法 以便静态检查对其元素的访问是否存在溢出 我们不会使用 C 11 静态断言 或任何其他黑盒解决方案 这是一个理论问题 而不是我出于编码目的所需的
  • 来自数组的 PHP 值,其中键位于另一个数组中

    由于某种原因 我正在为此苦苦挣扎 我有2个数组 第一个是名为 colsArray 的标准数组 如下所示 Array 0 gt fName 1 gt lName 2 gt city 第二个是一个名为 query data 的多维数组 如下所示
  • 网站性能测试:如何最好地估计计算机性能?

    我的网页中有一些浏览器密集型 CSS 和动画 我想确定用户是否拥有快速的 PC 以便我可以相应地缩放内容以提供最佳体验 我在用http Detectmobilebrowser com http detectmobilebrowser com
  • 使用 LSTM 教程代码来预测句子中的下一个单词?

    我一直在尝试理解示例代码https www tensorflow org tutorials recurrent https www tensorflow org tutorials recurrent你可以在以下位置找到https git