Tensorflow中使用tfrecord方式读取数据

2023-05-16

前言

本博客默认读者对神经网络与Tensorflow有一定了解,对其中的一些术语不再做具体解释。并且本博客主要以图片数据为例进行介绍,如有错误,敬请斧正。

使用Tensorflow训练神经网络时,我们可以用多种方式来读取自己的数据。如果数据集比较小,而且内存足够大,可以选择直接将所有数据读进内存,然后每次取一个batch的数据出来。如果数据较多,可以每次直接从硬盘中进行读取,不过这种方式的读取效率就比较低了。此篇博客就主要讲一下Tensorflow官方推荐的一种较为高效的数据读取方式——tfrecord。

从宏观来讲,tfrecord其实是一种数据存储形式。使用tfrecord时,实际上是先读取原生数据,然后转换成tfrecord格式,再存储在硬盘上。而使用时,再把数据从相应的tfrecord文件中解码读取出来。那么使用tfrecord和直接从硬盘读取原生数据相比到底有什么优势呢?其实,Tensorflow有和tfrecord配套的一些函数,可以加快数据的处理。实际读取tfrecord数据时,先以相应的tfrecord文件为参数,创建一个输入队列,这个队列有一定的容量(视具体硬件限制,用户可以设置不同的值),在一部分数据出队列时,tfrecord中的其他数据就可以通过预取进入队列,并且这个过程和网络的计算是独立进行的。也就是说,网络每一个iteration的训练不必等待数据队列准备好再开始,队列中的数据始终是充足的,而往队列中填充数据时,也可以使用多线程加速。

下面,本文将从以下4个方面对tfrecord进行介绍:

  1. tfrecord格式简介
  2. 利用自己的数据生成tfrecord文件
  3. 从tfrecord文件读取数据
  4. 实例测试

1. tfrecord格式简介

这部分主要参考了另一篇博文,Tensorflow 训练自己的数据集(二)(TFRecord)

tfecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

从上述代码可以看出,tf.train.Example 的数据结构很简单。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),浮点数列表(FloatList )或整数列表(Int64List )。例如我们可以将图片转换为字符串进行存储,图像对应的类别标号作为整数存储,而用于回归任务的ground-truth可以作为浮点数存储。通过后面的代码我们会对tfrecord的这种字典形式有更直观的认识。

2. 利用自己的数据生成tfrecord文件

先上一段代码,然后我再针对代码进行相关介绍。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from scipy import misc
import scipy.io as sio


def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))


root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecords_filename = root_path + 'tfrecords/train.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)


height = 300
width = 300
meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
meanvalue = meanfile['mean']

txtfile = root_path + 'txt/train.txt'
fr = open(txtfile)

for i in fr.readlines():
    item = i.split()
    img = np.float64(misc.imread(root_path + '/images/train_images/' + item[0]))
    img = img - meanvalue
    maskmat = sio.loadmat(root_path + '/mats/train_mats/' + item[1])
    mask = np.float64(maskmat['seg_mask'])
    label = int(item[2])
    img_raw = img.tostring()
    mask_raw = mask.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'name': _bytes_feature(item[0]),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(mask_raw),
        'label': _int64_feature(label)}))

    writer.write(example.SerializeToString())

writer.close()
fr.close()

代码中前两个函数(_bytes_feature和_int64_feature)是将我们的原生数据进行转换用的,尤其是图片要转换成字符串再进行存储。这两个函数的定义来自官方的示例。
接下来,我定义了数据的(路径-label文件)txtfile,它大概长这个样子:

txtfile

这里稍微啰嗦下,介绍一下我的实验内容。我做的是一个multi-task的实验,一支task做分割,一支task做分类。所以txtfile中每一行是一个样本,每个样本又包含3项,第一项为图片名称,第二项为相应的ground-truth segmentation mask的名称,第三项是图片的标签。(txtfile中内容形式无所谓,只要能读到想读的数据就可以)

接着回到主题继续讲代码,之后我又定义了即将生成的tfrecord的文件路径和名称,即tfrecord_filename,还有一个writer,这个writer是进行写操作用的。

接下来是图片的高度、宽度以及我事先在整个数据集上计算好的图像均值文件。高度、宽度其实完全没必要引入,这里只是为了说明tfrecord的生成而写的。而均值文件是为了对图像进行事先的去均值化操作而引入的,在大多数机器学习任务中,图像去均值化对提高算法的性能还是很有帮助的。

最后就是根据txtfile中的每一行进行相关数据的读取、转换以及tfrecord的生成了。首先是根据图片路径读取图片内容,然后图像减去之前读入的均值,接着根据segmentation mask的路径读取mask(如果只是图像分类任务,那么就不会有这些额外的mask),txtfile中的label读出来是string格式,这里要转换成int。然后图像和mask数据也要用相应的tosring函数转换成string。

真正的核心是下面这一小段代码:

example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(height),
        'width': _int64_feature(width),
        'name': _bytes_feature(item[0]),
        'image_raw': _bytes_feature(img_raw),
        'mask_raw': _bytes_feature(mask_raw),
        'label': _int64_feature(label)}))

writer.write(example.SerializeToString())

这里很好地体现了tfrecord的字典特性,tfrecord中每一个样本都是一个小字典,这个字典可以包含任意多个键值对。比如我这里就存储了图片的高度、宽度、图片名称、图片内容、mask内容以及图片的label。对于我的任务来说,其实height、width、name都不是必需的,这里仅仅是为了展示。键值对的键全都是字符串,键起什么名字都可以,只要能方便以后使用就可以。

定义好一个example后就可以用之前的writer来把它真正写入tfrecord文件了,这其实就跟把一行内容写入一个txt文件一样。代码的最后就是writer和txt文件对象的关闭了。

最后在指定文件夹下,就得到了指定名字的tfrecord文件,如下所示:

tfrecord文件

需要注意的是,生成的tfrecord文件比原生数据的大小还要大,这是正常现象。这种现象可能是因为图片一般都存储为jpg等压缩格式,而tfrecord文件存储的是解压后的数据。

3. 从tfrecord文件读取数据

还是代码先行。

from scipy import misc
import tensorflow as tf
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'

def read_and_decode(filename_queue, random_crop=False, random_clip=False, shuffle_batch=True):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
      serialized_example,
      features={
          'height': tf.FixedLenFeature([], tf.int64),
          'width': tf.FixedLenFeature([], tf.int64),
          'name': tf.FixedLenFeature([], tf.string),                           
          'image_raw': tf.FixedLenFeature([], tf.string),
          'mask_raw': tf.FixedLenFeature([], tf.string),                               
          'label': tf.FixedLenFeature([], tf.int64)
      })

    image = tf.decode_raw(features['image_raw'], tf.float64)
    image = tf.reshape(image, [300,300,3])

    mask = tf.decode_raw(features['mask_raw'], tf.float64)
    mask = tf.reshape(mask, [300,300])

    name = features['name']

    label = features['label']
    width = features['width']
    height = features['height']

#    if random_crop:
#        image = tf.random_crop(image, [227, 227, 3])
#    else:
#        image = tf.image.resize_image_with_crop_or_pad(image, 227, 227)

#    if random_clip:
#        image = tf.image.random_flip_left_right(image)


    if shuffle_batch:
        images, masks, names, labels, widths, heights = tf.train.shuffle_batch([image, mask, name, label, width, height],
                                                batch_size=4,
                                                capacity=8000,
                                                num_threads=4,
                                                min_after_dequeue=2000)
    else:
        images, masks, names, labels, widths, heights = tf.train.batch([image, mask, name, label, width, height],
                                        batch_size=4,
                                        capacity=8000,
                                        num_threads=4)
    return images, masks, names, labels, widths, heights

读取tfrecord文件中的数据主要是应用read_and_decode()这个函数,可以看到其中有个参数是filename_queue,其实我们并不是直接从tfrecord文件进行读取,而是要先利用tfrecord文件创建一个输入队列,如本文开头所述那样。关于这点,到后面真正的测试代码我再介绍。

在read_and_decode()中,一上来我们先定义一个reader对象,然后使用reader得到serialized_example,这是一个序列化的对象,接着使用tf.parse_single_example()函数对此对象进行初步解析。从代码中可以看到,解析时,我们要用到之前定义的那些键。对于图像、mask这种转换成字符串的数据,要进一步使用tf.decode_raw()函数进行解析,这里要特别注意函数里的第二个参数,也就是解析后的类型。之前图片在转成字符串之前是什么类型的数据,那么这里的参数就要填成对应的类型,否则会报错。对于name、label、width、height这样的数据就不用再解析了,我们得到的features对象就是个字典,利用键就可以拿到对应的值,如代码所示。

我注释掉的部分是用来做数据增强的,比如随机的裁剪与翻转,除了这两种,其他形式的数据增强也可以写在这里,读者可以根据自己的需要,决定是否使用各种数据增强方式。

函数最后就是使用解析出来的数据生成batch了。Tensorflow提供了两种方式,一种是shuffle_batch,这种主要是用在训练中,随机选取样本组成batch。另外一种就是按照数据在tfrecord中的先后顺序生成batch。对于生成batch的函数,建议读者去官网查看API文档进行细致了解。这里稍微做一下介绍,batch的大小,即batch_size就需要在生成batch的函数里指定。另外,capacity参数指定数据队列一次性能放多少个样本,此参数设置什么值需要视硬件环境而定。num_threads参数指定可以开启几个线程来向数据队列中填充数据,如果硬件性能不够强,最好设小一点,否则容易崩。

4. 实例测试

实际使用时先指定好我们需要使用的tfrecord文件:

root_path = '/mount/temp/WZG/Multitask/Data/'
tfrecord_filename = root_path + 'tfrecords/test.tfrecords'

然后用该tfrecord文件创建一个输入队列:

filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)

这里有个参数是num_epochs,指定好之后,Tensorflow自然知道如何读取数据,保证在遍历数据集的一个epoch中样本不会重复,也知道数据读取何时应该停止。

下面我将完整的测试代码贴出:

def test_run(tfrecord_filename):
    filename_queue = tf.train.string_input_producer([tfrecord_filename],
                                                    num_epochs=3)
    images, masks, names, labels, widths, heights = read_and_decode(filename_queue)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    meanfile = sio.loadmat(root_path + 'mats/mean300.mat')
    meanvalue = meanfile['mean']


    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        for i in range(1):
            imgs, msks, nms, labs, wids, heis = sess.run([images, masks, names, labels, widths, heights])
            print 'batch' + str(i) + ': '
            #print type(imgs[0])

            for j in range(4):
                print nms[j] + ': ' + str(labs[j]) + ' ' + str(wids[j]) + ' ' + str(heis[j])
                img = np.uint8(imgs[j] + meanvalue)
                msk = np.uint8(msks[j])
                plt.subplot(4,2,j*2+1)
                plt.imshow(img)
                plt.subplot(4,2,j*2+2)
                plt.imshow(msk, vmin=0, vmax=5)
            plt.show()

        coord.request_stop()
        coord.join(threads)

函数中接下来就是利用之前定义的read_and_decode()来得到一个batch的数据,此后我又读入了均值文件,这是因为之前做了去均值处理,如果要正常显示图片需要再把均值加回来。

再之后就是建立一个Tensorflow session,然后初始化对象。这些是Tensorflow基本操作,不再赘述。下面的这两句代码非常重要,是读取数据必不可少的。

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

然后是运行sess.run()拿到实际数据,之前只是相当于定义好了,并没有得到真实数值。为了简单起见,我在之后的循环里只测试了一个batch的数据,关于tfrecord的标准使用我也建议读者去官网的数据读取部分看看示例。循环里对数据的各种信息进行了展示,结果如下:

结果展示

从图片的名字可以看出,数据的确是进行了shuffle的,标签、宽度、高度、图片本身以及对应的mask图像也全部展示出来了。

测试函数的最后,要使用以下两句代码进行停止,就如同文件需要close()一样:

coord.request_stop()
coord.join(threads)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow中使用tfrecord方式读取数据 的相关文章

随机推荐

  • Android屏幕尺寸适配常见方案smallestWidth

    前言 介于目前的Android设备存在有不同的屏幕尺寸 xff0c 屏幕分辨率 xff0c 像素密度 xff0c Android应用在开发的过程必须要考虑到屏幕尺寸适配的问题 xff0c 以保证在不同尺寸的Android设备上都能够正常运行
  • 奇数阶魔方阵!

    import java util Scanner public class Test5 打印 魔方阵 所谓的魔方阵是指这样的方阵 xff0c 它的行 列 对角线元素之和均相等 以下是奇数阶魔方阵 xff01 xff01 xff01 xff0
  • 基于Python高光谱遥感影像处理实例

    前言 在写波段配准相关代码时经常需要用到tif影像的波段合成和分解 xff0c 虽然可以用ENVI才处理 xff0c 但是每次都要打开再设置一些参数有些麻烦 xff0c 所以本着 独立自主 自力更生 的原则就写了些脚本来处理这个需求 又写了
  • 基于SIFT的图像Matlab拼接教程

    前言 图像拼接技术 xff0c 将普通图像或视频图像进行无缝拼接 xff0c 得到超宽视角甚至360度的全景图 xff0c 这样就可以用普通数码相机实现场面宏大的景物拍摄 利用计算机进行匹配 xff0c 将多幅具有重叠关系的图像拼合成为一幅
  • PyTorch 进行多步时间序列预测详细教程

    一 前言 Encoder decoder 模型提供了最先进的结果 xff0c 可以对语言翻译等 NLP 任务进行排序 多步时间序列预测也可以视为 seq2seq 任务 xff0c 可以使用编码器 解码器模型 本文提供了一个Encoder d
  • PERSIANN 降雨数据使用教程

    一 前言 PERSIANN xff0c 使用人工神经网络从遥感信息中估算降水 xff0c 是一种基于卫星的降水检索算法 xff0c 可提供近乎实时的降雨信息 该算法使用来自全球地球同步卫星的红外 IR 卫星数据作为降水信息的主要来源 红外图
  • 基于Pyqt5快速构建应用程序详细教程

    一 介绍 图形用户界面 xff0c 更广为人知的名称是 GUI xff0c 是当今大多数个人计算机的一个特征 它为不同计算技能水平的用户提供了直观的体验 尽管 GUI 应用程序可能会使用更多资源 xff0c 但由于其点击式特性 xff0c
  • 基于Python的PROSAIL模型介绍以及使用

    1 介绍 PROSAIL是两种模型耦合得到的 SAIL是冠层尺度的辐射传输模型 xff0c 把冠层假设成是连续的且具有给定几何形状和密度的水平均匀分布的介质层 xff0c 从而模拟入射辐射与均匀介质之间的相互作用 xff0c 具体还是挺复杂
  • 关于VS中LNK1120与errorLNK2019问题

    最近遇到了该问题 xff0c 再查找了一些资料后 xff0c 发现了针对自己问题的解决方法 xff0c 贴出来让大家一起学习一下 其实如果这两个问题同时出现 xff0c 很可能不是链接库缺了lib xff0c 而是编译中添加的源没有被实例化
  • PCL—低层次视觉—点云分割(基于凹凸性)

    转自 xff1a http www cnblogs com ironstark p 5027269 html PCL 低层次视觉 点云分割 xff08 基于凹凸性 xff09 1 图像分割的两条思路 场景分割时机器视觉中的重要任务 xff0
  • 【ENVI入门系列】13.分类后处理

    原文地址 xff1a ENVI入门系列 13 分类后处理 作者 xff1a ENVI IDL中国 版权声明 xff1a 本教程涉及到的数据提供仅练习使用 xff0c 禁止用于商业用途 目录 分类后处理 1 概述 2 分类后处理 2 1 小斑
  • ENVI神经网络工具参数和使用方法

    原文地址 xff1a ENVI神经网络工具参数和使用方法 作者 xff1a pengheligis xff08 1 xff09 Activation xff1a 选择活化函数 对数 xff08 Logistic xff09 和双曲线 xff
  • Android中依赖版本统一管理

    前言 在Android的实际开发中 xff0c 我们会经常使用到多Module开发 xff0c 而当我们修改一些版本信息或者SDK升级时 xff0c 可能涉及多个Module都需要修改 显然逐个修改Module中的build gradle文
  • 详解使用pscp命令Linux文件上传与下载

    一 上传 2 开始 运行 cmd进入到 dos模式输入以下命令 以下是代码片段 xff1a pscp D java apache tomcat 5 5 27 webapps szfdc rardev 64 192 168 68 249 ho
  • 二进制的表白

    没能提起勇气对她进行表白 xff0c 只能寄托于0 1代码记录下对你的喜欢 01000101 01110110 01100101 01101110 00100001 01001001 00100000 01101100 01101111 0
  • java 去除或者替换字符串里面的数字或者字母

    package testPattern import java util regex Matcher import java util regex Pattern public class TestPattern 64 param args
  • python机器学习之scikit安装

    scikit是Python很容易上手的第三方库 下面介绍一下安装过程中遇到的问题 环境是 xff1a win32 43 python27 安装scikit需要安装numpy和scipy 很多教程都会选择使用easy install或者pip
  • 【Windows批处理】交互界面设计

    echo off cls title 终极多功能修复 menu cls color 0A echo span class token keyword echo span span class token operator span span
  • Mac下AndroidStudio报错macMissing essential plugin:org.jetbrains.android Please reinstall Android Studio

    在Mac环境下升级Android studio时报如下错误 xff1a Missing essential plugin org jetbrains android Please reinstall Android Studio from
  • Tensorflow中使用tfrecord方式读取数据

    前言 本博客默认读者对神经网络与Tensorflow有一定了解 xff0c 对其中的一些术语不再做具体解释 并且本博客主要以图片数据为例进行介绍 xff0c 如有错误 xff0c 敬请斧正 使用Tensorflow训练神经网络时 xff0c