【动手学】36 图片增广_代码

2023-11-11

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
img = d2l.Image.open('./data/tmp2E5F.png')
d2l.plt.imshow(img)  

<matplotlib.image.AxesImage at 0x7fd01038c3d0>

 

def apply(img,aug,num_rows=2,num_cols=4,scalel=1.5):
    Y = [aug(img) for _ in range(num_rows*num_cols)]
    d2l.show_images(Y,num_rows,num_cols,scale=scalel)

左右翻转图像

apply(img,torchvision.transforms.RandomHorizontalFlip())

上下翻转

apply(img,torchvision.transforms.RandomVerticalFlip())

 

 随机剪裁

shape_aug = torchvision.transforms.RandomResizedCrop(
        (200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)

随机改变图片亮度

apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))

 

 随机改变图片色调

apply(img,torchvision.transforms.ColorJitter(brightness=0,contrast=0,saturation=0,hue=0.5))

随机更改图片亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)

color_aug = torchvision.transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5)
apply(img,color_aug)

 

 结合多种图像增广方法

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    color_aug,shape_aug
])
apply(img,augs)

 使用图像增广进行训练

!mkdir ./data/cifar10

all_images = torchvision.datasets.CIFAR10(
            train=True,root='./data/cifar10',download=True)
d2l.show_images([all_images[i][0] for i in range(32)],4,8,scale=0.8)

 只使用最简单的随机左右翻转

train_augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
])
test_augs = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

定义辅助函数,以便于读取图像和应用图像增广

def load_cifar10(is_train,augs,batch_size):
    dataset = torchvision.datasets.CIFAR10(
        root = './data/cifar10',train=is_train,
        transform = augs,download = False
    )
    
    dataloader = torch.utils.data.DataLoader(
            dataset,batch_size=batch_size,shuffle=is_train,num_workers=4)
    return dataloader

定义一个函数,使用多gpu对模型进行训练和评估

def train_batch_ch13(net,X,y,loss,trainer,devices):
    if isinstance(X,list):
        X = [x.to(devices[0] for x in X)]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred ,y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred,y)
    return train_loss_sum,train_acc_sum

def train_ch13(net,train_iter,test_iter,loss,trainer,num_epochs,devices=d2l.try_all_gpus()):
    timer,num_batches = d2l.Timer(),len(train_iter)
    animator = d2l.Animator(xlabel='epoch',xlim=[1,num_epochs],ylim=[0,1],legend=['train loss','train acc','test acc'])
#     net = nn.DataParallel(net,device_ids=devices).to(devices[0])
    net = net.cuda()
    for epoch in range(num_epochs):
        #4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i ,(features,labels) in enumerate(train_iter):
            timer.start()
            l,acc =train_batch_ch13(
                net,features,labels,loss,trainer,devices)
            metric.add(l,acc,labels.shape[0],labels.numel())
            timer.stop()
            if(i+1)%(num_batches//5) == 0 or i ==num_batches-1:
                animator.add(epoch+(i+1)/num_batches,(metric[0]/metric[2],metric[1]/metric[3],None))
            test_acc = d2l.evaluate_accuracy_gpu(net,test_iter)
            animator.add(epoch+1,(None,None,test_acc))
        print(f'loss{metric[0]/metric[2]:.3f},train acc'
              f'{metric[1]/metric[3]:.3f},test acc{test_acc:.3f}')
        print(f'{metric[1]*num_epochs/timer.sum():.1f}exmples/sec on {str(devices)}'
             )
batch_size,devices,net =256,d2l.try_all_gpus(),d2l.resnet18(10,3)

def init_weights(m):
    if type(m) in [nn.Linear,nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)
net.apply(init_weights)

def train_with_data_aug(train_augs,test_augs,net,lr=0.001):
    train_iter = load_cifar10(True,train_augs,batch_size)
    test_iter = load_cifar10(False,test_augs,batch_size)
    loss = nn.CrossEntropyLoss(reduction='none')
    trainer = torch.optim.Adam(net.parameters(),lr=lr)
    train_ch13(net,train_iter,test_iter,loss,trainer,10,devices)
    
train_with_data_aug(train_augs,test_augs,net)

loss0.169,train acc0.942,test acc0.829
1654.1exmples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

 

train_with_data_aug(test_augs,test_augs,net)          #这里是在训练集上使用测试集的增广方式。而不是代表测试集

loss0.033,train acc0.989,test acc0.857
1722.6exmples/sec on [device(type='cuda', index=0), device(type='cuda', index=1)]

 显卡用的是TITAN xp 和 GeForce RTX 2080

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【动手学】36 图片增广_代码 的相关文章

  • 为什么这些双精度数的返回值为-1.#IND?

    I have double score cvMatchContourTrees CT1 CT2 CV CONTOUR TREES MATCH I1 0 0 cout lt
  • opencv createsamples没有错误,但是没有找到样本

    我在用着this http coding robin de 2013 07 22 train your own opencv haar classifier html教程 我正在根据我的正面图像创建大量样本 我正在使用 Windows 这是
  • opencv中矩阵的超快中值(与matlab一样快)

    我正在 openCV 中编写一些代码 想要找到一个非常大的矩阵数组 单通道灰度 浮点数 的中值 我尝试了几种方法 例如对数组进行排序 使用 std sort 和选择中间条目 但与 matlab 中的中值函数相比 它非常慢 准确地说 在 ma
  • 在 RGB 图像上绘制多类语义分割透明叠加

    我有语义分割掩码的结果 值在 0 1 之间 需要大津阈值来确定什么是积极的 我想直接在 RGB 图像上绘制 在 RGB 图像上每个预测类具有不同的随机颜色 我使用以下内容绘制了具有单一颜色的单个蒙版 是否有一个包或简单的策略可以为多类别做到
  • 如何在给定目标大小的情况下在 python 中调整图像大小,同时保留纵横比?

    首先 我觉得这是一个愚蠢的问题 对此感到抱歉 目前 我发现计算最佳缩放因子 目标像素数的最佳宽度和高度 同时保留纵横比 的最准确方法是迭代并选择最佳缩放因子 但是必须有更好的方法来做到这一点 一个例子 import cv2 numpy as
  • 如何在opencv python中为图像添加边框

    如果我有如下图所示的图像 如何在图像周围添加边框 以便最终图像的整体高度和宽度增加 但原始图像的高度和宽度保持在中间 下面的代码添加了一个大小恒定的边框10像素到原始图像的所有四个边 对于颜色 我假设您想要使用背景的平均灰度值 这是我根据图
  • 我是否必须使用我的数据库训练 Viola-Jones 算法才能获得准确的结果?

    我尝试提取面部数据库的面部特征 但我认识到 Viola Jones 算法在两种情况下效果不佳 当我尝试单独检测眼睛时 当我尝试检测嘴巴时 运作不佳 检测图像的不同部分 例如眼睛或嘴巴 或者有时会检测到其中几个 这是不可能的情况 我使用的图像
  • Python:opencv warpPerspective 既不接受 2 个也不接受 3 个参数

    我发现单应矩阵如下特征匹配 单应性教程 https docs opencv org 3 4 1 d1 de0 tutorial py feature homography html using M mask cv2 findHomograp
  • 多视图几何

    我从相距一定距离的两台相同品牌的相机捕获了两张图像 捕获了相同的场景 我想计算两个相机之间的现实世界旋转和平移 为了实现这一点 我首先提取了两张图像的 SIFT 特征并进行匹配 我现在有基本矩阵也单应性矩阵 然而无法进一步进行 有很多混乱
  • 使用 ffmpeg 或 OpenCV 处理原始图像

    看完之后维基百科页面 http en wikipedia org wiki Raw image format原始图像格式 是任何图像的数字负片 为了查看或打印 相机图像传感器的输出具有 进行处理 即转换为照片渲染 场景 然后以标准光栅图形格
  • OpenCV C++ 如何知道每行的轮廓数进行排序?

    我有一个二值图像 https i stack imgur com NRLVv jpg在这张图片中 我可以使用重载的函数轻松地对从上到下 从左到右找到的轮廓进行排序std sort 我首先通过以下方式从上到下排序 sort contours
  • ffmpeg AVFrame 到 opencv Mat 转换

    我目前正在开发一个使用 ffmpeg 解码接收到的帧的项目 解码后 我想将 AVFrame 转换为 opencv Mat 帧 以便我可以在 imShow 函数上播放它 我拥有的是字节流 我将其读入缓冲区 解码为 AVFrame f fope
  • 如何使用 python、openCV 计算图像中的行数

    我想数纸张 所以我正在考虑使用线条检测 我尝试过一些方法 例如Canny HoughLines and FLD 但我只得到处理过的照片 我不知道如何计算 有一些小线段就是我们想要的线 我用过len lines or len contours
  • OpenCV IP 相机应用程序崩溃 [h264 @ 0xxxxx] 访问单元中缺少图片

    我在 cpp 中有一个 opencv 应用程序 它使用 opencv 的简单结构捕获视频流并将其保存到视频文件中 它与我的网络摄像头完美配合 但是 当我运行它从 IP 摄像机捕获流时 它可能会在大约十秒后崩溃 我的编译命令是 g O3 IP
  • 如何确定与视频中物体的距离?

    我有一个从行驶中的车辆前面录制的视频文件 我将使用 OpenCV 进行对象检测和识别 但我停留在一方面 如何确定距已识别物体的距离 我可以知道我当前的速度和现实世界的 GPS 位置 但仅此而已 我无法对我正在跟踪的对象做出任何假设 我计划用
  • 如何绘制更大的边界框和仅裁剪边界框文本 Python Opencv

    我正在使用 easyocr 来检测图像中的文本 该方法给出输出边界框 输入图像如下所示 Image 1 Image 2 使用下面的代码获得输出图像 But I want to draw a Single Bigger bounding bo
  • OpenCV:如何从网络摄像头获取原始 YUY2 图像?

    你知道如何获得吗raw YUY2来自网络摄像头的图像 使用 OpenCV DirectShow 无 VFW http opencv willowgarage com wiki CameraCapture http opencv willow
  • OpenCV 2.2 和多 CPU - opencv_haartraining.exe 是多线程的吗?

    我在 VS 2010 上构建了 OpenCV 2 2 启用了 TBB 3 支持 我确保所有项目都有正确的 tbb lib 目录 并将 tbb lib 列为依赖项 通过隐藏 tbb dll 进行验证 果然 haartraining exe 抱
  • VideoCapture.read() 返回过去的图像

    我在跑python3 6 with openCV on the Raspberry pi OS is Raspbian 代码的大致结构如下 The image以时间间隔 3 5 分钟 捕获 被捕获image在函数中处理并返回度量 精度的种类
  • BRISK 特征检测器检测零个关键点

    下面显示的 Brisk 探测器没有给我任何关键点 有人可以提出一个问题吗 我将尝试用一些代码解释我在下面所做的事情 include opencv2 features2d features2d hpp using namespace cv u

随机推荐

  • 高合汽车旗下可进化超跑SUV高合HiPhi X亮相海口国际新能源车展

    2021年1月8日 高端新能源智能出行品牌高合汽车旗下高合HiPhi X亮相第三届海口国际新能源汽车展览会 华人运通高合汽车创始人丁磊在现场透露 上市至今高合HiPhi X限量3000辆创始版车型即将预订售罄 累计收获了32000多位留资用
  • 【广州华锐互动】AR远程巡检系统在设备维修保养中的作用

    随着科技的不断发展 AR 增强现实 远程巡检系统在设备检修中发挥着越来越重要的作用 这种系统可以将AR技术与远程通信技术相结合 实现对设备检修过程的实时监控和远程指导 提高设备检修的效率和质量 首先 AR远程巡检系统可以帮助检修人员更好地理
  • NodeJs应用场景【学习路线图】

    Nodejs学习路线图 从零开始nodejs系列文章 将介绍如何利Javascript做为服务端脚本 通过Nodejs框架web开发 Nodejs框架是基于V8的引擎 是目前速度最快的Javascript引擎 chrome浏览器就基于V8
  • 【LeetCode-Java】155. Min Stack

    1 原题 链接 https leetcode com problems min stack Design a stack that supports push pop top and retrieving the minimum eleme
  • 史上最全STL常用容器及其底层存储结构总结

    各大容器的特点 可以用下标访问的容器有 既可以插入也可以赋值 vector deque map 特别要注意一下 vector和deque如果没有预先指定大小 是不能用下标法插入元素的 序列式容器才可以在容器初始化的时候制定大小 关联式容器不
  • [vue3]子组件给父组件传值context.emit

    子组件 用context emit去触发事件 父组件 还是想vue2那样接收
  • 基于Python+Flask实现一个简易网页验证码登录系统案例

    在当今的互联网世界中 为了防止恶意访问 许多网站在登录和注册表单中都采用了验证码技术 验证码可以防止机器人自动提交表单 确保提交行为背后有一个真实的人类用户 本文将向您展示如何使用Python的Flask框架来创建一个简单的验证码登录系统
  • sparksql报错

    执行时报错 org apache spark sql AnalysisException Unable to generate an encoder for inner class cn itcast spark sql Intro Per
  • linux常用文本编辑命令

    cat 命令 cat 命令用于查看纯文本文件 内容较少的 格式为 cat 选项 文件 cat命令常用于查看内容较少的纯文本文件 more 命令 more 命令用于查看纯文本文件 内容较多的 格式为 more 选项 文件 more 命令会在最
  • Ubuntu16.04 LTS自带的Python3.5升级到Python3.7详细记录

    起因 有些第三方库运行只支持Python3 5以上 以及需要使用pip3安装 因此不得不升级Python版本 主要步骤为python官方源码安装 然后修改Python3和pip3的软连接即可 具体升级步骤 安装依赖 sudo apt get
  • Python数据分析与机器学习项目实战

    时值蚂蚁上市之际 马云在上海滩发表演讲 马云的核心逻辑其实只有一个 在全球数字经济时代 有且只有一种金融优势 那就是基于消费者大数据的纯信用 我们不妨称之为数据信用 它比抵押更靠谱 它比担保更保险 它比监管更高明 它是一种面向未来的财产权
  • 迁移学习入门,新手该如何下手?

    推荐迁移学习技术的实用入门图书 自然语言处理迁移学习实战 加纳 保罗 阿祖雷 Paul Azunre 著 李想 朱仲书 张世武 译 一本书带你读懂ChatGPT背后的技术 自然语言处理迁移学习 解锁机器学习新境界 从浅层到深度 掌握NLP迁
  • 2016 OWASP Mobile TOP 10 中文版

    M1 平台使用不当 这个类别包括平台功能的滥用 或未能使用平台的安全控制 它可能包括 Android 的意图 intent 平台权限 TouchID 的误用 密钥链 KeyChain 或是移动操作系统中的其他一些安全控制 有几种方式使移动应
  • javax.net.ssl.SSLHandshakeException: sun.security.validator.ValidatorExcepti

    问题现象 Java Spring应用发送数据报如下问题 AxisFault faultCode http schemas xmlsoap org soap envelope Server userException faultSubcode
  • 网络编程先导知识

    目录 1 什么是网络协议 2 什么是Socket Socket主要类型 3 C S和B S架构 4 网络字节序和主机字节序 5 局域网和广域网 6 IP地址和端口的概念 1 什么是网络协议 为了在计算机网络中做到有条不紊地交换数据 就必须遵
  • android.content包-----ClipboardManager

    ClipboardManager类介绍 Clipboardmanager类通过getSystemService String 方法进行实例化操作 ClipboardManger类的相关方法很简单 包含set和get剪切板的数据 剪切板的数据
  • Tesseract-OCR 中文识别(附上源码)

    简介 光学字符识别 OCR Optical Character Recognition 是指对文本资料进行扫描 然后对图像文件进行分析处理 获取文字及版面信息的过程 OCR技术非常专业 一般多是印刷 打印行业的从业人员使用 可以快速的将纸质
  • Arcmap卫星影像去黑边(彻底去除黑边)

    在处理栅格数据时 我们常常会遇到一个问题 下载下来的卫星影像数据在Arcmap等软件上会出现黑边问题 如图 出现黑边的原因是因为我们下载影像图层是按外接矩形下载的 所以下载时矩形内没图的地方会填充透明色 透明后下下来后就会用黑色代替 那么我
  • 计算机组成原理-复习题2

    二 简答题 43 请写出8位定点原码整数中能表示的最大正数 最小正数 最大负数和最小负数的机器数形式 并用十进制表示其数值范围 答 最大正数 01111111 最小正数 00000001 最大负数 10000001 最小负数 1111111
  • 【动手学】36 图片增广_代码

    matplotlib inline import torch import torchvision from torch import nn from d2l import torch as d2l d2l set figsize img