torch的交叉熵损失函数(cross_entropy)计算(含python代码)

2023-11-16

1.调用

首先,torch的交叉熵损失函数调用方式为:

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

一般会写成:

import torch.nn.functional as F
F.cross_entropy(input, target)

2.参数说明

  • 输入张量)–(N, C), 其中C = 类别数;或在 2D 损失的情况下输入尺寸为(N, C, H, W) ,或在K≥1 在 K 维损失的情况下输入尺寸为 (N, C, d1, d2, ..., dK) 

  • target张量)-(N),其中每个值是 0target[i]≤​​​​​​​C-1, 或者在 K≥1 对于 K 维损失,目标张量的尺寸为(N, d1, d2, ..., dK)

  • weight ( Tensor , optional ) – 对每个类别的手动重新缩放权重。如果给定,则必须是大小为C的张量

  • size_average ( bool , optional ) – 不推荐使用。默认情况下,损失是批次中每个损失元素的平均值。请注意,对于某些损失,每个样本有多个元素。如果该字段size_average 设置为False,则对每个小批量的损失求和。当 reduce 为 时忽略False。默认:True

  • ignore_index ( int , optional ) – 指定一个被忽略且对输入梯度没有贡献的目标值。当size_average为 时 True,损失在未忽略的目标上取平均值。默认值:-100

  • reduce ( bool , optional ) – 不推荐使用。默认情况下,损失对每个小批量的观察进行平均或求和,取决于size_average。当reduceis 时False,返回每个批次元素的损失并忽略size_average。默认:True

  • reduce ( string optional ) – 指定应用于输出的缩减: 'none''mean''sum''none': 不会应用减少, 'mean': 输出的总和将除以输出中的元素数, 'sum': 输出将被求和。注意:size_average 和reduce正在被弃用,同时,指定这两个参数中的任何一个都将覆盖reduction. 默认:'mean'

3.举例说明

代码:

import torch
import torch.nn.functional as F
input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
loss = F.cross_entropy(input, target)
loss.backward()

变量输出:


input:
tensor([[-0.6314,  0.6876,  0.8655, -1.8212,  0.0963],
        [-0.5437,  0.2778, -0.1662, -0.0784, -0.6565],
        [-0.1164,  0.3882,  0.2487, -0.5318,  0.3943]], requires_grad=True)
target:
tensor([1, 0, 0])
loss:
tensor(1.6557, grad_fn=<NllLossBackward>)

4.注意

python里的torch.nn.functional.cross_entropy函数的实现是:

def cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100,
                  reduce=None, reduction='mean'):
    if size_average is not None or reduce is not None:
        reduction = _Reduction.legacy_get_string(size_average, reduce)
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)

注意1:输入张量不需要经过softmax,直接从fn层拿出来的张量就可以送入交叉熵中,因为在交叉熵中已经对输入input做了softmax了。

注意2:不用对label进行one_hot编码,因为nll_loss函数已经实现了类似one-hot过程,不同之处是当class = [1, 2, 3]时要处理成从0开始[0, 1, 2]。

这里把官方网站的地址也放这里:torch.nn.functional — PyTorch master documentationicon-default.png?t=LA92https://pytorch.org/docs/1.2.0/nn.functional.html#torch.nn.functional.cross_entropy

整理不易,欢迎一键三连!!!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch的交叉熵损失函数(cross_entropy)计算(含python代码) 的相关文章

随机推荐

  • 三维数据处理软件架构

    三维数据处理软件都包含哪些模块 三维数据处理软件 一般包含三个模块 数据管理和处理 三维渲染 UI 这与图形学的三个经典问题是相对应的 建模 渲染和交互 与一般常见的数据处理软件 比如图像视频处理 不同的是 这里的数据展示模块需要三维渲染
  • numpy中的mean()函数

    mean 函数定义 numpy mean a axis dtype out keepdims mean 函数功能 求取均值 经常操作的参数为axis 以m n矩阵举例 axis 不设置值 对 m n 个数求均值 返回一个实数 axis 0
  • VB联合Python开发

    用到 Python 首先你肯定得装一个Python吧 用3 x的 创建一个 py文件 说明 需要引用pythoncom 这个需要pip install pywin32 import pythoncom class PythonUtiliti
  • R语言缺失值填补

    本文主要介绍如何利用R语言进行数值型缺失值的填补 主要使用zoo包中的na aggregate na approx na locf 函数进行缺失值的均值填补 线性插值填补以及邻近值填补 install packages zoo librar
  • 开源是物联网的驱动力量

    本文转载至 http www infoq com cn articles open source as a driver of internet of things utm campaign infoq content utm source
  • Shell脚本攻略:通配符、正则表达式

    目录 一 理论 1 通配符 2 正则表达式 二 实验 1 通配符 2 正则表达式 一 理论 1 通配符 1 概念 通配符只用于匹配文件名 目录名等 不能用于匹配文件内容 而且是已存在的文件或者目录 各个版本的shell都有通配符 这些通配符
  • 《Android 开发艺术探索》笔记2--IPC机制

    Android 开发艺术探索 笔记2 IPC机制 思维导图 Android IPC简介 Android中的多进程的模式 IPC基础概念 Serializable接口 Parcelable接口 Android的几种跨进程的方式 使用Bundl
  • having where 你真的了解了吗?

    where group by group by 字句 和 where条件语句结合在一起使用 当结合在一起时 where在前 group by 在后 即先对select xx from xx的记录集合用where进行筛选 然后再使用group
  • QT 三种关联信号和槽的办法

    1 手动关联 connect ui gt showChildButton QPushButton clicked this MyWidget showChildDialog 2 自动关联 右键单击按钮弹出菜单中选择 转到槽 void MyD
  • Basic Level 1052 卖个萌 (20分)

    题目 萌萌哒表情符号通常由 手 眼 口 三个主要部分组成 简单起见 我们假设一个表情符号是按下列格式输出的 左手 左眼 口 右眼 右手 现给出可选用的符号集合 请你按用户的要求输出表情 输入格式 输入首先在前三行顺序对应给出手 眼 口的可选
  • vue添加水印踩坑

    介绍 前景 app页面添加水印展示 技术实现 watermark dom 完整代码 vue watermark 实现效果 功能描述 添加 删除 更新水印 引入 方式一 推荐 方便拓展 在index html引入相关文件 方式二 npm包引入
  • java byte[] 学习总结

    最近在学习netty 突然发现自己对字符数组是那么的陌生 吓死宝宝了 然后各种学习 然后测试 终于会用一些了 下线的都是本人的学习笔记 byte表是字符 一个字节 8位 可以组成2 8 256中不同数字 byte存值范围 128 127 1
  • pytorch基本使用_02

    import numpy as np import torch 从numpy引入tensor a np array 2 3 3 print torch from numpy a tensor 2 0000 3 3000 dtype torc
  • java线上CPU100%如何排查

    定位耗费CPU的进程 top c 就可以显示进程列表 然后输入P 按照cpu使用率排序 你会看到类似下面的东西 2 定位耗费CPU的线程 top Hp 1500 就是输入那个进程id就好了 然后输入P 按照cpu使用率排序 你会看到类型下面
  • 单片机c语言数码管显示0到9,单片机如何让8个数码管同时流水显示0到9,大家帮我看看!...

    按你的要求修改如下 include reg52 h 此文件中定义了单片机的一些特殊功能寄存器 typedef unsigned int u16 对数据类型进行声明定义 typedef unsigned char u8 sbit LSA P2
  • Java 网络编程UDP协议之发送数据和接收数据的详解

    博主前些天发现了一个巨牛的人工智能学习网站 通俗易懂 风趣幽默 忍不住也分享一下给大家 点击跳转到网站 UDP协议 用户数据报协议 User Datagram Protocol UDP是无连接通信协议 即在数据传输时 数据的发送端和接收端不
  • 《信号与系统》4.10.2工频干扰的滤除

    平台 版本 Multisim14 1 参考书籍 信号与系统 4 10 2工频干扰的滤除 工程上 滤除工频干扰比较常用的电路是无源双T陷波滤波器 图示双T的无源陷波滤波器电路 陷波器是某一小频率范围内的带阻滤波器 陷波器的一个常见的应用是滤除
  • Seaborn入门详细教程

    作者 luanhz 来源 小数志 Seaborn入门详细教程 导读 今天我们来介绍 seaborn 这是一个基于matplotlib进行高级封装的可视化库 相比之下 绘制图表更为集成化 绘图风格具有更高的定制性 教程目录 01 初始seab
  • 一文带你了解序列化与反序列化基本原理与操作

    文章目录 一 什么是序列化与反序列化 二 为什么我们需要序列化与反序列化 三 步骤说明 四 注意说明 五 代码说明 六 序列化与反序列化原理 一 什么是序列化与反序列化 序列化是指将对象转换为字节序列的过程 以便于存储或传输 在序列化过程中
  • torch的交叉熵损失函数(cross_entropy)计算(含python代码)

    1 调用 首先 torch的交叉熵损失函数调用方式为 torch nn functional cross entropy input target weight None size average None ignore index 100