为何pytorch nn.KLDivLoss()损失计算为负数?

2023-11-11

参考文献:https://www.zhihu.com/question/384982085
先来看一下KL散度的定义
在这里插入图片描述
在这里插入图片描述
这里是要用分布Q为标签(原始分布),分布P作为预测值(预测分布)

在pytorch中,nn.KLDivLoss()的计算公式如下:
在这里插入图片描述
上图y为标签,x为预测值,则pytorch应该以如下代码使用

lossfunc = nn.KLDivLoss()
loss = lossfunc(预测值, 标签值)

但是,由于计算公式中,预测值x的输入要是对数形式,而标签值y则不需要,所以如果我们要对预测值和标签值的softmax值求KL散度就需要如下:

temp = 1 #温度系数
probs = torch.Tensor([[2, 6, 8], [7, 1, 2], [1, 9, 2.3], [1.9, 2.8, 5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
loss = lossfunc(F.log_softmax(probs / temp, dim=1), F.softmax(target / temp, dim=1))#如果probs和target已经是softmax的形式,就只需要给probs取对数输入就行了

也就是说要给输入的预测值预先取个对数,这样计算结果就不为负数了

错误示范:

probs = torch.Tensor([[2, 6, 8], [7, 1, 2], [1, 9, 2.3], [1.9, 2.8, 5.4]])
target = torch.Tensor([[0.8, 0.1, 0.1], [0.1, 0.7, 0.2], [0.5, 0.2, 0.3], [0.4, 0.3, 0.3]])
loss1 = lossfunc(F.softmax(probs / temp, dim=1), F.softmax(target / temp, dim=1))
loss2 = lossfunc(probs, target)

这样两个散度算出来就都是负数

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

为何pytorch nn.KLDivLoss()损失计算为负数? 的相关文章

随机推荐

  • 数字IC后端设计技术全局观

    数字IC后端设计flow 不含DFT 数字IC后端设计工具 DC 用于逻辑综合 FM 用于形式验证 ICC 用于物理实现 PrimeTime 用于STA 步骤 或文件类型 简述 RTL Register Transfer Level v文件
  • mysql8.0收费价格,MySQl 8.0遇到的坑

    报错 Illuminate Database QueryException SQLSTATE HY000 1045 Access denied for user root localhost using password NO SQL cr
  • Trying to access array offset on value of type int

    问题描述 出现报错信息 先百度翻译 试图访问int类型值的数组偏移量 通过翻译得知 int型的数据被其他不能使用的类型使用了 个人理解 关于这块 php7 4升级之后会有这个bug 网上大多人是说 7 4 版本的向后不兼容更改 非数组的数组
  • valgrind Massif

    valgrind检查内存泄露 valgrind 程序 内存泄漏问题 我们有memcheck工具来检查 很爽 但是有时候memcheck工具查了没泄漏 程序一跑 内存还是狂飙 这又是什么问题 其实memcheck检查的内存泄漏只是狭义的内存泄
  • Docker——安装和启动

    一 环境准备 1 安装Linux虚拟机软件 VMware或VirtualBox 比VMware更小巧轻便且免费 此处安装VirtulaBox 2 安装Linux虚拟系统 在管理中选择导入虚拟电脑 记得选中重新初始化所有网卡的MAC地址 双击
  • Dynamics CRM 365 如何设置经典登录页面

    Don t be surprised If you don t see classic interface post your sign up for dynamics 365 Okay let s face it We are losin
  • 复选框check的选中、不选中设置以及判断是否选中

    复选框的设置 一 JavaScript判断是否选中checkbox框 二 JavaScript设置选中checkbox框 三 JavaScript移除选中checkbox框 四 使用jQuery判断是否选中checkbox框 五 使用jQu
  • 国密(1) - 私钥Key文件( PEM格式)编解码方法

    详细的PEM文件格式解析 PEM文件 是按照私钥的ASN 1的格式 RFC5208 5915 5480 进行DER编码后输出二进制串的基础上 再进行Base64的编码 也就是每6个bit为一组 生成一个ascii码字符 需要4组6个bit
  • 学习笔记59—收藏这7个在线配色神器,再也不愁配色灵感了

    在设计中配色方案是必要的 也是让设计师头疼的一个问题 所以 编辑专为大家整理了一波配色神器网站 不用下载任何应用程序 打开即用 不仅能快速的做出符合设计概念的颜色组合 且有很多样品供你确认的工具 设计新手们千万别错过了 一 Khroma h
  • 【macOS】Win通过VNC远程控制Macbook

    Win通过VNC远程控制Macbook 参考 https zhuanlan zhihu com p 74162964 仅局域网内可用 Macbook配置 进入 电脑设置 勾选两个选项 Windows配置 安装VNC Viewer https
  • openpyxl操作表格的基本用法

    创建文件 以及创建xlsx表格 from openpyxl import Workbook load workbook import os 创建excel文件 默认会有一个sheet命名的表 def create xlsx path nam
  • Beginng_Rust(译):借用和生命周期(第二十二章)

    在本章中 您将学习 借用 和 生命周期 的概念 哪些是关于借用的典型编程错误 即困扰系统软件 Rust严格语法如何使用借用检查器来防止此类典型错误 插入块的方式如何限制借用范围 为什么函数返回引用需要生命周期指示符 如何使用寿命指定符来表示
  • 应用层的原理

    目录 应用层协议原理 网络应用程序体系结构 客户 服务器 P2P 混合模式 UDP TCP 所有能产生网络流量的程序 应用层协议原理 网络应用程序体系结构 客户 服务器 P2P 混合模式 UDP TCP 可供应用程序使用的运输服务 因特网提
  • 解决liquibase.exception.LockException: Could not acquire change log lock. Currently locked by XXXX

    项目启动后报liquibase exception LockException Could not acquire change log lock 解决方案 执行下面语句 use job job为你的数据库 select from DATA
  • HTML5 history新特性pushState、replaceState

    DOM中的window对象通过window history方法提供了对浏览器历史记录的读取 让你可以在用户的访问记录中前进和后退 从HTML5开始 我们可以开始操作这个历史记录堆栈 1 History 使用back forward 和go
  • windows dll 装载过程

    windows dll 装载过程 2010 12 04 19 13 56 分类 Windows系统平台上 你可以将独立的程序模块创建为较小的DLL Dynamic Linkable Library 文件 并可对它们单独编译和测试 在运行时
  • MySQL--事务+存储引擎+表类型+视图+用户管理

    目录 1 事务 1 1 概念 1 2 回退事务 1 3提交事务 1 4事务细节注意点 1 5事务的隔离级别 1 5 1 介绍 1 5 2 解决这些安全性问题 1 5 3演示脏读 1 5 4避免脏读 演示不可重复发生 1 5 5 演示不可重复
  • Hexo 博客利用 Nginx 实现中英文切换

    本文记录了对 Hexo 博客进行中英文切换的配置过程 实现同一应用共用模版 任何页面可以切换到另一语言的对应页面 并对未明确语言的访问地址 根据浏览器语言进行自动跳转 实现细则 中英文地址区分 博客中文首页 https chanvinxia
  • Filter内存马浅析

    1 何谓内存马 以Tomcat为例 内存马主要利用了Tomcat的部分组件会在内存中长期驻留的特性 只要将我们的恶意组件注入其中 就可以一直生效 直到容器重启 Java内存shell有很多种 大致分为 1 动态注册filter 2 动态注册
  • 为何pytorch nn.KLDivLoss()损失计算为负数?

    参考文献 https www zhihu com question 384982085 先来看一下KL散度的定义 这里是要用分布Q为标签 原始分布 分布P作为预测值 预测分布 在pytorch中 nn KLDivLoss 的计算公式如下 上