【pytorch】迁移学习

2023-11-11

在很多场合中,没有必要从头开始训练整个卷积网络(随机初始化参数),因为没有足够丰富的数据集,而且训练也是非常耗时、耗资源的过程。通常,采用pretrain a ConvNet的方式,然后用ConvNet作为初始化或特征提取器。有两种迁移学习,对应着不同的应用场景。
  • 微调ConvNet:使用已有的model参数代替随机初始化参数进行训练。
  • ConvNet做为特征提取器:我们需要冻结所有的网络权重的更新,最后一层(全连接层)除外。通常,最后一个全连接层是需要根据需求进行修改,并使用一个新的随机权重进行训练。显然,整个网络只有这个层被训练。

pytorch提供了很多pre-trained models,如下:


下面以cifar10为例,cifar10有10类图像 ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')。我们将采用采用第二种方式,修改resnet-18的全连层,以达到cifar10识别目的。

加载数据

print('==> Preparing data..')
transform_train = transforms.Compose([
    #transforms.RandomCrop(224, padding=4),
    transforms.Scale(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Scale(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data/cifar', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../data/cifar', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=args.test_batch_size, shuffle=False, num_workers=2)


加载并修改模型

# ConvNet
model_ft = models.resnet18(pretrained=True)
print(model_ft)

for i, param in enumerate(model_ft.parameters()):
    param.requires_grad = False # 冻结参数的更新

num_ftrs = model_ft.fc.in_features #重新定义fc层,此时,会进行参数的更新。
model_ft.fc = nn.Linear(num_ftrs, 10)
print(model_ft)


训练

def train(epoch):
    model_ft.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)

        optimizer.zero_grad()
        output = model_ft(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(trainloader.dataset),
                100. * batch_idx / len(trainloader), loss.data[0]))

完整代码可以查看: tfygg/pytorch-tutorials
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【pytorch】迁移学习 的相关文章

  • PyTorch - 参数不变

    为了了解 pytorch 的工作原理 我尝试对多元正态分布中的一些参数进行最大似然估计 然而 它似乎不适用于任何协方差相关的参数 所以我的问题是 为什么这段代码不起作用 import torch def make covariance ma
  • 我可以使用逻辑索引或索引列表对张量进行切片吗?

    我正在尝试使用列上的逻辑索引对 PyTorch 张量进行切片 我想要与索引向量中的 1 值相对应的列 切片和逻辑索引都是可能的 但是它们可以一起吗 如果是这样 怎么办 我的尝试不断抛出无用的错误 类型错误 使用 ByteTensor 类型的
  • 如何避免 PyTorch 中的“CUDA 内存不足”

    我认为对于 GPU 内存较低的 PyTorch 用户来说 这是一个非常常见的消息 RuntimeError CUDA out of memory Tried to allocate X MiB GPU X X GiB total capac
  • pytorch grad 在 .backward() 之后为 None

    我刚刚安装火炬 1 0 0 on Python 3 7 2 macOS 并尝试tutorial https pytorch org tutorials beginner blitz autograd tutorial html sphx g
  • PyTorch 教程错误训练分类器

    我刚刚开始 PyTorch 教程使用 PyTorch 进行深度学习 60 分钟闪电战我应该补充一点 我之前没有编写过任何 python 但其他语言 如 Java 现在 我的代码看起来像 import torch import torchvi
  • 在 PyTorch 中原生测量多类分类的 F1 分数

    我正在尝试在 PyTorch 中本地实现宏 F1 分数 F measure 而不是使用已经广泛使用的sklearn metrics f1 score https scikit learn org stable modules generat
  • 一次热编码期间出现 RunTimeError

    我有一个数据集 其中类值以 1 步从 2 到 2 i e 2 1 0 1 2 其中 9 标识未标记的数据 使用一种热编码 self one hot encode labels 我收到以下错误 RuntimeError index 1 is
  • torch.mm、torch.matmul 和 torch.mul 有什么区别?

    阅读完 pytorch 文档后 我仍然需要帮助来理解之间的区别torch mm torch matmul and torch mul 由于我不完全理解它们 所以我无法简明地解释这一点 B torch tensor 1 1207 0 3137
  • 尝试理解 Pytorch 的 LSTM 实现

    我有一个包含 1000 个示例的数据集 其中每个示例都有5特征 a b c d e 我想喂7LSTM 的示例 以便它预测第 8 天的特征 a 阅读 nn LSTM 的 Pytorchs 文档 我得出以下结论 input size 5 hid
  • 在非单一维度 1 处,张量 a (2) 的大小必须与张量 b (39) 的大小匹配

    这是我第一次从事文本分类工作 我正在使用 CamemBert 进行二进制文本分类 使用 fast bert 库 该库主要受到 fastai 的启发 当我运行下面的代码时 from fast bert data cls import Bert
  • 下载变压器模型以供离线使用

    我有一个训练有素的 Transformer NER 模型 我想在未连接到互联网的机器上使用它 加载此类模型时 当前会将缓存文件下载到 cache 文件夹 要离线加载并运行模型 需要将 cache 文件夹中的文件复制到离线机器上 然而 这些文
  • pytorch 中的 keras.layers.Masking 相当于什么?

    我有时间序列序列 我需要通过将零填充到矩阵中并在 keras 中使用 keras layers Masking 来将序列的长度固定为一个数字 我可以忽略这些填充的零以进行进一步的计算 我想知道它怎么可能在 Pytorch 中完成 要么我需要
  • 如何计算 CNN 第一个线性层的维度

    目前 我正在使用 CNN 其中附加了一个完全连接的层 并且我正在使用尺寸为 32x32 的 3 通道图像 我想知道是否有一个一致的公式可以用来计算第一个线性层的输入尺寸和最后一个卷积 最大池层的输入 我希望能够计算第一个线性层的尺寸 仅给出
  • 将 Keras (Tensorflow) 卷积神经网络转换为 PyTorch 卷积网络?

    Keras 和 PyTorch 使用不同的参数进行填充 Keras 需要输入字符串 而 PyTorch 使用数字 有什么区别 如何将一个转换为另一个 哪些代码在任一框架中获得相同的结果 PyTorch 还采用参数 in channels o
  • Pytorch CUDA 错误:没有内核映像可用于在带有 cuda 11.1 的 RTX 3090 设备上执行

    如果我运行以下命令 import torch import sys print A sys version print B torch version print C torch cuda is available print D torc
  • PyTorch 中的交叉熵

    交叉熵公式 但为什么下面给出loss 0 7437代替loss 0 since 1 log 1 0 import torch import torch nn as nn from torch autograd import Variable
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • 如何计算cifar10数据的平均值和标准差

    Pytorch 使用以下值作为 cifar10 数据的平均值和标准差 变换 Normalize 0 5 0 5 0 5 0 5 0 5 0 5 我需要理解计算背后的概念 因为这些数据是 3 通道图像 我不明白什么是相加的 什么是除什么的等等
  • 样本()和r样本()有什么区别?

    当我从 PyTorch 中的发行版中采样时 两者sample and rsample似乎给出了类似的结果 import torch seaborn as sns x torch distributions Normal torch tens
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0

随机推荐

  • 1056 组合数的和

    给定 N 个非 0 的个位数字 用其中任意 2 个数字都可以组合成 1 个 2 位的数字 要求所有可能组合出来的 2 位数字的和 例如给定 2 5 8 则可以组合出 25 28 52 58 82 85 它们的和为330 输入格式 输入在一行
  • Springboot3 + SpringSecurity + JWT + OpenApi3 实现认证授权

    Springboot3 SpringSecurity JWT OpenApi3 实现双token 目前全网最新的 Spring Security JWT 实现双 Token 的案例 收藏就对了 欢迎各位看友学习参考 此项目由作者个人创作 可
  • 即使失业,也要把第二个一万小时坚持下去

    这个月打的我有点懵逼 不知所措了 所以 在此写贴 即使失业 也要把第二个一万小时坚持下去 每天8小时学习 反正已经非工资收入九千了 基本上可以活下去了
  • Karma 自动化测试框架搭建文档

    一 前言 此文档为前端自动化单元测试框架 Karma 的搭建以及使用文档 二 准备环境 先列出我们此次搭建测试框架 Karma 必须的环境和包 1 node js node 引擎 2 npm node 包管理器 3 cnpm 可选 淘宝镜像
  • 数列分段(贪心入门)

    问题 对于给定的一个长度为 n 的正整数数列 ai 现要将其分成连续的若干段 并且每段和不超过 m 可以等于 m 问最少能将其分成多少段使得满足要求 算法复杂度为O n 思路 对于已给出数列 从前往后扫描一遍 在扫描过程中 不断记录当前最大
  • win10maven环境变量配置(简洁版):

    准备工作 下载了maven 可以官网下载 也可以通过其他途径获取 没安装之前 在命令行输入mvn v是这样的 解决方案 1 此电脑 属性 高级 环境变量 系统变量 2 新建变量 变量名 MAVEN HOME 值 本地maven的文件夹路径
  • 如何在Geany中添加python的中文注释

    在Geany中编译Python中直接添加中文注释会出现如下错误 只需要在程序的开始位置添加一句 coding utf 8
  • 全网最全Python兼职接单方式,赶快收藏!

    前言 近年来 Python凭借其简洁易入门的特点受到越来越多人群的青睐 当然这不仅仅是针对程序员来说 对于一些学生 职场人士也是如此 Python为什么会大受欢迎 1 Python还被大家称为 胶水语言 它适用于网站 桌面应用开发 自动化脚
  • Unity_DoTween_Path路径动画的使用

    using System Collections using System Collections Generic using System Linq using DG Tweening using UnityEngine public c
  • ResNet学习笔记

    目录 1 背景 2 BN Batch Normalization 层 3 residual结构 残差结构 1 背景 在 ResNet 之前 所有的神经网络都是通过卷积层和池化层的叠加组成的 人们认为卷积层和池化层的层数越多 获取到的图片特征
  • oracle 启动时出现ORA-01157: cannot identify/lock data和ORA-01110: data file 错误

    SQL gt shutdown ORA 01109 database not open Database dismounted ORACLE instance shut down SQL gt startup ORACLE instance
  • 微隔离(MSG)

    微隔离 MSG 参考文章 用 微隔离 实现零信任 什么是微隔离 当下哪家微隔离最靠谱 参考视频 不仅是防火墙 用微隔离实现零信任 定义 微隔离 Micro Segmentation 微隔离是一种网络安全技术 其核心的能力要求是聚焦在东西向流
  • NER标注----使用BILSTM模型训练招投标实体标注模型

    NER标注 BILSTM模型训练招投标实体标注模型 TOC NER标注 BILSTM模型训练招投标实体标注模型 前言 一 NER标注简介 二 从头开始训练一个NER标注器 二 使用步骤 1 引入库 2 数据处理 3 模型训练 前言 上文中讲
  • Python3 迭代器与生成器

    迭代器 迭代是Python最强大的功能之一 是访问集合元素的一种方式 迭代器是一个可以记住遍历的位置的对象 迭代器对象从集合的第一个元素开始访问 直到所有的元素被访问完结束 迭代器只能往前不会后退 迭代器有两个基本的方法 iter 和 ne
  • android 前后台保活 实现定位数据定时上传并展示轨迹 (下)

    上一篇地址 https blog csdn net qq 40803752 article details 86304508 上2篇写完了 保活 这一篇写进入业务逻辑 大概5分钟定一次位置 上传到服务器 并且展示 定位的话 我这里使用的百度
  • Qt5.9.0下载与安装(windows版本)

    1 下载 Qt5 9 0开源版本官网下载 选择图中2 3GB的安装包 即可进行下载 2 安装 双击安装包 弹出qt5 9 0的安装界面 点击下一步 这里的账户如果没有 可以不填 直接点Next 点击下一步 选择安装目录 勾选下面的勾选框 点
  • linux移除ntp,[笔记]Linux NTP命令

    推荐阅读 etc ntp conf 文件是ESX Linux NTP的主要配置文件 启动 停止 重启NTP 用下面的命令 root bigboy tmp service ntpd start root bigboy tmp service
  • 爬虫笔记2--爬取2345网站历史天气

    爬虫笔记2 爬取2345网站历史天气 最近需要获取某些地区的历史气象信息 墨迹天气无法获取历史数据 就在网上看了下 发现2345网站有相对完善的历史气象信息 就爬了下来并保存到MySql数据中 1 功能 本代码主要功能为 爬取2345天气历
  • vue中store模块化

    在进行书写store时 我们会分模块来管理我们的各个部分 我们会创建如下图目录 注意 每个模块中namespaced true是不可或缺的 export default namespaced true state mutations act
  • 【pytorch】迁移学习

    在很多场合中 没有必要从头开始训练整个卷积网络 随机初始化参数 因为没有足够丰富的数据集 而且训练也是非常耗时 耗资源的过程 通常 采用pretrain a ConvNet的方式 然后用ConvNet作为初始化或特征提取器 有两种迁移学习