图像分类如何得到每一类的预测概率?(结合python代码)

2023-11-16

要得到每一类的预测概率,首先通过torch.eq判断每个图片预测的准不准确,循环每个预测结果,得到没个结果对应的标签,如果准确,在该标签类的正确数量加一,在该类的总的数量加一。最后输出该类正确的数量除以该类总的数量就得到了该类的预测概率了。

# 查看单类准确率
        classes = ('0', '1', '2', '3','4')
        N_CLASSES = 5
        class_correct = list(0. for i in range(N_CLASSES))
        class_total = list(0. for i in range(N_CLASSES))
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                # print(val_labels.shape)
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                c = torch.eq(predict_y, val_labels.to(device)).squeeze()
                size = int(val_labels.shape[0])
                for i in range(size):
                    label = val_labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

若该分类任务存在类间分类,每一类差距很小,想要使预测结果处于相邻类就算分类正确时,则需要先将val_loader的batch_size设置为1,再通过一系列if语句实现该效果。

# 查看单类准确率
        classes = ('0', '1', '2', '3','4')
        N_CLASSES = 5
        class_correct = list(0. for i in range(N_CLASSES))
        class_total = list(0. for i in range(N_CLASSES))
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]
                
                labels = val_labels.numpy()
                predict = predict_y.cpu().numpy()
                
                if labels == 0:
                    if predict==0 or predict==1:
                        c = True
                    else:
                        c = False
                elif labels == 4:
                    if predict==3 or predict==4:
                        c = True
                    else:
                        c = False
                else:
                    if predict==labels-1 or predict==labels or predict==labels+1:
                        c = True
                    else:
                        c = False
                
                size = int(val_labels.shape[0])
                for i in range(size):
                    label = val_labels[i]
                    
                    class_correct[label] += c
                    class_total[label] += 1

                acc += c
                val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))
        for i in range(N_CLASSES):
            print('Accuracy of %5s : %2d %%' % (
                classes[i], 100 * class_correct[i] / class_total[i]))
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

图像分类如何得到每一类的预测概率?(结合python代码) 的相关文章

随机推荐

  • git 仓库迁移

    git 仓库迁移 文章目录 git 仓库迁移 在目标服务器建立新的git 功能仓库 设置git 仓库源 上传代码 验证是否成功 git远程仓库地址查看 在目标服务器建立新的git 功能仓库 git VM 0 5 centos git ini
  • MySQL中的IF语句使用

    MySQL中的IF语句 在 MySQL 数据库中 IF 语句是一种常见的条件控制语句 它可以根据指定的条件返回不同的结果 在本文中 我们将介绍 IF 语句的基本用法以及实际应用场景 IF函数 MySQL 提供了 IF 函数来实现 IF 语句
  • xcode4的自动完成功能(Code sense or Code Snippet)

    社区会员rainbird分享 自动完成包括两种含义 一种是输入字母的时候可以动态弹出一个列表 然后通过选择 提高输入效率 这种好像叫代码提示 Code sense 另一种就是输入几个字母的时候一回车 出来一串儿字符 Code Snippet
  • 把一个对象 转为JSON格式的方法

    List
  • svn的使用手册

    svn的使用手册 svn的使用手册 svn介绍 安装svn 安装VisualSVN server 安装TortoiseSVN 安装EclipseSVN插件 使用SVN Eclipse下使用SVN 合并冲突 分支 svn的使用手册 svn介绍
  • SpringBoot 实现定时任务

    定时任务 一 使用背景 二 定时任务的优点 三 SpringBoot 实现定时任务 3 0 项目结构 3 1 pom xml 3 2 启动类 3 3 服务类 3 4 cron表达式 3 4 1 时间范围 3 4 2 特殊字符 3 4 3 c
  • 启明云端分享

    提示 启明云端从2013年起就作为Espressif 乐鑫科技 大中华区合作伙伴 我们不仅用心整理了你在开发过程中可能会遇到的问题以及快速上手的简明教程 同时也用心推出了基于乐鑫的相关应用方案 希望你能第一时间了解并快速用上好的方案和产品
  • 微信支付接口常用参数及证书区分

    注意 服务商模式下 均是使用服务商的以下信息 1 证书 1 1商户api证书 v2和v3接口都需要使用 1 1 1获取方式 什么是商户API证书 如何获取商户API证书 商户api证书 里面介绍了如何获取商户证书的详细步骤 1 1 2作用
  • MyCAT 连接MySQL 8 注意事项

    一 问题产生 MyCat是一个基于MySQL协议的开源的分布式中间件 其核心是分库分表 但是目前MyCat仍主要面对MySQL 5 5 5 6 5 7版 对最新的MySQL 8尚未完全支持 需要用户对MySQL 8和MyCat的配置进行一系
  • Unity3d之Socket UDP协议

    原文地址 http blog csdn net dingkun520wy article details 49201245 一 Socket 套接字 UDP协议的特点 1 是基于无连接的协议 没有生成连接的延迟所以速度比TCP快 2 支持一
  • linux系统如何进入屏保,linux上屏保设置

    linux下屏保设置 Linux文本终端 字符界面屏保取消 在我们日常使用Linux过程中 经常遇到使用屏幕终端一段时间后 显示器关 闭 屏幕上没有任何显示 一段时间后 屏幕就会关闭 无任何显示 若此时系统死机或僵死 而且屏幕上有输出 当遇
  • 如何用js替换文本里的换行符 \n?

    如何用js替换文本里的换行符 n 有下面一段文本 在编辑器里的格式如下 div line1line2line3 div 切换到浏览器 显示如下 line1line2line3 这里我想使浏览器显示效果变成如下形式 line1 line2 l
  • python 多线程示例

    python 多线程示例 import queue import time import threading import threading from datetime import datetime 创建一个线程安全的队列 q queu
  • Moveit简单使用,在rviz中实现手动拖动-记录

    GAZEBO下载 一 首先需要准备模型文件 可以是自己的solidworks用URDF工具导出的 也可以是在网上下载的URDF文件包 1 我用的是solidworks手动导出的模型 b站博主导出SOLIDWORKS模型至URDF这个教程比较
  • 报错解决:SyntaxError: Non-UTF-8 code starting with ‘\xe7‘

    今天抓取数据时使用re对数据进行提取时遇到的问题 syntaxError Non UTF 8 code starting with xe7 意思是有的中文字符无法转成utf 8的形式 如图所示 这个是因为抓取的数据中有的中文字符识别不了 相
  • 深入理解 Spring 控制反转与依赖注入

    概览 对于 Spring 框架来说 控制反转 Inversion of Control IoC 和依赖注入 Dependency Injection DI 是个等同的概念 控制反转是通过依赖注入实现的 在这篇文章中 我们会详细介绍 IoC
  • 使用VS Code静态检查Android C/C++代码(clangd插件)

    前言 在前文使用VS Code更好的编写Android C C 代码 C C 插件 中主要介绍了如何更好的写代码 本文要探讨的是从 好写 到 写好 的问题 如何做静态代码检查 在查找资料中发现了Cppcheck和Clang Tidy等工具
  • 学位房如何查询学位真实性和户口是否被占用

    查户口有没占用 需要业主带上身份证 房产证去公安局户籍窗口查 他会口头告诉你这个地址的户口有没有人 不会出书面的东西 所以一定要听清楚 其实你和业主签三方合同的时候可以注明户口这方面的东西 比如多少号之前要迁走之类的 拿着房产证去公安局查户
  • JDBC连接MySQL8.0案例详解

    JDBC本质上是一个介于应用程序和数据库之间的公共接口 通过对这个接口的实现 我们可以建立应用程序和数据库之间的连接 便捷的访问数据库数据 不同版本的MySQL连接的参数是有一些小差别的 以下内容基于一个JDBC连接案例讲解连接数据库的过程
  • 图像分类如何得到每一类的预测概率?(结合python代码)

    要得到每一类的预测概率 首先通过torch eq判断每个图片预测的准不准确 循环每个预测结果 得到没个结果对应的标签 如果准确 在该标签类的正确数量加一 在该类的总的数量加一 最后输出该类正确的数量除以该类总的数量就得到了该类的预测概率了