卷积神经网络识别花卉并分类另保存

2023-11-19

本篇博客转载自卷积神经网络训练花卉识别分类器

本篇博客的所有代码已上传至GitHub仓库,后续会更新各个文件夹及文件的详细说明,用者自取

由于卷积神经网络训练花卉识别分类器博客已将模型的训练、测试代码写好,且可以通过这篇博客获取到大神训练好的模型,这里我只简单记录一下copy并运行时遇到的问题。

一、VGG16_test.py遇到的问题及解决方法
1. 不知道电脑是否有CUDA

解决:在Terminal的python中输入下列代码

if torch.cuda.is_available():
    print("Has cuda")
else:
    print("No cuda")
2. 需要将加载模型的设备改为CPU

解决:将原代码中的代码片A对应地更改为代码片B

# 代码片A
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load("VGG16_flower_200.pkl"))
net.to(device)
# 代码片B
device = torch.device("cpu")
net.load_state_dict(torch.load("VGG16_flower_200.pkl", map_location=device))
3. 提示各种文件路径不正确

解决:将所有的路径改为本机适用的,并将下载好的后缀名为.pkl的模型copy到./模型训练/VGG16文件夹中

至此测试数据集路径、模型路径的程序就可以运行起来了,且可以在CLI中看到识别花卉的正确性
二、自写分类.py
1. 大致思路
for循环遍历图片数据集
        net网络得到score
        softmax得到各种花的概率
        将最大概率的花的index映射成花名
        将图片内容和花名作为一对集合写进文件夹
2. 分类.py大部分代码仍copy原图片测试.py
# 原代码
from PIL import Image
import torchvision.transforms as transforms
from torchvision import models  #人家的模型
from torch.autograd import Variable
import torch
#from torchvision.datasets import ImageFolder
from torch import nn
#import VGG16_model


#数据预处理
data_transform = transforms.Compose([
    transforms.Resize((224,224), 2),                           #对图像大小统一
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[    #图像归一化
                             0.229, 0.224, 0.225])
         ])

#类别
#这个类别是我在训练的过程输出的训练集的类别,是按照训练的顺序排列的
data_classes = ['Cerasus', 'Dianthus', 'Digitalis_purpurea', 'Eschscholtzia', 
                'Gazania', 'Jasminum', 'Matthiola', 'Narcissus', 'Nymphaea', 
                'Pharbitis', 'Rhododendron', 'Rosa', 'Tithonia', 'Tropaeolum_majus', 
                'daisy', 'dandelion', 'peach_blossom', 'roses', 'sunflowers', 'tulips']

#读取数据
img = Image.open('./图片/向日葵.jpg') 
img=data_transform(img)#这里经过转换后输出的input格式是[C,H,W],网络输入还需要增加一维批量大小B
img = img.unsqueeze(0)#增加一维,输出的img格式为[1,C,H,W]

#类别
#train_dataset = ImageFolder(root='work/data/train/',transform=data_transform)
#data_classes = train_dataset.classes

#选择CPU还是GPU的操作
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#选择模型

net = models.vgg16()
net.classifier = nn.Sequential(nn.Linear(25088, 4096),      #vgg16
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 20))


#读取参数
net.load_state_dict(torch.load("VGG16_flower_200.pkl",map_location=torch.device('cpu')))
net.eval()
net.to(device)

img = Variable(img)
score = net(img)#将图片输入网络得到输出
probability = nn.functional.softmax(score,dim=1)#计算softmax,即该图片属于各类的概率
max_value,index = torch.max(probability,1)#找到最大概率对应的索引号,该图片即为该索引号对应的类别
print()
print("识别为'{}'的概率为{}".format(data_classes[index.item()],max_value.item()))
3. 选择加载设备部分与读取参数部分的更改

解决:将原代码中的代码片A对应地更改为代码片B

# 代码片A
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.load_state_dict(torch.load("VGG16_flower_200.pkl"))
net.to(device)
# 代码片B
device = torch.device("cpu")
net.load_state_dict(torch.load("VGG16_flower_200.pkl", map_location=device))
4. 批量读取识别图片
1. 设置照片根目录并加载
root_path = "/Users/stone/PycharmProjects/pythonProject/flower_det/flower_recognition/data/myTest"
files = os.listdir(root_path)
2. for循环遍历识别图片,得到最大概率的花名
for file in files:
    if not os.path.isdir(file):
        img = Image.open(root_path+"/"+file)
        img = data_transform(img)  # 图片预处理
        img = img.unsqueeze(0) # 增加一维,输出的img格式为[1,C,H,W]
        img = Variable(img) # 转换格式
        score = net(img)  # 将图片输入网络得到输出
        probability = nn.functional.softmax(score, dim=1)  # 计算softmax,即该图片属于各类的概率
        max_value, index = torch.max(probability, 1)  # 找到最大概率对应的索引号,该图片即为该索引号对应的类别
3. 在for循环中利用CV2保存图片
cv_img = cv2.imread(root_path+"/"+file)
dir_name = "/Users/stone/PycharmProjects/pythonProject/flower_det/flower_recognition/data" \
                   + "/" \
                   + data_classes[index.item()]
if not os.path.exists(dir_name):
	os.mkdir(dir_name)
images = os.listdir(dir_name)
num_images = len(images)
filename = dir_name + "/" + str(num_images+1) + ".jpg"
cv2.imwrite(filename, cv_img)
至此,分类.py算是完成了基础功能,运行后,在dir_name文件夹中会保存识别分类好的花卉
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

卷积神经网络识别花卉并分类另保存 的相关文章

随机推荐

  • CoLab设置使用GPU和TPU

    tf2 4 0 from tensorflow python keras callbacks import EarlyStopping from tensorflow python keras layers import Embedding
  • mysql学习系列(2)--忘记mysql登录密码怎么办?

    系列文章目录 文章目录 系列文章目录 前言 一 登录mysql 二 操作步骤 1 找到mysql exe所在的文件夹 2 Win R打开cmd 进入bin文件夹 3 跳过mysql用户验证 3 net start mysql启动服务 总结
  • c++模板 --- 类模板、自定义类型当做模板参数

    生成一个类模板 类中用到了未知类型叫做类模板 用 template 修饰的类 这个类就是一个模板类 多用在数据结构中 忽略类型的问题 只要被 template 修饰 就是一个模板类 有没有用未知类型都是模板类 把模板当做一种特殊的数据类型即
  • hdoj1036

    讨厌的输入和输出 include
  • C语言枚举

    一 枚举类型 枚举类型 一个整型变量只有几种可能的值 值用枚举常量来表示 每个枚举常量可以用一个标识符来表示 也可以为它们指定一个整数值 如果没有指定 那么默认从 0 开始递增 在C 语言中 枚举类型是被当做 int 或者 unsigned
  • 计算机图形学----光线追踪----路径追踪

    基础知识预备 概率论 概率 期望 概率 值 PDF 概率密度函数 概率密度函数 p x 在数学中 连续型随机变量的概率密度函数 在不至于混淆时可以简称为密度函数 是一个描述这个随机变量的输出值 在某个确定的取值点附近的可能性的函数 也就是结
  • MySQL——数据类型以及对表结构的修改

    MySQL的数据类型 刚才我们在创建表的时候 说到了一个字段类型 所谓的字段类型就是这个字段能存放的数据的数据类型 在MySQL中有以下几种数据类型 数据类型 大小 字节 用途 格式 INT 4 整数 FLOAT 4 单精度浮点类型 DOU
  • Python可视化——绘制折线图

    绘制折线图 plot 1 准备工作 绘制可视化图形 将会使用到Matplotlib库中的pyplot包 Matplotlib是Python的绘图库 其中的pyplot包封装了很多画图的函数 Matplotlib pyplot 包含一系列类似
  • 动力节点老杜java基础视频笔记第一章 学前准备 (1)

    课堂截图 为什么使用截图工具 在听课的过程中 有的时候老师操作的比较快 通过截图的方式将老师的操作保存下来 以便后期的操作 另外截图之后的图片也可以用于笔记的记录 在笔记当中最好采用图文并茂的方式 这样更加利于知识的回顾 使用哪个截图工具
  • unity期末个人作品-落笔洞寻宝探险之旅(寻宝游戏)

    落笔洞寻宝探险之旅 unity寻宝游戏 下载链接在文章下方 为了增添生活的乐趣开发的这款落笔洞寻宝游戏 主要内容为人物在落笔山脉寻找金币 右上角有金币计数器 找到所有金币则获胜 山中有障碍物 触碰会掉血50 人物生命值为100 血量为0则游
  • 实验6Hive分组排序

    实验6Hive分组排序 实验目的及要求 掌握Hive中全局排序Order by 内部排序Sort by的用法及区别 掌握Hive中Group by分组语句的用法 了解Hive中Distribute By分区排序 Group By及Clust
  • 成年人正确学英语的方式

    成年人正确学英语的方式 本人女 毕业两年 2020年3月份开始准备考试商务英语bec的考试 到2020年5月29号考试 期间准备了3个月 最终以165分的成绩通过商务英语bec中级考试 本着以热爱学习乐于分享的精神 给大家开源下我的bec的
  • 【历史上的今天】10 月 2 日:ENIAC 计算机退休;贝尔德发明电视;香港科技大学办学

    整理 王启隆 透过 历史上的今天 从过去看未来 从现在亦可以改变未来 今天是 2021 年 10 月 2 日 在历史上今天发生的科技关键事件不比昨天要少 举世闻名的通用电子计算机 ENIAC 便在今天退休 我国享誉世界的学府香港科技大学正式
  • vector的讲解及模拟实现(c++)

    为了方便大家理解我们边模拟实现vector容器的常用操作 然后根据代码讲解如何使用vector的这些功能 这样的话相信可以帮助大家更好的理解vector 目录 一 vector的介绍 二 vector模拟实现的讲解 1 vector的模块分
  • 计算机两个硬盘如何区分,双硬盘电脑怎么设置主从盘?

    随着电脑中存储的资料逐步增加 我以前电脑的硬盘空间严重不足 所以我购买了一块新硬盘 但是 两块硬盘放在一起工作后 经常发生死机 运行速度慢等问题 我查了很多资料后 最终确定是主从盘设置方面出的问题 本文将为大家介绍我是如何解决问题的 一 区
  • 毕业设计-基于深度学习的新闻推荐算法研究

    目录 前言 课题背景和意义 实现技术思路 基于深度学习的新闻推荐方法 1 DNR中的 两段式 方法 2 DNR中的 融合式 方法 实现效果图样例 最后 前言 大四是整个大学期间最忙碌的时光 一边要忙着备考或实习为毕业后面临的就业升学做准备
  • ubuntu的root用户ssh远程登录问题

    ubuntu默认不允许root远端登录 其它创建的用户默认是可以的 编辑ssh服务的配置文件 cd etc ssh 修改sshd config文件 设置允许root用户远程登录 找到 PermitRootLogin prohibit pas
  • R语言基础——缺失数据

    R语言基础 缺失数据 缺失数据的分类 统计学家通常将缺失数据分为三类 它们都用概率术语进行描述 但思想都非常直观 我们将用sleep研究中对做梦时长的测量 有12个动物有缺失值 来依次阐述三种类型 1 完全随机缺失 若某变量的缺失数据与其他
  • 问题(四)No matching distribution found for anyjson==0.3.3

    前言 本章主要讲述安装anyjson时提示 No matching distribution found for anyjson 0 3 3 的解决方案 一 问题描述 描述 批量下载第三方包时 提示 找不到anyjson0 3 3的匹配分布
  • 卷积神经网络识别花卉并分类另保存

    本篇博客转载自卷积神经网络训练花卉识别分类器 本篇博客的所有代码已上传至GitHub仓库 后续会更新各个文件夹及文件的详细说明 用者自取 由于卷积神经网络训练花卉识别分类器博客已将模型的训练 测试代码写好 且可以通过这篇博客获取到大神训练好