使用 PyTorch 对自定义数据集进行二分类(基于Vision Transformer)

2023-11-02

内容

简短描述:ViT 的简短描述。

编码部分:使用 ViT 对自定义数据集进行二分类。

附录:ViT hypermeters 解释。

简短描述

视觉转换器是深度学习领域中流行的转换器之一。在视觉转换器出现之前,我们不得不在计算机视觉中使用卷积神经网络来完成复杂的任务。随着视觉转换器的引入,我们获得了一个更强大的计算机视觉任务模型。在本文中,我们将学习如何将视觉转换器用于图像分类任务。

下图总结了 Vision Transformer 的分类过程:

编码部分

第 1 步:创建 anaconda 环境并设置所需的库。

下载requirements.txt(链接如下),放在你VIT相关的工程文件夹下,激活anaconda环境:

https://drive.google.com/uc?export=download&id=14xiSObMiBNRPSbwyevZ_hRRk7V3R-txF

conda create --name vit_project python=3.8
conda activate vit_project
pip install -r requirements.txt

第 2 步:自定义数据集的文件夹结构。

确保分类数据集的文件夹结构与下图中的相同:

第 3 步:编码

用到的库:

from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
import torch.utils.data as data
import torchvision
from torchvision.transforms import ToTensor
torch.cuda.is_available()

超参数:

# Hyperparameters:
batch_size = 64 
epochs = 20
lr = 3e-5
gamma = 0.7
seed = 142
IMG_SIZE = 128
patch_size = 16
num_classes = 2

数据加载器:

train_ds = torchvision.datasets.ImageFolder("dataset_new_split/train", transform=ToTensor())
valid_ds = torchvision.datasets.ImageFolder("dataset_new_split/val", transform=ToTensor())
test_ds = torchvision.datasets.ImageFolder("dataset_new_split/test", transform=ToTensor())

# Data Loaders:
train_loader = data.DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
valid_loader = data.DataLoader(valid_ds, batch_size=batch_size, shuffle=True,  num_workers=4)
test_loader  = data.DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=4)

构建模型:

# Training device:
device = 'cuda'

# Linear Transformer:
efficient_transformer = Linformer(dim=128, seq_len=64+1, depth=12, heads=8, k=64)

# Vision Transformer Model: 
model = ViT(dim=128, image_size=128, patch_size=patch_size, num_classes=num_classes, transformer=efficient_transformer, channels=3).to(device)

# loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)

# Learning Rate Scheduler for Optimizer:
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

自定义模型训练:

# Training:
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

        with torch.no_grad():
            epoch_val_accuracy = 0
            epoch_val_loss = 0
            
        for data, label in valid_loader:
            
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

模型保存和加载以备将来使用:

# Save Model:
PATH = "epochs"+"_"+str(epochs)+"_"+"img"+"_"+str(IMG_SIZE)+"_"+"patch"+"_"+str(patch_size)+"_"+"lr"+"_"+str(lr)+".pt"
torch.save(model.state_dict(), PATH)

模型评估——准确性:

# Performance on Valid/Test Data
def overall_accuracy(model, test_loader, criterion):
    
    '''
    Model testing 
    
    Args:
        model: model used during training and validation
        test_loader: data loader object containing testing data
        criterion: loss function used
    
    Returns:
        test_loss: calculated loss during testing
        accuracy: calculated accuracy during testing
        y_proba: predicted class probabilities
        y_truth: ground truth of testing data
    '''
    
    y_proba = []
    y_truth = []
    test_loss = 0
    total = 0
    correct = 0
    for data in tqdm(test_loader):
        X, y = data[0].to('cpu'), data[1].to('cpu')
        output = model(X)
        test_loss += criterion(output, y.long()).item()
        for index, i in enumerate(output):
            y_proba.append(i[1])
            y_truth.append(y[index])
            if torch.argmax(i) == y[index]:
                correct+=1
            total+=1
                
    accuracy = correct/total
    
    y_proba_out = np.array([float(y_proba[i]) for i in range(len(y_proba))])
    y_truth_out = np.array([float(y_truth[i]) for i in range(len(y_truth))])
    
    return test_loss, accuracy, y_proba_out, y_truth_out


loss, acc, y_proba, y_truth = overall_accuracy(model, test_loader, criterion = nn.CrossEntropyLoss())


print(f"Accuracy: {acc}")

print(pd.value_counts(y_truth))

模型评估——ROC 曲线:

# Plot ROC curve:

def plot_ROCAUC_curve(y_truth, y_proba, fig_size):
    
    '''
    Plots the Receiver Operating Characteristic Curve (ROC) and displays Area Under the Curve (AUC) score.
    
    Args:
        y_truth: ground truth for testing data output
        y_proba: class probabilties predicted from model
        fig_size: size of the output pyplot figure
    
    Returns: void
    '''
    
    fpr, tpr, threshold = roc_curve(y_truth, y_proba)
    auc_score = roc_auc_score(y_truth, y_proba)
    txt_box = "AUC Score: " + str(round(auc_score, 4))
    plt.figure(figsize=fig_size)
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1],'--')
    plt.annotate(txt_box, xy=(0.65, 0.05), xycoords='axes fraction')
    plt.title("Receiver Operating Characteristic (ROC) Curve")
    plt.xlabel("False Positive Rate (FPR)")
    plt.ylabel("True Positive Rate (TPR)")
#     plt.savefig('ROC.png')
plot_ROCAUC_curve(y_truth, y_proba, (8, 8))

模型评估混淆矩阵

from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

net = model
# iterate over test data
for inputs, labels in test_loader:
        output = net(inputs) # Feed Network

        output = (torch.max(torch.exp(output), 1)[1]).data.cpu().numpy()
        y_pred.extend(output) # Save Prediction
        
        labels = labels.data.cpu().numpy()
        y_true.extend(labels) # Save Truth

# constant for classes
classes = ('cats', 'dogs')

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix/np.sum(cf_matrix), index = [i for i in classes],
                     columns = [i for i in classes])
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True)
# plt.savefig('cm.png')

新图像的模型推理:

# Inference on Single Images (cats-dogs):
test_image = "new_cat_image.jpg"
test_image_null = "new_dog_image.png"
image = Image.open(test_image)
image_null = Image.open(test_image_null)

# Define tensor transform and apply it:
data_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
image_t = data_transform(image).unsqueeze(0)
image_null_t = data_transform(image_null).unsqueeze(0)

# Labels:
for inputs, labels in test_loader:
        labels = labels.data.cpu().numpy()

# Prediction:
out_cat = model(image_t)
out_dog= model(image_null_t)
print("predicted cat tensor:", out_cat)
print("predicted dog tensor:", out_dog)
print("")
# Print:
if(labels[out_cat.argmax()]== 0):
    print("smoke")
else:
    print("else")
    
# Show Image:
plt.figure(figsize=(2, 2))
plt.imshow(image)
plt.show()
# Print:
if(labels[out_dog.argmax()]== 0):
    print("cat")
else:
    print("dog")
    
# Show Image Null:
plt.figure(figsize=(2, 2))
plt.imshow(image_null)
plt.show()

附录 :

1. image_size: int (w 或 h 的最大尺寸)

2. patch_size: int (# of patches, image_size 必须能被 patch_size 整除,必须大于 16)

3. num_classes: int (# of classes)

4. dim: int(线性变换后输出张量最后一维nn.Linear(..,dim))

5. depth: int (# of transformer blocks)

6. heads: int (# of heads in Multi-head Attention layer)
7. mlp_dim: int(MLP-前馈层的维度)
8. channels: int (图像通道 = 3)
9. dropout:float(在[0,1]之间——神经元的dropout率)
10. emb_dropout(在[0,1]之间——嵌入的dropout率——通常为0)

ViT 学习率和损失函数:

Optimizer: ADAM 优化器:ADAM

学习率:StepLR(每 #(step_size) 个纪元通过 gamma 衰减 LR)

损失函数:CrossEntropy(记得也试试 BinaryCrossEntropy:nn.BCELoss())

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 PyTorch 对自定义数据集进行二分类(基于Vision Transformer) 的相关文章

随机推荐

  • 哪些循环依赖问题Spring解决不了?

    前言 大家都知道 Spring 解决了循环依赖的问题 网上也可以搜到 Spring 是使用三级缓存来解决循环依赖的 但有些时候循环依赖问题还是会导致启动报错 也就说明 在某些情况下 Spring 是没有办法解决循环依赖问题的 我们就来探究一
  • JIRA工作流节点状态变化前弹出窗口填写日志或者备注

    一 定义弹出框的页面 1 进入问题管理页面 并点添加屏幕 2 自定义弹出页面的名称 3 点击添加后会进入配置页面 配置页面所包含的字段 二 在工作流中配置页面 1 进入工作流编辑页面双击需要添加弹出页面的流程 2 在弹出框中页面栏选择刚刚配
  • Python爬虫必备:浏览器开发者工具的使用,非常详细

    最近很多小伙伴说 不会用浏览器开发者工具 今天我们就一起来深入了解一下开发者工具 以谷歌浏览器为例 谷歌浏览器开发者工具中的Network 是我们学习经常用到的 那么你都知道他们每个功能的意义吗 因本人经常有360极速浏览器 谷歌内核 所以
  • vue pdf.js统计pdf的页数

    参考链接作者原文展示了PDF 我只需要一部分功能
  • Adobe进军AI第一步——Firefly试用体验

    在关于人工智能讨论度高居不下的今天 各个行业的领路企业也纷纷不甘落后 Adobe作为媒体界的行业标杆 就在近期推出了自己的人工智能图像应用 萤火虫firefly 虽然这只萤火虫刚刚 起飞 它已经展现的文字生图和能力算是及格 我分别在网页版和
  • spark-submit 碰到 Spark-submit:System memory 466092032 must be at least 471859200

    在利用spark进行分布式计算时 home hadoop spark spark 2 4 0 bin hadoop2 7 bin spark submit master yarn ALS py 以上代码是在centos7 利用spark集群
  • vim 一段代码整体移动

    方法1 可以用ctrl v 然后上下移动光标 再shift i进入编辑模式 然后按删除或者空格或者tab键来移动第一行 然后按ESC 就能整段代码动起来了 方法2 1 点击 esc 键进入命令模式 使用 set nu 显示行号 2 点击 e
  • 又是第一!GBASE南大通用蝉联中国分析型数据库管理系统市场TOP1

    报告指出 大数据时代 用户对数据分析的需求不断提升 希望从大量数据中获得新的数据价值 数据分析需求不断上升 分析型数据库市场保持稳定增长 GBASE南大通用作为分析型数据库市场的代表企业 位居本土厂商第一名 在分析型数据库市场 GBASE自
  • 插件分享

    前言 要问我Goby怎么样 我会坚定回答你 最强实时网络空间测绘 没有之一 初次发现Goby还是来自于同事 hq404的推荐 看完第一反应 真漂亮 我馋了 我要xxxxxx 其Logo和UI做的相当棒 当然不仅拥有华丽的外表 更让我深爱又离
  • python爬取新发地菜价

    import requests from bs4 import BeautifulSoup import csv url http www xinfadi com cn marketanalysis 0 list 1 shtml respo
  • 【机试练习】【C++】【PAT A1053】Path of Equal Weight(玄学一样的“段错误”)

    此题有较大的玄学 如果将cmp函数的默认返回值更改为true 则会出现最后一个测试用例的 段错误 在代码中以 我的天 玄学 标识出 include
  • Java阻塞队列

    目录 一 阻塞队列的特点 二 生产者 消费者 存在问题 三 阻塞队列 Java实现 属性 方法 put方法 生产者 线程专门调用的方法 get方法 消费者 线程专门调用的方法 执行顺序分析 图解 在我们上图的代码当中 如果把while改成i
  • Sharding-JDBC(八)5.3 系列升级解读

    目录 一 背景 二 影响范围 1 Maven 坐标调整 2 自定义算法调整 3 事务调整 4 配置文件调整 三 升级指导 1 新的 ShardingSphereDriver 数据库驱动 2 正在使用 Spring Boot Starter
  • 2023华为OD机试真题【找朋友/单调栈】

    题目描述 在学校中 N个小朋友站成一队 第i个小朋友的身高为height i 第i个小朋友可以看到的第一个比自己身高更高的小朋友j 那么j是i的好朋友 要求j gt i 请重新生成一个列表 对应位置的输出是每个小朋友的好朋友位置 如果没有看
  • python爬虫系列5--xpath

    教程地址 http www runoob com xpath xpath tutorial html XPath在python的爬虫学习中 起着举足轻重的地位 对比正则表达式re两者可以完成同样的工作 实现的功能也差不多 但XPath明显比
  • 用 STM32 通用定时器做微秒延时函数(STM32CubeMX版本)

    概述 在使用 DHT11 的时候 时序通信需要微秒来操作 STM32CubeMX 自带一个系统时钟 但是实现的是毫秒级别的 因此就自己用通用计时器实现一个 文章目录 概述 1 配置定时器时钟 2 计数器时钟频率及计数模式 预分频系数 计数器
  • tomcat调优的几个方面

    和早期版本相比最新的Tomcat提供更好的性能和稳定性 所以一直使用最新的Tomcat版本 现在本文使用下面几步来提高Tomcat服务器的性能 增加JVM堆内存大小 修复JRE内存泄漏 线程池设置 压缩 数据库性能调优 Tomcat本地库
  • css画间距可控制的虚线

    借助linear gradient dash div margin left 50px margin right 50px height 10px background linear gradient to left transparent
  • linux git代码明明是最新版本的,status为啥全是modified?

    解决办法 依次执行以下两句代码 git rm cached r git reset hard
  • 使用 PyTorch 对自定义数据集进行二分类(基于Vision Transformer)

    内容 简短描述 ViT 的简短描述 编码部分 使用 ViT 对自定义数据集进行二分类 附录 ViT hypermeters 解释 简短描述 视觉转换器是深度学习领域中流行的转换器之一 在视觉转换器出现之前 我们不得不在计算机视觉中使用卷积神