正确使用 tfds.load() 中的 Cifar-10 数据集

2024-04-04

我正在尝试使用 Cifar-10 数据集来练习我的 CNN 技能。

如果我这样做就可以了:

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

但我试图使用tfds.load()我不明白该怎么做。

有了这个我就下载了,

train_ds, test_ds = tfds.load('cifar10', split=['train','test'])

现在我尝试了这个但不起作用,

assert isinstance(train_ds, tf.data.Dataset)
assert isinstance(test_ds, tf.data.Dataset)
(train_images, train_labels) = tuple(zip(*train_ds))
(test_images, test_labels) = tuple(zip(*test_ds))

有人可以告诉我实现它的方法吗?

谢谢你!


您可以按如下方式执行此操作。

import tensorflow as tf 
import tensorflow_datasets as tfds

train_ds, test_ds = tfds.load('cifar10', split=['train','test'], as_supervised=True)

These train_ds and test_ds are tf.data.Dataset对象,所以你可以使用map, batch,以及与其中每一个类似的功能。

def normalize_resize(image, label):
    image = tf.cast(image, tf.float32)
    image = tf.divide(image, 255)
    image = tf.image.resize(image, (28, 28))
    return image, label

def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_saturation(image, 0.7, 1.3)
    image = tf.image.random_contrast(image, 0.8, 1.2)
    image = tf.image.random_brightness(image, 0.1)
    return image, label 

train = train_ds.map(normalize_resize).cache().map(augment).shuffle(100).batch(64).repeat()
test = test_ds.map(normalize_resize).cache().batch(64)

现在,我们可以通过train and test直接到model.fit.

model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Flatten(input_shape=(28, 28, 3)),
            tf.keras.layers.Dense(128, activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(10, activation="softmax"),
        ]
    )

model.compile(
    optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
model.fit(
    train,
    epochs=5,
    steps_per_epoch=60000 // 64,
    validation_data=test, verbose=2
)
Epoch 1/5
17s 17ms/step - loss: 2.0848 - accuracy: 0.2318 - val_loss: 1.8175 - val_accuracy: 0.3411
Epoch 2/5
11s 12ms/step - loss: 1.8827 - accuracy: 0.3144 - val_loss: 1.7800 - val_accuracy: 0.3595
Epoch 3/5
11s 12ms/step - loss: 1.8383 - accuracy: 0.3272 - val_loss: 1.7152 - val_accuracy: 0.3904
Epoch 4/5
11s 11ms/step - loss: 1.8129 - accuracy: 0.3397 - val_loss: 1.6908 - val_accuracy: 0.4060
Epoch 5/5
11s 11ms/step - loss: 1.8022 - accuracy: 0.3461 - val_loss: 1.6801 - val_accuracy: 0.4081
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

正确使用 tfds.load() 中的 Cifar-10 数据集 的相关文章

随机推荐

  • 使用按钮导航到导航窗口中的另一个页面

    我正在尝试使用 WPF 中的导航命令框架在 WPF 应用程序 桌面 不是 XBAP 或 Silverlight 内的页面之间导航 我相信我已经正确配置了所有内容 但它不起作用 我构建并运行时没有错误 在输出窗口中没有收到任何绑定错误 但我的
  • 将 CSS3 动画/变换与滚动事件链接起来

    我正在寻找一种将 CSS3 过渡链接到滚动事件的方法 我正在寻找类似的功能http nizoapp com http nizoapp com 当您到达页面上的某个滚动点时 某些元素会移动 我知道你必须使用 jQuery 调用滚动事件 使用偏
  • 具有多个 S3 源的 AWS CloudFront

    我想配置 AWS CloudFront CDN 以提供来自两个 AWS S3 存储桶的 HTML 静态内容 第一个存储桶应在根目录中托管对象 第二个存储桶应在特定子路径中托管对象 S3配置 第一个桶 myapp home 应将主页和所有其他
  • C# 中的排序列表

    如何根据项目的整数值对列表进行排序 该列表就像 1 5 3 6 11 9 NUM1 NUM0 结果应该是这样的 1 3 5 6 9 11 NUM0 NUM1 有什么想法可以使用 LINQ 或 Lambda 表达式来做到这一点吗 提前致谢 这
  • 导入 _imaging 时 DLL 加载失败:

    我正在尝试运行我的 Python 程序 这些是我要导入的模块 从 tkinter 导入 从 functools 导入部分将 numpy 导入为 np 导入 matplotlib matplotlib use TkAgg 从 matplotl
  • 如何使用 CSS 将自定义位图字体嵌入到网站中

    如何使用 CSS 将自定义位图字体嵌入到我的网站中 我已尝试以下操作 但它只是恢复为后备字体 font face font family AgendaSemibold src url Agenda Semibold bmap format
  • JPA:检查实体对象是否已持久化

    有没有一个通用的方法可以 if entity is persisted before entity entity merge else entity persist 那么包含上述逻辑的方法在任何地方都是安全的吗 如果您需要知道对象是否已经在
  • Google Analytics API 3 - 错误:“invalid_grant”,说明:“”,Uri:“”

    我今天用谷歌搜索了这个问题的生命 分辨率为零 我正在尝试使用服务帐户构建一个非常简单的 Google Analytics 数据请求控制台应用程序 我已在 Google Developers Console 中设置了所有必需的详细信息 但收到
  • TFDMoniFlatFileClientLink 不规则地不跟踪到文件

    我有一个TFDMoniFlatFileClientLink在表单上 文件名设置为d temp monitor txt 追踪 真 TFDConnection Params MonitorBy mbFlatFile 这有时有效 有时则不跟踪任何
  • Python-创建一个以变量为名称的文本文件

    所以我正在做一个项目 我的程序创建一个名为 十个绿色瓶子 的文本文件 并在其中写入 10 个绿色瓶子歌曲 我已经成功地使其工作 但我想让它变得更好 我首先让用户可以选择瓶子的数量 效果很好 现在我只希望名称与用户输入的瓶子数量相关 即 如果
  • 为什么 Linux 可以在多处理中接受套接字?

    该代码在 Linux 上运行良好 但在 Windows 下失败 这是预期的 我知道多处理模块使用fork 产生一个新进程 并且父进程拥有的文件描述符 即打开的套接字 因此由子进程继承 然而 据我了解 您可以通过多处理发送的唯一数据类型需要是
  • B 树和 2-3-4 树之间的区别

    B 树和 2 3 4 树有什么区别 另外 你如何找到每个的最大和最小高度 链接到维基百科 http en wikipedia org wiki 2 3 4 tree and引用 2 3 4 树是 4 阶 B 树 A 2 3 4 is a B
  • 改善 python numpy 代码的运行时间

    我有一个代码可以将垃圾箱重新分配给一个大的numpy大批 基本上 大型数组的元素已以不同的频率进行采样 最终目标是将整个数组重新组合到固定的容器中freq bins 对于我拥有的数组来说 代码有点慢 有什么好的方法可以提高这段代码的运行时间
  • 将视频上传到 Google App Engine Blobstore

    我试图将视频文件与具有一堆属性的记录相关联 但似乎无法允许用户以一种形式执行所有操作 命名视频 提供描述并回答一些问题 然后上传文件 以下是我想要执行的步骤 用户将看到一个包含表单的页面 其中包含以下字段 名称 描述 文件选择器 文件被存储
  • 为什么replace()函数不起作用? [复制]

    这个问题在这里已经有答案了 我正在使用 Selenium 抓取一个网站 当我获取元素列表 标题 的文本时 它会打印以下内容 Countyarrow upward Reportingarrow upward Totalarrow upward
  • SQL 到 LINQ 工具 [关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 如何解决java中的连接重置异常?

    我正在尝试将视频上传到 youtube 我的代码是 private static void uploadVideo YouTubeService service throws IOException System out println F
  • 将位图资源存储在静态变量中

    我有一个显示小位图的视图 它在我的应用程序的许多地方使用 特别是列表视图 目前 每次使用创建该视图的实例时 我都会加载此位图BitmapFactory decodeResource resource id 我意识到 我可以通过将该位图加载到
  • 在 Azure 应用程序设置中配置我的连接字符串,并将其在我的 web.config 中配置为环境变量

    我在Azure中有一个Web应用程序 并且在其应用程序设置中配置了connectionStrgin 但我不知道如何将此配置设置为应用程序web config Net 中的环境变量 有人有一些文档或知道如何实现这一点 到目前为止我已经查看了微
  • 正确使用 tfds.load() 中的 Cifar-10 数据集

    我正在尝试使用 Cifar 10 数据集来练习我的 CNN 技能 如果我这样做就可以了 train images train labels test images test labels datasets cifar10 load data