Pytorch:交叉熵损失中的权重

2024-01-11

我试图通过一个实际的例子来理解 CrossEntropyLoss 中的权重是如何工作的。所以我首先运行标准 PyTorch 代码,然后手动运行。但损失并不相同。

from torch import nn
import torch
softmax=nn.Softmax()
sc=torch.tensor([0.4,0.36])
loss = nn.CrossEntropyLoss(weight=sc)
input = torch.tensor([[3.0,4.0],[6.0,9.0]])
target = torch.tensor([1,0])
output = loss(input, target)
print(output)
>>1.7529

现在进行手动计算,首先对输入进行 softmax:

print(softmax(input))
>>
tensor([[0.2689, 0.7311],
        [0.0474, 0.9526]])

然后正确类别概率的负对数并乘以相应的权重:

((-math.log(0.7311)*0.36) - (math.log(0.0474)*0.4))/2
>>
0.6662

我在这里缺少什么?


要计算班级的班级权重,请使用sklearn.utils.class_weight.compute_class_weight(class_weight, *, classes, y) 在这里阅读 https://scikit-learn.org/stable/modules/generated/sklearn.utils.class_weight.compute_class_weight.html
这将返回一个数组,即weight.
eg .

x = torch.randn(20, 5) 
y = torch.randint(0, 5, (20,)) # classes
class_weights=class_weight.compute_class_weight('balanced',np.unique(y),y.numpy())
class_weights=torch.tensor(class_weights,dtype=torch.float)
 
print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

然后将其传递给nn.CrossEntropyLoss的权重变量

criterion = nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

loss = criterion(...)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Pytorch:交叉熵损失中的权重 的相关文章

  • openpyxl 2.4.2:保存后公式生成的单元格值为空

    我使用 openpyxl 打开文件 编辑一些单元格并保存更改 这是一个例子 import openpyxl book openpyxl load workbook sheet path sheet book active for row i
  • 使用 python 制作本地服务器应用程序的最佳方法

    我想要简单轻松地集成 python 和 vba 人们 如果他们在阅读本文后亲自见到我 阅读本文可能会杀了我 但我正在使用 django 开发服务器来实现此目的 有没有什么简单又好的方法 仅举个例子 我想使用 python 模块 openpy
  • 如何屏蔽 PyTorch 权重参数中的权重?

    我正在尝试在 PyTorch 中屏蔽 强制为零 特定权重值 我试图掩盖的权重是这样定义的def init class LSTM MASK nn Module def init self options inp dim super LSTM
  • 为什么我不能导入 geopandas?

    我唯一的代码行是 import geopandas 它给了我错误 OSError Could not find libspatialindex c library file 以前有人遇到过这个吗 我的脚本运行得很好 直到出现此错误 请注意
  • 如何使用pycaffe重构caffe网络

    我想要的是 加载网络后 我将分解一些特定的图层并保存新的网络 例如 原网 数据 gt conv1 gt conv2 gt fc1 gt fc2 gt softmax New net 数据 gt conv1 1 gt conv1 2 gt c
  • Python 中的六边形自组织映射

    我在寻找六边形 自组织映射 http en wikipedia org wiki Self organizing map在Python上 准备好模块 如果存在的话 绘制六边形单元格的方法 将六边形单元作为数组或其他方式使用的算法 About
  • Python 中 genfromtxt() 的可变列数?

    我有一个 txt具有不同长度的行的文件 每一行都是代表一条轨迹的一系列点 由于每条轨迹都有自己的长度 因此各行的长度都不同 也就是说 列数从一行到另一行不同 据我所知 genfromtxt Python 中的模块要求列数相同 gt gt g
  • Sorted(key=lambda: ...) 背后的语法[重复]

    这个问题在这里已经有答案了 我不太明白背后的语法sorted 争论 key lambda variable variable 0 Isn t lambda随意的 为什么是variable在看起来像的内容中陈述了两次dict 我认为这里的所有
  • python ttk treeview:如何选择并设置焦点在一行上?

    我有一个 ttk Treeview 小部件 其中包含一些数据行 如何设置焦点并选择 突出显示 指定项目 tree focus set 什么也没做 tree selection set 0 抱怨 尽管小部件明显填充了超过零个项目 但未找到项目
  • 当x轴不连续时如何删除冗余日期时间 pandas DatetimeIndex

    我想绘制一个 pandas 系列 其索引是无数的 DatatimeIndex 我的代码如下 import matplotlib dates as mdates index pd DatetimeIndex 2000 01 01 00 00
  • Python:随时接受用户输入

    我正在创建一个可以做很多事情的单元 其中之一是计算机器的周期 虽然我将把它转移到梯形逻辑 CoDeSys 但我首先将我的想法放入 Python 中 我将进行计数 只需一个简单的操作 counter 1 print counter 跟踪我处于
  • 行为:如何从另一个文件导入步骤?

    我刚刚开始使用behave http pythonhosted org behave 一个Pythonic BDD框架 使用小黄瓜语法 http docs behat org guides 1 gherkin html 行为需要一个特征 例
  • 反加入熊猫

    我有两个表 我想附加它们 以便仅保留表 A 中的所有数据 并且仅在其键唯一时添加表 B 中的数据 键值在表 A 和 B 中是唯一的 但在某些情况下键将出现在表 A 和 B 中 我认为执行此操作的方法将涉及某种过滤联接 反联接 以获取表 B
  • Python unicode 字符代码?

    有没有办法将 Unicode 字符 插入 Python 3 中的字符串 例如 gt gt gt import unicode gt gt gt string This is a full block s unicode charcode U
  • 在wxpython中使用wx.TextCtrl并在按钮单击后显示数据的简单示例 - wx新手

    我正在学习 python 并尝试使用 wxpython 进行 UI 开发 也没有 UI exp 我已经能够创建一个带有面板 按钮和文本输入框的框架 我希望能够在文本框中输入文本 并让程序在单击按钮后对输入框中的文本执行操作 我可以获得一些关
  • 在谷歌C​​olab中使用cv2.imshow()

    我正在尝试通过输入视频来对视频进行对象检测 cap cv2 VideoCapture video3 mp4 在处理部分之后 我想使用实时对象检测来显示视频 while True ret image np cap read Expand di
  • WindowsError:[错误 5] 访问被拒绝

    我一直在尝试终止一个进程 但我的所有选项都给出了 Windows 访问被拒绝错误 我通过以下方式打开进程 一个python脚本 test subprocess Popen sys executable testsc py 我想杀死那个进程
  • Pandas 在特定列将数据帧拆分为两个数据帧

    I have pandas我组成的 DataFrameconcat 一行由 96 个值组成 我想将 DataFrame 从值 72 中分离出来 这样 一行的前 72 个值存储在 Dataframe1 中 接下来的 24 个值存储在 Data
  • Google App Engine 中的自定义身份验证

    有谁知道或知道我可以在哪里学习如何使用 Python 和 Google App Engine 创建自定义身份验证流程 我不想使用 Google 帐户进行身份验证 并且希望能够创建自己的用户 如果不是专门针对 Google App Engin
  • 如何在SqlAlchemy中执行“左外连接”

    我需要执行这个查询 select field11 field12 from Table 1 t1 left outer join Table 2 t2 ON t2 tbl1 id t1 tbl1 id where t2 tbl2 id is

随机推荐

  • 为什么通过引用捕获变量的 lambda 不能转换为函数指针?

    如果我有一个通过引用捕获所有自动变量的 lambda 为什么不能转换为函数指针呢 常规函数可以像通过引用捕获所有内容的 lambda 一样修改变量 那么为什么不一样呢 换句话说 我想 lambda 和 a 之间的功能区别是什么 捕获列表和常
  • 使用 Laravel 进行 Flutter FCM

    我正在使用 Laravel 作为我的应用程序后端 并希望按主题向我的 flutter 应用程序发送推送通知 现在我在我的 flutter 应用程序中实现了 firebase 消息传递 作为 registerOnFirebase fireba
  • 立即运行 Jenkins 作业

    我有一个非常轻量级的作业 应该在触发时立即执行 而不是等待一个小时才能完成当前作业 据我了解 一个蝇量级任务就是我想要的 它将创建一个临时执行器 专门用于该任务 我怎样才能让一个工作作为蝇量级运行 我最近也遇到了同样的问题 我的公司有很多
  • 我应该如何折叠 Python 中的元素? [复制]

    这个问题在这里已经有答案了 例如 l a 1 b 2 a 2 collapsed l dict a 1 2 b 2 如何最好地从l to collapsed l 从某种意义上说 我想要某种方式来概括我正在崩溃的 领域 以及哪个领域 我认为这
  • eslint 禁用扩展覆盖

    如果你有一个覆盖 你想 降级 js解析器 你如何关闭extends来自父母 parserOptions很容易被覆盖 因为它是基于密钥的 extends因为空数组不执行任何操作 因为它尝试将空列表附加到原始数组 如果您将其设置为null 您会
  • Zend Framework notEmpty 验证器 setRequired

    我看过其他的问题 https stackoverflow com questions 3871460 zend form setrequiredtrue or addvalidatornotempty 谷歌搜索这个 我的问题是 当我提交带有
  • std::variant 在 MSVC 和 gcc 中的行为不同

    Update 这是一个 C 标准缺陷 已在 C 20 P0608R3 中修复 另外 VS 2019 16 10 修复了这个错误 std c 20 MSVC 19 28 拒绝以下代码 但 gcc 10 2 接受它并输出true false i
  • 如何以编程方式为 UINavigationController 子类化 UINavigationBar?

    我正在使用自定义的drawRect函数来绘制UINavigationBar在我的 iOS4 应用程序中 它不使用图像 仅使用 CoreGraphics 因为你不能在中实现drawRectUINavigationBariOS5 中的类别 Ap
  • JavaScript 唯一浏览器 ID

    有没有办法在javascript中为浏览器创建一个唯一的ID 我说的不是每次生成时都是随机的 ID 而是生成该 ID 的浏览器所特有的 ID 而且还考虑了运行该 ID 的计算机 Example Windows 7 Chrome 可能会生成
  • Ionic Zip 仅提取特定文件夹

    我有一个案例 我需要使用 C Ionic zip 库提取 Zip 文件 Zip 文件包含多个文件夹 我想提取特定文件夹并将其复制到特定目的地 例如名为 abc zip 的 Zip 文件和目录结构如下 父目录 gt 子目录1 gt 文件a 文
  • 从文档大纲(书签)中获取页码

    我正在使用 itext7 库来操作一些现有的 PDF 由于某种原因 我无法从大纲中获取页码 我想我应该以某种方式从Pdf目的地 http itextsupport com apidocs itext7 latest com itextpdf
  • 每天都会对 Java 8 Stream API 中的实体进行惰性排序吗?

    我有一个很大的 Java 8 Stream Stream
  • 字符串的哈希函数

    我正在用 C 语言研究哈希表 并且正在测试字符串的哈希函数 我尝试的第一个功能是添加 ascii 代码并使用模 100 但我的第一次数据测试结果很差 130 个单词有 40 次碰撞 最终输入数据将包含 8000 个单词 它是存储在文件中的字
  • 如何显示R中两个日期之间发生的事件

    我的问题看起来很简单 我希望如此 我有一个数据框 其中包含疾病诊断日期 指示患者服用哪种药物 或暴露和未暴露组 的二元变量 药物的开始和停止日期以及总体停止日期 ID Diag date Treatment End date Drug st
  • c3p0中的资源无法检出的原因是什么?

    因此 我正在研究 c3p0 API 来调试我们的一个生产问题 该问题导致在检查连接时出现堆栈溢出错误 我发现下面的评论BasicResourcePool班级的checkoutResource method This function rec
  • Vuetify 标准设置(babel/eslint)图像加载失败

    我正在开发一个 VueJS 项目 并尝试在轮播上加载图像 我正在使用标准设置并将图像放在资产文件夹中 我引用图像 URL
  • Storybook 需要导出默认的 Ant Design 组件才能应用样式

    我希望使用 Ant Design 设计一些 React 组件 并将它们记录在 Storybook 中 故事书和组件都编写正确且有效 模态故事 js import React from react import action from sto
  • python中具有相同名称的对象引用不同的id

    在下面的代码片段中 两个对象名为div在第 1 行和第 2 行创建 python如何区分两者div在同一作用域下创建的对象 When id 应用于两个对象 对于相似的命名对象会显示两个不同的地址 为什么会这样呢 def div a b re
  • webclient 方法对我的 Silverlight 应用程序不可用

    尝试用 C 进行基本的 Web 客户端数据拉取 这些方法在 Visualstudio 中不可用 并且代码无法编译 snip WebClient client new WebClient byte resp client DownloadDa
  • Pytorch:交叉熵损失中的权重

    我试图通过一个实际的例子来理解 CrossEntropyLoss 中的权重是如何工作的 所以我首先运行标准 PyTorch 代码 然后手动运行 但损失并不相同 from torch import nn import torch softma