mmclassification数据集并训练

2023-11-17

mmclassification数据集并训练

1.数据集准备

import numpy as np
import os
import shutil

生成train.txt和val.txt

train_path = './train'
train_out = './train.txt'
val_path = './valid'
val_out = './val.txt'

data_train_out = './train_filelist'
data_val_out = './val_filelist'


def get_filelist(input_path,output_path):
    with open(output_path,'w') as f:
        for dir_path,dir_names,file_names in os.walk(input_path):   #dir_path 文件夹作为标签
            if dir_path != input_path:
                label = int(dir_path.split('\\')[-1]) -1  
            #print(label)
            for filename in file_names:
                f.write(filename +' '+str(label)+"\n")

def move_imgs(input_path,output_path):
    for dir_path, dir_names, file_names in os.walk(input_path):
        for filename in file_names:
            #print(os.path.join(dir_path,filename))
            source_path = os.path.join(dir_path,filename)

            shutil.copyfile(source_path, os.path.join(output_path,filename))

get_filelist(train_path,train_out)
get_filelist(val_path,val_out)
move_imgs(train_path,data_train_out)
move_imgs(val_path,data_val_out)

dataset

import numpy as np

from .builder import DATASETS
from .base_dataset import BaseDataset

修改configs文件

# Copyright (c) OpenMMLab. All rights reserved.
from .base_dataset import BaseDataset
from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
                      build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
from .multi_label import MultiLabelDataset
from .samplers import DistributedSampler, RepeatAugSampler
from .voc import VOC
from .my_filelist import MyFilelist  #右边是我的py文件的名字,右边是我自己起的类名

# 初始化函数,注册机制
__all__ = [
    'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
    'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
    'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
    'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'MyFilelist'
]
@DATASETS.register_module()
class MyFilelist(BaseDataset):
    CLASSES = [
        '有102个。。。',  #0号
        '我懒得写名字了,有102个。。。',  #1号
           ]

#加载标注
    def load_annotations(self):
        assert isinstance(self.ann_file, str)
#这些东西根据自己的数据自己写的
        data_infos = []
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                info = {'img_prefix': self.data_prefix}
                info['img_info'] = {'filename': filename}
                info['gt_label'] = np.array(gt_label, dtype=np.int64)
                data_infos.append(info)
            return data_infos

修改mmclassification代码

#-*-coding:utf-8-*-
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=102, #
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        topk=(1, 5)))
dataset_type = 'ImageNet'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='RandomResizedCrop', size=224),  #look capbility of GPU
    dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='ToTensor', keys=['gt_label']),
    dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='Resize', size=(256, -1)),
    dict(type='CenterCrop', crop_size=224),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='ImageToTensor', keys=['img']),
    dict(type='Collect', keys=['img'])
]
data = dict(
    samples_per_gpu=8,
    workers_per_gpu=2,
    train=dict(
        type='MyFilelist',
        data_prefix='./train_filelist',   #site of train
        ann_file='./train.txt', 
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='RandomResizedCrop', size=224),
            dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='ToTensor', keys=['gt_label']),
            dict(type='Collect', keys=['img', 'gt_label'])
        ]),
    val=dict(
        type='MyFilelist',
        data_prefix='F:./val_filelist',
        ann_file='./val.txt',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='Resize', size=(256, -1)),
            dict(type='CenterCrop', crop_size=224),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ]),
    test=dict(
        type='MyFilelist',
        data_prefix='./val_filelist',
        ann_file='./val.txt',  # ?????
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(type='Resize', size=(256, -1)),
            dict(type='CenterCrop', crop_size=224),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='ImageToTensor', keys=['img']),
            dict(type='Collect', keys=['img'])
        ]))
evaluation = dict(interval=1, metric='accuracy')
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='step', step=[30, 60, 90])
runner = dict(type='EpochBasedRunner', max_epochs=100)
checkpoint_config = dict(interval=50)
log_config = dict(interval=100, hooks=[dict(type='TextLoggerHook')])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
work_dir = './work_dirs/resnet18_8xb32_in1k'
gpu_ids = [0]

训练成功在这里插入图片描述

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

mmclassification数据集并训练 的相关文章

随机推荐

  • 【STM32】STM32F103基于CubeIDE移植ThreadX

    前言 本来ThreadX全家桶是无缝接入STM32单片机的 但是今天突然发现ST官方没有X CUBE AZRTOS F1这个软件包 使用CubeMX添加工程组件的时候 也没有ThreadX可以选择 因此就有了此文 Keil环境下的移植可以参
  • ad pcb界面变成灰色无法编辑

    在pcb库中编译后有一个error 点之后就无法编辑pcb了 解决方法 右键 gt 过滤器 filter gt 清除过滤器 clear filter
  • 【CVPR2022论文精读DiffFace】DiffFace: Diffusion-based Face Swapping with Facial Guidance

    CVPR2022论文精读DiffFace DiffFace Diffusion based Face Swapping with Facial Guidance 0 前言 Abstract 1 Introduction 2 Related
  • Python数据分析--读取npz文件

    使用pycharm的朋友们 需要在解释器来安装相应库 有想练手的朋友 用百度网盘 链接 https pan baidu com s 1aOTPRsqkgX4isGDhMjLdlQ 提取码 1234 国民生产总值案例 读取npz文件 第一步
  • caffe SigmoidLayer 学习

    SimgoidLayer 的定义 neuron layer h template
  • 一文教你如何编写测试用例

    一 通用测试用例八要素 1 用例编号 2 测试项目 3 测试标题 4 重要级别 5 预置条件 6 测试输入 7 操作步骤 8 预期输出 二 具体分析通用测试用例八要素 1 用例编号 一般是数字和字符组合成的字符串 可以包括 下划线 单词缩写
  • 【学习笔记】 pytorch的使用语法和代码实例

    数据类型 1 torch FloatTensor 用于生成数据类型为浮点型的Tensor 传递给torch FloatTensor的参数可以是一个列表 也可以是一个维度值 torch randn 用于生成数据类型为浮点型且维度指定的随机Te
  • Java解析cron表达式实战

    目录 前言 实战 依赖 code 执行结果 前言 前面讲了CentOS中安装crontab以及cron表达式的规则说明 在实际开发中我们经常会用到 有时候我们懒得记规则的时候 我们就会用一些工具网站去解析 例如我常用的 https www
  • Vulhub靶场环境搭建

    在Ubantu系统上搭建靶场环境 一 ubantu系统准备 1 更新安装列表 sudo apt get update 2 安装docker io sudo apt install docker io 查看是否安装成功 docker v 3
  • Centos二进制安装Geth以太坊客户端

    环境准备 yum install git yum install golang 获取二进制包 网站 https geth ethereum org downloads wget https gethstore blob core windo
  • spring 5.x 系列第9篇 —— 整合mongodb (xml配置方式)

    一 项目说明 1 1 项目结构 配置文件位于 resources 下 项目以单元测试的方式进行测试 1 2 相关依赖 除了 Spring 的基本依赖外 需要导入 MongoDB 的整合依赖
  • JSON使用示例

    1 什么是json JSON 说白了就是JavaScript用来处理数据的一种格式 这种格式非常简单易用 JSON支持的语言非常多 包括JavaScript C PHP Java等等 这是由于JSON是独立于语言的轻量级的数据交换格式 2
  • webpack性能优化,CDN内容分发分发网络

    CDN英文全称Content Delivery Network 中文翻译即为内容分发网络 当用户输入url后 首先向LDNS 本地DNS 发起域名解析请求 LDNS检查缓存中是否有该url的IP地址记录 如果有 则直接返回给用户 如果没有
  • MongoDB入门

    MongoDB MongoDB相关概念 业务应用场景 传统的关系型数据库 如MySQL 在数据操作的 三高 需求以及应对Web2 0的网站需求面前 显得力不从心 解释 三高 需求 High performance 对数据库高并发读写的需求
  • (十六)ADC转换实验

    本节主要是回顾有关于ADC的对应内容 我们这章通过一个AD芯片xpt2046来读取外部电压的变化 将电压的数字量显示在数码管上 关于ADC 我们都知道单片机内部都是数字量 就是1或者0 而我们的电流电压在传递的时候是模拟量 也就是模拟量很可
  • 由于找不到 libmmd.dll,无法继续执行代码。试试替换libmmd.dll文件可能会解决此问题

    由于找不到 libmmd dll 无法继续执行代码 重新安装程序可能会解决此问题 解决方法 1 右键桌面快捷图标 打开文件所在的位置 在这个文件夹下搜索libmmd dll 2 将搜索出来的libmmd dll复制到MAXSON下的CINE
  • NLTK Downloader出现 [Error 11004]getaddrinfo failed的错误时怎么解决

    当打开NLTK下载器时 弹出 Error 11004 getaddrinfo failed的提示窗口 打开NLTK下载器 import nltk nltk download 出现这样的问题时要怎么解决 很多人都走错了思路导致浪费了不少时间在
  • 音视频开发基础概述 - PCM、YUV、H264、常用软件介绍

    前言 相对而言 音视频开发算是有些门槛的 记得我第一次接触的时候 看别人的博客都看不懂 特别是写代码的时候 非常痛苦 只能抄别人的代码 却不知道为什么要这么写 也不知道应该怎么调整 后来总结了一下 痛苦的原因是在写代码之前没有掌握相关的基础
  • 【2023最新版】Windows11家庭版:安卓子系统(WSA)安装及使用教程【全网最详细】

    目录 一 准备工作 1 检查虚拟化功能 2 找到 Wndows功能 3 启用Hyper V和虚拟机平台 4 家庭版安装Hyper V 若步骤3找不到Hyper V 二 安装安卓子系统 1 进入开发者选项 2 下载Windows Subsys
  • mmclassification数据集并训练

    mmclassification数据集并训练 1 数据集准备 import numpy as np import os import shutil 生成train txt和val txt train path train train out