Resnet18卷积神经网络实现图片分类算法(代码全注释)

2023-11-11

1.类的定义

import torch.nn as nn
import torch


class BasicBlock(nn.Module):
    expansion = 1#是否可以调用

    def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
                               kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channel)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
                               kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channel)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        if self.downsample is not None:#残差结构实虚判断
            identity = self.downsample(x)#虚函数

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self,
                 block,#残差结构
                 blocks_num,#残差结构数目
                 num_classes=1000,#训练集分类数目
                 include_top=True,#复杂可选
                 groups=1,
                 width_per_group=64):
        super(ResNet, self).__init__()
        self.include_top = include_top
        self.in_channel = 64#卷积核个数

        self.groups = groups#不用
        self.width_per_group = width_per_group#不用

        self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,#输入通道数 卷积核个数
                               padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channel)#通道个数和卷核个数一样
        self.relu = nn.ReLU(inplace=True)#激活
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)#池化
        self.layer1 = self._make_layer(block, 64, blocks_num[0])
        self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
        self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
        self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
        if self.include_top:#默认为Ture
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)#平均池化
            self.fc = nn.Linear(512 * block.expansion, num_classes)#全连接层 512*1 分类数

        for m in self.modules():#卷积层初始化!!!
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def _make_layer(self, block, channel, block_num, stride=1):#通过上面定义好的残差结构造层 残差结构 卷积核个数 残差结构个数 步长默认为一
        downsample = None#默认实
        if stride != 1 or self.in_channel != channel * block.expansion:#用不到
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channel * block.expansion))

        layers = []
        layers.append(block(self.in_channel,#第一层虚线残差结构
                            channel,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        self.in_channel = channel * block.expansion

        for _ in range(1, block_num):#后面全是实线残差结构
            layers.append(block(self.in_channel,
                                channel,
                                groups=self.groups,
                                width_per_group=self.width_per_group))

        return nn.Sequential(*layers)#*列表去括号

    def forward(self, x):#正向传播
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        if self.include_top:
            x = self.avgpool(x)
            x = torch.flatten(x, 1)#压平 变成一维矩阵
            x = self.fc(x)#全连接

        return x


def resnet18(num_classes=1000, include_top=True):
    # https://download.pytorch.org/models/resnet34-333f7ec4.pth
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, include_top=include_top)

2.训练分类模型

 

import os
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm

from sun import resnet18


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")#优先使用Gpu0如果有 没有则cpu
    print("using {} device.".format(device))

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),#随机裁剪224*224
                                     transforms.RandomHorizontalFlip(),#随机翻转
                                     transforms.ToTensor(),#转化成矩阵吧 !!
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#标准化处理
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}

    #data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  #os.getcwd()获取当前文件目录 ..返回上一级目录../..返回上两级目录 abspath()返回绝对路径
    #data_root=r"D:\artificial intelligence\Cat and dog recognition\project_data"
    #image_path = os.path.join(data_root, "train1")  # flower data set path
    image_trpath_folder=r"D:\artificial intelligence\Cat and dog recognition\project_data\train"
    #assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=image_trpath_folder,#图片文件夹加载
                                         transform=data_transform["train"])
    train_num = len(train_dataset)

    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx#获取图片类型的字典
    cla_dict = dict((val, key) for key, val in flower_list.items())#将字典反过来
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)#编码成json
    with open('class_indices.json', 'w') as json_file:#写入json文件
        json_file.write(json_str)

    batch_size = 16#批处理数量处理
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,#加载图片
                                               batch_size=batch_size, shuffle=True,#批处理数量
                                               # 16
                                               num_workers=nw)#!!!

    image_vapath_folder=r"D:\artificial intelligence\Cat and dog recognition\project_data\test"#测试文件夹加载
    validate_dataset = datasets.ImageFolder(root=image_vapath_folder,#预处理
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)
    validate_loader = torch.utils.data.DataLoader(validate_dataset,#图片加载
                                                  batch_size=batch_size, shuffle=False,#不洗牌
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))

    net = resnet18()#resnet网络对象
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet18_pre.pth"
    #assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    #net.load_state_dict(torch.load(model_weight_path, map_location=device))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 2)#全连接层
    net.to(device)

    # 损失函数
    loss_function = nn.CrossEntropyLoss()

    #
    #params = [p for p in net.parameters() if p.requires_grad]
    optimizer = optim.Adam(net.parameters(), lr=0.0001)#学习rate优化器

    epochs = 20#训练次数
    best_acc = 0.0#准确rate初始化
    save_path = './resNet18.pth'#权重保存
    train_steps = len(train_loader)
    for epoch in range(epochs):#
        # train
        net.train()#训练
        running_loss = 0.0
        train_bar = tqdm(train_loader)#添加训练进度条 返回迭代器
        """ 
        enumerate()
        names = ["Alice","Bob","Carl"]
        for index,value in enumerate(names):
            print(f'{index}: {value}')
        

        """
        for step, data in enumerate(train_bar):
            images, labels = data
            optimizer.zero_grad()#清空之前的梯度信息
            logits = net(images.to(device))#这参数
            loss = loss_function(logits, labels.to(device))
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        net.eval()# 启用dropout方法
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                # loss = loss_function(outputs, test_labels)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,
                                                           epochs)

        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

3.利用训练好的模型进行图片的识别和分类 

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import json
from sun import resnet18
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
preimage_path=r"D:\artificial intelligence\Cat and dog recognition\project_data\test\test1\182.jpg"
image=Image.open(preimage_path)
plt.imshow(image)#显示图片格式
image=data_transform(image)
image=torch.unsqueeze(image,dim=0)#扩充维度batch 第0个维度一张图片3通道 224*224
try:
    json_file=open('./class_indices.json','r')
    class_indict=json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

model=resnet18(num_classes=2)
model_weight_path="./resNet18.pth"
model.load_state_dict(torch.load(model_weight_path))#载入模型
model.eval()#关闭dropout

with torch.no_grad():#默认进行反向传播 这个不进行
    output=torch.squeeze(model(image))#模型载入图片计算并压缩
    pre_list=torch.softmax(output,dim=0)#在列上进行概率计算
    pre=torch.argmax(pre_list).numpy()#取最大
print(class_indict[str(pre)],pre_list[pre].item())
plt.show()

总结:本人第一次写博文,希望大家多多支持,有不懂得地方可以找我交流或者评论区留言。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Resnet18卷积神经网络实现图片分类算法(代码全注释) 的相关文章

  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • 打印 scrapy 请求的“响应”

    我正在尝试学习 scrapy 在遵循教程的同时 我正在尝试进行细微的调整 我想简单地从请求中获取响应内容 然后我会将响应传递到教程代码中 但我无法发出请求并获取响应内容 建议就好 from scrapy http import Respon
  • 替换字符串列表中的 \x00 的最佳方法?

    我有一个来自已解析 PE 文件的值列表 其中包括 x00每个部分末尾的空字节 我希望能够删除 x00字符串中的字节而不删除所有字节 x 文件中的 s 我试过做 replace and re sub 但并没有取得太大成功 使用Python 2
  • 如何在 Ubuntu 上安装 Python 模块

    我刚刚用Python写了一个函数 然后 我想将其做成模块并安装在我的 Ubuntu 11 04 上 这就是我所做的 创建 setup py 和 function py 文件 使用 Python2 7 setup py sdist 构建分发文
  • 如何自动替换多个文件的文本内容中的字符?

    我有一个文件夹 myfolder包含许多乳胶表 我需要替换其中每个字符 即替换任何minus sign by an en dash 只是为了确定 我们正在替换连字符INSIDE该文件夹中的所有 tex 文件 我不关心 tex 文件名 手动执
  • Python 中 genfromtxt() 的可变列数?

    我有一个 txt具有不同长度的行的文件 每一行都是代表一条轨迹的一系列点 由于每条轨迹都有自己的长度 因此各行的长度都不同 也就是说 列数从一行到另一行不同 据我所知 genfromtxt Python 中的模块要求列数相同 gt gt g
  • Python3 查找 2 个列表中有多少个差异才能相等

    假设我们有 2 个列表 always具有相同的长度和always包含字符串 list1 sot sot ts gg gg gg list2 gg gg gg gg gg sot 我们需要找到 其中有多少项list2应该改变 以便它等于lis
  • python ttk treeview:如何选择并设置焦点在一行上?

    我有一个 ttk Treeview 小部件 其中包含一些数据行 如何设置焦点并选择 突出显示 指定项目 tree focus set 什么也没做 tree selection set 0 抱怨 尽管小部件明显填充了超过零个项目 但未找到项目
  • 如何为多组精灵创建随机位置?

    我尝试使用 blit 和 draw 方法进行 for 循环 并为 PlayerSprite 和 Treegroup 使用不同的变量 for PlayerSprite in Treegroup surface blit PlayerSprit
  • 字典的嵌套列表

    我正在尝试创建dict通过嵌套list groups Group1 A B Group2 C D L y x 0 for y in x if y x 0 for x in groups d k v for d in L for k v in
  • Python int 太大,无法放入 SQLite

    我收到错误 OverflowError Python int 太大 无法转换为 SQLite INTEGER 来自以下代码块 该文件约25GB 因此必须分部分读取 length 6128765 Works on partitions of
  • ValueError:无法插入 ID,已存在

    我有这个数据 ID TIME 1 2 1 4 1 2 2 3 我想按以下方式对数据进行分组ID并计算每组的平均时间和规模 ID MEAN TIME COUNT 1 2 67 3 2 3 00 1 如果我运行此代码 则会收到错误 ValueE
  • python中的sys.stdin.fileno()是什么

    如果这是非常基本的或之前已经问过的 我很抱歉 我用谷歌搜索但找不到简单且令人满意的解释 我想知道什么sys stdin fileno is 我在代码中看到了它 但不明白它的作用 这是实际的代码块 fileno sys stdin filen
  • 使用 lambda 函数更改属性值

    我可以使用 lambda 函数循环遍历类对象列表并更改属性值 对于所有对象或满足特定条件的对象 吗 class Student object def init self name age self name name self age ag
  • 是否可以强制浮点数的指数或有效数匹配另一个浮点数(Python)?

    这是我前几天试图解决的一个有趣的问题 是否可以强制一个的有效数或指数float与另一个人一样float在Python中 出现这个问题是因为我试图重新调整一些数据 以便最小值和最大值与另一个数据集匹配 然而 我重新调整后的数据略有偏差 大约小
  • CSV 在列中查找最大值并附加新数据

    大约两个小时前 我问了一个关于从网站读取和写入数据的问题 从那时起 我花了最后两个小时试图找到一种方法来从输出的 A 列读取最大日期值 将该值与刷新的网站数据进行比较 并将任何新数据附加到 csv 文件而不覆盖旧的或创建重复项 目前 100
  • 如何在单独的文件中使用 FastAPI Depends 作为端点/路由?

    我在单独的文件中定义了一个 Websocket 端点 例如 from starlette endpoints import WebSocketEndpoint from connection service import Connectio
  • 具有指定置信区间的 Seaborn 条形图

    我想在 Seaborn 条形图上绘制置信区间 但我已经计算出置信区间 如何让 Seaborn 绘制我的置信区间而不是尝试自行计算它们 例如 假设我有以下 pandas DataFrame x pd DataFrame Group 1 0 5
  • 从时间序列生成日期特征

    我有一个数据框 其中包含如下列 Date temp data holiday day 01 01 2000 10000 0 1 02 01 2000 0 1 2 03 01 2000 2000 0 3 30 01 2000 200 0 30
  • 使用 numpy 加速 for 循环

    下一个 for 循环如何使用 numpy 获得加速 我想这里可以使用一些奇特的索引技巧 但我不知道是哪一个 这里可以使用 einsum 吗 a 0 for i in range len b a numpy mean C d e f b i

随机推荐