利用pytorch训练网络---垃圾分类,(resnet18)

2023-11-19

数据集包含6种垃圾,分别为cardboard(纸箱),glass(玻璃)、metal(金属)、paper(纸)、plastic(塑料)、其他废品(trash),数据数量较小,仅供学习。

数据集标准备工作,包括将数据集分为训练集和测试集,制作标签文件。代码utils.py

import os
import shutil
import json
path="e://dataset//Garbage_classification"#此路径为上图中六类的目录,可根据自己数据集路径修改
classes=[garbage for garbage in os.listdir(path)]

if os.path.exists(os.path.join(os.getcwd(),'train'))==False:
    os.makedirs(os.path.join(os.getcwd(),'train'))
if os.path.exists(os.path.join(os.getcwd(),'val'))==False:
    os.makedirs(os.path.join(os.getcwd(),'val'))
f = open("garbage_train.json", 'w')
g = open("garbage_val.json", 'w')
for garbage in classes:
    s = 0
    for imgname in os.listdir(os.path.join(path,garbage)):

        if s%7!=0:
            data = {'name': imgname, 'label':classes.index(garbage)}
            jsondata = json.dumps(data)
            f.write(jsondata)
            shutil.copy(os.path.join(path, garbage, imgname),os.path.join(os.getcwd(),'train'))
        else:
            data = {'name': imgname, 'label': classes.index(garbage)}
            jsondata = json.dumps(data)
            g.write(jsondata)
            shutil.copy(os.path.join(path, garbage, imgname),os.path.join(os.getcwd(),'val'))
        s+=1

运行上述代码会生成下图的文件夹。

接下来,我们写一个数据集预处理的类,data.py.  root是上图处理得到的数据集的根目录,datajson是两个json文件夹

from PIL import Image
import torch
import os
import json
class MyDataset(torch.utils.data.Dataset):  # 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
    def __init__(self, root, datajson, transform=None, target_transform=None):  # 初始化一些需要传入的参数
        super(MyDataset, self).__init__()
        fh = open(datajson, 'r')  # 按照传入的路径和txt文本参数,打开这个文本,并读取内容
        load_dict = json.load(fh)
        imgs = [] # 创建一个名为img的空列表,一会儿用来装东西
        for line in load_dict: # 按行循环txt文本中的内容

            #line = line.rstrip()# 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
            #words = line.split()  # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
            imgs.append((line['name'], int(line['label'])))  # 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定

        self.root=root
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform


    def __getitem__(self, index):
          fn, label = self.imgs[index]  # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
          img = Image.open(os.path.join(self.root,fn)).convert('RGB')  # 按照path读入图片from PIL import Image # 按照路径读取图片

          if self.transform is not None:
              img = self.transform(img)  # 是否进行transform
          return img, label  # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

    def __len__(self):  # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
        return len(self.imgs)

再定义一下resnet网络。resnet.py ,这里需要说明一下,由于数据集不够大,很多图片没有超过224,我拟定输入为112,这里有多种resnet系列选择,我用的是最简单的resnet18.

import torch
import torch.nn as nn

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34

    """

    #BasicBlock and BottleNeck block 
    #have different output size
    #we use class attribute expansion
    #to distin
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

利用pytorch训练网络---垃圾分类,(resnet18) 的相关文章

  • 图像分类(1),数据预处理

    本文介绍如何使用pytorh利用预训练模型进行图像分类 主要参考Transfer Learning Tutorial和 具体代码可以参考Image classification 下载代码文件 git clone https github c
  • 经典卷积神经网络(CNN)图像分类算法详解

    本文原创 转载请引用 https blog csdn net dan teng article details 87192430 CNN图像分类网络 一点废话 CNN网络主要特点是使用卷积层 这其实是模拟了人的视觉神经 单个神经元只能对某种
  • Shuffle Net系列【V1—V2】

    1 ShuffleNet V1 1 1 Abstract 我们提出了一个极其效率的CNN架构 ShuffleNet 其专为计算能力非常有限的移动设备设计 这个新的架构利用了两个新的操作 pointwise group conv和channe
  • 图像分类_PyTorch图像数据分类

    图像分类数据集中最常用的是手写数字识别数据集MNIST 但大部分模型在MNIST上的分类精度都超过了95 为了更直观地观察算法之间的差异 我们将使用一个图像内容更加复杂的数据集Fashion MNIST 这个数据集也比较小 只有几十M 没有
  • ILSVRC竞赛详细介绍(ImageNet Large Scale Visual Recognition Challenge)

    ILSVRC ImageNet Large Scale Visual Recognition Challenge 是近年来机器视觉领域最受追捧也是最具权威的学术竞赛之一 代表了图像领域的最高水平 ImageNet数据集是ILSVRC竞赛使用
  • 保姆级使用PyTorch训练与评估自己的ResNeXt网络教程

    文章目录 前言 0 环境搭建 快速开始 1 数据集制作 1 1 标签文件制作 1 2 数据集划分 1 3 数据集信息文件制作 2 修改参数文件 3 训练 4 评估 5 其他教程 前言 项目地址 https github com Fafa D
  • Yolov5-7.0图像分类算法修改Resnet18/50主干网络流程

    网上大多数都是基于yolov5算法的目标检测网络进行修改主干网络 我最近在尝试图像分类算法 流程如下 以resnet50为例 1 打开models下的common py文件 添加下面的代码 模型 resnet50 class resnet5
  • ConvNeXt网络详解

    ConvNeXt 论文名称 A ConvNet for the 2020s 论文下载链接 https arxiv org abs 2201 03545 论文对应源码链接 https github com facebookresearch C
  • 图像分类之花卉图像分类(二)数据预处理代码

    经过上一节数据增强 我们来说说数据预处理吧 首先我们要知道图片进入网络训练都是要统一大小格式的 所以我们需要对训练集和验证集的图片进行裁剪 让他们大小统一 注意测试集不用裁剪 我选择裁剪成了64 64的 没改源码的裁剪大小 其实图片大些识别
  • 图像分类,物体检测,语义分割,实例分割的联系和区别

    从10月中旬开始 科研转为 Object Segment 即物体分割 这属于图像理解范畴 图像理解包含众多 如图像分类 物体检测 物体分割 实例分割等若干具体问题 每个问题研究的范畴是什么 或者说每个问题中 对于某幅图像的处理结果是什么 整
  • 人脸图像数据增强

    为什么要做数据增强 在计算机视觉相关任务中 数据增强 Data Augmentation 是一种常用的技术 用于扩展训练数据集的多样性 它包括对原始图像进行一系列随机或有规律的变换 以生成新的训练样本 数据增强的主要目的是增加模型的泛化能力
  • 模型实战(6)之Alex实现图像分类:模型原理+训练+预测(详细教程!)

    Alex实现图像分类 模型原理 训练 预测 图像分类或者检索任务在浏览器中的搜索操作 爬虫搜图中应用较广 本文主要通过Alex模型实现猫狗分类 并且将可以复用的开源模型在文章中给出 数据集可以由此下载 Data 本文将从以下内容做出讲述 1
  • 计算机视觉系列-2-图像分类

    给定一张输入图像 图像分类的任务是判断该图像属于哪类 如果是多任务分类 可以用于分类该图像包含哪个类别 深度学习作为机器学习中非常重要的分支 在图像领域中应用非常广泛 在图像分类任务中 通常采用卷积层 CNN 提取特征 加上全连接层进行分类
  • 通用图片分类项目

    generalImageClassification 文章目录 generalImageClassification 1 数据准备 1 1 开源数据集 1 2 利用特定网站爬数据 2 分类模型的选择 3 代码结构及使用方法 3 1 代码结构
  • mmclassification

    mmclassification 一 MMCLS项目 0 下载链接 Torch安装方法 CPU pip install torch i https download pytorch org whl torch stable html 指定清
  • 深度学习与计算机视觉[CS231N] 学习笔记(4.1):反向传播(Backpropagation)

    在学习深度学习的过程中 我们常用的一种优化参数的方法就是梯度下降法 而一般情况下 我们搭建的神经网络的结构是 输入 权重矩阵 损失函数 如下图所示 而在给定输入的情况下 为了使我们的损失函数值达到最小 我们就需要调节权重矩阵 使之满足条件
  • DenseNet学习与实现

    Densely Connected Convolutional Networks 提出了DenseNet 它用前馈的方式连接每一层与所有其他层 L层网络共有 L L 1
  • CNN中特征融合的一些策略

    Introduction 特征融合的方法很多 如果数学化地表示 大体可以分为以下几种 X Y textbf X textbf Y X Y X
  • 利用pytorch训练网络---垃圾分类,(resnet18)

    数据集包含6种垃圾 分别为cardboard 纸箱 glass 玻璃 metal 金属 paper 纸 plastic 塑料 其他废品 trash 数据数量较小 仅供学习 数据集标准备工作 包括将数据集分为训练集和测试集 制作标签文件 代码
  • 保姆级使用PyTorch训练与评估自己的ConvNeXt网络教程

    文章目录 前言 0 环境搭建 快速开始 1 数据集制作 1 1 标签文件制作 1 2 数据集划分 1 3 数据集信息文件制作 2 修改参数文件 3 训练 4 评估 5 其他教程 前言 项目地址 https github com Fafa D

随机推荐

  • 目录以及目录下的重要文件

    第1章 mv mv move 移动 remove 移除 类似于windows的剪切 语法格式 mv 参数选项 源文件 目标文件 mv 源文件 新的文件名 root oldboyedu lnb 移动 oldboy下的oldboy txt文件
  • css line-height

    项目中看到 line height 1 所以来总结一下 line height 属性 line height 定义 line height 属性设置行间的距离 行高 line height 不允许使用负值 属性可能的值 值 描述 norma
  • RT-Thread操作系统 AT组件源码分析(以 EC20 为例)

    文章目录 1 AT 组件 1 1 AT 组件调试信息级别设置 1 2 AT 命令打印使能设置 1 3 GPRS 网络注册状态检查 1 4 EC200x 是否能连接外网日志输出 1 5 AT 设备注册过程 1 6 AT 设备类注册过程 1 7
  • win10java jdk安装

    配置环境变量 右击我的电脑 属性 点击Path 新建 添加最后两行到自己的Path环境中 转载链接 https www cnblogs com suni p 8279672 html
  • 求二叉树中两个指定节点的最短距离

    给定一个二叉树 找到该树中两个指定节点间的最短距离 思路 求最近公共祖先节点 然后再求最近公共祖先节点到两个指定节点的路径 再求两个节点的路径之和 const shortestDistance function root p q let l
  • Vue练习题(带解析)

    Vue基础入门 一 填空题 Vue是一套构建 用户界面 的渐进式框架 MVVM主要包含3个部分 分别是Model View和 ViewModel Vue中通过 refs 属性获取相应DOM元素 在进行Vue调试时 通常使用 vue devt
  • Qt实现隐藏按钮功能

    在Qt design界面中添加pushbutton 按钮 选中pushbutton 在右下角有按钮属性相关的修改内容 选中flat 按钮外围边框已消失 此时还差一步 需要修改一下 找到stylesheet 输入以下内容 输入 backgro
  • Go并发异步请求秀动抢票

    继上次python请求秀动接口 这次我将采用性能最佳的Go语言重构 tips 因分享了太多人 有人以此向外获利 所以停止分享 之前采用python异步请求 三次请求购票接口的思路 鉴于秀动app的防护措施愈来愈强 我将采用发挥go语言的协程
  • 实现一个在线抽奖系统,就算是个小白看了也能做出来(附源码)

    在线抽奖系统 1 项目介绍 1 功能介绍 2 开发环境与技术栈 3 项目演示 2 项目准备 1 代码框架 源码 2 数据库设计 3 后端对前端接口的实现 1 用户的登录 注册 注销 2 查询奖项设置 修改抽奖人数 3 新增 修改 删除奖项
  • JavaScript Table行填充

    使用JS脚本操作Table元素 在不同浏览器中操作方法不尽相同 当新建一行之后 IE中可以使用单元格操作来完成单元格的添加 而在FireFox中无法正确通过单元格来操作 而只能使用 td td 来实现 因此 在编写填充函数时 要注意检测浏览
  • 基础算法题——帅到没朋友(唯一性)

    帅到没朋友 当芸芸众生忙着在朋友圈中发照片的时候 总有一些人因为太帅而没有朋友 本题就要求你找出那些帅到没有朋友的人 输入格式 输入第一行给出一个正整数N 100 是已知朋友圈的个数 随后N行 每行首先给出一个正整数K 1000 为朋友圈中
  • 用TensorFlow.js实现AI换脸 !所以你知道某些网站视频的明星是怎么来的了吗?

    前言 相信很多小伙伴对TensorFlow js早已有所耳闻 它是一个基于JavaScript的深度学习库 可以在Web浏览器中运行深度学习模型 AI换脸是一种基于深度学习的图像处理技术 将一张人脸照片的表情 头发 嘴唇等特征转移到另一张人
  • python遇到can not import xxx错误

    一种不容易被发现的问题是循环引用导致该问题的发生 具体可参考 ImportError cannot import name xxxxxx 的三种类型的解决方法 Activewaste的博客 CSDN博客 cannot import name
  • Android NDK 编译 三方库记录 及 jni库封装问题

    因工作需求 要将原先的c 库跨平台编译 在Android上运行 其依赖了几个第三方库 也需要一起编译 在此做个记录 所需工具 centos 系统上完成 1 cmake 3 15 6 2 ndk android ndk r21e NDK 下载
  • Python自动抢红包,从此再也不会错过微信红包了!

    作者 上海小胖 来源 Python专栏 ID xpchuiit 目录 0 引言 1 环境 2 需求分析 3 前置准备 4 抢红包流程回顾 5 代码梳理 6 后记 0 引言 提到抢红包 就不得不提Xposed框架 它简直是个抢红包的神器 但使
  • windows禁用输入法

    Rime 呼出菜单的快捷键 ctrl grave 跟 vs code 呼出底部命令行的快捷键冲突了 每次用 vs code 时都会用 ctrl space 将输入法禁用 让它变成一个圈叉 由 1 这个快捷键是 windows 系统禁用输入法
  • vue中的事件修饰符

    vue中的事件修饰符 1 prevent 阻止默认事件 常用 a href http www baidu com a 2 stop 阻止事件冒泡 常用 margin top 20px demo1 height 50px background
  • Es中查询数据存在某个字段或者数据的不存在某个字段(must_not,must的使用)

    一 存在 二 不存在 包含两种意思 1 这条数据根本就没有这个字段 2 这条数据的字段的值为null
  • 区块链大作业前期热身报告

    作业内容 使用已有的开源区块链系统FISCO BCOS 完成私有链的搭建以及新节点的加入 截图说明搭建流程 自行编写一个智能合约并部署到私有链上 同时完成合约调用 截图说明部署流程 使用命令查看一个区块 并对各个字段进行解释 单群组FISC
  • 利用pytorch训练网络---垃圾分类,(resnet18)

    数据集包含6种垃圾 分别为cardboard 纸箱 glass 玻璃 metal 金属 paper 纸 plastic 塑料 其他废品 trash 数据数量较小 仅供学习 数据集标准备工作 包括将数据集分为训练集和测试集 制作标签文件 代码