在 TensorFlow 中将多个字节读取到单个值中

2024-05-05

我尝试以 TensorFlow 中 cifar10 示例中描述的类似方式读取标签:

 ....
 label_bytes = 2 # it was 1 in the original version
 result.key, value = reader.read(filename_queue)
 record_bytes = tf.decode_raw(value, tf.uint8)
 result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
 ....

问题是,如果label_byte大于 1(例如 2),result.label似乎变成了两个元素的张量(每个元素都是 1 字节)。我只想代表连续的label_bytes将字节转换为单个值。我怎么做?

Thanks


创建第二个解码器,用它解码 int16 并将第一个元素作为标签

shorts = tf.decode_raw(value, tf.int16)
result.label = tf.cast(shorts[0], tf.int32)

可能有更好的解决方案,但它有效。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 TensorFlow 中将多个字节读取到单个值中 的相关文章

随机推荐