交叉熵:pytorch版本 vs 日常版本

2023-11-16

首先看下平时我们所说的交叉熵:
传送门
在信息论中,交叉熵可认为是对预测分布q(x)用真实分布p(x)来进行编码时所需要的信息量大小。而在机器学习的分类问题中,真实分布p(x)是one-hot形式,表明独属于one-hot中1对应的角标的那个类,因此这也是为什么交叉熵常用于做分类问题的损失函数。

H ( p , q ) = ∑ x p ( x ) log ⁡ 1 q ( x ) = − ∑ x p ( x ) log ⁡ q ( x ) \begin{aligned} H(p, q) &=\sum_{x} p(x) \log \frac{1}{q(x)} \\ &=-\sum_{x} p(x) \log q(x) \end{aligned} H(p,q)=xp(x)logq(x)1=xp(x)logq(x)

那么pytorch里的交叉熵是这样的吗?我们测试下:
pytorch:

import torch
loss = torch.nn.CrossEntropyLoss(reduction = "none")
pred = torch.tensor([[0.0,1.0],[0.4,0.6,],[0.8,0.2]])
label = torch.tensor([1,0,0])
print(loss(pred,label))
# 输出:tensor([0.3133, 0.7981, 0.4375])

手动:

import math
# [0.0,1.0]和[1]
res = -math.log(0.1) = -1 * 0 = 0
#  [0.4,0.6,]和[0]
res = -math.log(0.4) = 0.916290731874155
#  [0.8,0.2]和[0]
res = -math.log(0.8) = 0.2231435513142097

很显然和通过torch得到的结果不同。那么看下pytorch文档里的交叉熵公式。

torch.nn.CrossEntropyLoss是nn.logSoftmax()和nn.NLLLoss()整合起来的版本,其中NLLLoss()是negative log likelihood loss,负对数似然(损失)函数和交叉熵(损失)函数背后的思想或者说得到的过程有些不同。这里先介绍下“似然”的概念:在机器学习中,似然函数是一种关于模型中参数的函数。“似然性(likelihood)”和"概率(probability)"词意相似,但在统计学中它们有着完全不同的含义:概率用于在已知参数的情况下,预测接下来的观测结果;似然性用于根据一些观测结果,估计给定模型的参数可能值。因此负对数似然函数是希望通过已知的训练数据的标注去找到一组模型的参数,让模型的预测结果贴合训练数据的标注结果,寻找这组模型参数的过程就是模型训练的过程,也是最小化负似然函数的过程;而交叉熵则是熵的概念,即用真实分布编码预测分布时所需要的信息量大小,信息量越小,两个分布越接近,因此最小化这个的过程就是训练模型的过程。

虽然两者背后思想有些不同,但是最后呈现的公式是一样的,即上面提到平时所说的交叉熵公式。

那么nn.NLLLoss()已经可以用来做交叉熵公示了,为什么还要有torch.nn.CrossEntropyLoss呢?这是因为nn.CrossEntropyLoss()是考虑具体训练过程做了优化得到的版本:
在具体训练过程中,假设batch_size=32, num_classes = 10, 那么经过最后的Linear得到一个batch的预测结果的shape为[32,10],其中每条数据在不同类别上的值有可能小于0,且所有类别值加起来不为1,因此这不能算概率值,所以需要softmax。softmax之后结果再取log,这样做的目的是将乘法改成加法减少计算量,同时保障函数的单调性。因此最后将nn.logSoftmax()和nn.NLLLoss()结合得到了nn.CrossEntropyLoss()。

下面看下具体怎么结合的:
torch.nn.NLLLoss:
官方地址
这个 w n w_n wn是focal loss里的α,数据不平衡时用的,因此一般没有
ℓ ( x , y ) = L = { l 1 , … , l N } ⊤ , l n = − w y n x n , y n , w c = \ell(x, y)=L=\left\{l_{1}, \ldots, l_{N}\right\}^{\top}, \quad l_{n}=-w_{y_{n}} x_{n, y_{n}}, \quad w_{c}= (x,y)=L={l1,,lN},ln=wynxn,yn,wc= weight [ c ] ⋅ 1 { c ≠ [c] \cdot 1\{c \neq [c]1{c= ignore_index } \} }

torch.nn.LogSoftmax:
官方地址
LogSoftmax ⁡ ( x i ) = log ⁡ ( exp ⁡ ( x i ) ∑ j exp ⁡ ( x j ) ) \operatorname{LogSoftmax}\left(x_{i}\right)=\log \left(\frac{\exp \left(x_{i}\right)}{\sum_{j} \exp \left(x_{j}\right)}\right) LogSoftmax(xi)=log(jexp(xj)exp(xi))

因为在NLLLoss()中true label也是one-hot,即只有true label那个类参与计算,因此将NLLLoss()的 x n y n x_ny_n xnyn代入logSoftmax分子的 x i x_i xi,得到一个公式,然后再根据log(xy) = logx + logy化简:
1.9.1 pytorch官方地址
在这里插入图片描述
上面用1.9.1的pytorch是因为从1.10.1开始,公式没有展示化简后那步,不便于本文理解,并且1.10.1的交叉熵实现了label smooth。

因此理解上述内容后,根据pytorch的交叉熵公式再计算下:

import math
loss = torch.nn.CrossEntropyLoss(reduction = "none")
# 1.[0.0, 1.0]和[1]
# torch
loss(torch.tensor([[0.0,1.0]]), torch.tensor([1]))
# 输出:tensor([0.3133])
# 手写
-1+math.log(math.exp(0)+math.exp(1))
# 输出:0.3132616875182228

# 2.[0.0, 1.0]和[0]
# torch
loss(torch.tensor([[0.0,1.0]]), torch.tensor([0]))
# 输出:tensor([1.3133])
# 手写
-0+math.log(math.exp(0)+math.exp(1))
# 输出:1.3132616875182228

# 3.[0.0, 0.0, 1.0]和[2]
loss(torch.tensor([[0.0,0.0,1.0]]), torch.tensor([2]))
# 输出:tensor([0.5514])
-1+math.log(math.exp(0)+math.exp(0)+math.exp(1))
# 输出:0.5514447139320509

# 4.模拟一个batch
entroy=torch.nn.CrossEntropyLoss() # reduction默认为mean
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
# 输出:tensor(1.1142)
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
    output = 0
    length = len(target)
    for i in range(length):
        hou = 0
        for j in input[i]:
            hou += np.exp(j)
        output += -input[i][target[i]] + np.log(hou)
    return np.around(output / length, 4)
print(cross_entorpy(input, target))
# 输出:1.1142

对数似然、负对数似然
交叉熵、负对数似然
pytorch交叉熵公式推导以及代码证明
pytorch交叉熵

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

交叉熵:pytorch版本 vs 日常版本 的相关文章

随机推荐

  • 使用mysql_upgrade升级mysql5.1至5.6的数据库升级实施方案

    本方案是因为在工作中遇到的一个mysql主从功能配置的问题所引起的 有一个处在从位置上的mysql是5 1版本的 从5 1到5 6的mysql各种系统管理功能 像系统表表结构 日志文件格式等等均不一致 这时直接以5 1版本去作为一个5 6版
  • 信安软考 第十二章 网络安全审计技术

    一 网络安全审计概述 网络安全审计是指对网络信息系统的安全相关活动信息进行获取 记录 存储 分析和利用的工作 网络安全审计的作用在于建立 事后 安全保障措施 保存网络安全事件及行为信息 为网络安全事件分析提供线索及证据 以便发现潜在的网络安
  • rocketMq启动broker报错找不到或无法加载主类 Files\Java\jdk1.8.0_171\lib\dt.jar;C:\Program]

    假如弹出提示框提示 错误 找不到或无法加载主类 xxxxxx 1 打开runbroker cmd 将 CLASSPATH 加上英文双引号 切勿别加中文双引号 2 打开runserver cmd 同理 将 CLASSPATH 加上英文双引号
  • hexo+GitHub Pages一键搭建部署博客

    文章目录 前言 博客相关配置 matery主题相关配置 1 什么是 Hexo 2 准备工作 3 生成博客 4 更换主题 5 部署到github pages 总结 前言 现在技术更新迭代是非常的快 尤其是web方面 所以当前搭建一个博客差不多
  • Ubuntu16.04安装ROS Kinetic详细步骤

    文章目录 ROS安装 配置Ubuntu软件仓库 设置sources list 设置密钥 更新Debian软件包索引 安装ROS 初始化 rosdep 环境配置 构建工厂依赖 测试安装 开发环境 ROS安装 ROS Kinetic只支持Wil
  • CompletableFuture使用(二)

    CompletableFuture创建异步任务后 get 方法是阻塞到Future完成后返回结果 对于构建异步系统 需要将回调附加到CompletableFuture上 当Future完成时自动调用 就可以使用thenApply thenA
  • 【直播+福利】生产压测环境,如何做好安全保障?

    互联网 数字经济的不断发展使得系统架构不断演变 实现了从 单线程 到 多线程 多组件 再到 分布式 微服务 的一个跨越 分布式系统的复杂程度是公认的 牵一发而动全身 想要保障系统的稳定可用是所有企业的共有难题 生产全链路压测应运而生 可实际
  • Ansible-角色部署LAMP

    配置主机 root ansible cd etc ansible root ansible ansible ls ansible cfg hosts roles root ansible ansible vim hosts dev node
  • [图文]Openfiler应用篇(四) FTP和Quota

    本篇我们讨论openfiler FTP和Quota 磁盘配额 的应用 openfiler FTP和Quota功能必须在开启帐户功能的条件下才能使用 一 FTP应用 1 开启FTP 点击主菜单Services 在Manage Services
  • git的使用和规划

    1 拉取项目 在拉取项目的时候使用git rebase 这样分支管理更加清晰 2 提交项目 commit的时候不要把不希望别人看到的改到都commit上 commit的时候 要检查修改的文件代码书写是否正确 下图中打钩文件为想要提交的文件
  • SQL中EXISTS理解使用

    SQL中EXISTS的理解使用 关联子查询 EXISTS理解使用 关联子查询 在讲述EXISTS用法之前 先讲述一下关联子查询 关联子查询 是指在内查询中需要借助于外查询 而外查询离不开内查询的执行 举个栗子 在Oracle中自带的EMP表
  • Objective-C块block介绍

    块的定义 返回值类型 形参类型 形参1 形参类型 形参2 块执行体 以上是一个块的写法 1 返回值类型可以省略 形参也可以参略 但是形参的括号不能参略 NSLog 123 通常我们需要反复调用块 因为块相当于一个匿名的函数 我们调用它时可以
  • 在VMware中设置ubuntu与Windows共享文件夹

    本机系统 win7 使用vmware安装的unbutu 之前在win7上下载了一些文档和软件 想在虚拟机中使用 结果发现读取不了这些文件 头疼了一下午 从网上搜索了很多资源 发现没有一个完整的文章可以一次搞定 头疼 这里就总结一下我的方法
  • I2C与SPI通信总线协议

    仅以寄存器地址为8Bit的器件为例 例如MPU6500 LSM6DS3 I2C通信协议 I2C 的要点是了解I2C通信帧的组成部分 START起始位 STOP停止位 ACK NACK信号 从机器件地址 从机寄存器地址 I2C读的时序比较繁琐
  • K8S访问控制------认证(authentication )、授权(authorization )体系

    一 账号分类 在K8S体系中有两种账号类型 User accounts 用户账号 即针对human user的 Service accounts 服务账号 即针对pod的 这两种账号都可以访问 API server 都需要经历认证 授权 准
  • Linux根目录爆满,解决(/dev/mapper/rhel-root 98%问题)

    1 首先确定是否是磁盘空间不足 输入命令 df h 查看磁盘信息 发现已经使用率达到96 所有需要删除大文件数据 2 其次查找大文件 du h max depth 1 命令代表寻找当前目录 哪个文件夹占用空间最大 进入根目录 root vl
  • 六级英语词汇

    genuine d enju n fake If this offer is genuine I will gladly accept it 如果这份帮助是真诚的 我将愉快地接受它 一 单词关 whereas we r z conj 然而
  • [YOLO专题-17]:YOLO V5 - 如何把YOLO训练数据集批量转换成带矩形框的图片

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 https blog csdn net HiWangWenBing article details 122344955 目录 前言 第1章
  • 利用Spring框架在前端实现对数据库的增删改查

    在前端页面上显示购物数据库数据 并且可以这增 删 改 查 1 首先在WEB 配置文件
  • 交叉熵:pytorch版本 vs 日常版本

    首先看下平时我们所说的交叉熵 传送门 在信息论中 交叉熵可认为是对预测分布q x 用真实分布p x 来进行编码时所需要的信息量大小 而在机器学习的分类问题中 真实分布p x 是one hot形式 表明独属于one hot中1对应的角标的那个