RAF_DB数据集分类_3

2023-05-16

混淆矩阵

这里ECANet太长了,我这里直接利用resnet代替一下,你可以直接替换,然后把权重对应好即可,这只是一个简单的混淆矩阵生成,没有太多美化。

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import  DataLoader
import torchvision
from torch.nn import init
import math
from torchvision import transforms
import torch.nn.functional as F
val_path = "./RAF-DB/test"
val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])
batch_size = 256
val_data = torchvision.datasets.ImageFolder(val_path, transform=val_transform)
data1_val = DataLoader(val_data, batch_size=batch_size, shuffle=True,drop_last=True)

def plot_confusion_matrix(cm, savename, title='Confusion Matrix'):

    plt.figure(figsize=(12, 8), dpi=100)
    np.set_printoptions(precision=2)

    # 在混淆矩阵中每格的概率值
    ind_array = np.arange(len(classes))
    x, y = np.meshgrid(ind_array, ind_array)
    for x_val, y_val in zip(x.flatten(), y.flatten()):
        c = cm[y_val][x_val]
        if c > 0.001:
            plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')
    
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
    plt.title(title)
    plt.colorbar()
    xlocations = np.array(range(len(classes)))
    plt.xticks(xlocations, classes, rotation=90)
    plt.yticks(xlocations, classes)
    plt.ylabel('Actual label')
    plt.xlabel('Predict label')
    
    # offset the tick
    tick_marks = np.array(range(len(classes))) + 0.5
    plt.gca().set_xticks(tick_marks, minor=True)
    plt.gca().set_yticks(tick_marks, minor=True)
    plt.gca().xaxis.set_ticks_position('none')
    plt.gca().yaxis.set_ticks_position('none')
    plt.grid(True, which='minor', linestyle='-')
    plt.gcf().subplots_adjust(bottom=0.15)
    
    # show confusion matrix
    # plt.savefig(savename, format='png')
    plt.show()

# classes表示不同类别的名称,比如这有6个类别
classes = ['Anger', 'Disgust', 'Fear', 'Happiness','Neutral', 'Sadness','Surprise']

from torchvision import models
resnet = models.resnet18()
class SKNet(nn.Module):
    def __init__(self, num_class=7):
        super(SKNet, self).__init__()
        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, num_class)

    def forward(self, x):
        
        x = self.features(x)
        out = self.avgpool(x)
        out = torch.flatten(out,1)
        out = self.fc(out)
        return out

model_path = "./OneModel/迁移学习/ResNet18.pkl"
sknet = SKNet()
checkpoint = torch.load(model_path,map_location='cpu')
sknet.load_state_dict(checkpoint['model'])
del checkpoint



def evalute_(model,val_loader):
    model.eval()

    for batchidx, (x, label) in enumerate(val_loader):
        with torch.no_grad():
            print(batchidx)
            y1 = model(x)
            _, preds1 = torch.max(F.softmax(y1,dim=1), 1)
            if batchidx!=0:
                y = torch.cat((y,preds1),dim=0)
                labels = torch.cat((labels,label),dim=0)
            else:
                y = preds1
                labels = label
    print(y.shape)
    print(labels.shape)
    assert y.shape == labels.shape
    y = y.numpy()
    
    return y,labels

y_pred,y_true = evalute_(model=sknet,val_loader=data1_val)


# 获取混淆矩阵
cm = confusion_matrix(y_true, y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
plot_confusion_matrix(cm_normalized, './confusion_matrix.png', title='confusion matrix')
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

RAF_DB数据集分类_3 的相关文章

  • ubuntu装机并设置远程连接

    step1 ubuntu16装系统的过程略过 step2 联网 step3 apt get更新 sudo apt get update step4 安装ssh 安装 openssh 服务 sudo apt get install opens
  • ros学习之串口通信(数据读取),并进行发布

    串口参数 波特率 9600 起始位 1 数据位 8 停止位 1 奇偶校验 无 例如超声波模组地址为0X01 则主机发送 0X55 0XAA 0X01 0X01 checksum checksum 61 帧头 43 用户地址 43 指令 am
  • 在Ubuntu上使用LVM对ROOT进行在线扩容

    前提 xff1a 在安装ubuntu的时候 xff0c 是使用LVM进行分区管理的 背景 xff1a 我在安装的时候 xff0c 选择了500G大小 xff0c 磁盘总大小1T xff0c 现在想扩成1T 扩容前 xff1a yang 64
  • realsense D435i双目IMU 数据集

    realsense D435i 双目IMU数据集 使用双目 43 IMU的数据双目内参双目IMU外参 使用双目 43 IMU的数据 双目内参 model type PINHOLE camera name camera image width
  • MobaXterm 登录出现 Network error :Connection timed out

    本来用SSH连接正在操作 xff0c 突然连接不好Linux xff0c 无法登陆 xff0c 出现Network error Connection timed out错误 还以为是自己哪里操作出错了 xff0c 打开本机 cmd命令窗口
  • 消息队列总结

    一 为什么需要无锁队列 xff1f 二 无锁队列是什么 xff1f 三 无锁队列是如何实现的 xff1f span class token keyword inline span span class token class name yq
  • 姿态传感器—MPU6050

    姿态传感器 MPU6050 简介寄存器数字运动处理器 DMP遇到的问题1 初始化是要水平放置 且 按照上电时的方位为基准 xff08 正点原子提供的例程 xff09 简介 MPU6050是一款六轴 xff08 三轴加速度 43 三轴角速度
  • 卡尔曼滤波的优点总结

    卡尔曼滤波的优点不在于它的估计的偏差小多少 xff0c 而在于它巧妙的融合了观测数据与估计数据 xff0c 对误差进行闭环管理 xff0c 将误差限定在一定范围 xff0c 试想 xff0c 如果没有两者的信息融合 xff0c 只有估计数据
  • 个人简历2021

    标题 个人简历 日期 2021 09 27 23 42 57 标签 简历 分类 工作 职业发展 说下我的个人简历吧 xff0c 希望大家能够了解我 xff0c 一起在技术这条路上一直走下去 个人信息 姓名性别年龄现居地址邮箱陈作立男29上海
  • 深入理解图优化与g2o:图优化篇 - 半闲居士 - 博客园 转

    深入理解图优化与g2o xff1a 图优化篇 半闲居士 博客园
  • 二次型优化问题矩阵求导解法

    二次型求导 风之舞555 博客园 https www csdn net tags MtTaEgzsOTU2NzAxLWJsb2cO0O0O html
  • SQL2000 好书 《SQL Server 2000数据库管理与开发技术大全》----求是科技 人民邮电出版社

    SQL2000 好书 SQL Server 2000数据库管理与开发技术大全 求是科技 人民邮电出版社
  • grub启动

    grub启动 如何修复引导 现象 开机直接进入grub rescue模式 解决方案 第一步 xff1a 退出rescue模式 一般只需要设置prefix变量 span class token comment 通过ls 命令查看所有的磁盘 s
  • aruco安装 配合realsense 使用

    使用github安装 网址 xff1a http www uco es investiga grupos ava node 26 git clone到本地之后 xff0c catkin make即可开始使用 使用apt安装 span cla
  • VS连接realsense D435i摄像头(4)——使用PCL绘制点云图

    本篇主要是在使用PCL绘制点云过程中遇到的问题 xff0c 初始化参照该博客 电脑系统 xff1a win10 x64Visual Studio 2019Realsense D435i摄像头使用语言 xff1a C xff0c C 43 4
  • MobaXterm 无法显示弹框或界面

    MobaXterm 无法显示弹框或界面的解决方案之一 xff1a Settings gt Configuration gt X11 gt Xorg version xff1a 选择Mobox 1 20 4 版本越新越好 亲测可用
  • VINS_FUSION的global融合思想

    VINS FUSION的global融合思想 文章目录 VINS FUSION的global融合思想 使用全局融合的原因 GPS的缺点 融合的目的 算法架构 观测和状态约束关系 GPS 融合思路 GPS残差计算 代码段 參考文献 使用全局融

随机推荐