如何将 tf.contrib.seq2seq.Helper 用于非嵌入数据?

2024-05-18

我正在尝试使用 tf.contrib.seq2seq 模块对某些数据(仅 float32 向量)进行预测,但我使用 TensorFlow 中的 seq2seq 模块找到的所有示例都用于翻译,因此用于嵌入。

我正在努力准确理解 tf.contrib.seq2seq.Helper 为 Seq2Seq 架构所做的事情以及如何在我的案例中使用 CustomHelper。

这就是我现在所做的:

import tensorflow as tf 
from tensorflow.python.layers import core as layers_core

input_seq_len = 15 # Sequence length as input
input_dim = 1 # Nb of features in input

output_seq_len = forecast_len = 20 # horizon length for forecasting
output_dim = 1 # nb of features to forecast


encoder_units = 200 # nb of units in each cell for the encoder
decoder_units = 200 # nb of units in each cell for the decoder

attention_units = 100

batch_size = 8


graph = tf.Graph()
with graph.as_default():

    learning_ = tf.placeholder(tf.float32)

    with tf.variable_scope('Seq2Seq'):

        # Placeholder for encoder input
        enc_input = tf.placeholder(tf.float32, [None, input_seq_len, input_dim])

        # Placeholder for decoder output - Targets
        target = tf.placeholder(tf.float32, [None, output_seq_len, output_dim])


        ### BUILD THE ENCODER

        # Build RNN cell
        encoder_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_units)

        initial_state = encoder_cell.zero_state(batch_size, dtype=tf.float32)

        # Run Dynamic RNN
        #   encoder_outputs: [batch_size, seq_size, num_units]
        #   encoder_state: [batch_size, num_units]
        encoder_outputs, encoder_state = tf.nn.dynamic_rnn(encoder_cell, enc_input, initial_state=initial_state)

        ## Attention layer

        attention_mechanism_bahdanau = tf.contrib.seq2seq.BahdanauAttention(
            num_units = attention_units, # depth of query mechanism
            memory = encoder_outputs, # hidden states to attend (output of RNN)
            normalize=False, # normalize energy term
            name='BahdanauAttention')

        attention_mechanism_luong = tf.contrib.seq2seq.LuongAttention(
            num_units = encoder_units,
            memory = encoder_outputs,
            scale=False,
            name='LuongAttention'
        )


        ### BUILD THE DECODER

        # Simple Dense layer to project from rnn_dim to the desired output_dim
        projection = layers_core.Dense(output_dim, use_bias=True, name="output_projection")

        helper = tf.contrib.seq2seq.TrainingHelper(target, sequence_length=[output_seq_len for _ in range(batch_size)])
 ## This is where I don't really know what to do in my case, is this function changing my data into [ GO, data, END] ?

        decoder_cell = tf.nn.rnn_cell.BasicLSTMCell(decoder_units)

        attention_cell = tf.contrib.seq2seq.AttentionWrapper(
            cell = decoder_cell,
            attention_mechanism = attention_mechanism_luong, # Instance of AttentionMechanism
            attention_layer_size = attention_units,
            name="attention_wrapper")

        initial_state = attention_cell.zero_state(batch_size=batch_size, dtype=tf.float32)
        initial_state = initial_state.clone(cell_state=encoder_state)

        decoder = tf.contrib.seq2seq.BasicDecoder(attention_cell, initial_state=initial_state, helper=helper, output_layer=projection)

        outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder=decoder)


        # Loss function:

        loss = 0.5*tf.reduce_sum(tf.square(outputs[0] - target), -1)
        loss = tf.reduce_mean(loss, 1)
        loss = tf.reduce_mean(loss)

        # Optimizer

        optimizer = tf.train.AdamOptimizer(learning_).minimize(loss)

我知道 Seq2seq 架构的训练状态和推理状态有很大不同,但我不知道如何使用模块中的帮助器来区分两者。 我使用这个模块是因为它对于注意力层非常有用。 如何使用 Helper 为解码器创建 ['Go' , [input_sequence]] ?


None

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何将 tf.contrib.seq2seq.Helper 用于非嵌入数据? 的相关文章

随机推荐

  • 您可以将 Docker 映像直接拉入 IBM Cloud Kubernetes 集群吗?

    TL DR 抱歉 如果这是基础知识 我正在学习 Kubernetes 我尝试在 IBM Cloud 中创建 Kubernetes 部署 但失败了 该部署在我的本地 minikube 上运行良好 但在 IBM Cloud 中失败 我是否需要使
  • 我的 linq select 不起作用,我的 foreach 起作用

    我有以下 LINQSelect这是行不通的 Data Select d gt d Value IsDirty true Not working 我的较长解决方法确实如此 foreach var d in Data d Value IsDir
  • 将 Bootstrap 与 Bower 一起使用

    我正在尝试将 Bootstrap 与 Bower 一起使用 但由于它克隆了整个存储库 因此没有 CSS 和其他内容 这是否意味着我需要在我自己的构建过程中包含构建 Bootstrap 或者如果我错了 正确的工作流程是什么 I finally
  • 添加自定义过渡会导致 xib 加载错误的屏幕尺寸

    我正在尝试向具有 xib 的 UIViewController 添加自定义过渡 我尝试了几种方法 但它们都有相同的问题 视图显示的屏幕尺寸错误 我当前的示例基于以下教程 使用 Swift 在 iOS 中自定义 UIViewControlle
  • 为什么这个谓词格式会变成 '= nil'

    有人建议这个线程 https stackoverflow com questions 40686005 nspredicate crash after swift 3 migration与我的问题完全相同 但是 我的应用程序没有崩溃 并且我
  • 使 .net web api 队列请求以“单线程”方式运行

    我们有一个 c net Web API 服务调用代码 该代码无法一次处理多个数据库请求 该系统适用于需求相对较小的账单在线支付 我们无法控制代码来进行可以解决问题的更改 另一个使用相同代码的小组使用 WCF API 和服务配置将并发请求限制
  • 猪的组连接等效吗?

    试图在 Pig 上完成这个任务 寻找 MySQL 的 group concat 等效项 例如 在我的表中 我有以下内容 3fields userid clickcount pagenumber 155 2 12 155 3 133 155
  • 如何解决 Laravel 8 UI 分页问题?

    我在尝试最近发布的 laravel 8 时遇到了问题 我试图找出变化是什么以及它是如何工作的 当我这样做时 我遇到了分页 laravel 8 UI 变得混乱的问题 不知何故它发生了 有人可以帮助我吗 或者经历过同样的事情 像这样我在 lar
  • Inc 函数 Inno Setup

    这可能非常简单 但是当我尝试编译包含以下内容的程序时 Inc Count 在 Inno Setup 中我不断得到 未知标识符 Inc 我相信这就是在 Pascal 中递增整数的方式 并且对如何继续这里感到困惑 我正在使用 Inno Setu
  • Rails 资源单数还是复数?

    我有一条搜索路线 我想将其设为单数 但是当我指定单数路线时 它仍然会生成复数控制器路线 这是应该的样子吗 resource search Gives me search POST search format action gt create
  • 在 Windows 7 中使用 ActivePerl @ARGV 为空

    我有以下 Perl 脚本 我正在尝试使用 ActivePerl 在 Windows 7 中运行它 c Perl64 bin perl exe w use strict my mp3splt exe c Program Files x86 m
  • 为 PostgreSQL 查询选择正确的索引

    简化表 CREATE TABLE products product no integer PRIMARY KEY sales integer status varchar 16 category varchar 16 CREATE INDE
  • 如何在 Swift 中获取字典中最后输入的值?

    如何获取 Swift 字典中最后输入的值 例如 我如何从下面获取值 CCC var dictionary Dictionary
  • Eslint 从另一个文件确定全局变量

    我试图以这样的方式设置 ESLint 使其在对实际目标文件进行 linting 之前解析全局声明文件 这样我就不必将所有确实是全局的函数和变量声明为全局 而是让解析器弄清楚 In 一些 模块 js function do something
  • 可以读取目标文件吗?

    我很好奇 obj文件 我几乎不知道它们是什么 或者它们包含什么 所以我用 Vim 文本编辑器打开它们 我在里面发现了一种类似外星人的语言 有什么办法可以理解它们代表什么以及它们的内容是什么 另外 它们的用途是什么 Thanks Sure 但
  • Python-将标题写入csv

    目前我正在用 python 编写查询 将数据从 oracle dbo 导出到 csv 文件 我不知道如何在文件中写入标题 try connection cx Oracle connect user pass tns name cursor
  • 在真实设备上展示测试广告

    这是我的代码 let request GADRequest request testDevices kGADSimulatorID XXXX2F32d69CCA859FFB559D0FEA3CF6483D08A6 adView load r
  • Python最大递归,关于sys.setrecursionlimit()的问题

    我有一个问题sys setrecursionlimit 来自蟒蛇docs https docs python org 2 library sys html sys setrecursionlimit这个函数 将Python解释器堆栈的最大深
  • Pyqt5 中的 QThreads:这是官方 QThread 文档的正确 C++ 到 Python 翻译吗?

    关于如何实例化和使用的官方文档QThread可以在这里找到 http doc qt io qt 5 qthread html http doc qt io qt 5 qthread html 该文档描述了两种基本方法 1 工作对象方法和 2
  • 如何将 tf.contrib.seq2seq.Helper 用于非嵌入数据?

    我正在尝试使用 tf contrib seq2seq 模块对某些数据 仅 float32 向量 进行预测 但我使用 TensorFlow 中的 seq2seq 模块找到的所有示例都用于翻译 因此用于嵌入 我正在努力准确理解 tf contr