人脸识别损失函数简介与Pytorch实现:ArcFace、SphereFace、CosFace

2023-05-16

总结成一句话就是:在softmax基础上,对最后一层全连接的权重和输入特征进行归一化,重新放缩到半径为s的超平面,增加惩罚的margin训练,使得类型紧凑,类间变得远离

 

# SphereFace
class SphereProduct(nn.Module):
    r"""Implement of large margin cosine distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        m: margin
        cos(m*theta)
    """

    def __init__(self, in_features, out_features, m=4):
        super(SphereProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.m = m
        self.base = 1000.0
        self.gamma = 0.12
        self.power = 1
        self.LambdaMin = 5.0
        self.iter = 0
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform(self.weight)

        # duplication formula
        # 将x\in[-1,1]范围的重复index次映射到y\[-1,1]上
        self.mlambda = [
            lambda x: x ** 0,
            lambda x: x ** 1,
            lambda x: 2 * x ** 2 - 1,
            lambda x: 4 * x ** 3 - 3 * x,
            lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
            lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
        ]
        """
        执行以下代码直观了解mlambda
        import matplotlib.pyplot as  plt

        mlambda = [
            lambda x: x ** 0,
            lambda x: x ** 1,
            lambda x: 2 * x ** 2 - 1,
            lambda x: 4 * x ** 3 - 3 * x,
            lambda x: 8 * x ** 4 - 8 * x ** 2 + 1,
            lambda x: 16 * x ** 5 - 20 * x ** 3 + 5 * x
        ]
        x = [0.01 * i for i in range(-100, 101)]
        print(x)
        for f in mlambda:
            plt.plot(x,[f(i) for i in x])
            plt.show()
        """

    def forward(self, input, label):
        # lambda = max(lambda_min,base*(1+gamma*iteration)^(-power))
        self.iter += 1
        self.lamb = max(self.LambdaMin, self.base * (1 + self.gamma * self.iter) ** (-1 * self.power))

        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cos_theta = F.linear(F.normalize(input), F.normalize(self.weight))
        cos_theta = cos_theta.clamp(-1, 1)
        cos_m_theta = self.mlambda[self.m](cos_theta)
        theta = cos_theta.data.acos()
        k = (self.m * theta / 3.14159265).floor()
        phi_theta = ((-1.0) ** k) * cos_m_theta - 2 * k
        NormOfFeature = torch.norm(input, 2, 1)

        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cos_theta.size())
        one_hot = one_hot.cuda() if cos_theta.is_cuda else one_hot
        one_hot.scatter_(1, label.view(-1, 1), 1)

        # --------------------------- Calculate output ---------------------------
        output = (one_hot * (phi_theta - cos_theta) / (1 + self.lamb)) + cos_theta
        output *= NormOfFeature.view(-1, 1)

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', m=' + str(self.m) + ')'

 

# CosFace
class AddMarginProduct(nn.Module):
    r"""Implement of large margin cosine distance: :
    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        s: norm of input feature
        m: margin
        cos(theta) - m
    """

    def __init__(self, in_features, out_features, s=30.0, m=0.40):
        super(AddMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        phi = cosine - self.m
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device='cuda')
        # one_hot = one_hot.cuda() if cosine.is_cuda else one_hot
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', s=' + str(self.s) \
               + ', m=' + str(self.m) + ')'

 

代码实现:

# ArcFace
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin

            cos(theta + m)
        """

    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        # Parameter 的用途:
        # 将一个不可训练的类型Tensor转换成可以训练的类型parameter
        # 并将这个parameter绑定到这个module里面
        # net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的
        # https://www.jianshu.com/p/d8b77cc02410
        # 初始化权重
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        # torch.nn.functional.linear(input, weight, bias=None)
        # y=x*W^T+b
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        # cos(a+b)=cos(a)*cos(b)-size(a)*sin(b)
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            # torch.where(condition, x, y) → Tensor
            # condition (ByteTensor) – When True (nonzero), yield x, otherwise yield y
            # x (Tensor) – values selected at indices where condition is True
            # y (Tensor) – values selected at indices where condition is False
            # return:
            # A tensor of shape equal to the broadcasted shape of condition, x, y
            # cosine>0 means two class is similar, thus use the phi which make it
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        # 将cos(\theta + m)更新到tensor相应的位置中
        one_hot = torch.zeros(cosine.size(), device='cuda')
        # scatter_(dim, index, src)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

人脸识别损失函数简介与Pytorch实现:ArcFace、SphereFace、CosFace 的相关文章

随机推荐

  • O2OA人员身份,人员属性

    人员信息创建 从组织管理应用中进入个人管理界面后 xff0c 点击左侧上方的添加按钮 xff0c 如下图所示 xff1a 在右侧显示的界面中填写人员信息 xff1a 人员名称 手机号码 唯一编码 xff08 以上必填 xff0c 其他选填写
  • O2OA的SSO与单点认证

    SSO与单点认证 与其他系统实现单点登入 1 1 URL传递加密参数方式 这种方式是比较通用简单的实现方式 xff0c 应急门户将用户登录信息 xff08 用户ID xff09 以URL参数方式传递给被集成系统 xff0c 被集成系统通过接
  • newman和Jenkins(postname和Jenkins的结合使用)

    Newman介绍 Newman 是 Postman 推出的一个 nodejs 库 xff0c 直接来说就是 Postman 的json文件可以在命令行执行的插件 Newman 可以方便地运行和测试集合 xff0c 并用之构造接口自动化测试和
  • SmartBI入门(一)介绍和安装

    一 SmartBI系统介绍 商业智能 xff08 Business Intelligence xff0c 简称 xff1a BI xff09 xff0c 又称商业智慧或商务智能 xff0c 指用现代数据仓库技术 线上分析处理技术 数据挖掘和
  • SmartBI入门(二)配置SmartBI

    具体可以参考文档 Smartbi Config页面介绍 Smartbi V10帮助中心 SmartBI配置 如果是首次访问 xff0c 需要设置 管理员账号 密码 xff0c 以便下次登录配置界面时验证 xff0c 设置后用用户名密码登录即
  • 求助 关于A-Frame带有动画模型的导入

    哪位大神知道导入带有动画的模型后 如何调用模型自带的动画 gltf格式的
  • SmartBI入门(三)数据源配置

    1 设置数据连接 配置连接 2 选择数据表 创建的数据源 xff0c 点击数据库管理 xff0c 添加实际报表需要的数据表 3 数据库展现
  • 再见2014,你好2015

    过去就是过去了 2014年再见 xff0c 2015年你好 xff01 回首 总结也只是慰藉 1999年12月至今 xff0c 经历了整整十五个曾经 xff0c 这其中的波折 xff0c 怎是我这样的小辈能够理解的 借这个平台也只是为了感谢
  • idea报错 Artifact web:war exploded: Error during artifact deployment. See server log for details.

    因为tomcat把报错信息重定向到日志文件中了 xff0c 所以在控制台找不到报错信息 所以需要看一下tomcat日志文件报错信息 xff0c 一般情况下都有报错 日志的路径默认是在C Users 你的用户名 AppData Local J
  • Unity离线用户手册打开缓慢、卡顿

    Unity中文离线用户手册下载页面 https docs unity3d com cn 2019 4 Manual OfflineDocumentation html 文档包下载地址 xff08 需要FQ xff09 https stora
  • Python MapReduce 案例

    map t py import sys import re p 61 re compile r 39 w 43 39 for line in sys stdin ss 61 line strip split 39 39 for s in s
  • ARM基础学习-快速上下文切换技术

    FCSE的原理 快速上下文切换技术 xff08 FCSE xff09 通过修改系统中不同进程的虚拟地址 xff0c 避免在进行进程间切换时造成虚拟地址到物理地址的重映射 xff0c 从而提高系统性能 xff1b 通常情况下 xff0c 如果
  • linux下使用VNC

    安装 xff1a root 64 localhost vncserver You will require a password to access your desktops Password Verify New 39 localhos
  • svn忽略ignore文件记住方式(转)

    每个项目中的配置文件都有区别 xff0c 在本地开发和线上生产 xff0c 之前一直很懒 xff0c 不想去忽略提交一些配置文件 xff0c 只是在提交的时候排除掉 但是在项目上传部署的时候又必须小心 xff0c 害怕覆盖线上的配置 xff
  • OpenvSwitch完全使用手册

    本文主要参考 Overview of functionality and components 以及 Frequently Asked Questions 以及结合自己的理解 http sdnhub cn index php openv s
  • c++正则表达式regex_match和regex_seach使用

    1 xff1a 区别 regex match用来做全匹配 xff0c 检测字符串是否符合表达式规则 xff0c 如果字符串符合规则 xff0c 返回true xff1b regex seach用来匹配字符串中是否有和规则相匹配的字符串子串
  • ubuntu 14.04 ssh连接失败解决---离线安装openssh_server

    1 在联网机器上下载 deb 包 xff0c 因为我的离线机器为 Ubuntu14 04 64位 所以对应下载 amd64 Trusty 包 xff0c 以下为需要下载的 openssh server 及安装依赖的 deb 包名称 open
  • 使用 js 做一个简单学生系统

    HTML lt DOCTYPE html gt lt html gt lt head gt lt title gt 学生系统 lt title gt lt style type 61 34 text css 34 gt nianji wid
  • C#从基于FTPS的FTP server下载数据 (FtpWebRequest 的使用)SSL 加密

    FTPS xff0c 亦或是FTPES 是FTP协议的一种扩展 xff0c 用于对TLS和SSL协议的支持 本文讲述了如何从一个基于FTPS的Server中下载数据的实例 任何地方 xff0c 如有纰漏 xff0c 欢迎诸位道友指教 话不多
  • 人脸识别损失函数简介与Pytorch实现:ArcFace、SphereFace、CosFace

    总结成一句话就是 xff1a 在softmax基础上 xff0c 对最后一层全连接的权重和输入特征进行归一化 xff0c 重新放缩到半径为s的超平面 xff0c 增加惩罚的margin训练 xff0c 使得类型紧凑 xff0c 类间变得远离