有监督对比loss计算

2023-11-13

https://blog.csdn.net/wf19971210/article/details/116715880

关于对比损失

  无监督对比损失,通常视数据增强后的图像与原图像互为正例。而对于有监督对比损失来说,可以将同一batch中标签相同的视为正例,与它不同标签的视为负例。对比学习能够使得同类更近,不同类更远。有监督对比损失公式如下。

有监督对比损失数学公式

Pytorch实现有监督对比损失

  话不多说,直接看代码。为了更好的说明有监督对比损失的整个实现过程,以下代码没有经过系统整理,从一个例子,一步一步地计算出损失。若是理解了每一步,那系统整理应该没什么问题。

1.通过cos计算相似度

import torch
import torch.nn.functional as F
T = 0.5  #温度参数T
label = torch.tensor([1,0,1,0,1])
n = label.shape[0]  # batch
#假设我们的输入是5 * 3  5是batch,3是句向量
representations = torch.tensor([[1, 2, 3],[1.2, 2.2, 3.3],
                                [1.3, 2.3, 4.3],[1.5, 2.6, 3.9],
                                [5.1, 2.1, 3.4]])
#这步得到它的相似度矩阵
similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
#这步得到它的label矩阵,相同label的位置为1
similarity_matrix = torch.exp(similarity_matrix/T)
print('similarity_matrix is *****')
print(similarity_matrix)

  结果

similarity_matrix is *****
tensor([[7.3891, 7.3851, 7.3241, 7.3777, 4.9964],
        [7.3851, 7.3891, 7.3172, 7.3872, 5.1341],
        [7.3241, 7.3172, 7.3891, 7.3079, 4.9291],
        [7.3777, 7.3872, 7.3079, 7.3891, 5.2278],
        [4.9964, 5.1341, 4.9291, 5.2278, 7.3891]])

2.创建各种mask

mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t())) - torch.eye(n, n )
#这步得到它的不同类的矩阵,不同类的位置为1
mask_no_sim = torch.ones_like(mask) - mask
#这步产生一个对角线全为0的,其他位置为1的矩阵
mask_dui_jiao_0 = torch.ones(n ,n) - torch.eye(n, n )
#这步给相似度矩阵求exp,并且除以温度参数T
print('mask is *****')
print(mask)

print('mask_no_sim is *****')
print(mask_no_sim)

print('mask_dui_jiao_0 is *****')
print(mask_dui_jiao_0)

结果为

mask is *****
tensor([[0., 0., 1., 0., 1.],
        [0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [1., 0., 1., 0., 0.]])
mask_no_sim is *****
tensor([[1., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1.],
        [0., 1., 1., 1., 0.],
        [1., 0., 1., 1., 1.],
        [0., 1., 0., 1., 1.]])
mask_dui_jiao_0 is *****
tensor([[0., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1.],
        [1., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 0.]])

3.相应创建各种矩阵

#这步将相似度矩阵的对角线上的值全置0,因为对比损失不需要自己与自己的相似度
similarity_matrix = similarity_matrix*mask_dui_jiao_0
print('similarity_matrix is *******')
print(similarity_matrix)

#这步产生了相同类别的相似度矩阵,标签相同的位置保存它们的相似度,其他位置都是0,对角线上也为0
sim = mask*similarity_matrix
print('sim is ')
print(sim)

#用原先的对角线为0的相似度矩阵减去相同类别的相似度矩阵就是不同类别的相似度矩阵
no_sim = similarity_matrix - sim
print('no_sim is ')
print(no_sim)
#把不同类别的相似度矩阵按行求和,得到的是对比损失的分母(还差一个与分子相同的那个相似度,后面会加上)
no_sim_sum = torch.sum(no_sim , dim=1)

结果为

similarity_matrix is *******
tensor([[0.0000, 7.3851, 7.3241, 7.3777, 4.9964],
        [7.3851, 0.0000, 7.3172, 7.3872, 5.1341],
        [7.3241, 7.3172, 0.0000, 7.3079, 4.9291],
        [7.3777, 7.3872, 7.3079, 0.0000, 5.2278],
        [4.9964, 5.1341, 4.9291, 5.2278, 0.0000]])
sim is 
tensor([[0.0000, 0.0000, 7.3241, 0.0000, 4.9964],
        [0.0000, 0.0000, 0.0000, 7.3872, 0.0000],
        [7.3241, 0.0000, 0.0000, 0.0000, 4.9291],
        [0.0000, 7.3872, 0.0000, 0.0000, 0.0000],
        [4.9964, 0.0000, 4.9291, 0.0000, 0.0000]])
no_sim is 
tensor([[0.0000, 7.3851, 0.0000, 7.3777, 0.0000],
        [7.3851, 0.0000, 7.3172, 0.0000, 5.1341],
        [0.0000, 7.3172, 0.0000, 7.3079, 0.0000],
        [7.3777, 0.0000, 7.3079, 0.0000, 5.2278],
        [0.0000, 5.1341, 0.0000, 5.2278, 0.0000]])

4.计算分母的矩阵

'''
将上面的矩阵扩展一下,再转置,加到sim(也就是相同标签的矩阵上),然后再把sim矩阵与sim_num矩阵做除法。
至于为什么这么做,就是因为对比损失的分母存在一个同类别的相似度,就是分子的数据。做了除法之后,就能得到
每个标签相同的相似度与它不同标签的相似度的值,它们在一个矩阵(loss矩阵)中。
'''
no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
print('no_sim_sum_expend is ')
print(no_sim_sum_expend)
sim_sum  = sim + no_sim_sum_expend

结果为

no_sim_sum_expend is 
tensor([[14.7628, 14.7628, 14.7628, 14.7628, 14.7628],
        [19.8363, 19.8363, 19.8363, 19.8363, 19.8363],
        [14.6251, 14.6251, 14.6251, 14.6251, 14.6251],
        [19.9134, 19.9134, 19.9134, 19.9134, 19.9134],
        [10.3618, 10.3618, 10.3618, 10.3618, 10.3618]])

5.计算对比loss

loss = torch.div(sim , sim_sum)
    '''
    由于loss矩阵中,存在0数值,那么在求-log的时候会出错。这时候,我们就将loss矩阵里面为0的地方
    全部加上1,然后再去求loss矩阵的值,那么-log1 = 0 ,就是我们想要的。
    '''
    loss = mask_no_sim + loss + torch.eye(n, n )
    #接下来就是算一个批次中的loss了
    loss = -torch.log(loss)  #求-log
    #loss = torch.sum(torch.sum(loss, dim=1) )/(2*n)  #将所有数据都加起来除以2n
    #print(loss)  #0.9821
    #最后一步也可以写为---建议用这个, (len(torch.nonzero(loss)))表示一个批次中样本对个数的一半
    loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))
    

6.完整的计算

def sup_constrive(representations, label,T):
    n = label.shape[0]
    similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
    #这步得到它的label矩阵,相同label的位置为1
    mask = torch.ones_like(similarity_matrix) * (label.expand(n, n).eq(label.expand(n, n).t())) - torch.eye(n, n)
    
    #这步得到它的不同类的矩阵,不同类的位置为1
    mask_no_sim = torch.ones_like(mask) - mask
    #这步产生一个对角线全为0的,其他位置为1的矩阵
    mask_dui_jiao_0 = torch.ones(n ,n) - torch.eye(n, n )
    #这步给相似度矩阵求exp,并且除以温度参数T
    similarity_matrix = torch.exp(similarity_matrix/T)
    #这步将相似度矩阵的对角线上的值全置0,因为对比损失不需要自己与自己的相似度
    similarity_matrix = similarity_matrix*mask_dui_jiao_0
    #这步产生了相同类别的相似度矩阵,标签相同的位置保存它们的相似度,其他位置都是0,对角线上也为0
    sim = mask*similarity_matrix
    #用原先的对角线为0的相似度矩阵减去相同类别的相似度矩阵就是不同类别的相似度矩阵
    no_sim = similarity_matrix - sim
    #把不同类别的相似度矩阵按行求和,得到的是对比损失的分母(还差一个与分子相同的那个相似度,后面会加上)
    no_sim_sum = torch.sum(no_sim , dim=1)
    '''
    将上面的矩阵扩展一下,再转置,加到sim(也就是相同标签的矩阵上),然后再把sim矩阵与sim_num矩阵做除法。
    至于为什么这么做,就是因为对比损失的分母存在一个同类别的相似度,就是分子的数据。做了除法之后,就能得到
    每个标签相同的相似度与它不同标签的相似度的值,它们在一个矩阵(loss矩阵)中。
    '''
    no_sim_sum_expend = no_sim_sum.repeat(n, 1).T
    sim_sum  = sim + no_sim_sum_expend
    loss = torch.div(sim , sim_sum)
    '''
    由于loss矩阵中,存在0数值,那么在求-log的时候会出错。这时候,我们就将loss矩阵里面为0的地方
    全部加上1,然后再去求loss矩阵的值,那么-log1 = 0 ,就是我们想要的。
    '''
    loss = mask_no_sim + loss + torch.eye(n, n )
    #接下来就是算一个批次中的loss了
    loss = -torch.log(loss)  #求-log
    #loss = torch.sum(torch.sum(loss, dim=1) )/(2*n)  #将所有数据都加起来除以2n
    #print(loss)  #0.9821
    #最后一步也可以写为---建议用这个, (len(torch.nonzero(loss)))表示一个批次中样本对个数的一半
    loss = torch.sum(torch.sum(loss, dim=1)) / (len(torch.nonzero(loss)))
    
    return loss

x = torch.rand(8,64)
label = torch.tensor([0,2,3,2,1,1,3,1])
sup_constrive(x, label,T=0.1)

大致实现过程就是这样,如果有什么问题可以随时提出。或者有什么更好的实现方法,也欢迎共享。若你要使用该损失发文章,请引用:

  “Chen, L., Wang, F., Yang, R. et al. Representation learning from noisy user-tagged data for sentiment classification. Int. J. Mach. Learn. & Cyber. (2022). https://doi.org/10.1007/s13042-022-01622-7

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

有监督对比loss计算 的相关文章

随机推荐

  • python函数闭包

    闭包 函数的闭包与函数的嵌套类似 它返回的不是一个值 而是一个函数 也就是说在函数内定义函数 如加法函数 def sum a def add b return a b 内部函数add 引用了外部函数sum 的变量a return add 外
  • echarts自定义X轴、Y轴间距

    echarts自定义X轴 Y轴间距 1 自定义间距 1 自定义间距 最近做一个项目 要求x y 轴间距自定义 因为项目数据X轴为时间轴 Y轴为对数数据轴 由于x轴的时间轴各段时间点返回密度不均匀 所以一开始用interval 官网上spli
  • 面试——软件测试

    自我介绍 Web与app测试的区别 App与小程序测试的区别 小程序的兼容性测试怎么测 小程序测试需要分别测试Android和iOS吗 还是怎么测试 Android小程序和iOS小程序的测试区别 测试流程 介绍一下项目 公司有几个测试 秒杀
  • maven 同时配置远程仓库和中央仓库的方法 mirroOf 标签意义

    问题描述 在公司内做maven项目开发时使用的都是公司内部搭建的私有远程仓库做项目开发 所以导致setting文件的设置如下
  • 华为OD机试 - 用连续自然数之和来表达整数(Java)

    题目描述 一个整数可以由连续的自然数之和来表示 给定一个整数 计算该整数有几种连续自然数之和的表达式 且打印出每种表达式 输入描述 一个目标整数T 1 lt T lt 1000 输出描述 该整数的所有表达式和表达式的个数 如果有多种表达式
  • 【vue3+elementplus】el-table的操作列使用子组件渲染按钮,按钮权限改变,父给子传值,子组件的dom不更新的解决方案

    起初是因为我使用了这个回答里面的组件去渲染表格操作列 需求 点击某个按钮 表格数据改变 按钮的权限也随着该数据变化而变化 问题 表格行数据变了 给子组件传的值也变了 在watch中也监听了 但是子组件的dom就是不更新 原因 重新获取表格数
  • 单键控制单片机电源开关电路

    原文地址 http www jichudianlu com archives 168 相关文章 1 问答 单片机控制电源开关 https bbs elecfans com jishu 1698980 1 1 html 2 由MCU控制的开关
  • 野火 RT1052 移植网卡功能(LAN8720A)

    野火 RT1052 移植网卡功能 LAN8720A 开发环境 RT Thread v4 0 2 master SOC i MX RT1050 Board 野火 RT1052 目的 在 RT Thread 系统上进行网络通讯 背景描述 1 首
  • 一维随机变量的常见分布、期望、方差及其性质与推导过程

    文章目录 必须知道的概率论知识 一维变量 离散随机变量 def 常见分布 几何分布 期望 方差 二项分布 b n p 期望 方差 泊松分布 P
  • 小小圣诞树来了

    作者 小刘在这里 每天分享云计算网络运维课堂笔记 疫情之下 你我素未谋面 但你一定要平平安安 一 起努力 共赴美好人生 夕阳下 是最美的 绽放 愿所有的美好 再疫情结束后如约而至 目录 圣诞树 一 代码 圣诞树 一 代码 import tu
  • postgres之jsonb属性的简单操作

    jsonb的一些简单操作 增删改查 更新操作 attributes属性为jsonb类型 方法定义 jsonb set target jsonb path text new value jsonb create missing boolean
  • MySql 笔记

    数据结构 B TREE 二叉树 顺序增长依次查询效率低 红黑树 数据多了深度越深 效率自然低了 HASH 查询条件限制 B TREE 度 degree 节段的数据存储个数 叶节点具有 相同的深度 叶节点的指针为空 节点的数据key从左到右递
  • vue3 多种方法的锚点定位

    在 Vue 3 中 可以通过多种方式实现锚点定位 包括使用原生的 JavaScript 方法和利用 Vue Router 提供的导航守卫等 下面我会分别介绍这些方法 1 使用原生 JavaScript 方法 在 Vue 3 中 你可以使用
  • 【Hadoop生态圈】7.离线OLAP引擎Hive入门教程

    文章目录 1 简介 2 架构分析 3 环境准备 4 使用客户端工具操作hive 4 1 数据库操作 4 2 DDL操作 4 2 1 创建表 4 2 2 导入数据到hive表中 4 2 3 指定列和行分隔符创建表 4 2 4 数据类型 4 3
  • [已解决]jeesite生成页面的弹窗问题

    jeesite生成的页面如需弹窗layer写法会有问题 actions push a href class btnList title i class fa fa check i a nbsp data confirm text 提示信息
  • ansible安装nginx

    ansible安装nginx 定义一个ansible组 把nginx tar包传到ansible主机 ansible 组名 m shell a yum y install pcre devel open devel gcc gcc c ng
  • Golang 单元测试详尽指引

    文末有彩蛋 作者 yukkizhang 腾讯 CSIG 专项技术测试工程师 本篇文章站在测试的角度 旨在给行业平台乃至其他团队的开发同学 进行一定程度的单元测试指引 让其能够快速的明确单元测试的方式方法 本文主要从单元测试出发 对Golan
  • IntelliJ IDEA 进行js Debug调试

    idea的js调试目前看来不同给力 一是玩转它需要安装谷歌插件支持 二是貌似存在一些bug 一 新建一个jsp并打上断点 二 调试 idea出现提示 安装JetBrains IDE Support支持 问题出现了 点击其中连接却一直连不上
  • [其他]IDEA中Maven项目配置国内源

    配置国内源主要解决了 在maven项目中pom xml下载jar包失败或过慢的问题 在IDEA中的设置分成两种 设置当前项目与新创项目 我们就需要两种都进行设置 不然只有在当前项目配置了国内源 新创项目的时候还是默认的状态 由于下面两种设置
  • 有监督对比loss计算

    https blog csdn net wf19971210 article details 116715880 关于对比损失 无监督对比损失 通常视数据增强后的图像与原图像互为正例 而对于有监督对比损失来说 可以将同一batch中标签相同