PyTorch预训练和微调:以VGG16为例

2023-11-04

预训练和微调代码

数据集:CIFAR10
CIFAR-10数据集由10类32x32的彩色图片组成,一共包含60000张图片,每一类包含6000图片。其中50000张图片作为训练集,10000张图片作为测试集。数据集介绍来自:CIFAR10

在这里插入图片描述
图片来源:https://paperswithcode.com/dataset/cifar-10

预训练模型: vgg16

代码

# Imports
import torch
import torchvision
import torch.nn as nn  # All neural network modules, nn.Linear, nn.Conv2d, BatchNorm, Loss functions
import torch.optim as optim  # For all Optimization algorithms, SGD, Adam, etc.
import torch.nn.functional as F  # All functions that don't have any parameters
from torch.utils.data import (
    DataLoader,
)  # Gives easier dataset managment and creates mini batches
import torchvision.datasets as datasets  # Has standard datasets we can import in a nice way
import torchvision.transforms as transforms  # Transformations we can perform on our dataset
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
num_classes = 10
learning_rate = 1e-3
batch_size = 1024
num_epochs = 2


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

# Load pretrain model & modify it

model = torchvision.models.vgg16(weights='DEFAULT')
# # If you want to do finetuning then set requires_grad = False
# # Remove these two lines if you want to train entire model,
# # and only want to load the pretrain weights.
# for param in model.parameters():
#     param.requires_grad = False
for param in model.parameters():
    param.requires_grad = False

model.avgpool = Identity() # 站位层,使得该层啥事不做
model.classifier = nn.Sequential(nn.Linear(512, 100),
                                 nn.ReLU(),
                                 nn.Linear(100, 10)) # 修改原模型的后几层
model.to(device)


# Load Data
train_dataset = datasets.CIFAR10(
    root="dataset/", 
    train=True, 
    transform=transforms.ToTensor(), 
    download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train Network
for epoch in range(num_epochs):
    losses = []

    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        # Get data to cuda if possible
        data = data.to(device=device)
        targets = targets.to(device=device)

        # forward
        scores = model(data)
        loss = criterion(scores, targets)

        losses.append(loss.item())
        # backward
        optimizer.zero_grad()
        loss.backward()

        # gradient descent or adam step
        optimizer.step()

    print(f"Cost at epoch {epoch} is {sum(losses)/len(losses):.5f}")

# Check accuracy on training & test to see how good our model


def check_accuracy(loader, model):
    if loader.dataset.train:
        print("Checking accuracy on training data")
    else:
        print("Checking accuracy on test data")

    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        print(
            f"Got {num_correct} / {num_samples} with accuracy {float(num_correct)/float(num_samples)*100:.2f}%"
        )

    model.train()


check_accuracy(train_loader, model)

测试结果

Checking accuracy on training data
Got 29449 / 50000 with accuracy 58.90%

可以看到本次预训练模型的导入,测试结果并不理想。但并不妨碍我们对Pytorch预训练和微调的学习。

参考来源

【1】 https://www.youtube.com/watch?v=qaDe0qQZ5AQ&list=PLhhyoLH6IjfxeoooqP9rhU3HJIAVAJ3Vz&index=8
【2】https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_pretrain_finetune.py

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

PyTorch预训练和微调:以VGG16为例 的相关文章

  • softmax_cross_entropy_with_logits 的 PyTorch 等效项

    我想知道 TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy with logits TensorFlow 是否有等效的 PyTorch 损失函数softmax cross entropy
  • python 可以检测它运行在哪个操作系统下吗?

    python 可以检测操作系统 然后为文件系统构建 if else 语句吗 我需要将 Fn 字符串中的 C CobaltRCX 替换为 FileSys 字符串 import os path csv from time import strf
  • 类属性在功能上依赖于其他类属性

    我正在尝试使用静态类属性来定义另一个静态类属性 我认为可以通过以下代码来实现 f lambda s s 1 class A foo foo bar f A foo 然而 这导致NameError name A is not defined
  • 在 Python 中使用 sec 函数的反函数

    我正在创建一个程序 用于计算从一定高度范围和设定初始速度发射射弹的最佳角度 在我需要使用的最终方程中 存在一个反 sec 函数 它导致了一些麻烦 我已经导入了数学并尝试使用 asec 无论如何 但是数学似乎无法计算反秒函数 我也明白 sec
  • 使用正则表达式解析 Snort 警报文件

    我正在尝试使用 Python 中的正则表达式从 snort 警报文件中解析出源 目标 IP 和端口 和时间戳 示例如下 03 09 14 10 43 323717 1 2008015 9 ET MALWARE User Agent Win9
  • 无法包含外部 pandas 文档 Pycharm v--2018.1.2

    我无法包含外部 pandas 文档Pycharm v 2018 1 2 例如 numpy gt http docs scipy org doc numpy reference generated module name element na
  • VSCode pytest 测试发现失败

    Pytest 测试发现失败 用户界面指出 Test discovery error please check the configuration settings for the tests 输出窗口显示 Test Discovery fa
  • 如何在 Windows 上使用 Python 3.6 来安装 Python 2.7

    我想问一下如何使用pip install对于 Python 2 7 当我之前安装并使用 Python 3 6 时 我现在必须使用 Windows 上的 Python 版本 pip install 继续安装 Python 3 6 我需要使用以
  • 使用循环将对象添加到列表(python)

    我正在尝试使用 while 循环将对象添加到列表中 基本上这就是我想做的 class x pass choice raw input pick what you want to do while choice 0 if choice 1 E
  • 使用 python 将文本发送到带有逗号分隔符的列

    如何使用分隔符 在 Excel 中将一列分成两列 并使用 python 命名标题 这是我的代码 import openpyxl w openpyxl load workbook DDdata xlsx active w active a a
  • urllib2.urlopen() 是否实际获取页面?

    当我使用 urllib2 urlopen 时 我在考虑它只是为了读取标题还是实际上带回整个网页 IE 是否真的通过 urlopen 调用或 read 调用获取 HTML 页面 handle urllib2 urlopen url html
  • 在谷歌C​​olab中使用cv2.imshow()

    我正在尝试通过输入视频来对视频进行对象检测 cap cv2 VideoCapture video3 mp4 在处理部分之后 我想使用实时对象检测来显示视频 while True ret image np cap read Expand di
  • python中的sys.stdin.fileno()是什么

    如果这是非常基本的或之前已经问过的 我很抱歉 我用谷歌搜索但找不到简单且令人满意的解释 我想知道什么sys stdin fileno is 我在代码中看到了它 但不明白它的作用 这是实际的代码块 fileno sys stdin filen
  • 在pycharm中调试python代码

    这个问题类似于this https stackoverflow com questions 10240018 how to use pycharm to debug python script一 我正在尝试调试pyethapp https
  • Scrapy 蜘蛛无法工作

    由于到目前为止没有任何效果 我开始了一个新项目 python scrapy ctl py startproject Nu 我完全按照教程操作 创建了文件夹和一个新的蜘蛛 from scrapy contrib spiders import
  • CSV 在列中查找最大值并附加新数据

    大约两个小时前 我问了一个关于从网站读取和写入数据的问题 从那时起 我花了最后两个小时试图找到一种方法来从输出的 A 列读取最大日期值 将该值与刷新的网站数据进行比较 并将任何新数据附加到 csv 文件而不覆盖旧的或创建重复项 目前 100
  • Pandas 在特定列将数据帧拆分为两个数据帧

    I have pandas我组成的 DataFrameconcat 一行由 96 个值组成 我想将 DataFrame 从值 72 中分离出来 这样 一行的前 72 个值存储在 Dataframe1 中 接下来的 24 个值存储在 Data
  • 将 Scikit-Learn OneHotEncoder 与 Pandas DataFrame 结合使用

    我正在尝试使用 Scikit Learn 的 OneHotEncoder 将 Pandas DataFrame 中包含字符串的列替换为 one hot 编码的等效项 我的下面的代码不起作用 from sklearn preprocessin
  • 使用“pythonw”(而不是“python”)运行应用程序时找不到模块

    我尝试了这个最小的例子 from flask import Flask app Flask name app route def hello world return Hello World if name main app run deb
  • 如何识别图形线条

    我有以下格式的路径的 x y 数据 示例仅用于说明 seq p1 p2 0 20 2 3 1 20 2 4 2 20 4 4 3 22 5 5 4 22 5 6 5 23 6 2 6 23 6 3 7 23 6 4 每条路径都有多个点 它们

随机推荐

  • Windows下Sqlmap环境安装教程详解

    更多编程教程请到 菜鸟教程 https www piaodoo com Sqlmap安装 Sqlmap gt gt 基于Python的自动化渗透测试工具 安装工具前需要进行Python的环境准备 Python环境的安装 1 1 下载与安装
  • 软件安全测试包含哪些内容和方法?安全测试报告的必要性

    软件安全测试是一种通过模拟真实攻击的方式 对软件系统进行全面的安全性评估和测试 以发现潜在的安全漏洞和弱点 是确保软件系统安全性的重要措施 在进行软件安全测试时 我们需要了解测试的内容和方法 以及为什么进行安全测试报告的必要性 一 软件安全
  • 2021-05-24

    JDBC 目录 JDBC 一 idea下创建JDBC项目 1 下载所需JDBC驱动 2 连接数据库 3 创建JDBC项目 二 JDBC常用类及常用方法介绍 1 DrivaerManger 驱动管理对象 2 Connection 数据库连接对
  • Chat GPT Access denied——最新解决方法

    前几天没怎么捣鼓ChatGpt 看到网上铺天盖地的被封号的消息 心想我不会也被封了吧 立马上线一探究竟 结果喜提Access denied 难道我要告别Chatgpt了 不 我不甘心 然后就是一顿操作根据网上各路大神帖子提供的方案 都一一失
  • Nginx参数配置详细说明【全局、http块、server块、events块】【已亲测】

    Nginx重点参数配置说明 本文包含Nginx参数配置说明全局块 http块 server块 events块共计30多个参数配置与解释 其中常见参数包含配置错误出现的错误日志 能让你更快的解决问题 该文的所有参数大部分经过单独测试 错误都是
  • Vue修改数据页面不更新的问题解决

    第一种场景 动态给对象新增属性或者删除属性是不会触发视图刷新的 Vue识别不到 第二种场景 通过数组下标修改数组中的元素或者手动修改数组的长度 Vue识别不到 解决方法1 静默刷新 使用v if的特性 在修改值之后将元素销毁 然后在修改后的
  • Java GUI编程(Swing)(窗口 面板 弹窗)

    目录 一 窗口 面板 Swing 重点 重点 重点 如果想给窗口进行背景颜色 必须要给窗口JFrame实例化 否则其他没有颜色 例如 jframe setBackground Color BLUE 背景无颜色 Container conta
  • Minio 部署

    minio 官网 https www minio org cn 部署文档 https www minio org cn docs minio container operations install deploy manage deploy
  • SQL中JOIN和UNION

    join 是两张表做交连后里面条件相同的部分记录 可以是不同字段 产生一个记录集 union是产生的两个记录集 字段要一样的 并在一起 成为一个新的记录集 JOIN用于按照ON条件联接两个表 主要有四种 INNER JOIN 内部联接两个表
  • Vision Transformer(ViT)

    1 前言 本文讲解Transformer模型在计算机视觉领域图片分类问题上的应用 Vision Transformer ViT 本人全部文章请参见 博客文章导航目录 本文归属于 计算机视觉系列 2 Vision Transformer Vi
  • Linux rootfs(根文件系统讲解)

    rootfs 其实就是 针对特定的操作系统的架构 一种实现的形式 具体表现为 特定的目录 就理解为windows的文文件夹 目录之间的关系 即组织架构 以及特定的各种文件 boot 系统启动的相关文件 如内核 initrd 以及grup b
  • proxy_set_header Host $host;

    server listen 80 server name www yuetai net cn 核心代码 rewrite https server name 1 permanent location proxy set header X Re
  • HC-02蓝牙串口模块的配置和使用

    HC 02蓝牙串口模块是基于蓝牙2 0并兼容BLE的双模蓝牙数传模块 带底板的蓝牙模块如下图 模块可以作为从机与HC 05或HC 06的主机设备通信 也可以和手机通信 模块在上电未连接蓝牙的时候LED快闪指示 这时可使用串口助手AT指令配置
  • 远场(far-field)语音识别的主流技术有哪些

    转自 https www zhihu com question 48537863 远场 far field 语音识别的主流技术有哪些 以amazon echo为首的一批智能硬件正在崛起 这些硬件实现语音识别功能时面临的一个挑战性的问题就是如
  • apache 2.4 + php 5.5 配置

    网上流传的大都是 Apache 2 2 和 php 5 4的配置 还那么多人转来转去 害苦了多少入门的新手 以下内容适合apache 2 4 php 5 5 mysql或者mariadb的安装配置不说了 从apachelounge下载Apa
  • PCB如何添加3D模型

    我们在布线完PCB后 可以通过按下键盘的数字 3 来查看自己的板子3D模式下的样子 然后我们可以将一些元器件的3D模型添加上去 看一下板子焊接上元器件是什么样子 所以我们可以手动将3D元器件添加上去也可以在选择元器件的封装时候添加其3D封装
  • 监督学习-贝叶斯分类器

    贝叶斯分类器 1 原理 先验概率 某个事件B发生的概率P B 条件 后验 概率 事件B在另一事件A已发生条件下的发生概率P B A 联合概率 两个事件共同发生的概率P A B P B A 2 多个离散属性的条件概率 样本x是n维的特征向量
  • iOS APP打包上传到APPstore的最新步骤

    一 前言 作为一名iOS开发者 把辛辛苦苦开发出来的App上传到App Store是件必要的事 但是很多人还是不知道该怎么上传到App Store上 下面就来详细讲解一下具体流程步骤 二 准备 一个已付费的开发者账号 可分为 账号类型分为个
  • VS2019查找库函数文件夹的方法

    我的路径 C Program Files x86 Windows Kits 10 Include 10 0 19041 0 ucrt 方法 1 新建一个项目 打上所想找到的库函数 2 右击 gt 转到文档 3 右击选项卡 gt 打开文件夹
  • PyTorch预训练和微调:以VGG16为例

    文章目录 预训练和微调代码 测试结果 参考来源 预训练和微调代码 数据集 CIFAR10 CIFAR 10数据集由10类32x32的彩色图片组成 一共包含60000张图片 每一类包含6000图片 其中50000张图片作为训练集 10000张