如何使用 Huggingface Trainer 微调 gpt-j

2024-01-08

我正在尝试使用 Huggingface 训练器微调 gpt-j 但惨败。我遵循了引用 bert 的示例,但是当然,gpt-j 模型并不完全类似于 bert 模型。

该错误表明模型没有产生损失,这很好,但我不知道如何让它产生损失或如何改变训练者的期望。

我正在使用变形金刚 4.22.2。在尝试使用 GPU 在 Paperspace 上执行任何操作之前,我希望先在 CPU 上实现此功能。我确实使用 GPU 进行了初步尝试,但收到了相同的错误,使用 cuda 的代码略有不同。

我怀疑我的做法是完全错误的。我发现了一个使用 8 位量化微调 gpt-j 的非常古老的示例,但即使该存储库也表示它已被弃用。

我不确定我的错误是否在于使用了 bert 示例中找到的compute_metrics(),或者是否是其他原因。任何意见,将不胜感激。或者,也许这是我提供配置的标签的问题,但我尝试了不同的排列。

我了解损失函数是什么,但我不知道在这种情况下应该如何配置它。

My Code:

from transformers import Trainer, TrainingArguments, AutoModelForCausalLM
from transformers import GPTJForCausalLM, AutoTokenizer
from datasets import load_dataset
import time
import torch
import os
import numpy as np
import evaluate
import sklearn

start = time.time()

GPTJ_FINE_TUNED_FILE = "./fine_tuned_models/gpt-j-6B"

print("Loading model")
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", low_cpu_mem_usage=True)
model.config.pad_token_id = model.config.eos_token_id

print("Loading tokenizer")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.pad_token = tokenizer.eos_token

print("Loading dataset")
current_dataset = load_dataset("wikitext", 'wikitext-103-v1')
current_dataset['train'] = current_dataset['train'].select(range(1200))


def tokenize_function(examples):
    current_tokenizer_result = tokenizer(examples["text"], padding="max_length", truncation=True)
    return current_tokenizer_result


print("Splitting and tokenizing dataset")
tokenized_datasets = current_dataset.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets["train"].select(range(100))

print("Preparing training arguments")

training_args = TrainingArguments(output_dir=GPTJ_FINE_TUNED_FILE,
                                  report_to='all',
                                  logging_dir='./logs',
                                  per_device_train_batch_size=1,
                                  label_names=['input_ids', 'attention_mask'],  # 'logits', 'past_key_values'
                                  num_train_epochs=1,
                                  no_cuda=True
                                  )

metric = evaluate.load("accuracy")


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset
)

print("Starting training")
trainer.train()
print(f"Finished fine-tuning in {time.time() - start}")

这会导致错误和堆栈跟踪:

  File "xxx\ft_v3.py", line 66, in <module>
  File "xxx\venv\lib\site-packages\transformers\trainer.py", line 1521, in train
    return inner_training_loop(
  File "xxx\venv\lib\site-packages\transformers\trainer.py", line 1763, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "xxx\venv\lib\site-packages\transformers\trainer.py", line 2499, in training_step
    loss = self.compute_loss(model, inputs)
  File "xxx\venv\lib\site-packages\transformers\trainer.py", line 2544, in compute_loss
    raise ValueError(
ValueError: The model did not return a loss from the inputs, only the following keys: logits,past_key_values. For reference, the inputs it received are input_ids,attention_mask.

我找到了似乎有效的方法,尽管现在我的内存不足并正在研究处理它的方法。

data_collat​​or 参数似乎解决了我遇到的确切问题。

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
    data_collator=data_collator,
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用 Huggingface Trainer 微调 gpt-j 的相关文章

随机推荐

  • Google Cloud SQL 上的 1290 错误

    我今天在使用 MySQL Workbench 时才开始收到此错误 并注意到它早在周六就出现在我的应用程序中 还有人收到吗 知道可能是什么原因造成的吗 ERROR错误代码 1290 MySQL 服务器正在使用 read only 选项运行 因
  • Microsoft Visual Studio 2017 依赖于每个 Xamarin 操作

    我最近安装了 Visual Studio 2017 当我创建 Xamarin android 项目或单击 xamarin 设置时 它会挂起 当单击任意位置时 它会显示 Microsoft Visual Studio 正忙 Xamarin 版
  • 如何将 setOnFocusChangeListener 与 RecyclerView 结合使用?

    我在 RecyclerView 的适配器类中的 onBindViewHolder 上有以下内容 holder answerEditText setOnFocusChangeListener new View OnFocusChangeLis
  • 在 VS2010 中将设计器与 WPF 的 XAML 窗口分离

    我在 Visual Studio 2010 中没有看到用于将设计器窗口与 WPF 的 XAML 窗口分开的按钮 我有三个屏幕 我想要一个全屏 XAML 窗口和一个全屏设计器窗口 像往常一样打开 XAML 文件 在解决方案资源管理器中 右键单
  • 我如何打开不同的linux终端以在python中输出不同类型的调试信息?

    我需要将不同的信息输出到不同的终端实例 而不是在同一输出流中打印它们 例如 std err 或 std out 例如 我有 5 种信息说 A E 需要显示在同一桌面上的不同终端窗口上 看起来像 终端1 终端2 端子3 端子4 端子5 我知道
  • 用前导 0 填充计数器到 9,然后用 php 删除前导零

    尝试用前导 0 填充最多 9 个 然后删除 01 02 03 04 05 06 07 08 09 10 11 12 14 到目前为止我有这个 您还可以使用str pad http us php net manual en function
  • Firefox 和 Opera 中的 Webfont 平滑和抗锯齿

    我的网站上使用了定制的网络字体 为了设置渲染输出的样式 我使用了以下代码 webkit text stroke width 05px webkit text stroke color white webkit font smoothing
  • 如何只加载某些层的权重?

    我想获取某些层的权重 不是全部 因为架构不同 来自model trained并初始化model untrained用它 我怎样才能用 Keras 做到这一点 如果你有一个函数create model 它返回一个 Keras 模型 examp
  • 如何将QMainWindow设置为模态窗口?

    我正在使用 QMainWindow 进行项目的 GUI 开发 我遇到的一个问题是当一个窗口正在运行时阻止所有其他可见窗口获取输入 我不能使用QDialog 因为需要QMainWindow的丰富功能 如何将特定窗口声明为模态窗口 我尝试过QW
  • BeanPostProcessor 混乱

    我试图理解 Spring 中的 BeanPostProcessor 但不明白它的作用 BeanPostProcessor 定义了在这些点调用的两个方法是否正确 在初始化之前 init 方法或 afterPropertiesSet 但实例已创
  • Camel Splitter并行处理数组列表-并发访问问题

    使用 Camel 拆分 ArrayList 并最多 10 个线程并行处理每个项目 以下是配置 线程池配置文件设置为最大线程数 10
  • 计算闰年的Java代码

    我正在关注 Java 的艺术与科学 一书 它展示了如何计算闰年 本书使用了ACM Java Task Force 的库 这是本书使用的代码 import acm program public class LeapYear extends C
  • C# Thread.Sleep(0) 是什么意思?

    意思是没有延迟吗 一本书上说如下 Thread Sleep 0 放弃线程的当前时间片 立即 主动将CPU交给其他线程 这是否意味着即使应该执行一条语句 给 sleep 0 也会暂时跳过执行 0表示没有minimum控制权将返回给线程之前的时
  • 使用Python列表作为队列的效率

    一位同事最近编写了一个程序 其中使用 Python 列表作为队列 换句话说 他用了 append x 当需要插入物品时 pop 0 当需要移除物品时 我知道Python有collections deque http docs python
  • 如何让 TProgressBar 停止滞后?

    我有一个运行大量操作的应用程序 并且我正在尝试使用 TProgressBar 来跟踪正在发生的情况 我设置了多个步骤 并调用 StepIt 来增加进度条 问题是 它并没有很好地跟上 它似乎不喜欢直接跳到正确的位置 而是逐渐滑动到正确的位置
  • 如何获取调用别名方法的名称?

    我正在调用方法link to admin然后我给另一个方法起了别名simple link to def link to admin name url options My stuff here link to name url option
  • 使 Http DefaultClient 的execute()非常慢

    我的 HttpDefaultClient 的 execute 方法的执行方法存在大量性能问题 我目前正在使用它来将数据发布到服务器 接收 JSON 并反序列化数据 我的手机打电话需要 8 到 30 秒 如果我切换到 Wifi 速度相当快 在
  • Wix:安装过程中忽略对话框中的属性更改

    我在 Wix 文件中有一个属性 该属性公开用于在对话框中进行编辑 在下面的示例中 它是 MyProperty 该属性用于创建注册表项 但是 如果在对话框中更改属性 则不会使用更改后的值 而是默认值 SomeProperty 但是 如果我在另
  • Python 多处理问题?

    我有一个包含 500 个输入文件的文件夹 所有文件的总大小约为 500 MB 我想写一个python执行以下操作的脚本 1 将所有输入文件加载到内存中 2 初始化一个空的python稍后将使用的列表 参见项目符号 4 3 启动 15 个不同
  • 如何使用 Huggingface Trainer 微调 gpt-j

    我正在尝试使用 Huggingface 训练器微调 gpt j 但惨败 我遵循了引用 bert 的示例 但是当然 gpt j 模型并不完全类似于 bert 模型 该错误表明模型没有产生损失 这很好 但我不知道如何让它产生损失或如何改变训练者