torch.nn.CrossEntropyLoss 多个批次

2023-12-22

我目前正在与torch.nn.CrossEntropyLoss。据我所知,批量计算损失是很常见的。但是,是否有可能计算多个批次的损失?

更具体地说,假设我们给出了数据

import torch

features = torch.randn(no_of_batches, batch_size, feature_dim)
targets = torch.randint(low=0, high=10, size=(no_of_batches, batch_size))

loss_function = torch.nn.CrossEntropyLoss()

有没有一种方法可以一行计算

loss = loss_function(features, targets) # raises RuntimeError: Expected target size [no_of_batches, feature_dim], got [no_of_batches, batch_size]

?

先感谢您!


您可以计算多个交叉熵损失,但您需要自己进行减少。由于交叉熵损失假设特征暗淡始终是特征张量的第二维,因此您还需要首先对其进行排列。

loss_function = torch.nn.CrossEntropyLoss(reduction='none')
loss = loss_function(features.permute(0,2,1), targets).mean(dim=1)

这将导致loss张量与no_of_batches条目。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

torch.nn.CrossEntropyLoss 多个批次 的相关文章

随机推荐

  • 如何使用 Java 8 将方法传递给注释? [复制]

    这个问题在这里已经有答案了 我想将方法 传递给注释 这样的事情可能吗 MyAnnotation method MyClass myMethod private String myVariable 传递方法不是一种选择 相反 传递以下内容应该
  • jquery 中 $.each 内的appendTo()似乎会导致闪烁

    当appendTo 在 each里面时会导致闪烁 each jsob Table function i employee div class resultsdiv br span class resultName employee Emp
  • 在哪里可以找到“收缩期间发现警告”的警告?

    将 RxAndroid 和 Retrofit 库添加到我的 gradle 并编译后 我收到以下错误 显示在我的 Android Studio 消息面板中 Error Execution failed for task app transfo
  • Discord.py Bot 如何播放本地文件中的音频

    基本上就是标题 我已经安装了 ffmpeg 和discord py audio 我只需要了解它是如何工作的 找不到任何本地音频文件的教程 我无法理解文档中的任何内容 这是播放本地音频文件的功能 我在使用 FFmpeg 时遇到问题 因此我对
  • 选择过去 30 天的所有订单,并计算每天的数量

    我正在尝试选择过去 30 天内的所有订单来自一位客户 所以我需要 customer id customer id 和计算我每天有多少订单对于那一位客户 我需要得到这样的数组 Array 1 gt Array orders gt 41 dat
  • ref、toRef 和 toRefs 之间有什么区别

    我刚刚开始使用 Vue 3 和 Composition API 我想知道两者之间有什么区别ref toRef and toRefs Vue 3 ref A ref https v3 vuejs org api refs api html r
  • 如何正确使用 ASP.NET Core 共享框架或如何单独使用其程序集?

    情况 在我们的应用程序中 我们有一个WPF客户端 and an ASP NET Core 服务器 两者都使用 NET 5 我们将所有 DLL 存储在server和client当用户登录时从服务器下载所有必需的 DLL 最初 client独立
  • 在 Post 请求上触发 Socket

    尝试在 POST 请求上发出消息 收到错误消息 无法读取属性 emit未定义的 app post webhook orders updated function req res next io socket emit order Order
  • 在 Azure Devops 上的 cURL 请求中使用环境变量

    我正在尝试使用 Azure DevOps 上的 cURL 通过命令行任务将 zip 文件上传到 Netlify 显然我不想在 yaml 文件中包含 Netlify 访问令牌 因此我为它创建了一个秘密变量 使用 UI 设计器 并使用 然而我不
  • 为什么创建了很多spark-warehouse文件夹?

    我在ubuntu上安装了hadoop 2 8 1 然后在其上安装了spark 2 2 0 bin hadoop2 7 我使用 Spark shell 并创建了表格 我再次使用直线并创建了表格 我观察到创建了三个不同的文件夹 名为spark
  • 共享文件而不将其保存在外部存储上

    我使用以下代码允许用户共享位图 try File save dir Environment getExternalStorageDirectory FileOutputStream out new FileOutputStream save
  • 在 Word 中获取本地化/未本地化的样式名称 (VSTO)

    我有一个单词插件 需要帮助处理样式名称 我使用 get Style NameLocal 获得段落样式 这将返回本地化名称 具体取决于 Office 运行所用的语言 只要有内置样式 我就找到了一种方法来获取本地名称 方法是将 wdBuiltI
  • 在构造函数中初始化虚拟属性是否错误? [复制]

    这个问题在这里已经有答案了 在构造函数中初始化虚拟属性是否错误 它只是感觉不对 因为如果您重写派生类中的属性 该属性将首先使用基类构造函数中的值进行初始化 然后由派生类构造函数再次对其进行赋值 有没有其他方法可以做到这一点 我正在谈论这样的
  • php imagick setGravity 函数不适用于compositeImage() 函数

    我正在为一个项目使用 php Imagick 类 我尝试合成一个图像 改变图像的重力 我的意思是 我想将目标图像合成到中间或顶部中心 I use imageOrg gt setGravity imagick GRAVITY CENTER I
  • kafka + 如何避免磁盘存储空间不足

    我想描述我们的一个生产集群上的以下案例 我们有 HDP 版本 2 6 4 的 ambari 集群 集群包括 3 台 kafka 机器 每个 kafka 都有 5 T 的磁盘 我们看到的是所有kafka磁盘的大小都是100 所以kafka磁盘
  • R:分配数据框列的变量标签

    我正在努力处理 data frame 列的变量标签 假设我有以下数据框 更大数据框的一部分 data lt data frame age c 21 30 25 41 29 33 sex factor c 1 2 1 2 1 2 labels
  • memset bool 为 0 安全吗?

    假设我有一些legacy无法更改的代码 除非bug被发现 它包含以下代码 bool data 32 memset data 0 sizeof data 这是设置所有内容的安全方法吗bool在数组中到false value 更一般地说 安全吗
  • 从泛型类实现的接口调用泛型类中的泛型属性

    我有一个具有一个类型参数 T 的泛型类 我需要存储这些不同类型的通用对象的集合 因此我创建了一个通用类按照建议实现的接口here https stackoverflow com questions 754341 adding generic
  • 初学iphone问题:画一个矩形。我究竟做错了什么?

    试图找出我在这里做错了什么 已经尝试了几种方法 但我从未在屏幕上看到那个难以捉摸的矩形 现在 这就是我想做的一切 只需在屏幕上绘制一个矩形 除了 CGContextSetRGBFillColor 之外 我在所有内容上都收到 无效上下文 之后
  • torch.nn.CrossEntropyLoss 多个批次

    我目前正在与torch nn CrossEntropyLoss 据我所知 批量计算损失是很常见的 但是 是否有可能计算多个批次的损失 更具体地说 假设我们给出了数据 import torch features torch randn no