使用估计器 api 避免 tf.data.Dataset.from_tensor_slices

2024-03-13

我正在尝试找出推荐的使用方法datasetapi 连同estimatorAPI。我在网上看到的所有内容都是以下内容的一些变体:

def train_input_fn():
   dataset = tf.data.Dataset.from_tensor_slices((features, labels))
   return dataset

然后可以将其传递给估计器的训练函数:

 classifier.train(
    input_fn=train_input_fn,
    #...
 )

but the 数据集指南 https://www.tensorflow.org/guide/datasets警告:

上面的代码片段会将特征和标签数组作为 tf.constant() 操作嵌入到 TensorFlow 图中。这对于小数据集效果很好,但会浪费内存——因为数组的内容将被复制多次——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。

然后描述了一种方法,该方法涉及定义占位符,然后用feed_dict:

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

但如果你正在使用estimatorapi,您没有手动运行会话。那么你如何使用dataset带有估计器的 api,同时避免了相关的问题from_tensor_slices()?


要使用可初始化或可重新初始化迭代器,您必须创建一个继承自 tf.train.SessionRunHook 的类,该类可以在训练和评估步骤期间多次访问会话。

然后,您可以使用这个新类来初始化迭代器,就像您通常在经典设置中所做的那样。您只需要将这个新创建的钩子传递给训练/评估函数或正确的训练规范即可。

以下是您可以根据自己的需求进行调整的简单示例:

class IteratorInitializerHook(tf.train.SessionRunHook):
    def __init__(self):
        super(IteratorInitializerHook, self).__init__()
        self.iterator_initializer_func = None # Will be set in the input_fn

    def after_create_session(self, session, coord):
        # Initialize the iterator with the data feed_dict
        self.iterator_initializer_func(session) 


def get_inputs(X, y):
    iterator_initializer_hook = IteratorInitializerHook()

    def input_fn():
        X_pl = tf.placeholder(X.dtype, X.shape)
        y_pl = tf.placeholder(y.dtype, y.shape)

        dataset = tf.data.Dataset.from_tensor_slices((X_pl, y_pl))
        dataset = ...
        ...

        iterator = dataset.make_initializable_iterator()
        next_example, next_label = iterator.get_next()


        iterator_initializer_hook.iterator_initializer_func = lambda sess: sess.run(iterator.initializer,
                                                                                    feed_dict={X_pl: X, y_pl: y})

        return next_example, next_label

    return input_fn, iterator_initializer_hook

...

train_input_fn, train_iterator_initializer_hook = get_inputs(X_train, y_train)
test_input_fn, test_iterator_initializer_hook = get_inputs(X_test, y_test)

...

estimator.train(input_fn=train_input_fn,
                hooks=[train_iterator_initializer_hook]) # Don't forget to pass the hook !
estimator.evaluate(input_fn=test_input_fn,
                   hooks=[test_iterator_initializer_hook])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用估计器 api 避免 tf.data.Dataset.from_tensor_slices 的相关文章

随机推荐