为 Seq2Seq 模型添加注意力层

2023-11-29

我已经构建了编码器-解码器的 Seq2Seq 模型。我想为其添加一个注意力层。我尝试添加注意力层通过这个但这没有帮助。

这是我最初的代码,没有注意

# Encoder
encoder_inputs = Input(shape=(None,))
enc_emb =  Embedding(num_encoder_tokens, latent_dim, mask_zero = True)(encoder_inputs)
encoder_lstm = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(enc_emb)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()

这是我在解码器中添加注意层后的代码(编码器层与初始代码中的相同)

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
attention = dot([decoder_lstm, encoder_lstm], axes=[2, 2])
attention = Activation('softmax')(attention)
context = dot([attention, encoder_lstm], axes=[2,1])
decoder_combined_context = concatenate([context, decoder_lstm])
decoder_outputs, _, _ = decoder_combined_context(dec_emb,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()

执行此操作时,我收到错误

 Layer dot_1 was called with an input that isn't a symbolic tensor. Received type: <class 'keras.layers.recurrent.LSTM'>. Full input: [<keras.layers.recurrent.LSTM object at 0x7f8f77e2f3c8>, <keras.layers.recurrent.LSTM object at 0x7f8f770beb70>]. All inputs to the layer should be tensors.

有人可以帮忙在这个架构中安装一个注意力层吗?


点积需要在张量输出上计算...在编码器中您正确定义了编码器输出,在解码器中您必须添加decoder_outputs, state_h, state_c = decoder_lstm(enc_emb, initial_state=encoder_states)

现在的点积是

attention = dot([decoder_outputs, encoder_outputs], axes=[2, 2])
attention = Activation('softmax')(attention)
context = dot([attention, encoder_outputs], axes=[2,1])

连接不需要initial_states。你必须在你的 rnn 层中定义它:decoder_outputs, state_h, state_c = decoder_lstm(enc_emb, initial_state=encoder_states)

这是完整的例子

编码器+解码器

# dummy variables
num_encoder_tokens = 30
num_decoder_tokens = 10
latent_dim = 100

encoder_inputs = Input(shape=(None,))
enc_emb =  Embedding(num_encoder_tokens, latent_dim, mask_zero = True)(encoder_inputs)
encoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(enc_emb)
# We discard `encoder_outputs` and only keep the states.
encoder_states = [state_h, state_c]

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(dec_emb,
                                     initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()

带有注意力的解码器

# Set up the decoder, using `encoder_states` as initial state.
decoder_inputs = Input(shape=(None,))
dec_emb_layer = Embedding(num_decoder_tokens, latent_dim, mask_zero = True)
dec_emb = dec_emb_layer(decoder_inputs)
# We set up our decoder to return full output sequences,
# and to return internal states as well. We don't use the
# return states in the training model, but we will use them in inference.
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, state_h, state_c = decoder_lstm(dec_emb, initial_state=encoder_states)
attention = dot([decoder_outputs, encoder_outputs], axes=[2, 2])
attention = Activation('softmax')(attention)
context = dot([attention, encoder_outputs], axes=[2,1])
decoder_outputs = concatenate([context, decoder_outputs])
decoder_dense = Dense(num_decoder_tokens, activation='softmax')(decoder_outputs)

# Define the model that will turn
# `encoder_input_data` & `decoder_input_data` into `decoder_target_data`
model = Model([encoder_inputs, decoder_inputs], decoder_dense)
model.summary()
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为 Seq2Seq 模型添加注意力层 的相关文章

随机推荐

  • 为什么百分比高度在我的 div 上不起作用? [复制]

    这个问题在这里已经有答案了 我有两个高度为90 的div 但是显示不一样 我尝试在它们周围放置一个外部 div 但这没有帮助 此外 在 Firefox Chrome Opera 和 Safari 上也是如此 有人可以解释为什么我遇到这个问题
  • 使用数据触发控制故事板,但只触发一次

    我使用数据触发器来控制一些故事板 但它只能触发一次
  • Java - 同步线程 - 输出顺序错误

    在玩了一年 Java 之后 我正在阅读 Java 完整参考 第 9 版 到目前为止 我对这本书很满意 但我现在在同步线程方面遇到了一个非常奇怪的问题 package syncro class Callme void call String
  • 同时等待具有独立延续的多个异步调用

    在多种情况下 我需要调用多个异步调用 来自同一个事件处理程序 这些调用可以彼此独立地进行 每个调用都有自己的延续来更新 UI 以下简单的实现导致三个异步操作按顺序执行 private async void button Click obje
  • Google 地图 JavaScript API - fitbounds 与 setCenter 一起使用

    我一直在寻找解决这个问题的方法 但我似乎找不到解决这个问题的东西 我得到的最接近的是这个线程 但这行不通 我想做的是基于一组运行良好的标记来运行 fitbounds 但我还想根据用户位置 plunk 中的弹跳标记 将地图居中 并仍将所有标记
  • 打印 pandas 数据框时抑制描述性输出

    假设我有数据框 c a np random random 6 2 c pd DataFrame a c columns A B 打印第 0 行值 print c loc 0 结果是 A 0 220170 B 0 261467 Name 0
  • 如何渲染大量相似的物体?

    我有大量对象 至少 10 000 个粒子 例如三角形 正方形 圆形或球体 实际上现在我有一个对象 我渲染了很多次 它看起来像这样 for int i 0 i
  • android 版 admob 入门 - 对文档感到困惑

    我刚刚开始考虑将 Admob 广告放入我正在构建的 Android 应用程序中 到目前为止 还没有好的结果 我一直在遵循从 adMob 网站下载的 AdMod Android SDK Instructions pdf 中的示例 但感到困惑
  • Android:AsyncTask 的处理程序

    我将 AsyncTask 与 ProgressDialog 结合使用 查看我的代码 我在 onPostExecute 中遇到问题 如果任务是第一次运行 它会在handleMessage 中收到progressDialog 的Null Poi
  • 反转 pandas 中的 get_dummies 编码

    列名称为 ID 1 2 3 4 5 6 7 8 9 col 值为 0 或 1 我的数据框如下所示 ID 1 2 3 4 5 6 7 8 9 1002 0 1 0 1 0 0 0 0 0 1003 0 0 0 0 0 0 0 0 0 1004
  • SignalR(v2.2.0) OnDisconnected 设置用户离线

    我使用以下代码在组中添加用户 并使用以下代码将用户保存在该特定组的数据库中 SERVER public class ChatHub Hub public async Task JoinRoom string user Id string r
  • 将小数转换为任何基数? [关闭]

    Closed 这个问题不符合堆栈溢出指南 目前不接受答案 我知道 strtoll 但它将任何基数基数 2 到 36 之间 转换为十进制 我需要通过将十进制转换为任何基数基数来执行相反的操作 一个例子是十进制 130 基数 12 AA 以下代
  • 如何在 Python 中根据 DTD 文件验证 xml

    我需要验证 XML 字符串 而不是文件 针对 DTD 描述文件 这怎么能在python 另一个不错的选择是lxml的验证我觉得用起来很愉快 取自 lxml 站点的一个简单示例 from StringIO import StringIO fr
  • 在 WAMP PHP Google+ 项目中安装 Composer,PHP 无法识别

    我正在关注这个 PHP Google 教程我正在尝试在我的 WAMP 目录中安装作曲家 C wamp www gplus quickstart php gt curl s https getcomposer org installer ph
  • 更改 WiFi MAC 地址 [关闭]

    Closed 这个问题不符合堆栈溢出指南 目前不接受答案 我是致力于该项目的开发人员之一薮猫项目我们正在考虑使用华为创意U1850作为我们Android开发的默认平台 活动 我们从当地经销商之一购买了几部手机 在澳大利亚这里 我们注意到 我
  • 为什么Java中每次long和double都工作时会有这么多类型的数字?

    现在我一直在尝试学习Java编程 我想知道为什么我们使用这样的东西Float short and int当我们可以只是使用Long and Double 我不明白那部分 很好的问题 特别是如果你来自这样的语言JavaScript它不区分数字
  • 如何在 SeekBar 上显示最大值和最小值?

    我正在尝试做的事情 我想实施一个SeekBar在 Android 应用程序中SeekBar我还想显示最大值和最小值 最小值始终为 0 但最大值取决于剪辑长度 例如 0 180 有没有办法显示用户移动时选择的值 在搜索栏本身上 SeekBar
  • Visual Studio Code 更新后,HTML 文件中的智能 Javascript 建议不再起作用

    我使用 Visual Studio Code 已有几个月了 我已经习惯了里面的聪明建议
  • 为什么Java程序需要“main()”方法?

    这只是一个命名约定 为什么从 shell 执行程序时不能调用任何方法 例如 gt java myPackage MyClass myOwnEntryPoint String str 是的 这是一个命名约定 继承自C 这样做的好处是 只需查看
  • 为 Seq2Seq 模型添加注意力层

    我已经构建了编码器 解码器的 Seq2Seq 模型 我想为其添加一个注意力层 我尝试添加注意力层通过这个但这没有帮助 这是我最初的代码 没有注意 Encoder encoder inputs Input shape None enc emb