从单词图像及其转录的列表中,我尝试创建和读取稀疏序列标签(例如tf.nn.ctc_loss
) 用一个tf.train.slice_input_producer
,避免
将预打包的训练数据序列化到磁盘中TFRecord
format
的明显局限性tf.py_func
,
任何不必要或过早的填充,以及
将整个数据集读取到 RAM。
主要问题似乎是将字符串转换为标签序列(aSparseTensor
)需要tf.nn.ctc_loss
.
例如,字符集在(有序)范围内[A-Z]
,我想转换文本标签字符串"BAD"
到序列标签类别列表[1,0,3]
.
我想要读取的每个示例图像都包含文本作为文件名的一部分,因此可以直接提取并直接在 python 中进行转换。 (如果有办法在 TensorFlow 计算中做到这一点,我还没有找到。)
之前的几个问题都扫过这些问题,但是一直没能整合成功。例如,
Tensorflow读取带有标签的图像 https://stackoverflow.com/questions/34340489/tensorflow-read-images-with-labels显示了一个带有离散、分类标签的简单框架,
我是从这个模型开始的。
如何使用 TensorFlow 加载稀疏数据? https://stackoverflow.com/questions/36917807/how-to-load-sparse-data-with-tensorflow很好地解释了加载稀疏数据的方法,但假设
预包装tf.train.Example
s.
有没有办法整合这些方法?
另一个例子(SO问题#38012743)显示了我如何延迟从字符串到列表的转换,直到将文件名出队进行解码之后,但它依赖于tf.py_func
,其中有警告。 (我应该担心他们吗?)
我认识到“SparseTensors 不能很好地处理队列”(根据 tf 文档),因此在批处理之前可能需要对结果(序列化?)进行一些巫术,甚至在计算发生的地方进行返工;我对此持开放态度。
按照 MarvMind 的大纲,这里是一个基本框架,其中包含我想要的计算(迭代包含示例文件名的行,提取每个标签字符串并转换为序列),但我尚未成功确定“Tensorflow”方式来执行此操作。
感谢您的正确“调整”,一个更适合我的目标的策略,或者一个指示tf.py_func
不会破坏训练效率或下游的其他东西(例如,加载经过训练的模型以供将来使用)。
编辑(+7 小时)我找到了缺失的操作来修补问题。虽然仍然需要验证它与下游 CTC_Loss 的连接,但我已经检查了下面编辑的版本是否正确批处理并读取图像和稀疏张量。
out_charset="ABCDEFGHIJKLMNOPQRSTUVWXYZ"
def input_pipeline(data_filename):
filenames,seq_labels = _get_image_filenames_labels(data_filename)
data_queue = tf.train.slice_input_producer([filenames, seq_labels])
image,label = _read_data_format(data_queue)
image,label = tf.train.batch([image,label],batch_size=2,dynamic_pad=True)
label = tf.deserialize_many_sparse(label,tf.int32)
return image,label
def _get_image_filenames_labels(data_filename):
filenames = []
labels = []
with open(data_filename)) as f:
for line in f:
# Carve out the ground truth string and file path from
# lines formatted like:
# ./241/7/158_NETWORK_51375.jpg 51375
filename = line.split(' ',1)[0][2:] # split off "./" and number
# Extract label string embedded within image filename
# between underscores, e.g. NETWORK
text = os.path.basename(filename).split('_',2)[1]
# Transform string text to sequence of indices using charset, e.g.,
# NETWORK -> [13, 4, 19, 22, 14, 17, 10]
indices = [[i] for i in range(0,len(text))]
values = [out_charset.index(c) for c in list(text)]
shape = [len(text)]
label = tf.SparseTensorValue(indices,values,shape)
label = tf.convert_to_tensor_or_sparse_tensor(label)
label = tf.serialize_sparse(label) # needed for batching
# Add data to lists for conversion
filenames.append(filename)
labels.append(label)
filenames = tf.convert_to_tensor(filenames)
labels = tf.convert_to_tensor_or_sparse_tensor(labels)
return filenames, labels
def _read_data_format(data_queue):
label = data_queue[1]
raw_image = tf.read_file(data_queue[0])
image = tf.image.decode_jpeg(raw_image,channels=1)
return image,label