pytorch混合常量、变量

2023-11-16

有矩阵 X ∈ R n × d X\in\R^{n\times d} XRn×d 和指示向量 m ∈ { 0 , 1 } n m\in\{0,1\}^n m{0,1}n,其中 m i = 1 m_i=1 mi=1 指明的行是常量,不可训练,即 requires_grad=False;而 m i = 0 m_i=0 mi=0 对应的行是 learnable 的变量,requires_grad=True(如缺失数据)。此处为此实现一个 wrapper 类,使其调用类似一般 tensor。

preliminaries

  • 验证:constant 和 variable 放在同一个 tensor 里,能否正常计算梯度,即 constant 无梯度、variable 有梯度。
  • 两种写法:concatenating、预分配空间 + copying。
  • 结论:两种都可以
import torch


X = torch.arange(12).view(4, 3).float()
print(X)
mask = torch.tensor([1, 0, 1, 0]).int()
n_var = (0 == mask).sum()

# 常量部分
X_const = X[mask > 0]
print(X_const)

# 变量部分
X_var = torch.normal(0, 1, size=[n_var, X.size(1)])
X_var.requires_grad_(True)
print(X_var)


print("写法 1. grad_fn=<CatBackward>")
ic, ip = 0, 0
X_mix = []
for i in range(X.size(0)):
    if mask[i] > 0:
        X_mix.append(X_const[ic:ic+1])
        ic += 1
    else:
        X_mix.append(X_var[ip:ip+1])
        ip += 1
X_mix = torch.cat(X_mix, dim=0)
print(X_mix)


"""print("写法 2. grad_fn=<CopySlices>")
ic, ip = 0, 0
X_mix = torch.zeros_like(X)
for i in range(X.size(0)):
    if mask[i] > 0:
        X_mix[i] = X_const[ic:ic+1]
        ic += 1
    else:
        X_mix[i] = X_var[ip:ip+1]
        ip += 1
print(X_mix)
"""


loss = ((X - X_mix) ** 2).sum()
loss.backward()
print("--- grad ---")
print(X_const.grad)
print(X_var.grad)
print("--- update ---")
X_var.data -= X_var.grad
print(X_var)

wrapper class & sample

  • MixVar 是 wrapper 类
  • 一个 reconstruction 的例子
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


class MixVar(nn.Module):
    """mixture of constants & trainable variables"""

    def __init__(self, X, const_mask, init_val=None, process_fn=None):
        """
        Input:
            X: [n, d], FULL matrix including both constants & (placeholders of) variables
            const_mask: [n], in {0, 1}, indicating whether the i-th item is constant
            init_val: constant initializer of variables
            process_fn: something to do before returning the var,
                e.g. normalization, activation, etc.
        """
        super(MixVar, self).__init__()
        self.X = X
        self.const_mask = const_mask
        self.process_fn = process_fn
        self.full_indices = np.arange(X.size(0))

        assert X.size(0) == const_mask.size(0)
        n = X.size(0)  # 总数据量,包括 constant 和 variable
        n_const = const_mask.sum()  # constant 数
        n_var = n - n_const  # variable 数
        assert n_var > 0, "* constant only, no need to use this class"
        size = [n_var, X.size(1)]

        # variable 另外放在 `self.weight` 里
        # 注意此时其 indexing 和 constant 已**不同**
        # 所以需要下面的 id map
        if init_val is None:
            self.weight = Parameter(torch.Tensor(*size))
            self.reset_parameters()
        else:
            self.weight = Parameter(init_val * torch.ones(*size, dtype=torch.float))

        # map the full id in `X` to the relative one in `weight`
        _cnt = 0
        self.id_map = {}
        for i in range(n):
            if 0 == const_mask[i]:
                self.id_map[i] = _cnt
                _cnt += 1
        assert _cnt == n_var

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, index=None):
        """MUST use this function for slicing instead of slicing manually"""
        if index is None:
            index = self.full_indices
        res = torch.zeros(index.shape[0], self.X.size(1),
            dtype=self.X.dtype).to(self.weight.device)
        for i in range(index.shape[0]):
            _idx = index[i]
            if self.const_mask[_idx] > 0:
                res[i] = self.X[_idx].to(self.weight.device)
            else:
                res[i] = self.weight[self.id_map[_idx]]

        if self.process_fn:
            res = self.process_fn(res)
        return res

    def extra_repr(self):
        return 'size={}'.format(self.X.size())


# 一个使用例子
X = torch.arange(12).view(6, 2).float()
print("original:\n", X)
mask = torch.tensor([1, 0, 1, 0, 0, 1]).int()
X_mix = MixVar(X, mask)

indices = np.arange(X.size(0))
optimizer = torch.optim.SGD(X_mix.parameters(), lr=0.1)
batch_size = 2
for epoch in range(100):
    for i in range(0, X.size(0), batch_size):
        index = indices[i: i + batch_size]
        loss = F.mse_loss(X[index], X_mix(index))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


print("reconstructed:\n", X_mix().data)
  • 输出
original:
tensor([[ 0.,  1.],
        [ 2.,  3.],
        [ 4.,  5.],
        [ 6.,  7.],
        [ 8.,  9.],
        [10., 11.]])
reconstructed:
tensor([[ 0.0000,  1.0000],
        [ 1.9914,  2.9829],
        [ 4.0000,  5.0000],
        [ 5.9679,  6.9619],
        [ 7.9526,  8.9491],
        [10.0000, 11.0000]])
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

pytorch混合常量、变量 的相关文章

随机推荐

  • 聊一聊基础的CPU寄存器~

    寄存器 CPU内部的存储单元 用于存放从内存读取而来的数据 包括指令 和CPU运算的中间结果 使用寄存器来临时存放数据而不直接操作内存原因如下 CPU的工作原理决定了有些操作只能在CPU内部进行 CPU读写寄存器的速度比读写内存的速度要快很
  • 前端开发实习总结参考范文(合集)

    前端开发实习总结篇一 今天就简单聊聊上面的Struts Spring Hibernate吧 Struts 代表 表示层 Spring代表 业务逻辑层 Hibernate则代表持久层 他们是目前在Java Web编程开发中用得最多的框架 其实
  • 使用hutool读取excel多sheet文件

    首先要使用hutool 可以加载maven
  • 华为手机一直android,华为手机内存不够用?这5个文件夹常清理,可以腾出近10个G内存...

    华为手机的用户量在急剧增加 当然 随时而来的就是许多使用问题 用户反馈最多的就是手机运行问题 手机使用时间一长 就会卡顿 尤其是处理紧急问题时遇到手机怠工 真是没救了 手机卡顿很大程度上是内存问题 平时使用不当造成手机内垃圾信息过多 占用手
  • R语言 第四章 初级绘图(5)课后练习,保存图形,layout函数,绘制组合图形,添加图例

    关注公众号凡花花的小窝 收获更多的考研计算机专业编程相关的资料 添加图例 当图形中包含的数据不止一组时 图例可以帮助你辨别出每个条形 扇形区域或折线各代表哪一类数据 此时 可以使用legend函数来在画布中添加图例 对图形进行相应说明 le
  • nginx root 和alise

    Nginx静态服务配置 详解root和alias指令 简书 jianshu com 静态文件 Nginx以其高性能著称 常用与做前端反向代理服务器 同时nginx也是一个高性能的静态文件服务器 通常都会把应用的静态文件使用nginx处理 配
  • Android下NestedScrolling机制与CoordinatorLayout之源码分析

    1 CoordinatorLayout依赖库 旧版本导入CoordinatorLayout依赖 implementation com android support design 28 0 0 升级Android X后的依赖 impleme
  • STUN和TURN技术浅析

    原文地址 http www h3c com cn MiniSite Technology Circle Net Reptile The Five Home Catalog 201206 747038 97665 0 htm 在现实Inter
  • pytorch中attention的两种实现方式

    class AttnDecoderRNN nn Module def init self hidden size output size dropout p 0 1 max length MAX LENGTH super AttnDecod
  • XMind2TestCase思维导图测试用例转Excel使用方法

    很多测试工程师习惯于用思维导图写测试用例 结构会比较清晰 但是我们通常把思维导图的用例整理至excel或者导入其他工具如禅道 testlink tapd来执行用例或存档 如果再逐条把思维导图转为excel会比较浪费时间 有没有工具可以把思维
  • 数字高程信息30m分辨率SRTM DEM数据下载与拼接(ENVI)

    数据下载 本次下载的数据是SRTMDEM数据 该数据分辨率为30m 可以到官网下载官网地址 http gdex cr usgs gov gdex 官网数据下载需要注册信息 如果部分区域可从网盘下载 网盘地址 链接 https pan bai
  • LeetCode第26题,删除排序数组中的重复项

    LeetCode 高频题 数组篇 26 删除排序数组中的重复项 大家好 我是Panda 今天分享的是LeetCode第26题 删除排序数组中的重复项 力扣题目链接 LeetCode 26 题目描述 给你一个 升序排列 的数组 nums 请你
  • layui后台表格的增删改查

    完整案例 github自己下下来 就是个很一般的ssm项目 但基本功能都有 已部署到云平台 后台管理员地址 暑假时候没做完凑合看吧 账号 17679210786 密码 123456 前后台都是 前台可以自己用手机号注册 别删除原来的内容 先
  • 攻防世界-MISC-练习区-12(功夫再高也怕菜刀)

    题目描述 菜狗决定用菜刀和菜鸡决一死战 这是攻防世界里面训练区的一道流量分析题 用wireshark 打开流量包 然后一级搜索http 二级用分组字节流搜索flag 按CTRL F 并找到no 1367 在Line based text d
  • 移动NB模块M5311(lwm2m协议登录详解)

    身为一个通信专业大三狗 第一次和别人对接项目今天属于我的功能总算是结束了 接下来就是等待联调 心情愉悦 首先NB是什么 这个我就不详细的解释了 我相信大多数人看这篇文章是以实践为开始的 那么多余的就不说了 接下来说具体流程 首先M5311模
  • 确实有必要好好学英语

    前言 工作已经6年多了 最近忽然明悟一些道理 零度觉得分享出来可能可以帮助一些人 这些道理可能很多成功的 牛逼的人早就知道这些了 随着技术的迭代更新越来越快 新技术不断产生 很多很多人都在焦虑 但是有一个道理的确是这样的 你不学习 未来终将
  • 【微信小程序】项目开发-----百度翻译API接口开发微信翻译小程序

    开发环境 微信开发者工具 V1 02 1902010版本以上 开发语言 JavaSript语言 HTML语言 API接口 百度翻译开发平台开放接口 界面预览 开发 基础配置 1 app js App onLaunch function 展示
  • AVPlay播放视频

    property nonatomic retain nullable AVPlayer player NSString urlStr NSBundle mainBundle pathForResource demo mp4 ofType n
  • 将灰度图片转成三通道(RGB)图片(MatLab)

    运行程序报错 RuntimeError output with shape 1 224 224 doesn t match the broadcast shape 3 224 224 报错原因 原模型输入的图片为RGB三通道 我输入的为单通
  • pytorch混合常量、变量

    有矩阵 X R n d X in R n times d