使用tf.data.Dataset.from_tensor_slices五步加载数据集

2023-10-31

前言:

最近在学习tf2
数据加载感觉蛮方便的
这里记录下使用 tf.data.Dataset.from_tensor_slices 进行加载数据集.
使用tf2做mnist(kaggle)的代码

思路

Step0: 准备要加载的numpy数据
Step1: 使用 tf.data.Dataset.from_tensor_slices() 函数进行加载
Step2: 使用 shuffle() 打乱数据
Step3: 使用 map() 函数进行预处理
Step4: 使用 batch() 函数设置 batch size
Step5: 根据需要 使用 repeat() 设置是否循环迭代数据集

代码

import tensorflow as tf
from tensorflow import keras

def load_dataset():
	# Step0 准备数据集, 可以是自己动手丰衣足食, 也可以从 tf.keras.datasets 加载需要的数据集(获取到的是numpy数据) 
	# 这里以 mnist 为例
	(x, y), (x_test, y_test) = keras.datasets.mnist.load_data()
	
	# Step1 使用 tf.data.Dataset.from_tensor_slices 进行加载
	db_train = tf.data.Dataset.from_tensor_slices((x, y))
	db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
	
	# Step2 打乱数据
	db_train.shuffle(1000)
	db_test.shuffle(1000)
	
	# Step3 预处理 (预处理函数在下面)
	db_train.map(preprocess)
	db_test.map(preprocess)

	# Step4 设置 batch size 一次喂入64个数据
	db_train.batch(64)
	db_test.batch(64)

	# Step5 设置迭代次数(迭代2次) test数据集不需要emmm
	db_train.repeat(2)

	return db_train, db_test

def preprocess(labels, images):
	'''
	最简单的预处理函数:
		转numpy为Tensor、分类问题需要处理label为one_hot编码、处理训练数据
	'''
	# 把numpy数据转为Tensor
	labels = tf.cast(labels, dtype=tf.int32)
	# labels 转为one_hot编码
	labels = tf.one_hot(labels, depth=10)
	# 顺手归一化
	images = tf.cast(images, dtype=tf.float32) / 255
	return labels, images
  1. one_hot 编码: 小姐姐给你解释去 (我在使用自带的fit函数进行训练的时候,发现报错维度不正确,原来是不需要one_hot编码)

  2. shuffle()函数的数值: 源码链接, 内容我贴图了
    函数定义源码
    我找到一个比较好的解释: 简书真是好东西

  3. 我发现 自己的数据使用tf.data.Dataset.from_tensor_slices(x, y)加载时, 一定要x在前y在后。。。没仔细看函数说明,否则会导致bug的emmm

  4. 使用了该函数之后, fit的时候是不支持 validation_split 这个参数提供的功能的~

总结

五个步骤很重要 比较简单的方式加载数据 当然还有其他方法加载 之后再说叭
此外, 建议读读api tf.data.Dataset 里好东西太多了~

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用tf.data.Dataset.from_tensor_slices五步加载数据集 的相关文章

随机推荐

  • Dynamics 365 APP -- 清晰定义你的系统职责范围

    今天我们来看看Dynamics 365 的一个新feature APP 对的 没错是APP 各位小伙伴肯定很诧异 难道Dynamics 365又推出了新版本的APP吗 如果各位这么想的话就中了博主的招了 因为博主故意放了一个烟雾弹 今天要讲
  • Windows下把CUDA程序生成dll库并在项目中调用dll中的函数

    如何把自己写的cuda代码生成dll库 方便集成到其他主项目中去进行调用呢 这里总结了一个基本流程 操作环境 Windows10 visual studio2017 cuda10 2 opencv4 2都已经安装并配置好了 主题1 cuda
  • 西门子PLC内部的数据类型大全

    西门子PLC的数据类型种类繁多 本文进行了收集 并指明了适用范围 长度 供需要进行数据采集和分析的朋友们参考 本表格整理自博图V14 不保证更高级版本不会新增数据类型 请使用中注意 类别 数据类型 长度 位 长度 字节 S7 300 400
  • php 递归面试题_8个PHP数组面试题,php数组试题

    8个PHP数组面试题 php数组试题 网上找的PHP数组题 准备自己做一遍并且记录下来 1 写函数创建长度为10的数组 数组中的元素为递增的奇数 首项为1 复制代码 代码如下 function arrsort first length ar
  • Python 十大装 B 语法【Python干货】

    Python 是一种代表简单思想的语言 其语法相对简单 很容易上手 不过 如果就此小视 Python 语法的精妙和深邃 那就大错特错了 本文精心筛选了最能展现 Python 语法之精妙的十个知识点 并附上详细的实例代码 如能在实战中融会贯通
  • Exception 处理之最佳实践

    作者 Gunjan Doshi 2003 11 19 译者注 本文算是一篇学习笔记 仅供学习参考使用 有不妥之处 还请指出 2003 12 04 本文是Exception处理的一篇不错的文章 从Java Exception的概念介绍起 依次
  • L3 Hive操作

    示例 1 建表 create table t dml detail id bigint sale date date province string city string product id bigint cnt double amt
  • Yarn的安装详解?Yarn的各种系统安装详解

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 如何在不同系统环境中安装Yarn Yarn在各种系统的安装详解 Yarn的安装详细的教程 希望能帮助一些程序袁 工具 原料 电脑 Yarn Windows安装详解 1 可以
  • NJUPT南邮

    设计可用于该实验的进程控制块 进程控制块至少包括进程号 状态和要求服务时间 动态或静态创建多个进程 模拟操作系统四种进程调度算法 先进先出 短作业优先 高优先级优先 高相应比优先中的任意两种 调度所创建的进程并显示调度结果 package
  • matlab2021版关于csv文件读写的一些方法

    首先给出一些演示数据 直接给出来大家看起来都方便 完整代码在最后 有基础的可以直接看代码 下面是data csv的文件内容 可以看得出里面有文本也有数值 代码 名称 最新价 涨跌额 涨跌幅 买入 卖出 昨收 今开 最高 最低 成交量 成交额
  • MySQL binlog 日志解析

    很多时候 当我们的业务数据产生了不正常的变化 但却无法得知这类操作是在哪里进行 并且如何进行 单单从程序当面排查很费力 那么就需要通过分析数据库日志来得到历史执行 SQL 根据 SQL 执行逻辑来确认代码位置 进而确认是否是 BUG 亦或是
  • configure –prefix 的用法

    源码的安装一般由有这三个步骤 配置 configure 编译 make 安装 make install 其中 prefix选项就是配置安装的路径 如果不配置该选项 安装后可执行文件默认放在 usr local bin 库文件默认放在 usr
  • spring boot 根据模板生成新的文件(exec),复杂一点的,新的导出

  • Timesat提取物候信息并绘图

    Timesat提取物候信息并绘图 前言 一 准备数据 1 将 Tiff 数据转成 dat 数据并生成 Filelist 2 使用python生成单个像素点text时序数据 二 Timesat 打开时序数据 1 处理dat文件 提取物候信息
  • Windows 随意切换node版本

    第一步 先清空本地安装的node js版本 卸载 删除 第二步 安装nvm管理工具 先关掉360等软件 不然会弹出警告 1 从官网下载安装包 https github com coreybutler nvm windows releases
  • Java实现微信扫码登录并实现认证授权

    Java实现微信扫码登录并实现认证授权 1 登录流程及原理 1 1 OAuth2协议 网站应用微信登录是基于OAuth2 0协议标准构建的微信OAuth2 0授权登录系统 在进行微信OAuth2 0授权登录接入之前 在微信开放平台注册开发者
  • java中的date_Java中Date类型详解

    一 Date类型的初始化 1 Date int year int month int date 直接写入年份是得不到正确的结果的 因为java中Date是从1900年开始算的 所以前面的第一个参数只要填入从1900年后过了多少年就是你想要得
  • Spring:IOC控制反转、@Bean和@Component、日志、注入、注解

    Spring核心知识点 Spring 核心功能 IOC 控制反转 和 AOP 面向切面编程 一 什么是IOC Inversion of Control 控制反转 1 主动控制 2 控制反转 二 使用原生Spring创建Demo项目 1 导入
  • 基于python的opencv入门到精通(一)

    记录自己从0开始成长的研究生生活 文章目录 前言 一 Anaconda是什么 二 已经安装了python如何与Anaconda共存 三 如何将PyCharm与Anaconda进行关联 四 配置Anaconda源 五 如何彻底删除python
  • 使用tf.data.Dataset.from_tensor_slices五步加载数据集

    前言 最近在学习tf2 数据加载感觉蛮方便的 这里记录下使用 tf data Dataset from tensor slices 进行加载数据集 使用tf2做mnist kaggle 的代码 思路 Step0 准备要加载的numpy数据