pytorch crossentropy为nan

2023-11-19

pytorch crossentropy为nan

交叉熵损失函数的具体为:

loss = -(x*ln(z)+(1-x)*ln(1-z)) 
z = softmax(pred_x)

这样当z为0/0时会出现loss为nan的情况

本人的具体原因
网络中用了MultiHeadAttention,attention的mask全为0,这样attention就为nan,造成个别样本的输出特征全为nan。于是就自己用pytorch写了一个cross_entropy loss函数,剔除掉有时候个别为nan的样本。
github地址:Self_cross_entropy

参考解决方案

在pred_x上加一个很小的量,如1e-10
loss = crossentropy(out+1e-8, target)
1
采用更小的学习率
做梯度裁剪

pytorch 梯度裁剪

import torch.nn as nn

outputs = model(data)

loss= loss_fn(outputs, target)

optimizer.zero_grad()

loss.backward()

nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)

optimizer.step()

nn.utils.clip_grad_norm_ 的参数:

parameters – 一个基于变量的迭代器,会进行梯度归一化

max_norm – 梯度的最大范数

norm_type – 规定范数的类型,默认为L2


还可能是数据有问题
比如这位的.链接
[参考]

https://stats.stackexchange.com/questions/108381/how-to-avoid-nan-in-using-relu-cross-entropy
 

且新的变量自动全是variable类型,可顺利反向传播
实现好后运行结果出现大量的nan,无法正常运算,使用clamp限制loss计算值的范围
class CrossEntropy(nn.Module):
    def __init__(self):
        super(CrossEntropy, self).__init__()
        
    def forward(self, inputs, targets):
        ## torch中要想实现backward就不能使用np,不能用array,只能使用tensor,只有tensor才有requires_grad参数
        loss1=-targets*(torch.log(inputs)).cuda()
        loss=torch.sum(loss1.clamp(min=0.0001,max=1.0))

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch crossentropy为nan 的相关文章

随机推荐

  • 【Jboss】热部署

    版权声明 本文为博主原创文章 未经博主允许不得转载 https blog csdn net inforstack article details 47681803
  • 蓝桥杯:斐波那契数列最大公约数

    题目表示的很明确 要用两个算法 斐波那契数列是很经典的dp问题 最大公约数是很经典的辗转相除法 从而我理所应当的就定义一个数组存放斐波那契数列 long long int F 2021 0 F 1 1 F 2 1 for int i 3 i
  • 解决SQL Server占用服务器内存过高问题

    最近发现个问题 数据库服务器内存居高不下 64G的内存 几乎被占用100 结果差点把服务器给拖垮了 第一步 打开SQL Server Management Studio 在连接上右键 属性 第二步 内存选项卡 修改最大服务器内存的大小 如下
  • Android Killer的安装和配置 -安卓逆向的必备神器

    图文并茂 详细的不能再详细了 这你总不能学不会吧 都给我学会 安装包我已经放在了文章末尾 需要可以自取哟 1 下载安卓Android killer 首先 我们先下载打包好的Android killer 解压缩后可以得到一大堆没啥用的 文件
  • 用WinHex软件解析FAT32文件系统

    一 工欲善其事 1 准备工作 将一个U盘格式化为FAT32格式 在U盘内创建几个文件 最好是TXT文档 其中至少有一个是长文件 命名较长 2 补充知识 短文件名表示 长文件名表示 Note 当一个文件名为长文件名时 会由几个长文件名表示法和
  • Linux查看当前目录下各文件所占空间

    要查看当前目录下各文件所占空间 可以使用du命令 磁盘使用情况 配合sort命令来实现 以下是在Linux系统中执行的命令 du sh sort hr 解释一下这个命令 du sh 计算当前目录下每个文件和目录的总大小 并以易读的方式显示
  • 给定一个无重复元素的有序整数数组 nums 。 返回 恰好覆盖数组中所有数字 的 最小有序 区间范围列表。也就是说,nums 的每个元素都恰好被某个区间范围所覆盖,并且不存在属于某个范围但不属于 n

    class Solution public List
  • 让dapper支持Oracle

    之前的项目数据使用mssql和mysql ORM使用一个轻量级的dapper 感觉很方便 性能也比EF强 关键是语法灵活 上手容易 用这种框架开发了几个网站 感觉非常好 但新项目要使用oracle 就出问题了 dapper里的关键字 在or
  • 史上最全量化交易资源整理

    有些国外的平台 社区 博客如果连接无法打开 那说明可能需要 科学 上网 国内在线量化平台 BigQuant 你的人工智能量化平台 可以无门槛地使用机器学习 人工智能开发量化策略 基于python 提供策略自动生成器 镭矿 基于量化回测平台
  • PCL 大窗口可视化两个点云

    一 主要参数 viewer gt setFullScreen true 设置点云为全屏显示的2D俯视图 二 代码实现 1 一个大窗口可视化两个点云 include
  • 清除history内容

    history记录是记录在 bash history中的 history c 清除的是当前会话的记录 原来的记录是不会被清除的 可以直接删除 bash history 清空这个文件cat dev null gt bash history
  • c++ - 抽象类 和 多态当中一些问题

    抽象类 纯虚函数 在虚函数的后面写上 0 则这个函数为纯虚函数 class A public virtual void func 0 纯虚函数不需要写函数的定义 他有类似声明一样的结构 抽象类概念 我们把具有纯虚函数的类 叫做抽象类 所谓抽
  • C++ 惯用法之 CRTP

    背景 CRTP 是 一种 C 的设计方法 其巧妙的结合了继承和模板编程技术 可以用来给类提供额外的功能 CRTP 概述 CRTP 的基本特征表现为 基类是一个模板类 派生类在继承该基类时 将派生类自身作为模板参数传递给基类 实现示例 tem
  • ACE命令参数解析

    ACE提供了ACE Get Opt类来处理命令行参数选项 这个类是一个迭代器 用于解析按照自然数方式计数的参数向量 它包装了POSIX的getotp 函数的功能 但是与getopt 函数不同 ACE Get Opt类的每个实例都维护有自己的
  • 【Vue】终极笔记:面试必胜宝典,大厂面试题源码级详解 (持续更新!!!)

    Vue经典面试题源码级详解 1 Vue组件之间通信方式有哪些 分析 思路分析 回答范例 1 组件通信常用方式有以下8种 2 根据组件之间关系讨论组件通信最为清晰有效 2 v if 和 v for哪个优先级更高 分析 思路分析 回答范例 3
  • mysql运行语句时出现 FUNCTION *** does not exist

    我在运行MYSQL时 经常出现这种问题 一阵搜索后 原来问题出现在函数与括号之间的空格上 比如 写成 concat 这样就出错了 需要去掉空格 concat 就好了 资料来源 在这个网址找到方法 http blog 152 org 2009
  • [Spring Boot]08 IDEA接入MyBatisCodeHelper代码自动生成器

    目录 前言 一 插件市场安装插件 二 使用插件自动生成代码 前言 上次介绍了 原生mybatis的方法 06 Spring Boot接入mybatis通用mapper插件自动生成器 这次 再介绍下插件MyBatisCodeHelper Pr
  • P4wnP1 USB与赛门铁克反病毒绕过

    最近 我使用P4wnP1 image把我手头的Raspberry Pi Zero W转换成了一个bad USB 我的最终目标是运行远程命令shell 同时绕过已启用完全保护的最新版Symantec SEP 我通过创建自己的有效负载paylo
  • QGis二次开发 -- 源码编译终极篇

    由于是开源软件 QGis版本迭代比较快 在保持long term release版本的基础上 每个月都会有一个monthly release的新版本发布 源码工程变化快速 给想要上手编译开发的新人朋友带来了一些困惑 我之前分别写过QGis1
  • pytorch crossentropy为nan

    pytorch crossentropy为nan 交叉熵损失函数的具体为 loss x ln z 1 x ln 1 z z softmax pred x 这样当z为0 0时会出现loss为nan的情况 本人的具体原因 网络中用了MultiH