Seq2Seq 模型在几次迭代后学会仅输出 EOS 令牌 (<\s>)

2024-01-18

我正在创建一个接受过训练的聊天机器人康奈尔电影对话语料库 https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html using NMT https://github.com/tensorflow/nmt.

我的代码部分基于https://github.com/bshao001/ChatLearner https://github.com/bshao001/ChatLearner and https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot https://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot

在训练期间,我打印从批次中馈送到解码器的随机输出答案以及我的模型预测的相应答案以观察学习进度。

我的问题:仅经过大约 4 次迭代训练后,模型就学会了输出 EOS 代币(<\s>)对于每个时间步长。即使训练仍在继续,它也始终将其输出作为其响应(使用 logits 的 argmax 确定)。模型偶尔会输出一系列周期作为答案,但这种情况很少发生。

我还在训练期间打印前 10 个 logit 值(不仅仅是 argmax),以查看其中是否有正确的单词,但它似乎是在预测词汇中最常见的单词(例如 i、you、?、. )。即使是这前 10 个单词在训练过程中也没有太大变化。

我已确保正确计算编码器和解码器的输入序列长度,并添加了 SOS (<s>)和 EOS(也用于填充)相应的代币。我也表演masking在损失计算中。

这是一个示例输出:

训练迭代 1:

Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration 
administration winston winston winston magazines magazines magazines 
magazines

...

训练迭代 4:

Decoder Input: <s> i guess i had it coming . let us call it settled . 
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> 
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>


经过几次迭代后,它决定只预测 EOS(很少预测某些时期)

我不确定是什么导致了这个问题,并且已经被困在这个问题上一段时间了。任何帮助将不胜感激!

Update:我让它训练超过十万次迭代,但它仍然只输出 EOS(和偶尔的周期)。经过几次迭代后,训练损失也没有减少(从一开始就保持在 47 左右)


最近我也在研究seq2seq模型。 我以前遇到过你的问题,就我而言,我通过更改损失函数来解决它。

你说你用面膜,所以我猜你用tf.contrib.seq2seq.sequence_loss就像我一样。

我改为tf.nn.softmax_cross_entropy_with_logits,并且它可以正常工作(并且计算成本更高)。

(编辑 05/10/2018。请原谅,我需要编辑,因为我发现我的代码中有一个严重的错误)

tf.contrib.seq2seq.sequence_loss可以很好地工作,如果形状logits ,targets , mask是对的。 正如官方文档中定义的:tf.contrib.seq2seq.sequence_loss https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/sequence_loss

loss=tf.contrib.seq2seq.sequence_loss(logits=decoder_logits,
                                      targets=decoder_targets,
                                      weights=masks) 

#logits:  [batch_size, sequence_length, num_decoder_symbols]  
#targets: [batch_size, sequence_length] 
#weights: [batch_size, sequence_length] 

嗯,即使形状不符合,它仍然可以工作。但结果可能很奇怪(很多#EOS #PAD...等)。

自从decoder_outputs,以及decoder_targets可能具有与所需相同的形状(在我的例子中,我的decoder_targets有形状[sequence_length, batch_size])。 所以尝试使用tf.transpose帮助您重塑张量。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Seq2Seq 模型在几次迭代后学会仅输出 EOS 令牌 (<\s>) 的相关文章

  • 中断 Select 以添加另一个要在 Python 中监视的套接字

    我正在 Windows XP 应用程序中使用 TCP 实现点对点 IPC 我正在使用select and socketPython 2 6 6 中的模块 我有三个 TCP 线程 一个读取线程通常会阻塞select 一个通常等待事件的写入线程
  • 元组有什么用?

    我现在正在学习 Python 课程 我们刚刚介绍了元组作为数据类型之一 我阅读了它的维基百科页面 但是 我无法弄清楚这种数据类型在实践中会有什么用处 我可以提供一些需要一组不可变数字的示例吗 也许是在 Python 中 这与列表有何不同 每
  • 安装了 32 位的 Python,显示为 64 位

    我需要运行 32 位版本的 Python 我认为这就是我在我的机器上运行的 因为这是我下载的安装程序 当我重新运行安装程序时 它会将当前安装的 Python 版本称为 Python 3 5 32 位 然而当我跑步时platform arch
  • Python 中的舍入浮点问题

    我遇到了 np round np around 的问题 它没有正确舍入 我无法包含代码 因为当我手动设置值 而不是使用我的数据 时 返回有效 但这是输出 In 177 a Out 177 0 0099999998 In 178 np rou
  • 需要在python中找到print或printf的源代码[关闭]

    很难说出这里问的是什么 这个问题是含糊的 模糊的 不完整的 过于宽泛的或修辞性的 无法以目前的形式得到合理的回答 如需帮助澄清此问题以便重新打开 访问帮助中心 help reopen questions 我正在做一些我不能完全谈论的事情 我
  • 删除flask中的一对一关系

    我目前正在使用 Flask 开发一个应用程序 并且在删除一对一关系中的项目时遇到了一个大问题 我的模型中有以下结构 class User db Model tablename user user id db Column db String
  • 将 python2.7 与 Emacs 24.3 和 python-mode.el 一起使用

    我是 Emacs 新手 我正在尝试设置我的 python 环境 到目前为止 我已经了解到在 python 缓冲区中使用 python mode el C c C c将当前缓冲区的内容加载到交互式 python shell 中 显然使用了什么
  • 立体太阳图 matplotlib 极坐标图 python

    我正在尝试创建一个与以下类似的简单的立体太阳路径图 http wiki naturalfrequent com wiki Sun Path Diagram http wiki naturalfrequency com wiki Sun Pa
  • 在Python中连接反斜杠

    我是 python 新手 所以如果这听起来很简单 请原谅我 我想加入一些变量来生成一条路径 像这样 AAAABBBBCCCC 2 2014 04 2014 04 01 csv Id TypeOfMachine year month year
  • 使用 xlrd 打开 BytesIO (xlsx)

    我正在使用 Django 需要读取上传的 xlsx 文件的工作表和单元格 使用 xlrd 应该可以 但因为文件必须保留在内存中并且可能不会保存到我不知道如何继续的位置 本例中的起点是一个带有上传输入和提交按钮的网页 提交后 文件被捕获req
  • 如何在 Python 中解析和比较 ISO 8601 持续时间? [关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 我正在寻找一个 Python v2 库 它允许我解析和比较 ISO 8601 持续时间may处于不同单
  • Jupyter Notebook 找不到 Python 模块

    不知道发生了什么 但每当我使用 ipython 氢 原子 或 jupyter 笔记本时都找不到任何已安装的模块 我知道我安装了 pandas 但笔记本说找不到 我应该补充一点 当我正常运行脚本时 python script py 它确实导入
  • 使用特定颜色和抖动在箱形图上绘制数据点

    我有一个plotly graph objects Box图 我显示了箱形 图中的所有点 我需要根据数据的属性为标记着色 如下所示 我还想抖动这些点 下面未显示 Using Box我可以绘制点并抖动它们 但我不认为我可以给它们着色 fig a
  • Pandas 将多行列数据帧转换为单行多列数据帧

    我的数据框如下 code df Car measurements Before After amb temp 30 268212 26 627491 engine temp 41 812730 39 254255 engine eff 15
  • 为什么 Pickle 协议 4 中的 Pickle 文件是协议 3 中的两倍,而速度却没有任何提升?

    我正在测试 Python 3 4 我注意到 pickle 模块有一个新协议 因此 我对 2 个协议进行了基准测试 def test1 pickle3 open pickle3 wb for i in range 1000000 pickle
  • 在本地网络上运行 Bokeh 服务器

    我有一个简单的 Bokeh 应用程序 名为app py如下 contents of app py from bokeh client import push session from bokeh embed import server do
  • python import inside函数隐藏现有变量

    我在我正在处理的多子模块项目中遇到了一个奇怪的 UnboundLocalError 分配之前引用的局部变量 问题 并将其精简为这个片段 使用标准库中的日志记录模块 import logging def foo logging info fo
  • Python ImportError:无法导入名称 __init__.py

    我收到此错误 ImportError cannot import name life table from cdc life tables C Users tony OneDrive Documents Retirement retirem
  • 实现 XGboost 自定义目标函数

    我正在尝试使用 XGboost 实现自定义目标函数 在 R 中 但我也使用 python 所以有关 python 的任何反馈也很好 我创建了一个返回梯度和粗麻布的函数 它工作正常 但是当我尝试运行 xgb train 时它不起作用 然后 我
  • 使用for循环时如何获取前一个元素? [复制]

    这个问题在这里已经有答案了 可能的重复 Python 循环内的上一个和下一个值 https stackoverflow com questions 1011938 python previous and next values inside

随机推荐