微调Hugging Face中图像分类模型

2023-11-05

前言

  • 本文主要针对Hugging Face平台中的图像分类模型,在自己数据集上进行微调,预训练模型为Googlevit-base-patch16-224模型,模型简介页面
  • 代码运行于kaggle平台上,使用平台免费GPU,型号P100,笔记本地址,欢迎大家copy & edit
  • Github项目地址Hugging Face模型微调文档

依赖安装

  • 如果是在本地环境下运行,只需要同时安装3个包就好transformersdatasetsevaluate,即pip install transformers datasets evaluate
  • 在kaggle中因为accelerate包与环境冲突,所以需要从项目源进行安装,即:
import IPython.display as display
! pip install -U git+https://github.com/huggingface/transformers.git
! pip install -U git+https://github.com/huggingface/accelerate.git
! pip install datasets
display.clear_output()
  • 因为安装过程中会产生大量输出,所以使用display.clear_output()清空jupyter notebook的输出。

数据处理

  • 这里使用kaggle中的图像分类公共数据集,5 Flower Types Classification Dataset,数据结构如下:
 - flower_images
	 - Lilly
		 - 000001.jpg
		 - 000002.jpg
		 - ......
	 - Lotus
		 - 001001.jpg
		 - 001002.jpg
		 - ......
	 - Orchid
	 - Sunflower
  • 可以看到flower_images为主文件夹,Lilly,Lotus,Orchid,Sunflower为各类花的种类,每类花的图片数量均为1000张
  • 微调模型图像的数据集读取与加载需要使用datasets包中的load_dataset函数,有关该函数的文档
from datasets import load_dataset
from datasets import load_metric
# 加载本地数据集
dataset = load_dataset("imagefolder", data_dir="/kaggle/input/5-flower-types-classification-dataset/flower_images")
# 整合数据标签与下标
labels = dataset["train"].features["label"].names

label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

metric = load_metric("accuracy")
display.clear_output()
  • 如果想要查看图片,可以使用image来访问
example = dataset["train"][0]
example['image'].resize((224, 224))

请添加图片描述

  • 确定想要进行微调的模型,加载其配置文件,这里选择vit-base-patch16-224,关于transfromers包中的AutoImageProcessor类,from_pretrained方法,请参见文档
from transformers import AutoImageProcessor
model_checkpoint = "google/vit-base-patch16-224"
batch_size = 64
image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor 
  • 根据vit-base-patch16-224预训练模型图像标准化参数标准化微调数据集,都是torchvision库中的一些常见变换,这里就不赘述了,重点是preprocess_trainpreprocess_val函数,分别用于标准化训练集与验证集。
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch
  • 划分数据集,并分别将训练集与验证集进行标准化
# 划分训练集与测试集
splits = dataset["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

display.clear_output()

微调模型

  • 加载预训练模型使用transformers包中AutoModelForImageClassification类,from_pretrained方法,参考文档
  • 需要注意的是ignore_mismatched_sizes参数,如果你打算微调一个已经微调过的检查点,比如google/vit-base-patch16-224(它已经在ImageNet-1k上微调过了),那么你需要给from_pretrained方法提供额外的参数ignore_mismatched_sizes=True。这将确保输出头(有1000个输出神经元)被扔掉,由一个新的、随机初始化的分类头取代,其中包括自定义数量的输出神经元。你不需要指定这个参数,以防预训练的模型不包括头。
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(model_checkpoint, 
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes = True)
display.clear_output()
  • 配置训练参数由TrainingArguments函数控制,该函数参数较多,参考文档
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-eurosat",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit = 5,
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",)
  • 我解释一下上面出现的一些参数
    • output_dir:模型预测和检查点的输出目录
    • remove_unused_columns:是否自动删除模型转发方法未使用的列
    • evaluation_strategy: 在训练期间采用的评估策略
    • save_strategy:在训练期间采用的检查点保存策略
    • save_total_limit:限制检查点的总数,删除较旧的检查点
    • learning_rateAdamW优化器的初始学习率
    • per_device_train_batch_size:训练过程中GPU/TPU/CPU核心batch大小
    • gradient_accumulation_steps:在执行向后/更新传递之前累积梯度的更新步数
    • per_device_eval_batch_size:评估过程中GPU/TPU/CPU核心batch大小
    • num_train_epochs:要执行的训练时期总数
    • warmup_ratio:用于学习率从0到线性预热的总训练步数的比率
    • logging_steps:记录steps间隔数
    • load_best_model_at_end:是否在训练结束时加载训练期间找到的最佳模型
    • metric_for_best_model:指定用于比较两个不同模型的指标
  • 制定评估指标函数
import numpy as np
import torch

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}
  • 传递训练配置,准备开始微调模型,Trainer函数,参考文档
trainer = Trainer(model,
                  args,
                  train_dataset=train_ds,
                  eval_dataset=val_ds,
                  tokenizer=image_processor,
                  compute_metrics=compute_metrics,
                  data_collator=collate_fn,)
  • 同样的,我解释一下上面的一些参数
    • model:训练、评估或用于预测的模型
    • args:调整训练的参数
    • train_dataset:用于训练的数据集
    • eval_dataset:用于评估的数据集
    • tokenizer:用于预处理数据的标记器
    • compute_metrics:将用于在评估时计算指标的函数
    • data_collator:用于从train_dataseteval_dataset的元素列表形成批处理的函数
  • 开始训练,并在训练完成后保存模型权重,模型训练指标变化,模型最终指标。
train_results = trainer.train()
# 保存模型
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
  • 在训练过程中可选择使用wandb平台对训练过程进行实时监控,但需要注册一个账号,获取对应api,个人推荐使用,当然也可以ctrl+q选择退出。
  • 训练输出:
Epoch	Training Loss	Validation Loss	Accuracy
1	0.384800	0.252986	0.948000
2	0.174000	0.094400	0.968000
3	0.114500	0.070972	0.978000
4	0.106000	0.082389	0.972000
5	0.056300	0.056515	0.982000
6	0.044800	0.058216	0.976000
7	0.035700	0.060739	0.978000
8	0.068900	0.054247	0.980000
9	0.057300	0.058578	0.982000
10	0.067400	0.054045	0.980000
11	0.067100	0.051740	0.978000
12	0.039300	0.069241	0.976000
13	0.029000	0.056875	0.978000
14	0.027300	0.063307	0.978000
15	0.038200	0.056551	0.982000
16	0.016900	0.053960	0.984000
17	0.021500	0.049470	0.984000
18	0.031200	0.049519	0.984000
19	0.030500	0.051168	0.984000
20	0.041900	0.049122	0.984000
***** train metrics *****
  epoch                    =         20.0
  total_flos               = 6494034741GF
  train_loss               =       0.1092
  train_runtime            =   0:44:01.61
  train_samples_per_second =       34.062
  train_steps_per_second   =        0.538

wandb平台指标可视化

请添加图片描述

请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述

评估模型

metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

输出:

***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =      0.984
  eval_loss               =      0.054
  eval_runtime            = 0:00:11.18
  eval_samples_per_second =     44.689
  eval_steps_per_second   =      0.715

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

微调Hugging Face中图像分类模型 的相关文章

  • 什么是JWT?

    在HTTP接口调用的时候 服务端经常需要对调用方做认证 以保证安全性 一种常见的认证方式是使用JWT Json Web Token 采用这种方式时 经常在header传入一个authorization字段 值为对应的jwt token 或者
  • python调用hive脚本_python 中写hive 脚本

    1 直接执行 sql脚本 import numpy as np import pandas as pd import lightgbm as lgb from pandas import DataFrame from sklearn mod

随机推荐

  • wkhtmltopdf 实现html 文档对象转换为pdf 文件核心功能代码

    环境说明 环境 windows10 客户端软件 wkhtmltopdf 记得配置wkhtmltopdf 全局环境变量 相关的安装步骤可以baidu或者google wkhtmltopdf 安装包已经上传csdn Java 核心功能代码 添加
  • javaweb 如何在前端根据数据画出图像曲线

    一个实现画板的程序 与我的项目没啥关系 发现一个实现图表的js程序 chartjs官网 我一会儿得学学 echarts的js的实例 唉 找不到一个好的方法 看了看echarts的官方文档 发现echarts确实非常强悍相比如chartjs来
  • 龙书笔记

    1 我们可以设置第四个参数w 当w设置为1时 为了让点可以恰当的转变 当w设置为0时 为了防止向量被平移 2 一个平面 n d 可以被当做一个4d向量来交换 将这个4D向量乘期望的变换矩阵的逆矩阵就可以了 3 顶点操作 并非所有的显卡都支持
  • 小程序练手

    上个星期学了一下小程序 然后写了一个项目练练手 主要实现了三个功能 实现文件的上传功能 实现评论功能 实现展示功能 这里记录一下云开发几个重要的点 首先的是文件的上传并预览 wxml
  • MYSQL02高级_目录结构、默认数据库、表文件、系统独立表空间

    文章目录 MySQL目录结构 查看默认数据库 MYSQL5 7和8表文件 系统 独立表空间 MySQL目录结构 如何查看关联mysql目录 root mysql8 find name mysql var lib mysql var lib
  • SpringSecurity学习笔记(九)RememberMe进阶

    参考视频 编程不良人 前面我们介绍了rememberMe的实现原理 从中我们可以思考这样一个问题 如果我们的cookie被非法用户获取 然后携带这个cookie进行访问我们的项目中的内容 就会导致非法用户登录 这个问题怎么解决呢 Remem
  • MySQL 字符串函数:字符串截取

    MySQL 字符串函数 字符串截取 MySQL 字符串截取函数 left right substring substring index 还有 mid substr 其中 mid substr 等价于 substring 函数 substr
  • linux 新建用户无 .profile 问题

    1 新建一个用户 其家目录下面默认生成什么文件由 etc skel 目录决定 就是 这个目录下面有什么新建用户后家目录就生成什么 2 新建一个用户可以由 d 参数指定家目录 如 useradd d home test u 500 g ora
  • 微信小程序wx.getUserInfo授权获取用户信息(头像、昵称)

    这个接口只能获得一些非敏感信息 例如用户昵称 用户头像 经过用户授权允许获取的情况下即可获得用户信息 至于openid这些 需要调取wx login来获取 index wxml
  • VS2013使用技巧汇总

    1 Peek View 在不新建TAB的情况下快速查看 编辑一个函数的代码 以前要看一个函数的实现 需要在使用的地方点击F12跳转到该函数 实际上这是很浪费时间的 VS2013Peek View便解决了这个问题 在光标移至某个函数下 按下a
  • go-diskqueue数据结构

    一 本文目的是介绍go diskqueue go diskqueue 应用于nsq https github com nsqio nsq 作用是存储内存装不下的消息到磁盘 并支持读取 go diskqueue https github co
  • stm32寄存器

    define RCC APB2ENR volatile unsigned int 0x40021000 0x18 RCC APB2ENR 1 lt lt 3 CRH是控制高八位引脚 CRL是控制低八位引脚 配置推挽输出 GPIOB CRH
  • Layui实现TreeTable(树形数据表格)

    参考 Layui实现TreeTable 树形数据表格 LayUI树形表格treetable使用详解 gitee ele admin treetable lay 文中涉及的treetable js 页面代码都可以在这下载gitee代码下载 直
  • Kubeadm 结合 Vagrant 自动化部署最新版 Kubernetes 集群

    之前写过一篇搭建 k8s 集群的教程 使用 kubeadm 搭建 kubernetes 集群 教程中用到了 kubeadm 和 vagrant 但是整个过程还是手动一步一步完成 创建节点 gt 节点配置 相关软件安装 gt 初始化 mast
  • mpeg gpcc编译

    gpcc 编译 1 编译 mkdir build cd build cmake make 2 生成配置文件 cd cfg 进入配置文件 bash scripts gen cdg sh 执行转换设置格式脚本 遇到 坑1 坑1 若能正常执行脚本
  • pip使用总结(持续更新)

    持续总结python pip遇到过的坑 1 pip镜像源 阿里镜像 临时 pip install xxx i http mirrors aliyun com pypi simple trusted host mirrors aliyun c
  • ms-repeat 循环

    ms repeat 可以写成 ms repeat el 后面的el 相当于给每个节点定义的变量名 还可以定义成ms repeat m避免与上级循环的变量重名 ul class times li a href el year a li ul
  • perp系列之七:perp手册

    perp系列之七 perp手册 版本说明 版本 作者 日期 备注 0 1 ZY 2019 5 29 初稿 目录 文章目录 perp系列之七 perp手册 版本说明 目录 1 该发行版包括以下手册页 perp intro 8 perp set
  • 服务器端安装jupyter notebook并在本地使用与环境配置一条龙服务【服务器上跑ipynb】

    linux服务器端安装jupyter notebook并在本地使用 1 生成配置文件 2 配置Jupyter notebook密码 3 修改配置文件 jupyter jupyter notebook config py 4 本地访问远端的服
  • 微调Hugging Face中图像分类模型

    前言 本文主要针对Hugging Face平台中的图像分类模型 在自己数据集上进行微调 预训练模型为Google的vit base patch16 224模型 模型简介页面 代码运行于kaggle平台上 使用平台免费GPU 型号P100 笔