数据集包含6种垃圾,分别为cardboard(纸箱),glass(玻璃)、metal(金属)、paper(纸)、plastic(塑料)、其他废品(trash),数据数量较小,仅供学习。
数据集标准备工作,包括将数据集分为训练集和测试集,制作标签文件。代码utils.py
import os
import shutil
import json
path="e://dataset//Garbage_classification"#此路径为上图中六类的目录,可根据自己数据集路径修改
classes=[garbage for garbage in os.listdir(path)]
if os.path.exists(os.path.join(os.getcwd(),'train'))==False:
os.makedirs(os.path.join(os.getcwd(),'train'))
if os.path.exists(os.path.join(os.getcwd(),'val'))==False:
os.makedirs(os.path.join(os.getcwd(),'val'))
f = open("garbage_train.json", 'w')
g = open("garbage_val.json", 'w')
for garbage in classes:
s = 0
for imgname in os.listdir(os.path.join(path,garbage)):
if s%7!=0:
data = {'name': imgname, 'label':classes.index(garbage)}
jsondata = json.dumps(data)
f.write(jsondata)
shutil.copy(os.path.join(path, garbage, imgname),os.path.join(os.getcwd(),'train'))
else:
data = {'name': imgname, 'label': classes.index(garbage)}
jsondata = json.dumps(data)
g.write(jsondata)
shutil.copy(os.path.join(path, garbage, imgname),os.path.join(os.getcwd(),'val'))
s+=1
运行上述代码会生成下图的文件夹。
接下来,我们写一个数据集预处理的类,data.py. root是上图处理得到的数据集的根目录,datajson是两个json文件夹
from PIL import Image
import torch
import os
import json
class MyDataset(torch.utils.data.Dataset): # 创建自己的类:MyDataset,这个类是继承的torch.utils.data.Dataset
def __init__(self, root, datajson, transform=None, target_transform=None): # 初始化一些需要传入的参数
super(MyDataset, self).__init__()
fh = open(datajson, 'r') # 按照传入的路径和txt文本参数,打开这个文本,并读取内容
load_dict = json.load(fh)
imgs = [] # 创建一个名为img的空列表,一会儿用来装东西
for line in load_dict: # 按行循环txt文本中的内容
#line = line.rstrip()# 删除 本行string 字符串末尾的指定字符,这个方法的详细介绍自己查询python
#words = line.split() # 通过指定分隔符对字符串进行切片,默认为所有的空字符,包括空格、换行、制表符等
imgs.append((line['name'], int(line['label']))) # 把txt里的内容读入imgs列表保存,具体是words几要看txt内容而定
self.root=root
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index] # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
img = Image.open(os.path.join(self.root,fn)).convert('RGB') # 按照path读入图片from PIL import Image # 按照路径读取图片
if self.transform is not None:
img = self.transform(img) # 是否进行transform
return img, label # return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容
def __len__(self): # 这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
return len(self.imgs)
再定义一下resnet网络。resnet.py ,这里需要说明一下,由于数据集不够大,很多图片没有超过224,我拟定输入为112,这里有多种resnet系列选择,我用的是最简单的resnet18.
import torch
import torch.nn as nn
class BasicBlock(nn.Module):
"""Basic Block for resnet 18 and resnet 34
"""
#BasicBlock and BottleNeck block
#have different output size
#we use class attribute expansion
#to distin