有没有正确的方法来子类 Tensorflow 的数据集?

2024-01-26

我正在研究可以处理自定义 Tensorflow 数据集的不同方法,并且我习惯于查看PyTorch 的数据集 https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files,但是当我去看的时候Tensorflow 的数据集 https://www.tensorflow.org/guide/data_performance,我看到了这个例子:

class ArtificialDataset(tf.data.Dataset):
  def _generator(num_samples):
    # Opening the file
    time.sleep(0.03)

    for sample_idx in range(num_samples):
      # Reading data (line, record) from the file
      time.sleep(0.015)

      yield (sample_idx,)

  def __new__(cls, num_samples=3):
    return tf.data.Dataset.from_generator(
        cls._generator,
        output_signature = tf.TensorSpec(shape = (1,), dtype = tf.int64),
        args=(num_samples,)
        )

但出现了两个问题:

  1. 看起来它所做的就是当对象被实例化时,__new__方法只是调用tf.data.Dataset.from_generator静态方法。那么为什么不直接调用它呢?为什么有一个甚至子类化的点tf.data.Dataset?是否有任何方法可以使用tf.data.Dataset?
  2. 有没有一种方法可以像数据生成器一样做到这一点,其中一个人填写一个__iter__方法同时继承自tf.data.Dataset?我不知道,就像
class MyDataLoader(tf.data.Dataset):
  def __init__(self, path, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.data = pd.read_csv(path)

  def __iter__(self):
    for datum in self.data.iterrows():
      yield datum

非常感谢大家!


问题1

该示例只是将数据集与生成器封装在类中。它继承自tf.data.Dataset因为from_generator() https://www.tensorflow.org/api_docs/python/tf/data/Dataset#from_generator返回一个tf.data.Dataset基于对象。然而,没有方法tf.data.Dataset如示例中所示使用。因此,回答问题1:是的,可以直接调用而不使用类。

问题2

是的。可以这样做。

另一种类似的方法是使用tf.keras.utils.Sequence https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence like here https://medium.com/@mrgarg.rajat/training-on-large-datasets-that-dont-fit-in-memory-in-keras-60a974785d71.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

有没有正确的方法来子类 Tensorflow 的数据集? 的相关文章

随机推荐