huggingface 自定义模型finetune训练测试--bert多任务

2023-11-16

背景:

需要将bert改为多任务,但是官方仅支持多分类、二分类,并不支持多任务。改为多任务时我们需要修改输出层、loss、评测等。如果需要在bert结尾添加fc等也可以参考该添加方式。

代码

修改model

这里把BertForSequenceClassification改为多任务

import torch
import torch.nn as nn
from typing import List, Optional, Tuple, Union
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings
from transformers import BertPreTrainedModel, BertModel
from transformers.utils import add_start_docstrings_to_model_forward, add_code_sample_docstrings,add_start_docstrings

_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION = "textattack/bert-base-uncased-yelp-polarity"
_CONFIG_FOR_DOC = "BertConfig"
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
_SEQ_CLASS_EXPECTED_LOSS = 0.01
BERT_START_DOCSTRING = r"""

    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`BertConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
BERT_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `({0})`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)
        token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
            Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.

        inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""

@add_start_docstrings(
    """
    Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
    output) e.g. for GLUE tasks.
    """,
    BERT_START_DOCSTRING,
)
class BertForSequenceClassification_Multitask(BertPreTrainedModel):
    def __init__(self, config, task_output_dims):
        super().__init__(config)
        self.task_output_dims = task_output_dims
        
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.classifiers=nn.ModuleList([nn.Linear(768,output_dim) for output_dim in task_output_dims])
        # Initialize weights and apply final processing
        self.post_init()
    @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
    @add_code_sample_docstrings(
        checkpoint=_CHECKPOINT_FOR_SEQUENCE_CLASSIFICATION,
        output_type=SequenceClassifierOutput,
        config_class=_CONFIG_FOR_DOC,
        expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
        expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
    )
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        if self.config.problem_type == 'multi_task_classification':
            logits=[classifier(pooled_output) for classifier in self.classifiers]
        else:
            logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                elif labels.dtype==list:
                    self.config.problem_type = "multi_task_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
            elif self.config.problem_type == "multi_task_classification":
                loss_fct = CrossEntropyLoss()
                loss_list=[loss_fct(logits[i],labels[:,i]) for i in range(len(self.task_output_dims))]
                loss=torch.sum(torch.stack(loss_list))
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
# 调用时
# 原调用为
model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path, num_labels=2, hidden_dropout_prob=dropout)
# 现改为
model = BertForSequenceClassification_Multitask.from_pretrained(pretrained_model_name_or_path, num_labels=len(pjwk_cates), hidden_dropout_prob=dropout, task_output_dims=[6,63], problem_type = "multi_task_classification")

测试加载模型时

测试时,在load_checkpoint时,由于原有文件中没有problem_type =“multi_task_classification”,需要添加。可以哪里报错再加入。我的文件是/home/anaconda3/envs/bert/lib/python3.8/site-packages/transformers/configuration_utils.py第347行。

# 加入multi_task_classification
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification","multi_task_classification")
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

huggingface 自定义模型finetune训练测试--bert多任务 的相关文章

随机推荐

  • python数据容器--列表的常用操作

    数据容器List 列表的常用操作 List列表的常用操作 1 查找某元素在列表内的下标索引 列表 index 元素 mylist itcast itheima python index mylist index itcast print i
  • 指针的删除

    1 在链表中 将某个指针delete 指向该指针的那个指针的next 不会自动赋值为NULL 需要手动赋值 2 删掉 某指针所指向的内存 该指针仍然可以使用 下面是一个带头指针的单向链表 void Stack Pop int value i
  • 每日一考错题整理

    1 JDK JRE JVM三者之间的关系 以及JDK JRE包含的主要结构有哪些 JDK JRE Java开发工具 javac exe java exe javadoc exe JRE JVM JAVASE核心类库 2 标识符的命名规范有哪
  • JDK8新特性之双冒号 :: 用法及详解

    JDK8新特性之双冒号 用法及详解 转自 https cloud tencent com developer article 1404786 JDK8的新特性有很多 最亮眼的当属函数式编程的语法糖 本文主要讲解下双冒号 的用法 类名 方法名
  • ubuntu:android studio 安装adb调试工具

    adb安装 apt get install android tools adb 远程连接 adb connect 172 26 0 119 5555 遇到端口占用 yangwenlong title71 Android AndroidPro
  • 微服务全栈:深入核心组件与开发技巧

    文章目录 1 服务注册与发现 1 1 客户端注册 ZooKeeper 1 2 第三方注册 独立的服务Registrar 1 3 客户端发现 1 4 服务端发现 1 5 Consul 1 6 Eureka 1 7 SmartStack 1 8
  • firefox 火狐浏览器安装java插件

    由于工作中用到决策引擎产品 FICO Blaze 该产品展示决策流 决策树 决策表等组件是依托的applet 需要浏览器启用java插件 经常碰到明明电脑上装了java 但是浏览器的附加组件中却没有显示 经过一下午的折腾发现了以下几条限制
  • SpringCloud系列教程(1)--开发环境的准备

    开发环境准备 eclipse apache maven 3 5 0 jdk1 8 说明 这个是本人的开发环境工具 也可以使用自己适应的环境 比如 IntelliJ IDEA 但是本系列以eclipse来简述 如果环境不会配置 请自行百度 因
  • Android Studio TraceView性能优化分析

    http blog csdn net androiddevelop article details 8223805 http www cnblogs com sunzn p 3192231 html Android 编程下的 TraceVi
  • ConstraintLayout各种居中设置

    1 全局居中 app layout constraintBottom toBottomOf parent app layout constraintEnd toEndOf parent app layout constraintStart
  • Python基于xlrd模块处理合并单元格

    Excel是我们日常工作中经常使用的电子表格软件 它可以方便地对数据进行整理 计算和分析 在Excel中 有时候需要将多个单元格合并成一个单元格 以便更好地展示数据 但是 在数据处理过程中 合并单元格也会带来不少麻烦 本文将介绍如何使用Py
  • SpringBoot实战(八)集成 Logback

    目录 1 简介 2 项目结构 3 配置文件 3 1 Maven 3 2 logback spring xml 3 3 application yml 4 自定义输出级别 5 项目地址 6 部分内容没有输出到日志文件中问题处理 7 根据开发
  • oauth2.0--基础--6.1--SSO的实现原理

    oauth3 0 基础 6 1 SSO的实现原理 1 什么是SSO 1 1 概念 在一个 多系统共存 的环境下 用户在一处登录后 就不用在其他系统中登录 就可以访问其他系统的资源 用户环境 浏览器 只能同一个浏览器 不会出现A浏览器登录成功
  • Kafka常见的导致重复消费原因和解决方案

    点击上方蓝色字体 选择 设为星标 回复 资源 获取更多资源 大数据技术与架构 点击右侧关注 大数据开发领域最强公众号 暴走大数据 点击右侧关注 暴走大数据 问题分析 导致kafka的重复消费问题原因在于 已经消费了数据 但是offset没来
  • Android Studio 编译Library的jar包与aar包

    编译器 基于Android Studio版本为4 0 2 1 先编译一下工程 jar包 在Project模式下 jar包的位置 build intermediates compile library classes jar debug cl
  • access有效性规则不为空值_在设置access有效性规则中,大于0并且小于100怎么写?...

    展开全部 在有效性e5a48de588b662616964757a686964616f31333433633362规则处输入 gt 0 And lt 100即可 就是了 如果需要不包含0和100 那么就去掉其中的 就可以了 在有效性规则中
  • 完整的php在线加密代码,无私奉上(原创)

    以下是一个示例的完整的PHP代码 用于在网站前台输入PHP代码并生成加密后的代码
  • vue.runtime.esm.js2b0e619 [Vue warn] Error in render “TypeError Cannot read property ‘matched‘

    错误截图 踩坑原因 在配置vue router的路由时和将router实例挂载至Vue实例上时 实例名称没有按照标准 配置vue时 在配置vue router时 我们一般会将路由的映射关系抽成一个数组 就想下面这样 const routes
  • 小程序web-view 跳转到h5 监听返回按钮

    1 跳转到h5之后 先给页面堆栈 然后就可以监听到返回事件了
  • huggingface 自定义模型finetune训练测试--bert多任务

    背景 需要将bert改为多任务 但是官方仅支持多分类 二分类 并不支持多任务 改为多任务时我们需要修改输出层 loss 评测等 如果需要在bert结尾添加fc等也可以参考该添加方式 代码 修改model 这里把BertForSequence