如何通过tensorflow的tf.data API加载pickle文件

2023-11-22

我的数据存储在磁盘上的多个 pickle 文件中。我想使用tensorflow的tf.data.Dataset将数据加载到训练管道中。我的代码是:

def _parse_file(path):
    image, label = *load pickle file*
    return image, label
paths = glob.glob('*.pkl')
print(len(paths))
dataset = tf.data.Dataset.from_tensor_slices(paths)
dataset = dataset.map(_parse_file)
iterator = dataset.make_one_shot_iterator()

问题是我不知道如何实施_parse_file功能。该函数的参数,path,是张量类型。我试过

def _parse_file(path):
    with tf.Session() as s:
        p = s.run(path)
        image, label = pickle.load(open(p, 'rb'))
    return image, label

并收到错误消息:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'arg0' with dtype string
     [[Node: arg0 = Placeholder[dtype=DT_STRING, shape=<unknown>, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

在网上搜索一番后我仍然不知道该怎么做。我将感谢任何给我提示的人。


我自己已经解决了这个问题。我应该使用tf.py_func就像在这个doc.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何通过tensorflow的tf.data API加载pickle文件 的相关文章

随机推荐