当 state_is_tuple=True 时如何设置 TensorFlow RNN 状态?

2024-01-10

我写了一个使用 TensorFlow 的 RNN 语言模型 https://github.com/wpm/tfrnnlm。该模型被实现为RNN班级。图结构是在构造函数中构建的,而RNN.train and RNN.test方法运行它。

当我移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,我希望能够重置 RNN 状态。我通过管理训练循环内的状态,通过提要字典将其传递到图中来实现这一点。

在构造函数中我像这样定义了 RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

训练循环看起来像这样

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x and y是文档中的批量训练数据。我的想法是,我在每批之后传递最新的状态,除非我开始一个新文档,当我通过运行将状态清零时self.reset_state.

这一切都有效。现在我想更改我的 RNN 以使用推荐的state_is_tuple=True。但是,我不知道如何通过 feed 字典传递更复杂的 LSTM 状态对象。我也不知道要传递什么参数self.state = tf.placeholder(...)我的构造函数中的行。

这里正确的策略是什么?仍然没有太多示例代码或文档dynamic_rnn可用的。


TensorFlow 问题2695 https://github.com/tensorflow/tensorflow/issues/2695 and 2838 https://github.com/tensorflow/tensorflow/issues/2838显得相关。

A 博客文章 http://www.wildml.com/2016/08/rnns-in-tensorflow-a-practical-guide-and-undocumented-features/on WILDML 解决了这些问题,但没有直接阐明答案。

也可以看看TensorFlow:记住下一批的 LSTM 状态(有状态 LSTM) https://stackoverflow.com/questions/38241410/tensorflow-remember-lstm-state-for-next-batch-stateful-lstm.


Tensorflow 占位符的一个问题是你只能使用 Python 列表或 Numpy 数组来提供它(我认为)。因此,您无法将运行之间的状态保存在 LSTMStateTuple 的元组中。

我通过将状态保存在这样的张量中解决了这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM 层中有两个组件,细胞状态 and 隐藏状态,这就是“2”的由来。 (这篇文章很棒:https://arxiv.org/pdf/1506.00019.pdf https://arxiv.org/pdf/1506.00019.pdf)

构建图表时,您可以解压并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后你按照通常的方式得到新的状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

事情不应该是这样的……也许他们正在研究解决方案。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

当 state_is_tuple=True 时如何设置 TensorFlow RNN 状态? 的相关文章

随机推荐