通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类

2023-11-12

实验目的

通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类。

实验环境

import tensorflow as tf

print(tf.__version__)

output:
2.3.0

实验步骤

(一) 数据获取和预处理

1.1 数据选择 TensorFlow 官方提供的花朵图片数据,经如下代码获取:

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
img_dir= tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)

1.2 读取图片:这里,我们通过 tf.keras.preprocessing.image_dataset_from_directory 函数批量读入图片。

import pathlib
# 数据保存路径
data_dir = pathlib.Path(data_dir)

BATCH_SIZE = 32  # BATCH size 设为32
img_height = 180  # 读取图片后,高度转换为180像素
img_width = 180  # 读取图片后,宽度转换为180像素

# 读入images (training data)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(img_dir,
                                                               shuffle=True, 
                                                               validation_split=0.2, 
                                                               seed=123, 
                                                               subset='training', 
                                                               batch_size=BATCH_SIZE,
                                                               image_size=(img_height, img_width)
                                                              )

# 读入images(test data)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(img_dir, 
                                                             shuffle=True,
                                                             validation_split=0.2,
                                                             seed=123,
                                                             subset='validation',
                                                             image_size=(img_height, img_width),
                                                             batch_size=BATCH_SIZE
                                                            )

1.3 查看训练数据的前9张图片.

plt.figure(figsize=(6, 6))
for imgs, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3,3,i+1)
        plt.imshow(imgs[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

在这里插入图片描述

(二) 通过 tf 的基础类,自定义模型

class Mymodel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        # 定义normalization 层
        self.normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)
        
        # 定义数据增强层
        self.aug1 = tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal')
        self.aug2 = tf.keras.layers.experimental.preprocessing.RandomRotation(0.1)
        self.aug3 = tf.keras.layers.experimental.preprocessing.RandomZoom(0.1)
        
        # 定义cov1
        self.cov1 = tf.keras.layers.Conv2D(16, (3,3), padding='same', activation='relu', name='cov1')
        self.pool1 = tf.keras.layers.MaxPool2D(name='pool1')
        
        # 定义cov2
        self.cov2 = tf.keras.layers.Conv2D(32, (3,3), padding='same', activation='relu', name='cov2')
        self.pool2 = tf.keras.layers.MaxPool2D(name='pool2')
        
        # 定义cov3
        self.cov3 = tf.keras.layers.Conv2D(64, (3,3), padding='same', activation='relu', name='cov3')
        self.pool3 = tf.keras.layers.MaxPool2D(name='pool3')
        
        # 定义 Dropout
        self.dropout = tf.keras.layers.Dropout(0.2)
        
        # 定义 flatten
        self.flatten = tf.keras.layers.Flatten()
        
        # 定义 Dense
        self.dense1 = tf.keras.layers.Dense(128, activation='relu', name='dense1')
        self.dense2 = tf.keras.layers.Dense(5)
        
        
    def call(self, img):
        # 执行normalization
        X = self.normalization_layer(img)
        
        # 执行aug
        X = self.aug1(X)
        X = self.aug2(X)
        X = self.aug3(X)
        
        X = self.cov1(X)
        X = self.pool1(X)
        
        X = self.cov2(X)
        X = self.pool2(X)
        
        X = self.cov3(X)
        X = self.pool3(X)
        
        X = self.flatten(X)
        X =  self.dense1(X)
        X = self.dense2(X)
        
        return X

(三) 定义损失函数

def loss(y_true, y_predict):
    return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true, y_predict)

(四) 定义优化函数

optimizer = tf.keras.optimizers.Adam()

(五) 定义训练函数

def train_step(batch_inp, batch_targ, model):
    with tf.GradientTape() as tape:
        dense_ = model(batch_inp)
    
        batch_loss = loss(batch_targ, dense_)
    gradients = tape.gradient(batch_loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return batch_loss

(六) 训练模型

# 实例化模型
model = Mymodel()

epochs = 50  # 训练50个epoch
els = [] # 存储每个epoch的损失函数,用于后续绘图
for epoch in range(epochs):
    epoch_loss = 0
	# 由于我的计算机显存太小,这里每个epoch只取前20个batch进行训练
    for batch, (inp, targ) in enumerate(train_ds.take(20)):
        batch_loss = train_step(inp, targ, model)
        epoch_loss += batch_loss.numpy()
    print('epoch {}: {:.4f}'.format(epoch, epoch_loss/10))
    els.append(epoch_loss/10)

训练过程如下:

epoch 0: 0.5867
epoch 1: 0.6709
epoch 2: 0.6393
epoch 3: 0.6831
epoch 4: 0.6870
epoch 5: 0.6461
epoch 6: 0.4888

Loss 随训练过程的变化情况:

在这里插入图片描述

(七) 通过模型进行预测

预测的代码来之 TensorFlow 官方社区。

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)

img = keras.preprocessing.image.load_img(
    sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0) # Create a batch

predictions = model.predict(img_array)
score = tf.nn.softmax(predictions[0])

print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
    .format(class_names[np.argmax(score)], 100 * np.max(score))
)

图片为:

在这里插入图片描述

预测结果:

This image most likely belongs to sunflowers with a 97.69 percent confidence.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类 的相关文章

  • 在vue中使用for循环有异步请求,每次都只获取到最后一个数据解决办法

    我预想是将标签数组 dynamicTags 使用for循环取出每个值 在遍历标签数组将值一一添加到数据库中 奈何for循环结束了 才去执行put请求 这就导致了只添了数组的最后一个值 原因是axios put请求是异步请求 解决方案 使用递
  • 分析整理文献的具体步骤_VOSviewer文献综述

    文献综述在科技论文和学位论文中占据着重要地位 是论文中的一个重要且不可或缺的部分 它是研究者在对某一学科 专业或专题的大量文献进行阅读 整理 筛选 分析 综合和提炼的基础上 用自己的语言综合叙述研究状况的情报研究成果 因此 文献综述的好坏直
  • 计算宽度_桥梁有效宽度计算,看看很有用!

    有效分布宽度实质上是剪力滞效应的反应 由于目前桥梁设计多用二维平面解析 故荷载的有效分布宽度仍需要计算 不过还有很多深层次问题还不能合理解答 有待进一步研究和探讨 各中间跨正弯矩段取该跨计算跨径的0 2倍 边跨正弯矩段取该跨计算跨径的0 2

随机推荐

  • 回溯法-装载问题

    子集树问题 和 子集树的0 1背包问题类似 但是没有考虑价格 include
  • 【Parallels Desktop】解决Sorry, This Application Cannot Be Run Under A Virtual Machine

    问题描述 解决步骤 Win R 或Cmd R 打开 运行 窗口 输入regedit并点击 确定 打开注册表编辑器 依次展开HKEY LOCAL MACHINE HARDWARE ACPI DSDT文件夹 鼠标右键点击PRLS 选择 重命名
  • Redis第二讲 Redis数据持久化AOF和RDB

    RDB快照 snapshot 在默认情况下 Redis 将内存数据库快照保存在名字为 dump rdb 的二进制文件中 你可以对 Redis 进行设置 让它在 N 秒内数据集至少有 M 个改动 这一条件被满足时 自动保存一次数据集 save
  • 【修仙境界】等级划分

    文章目录 一 下境界 1 炼气 一共13层 2 筑基 分初 中 后期和大圆满 3 结丹 分初 中 后期和大圆满 4 元婴 分初 中 后期和大圆满 5 化神 分初 中 后期和大圆满 二 中境界 1 炼虚 分初 中 后期和大圆满 2 合体 分初
  • C++ 编程出错的地方(考试选择题易错点)

    一 int IsSvn int n if n 7 0 return 1 要判断这个数能不能被7整除 你就只返回1吗 那岂不是只返回1 没有0的情况了 应该改为 int IsSvn int n if n 7 0 return 1 else r
  • 2021年电赛模块化程序总结

    文章目录 1 ADC0804 2 LCD1602 3 AD9854 1 ADC0804 集成A D转换器品种繁多 选用时应综合考虑各种因素选取集成芯片 一般逐次比较型A D转换器用的比较多 ADC0804就是这类单片集成A D转换器 ADC
  • 9、HTML:有序列表(ol),无序列表(ul),描述列表(dl、dt、dd)详解

    1 什么是列表 什么是列表 什么是有序列表 什么是无序列表 上面写的 3 句话就是一个列表 你懂得 2 有序列表 有序列表 英文叫做 ordered list 所以标签也是取这个词组的首字母 ol ol标签括起来的范围就是有序列表的范围 而
  • Win11怎么修改c盘用户名?

    Win11怎么修改c盘用户名 不知道的小伙伴们可以学起来了 谨慎操作 以下的方法提供给你 希望对你有所帮助 Win11更改C盘user用户名教程 一 开启Administrator权限并登入 搜索框搜索cmd 右击以管理员身份运行 出现cm
  • C++每日一问:C++ 内存管理——内存泄漏及处理

    2 内存泄漏 2 1 C 中动态内存分配引发问题的解决方案 假设我们要开发一个String类 它可以方便地处理字符串数据 我们可以在类中声明一个数组 考虑到有时候字符串极长 我们可以把数组大小设为200 但一般的情况下又不需要这么多的空间
  • 唯一分解定理(分解质因子)

    唯一分解定理 每个大于一的自然数均可写为质数的积 而且这些素因子按大小排列之后 写法只有一种方式 最简单的写法 include
  • matlab绘制正弦函数、幅度调制初步、Inner matrix dimensions must agree错误

    以sin 2 f t 表达式来绘制正弦图像 必须给定数值序列才能绘制出图像 t必须给定一个数值序列 然后计算出 y sin 函数值序列 以t为横轴 y为纵轴 就绘制出了图像 先给出f 4 在这里是有几个周期 采样率Fs 100 matlab
  • flask从入门到精通,知识讲解+代码演示 day1

    flask从入门到精通 知识讲解 代码演示 day1 文章目录 flask从入门到精通 知识讲解 代码演示 day1 一 flask是什么 二 使用步骤 1 创造flask项目 2 初入flask 3 flask代码初运行 4 flask从
  • Spring Cloud实战(五)-声明式接口模块

    接着上一篇 Spring Cloud实战 四 配置中心 现在开始搭建api模块 一 声明式接口模块api 1 pom xml
  • 数学建模-相关性分析(Matlab)

    注意 代码文件仅供参考 一定不要直接用于自己的数模论文中 国赛对于论文的查重要求非常严格 代码雷同也算作抄袭 如何修改代码避免查重的方法 https www bilibili com video av59423231 清风数学建模 一 基础
  • GPU与GPGPU泛淡

    GPU与GPGPU泛淡 GPU Graphics Processing Unit 也即显卡 是一种专门在个人电脑 工作站 游戏机和一些移动设备 如平板电脑 智能手机等 上作图像运算工作的微处理器 它已经是个人PC和移动设备上不可或缺的芯片
  • C#数据类型之枚举类型

    一 枚举类型的定义 public enum 枚举名称 枚举数据类型 枚举的数据类型可以省略 默认类型为int 枚举项1 枚举项的值 枚举项的值是整数可以自己设置 枚举项2 枚举项3 例如 public enum month ushort 一
  • Clion + mysql (win/Mac + 本地/远程)

    新手教程 那些年我用clion操作mysql的一些经验教训 本文目录 使用clion自带的数据库工具 对数据库进行操作 连接本地数据库 建库 建表 编辑表格 修改字段名 查询数据 插入新的数据 sql常用语句 mysql版 win Clio
  • 口罩检测——数据准备(2)

    文章目录 前言 一 数据介绍 二 数据标注 三 数据转换 总结 前言 上一篇文章中小编讲解了口罩检测的环境要求 在这一篇文章中我们就正式进入项目的讲解 我们从数据准备开始 数据是模型快乐的源泉 没有高质量的数据 再好的模型也白搭 一 数据介
  • Flink消费Rabbit数据,写入HDFS - 使用 BucketingSink

    一 应用场景 Flink 消费 Kafka 数据进行实时处理 并将结果写入 HDFS 二 Bucketing File Sink 由于流数据本身是无界的 所以 流数据将数据写入到分桶 bucket 中 默认使用基于系统时间 yyyy MM
  • 通过 Tensorflow 的基础类,构建卷积神经网络,用于花朵图片的分类

    实验目的 通过 Tensorflow 的基础类 构建卷积神经网络 用于花朵图片的分类 实验环境 import tensorflow as tf print tf version output 2 3 0 实验步骤 一 数据获取和预处理 1