自定义实现nn.CrossEntropyLoss损失函数

2023-11-04

nn.CrossEntropyLoss是在PyTorch中常用的交叉熵损失函数。它主要用于解决多分类问题,但也可以用于解决二分类问题。该函数有两个输入参数。第一个参数是网络的最后一层的输出,是一个二维数组,其中每个向量包含不同类别的概率值。第二个参数是传入的标签,即某个类别的索引值,表示样本对应的真实类别。

在有时需要修改模型时,可能需要修改损失函数的运算过程。直接改变PyTorch中的代码可能会有些麻烦,因此我们可以先按照nn.CrossEntropyLoss的原理自己实现一个版本,然后在此基础上进行修改。

关于nn.CrossEntropyLoss的原理,这篇博客已经有详细介绍,因此在此不再赘述。简单来说,该函数先进行softmax操作,然后再进行log操作。最后,对每个样本的标签处的预测值求和,取平均值,再取绝对值变为一个正数。以下是博客链接:
https://blog.csdn.net/Lucinda6/article/details/116162198

但是按照这篇文章实现时会遇到了一些问题,有时候先sofmax再log时,有时会softmax出来0导致log到inf.直接用LogSoftmax是可以的,按这个实现时的代码如下:

def selfCrossEntropyLoss(selfx,target):
    ls = nn.LogSoftmax(dim=1)
    selfloss=ls(selfx).double()
    selfloss = abs(torch.sum(selfloss) / len(selfx))
    return selfloss

但是还是有问题,sum和len算多维数据时会算不全,只算了一个维度,比如shape是(100,61)时,len只会算到100,最后改了这个后,又和torch的原交叉熵比会有一点小差别,比如71和65的差别,大差不差只能是。

为了在不同维度的张量上使用,后来将nn.LogSoftmax(dim=1)修改为nn.LogSoftmax(dim=-1),在计算平均log_softmax值时,将abs()函数改为torch.mean()函数,能提高精度。最后,用了gather()函数和unsqueeze()函数对预测值进行索引,并使用squeeze()函数去除不必要的维度,以计算损失值。这样算下来结果是和torch的原交叉熵一样的,代码如下:

def selfCrossEntropyLoss2(selfx, target):
    ls = nn.LogSoftmax(dim=-1)
    selfloss = ls(selfx).double()
    selfloss = torch.mean(selfloss.gather(1, target.unsqueeze(1)).squeeze()) * -1
    return selfloss

ok这样就可以在nn.CrossEntropyLoss的基础上添加自己的其他魔改了。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

自定义实现nn.CrossEntropyLoss损失函数 的相关文章

  • PyTorch:tensor.cuda()和tensor.to(torch.device(“cuda:0”))之间有什么区别?

    在 PyTorch 中 以下两种将张量 或模型 发送到 GPU 的方法有什么区别 Setup X np array 1 3 2 3 2 3 5 6 1 2 3 4 X model X torch DoubleTensor X Method
  • PoseWarping:如何矢量化此 for 循环(z 缓冲区)

    我正在尝试使用地面真实深度图 姿势信息和相机矩阵将帧从视图 1 扭曲到视图 2 我已经能够删除大部分 for 循环并将其矢量化 除了一个 for 循环 扭曲时 由于遮挡 视图 1 中的多个像素可能会映射到视图 2 中的单个位置 在这种情况下
  • pytorch通过易失性变量反向传播错误

    我试图通过多次向后传递迭代来运行它并在每个步骤更新输入 从而最小化相对于某个目标的一些输入 第一遍运行成功 但在第二遍时出现以下错误 RuntimeError element 0 of variables tuple is volatile
  • 二维数组的按行 numpy.isin [重复]

    这个问题在这里已经有答案了 我有两个数组 A np array 3 1 4 1 1 4 B np array 0 1 5 2 4 5 2 3 5 是否可以使用numpy isin二维数组按行排列 我想检查一下是否A i j is in B
  • 使用 pytorch 获取可用 GPU 内存总量

    我正在使用 google colab 免费 Gpu 进行实验 并想知道有多少 GPU 内存可供使用 torch cuda memory allocated 返回当前占用的 GPU 内存 但我们如何使用 PyTorch 确定总可用内存 PyT
  • 预训练 Transformer 模型的配置更改

    我正在尝试为重整变压器实现一个分类头 分类头工作正常 但是当我尝试更改配置参数之一 config axis pos shape 即模型的序列长度参数时 它会抛出错误 Reformer embeddings position embeddin
  • 如何在 google colab 中运行 matlab .m 文件

    我目前正在尝试运行这个存储库https github com Fanziapril mvfnet https github com Fanziapril mvfnet这需要一个步骤 Run the Matlab ModelGeneratio
  • pytorch grad 在 .backward() 之后为 None

    我刚刚安装火炬 1 0 0 on Python 3 7 2 macOS 并尝试tutorial https pytorch org tutorials beginner blitz autograd tutorial html sphx g
  • torchvision.transforms.Normalize 是如何操作的?

    我不明白如何标准化Pytorch works 我想将平均值设置为0和标准差1跨越张量中的所有列x形状的 2 2 3 一个简单的例子 gt gt gt x torch tensor 1 2 3 4 5 6 7 8 9 10 11 12 gt
  • 在 PyTorch 中原生测量多类分类的 F1 分数

    我正在尝试在 PyTorch 中本地实现宏 F1 分数 F measure 而不是使用已经广泛使用的sklearn metrics f1 score https scikit learn org stable modules generat
  • 如何在pytorch中查看DataLoader中的数据

    我在 Github 上的示例中看到类似以下内容 如何查看该数据的类型 形状和其他属性 train data MyDataset int 1e3 length 50 train iterator DataLoader train data b
  • BatchNorm 动量约定 PyTorch

    Is the 批归一化动量约定 http pytorch org docs master modules torch nn modules batchnorm html 默认 0 1 与其他库一样正确 例如Tensorflow默认情况下似乎
  • 如何更新 PyTorch 中神经网络的参数?

    假设我想将神经网络的所有参数相乘PyTorch 继承自的类的实例torch nn Module http pytorch org docs master nn html torch nn Module by 0 9 我该怎么做呢 Let n
  • Pytorch GPU 使用率低

    我正在尝试 pytorch 的例子https pytorch org tutorials beginner blitz cifar10 tutorial html https pytorch org tutorials beginner b
  • PyTorch 中的连接张量

    我有一个张量叫做data形状的 128 4 150 150 其中 128 是批量大小 4 是通道数 最后 2 个维度是高度和宽度 我有另一个张量叫做fake形状的 128 1 150 150 我想放弃最后一个list array从第 2 维
  • 如何在 PyTorch 中对子集使用不同的数据增强

    如何针对不同的情况使用不同的数据增强 转换 Subset在 PyTorch 中吗 例如 train test torch utils data random split dataset 80000 2000 train and test将具
  • 如何使用 pytorch 同时迭代两个数据加载器?

    我正在尝试实现一个接收两张图像的暹罗网络 我加载这些图像并创建两个单独的数据加载器 在我的循环中 我想同时遍历两个数据加载器 以便我可以在两个图像上训练网络 for i data in enumerate zip dataloaders1
  • Pytorch 与 joblib 的 autograd 问题

    将 pytorch 的 autograd 与 joblib 混合似乎存在问题 我需要并行获取大量样本的梯度 Joblib 与 pytorch 的其他方面配合良好 但是 与 autograd 混合时会出现错误 我做了一个非常小的例子 显示串行
  • 使用 PyTorch 分布式 NCCL 连接失败

    我正在尝试使用 torch distributed 将 PyTorch 张量从一台机器发送到另一台机器 dist init process group 函数正常工作 但是 dist broadcast 函数中出现连接失败 这是我在节点 0
  • 在requirements.txt中包含.whl安装

    如何将其包含在requirements txt 文件中 对于Linux pip install http download pytorch org whl cu75 torch 0 1 12 post2 cp27 none linux x8

随机推荐

  • c#使用多线程的几种方式介绍

    本文主要介绍了c 使用多线程的几种方式 通过示例学习c 的多线程使用方式 大家参考使用吧 1 不需要传递参数 也不需要返回参数 ThreadStart是一个委托 这个委托的定义为void ThreadStart 没有参数与返回值 代码如下
  • Docker使用

    1 下载安装 在linux下安装docker一共有三步 更新软件包列表 sudo apt get update 安装docker sudo apt get install docker ce 检查docker是否安装成功 docker ve
  • MES管理系统项目失败的原因,总结三点

    MES是一款管理系统 建设效果参差不齐 但是MES管理系统项目以胜利的寥寥无几 因为MES管理系统 主要面向管理人员 管理人员希望打开工厂黑河 然而工厂的数据来源基本都是由执行层提供的 建设MES生产管理系统的诉求与国家统计局需求是一样的
  • Chat GPT介绍

    推荐一个在线使用网站 ChatGPT Next Web chatnext top 可以免费使用 但有次数限制 体验一下ChatGPT还是不错的 次数用完可以充钱28 8元成为永久会员 我不是打广告 我只想让更多的人体验和接触ChatGPT
  • android 难题,Android开发中遇到的难题与解决方案

    引用资源文件错误 导致运行失败 无法确定错误位置 解决方案 在Android Studio的Terminal控制台输入 gradlew compileDebugSources 获取webView的高度 public void initVie
  • [windows][UI] WM_MOUSEACTIVATE

    当用户单击一个非激活的顶级窗体 或非激活的顶级窗体的子窗体时 系统就会发送WM MOUSEACTIVATE消息 还包括其他消息 给顶级窗体或子窗体 该消息在WM NCHITTEST消息之后 但在button down消息之前 当把 WM M
  • swift 类型判断 Dictory Array

    一 类型的判断 1 is 的介绍 Swift 中类型的判断的关键词是 is is操作用来判断某一个对象是否是某一个特定的类 它会返回一个bool类型的值 2 is的使用方法 1 gt is 的一般判断 Swift 系统也会自动判断 类型的一
  • C++/Python程序读取命令行参数

    C 程序读取命令行参数 include
  • 傅里叶变换(FT)数学解析推导学习总结

    写在前面 本文是一篇非常容易理解 同时会很有收获的傅里叶变换推导教程 文章是学习B站DR CAN老师傅里叶级数和傅里叶变换系列课程后的学习总结 主要目的以个人复习巩固为主 同时也分享给大家一些心得以及非常好的一位老师 附上链接 DR CAN
  • 串口的基本定义以及RS232,RS485和UART,USAT,SPI的联系和区别

    1 什么是串口 一个bit一个bit传输数据的方式称之为串口 串行接口 2 串口的种类 同步串口 带有同步时钟线的串口传输方式 异步串口 不带同步时钟线的串口传输方式 需要双方约定传输速度 3 串口的组成 串口由物理电气层和协议层组成 3
  • java字符串判断相等_java判断字符串是否相等的方法

    java判断字符串是否相等的方法 1 java中字符串的比较 我们经常习惯性的写上if str1 str2 这种写法在java中可能会带来问题 example1 String a abc String b abc 那么a b将返回true
  • 转载:算力计算

    一 GOPS与FLOPS 1 1 FLOPS FLOPS定义 是 每秒所执行的浮点运算次数 floating point operations per second 的缩写 它常被用来估算电脑的执行效能 尤其是在使用到大量浮点运算的科学计算
  • 1334. 阈值距离内邻居最少的城市

    1334 阈值距离内邻居最少的城市 原题链接 完成情况 解题思路 参考代码 Dijkstra Dijkstra 小顶堆 Floyd martix方法 原题链接 1334 阈值距离内邻居最少的城市 https leetcode cn prob
  • 裸机服务器和虚拟机的用途和好处

    裸机服务器 用户可以根据需要自定义存储区域 用户几乎可以在世界的每个角落访问他们的数据 用户还将拥有最高级别的数据加密 只有使用最新技术的用户才能访问 由于这些服务器有专门的用户 因此具有安全性和监管优势 它具有很高的处理能力 用户可以完全
  • Poi版本升级优化

    Poi 3 17前后版本api使用差异 1 升级缘由 最近公司prod环境出现因为Excel文件下载数据量过大导致应用out of memory 然后就需要找到内存溢出的原因及优化方案 经分析 得出以下结论 1 1 事故原因 1 应用场景发
  • 四合天地软件测试系统,GZ-2017025软件测试赛题.-全国职业院校技能大赛.doc

    GZ 2017025软件测试赛题 全国职业院校技能大赛 doc 2017年全国职业院校技能大赛高职组 软件测试 项目竞赛任务书 2017年全国职业院校技能大赛 高职组 软件测试 赛项执委会制 2017年5月 目录 一 赛程说明3 二 竞赛技
  • ElasticSearch启动流程指令及注意事项

    elasticsearch es的集群部署 第一步 创建普通用户 注意 ES不能使用root用户来启动 必须使用普通用户来安装启动 这里我们创建一个普通用户以及定义一些常规目录用于存放我们的数据文件以及安装包等 创建一个es专门的用户 必须
  • 第一个python代码,第一个错误。python是对缩进严格要求的代码。

    在编写第一个条件判断语句的代码中 就遇到了第一个错误 运行py时提示 仔细对照了一下代码 发现原来时缩进格式错误 并很不明显 条件语句的if换行一般是缩进四个空格 但个人觉得以其按四个空格 不如直接按一下tab键来得简洁明了 我两种方法都试
  • SpringCloud 商城系统搭建之Ribbon (基于Ribbon + RestTemplate)

    Spring Cloud 服务调用方式 Spring Cloud有两种服务调用方式 一种是Ribbon RestTemplate 另一种是feign 在这一篇文章首先讲解下基于Ribbon RestTemplate Ribbon简介 Rib
  • 自定义实现nn.CrossEntropyLoss损失函数

    nn CrossEntropyLoss是在PyTorch中常用的交叉熵损失函数 它主要用于解决多分类问题 但也可以用于解决二分类问题 该函数有两个输入参数 第一个参数是网络的最后一层的输出 是一个二维数组 其中每个向量包含不同类别的概率值