我的 raw_data 是 PTB 数据集。
我通过以下代码生成批次。
def generate_batches(raw_data, batch_size, unrollings):
global data_index
data_len = len(raw_data)
num_batches = data_len // batch_size
inputs = []
labels = []
print (num_batches, data_len, batch_size)
for j in xrange(unrollings) :
inputs.append([])
labels.append([])
for i in xrange(batch_size) :
inputs[j].append(raw_data[i + data_index])
labels[j].append(raw_data[i + data_index + 1])
data_index = (data_index + batch_size) % len(raw_data)
return inputs, labels
在会话运行中,生成的相同批次将被输入 feed_dict 中,如以下代码所示。
for step in xrange(num_steps) :
batch_inputs, batch_labels = generate_batches(train_dataset, batch_size, unrollings=5)
feed_dict = dict()
for i in range(unrollings):
feed_dict = {train_inputs : batch_inputs, train_labels : batch_labels}
_, l, predictions, lr = session.run([optimizer, loss, train_prediction, learning_rate], feed_dict=feed_dict)
训练输入和标签如下:
for _ in range(unrollings) :
train_data.append(tf.placeholder(shape=[batch_size], dtype=tf.int32))
train_label.append(tf.placeholder(shape=[batch_size, 1], dtype=tf.float32))
train_inputs = train_data[:unrollings]
train_labels = train_label[:unrollings]
首先,我得到了错误TypeError: unhashable type: 'list'
我将batch_input列表转换为元组使用tuple(batch_input[i])
这在中解释得很清楚Python字典:类型错误:不可散列的类型:“列表” https://stackoverflow.com/questions/8532146/python-dictionary-typeerror-unhashable-type-list.
已解决:然后我收到此错误TypeError: unhashable type: 'numpy.ndarray'
.
.