程序记录(一)VGG16猫狗分类

2023-11-16

import torch
from torchvision import datasets, models, transforms
import os
from torch.utils.data import DataLoader
from torch.autograd import Variable
import matplotlib.pyplot as plt
import time
# 下载数据集,保存到DogsVSCats文件夹。
data_dir = "DogsVSCats"
# 定义函数:子文件夹train、valid下的图片转换为224*224的tensor形式,并做归一化处理。
data_transform = {x: transforms.Compose([transforms.Resize([224, 224]),
                                         transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
                  for x in ["train", "valid"]}
# 定义数据集所在路径,调用转换函数。
image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x])
                  for x in ["train", "valid"]}
# 根据以上定义加载数据集
dataloader = {x: DataLoader(dataset=image_datasets[x], batch_size=16, shuffle=True) for x in ["train", "valid"]}
# 
X_example, y_example = next(iter(dataloader["train"]))
example_clasees = image_datasets["train"].classes
index_classes = image_datasets["train"].class_to_idx

model = models.vgg16(pretrained=True)
for parma in model.parameters():
    parma.requires_grad = False

model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 4096),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(4096, 4096),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(4096, 2))
Use_gpu = torch.cuda.is_available()
if Use_gpu:
    model = model.cuda()

cost = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters())

loss_f = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.00001)

epoch_n = 2
time_open = time.time()

for epoch in range(epoch_n):
    print("Epoch{}/{}".format(epoch,epoch_n-1))
    print("-"*10)

    for phase in ["train","valid"]:
        if phase == "train":
            print("Training...")
            model.train(True)
        else:
            print("Validing...")
            model.train(False)

        running_loss = 0.0
        running_corrects = 0

        for batch, data in enumerate(dataloader[phase], 1):
            X, y = data
            if Use_gpu:
                X, y = Variable(X.cuda()), Variable(y.cuda())
            else:
                X, y = Variable(X), Variable(y)

            y_pred = model(X) # 是否需要cuda

            _, pred = torch.max(y_pred.data, 1)

            optimizer.zero_grad()

            loss = loss_f(y_pred, y)

            if phase == "train":
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
            running_corrects += torch.sum(pred == y.data)

            if batch%500 == 0 and phase == "train":
                print("Batch{}, Train Loss :{:.4f}, Train ACC: {:.4f}".format(batch, running_loss/batch,
                                                                              100*running_corrects/(16*batch)))
        epoch_loss = running_loss*16/len(image_datasets[phase])
        epoch_acc = 100*running_corrects/len(image_datasets[phase])
        print("{} Loss:{:.4f} ACC:{:.4f}%".format(phase, epoch_loss, epoch_acc))
        time_end = time.time() - time_open
        print(time_end)


本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

程序记录(一)VGG16猫狗分类 的相关文章

  • 如何查找分布式dask中任务失败的原因?

    我正在开发一个分布式计算系统dask distributed 我通过以下方式提交给它的任务Executor map功能有时会失败 而其他看起来相同的功能却可以成功运行 该框架是否提供了诊断问题的方法 update我所说的失败是指增加 Bok
  • 使用Python的工业视觉相机[关闭]

    就目前情况而言 这个问题不太适合我们的问答形式 我们希望答案得到事实 参考资料或专业知识的支持 但这个问题可能会引发辩论 争论 民意调查或扩展讨论 如果您觉得这个问题可以改进并可能重新开放 访问帮助中心 help reopen questi
  • 在 python 2 和 3 的spyder之间切换

    根据我在文档中了解到的内容 它指出您只需使用命令提示符创建一个新变量即可轻松在 2 个 python 环境之间切换 如果我已经安装了 python 2 7 则 conda create n python34 python 3 4 anaco
  • 为什么方法无法访问类变量?

    我试图理解Python中的变量作用域 除了我不明白为什么类变量不能从其方法访问的部分之外 大多数事情对我来说都很清楚 在下面的例子中mydef1 无法访问a 但如果a可以在全局范围 类定义之外 声明 class MyClass1 a 25
  • 如何有条件地组合两个相同形状的 numpy 数组

    这听起来很简单 但我想我把它想得太复杂了 我想创建一个数组 其元素是从两个形状相同的源数组生成的 具体取决于源数组中哪个元素更大 为了显示 import numpy as np array1 np array 2 3 0 array2 np
  • 如何用spaCy获取依赖树?

    我一直在尝试寻找如何使用 spaCy 获取依赖树 但我找不到任何有关如何获取树的信息 只能在如何导航树 https spacy io usage examples subtrees 如果有人想轻松查看 spacy 生成的依赖关系树 一种解决
  • 返回不包括指定键的字典副本

    我想创建一个函数 返回字典的副本 不包括列表中指定的键 考虑这本词典 my dict keyA 1 keyB 2 keyC 3 致电without keys my dict keyB keyC 应该返回 keyA 1 我想用一行简洁的字典理
  • 无法安装时间模块

    我试过了pip install time and sudo H pip install time 但我不断收到错误 找不到满足要求时间的版本 从 版本 未找到时间匹配的发行版 我正在 PyCharm 中工作 但真正没有意义的是我可以在 Py
  • 远程控制或脚本打开 Office 从 Python 编辑 Word 文档

    我想 最好在 Windows 上 在特定文档上启动 Open Office 搜索固定字符串并将其替换为我的程序选择的另一个字符串 我该如何从外部 Python 程序中做到这一点 OLE 什么 原生 Python 脚本解决方案 The doc
  • 使用 Python 抓取维基百科数据

    我正在尝试从以下内容中检索 3 列 NFL 球队 球员姓名 大学球队 维基百科页面 http en wikipedia org wiki 2008 NFL draft 我是 python 新手 一直在尝试使用 beautifulsoup 来
  • 一起使用 Argparse 和 Json

    我是 Python 初学者 我想知道 Argparse 和 JSON 是否可以一起使用 说 我有变量p q r 我可以将它们添加到 argparse 中 parser add argument p param1 help x variabl
  • 以编程方式将列名称添加到 numpy ndarray

    我正在尝试将列名称添加到 numpy ndarray 然后按名称选择列 但这不起作用 我无法判断问题是在添加名称时出现 还是在稍后尝试调用它们时出现 这是我的代码 data np genfromtxt csv file delimiter
  • 在 Windows 上将 Word2vec 与 Tensorflow 结合使用

    In 本教程文件 https github com tensorflow models blob master tutorials embedding word2vec py L45通过 Tensorflow 找到以下行 第 45 行 来加
  • 如何检查包含 NaN 的列表 [关闭]

    Closed 这个问题需要细节或清晰度 help closed questions 目前不接受答案 在我的 for 循环中 我的代码生成一个如下所示的列表 list 0 0 0 0 sum 0 0 0 0 该循环生成所有其他数字向量 但它也
  • 如何将同步函数包装在异步协程中?

    我在用着aiohttp https github com aio libs aiohttp构建一个 API 服务器 将 TCP 请求发送到单独的服务器 发送 TCP 请求的模块是同步的 对于我来说是一个黑匣子 所以我的问题是这些请求阻塞了整
  • Python time.sleep - 永不醒来

    我认为这将是那些简单的问题之一 但它让我感到困惑 停止媒体 我是对的 找到了解决方案 查看答案 我正在使用 Python 的单元测试框架来测试多线程应用程序 很好而且很直接 我有 5 个左右的工作线程监视一个公共队列 以及一个为它们制作工作
  • Pandas DataFrame:如何计算组中第一行和最后一行的差异?

    这是我的熊猫数据框 import pandas as pd import numpy as np data column1 338 519 871 1731 2693 2963 3379 3789 3910 4109 4307 4800 4
  • 在 python 查询参数中使用 %20 而不是 + 作为空格

    我使用 python requests 编写了以下 python 脚本 http requests readthedocs org en latest http requests readthedocs org en latest impo
  • 使用Python重命名目录中的多个文件

    我正在尝试使用以下 Python 脚本重命名目录中的多个文件 import os path Users myName Desktop directory files os listdir path i 1 for file in files
  • Python - 打印漂亮的 XML 为空标签文本创建开始和结束标签

    我正在编写一个 python 应用程序 它创建一个 ElementTree XML 然后使用 minidom 的 toprettyxml 将其写入文件 final tree minidom parseString ET tostring r

随机推荐

  • 信息安全产品认证

    文章目录 一 引言 二 网络关键设备和网络安全专用产品安全认证证书 2 1 背景 2 2 产品目录 2 3 认证依据标准 2 4 认证机构 2 5 商密产品检测认证目录 与 网络关键设备和网络安全专用产品目录 的关系 三 中国国家信息安全产
  • 20个常见的Java错误以及规避方法

    原文 50 Common Java Errors and How to Avoid Them Part 1 作者 Angela Stringfellow 翻译 雁惊寒 译者注 本文介绍了20个常见的Java编译器错误 每种错误都包含了代码片
  • MKP勒索病毒:了解最新变种mkp,以及如何保护您的数据

    导言 在数字化时代 mkp 勒索病毒成为了网络安全领域的一大威胁 它采用高级加密技术 将您的数据文件锁定 要求支付赎金以解锁 本文将详细介绍 mkp 勒索病毒的工作原理 如何恢复被它加密的数据文件 以及如何采取预防措施来降低受攻击的风险 如
  • lambdaQuery用法

    lambdaQuery用法 LambdaQueryWrapper
  • pandas DataFrame行或列的删除方法

    pandas DataFrame的增删查改总结系列文章 pandas DaFrame的创建方法 pandas DataFrame的查询方法 pandas DataFrame行或列的删除方法 pandas DataFrame的修改方法 此文我
  • uniapp之微信小程序开发教程及如何合理使用WebSocket(实时监听)+workman聊天系统+linux系统配置阿里云端口

    添加链接描述 添加链接描述 thinphp6 1 workerman文档 添加链接描述 https www kancloud cn manual thinkphp6 0 1147857 workerman手册 https www worke
  • 软件的最低测试方法

    前言 1 1 引言 对于大部分软件系统 如何测试及有效的测试 是一个很头痛的问题 在软件工程上 测试是软件工程中极其重要的一部分 但在具体的实际情况上 无论是时间 人手及资源的调配等原因 使国内大部分软件公司没有进行过理论上的完整的测试 本
  • JAVA变量与数据类型

    人生不如意之事十有八九 在最好的年纪要努力充实自己 莫等空悲切白了少年头 而是要及时当勉励 岁月不待人 一 java变量 变量概述 1 内存中存储的一个存储区域 2 该存储区域内的数据在同一类型范围内不断变化 3 变量是程序中最基本的存储单
  • 老虎证券美股策略——将动量策略日频调仓改成月频

    最近策略频繁回撤 跑不赢标普500指数 所以对策略简单修改 以待后效 新加入的代码 def get if trade day infile open countday dat r incontent infile read infile c
  • Linux系统中负载较高问题排查思路与解决方法

    Load 就是对计算机干活多少的度量 Load Average 就是一段时间 1分钟 5分钟 15分钟 内平均Load linux服务器出现高负载的情况下 一般都有一些具体的症状 比如cpu 内存等被耗尽 磁盘IO或者网络等出现问题 下面通
  • CentOS7下安装LNMP以及phpMyAdmin

    两种安装 第一种 下载 可以到官网找 版本 https www phpmyadmin net downloads cd 到你要下载的位置 wget https files phpmyadmin net phpMyAdmin 4 4 12 p
  • Maven中pom文件内scope标签中import、parent 、dependencies、dependencyManagement详解

    首先介绍parent 如果父项目中有这些依赖
  • linux-sed命令

    目录 1 linux shell sed获取某一段字符串 2 linux shell shell脚本中 sed n取出某一行赋给一个变量 3 linux shell sed查询某一行 1 linux shell sed获取某一段字符串 如果
  • 网络层(四)

    网络层 我们说过 网络层主要讲的就是ip编址和路由选择算法 更准确的说 应该是网际IP协议 网际IP协议主要说明了各个主机和服务器的ip编址规则 了解IP编址前 我们需要知道IP数据报 IP数据报在网络层中传输 我们看一下IP数据报的结构
  • STM32F103ZET6【标准库函数开发】------PB3,PB4当做普通IO口,重定义

    一 如题 我在设计原理图的时候将PB3和PB4当做了普通IO口 结果按照一般配置的方法操作后 PB3 PB4并没有输出自己想要的信号 配置如下 void MOTOR GPIO Init void 初始化 GPIO InitTypeDef G
  • 人社练兵比武怎样挣积分 python 源码在线答题源码

    可以自动答题积分 不明白如何用的可以联系我 下面2个函数是学练习的 需要用的库为selenium time re pickle 题库需要收集 def dan 单选或多选 j browser find element by xpath id
  • java绘制(可视化)树结构图

    以JPanel组件为画板 继承JPanel类并重写paint Graphics g 函数 在函数中使用画笔g绘制树结构图 实例代码 3个java源文件 Main java DrawNode java DrawTree java 1 Main
  • 周订单量趋势

    周订单量趋势 PreAuthorize hasAuthority admin statistics home chart order week ApiOperation value 周订单量趋势 RequestMapping value c
  • 《Perl语言入门》读书笔记(六)哈希

    1 哈希特点 哈希是一种数据结构 与数组相同点 能容纳任意多的值 而哈希的检索方式与数组不同 数组是以数字下标检索 而哈希中的值 value 以唯一的名字 key 检索 key value一一对应 乱序排列 类似一桶数据 由于检索方式不同
  • 程序记录(一)VGG16猫狗分类

    import torch from torchvision import datasets models transforms import os from torch utils data import DataLoader from t