我正在尝试找出推荐的使用方法dataset
api 连同estimator
API。我在网上看到的所有内容都是以下内容的一些变体:
def train_input_fn():
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
return dataset
然后可以将其传递给估计器的训练函数:
classifier.train(
input_fn=train_input_fn,
#...
)
but the 数据集指南 https://www.tensorflow.org/guide/datasets警告:
上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到 TensorFlow 图中。这对于小数据集效果很好,但会浪费内存——因为数组的内容将被复制多次——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。
然后描述了一种方法,该方法涉及定义占位符,然后用feed_dict
:
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
但如果你正在使用estimator
api,您没有手动运行会话。那么你如何使用dataset
带有估计器的 api,同时避免了相关的问题from_tensor_slices()
?