Kaggle Feedback Prize 3比赛总结:针对层级的训练策略

2023-05-16

Last Layers Re-initialization

我们不使用所有层的预训练权重,而是使用原始的Transformer初始化来重新初始化指定的层数。重新初始化的层会破坏这些特定块的预训练知识。我们知道较低的预训练层学习更多的全局一般特征,而靠近输出的较高的层则更专注于预训练任务。因此初始化较高的层,并重新训练能够让网络更好的学习当前特定的任务。下面的例子是初始化 roberta 最后两层。

from transformers import AutoConfig
from transformers import AutoModelForSequenceClassification


reinit_layers = 2
_model_type = 'roberta'
_pretrained_model = 'roberta-base'
config = AutoConfig.from_pretrained(_pretrained_model)
model = AutoModelForSequenceClassification.from_pretrained(_pretrained_model)

if reinit_layers > 0:
    print(f'Reinitializing Last {reinit_layers} Layers ...')
    encoder_temp = getattr(model, _model_type)
    for layer in encoder_temp.encoder.layer[-reinit_layers:]:
        for module in layer.modules():
            if isinstance(module, nn.Linear):
                module.weight.data.normal_(mean=0.0, std=config.initializer_range)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.Embedding):
                module.weight.data.normal_(mean=0.0, std=config.initializer_range)
                if module.padding_idx is not None:
                    module.weight.data[module.padding_idx].zero_()
            elif isinstance(module, nn.LayerNorm):
                module.bias.data.zero_()
                module.weight.data.fill_(1.0)
    print('Done.!')

LLRD - Layerwise Learning Rate Decay

LLRD是一种对顶层采用较高学习率,对底层采用较低学习率的方法。这是通过设置顶层的学习率和使用乘法衰减率从上到下逐层降低学习率来实现的。这是因为在网络的底层一般学习的是全局的一般信息,所以预训练的权重效果会更好,而对于网络的高层是针对特定任务的权重,因此需要较高的学习率来加速更新学习。代码实现如下。

def get_optimizer_grouped_parameters(
    model, model_type, 
    learning_rate, weight_decay, 
    layerwise_learning_rate_decay
):
    no_decay = ["bias", "LayerNorm.weight"]
    # initialize lr for task specific layer
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if "classifier" in n or "pooler" in n],
            "weight_decay": 0.0,
            "lr": learning_rate,
        },
    ]
    # initialize lrs for every layer
    num_layers = model.config.num_hidden_layers
    layers = [getattr(model, model_type).embeddings] + list(getattr(model, model_type).encoder.layer)
    layers.reverse()
    lr = learning_rate
    for layer in layers:
        optimizer_grouped_parameters += [
            {
                "params": [p for n, p in layer.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": weight_decay,
                "lr": lr,
            },
            {
                "params": [p for n, p in layer.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                "lr": lr,
            },
        ]
        # 学习率衰减
        lr *= layerwise_learning_rate_decay
    return optimizer_grouped_parameters

然后将定义好的参数,放入优化器中:

grouped_optimizer_params = get_optimizer_grouped_parameters(
    model, _model_type, 
    learning_rate, weight_decay, 
    layerwise_learning_rate_decay
)
optimizer = AdamW(
    grouped_optimizer_params,
    lr=learning_rate,
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Kaggle Feedback Prize 3比赛总结:针对层级的训练策略 的相关文章

  • 在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别

    昨天我在Kaggle上下载了一份用于手写数字识别的数据集 xff0c 想通过最近学习到的一些方法来训练一个模型进行手写数字识别 这些数据集是从28 28像素大小的手写数字灰度图像中得来 xff0c 其中训练数据第一个元素是具体的手写数字 x
  • Feedback Network for Image Super-Resolution(SRFBN)---翻译

    attention xff1a 只详细翻译了重点部分 摘要 图像超分辨率 xff08 SR xff09 的最新进展展现了深度学习的力量 xff0c 可以实现更好的重建性能 然而 xff0c 现有的基于深度学习的图像SR方法尚未充分利用人类视
  • 【CVPR2019】超分辨率文章,SRFBN: Feedback Network for Image Super-Resoluition

    论文地址 代码 CVPR的单图像超分辨率文章 xff0c 主要是用回传机制来提高超分辨率的效果 xff0c 且不引入过多的参数 主要是设计了一个feedback模块 xff0c 多次回传 xff0c 如下图所示 xff1a 上一次feedb
  • kaggle邮箱不能验证+安装python的Speedml库

    注册kaggle账号遇到一些问题 下面是具体问题和解决方案 希望遇到同样问题的小伙伴不要再踩到坑啦 1kaggle邮箱不能验证You did not enter the correct captcha response Please try
  • Kaggle平台持续运行项目最多9小时的解决方法

    在Kaggle平台运行自己的项目经常遇到9小时就中断的问题 很多时候到9小时项目并没有跑完 导致前面的时间都浪费了 没能能到最终结果 有一个解决方案是分开运行项目 如果一共需要跑200轮 则拆分成两次跑 一次100轮 这100轮要保证能够在
  • Kaggle 数据竞赛

    文章目录 一 前言 二 主要内容 1 评估 2 时间线 3 奖金 4 代码要求 三 总结 CSDN 叶庭云 https yetingyun blog csdn net 一 前言 使用机器学习技术 通过匿名健康特征的测量数据来检测疾病 比赛目
  • X和Y为不同函数关系时pd.corr()的输出结果

    from pandas import DataFrame Series import pandas as pd import numpy as np import math 当X和Y为log 关系时 python x Series np a
  • json文件中数据类别个数统计与类别信息可视化

    将json文件保存的数据信息利用URL下载数据以后 希望将统计出数据集中每一类图片个数 且进行可视化 看数据分布是否均匀 然后在进行相应的操作 数据还是kaggle比赛中提供的数据集 json文件内容如下 python实现上述要求 导入相应
  • kaggle通过API下载数据集主要事项及指定路径保存

    每次下载新的数据集都需要重新操作接受规则 Rules gt 下载 json 文件 gt 将新的 json 文件放入到 kaggle 文件夹中 否则下载时 会出现错误 更改默认下载地址 kaggle config set n path v l
  • Kaggle 数据集导入 Jupyter Notebook

    我正在尝试将一些数据从 kaggle 导入到笔记本中 我收到的错误是 401 未经授权 但我已接受比赛规则并且能够下载数据 这是我正在运行的代码 from kaggle api kaggle api extended import Kagg
  • 是否有相当于 iOS 推送通知反馈服务的 Android GCM?

    我们的网络应用程序向 iOS 和 Android 设备发送推送通知请求 对于 iOS Apple 推送通知服务具有反馈服务 因此您可以检测哪些设备已卸载您的应用程序 然后将其从数据库中删除 Android GCM 有类似的反馈服务吗 如果没
  • train.default(x, y, Weights = w, ...) 中的错误:无法确定最终调整参数

    我对机器学习非常陌生 正在尝试Kaggle 上的森林覆盖预测竞赛 但我很早就挂断了 当我运行下面的代码时 出现以下错误 Error in train default x y weights w final tuning parameters
  • 在 Google Colab 中导入 Cats-vs-Dogs 数据集时出错

    尝试使用以下命令下载 Cats vs Dogs TensorFlow 数据集时tfds模块 我收到以下错误 DownloadError Traceback most recent call last
  • Feedback.js 服务器 API

    反馈 js http experiments hertzen com jsfeedback 是一个很棒的 jquery 插件 允许您创建反馈表单 其中包括在客户端浏览器上创建的屏幕截图以及表单 如何将捕获的图像和用户的评论发送到服务器端 a
  • 在jshell中创建自定义反馈模式

    从 jshell 中 set Feedback 的文档来看 有以下几种内置模式 verbose normal concise and silent 是否可以打造一种兼具简洁和静音功能的反馈模式 或者我们可以改变上述任何一种模式吗 或者我们可
  • 在 R 中下载 Kaggle zip 文件

    我正在尝试直接从 R 代码本身的 Kaggle 空间下载 zip 文件 不幸的是 它的效果并不好 这是发生的事情 对于旧金山犯罪数据集 请访问https www kaggle com c sf crime data https www ka
  • AttributeError:“Simple_Imputer”对象在 PyCaret 中没有属性“fill_value_categorical”

    我正在使用 PyCaret 并收到错误 AttributeError Simple Imputer object has no attribute fill value categorical 尝试创建一个基本实例 pip install
  • 如何删除 Apple APNS 反馈收到的设备令牌

    我成功通过 PHP 获取 Apple APNS 反馈数据 我得到的结构 经过一些处理 看起来像这样 时间戳 设备令牌 我的问题是如何知道应该从数据库中删除哪些设备令牌并停止向它们发送通知 Regardz Mladjo 时间戳是这里的关键元素
  • xgboost中的eval_metric和feval有什么区别?

    有什么区别feval and eval metric在xgb train中 这两个参数仅用于评估目的 Kaggle 的帖子提供了一些见解 https www kaggle com c prudential life insurance as
  • *Python 内的 Kaggle API 文档?

    我想写一个python从 Kaggle com 下载公共数据集的脚本 Kaggle API 是用 python 编写的 但是我能找到的几乎所有文档和资源都是关于如何在命令行中使用该 API 的 而关于如何使用kaggle图书馆内python

随机推荐