【偷偷卷死小伙伴Pytorch20天-day16-损失函数】

2023-10-26

最近这几天忙着开学返校的事情,终于几番周折回到了学校,继续pytorch的学习打卡!!

一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization)

Pytorch中的损失函数一般在训练模型时候指定。

注意Pytorch中内置的损失函数的参数和tensorflow不同,是y_pred在前,y_true在后,而Tensorflow是y_true在前,y_pred在后。

对于回归模型,通常使用的内置损失函数是均方损失函数nn.MSELoss 。

对于二分类模型,通常使用的是二元交叉熵损失函数nn.BCELoss (输入已经是sigmoid激活函数之后的结果)
或者 nn.BCEWithLogitsLoss (输入尚未经过nn.Sigmoid激活函数) 。

对于多分类模型,一般推荐使用交叉熵损失函数 nn.CrossEntropyLoss。
(y_true需要是一维的,是类别编码。y_pred未经过nn.Softmax激活。)

此外,如果多分类的y_pred经过了nn.LogSoftmax激活,可以使用nn.NLLLoss损失函数(The negative log likelihood loss)。
这种方法和直接使用nn.CrossEntropyLoss等价。

如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。

Pytorch中的正则化项一般通过自定义的方式和损失函数一起添加作为目标函数。

如果仅仅使用L2正则化,也可以利用优化器的weight_decay参数来实现相同的效果。

一、内置损失函数

import numpy as np
import pandas as pd
import torch 
from torch import nn 
import torch.nn.functional as F 


y_pred = torch.tensor([[10.0,0.0,-10.0],[8.0,8.0,8.0]])
y_true = torch.tensor([0,2])

# 直接调用交叉熵损失
ce = nn.CrossEntropyLoss()(y_pred,y_true)
print(ce)

# 等价于先计算nn.LogSoftmax激活,再调用NLLLoss
y_pred_logsoftmax = nn.LogSoftmax(dim = 1)(y_pred)
nll = nn.NLLLoss()(y_pred_logsoftmax,y_true)
print(nll)

在这里插入图片描述

内置的损失函数一般有类的实现和函数的实现两种形式。

如:nn.BCE 和 F.binary_cross_entropy 都是二元交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。

实际上类的实现形式通常是调用函数的实现形式并用nn.Module封装后得到的。

一般我们常用的是类的实现形式。它们封装在torch.nn模块下,并且类名以Loss结尾。

常用的一些内置损失函数说明如下。

  • nn.MSELoss(均方误差损失,也叫做L2损失,用于回归)

  • nn.L1Loss (L1损失,也叫做绝对值误差损失,用于回归)

  • nn.SmoothL1Loss (平滑L1损失,当输入在-1到1之间时,平滑为L2损失,用于回归)

  • nn.BCELoss (二元交叉熵,用于二分类,输入已经过nn.Sigmoid激活,对不平衡数据集可以用weigths参数调整类别权重)

  • nn.BCEWithLogitsLoss (二元交叉熵,用于二分类,输入未经过nn.Sigmoid激活)

  • nn.CrossEntropyLoss (交叉熵,用于多分类,要求label为稀疏编码,输入未经过nn.Softmax激活,对不平衡数据集可以用weigths参数调整类别权重)

  • nn.NLLLoss (负对数似然损失,用于多分类,要求label为稀疏编码,输入经过nn.LogSoftmax激活)

  • nn.CosineSimilarity(余弦相似度,可用于多分类)

  • nn.AdaptiveLogSoftmaxWithLoss (一种适合非常多类别且类别分布很不均衡的损失函数,会自适应地将多个小类别合成一个cluster)

二、自定义损失函数

自定义损失函数接收两个张量y_pred,y_true作为输入参数,并输出一个标量作为损失函数值。

也可以对nn.Module进行子类化,重写forward方法实现损失的计算逻辑,从而得到损失函数的类的实现。

下面是一个Focal Loss的自定义实现示范。Focal Loss是一种对binary_crossentropy的改进损失函数形式。

它在样本不均衡和存在较多易分类的样本时相比binary_crossentropy具有明显的优势。

它有两个可调参数,alpha参数和gamma参数。其中alpha参数主要用于衰减负样本的权重,gamma参数主要用于衰减容易训练样本的权重。

从而让模型更加聚焦在正样本和困难样本上。这就是为什么这个损失函数叫做Focal Loss。

class FocalLoss(nn.Module):
    
    def __init__(self,gamma=2.0,alpha=0.75):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha

    def forward(self,y_pred,y_true):
        bce = torch.nn.BCELoss(reduction = "none")(y_pred,y_true)
        p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
        alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
        modulating_factor = torch.pow(1.0 - p_t, self.gamma)
        loss = torch.mean(alpha_factor * modulating_factor * bce)
        return loss
#困难样本
y_pred_hard = torch.tensor([[0.5],[0.5]])
y_true_hard = torch.tensor([[1.0],[0.0]])

#容易样本
y_pred_easy = torch.tensor([[0.9],[0.1]])
y_true_easy = torch.tensor([[1.0],[0.0]])

focal_loss = FocalLoss()
bce_loss = nn.BCELoss()

print("focal_loss(hard samples):", focal_loss(y_pred_hard,y_true_hard))
print("bce_loss(hard samples):", bce_loss(y_pred_hard,y_true_hard))
print("focal_loss(easy samples):", focal_loss(y_pred_easy,y_true_easy))
print("bce_loss(easy samples):", bce_loss(y_pred_easy,y_true_easy))

#可见 focal_loss让容易样本的权重衰减到原来的 0.0005/0.1054 = 0.00474
#而让困难样本的权重只衰减到原来的 0.0866/0.6931=0.12496

# 因此相对而言,focal_loss可以衰减容易样本的权重。

在这里插入图片描述
FocalLoss的使用完整范例可以参考下面中自定义L1和L2正则化项中的范例,该范例既演示了自定义正则化项的方法,也演示了FocalLoss的使用方法。

三、自定义L1和L2正则化项

通常认为L1 正则化可以产生稀疏权值矩阵,即产生一个稀疏模型,可以用于特征选择。

而L2 正则化可以防止模型过拟合(overfitting)。一定程度上,L1也可以防止过拟合。

下面以一个二分类问题为例,演示给模型的目标函数添加自定义L1和L2正则化项的方法。

这个范例同时演示了上一个部分的FocalLoss的使用。

1.准备数据

import numpy as np 
import pandas as pd 
from matplotlib import pyplot as plt
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader,TensorDataset
import torchkeras 
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

#正负样本数量
n_positive,n_negative = 200,6000

#生成正样本, 小圆环分布
r_p = 5.0 + torch.normal(0.0,1.0,size = [n_positive,1]) 
theta_p = 2*np.pi*torch.rand([n_positive,1])
Xp = torch.cat([r_p*torch.cos(theta_p),r_p*torch.sin(theta_p)],axis = 1)
Yp = torch.ones_like(r_p)

#生成负样本, 大圆环分布
r_n = 8.0 + torch.normal(0.0,1.0,size = [n_negative,1]) 
theta_n = 2*np.pi*torch.rand([n_negative,1])
Xn = torch.cat([r_n*torch.cos(theta_n),r_n*torch.sin(theta_n)],axis = 1)
Yn = torch.zeros_like(r_n)

#汇总样本
X = torch.cat([Xp,Xn],axis = 0)
Y = torch.cat([Yp,Yn],axis = 0)


#可视化
plt.figure(figsize = (6,6))
plt.scatter(Xp[:,0],Xp[:,1],c = "r")
plt.scatter(Xn[:,0],Xn[:,1],c = "g")
plt.legend(["positive","negative"]);

在这里插入图片描述

ds = TensorDataset(X,Y)

ds_train,ds_valid = torch.utils.data.random_split(ds,[int(len(ds)*0.7),len(ds)-int(len(ds)*0.7)])
dl_train = DataLoader(ds_train,batch_size = 100,shuffle=True,num_workers=2)
dl_valid = DataLoader(ds_valid,batch_size = 100,num_workers=2)

2.定义模型

class DNNModel(torchkeras.Model):
    def __init__(self):
        super(DNNModel, self).__init__()
        self.fc1 = nn.Linear(2,4)
        self.fc2 = nn.Linear(4,8) 
        self.fc3 = nn.Linear(8,1)
        
    def forward(self,x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        y = nn.Sigmoid()(self.fc3(x))
        return y
        
model = DNNModel()

model.summary(input_shape =(2,))

在这里插入图片描述

3.训练模型

# 准确率
def accuracy(y_pred,y_true):
    y_pred = torch.where(y_pred>0.5,torch.ones_like(y_pred,dtype = torch.float32),
                      torch.zeros_like(y_pred,dtype = torch.float32))
    acc = torch.mean(1-torch.abs(y_true-y_pred))
    return acc

# L2正则化
def L2Loss(model,alpha):
    l2_loss = torch.tensor(0.0, requires_grad=True)
    for name, param in model.named_parameters():
        if 'bias' not in name: #一般不对偏置项使用正则
            l2_loss = l2_loss + (0.5 * alpha * torch.sum(torch.pow(param, 2)))
    return l2_loss

# L1正则化
def L1Loss(model,beta):
    l1_loss = torch.tensor(0.0, requires_grad=True)
    for name, param in model.named_parameters():
        if 'bias' not in name:
            l1_loss = l1_loss +  beta * torch.sum(torch.abs(param))
    return l1_loss

# 将L2正则和L1正则添加到FocalLoss损失,一起作为目标函数
def focal_loss_with_regularization(y_pred,y_true):
    focal = FocalLoss()(y_pred,y_true) 
    l2_loss = L2Loss(model,0.001) #注意设置正则化项系数
    l1_loss = L1Loss(model,0.001)
    total_loss = focal + l2_loss + l1_loss
    return total_loss

model.compile(loss_func =focal_loss_with_regularization,
              optimizer= torch.optim.Adam(model.parameters(),lr = 0.01),
             metrics_dict={"accuracy":accuracy})

dfhistory = model.fit(30,dl_train = dl_train,dl_val = dl_valid,log_step_freq = 30)

在这里插入图片描述

# 结果可视化
fig, (ax1,ax2) = plt.subplots(nrows=1,ncols=2,figsize = (12,5))
ax1.scatter(Xp[:,0],Xp[:,1], c="r")
ax1.scatter(Xn[:,0],Xn[:,1],c = "g")
ax1.legend(["positive","negative"]);
ax1.set_title("y_true");

Xp_pred = X[torch.squeeze(model.forward(X)>=0.5)]
Xn_pred = X[torch.squeeze(model.forward(X)<0.5)]

ax2.scatter(Xp_pred[:,0],Xp_pred[:,1],c = "r")
ax2.scatter(Xn_pred[:,0],Xn_pred[:,1],c = "g")
ax2.legend(["positive","negative"]);
ax2.set_title("y_pred");

在这里插入图片描述

四、通过优化器实现L2正则化

如果仅仅需要使用L2正则化,那么也可以利用优化器的weight_decay参数来实现。

weight_decay参数可以设置参数在训练过程中的衰减,这和L2正则化的作用效果等价。

before L2 regularization:  

gradient descent: w = w - lr * dloss_dw   

after L2 regularization:  

gradient descent: w = w - lr * (dloss_dw+beta*w) = (1-lr*beta)*w - lr*dloss_dw  

so (1-lr*beta)is the weight decay ratio.

Pytorch的优化器支持一种称之为Per-parameter options的操作,就是对每一个参数进行特定的学习率,权重衰减率指定,以满足更为细致的要求。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

【偷偷卷死小伙伴Pytorch20天-day16-损失函数】 的相关文章

随机推荐

  • 回路电感详细介绍(环路电感)

    相比于硬件工程师 PCB工程师对环路电感更敏感 因为环路电感和走线强相关 不管是信号完整性还是电源完整性都涉及到这个概念 一旦电路结构确定 环路电感也随之确定 如果环路电感初期评估失误将会给后期改版带来巨大风险 更多资料请关注公众号 工程师
  • 操作系统 java模拟主存储器空间的分配和回收

    文章目录 实验原理 算法流程图 代码 结果 实验原理 模拟在可变分区管理方式下采用最先适应算法实现主存分配和回收 1 可变分区方式是按作业需要的主存空间大小来分割分区的 当要装入一个作业时 根据作业需要的主存量查看是否有足够的空闲空间 若有
  • 创建第一个Qt Widget项目

    创建第一个Qt Widget项目步骤 1 选择文件 Ctrl n 2 新建文件或项目 3 Qt Widget Application 4 输入项目名称FirstApplication 选择存储的位置 5 选择构建套件Desktop Qt Q
  • numpy一维数组永远为列向量

    import numpy as np a np array 1 3 4 5 print a shape a np transpose a print a shape print a a np ravel a print a shape pr
  • 静态分析简介

    一 程序静态分析简介 Program Static Analysis 程序静态分析简介 Program Static Analysis 是指在不运行代码的方式下 通过词法分析 语法分析 控制流 数据流分析等技术对程序代码进行扫描 验证代码是
  • 【mysql安装报错(已解决)】ERROR 1045 (28000): Access denied for user ‘root‘@‘localhost‘ (using password: YES)

    1 说在开头 我的 mysql 版本是 8 0 27 的 安装的时候 感觉每一步都没有错 但是就是不行 到连接本地数据库时 发现一直连不上 搞了好久 一直报下面的错 ERROR 1045 28000 Access denied for us
  • 一文带你聊聊MYSQL的锁和MVCC

    如果你觉得内容对你有帮助的话 不如给个赞 鼓励一下更新 本文内容总结自极客时间 MySQL实战45讲 专栏 LBCC 单版本控制 锁 基于锁的并发控制 这种方案比较简单粗暴 就是一个事务去读取一条数据的时候 就上锁 不允许其他事务来操作 当
  • 【Python自动化】生成带装饰图形的渐变背景文字封面

    Python自动化专栏 利用文字生成固定比例且带有装饰图形的封面 文章目录 一 背景介绍 二 功能介绍 效果预览 功能清单 三 过程拆解 1 渐变背景层 2 装饰图形层 3 半透明遮罩层 4 文字层 四 完整代码 参考文档 一 背景介绍 在
  • echarts地图map下钻到镇街、KMZ文件转GeoJson、合成自定义区域

    echarts 地图map下钻到镇街 KMZ文件转GeoJson 合成自定义区域 我们可以通过 http datav aliyun com tools atlas 阿里旗下的高德地图提供的api 可以获取到中国各个省份 区级 县级的json
  • NAT技术的主要实现方式及其对网络应用程序的使用影响

    网络地址转换 NAT 是接入广域网 WLAN 的一种技术 能够将私有 保留 地址转化为合法的IP地址 它被广泛应用于各种类型Internet接入方式和各种类型的网络中 NAT的实现方式有三种 静态转换 动态转换和端口多路复用 静态转换设置起
  • Linux审计与日志安全加固

    审计和日志服务配置 auditctl 审计数据配置 日志文件最大参数 在储存策略 etc audit audit conf 中配置max log file
  • 高德地图精确查找与定位RegeocodeQuery与GeocodeQuery

    根据输入的字符串精确查找位置 用GeocodeQuery查找坐标 然后根据获取到的坐标 用RegeocodeQuery查询地址 例子中用了两个页面 一个是显示地址信息及定位的页面 另一个是搜索页面 点击搜索结果返回显示页面 显示信息并定位
  • iOS经典面试题总结--内存管理

    2019独角兽企业重金招聘Python工程师标准 gt gt gt 我根据自己的情况做了一下总结 答案是我总结的 如有答的不好的地方 希望批评指正以及交流 谢谢 内存管理 1 什么是ARC ARC是automatic reference c
  • 【darknet】2、yolov4模型训练之模型训练

    文章目录 1 进行模型训练数据准备 1 1 划分训练和验证集 1 2 将数据标注格式转换为YOLO格式 2 修改配置文件 2 1 新建cfg vechle names 2 2 新建cfg vechle data 2 3 根据所选模型的不同
  • java连接数据库的Connection中的prepareStatement与createStatement的区别

    这两者的区别主要在于如何构造执行sql语句的对象 1 对于prepareStatement来说 其执行返回的是一个prepareStatement对象 而这个方法的描述是这样的 prepareStatement String sql 创建一
  • 在mac上安装gradle(超详细,直接按步骤操作即可轻松搞定)

    在mac上安装gradle 超详细 直接按步骤操作即可轻松搞定 第一步 就是先download最新版本的gradle 网址如下 http gradle org gradle download 然后将下载下来的zip包放解压到本地任意的路径上
  • input 标签里 value值从数据库读取出来的值显示一半或者没显示原因

    存进数据库的字符如下 读取数据出来显示如下 毒 这家超市被星巴克称为 价格警察 这段话没显示出来 原因 这样出来的是value 比海底捞服务更 毒 这家超市被星巴克称为 价格警察 input value值中的双引号被作为value值的结束符
  • 求二元函数最大值matlab,利用matlab, 二元函数求最大值

    求二元函数 z 0 2323 x 2 0 2866 2 2 0 5406 a0 2 1 0203 a0 2 x 2 x 2 y 2 0 5 tanh 2 x 2 y 2 0 5 x 2 0 5733 u0 2 的最大值 变量x和y都是在0
  • React -css in js框架style-components

    原文 https www jianshu com p 27788be90605 前言 前端飞一般的发展中 衍生出各式各样的框架 框架的目的是减轻开发人员的开发难度 提高效率 以前网页开发的原则是关注点分离 意思是各种技术只负责自己的领域 不
  • 【偷偷卷死小伙伴Pytorch20天-day16-损失函数】

    最近这几天忙着开学返校的事情 终于几番周折回到了学校 继续pytorch的学习打卡 一般来说 监督学习的目标函数由损失函数和正则化项组成 Objective Loss Regularization Pytorch中的损失函数一般在训练模型时