class balanced loss pytorch 实现

2023-05-16

cb loss pytorch 实现,可直接调用
参考:https://github.com/vandit15/Class-balanced-loss-pytorch/blob/master/class_balanced_loss.py

import numpy as np
import torch
import torch.nn.functional as F



def focal_loss(logits, labels, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.

    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).

    Args:
      logits: A float tensor of size [batch, num_classes].
      labels: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.

    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """
    bce_loss = F.binary_cross_entropy_with_logits(input=logits, target=labels, reduction="none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 + torch.exp(-1.0 * logits)))

    loss = modulator * bce_loss

    weighted_loss = alpha * loss
    loss = torch.sum(weighted_loss)
    loss /= torch.sum(labels)
    return loss


class ClassBalancedLoss(torch.nn.Module):
    def __init__(self, samples_per_class=None, beta=0.9999, gamma=0.5, loss_type="focal"):
        super(ClassBalancedLoss, self).__init__()
        if loss_type not in ["focal", "sigmoid", "softmax"]:
            loss_type = "focal"
        if samples_per_class is None:
            num_classes = 5000
            samples_per_class = [1] * num_classes
        effective_num = 1.0 - np.power(beta, samples_per_class)
        weights = (1.0 - beta) / np.array(effective_num)
        self.constant_sum = len(samples_per_class)
        weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)
        self.class_weights = weights
        self.beta = beta
        self.gamma = gamma
        self.loss_type = loss_type


    def update(self, samples_per_class):
        if samples_per_class is None:
            return
        effective_num = 1.0 - np.power(self.beta, samples_per_class)
        weights = (1.0 - self.beta) / np.array(effective_num)
        self.constant_sum = len(samples_per_class)
        weights = (weights / np.sum(weights) * self.constant_sum).astype(np.float32)
        self.class_weights = weights



    def forward(self, x, y):
        _, num_classes = x.shape
        labels_one_hot = F.one_hot(y, num_classes).float()
        weights = torch.tensor(self.class_weights, device=x.device).index_select(0, y)
        weights = weights.unsqueeze(1)
        if self.loss_type == "focal":
            cb_loss = focal_loss(x, labels_one_hot, weights, self.gamma)
        elif self.loss_type == "sigmoid":
            cb_loss = F.binary_cross_entropy_with_logits(x, labels_one_hot, weights)
        else:  # softmax
            pred = x.softmax(dim=1)
            cb_loss = F.binary_cross_entropy(pred, labels_one_hot, weights)
        return cb_loss


def test():
    torch.manual_seed(123)
    batch_size = 10
    num_classes = 5
    x = torch.rand(batch_size, num_classes)
    y = torch.randint(0, 5, size=(batch_size,))
    samples_per_class = [1, 2, 3, 4, 5]
    loss_type = "focal"
    loss_fn = ClassBalancedLoss(samples_per_class, loss_type=loss_type)
    loss = loss_fn(x, y)
    print(loss)


if __name__ == '__main__':
    test()

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

class balanced loss pytorch 实现 的相关文章

  • pytorch 的 IDE 自动完成

    我正在使用 Visual Studio 代码 最近尝试了风筝 这两者似乎都没有 pytorch 的自动完成功能 这些工具可以吗 如果没有 有人可以推荐一个可以的编辑器吗 谢谢你 使用Pycharmhttps www jetbrains co
  • Haskell 中的异构多态性(正确方法)

    让一个模块来抽象Area操作 错误的定义 class Area someShapeType where area someShapeType gt Float module utilities sumAreas Area someShape
  • 在其抽象超类中使用子类的泛型类型?

    在我的代码中有以下抽象超类 public abstract class AbstractClass
  • java中类的命名约定 - 全部大写

    在 Java 中 当类全部大写时 如何命名它 例如 如果我想创建一个班级来选择某些人成为 VIP 我应该将类命名为 VIPSelector 还是 VipSelector Thanks 你的两个选择都有效 类的主要目标是让它们以大写字母开头
  • 在 Pytorch 中估计高斯模型的混合

    我实际上想估计一个以高斯混合作为基本分布的归一化流 所以我有点被火炬困住了 但是 您可以通过估计 torch 中高斯模型的混合来在代码中重现我的错误 我的代码如下 import numpy as np import matplotlib p
  • C# 中的类和模块有什么用

    有人可以解释一下类和模块之间的区别吗 你什么时候使用其中一种而不是另一种 我正在使用 C 更新 我的意思是相当于 VB 模块的 C 版本 这在很大程度上取决于您所指的 模块 Visual Basic 的模块 C 中没有真正等效的 VB Ne
  • 保存具有自定义前向功能的 Bert 模型并将其置于 Huggingface 上

    我创建了自己的 BertClassifier 模型 从预训练开始 然后添加由不同层组成的我自己的分类头 微调后 我想使用 model save pretrained 保存模型 但是当我打印它并从预训练上传时 我看不到我的分类器头 代码如下
  • 将 __DIR__ 常量与字符串连接作为数组值,该数组值是 PHP 中的类成员

    谁能告诉我为什么这不起作用 这只是我在其他地方尝试做的事情的一个粗略的例子 stuff array key gt DIR value 但是 这会产生错误 PHP Parse error syntax error unexpected exp
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • 在类中使用 std::chrono::high_resolution_clock 播种 std::mt19937 的正确方法是什么?

    首先 大家好 这是我在这里提出的第一个问题 所以我希望我没有搞砸 在写这篇文章之前我用谷歌搜索了很多 我对编码 C 很陌生 我正在自学 考虑到有人告诉我 只为任何随机引擎播种一次是一个很好的做法 我在这里可能是错的 什么是正确 最佳 更有效
  • 参考接口创建对象

    引用变量可以声明为类类型或接口类型 如果变量声明为接口类型 则它可以引用实现该接口的任何类的任何对象 根据上面的说法我做了一个理解上的代码 正如上面所说声明为接口类型 它可以引用实现该接口的任何类的任何对象 但在我的代码中显示display
  • 删除向量类成员

    我有一个 A 类 其成员是另一个 B 类的对象指针向量 class A std vector
  • C++:获取器和设置器?

    我正在尝试编写一些代码来为以下数据的 ID 号 名字 姓氏 期中成绩和期末成绩创建 getter 和 setter 这些数据位于我正在编写的班级的文本文件中 10601 ANDRES HYUN 88 91 94 94 89 84 94 84
  • 动态创建类 - Python

    我需要动态创建一个类 为了更详细地讲 我需要动态创建 Django 的子类Form class 通过 动态 我打算根据用户提供的配置创建一个类 e g 我想要一个名为CommentForm这应该子类化Form class 该类应该有一个选定
  • 使类只能从特定类实例化

    假设我有 3 节课class1 class2 and class3 我怎样才能拥有它class1只能通过实例化class2 class1 object new class1 但不是 class3 或任何其他类 我认为它应该与修饰符一起使用
  • PHP - 扩展 __construct

    我想知道你是否可以帮助我 我有两个类 一个扩展了另一个 B 类将由各种不同的对象扩展 并用于常见的数据库交互 现在我希望 B 类能够处理其连接和断开连接 而无需来自 A 类或任何外部输入的指示 据我了解 问题是扩展类不会自动运行其 cons
  • 无法从 C# WPF 中的另一个窗口调用方法

    好吧 假设我有两个窗户 在第一个中我有一个方法 public void Test Label Content works 在第二个方法中 我称此方法为 MainWindow mw new MainWindow mw Test 但什么也没发生
  • Delphi - 如果没有创建类,为什么这个函数可以工作?

    考虑这个类 unit Unit2 interface type TTeste class private texto string public function soma a b integer string end implementa
  • 无法在类对象的 ArrayList 中存储值。 (代码已编辑)

    这基本上是一个 Java 代码转换器 它涉及一个 GUI 让用户输入类类型 名称和方法 为了存储值 我创建了一个类VirtualClass与ArrayList
  • 类函数/变量在使用之前是否必须声明?

    所以我在学习课程时偶然发现了一些对我来说相当尴尬的事情 class Nebla public int test printout return x void printout printout2 private int x y void p

随机推荐

  • 阿克曼前轮转向车gazebo模型

    想要一个阿克曼转向结构车的gazebo模型 xff0c 要求能够用ros话题控制前进速度和前轮转角 令人惊讶的是 xff0c 网上基本没有这种模型 racecar模型 首先古月居提供了一个racecar的模型 xff0c 可以控制速度和前轮
  • Jetson TX2 在docker容器中import torch 报错的处理方式

    1 Jetson TX2 信息 xff1a 驱动版本 xff1a JetPack 4 6 1 2 docker信息 xff1a docker 镜像 xff1a pull 了 nvcr io nvidia l4t ml r32 7 1 py3
  • 史上最详细的PID教程——理解PID原理及优化算法

    Matlab动态PID仿真及PID知识梳理 云社区 华为云 huaweicloud com 位置式PID与增量式PID区别浅析 Z小旋 CSDN博客 增量式pid https zhuanlan zhihu com p 38337248 期望
  • JAVA经典试卷(理工)

    一 判断题 xff08 本大题共20小题 xff0c 每小题1分 xff0c 总计20分 xff09 1 xff0e final类能派生子类 2 xff0e 子类要调用父类的方法 xff0c 必须使用super关键字 3 xff0e Jav
  • git与gitee学习笔记

    随着时间推移 xff0c 除去常量 xff0c 任何事物都是在变化的 xff0c 如果用一根曲线表示 xff0c 横轴代表时间 xff0c 纵轴代表事物量 xff0c 那么所绘制的曲线 xff0c 在时间足够长的情况下 xff0c 必然是高
  • docker基础命令操作---镜像操作

    1 搜索官方仓库镜像 xff1a docker search image name 镜像名 例如 xff1a docker search nginx 命令执行结果参数说明 xff1a 参数 说明 NAME 镜像名称 DESCRIPTION
  • ESP8266连接天猫精灵(一)

    背景 接触天猫精灵后 xff0c 就想作一些小东西能接入天猫精灵 查看官网的文档后 xff0c 选择了ESP系列 xff0c 官方在文档中也比较推荐 读技术文档是个很难受的事情 xff0c 容易犯困 xff0c 最好有可以操作的设备 准备如
  • Windows下Boost库的安装与使用

    目录 1 基本介绍 2 下载安装 3 配置boost环境 xff08 VS2010 xff09 4 测试 1 基本介绍 Boost库是为C 43 43 语言标准库提供扩展的一些C 43 43 程序库的总称 xff0c 由Boost社区组织开
  • 嵌入式JetSon TX2上使用RealSense D435 (外加IMU芯片) 运行RTAB-Map与VINS-MONO的全流程记录

    本周成功的在JetSon TX2上移植了Vins Mono与RTAB Map xff0c 并使用摄像头RealSense D435顺利跑通了这两个框架 中间遇到了各种各样神奇的问题 xff0c 踩坑无数 xff0c 现整理记录一下整体流程
  • 微信公众号本地开发调试 - 无公网IP,内网穿透

    文章目录 前言1 配置本地服务器2 内网穿透2 1 下载安装cpolar内网穿透2 2 创建隧道 3 测试公网访问4 固定域名4 1 保留一个二级子域名4 2 配置二级子域名 5 使用固定二级子域名进行微信开发 前言 在微信公众号开发中 x
  • opencv图像通道 8UC1?

    转载自博主 64 马卫飞 https blog csdn net maweifei article details 51221259 CV lt bit depth gt S U F C lt number of channels gt b
  • gazebo中urdf、xacro、sdf模型文件关系

    gazebo的模型是用xml格式的文本文件来描述的 具体有三种形式 xff1a urdf xacro sdf urdf urdf是老的gazebo模型格式 xff0c 本身有一些缺陷 xff0c 也缺一些功能 但是网上很多gazebo模型都
  • 1_树莓派开启ssh服务

    树莓派3 开启 SSH 服务 原文链接 xff1a https blog csdn net qq 16775293 article details 88385393 文章目录 1 使用管理工具2 启动服务3 自动启动服务 3 1 Windo
  • 树莓派4b串口通信配置

    树莓派4b本身是两个串口 xff0c 运行ls dev al如下 xff1a 请注意 xff1a 在默认状态下 xff0c serial0 就是GPIO14 15 是映射到ttyS0的 xff08 就是MINI串口 xff1a dev tt
  • Pandas第三次作业20200907

    练习1 读取北向 csv 指定trade date为行索引 查看数据的基本信息 有无缺失值 对其缺失值进行处理 删除缺失值所在行 查看数据的基本信息 查看数据是否清洗完毕 index列没啥用 将index列删除 观察数据是否有重复行 将重复
  • 新手入门板卡硬件调试

    硬件电路调试步骤 新手入门板卡硬件调试一看 观察焊接情况二测 测量阻抗三接触式上电调试遇到的问题一般解决思路电源供电运放出现震荡测量时GND的选取振铃现象 新手入门板卡硬件调试 一看 观察焊接情况 1 拿到板卡后 xff0c 首先观察下焊接
  • 用shell 命令获取占用cpu 最多的前五位

    通常情况下使用ps axu 来获得系统中所有进程占用资源情况 xff0c 通常也可以使用top 命令来动态的获得系统中资源占用最多的进程 假设我们使用ps aux gt file tmp来获取linux系统中的进程占用资源情况 xff0c
  • 关于准确率accuracy和召回率recall的理解

    假设有100个样本 xff0c 其中正样本70 xff0c 负样本30 xff0c 这个是由数据集本身决定的 xff0c 机器要做的就是判别这100个样本中哪几个样本是正样本 xff0c 哪几个样本是负样本 现在机器做出了预测 xff1a
  • pytorch BERT文本分类保姆级教学

    pytorch BERT文本分类保姆级教学 本文主要依赖的工具为huggingface的transformers xff0c 更详细的解释可以查阅文档 定义模型 模型定义主要是tokenizer config和model的定义 xff0c
  • class balanced loss pytorch 实现

    cb loss pytorch 实现 xff0c 可直接调用 参考 xff1a https github com vandit15 Class balanced loss pytorch blob master class balanced