Pytorch 自己搭建的一个神经网络

2023-11-11

目录

数据集
dogs Vs Cats

import time

import torch.nn as nn
import torch.optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import os
import torch.nn
import torchvision
from glob import glob
import numpy as np
from torch._C._cudnn import is_cuda
from torchvision import transforms
from torchvision.datasets import ImageFolder

# 数据集的起始目录
path = "./dogs-vs-cats/"
# 通过glob函数获取 目录下的所有jpg文件 *是通配符
files = glob(os.path.join(path, '*/*.jpg'))
# 输出文件个数
print(f'Total no of images{len(files)}')
# 记录文件个数
no_of_images = len(files)
# 创建可以用于创建验证数据集的混合索引
shuffle = np.random.permutation(no_of_images)
# os.mkdir(os.path.join(path,'valid'))
# 在这里对训练集的数据进行分类
for t in ['test1']:
    # 获取所有猫的图片
    files1 = glob(os.path.join(path + t, "/cat*.jpg"))
    print(files1)
    # 获取所有狗的图片
    files2 = glob(os.path.join(path + t, "/dog*.jpg"))
    temp = 0
    # 设置目录名
    dirs = ['cat', 'dog']
    # 根据目录 将对应图片分类存入
    for file in [files1, files2]:
        for sfile in file:
            print(sfile)
            folder1 = sfile.split("\\")[-1]
            folder2 = sfile.split("\\")[0]
            print(folder1, folder2)
            # 通过rename修改文件的路径
            os.rename(sfile, os.path.join(folder2, dirs[temp], folder1))
        temp += 1
# 用于混合排列文件
shuffle=np.random.permutation(no_of_images )
# transforms进行变换和加载图片
simple_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
train=ImageFolder("dogs-vs-cats/train",simple_transform)
# 设置训练集的标签
train.class_to_idx={"cat":0,"dog":1}
train.classes=['cat','dog']
# 加载测试集合
test=ImageFolder("dogs-vs-cats/test1",simple_transform)
# 设置测试集合的标签
test.class_to_idx={"cat":0,"dog":1}
test.classes=['cat','dog']
# 将train和test数据加载到 数据加载器中
train_data_gen=torch.utils.data.DataLoader(train,batch_size=64,num_workers=0)
test_data_gen=torch.utils.data.DataLoader(test,batch_size=64,num_workers=0)

def train_model(model,criterion,optimizer,scheduler,num_epochs=25):
    # 保留开始时间 方便后面计算
    since=time.time()
    best_model_wts=model.state_dict()
    best_acc=0.0
    for epoch in range(num_epochs):
        print('Epoch{}/{}'.format(epoch,num_epochs-1))
        print('-'*10)
        # 每一轮都在训练和检验
        for phase in ['train','test1']:
            # 模型设置为训练模式
            if phase == 'train':
                scheduler.step()
                model.train(True)
            #  模型设置为评估模式
            else:
                model.train(False)
            running_loss=0.0
            running_correct=0
        # 对数据加载器 进行迭代
        for data in train_data_gen:
            inputs,labels=data

            if is_cuda:
                inputs=Variable(inputs.cuda())
                labels=Variable(labels.cuda())
            else:
                inputs,labels=Variable(inputs),Variable(labels)
            # 梯度清零
            optimizer.zero_grad();
            
            # 向前
            outputs=model(inputs)
            _,preds=torch.max(outputs.data,1)
            loss=criterion(outputs,labels)
            
            # 训练的时候 反向优化
            if phase=="train":
                loss.backward()
                optimizer.step()
                scheduler.step()
            
            # 统计
            running_loss+=loss.item()
            running_correct+=torch.sum(preds==labels.data)
            import carAndDog
        epoch_loss=running_loss/carAndDog.no_of_images
        epoch_acc=running_correct/carAndDog.no_of_images
        
        # 复刻模型
        if phase=='test1' and epoch_acc >best_acc:
            best_acc=epoch_acc
            best_model_wts=model.state_dict()
        print()
    time_elapsed=time.time()-since
    print('Training complete in {:.0f}m {:.0f}'.format(time_elapsed//60,time_elapsed%60))
    print("Best val Acc:{:4f}".format(best_acc))
    # 生成最优权重
    model.load_state_dict(best_model_wts)
    return model
# 创建算法实例
model_ft = torchvision.models.resnet18()
num_ftrs = model_ft.fc.in_features
# 配置全连接层
model_ft.fc = torch.nn.Linear(num_ftrs, 2)
# 使用GPU加速
if is_cuda:
    model_ft = model_ft.cuda()

learning_rate=0.001
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer_ft=torch.optim.SGD(model_ft.parameters(),lr=0.001,momentum=0.9)
exp_lr_scheduler=lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)
# 传参开始训练
train_model(model_ft,criterion,optimizer_ft,exp_lr_scheduler)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch 自己搭建的一个神经网络 的相关文章

随机推荐

  • Eclipse导入项目出现No projects are found to import、中文乱码等问题

    首先说一下导入的步骤 1 打开eclipse 左上角选择File gt 选择Import 2 选择Existing Project into Workspace 3 选择第一行的select root directory 然后选择你要导入包
  • #PRBS# PRBS7高速串行总线的常用测试码型

    PRBS的定义 PRBS Pseudo Random Binary Sequence 伪随机二进制序列 PRBS 码具有 随机 特性 是因为在 PRBS 码流中 二进制数 0 和 1 是随机出现的 但是它又和真正意义上的随机码不同 这种 随
  • 基于GRU门控循环网络的时间序列预测matlab仿真,对比LSTM网络

    目录 1 算法运行效果图预览 2 算法运行软件版本 3 部分核心程序 4 算法理论概述 5 算法完整程序工程 1 算法运行效果图预览 LSTM GRU 2 算法运行软件版本 matlab2022a 3 部分核心程序 构建GRU网络模型 la
  • 如何通过Geth、Node.js和UNIX/PHP访问以太坊节点

    本文旨在说明通过Geth Node js如何访问以太坊节点和UNIX下PHP如何访问以太坊节点 说明如何通过RPC使用此 A 以太坊节点 对于以太坊主网络使用RPC url http 85 214 51 53 8545 对于Ropsten测
  • 线上生产问题系列之-@Async使用不当引发的血案

    现象描述 突然客户群里反馈 线上某功能处理出现严重拥堵 再处理不好就要切换渠道 这个功能就是一个通知功能 客户依赖通知结果去完成他的业务逻辑 但是这个通知非常缓慢 严重拥堵 背景描述 常有这样一个需求场景 为了提高请求的吞吐量 在一个请求链
  • 奇偶调序

    题目描述 输入一个整数数组 调整数组中数字的顺序 使得所有奇数位于数组的前半部分 所有偶数位于数组的后半部分 要求时间复杂度为O n 分析与解法 最容易想到的办法是从头扫描这个数组 每碰到一个偶数 拿出这个数字 并把位于这个数字后面的所有数
  • Scala针对容器的操作(遍历、映射、过滤、归约)案例

    Scala针对容器的操作 遍历 映射 一 遍历操作 二 映射操作 2 1 map方法 2 2 flatmap方法 三 过滤操作 四 归约操作 一 遍历操作 Scala容器的标准遍历方法foreach def foreach U f Elem
  • react+ts+echarts5.x按需导入实现世界地图

    registerMap注册世界地图 1 获取世界地图geoJSON格式的文件 获取地图的渠道 这个步骤很重要 本人找了很久都没找到世界地图的GeoJSON文件 这个网址可以提供 并且也提供了各个国家的GeoJSON a 根据 在线实例 确定
  • 前端js采坑,一个函数中同时有多个ajax()异步请求

    在近期的项目中 问题 多个异步请求执行时 有两个请求的路径是相同的 导致结果只执行当中的一个异步请求 add function vm showList false vm title 新增 vm role deptName null dept
  • 微服务架构-Day7

    学习目标 学会微服务架构 对应项目hotel demo 学习笔记 1 数据聚合 聚合 aggregations 可以让我们极其方便的实现对数据的统计 分析 运算 实现这些统计功能的比数据库的sql要方便的多 而且查询速度非常快 可以实现近实
  • display设为inline-block时引发的高度问题,大坑

    今天在写小程序 点击让这个遮罩层显示 结果一直下移 莫名其妙 解决方案 在元素的CSS中添加 vertical align bottom
  • SQL-使用视图

    什么是视图 它们怎样工作 何时使用它们 如何利用视图简化执行的某些SQL操作 1 使用视图的原因 A 重用SQL语句 B 简化复杂的SQL操作 在编写查询后 可以方便地重用它而不必知道其基本查询 C 使用表的一部分而不是整个表 D 保护数据
  • 【Python】科学计算库Scipy简易入门

    0 导语 Scipy是一个用于数学 科学 工程领域的常用软件包 可以处理插值 积分 优化 图像处理 常微分方程数值解的求解 信号处理等问题 它用于有效计算Numpy矩阵 使Numpy和Scipy协同工作 高效解决问题 Scipy是由针对特定
  • vue-组件按需加载

    组件按需加载 路由配置 path name component gt import views vue 按需加载 在vue中配置路由时 可以在头部先引入组件 然后下面定义路由时 在指向到具体使用的组件 这种是页面运行时 组件全部加载 占内存
  • 严重: 子容器启动失败 java.util.concurrent.ExecutionException 信息: 正在摧毁协议处理器 ["http-nio-80"]WARNING: An illegal

    话不多说直接上错误 解决方案 由于一开始以为是tomcat和eclipseEE出现故障 将两个软件重新下载并配置环境但错误没有解决 然后又检查了JDK版本也没问题 最后肯定了是代码的问题 仔细检查后发现是servlet映射地址写重了 后来又
  • HAL库的RCC简介

    一 RCC的时钟树总览 时钟输入源有四个 选择器 预 分频器和倍频器 最终设置的频率 SYSCLK系统时钟 SYSCLK可以有三种方式得到 1 HSI内部高速时钟用的是RC振荡器 频率为8M 精度不高 没有经过分频器和倍频器 这种方式得到的
  • 空utf8文件占三字节的问题(Java空文本文件FileInputStream读取问题)

    1 文件创建情况 2 程序代码 public class Demo01 public static void main String args throws IOException File file new File a txt long
  • pycharm mysql 安装_pycharm安装mysql驱动包

    新的环境配置pycharm的项目时 发现pycharm不能连接到mysql数据库 由于安了java环境但是还没配置相关的库 并且jetbrains家的IDE一般都是java写的 于是猜想可能是java缺少mysql的驱动 1 先确保pyth
  • c++学习:2.变量声明和定义的关系

    为了支持分离式编译 c 语言将声明和定义区分开来 声明只有名字并无实体 定义创建于声明名字相关的实体 因此声明和定义最重要的区别 声明不申请存储空间 定义申请存储空间 变量能且只能被定义一次 但是可以被多次声明 注意这里说的变量定义和变量赋
  • Pytorch 自己搭建的一个神经网络

    目录 数据集 dogs Vs Cats import time import torch nn as nn import torch optim from torch autograd import Variable from torch