我想保存 LSTM 的最终状态,以便在恢复模型时将其包含在内并可用于预测。如下所述,当我使用时,保护程序仅了解最终状态tf.assign
。但是,这会引发错误(也将在下面解释)。
在训练期间,我总是将最终的 LSTM 状态反馈回网络,如中所述这个帖子 https://stackoverflow.com/questions/39112622/how-do-i-set-tensorflow-rnn-state-when-state-is-tuple-true。以下是代码的重要部分:
构建图表时:
self.init_state = tf.placeholder(tf.float32, [
self.n_layers, 2, self.batch_size, self.n_hidden
])
state_per_layer_list = tf.unstack(self.init_state, axis=0)
rnn_tuple_state = tuple([
tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
state_per_layer_list[idx][1])
for idx in range(self.n_layers)
])
outputs, self.final_state = tf.nn.dynamic_rnn(
cell, inputs=self.inputs, initial_state=rnn_tuple_state)
在训练期间:
_current_state = np.zeros((self.n_layers, 2, self.batch_size,
self.n_hidden))
_train_step, _current_state, _loss, _acc, summary = self.sess.run(
[
self.train_step, self.final_state,
self.merged
],
feed_dict={self.inputs: _inputs,
self.labels:_labels,
self.init_state: _current_state})
当我稍后从检查点恢复模型时,最终状态也不会恢复。如中所述这个帖子 https://stackoverflow.com/questions/39112622/how-do-i-set-tensorflow-rnn-state-when-state-is-tuple-true问题是 Saver 不知道新的状态。该帖子还提出了一个解决方案,基于tf.assign
。遗憾的是,我无法使用建议的
assign_op = tf.assign(self.init_state, _current_state)
self.sess.run(assign_op)
因为 self.init state 不是一个变量而是一个占位符。我收到错误
AttributeError:“张量”对象没有属性“分配”
我已经尝试解决这个问题几个小时了,但我无法让它工作。
任何帮助表示赞赏!
EDIT:
我已将 self.init_state 更改为
self.init_state = tf.get_variable('saved_state', shape=
[self.n_layers, 2, self.batch_size, self.n_hidden])
state_per_layer_list = tf.unstack(self.init_state, axis=0)
rnn_tuple_state = tuple([
tf.contrib.rnn.LSTMStateTuple(state_per_layer_list[idx][0],
state_per_layer_list[idx][1])
for idx in range(self.n_layers)
])
outputs, self.final_state = tf.nn.dynamic_rnn(
cell, inputs=self.inputs, initial_state=rnn_tuple_state)
在训练期间,我不为 self.init_state 提供值:
_train_step, _current_state, _loss, _acc, summary = self.sess.run(
[
self.train_step, self.final_state,
self.merged
],
feed_dict={self.inputs: _inputs,
self.labels:_labels})
但是,我仍然无法运行分配操作。知道我得到
类型错误:预期 float32 传递给操作“分配”的参数“值”,得到 (LSTMStateTuple(c=array([[ 0.07291573, -0.06366599, -0.23425588, ..., 0.05307654,