如何使用 Tensorflow 数据集管道进行可变长度输入?

2023-11-24

我正在 Tensorflow 中通过不同长度的数字序列数据集训练循环神经网络,并一直在尝试使用tf.data用于创建高效管道的 API。但是我似乎无法让这个东西发挥作用

我的方法

我的数据集是一个 NumPy 形状数组[10000, ?, 32, 2]它作为文件保存在我的磁盘上.npy格式。这里的?表示元素在第二维中具有可变长度。 10000 表示数据集中的小批量的数量,32 表示小批量的大小。

我在用np.load打开这个数据集,我正在尝试创建一个tf.data.Dataset对象使用from_tensor_slices方法,但似乎只有在所有输入张量具有相同形状时才有效!

我尝试阅读docs但他们只给出了一个非常简单的例子。

My code

因此 numpy 文件已生成如下 -

dataset = []
for i in xrange(num_items):
  #add an element of shape [?, 32, 2] to the list where `?` takes
  # a random value between [1, 40]
  dataset.append(generate_random_rnn_input())

with open('data.npy', 'w') as f:
  np.save(f, dataset)

下面给出的代码是我尝试创建一个tf.data.Dataset object

# dataset_list is a list containing `num_items` number of itesm
# and each item has shape [?, 32, 2]
dataset_list = np.load('data.npy')

# error, this doesn't work!
dataset = tf.data.Dataset.from_tensor_slices(dataset_list)

我得到的错误是“TypeError:预期的二进制或unicode字符串,得到数组([[[0.0875, 0.], ...”

继续,仍然需要帮助!

所以我尝试了@mrry的答案,现在我可以创建一个数据集对象。However,我无法按照教程中所述使用迭代器迭代此数据集。这就是我的代码现在的样子 -

dataset_list = np.load('data.npy')

dataset = tf.data.Dataset.from_generator(lambda: dataset_list, 
                                         dataset_list[0].dtype,
                                         tf.TensorShape([None, 32, 2]))

dataset = dataset.map(lambda x : tf.cast(x, tf.float32))

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
  print sess.run(next_element) # The code fails on this line

我得到的错误是AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'。我完全不知道这意味着什么。

这是完整的堆栈跟踪 -

2018-05-15 04:19:25.559922: W tensorflow/core/framework/op_kernel.cc:1261] Unknown: exceptions.AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'
Traceback (most recent call last):

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 147, in __call__
    ret = func(*args)

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 378, in generator_py_func
    nest.flatten_up_to(output_types, values), flattened_types)

AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'


2018-05-15 04:19:25.559989: W tensorflow/core/framework/op_kernel.cc:1273] OP_REQUIRES failed at iterator_ops.cc:891 : Unknown: exceptions.AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'
Traceback (most recent call last):

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 147, in __call__
    ret = func(*args)

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 378, in generator_py_func
    nest.flatten_up_to(output_types, values), flattened_types)

AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'


     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_DOUBLE], token="pyfunc_1"](arg0)]]
Traceback (most recent call last):
  File "pipeline_test.py", line 320, in <module>
    tf.app.run()
  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 126, in run
    _sys.exit(main(argv))
  File "pipeline_test.py", line 316, in main
    train(FLAGS.num_training_iterations, FLAGS.report_interval, FLAGS.report_interval_verbose)
  File "pipeline_test.py", line 120, in train
    print(sess.run(next_element))
  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 905, in run
    run_metadata_ptr)
  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1140, in _run
    feed_dict_tensor, options, run_metadata)
  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1321, in _do_run
    run_metadata)
  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1340, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.UnknownError: exceptions.AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'
Traceback (most recent call last):

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/ops/script_ops.py", line 147, in __call__
    ret = func(*args)

  File "/home/vastolorde95/virtualenvs/thesis/local/lib/python2.7/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 378, in generator_py_func
    nest.flatten_up_to(output_types, values), flattened_types)

AttributeError: 'numpy.dtype' object has no attribute 'as_numpy_dtype'


     [[Node: PyFunc = PyFunc[Tin=[DT_INT64], Tout=[DT_DOUBLE], token="pyfunc_1"](arg0)]]
     [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,32,2]], output_types=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]

正如你所注意到的,tf.data.Dataset.from_tensor_slices()仅适用于可以转换为(密集)的对象tf.Tensor or a tf.SparseTensor。将可变长度 NumPy 数据转换为最简单的方法Dataset是使用tf.data.Dataset.from_generator(), 如下:

dataset = tf.data.Dataset.from_generator(lambda: dataset_list, 
                                         tf.as_dtype(dataset_list[0].dtype),
                                         tf.TensorShape([None, 32, 2]))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用 Tensorflow 数据集管道进行可变长度输入? 的相关文章

随机推荐

  • 我怎样才能知道c中指针变量分配的内存大小[重复]

    这个问题在这里已经有答案了 我在这种情况下遇到了一些问题 您能请您提供一下想法吗 main char p NULL p char malloc 2000 sizeof char printf size of p d n sizeof p 在
  • => 在 Linq 表达式中意味着什么 [重复]

    这个问题在这里已经有答案了 虽然这是一个重复的问题 但我以前从未在代码中见过表达式 gt 如果我知道这是一个 lambda 表达式 我就会用 google 搜索并自己找出答案 谢谢 我是 Linq 的新手 所以当我在这段代码中遇到 gt 时
  • 如何制作 Django 查询集来选择组内具有最大值的记录

    这是我的 Django 类 class MyClass models Model my integer models IntegerField created ts models DateTimeField default datetime
  • 在 Python 中查找箭头键的值:为什么它们是三元组?

    我正在尝试查找本地系统分配给箭头键的值 特别是在 Python 中 我正在使用以下脚本来执行此操作 import sys tty termios class Getch def call self fd sys stdin fileno o
  • 如何防止 Excel 在宏计算时渲染电子表格?

    我的宏用数字更新一个大型电子表格 但它运行速度非常慢 因为 Excel 在计算时渲染结果 如何在宏完成之前阻止 Excel 渲染输出 我使用了两种建议的解决方案 Application ScreenUpdating False Applic
  • 如何在 thymeleaf 中处理和连接字符串

    我有一个字符串列表 这是我感兴趣的属性名称 我想连接这些字符串的值 但不使用属性名称 而是使用它们的属性值 我看到起点是 strings listJoin 但是我怎么能说将列表中的元素与属性文件中的值相匹配呢 该列表将是 name addr
  • 创建一个可根据其内容调整大小的 QDockWidget

    我有一个应用程序 需要在运行时根据用户输入以编程方式将固定大小的子窗口小部件添加到停靠窗口小部件 我想将这些小部件添加到 Qt RightDockArea 上的停靠栏 从上到下直到空间不足 然后创建一个新列并重复 本质上与流程布局示例相反
  • cpp中的“[=]”是什么意思

    请检查下面的代码 NodeScheduleLambda this 0 01f this gt removeFromParentAndCleanup true 那里面的 是什么意思呢 有谁可以帮帮我吗 谢谢 lambda 是一种未命名 匿名函
  • NewDirectByteBuffer 是否在本机代码中创建副本

    我正在 C 中创建两个数组 这两个数组将在 java 端读取 env gt NewDirectByteBuffer env gt NewByteArray 这些函数会复制我发送的缓冲区吗 我是否需要在 C 端的堆上创建缓冲区 或者是否可以在
  • Selenium IDE:如何在未找到元素或出现错误时继续执行脚本

    我需要你的帮助 我只想在 Firefox 上继续我的 Selenium IDE 脚本 即使出现错误或未找到元素 我正在使用 HTML 格式的脚本 在下一个命令中使用该元素之前 您必须显式检查该元素是否存在 这可能会导致错误并中断脚本的执行
  • HTTP 请求失败! HTTP/1.1 503 服务暂时不可用

    我正在使用函数 file get contents 从网页获取内容 有些网站运行良好 但大多数都给我这个错误 failed to open stream HTTP request failed HTTP 1 1 503 Service Te
  • 具有不同文本大小的 TextView

    是否可以在一个 TextView 中设置不同的 textSize 我知道我可以使用以下方法更改文本样式 TextView textView TextView findViewById R id textView Spannable span
  • 防止堆上未对齐的数据

    我正在构建一个使用 SSE 内在函数的类层次结构 因此该类的一些成员需要 16 字节对齐 对于堆栈实例我可以使用 declspec align 像这样 typedef declspec align 16 float Vector 4 cla
  • Azure 容器应用程序每 30 秒重新启动一次

    我有一个基于的 Azure 容器应用程序托管后台服务模型 它本质上只是一个长期运行的控制台应用程序 它覆盖了BackgroundService ExecuteAsync方法并等待停止信号 通过传递的取消令牌 当我在 Docker 中本地运行
  • 在 jQuery 中,将数字格式化为小数点后两位的最佳方法是什么?

    这就是我现在所拥有的 number val parseFloat number val toFixed 2 我觉得很乱 我认为我没有正确链接这些函数 我是否必须为每个文本框调用它 或者我可以创建一个单独的函数吗 如果您要对多个领域执行此操作
  • 与区域设置无关的 strtod 实现

    我有一个库需要解析始终使用点 的双数 作为小数点分隔符 不幸的是 对于这种情况 strtod 尊重可能使用不同分隔符的语言环境 因此解析可能会失败 我无法 setlocale 它不是线程安全的 所以我现在正在寻找一个干净的独立于语言环境的
  • 同步三个线程

    在采访中被问到这个问题 试图解决它 但没有成功 我想到使用 CyclicBarrier 有三个线程 T1 打印 1 4 7 T2 打印 2 5 8 T3 打印 3 6 9 如何同步这三个来打印序列 1 2 3 4 5 6 7 8 9 我尝试
  • 从 2 个向量的串联构建一个向量

    有没有办法构建一个vector作为 2 的串联vectors 除了创建一个辅助函数 例如 const vector
  • 延迟加载 Angular 13+ 模块,无需使用已弃用的编译器

    我曾与加载和实例化 Angular 模块 不带路由器 但现在 在 Angular 13 中 我发现实例化 NgModule 的常用编译器工具已被弃用 这是我加载模块的常用代码 const moduleFactory await this c
  • 如何使用 Tensorflow 数据集管道进行可变长度输入?

    我正在 Tensorflow 中通过不同长度的数字序列数据集训练循环神经网络 并一直在尝试使用tf data用于创建高效管道的 API 但是我似乎无法让这个东西发挥作用 我的方法 我的数据集是一个 NumPy 形状数组 10000 32 2