Huggingface-4.8.2自定义训练

2023-11-05

Huggingface走到4.8.2这个版本,已经有了很好的封装。训练一个语言网络只需要调用Trainer.train(...)即可完成。如果要根据自己的需求修改训练的过程,比如自定义loss,输出梯度,直接修改huggingface的源码显然是不可取的了。好在huggingface提供了相应的接口,让我们可以深入到训练过程中,加入自定义的内容。根据官方的教程,有两种推荐的方法:

  1. 重载trainer中的方法,将其修改为我们需要的内容。比如trainer.compute_loss()这个函数,它定义了如何计算loss,我们只需要修改其中的逻辑,就可以自定义loss的计算。
  2. 使用callbacks。callbacks可以查看训练过程中一些关键变量的值,并根据其状态做出相应的决策,比如early stop。

关于trainer和callbacks这两个的官方文档分别是这里这里,这两个方法都可以很优雅地修改原有的逻辑。但个人感觉重载trainer的方法是一种更灵活也更强大的方法。callbacks其实只能查看提供的一些变量,并且也只是查看,不能做出修改。而重载方法可以定义任意的全新的函数。接下来给出这两种方法的两个例子。

重载方法

在官方给的教程中是一个重载compute loss的例子,这里给一个不一样的,定义trainging_step的例子,代码如下:

class PrintGradientTrainer(Trainer):

    def training_step(self, model, inputs):
        model.train()
        inputs = self._prepare_inputs(inputs)

        loss = self.compute_loss(model, inputs)

        loss.backward()
        
        # ------------------------new added codes.--------------------------
        for name, param in model.named_parameters():
            if param.requires_grad:
                if param.grad is not None:
                    print("{}, gradient: {}".format(name, param.grad.mean()))
                else:
                    print("{} has not gradient".format(name))
        # ------------------------new added codes.--------------------------
        return loss.detach()

# originally the Trainer() is called
#trainer = Trainer(
#    model=model, args=training_args, train_dataset=small_train_dataset, #eval_dataset=small_eval_dataset,
#    tokenizer=tokenizer, data_collator=data_collator
#)

# Now call the new defined PrintGradientTrainer()
trainer = PrintGradientTrainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
    tokenizer=tokenizer, data_collator=data_collator
)

trainer.train()

只给出了关键部分的代码,其他的就按照正常写即可。

Callbacks

这个方法也需要定义一个原本的TrainerCallback的子类,然后重载原有的空的callbacks方法。代码实例如下,这个例子打出了现在是第几个epoch。

class MyCallback(TrainerCallback):
    def on_step_begin(self, args, state, control, **kwargs):
        print("train step start")
        control.should_log = False
        control.should_evaluate = False
        control.should_save = False
        print('---------------------------------------',state.epoch)
        # return self.call_event("on_step_begin", args, state, control)
trainer = PrintGradientTrainer(
    model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset,
    tokenizer=tokenizer, data_collator=data_collator,callbacks=[MyCallback()]
)

在定义trainer的时候,给callbacks加入自己定义的类就可以了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Huggingface-4.8.2自定义训练 的相关文章

随机推荐

  • Open mv识别三角形的办法

    文章目录 前言 带着问题来看 一 函数 二 使用方法 1 find line segments 2 img find template 三 摄像情况及终端结果 1 find line segments 2 img find template
  • 初始C语言——利用Ascll码进行字母大小写转换

    打开Ascll码表 你会发现大写字母和小写字母之间存在这样的关系 图片来自 https img blog csdnimg cn 54404234b42348d6a33bc1c4d5ab24e5 png 小写字母的值始终比大写字母多32 de
  • Node.js

    Node js Node js基础 概念 简单的说 Node js 就是运行在服务端的 JavaScript Node js 是一个基于Chrome JavaScript 运行时建立的一个平台 Node js是一个事件驱动I O服务端Jav
  • (五)决策树

    一 决策树 决策树是监督学习算法 下面为一些样本 本质上是一种特征去结果的相关度 比如你的信贷情况与能否还贷的相关度肯定高 而你有没有结婚的相关度肯定低 二 信息增益 三 ID3算法
  • php 未支付取消订单,【php】用户提交订单,30分钟后没付款取消订单功能分析

    我先在要做这样的功能 用户在创建订单后 订单表中记入的是未付款状态 如果用户在30分钟后 还未付款 然后就把该订单给取消 关于用户创建订单 30分钟后还没付款 取消该订单的逻辑是怎么实现的 我自己的想了两个方案 1 客户端记入这个订单 如果
  • MindNode 5 for Mac(思维导图软件)中文版

    绘制流程图 思维导图 规划图 信息图等自然少不了这款MindNode 5 for Mac 作为优质的思维导图软件 mindnode5 mac破解版的功能很全面 添加文字 链接 图片 扩展注释等非常便捷 而且mindnode 5 破解版会智能
  • Rocketmq原理&最佳实践

    一 MQ背景 选型 消息队列作为高并发系统的核心组件之一 能够帮助业务系统解构提升开发效率和系统稳定性 主要具有以下优势 削峰填谷 主要解决瞬时写压力大于应用服务能力导致消息丢失 系统奔溃等问题 系统解耦 解决不同重要程度 不同能力级别系统
  • Python开发篇——基于React-Dropzone开发上传组件

    这次我要讲述的是在React Flask框架上开发上传组件的技巧 我目前主要以React开发前端 在这个过程中认识到了许多有趣的前端UI框架 React Bootstrap Ant Design Material UI Bulma等 而比较
  • Linux操作系统知识点总结

    1 什么是Linux系统 Linux 全称GNU Linux 是一种免费使用和自由传播的类UNIX操作系统 其内核由林纳斯 本纳第克特 托瓦兹 Linus Benedict Torvalds 于1991年10月5日首次发布 它主要受到Min
  • Qt 实现自定义Ui控件例子,以自定义的Slider为例(QWidget)

    说明 Qt可以比较方便地实现自定义控件在Qt Creator中使用 网上也有很多大神的控件可以使用 但是如果想要自己简单定制也可以按照这个流程 本文的要点 1 如何实现一个自定义控件 本文使用的方法有两个步骤 先在一个普通项目中实现使用 新
  • FreeRTOS学习笔记(3、信号量、互斥量的使用)

    FreeRTOS学习笔记 3 信号量 互斥量的使用 前言 往期学习笔记链接 学习工程 信号量 semaphore 两种信号量的对比 信号量的使用 1 创建信号量 2 give 3 take 4 删除信号量 使用计数型信号量实现同步功能 使用
  • zookeeper结构和命令

    zookeeper特性 1 Zookeeper 一个leader 多个follower组成的集群 2 全局数据一致 每个server保存一份相同的数据副本 client无论连接到哪个server 数据都是一致的 3 分布式读写 更新请求转发
  • 选择、插入、归并、希尔、快速排序算法性能比较总结

    1 概述 本文对比较常用且比较高效的排序算法进行了总结和解析 并贴出了比较精简的实现代码 包括选择排序 插入排序 归并排序 希尔排序 快速排序等 算法性能比较如下图所示 2 选择排序 选择排序的第一趟处理是从数据序列所有n个数据中选择一个最
  • MyBatis-扩展-PageHelpler分页插件使用

    PageHelper是MyBatis中非常方便的第三方分页插件 官方文档 https github com pagehelper MybatisPageHelper blob master README zh md 我们可以对照官方文档的说
  • tomcat的日志记录有哪些?

    Tomcat 是一个常用的 Java Web 服务器 它可以生成各种类型的日志记录 以下是 Tomcat 的一些常见日志记录 访问日志 Access Logs 记录所有进入 Tomcat 服务器的 HTTP 请求 这些日志包含有关请求的详细
  • jdk源码调试显示变量

    原文地址 http my oschina net xionghui blog 497361 Java是一门开源的程序设计语言 喜欢研究源码的java开发者总会忍不住debug一下jdk源码 虽然官方的jdk自带了源码包src zip 然而在
  • LeetCode 面试题01.01. 判定字符是否唯一的两种解法

    本文唯一重点 按位取与的运算优先级比较低 至少比 和 都低 注意加括号 题目概述 题解 一 哈希表 思路是简单的 用第一个下标做字符 第二个下标做字符出现的次数 先遍历一遍字符串 把次数都统计好 然后再遍历一遍字符串 如果查询到某个字符的c
  • 解决中文乱码问起

    Java对数据库进行CRUD操作出现乱码 先查看web xml有没有配置字符编码过滤器
  • es6把多个class方法合并在一起

    前言 es6新增的class方法 现在想把他们多个合并到一起 最终生成一个新方法出来 思路 我们新建3个文件 分别为index js login js main js login js 和 main js是两个 class函数 将他们合并到
  • Huggingface-4.8.2自定义训练

    Huggingface走到4 8 2这个版本 已经有了很好的封装 训练一个语言网络只需要调用Trainer train 即可完成 如果要根据自己的需求修改训练的过程 比如自定义loss 输出梯度 直接修改huggingface的源码显然是不