JAX基本用法以及GCN实现

2023-10-27

JAX定位

  • JAX 不是一个深度学习框架或深度学习库,其设计初衷也不是成为一个深度学习框架或深度学习库。

  • JAX 的定位科学计算(Scientific Computing)和函数转换(Function Transformations)的交叉融合。深度学习只是 JAX 功能的一小部分。特色功能如下:

    • 即时编译(Just-in-Time Compilation)
    • 自动并行化(Automatic Parallelization)
    • 自动向量化(Automatic Vectorization)
    • 自动微分(Automatic Differentiation)
  • 两大部分内容:

    1. 对标Numpy的科学计算库,可以在GPU和TPU上运行
    2. 深度学习需要用到的底层计算工具。

基本用法

jit即时编译加速

  1. 纯函数:输入全部作为参数,结果全部作为输出。不使用一切外部变量。

  2. jax transforms将语句翻译成简单的数据计算流(tracing),追踪数值/变量的变换轨迹

在这里插入图片描述

import jax
import jax.numpy as jnp

global_list = []

def f(x):
    global_list.append(x)    #  side-effect
	print(x)    			 #  side-effect
    return 2*x*x+3*x+3

jaxpr = jax.make_jaxpr(f)
jaxpr(3)
# make_jaxpr翻译(转换)后函数:
{ lambda ; a:i32[]. let
    b:i32[] = mul a 2
    c:i32[] = mul b a
    d:i32[] = mul a 3
    e:i32[] = add c d
    f:i32[] = add e 3
  in (f,) }
  1. jit将代码直接编译成机器码,静态编译。由谷歌开发的XLA(加速线性代数)编译器完成。

  2. 会将编译后的代码进行缓存。只有输入变量形状改变,静态参数改变,才会重新编译。

    # 举一个调用全局变量的问题:
    g = 0
    
    def f(x):
        return x + g
    
    jit_f = jax.jit(f)	
    
    print ("First call: ", jit_f(3.))
    

    输出:

    First call:  3.0
    

    如果此时将外部变量g改为10,再次运行程序:

    g = 10.  # Update the global
    jit_f(4.)
    

    输出:

    DeviceArray(4., dtype=float32, weak_type=True)
    
    jit_f(jnp.array([4.]))	# 输入参数的shape改变
    
    output:DeviceArray([14.], dtype=float32)
    
  3. 不支持分支语句,对循环语句有限制条件,强烈建议使用特定的函数

    1. cond

      def cond(pred, true_fun, false_fun, operand):
          if pred:
           	 return true_fun(operand)
          else:
           	 return false_fun(operand)
      
    2. while_loop

      def while_loop(cond_fun, body_fun, init_val):
          val = init_val
          while cond_fun(val):
            val = body_fun(val)
          return val
      
    3. fori_loop

      body_fun有两个参数:i和中间变量

      def fori_loop(start, stop, body_fun, init_val):
          val = init_val
          for i in range(start, stop):
            	val = body_fun(i, val)
          return val
      
    4. scan

      • f: 双参数状态变化函数
      • init: carry的初始值
      • xs: 输入的变量
      def scan(f, init, xs, length=None):
          if xs is None:
            	xs = [None] * length
          carry = init
      	# core:
          ys = []
          for x in xs:
            	carry, y = f(carry, x)
            	ys.append(y)
          return carry, np.stack(ys)
      

vmap自动向量化,批处理

  1. 输入一个函数,输出一个函数。

  2. 自动对参数进行切分。默认对全部数组参数按照最高维进行切分。切分好的切片按照原函数函数分组进行计算,无法分组或分组后无法计算会报错,结果会stack进行堆砌。

    from jax import numpy as jnp
    from jax import random, jit, grad, vmap
    
    x = jnp.array([[1, 2, 3], [0, 1, 2]]) # 2 * 3
    y = jnp.array([[1, 2, 3], [0, 0, 0],[1, 2, 3], [0, 1, 2] ]) # 4 * 3
    def fun(a1, a2):
        return (a1 * a2)
    vmap_fun = vmap(fun)
    print(vmap_fun(x, y))
    
    ValueError: vmap got inconsistent sizes for array axes to be mapped:
    arg 0 has shape (2, 3) and axis 0 is to be mapped
    arg 1 has shape (4, 3) and axis 0 is to be mapped
    so
    arg 0 has an axis to be mapped of size 2
    arg 1 has an axis to be mapped of size 4
    
  3. in_axes指定输入数组的切片轴。out_axes指定输出函数的堆砌轴。

    # 定义函数:
    f = lambda x,w : jnp.dot(w,x)
    
    # 定义batch_x, w。
    x_batch = jax.random.normal(jax.random.PRNGKey(55), (4, 5, 3))
    w = jax.random.normal(jax.random.PRNGKey(42), (100, 5))
    
    batch_a = jax.vmap(f, in_axes=(0,None), out_axes=0)(x_batch, w)
    
    print(batch_a,shape)
    ==================================
      out:
      (4, 100, 3)
    

pmap

  1. 单机多卡
  2. 和vmap类似。会把切片分散到不同的GPU里进行并行计算。会把用到的参数在每个GPU里复制一份。
  3. 使用并行结果的时候利用API去获得不同GPU上的结果。

grad自动微分

  1. ·输入一个函数,输入一个函数

  2. 多参数函数,可指定对哪几个参数进行微分。默认只对第一个参数进行微分。

    import jax
    
    def f(x, y):
        return 2*x*x + 3*y + 3 
    
    x = 10
    y = 5
    
    jax.grad(f)(x, y) 
    jax.grad(f, argnums=(0,1,))(x, y) 
    

    输出:返回的是一个元组,可以被索引:

    DeviceArray(40., dtype=float32, weak_type=True) 
    
    (
    DeviceArray(40., dtype=float32, weak_type=True),
    DeviceArray(3., dtype=float32, weak_type=True)
    ) 
    

Pytree

JAX中的pytree指的是使用python容器(比如list、dict、tuple、OrderedDict、None、namedtuple等)储存的树状结构的数据(e.g., lists of lists of dicts)。如果一些数据没有被python容器装起来,那么它就是子叶数据(比如数值、数组、类、字符串),pytree中可以嵌套pytree。

嵌套式的list/dict/tuple结构,常常用来做神经网络的参数。

jnumpy

  1. APINumpyAPI几乎完全一样。

    import numpy as np
    
    import jax.numpy as np
    
  2. 产生随机数的方式不一样。

    np.random.seed(seed)
    np.random.uniform()	# 0.54881350
    np.random.unifrom() # 0.71518936
    
    key = jax.random.PRNGKey(seed)	# key:DeviceArray([0, 0], dtype=uint32)
    x = jax.random.uniform(key)	# 0.41845703
    x = jax.random.uniform(key) # 0.41845703
    
    key, subkey = jax.random.split(key)
    x = jax.random.uniform(subkey) # 0.10546897
    

    key 和 x 之间看似有一种映射关系。

    在jax中使用随机数的精髓:永远不用重复使用你的key,善用jax.random.split()函数。

    为了兼容jax的并行化、可重复以及可矢量化。

  3. 数组不可变

    # NumPy: mutable arrays
    x = np.arange(10)
    x[0] = 10
    
    # JAX: immutable arrays
    x = jnp.arange(10)
    x[0] = 10  # 报错
    
    y = x.at[0].set(10)
    new_array = index_update(old_array, index[1, :], 1.)
    new_array = index_add(old_array, index[::2, 3:], 7)
    

    允许就地改变变量使得程序分析和转换非常困难。

  4. 索引超出范围不报错,返回最后一个

    jnp.arange(10)[11]
    --------------------------------------------
    out: DeviceArray(9, dtype=int32)
    

JAX PyTorch GCN实现对比

1. 导包
import jax
import jax.numpy as np
from jax import lax, random	# 随机数包
from jax.experimental import stax	# 计算模型
from jax.experimental.stax import Relu, LogSoftmax # 激活函数
from jax.nn.initializers import glorot_normal, glorot_uniform, normal, uniform, zeros
import optax	# 优化器
import jax.nn as nn	# nn库
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim
import matplotlib.pyplot as plt
2. 定义图卷积
def GraphConvolution(out_dim, bias=False, sparse=False):
  
    def matmul(A, B, shape):
        if sparse:
            return sp_matmul(A, B, shape)
        else:
            return np.matmul(A, B)

    def init_fun(rng, input_shape):
        output_shape = input_shape[:-1] + (out_dim,)
        k1, k2 = random.split(rng)
        W_init, b_init = glorot_uniform(), zeros
        W = W_init(k1, (input_shape[-1], out_dim))
        if bias:
            b = b_init(k2, (out_dim,))
        else:
            b = None
        return output_shape, (W, b)
    
    def apply_fun(params, feature, adj):
        W, b = params
        support = np.dot(feature, W)
        out = matmul(adj, support, support.shape[0])
        if bias:
            out += b
        return out

    return init_fun, apply_fun
class GraphConvolution(nn.Module):
	def __init__(self, input_dim, output_dim, use_bias=True):
        """图卷积: L*X*\theta
        Args:
        ----------
        input_dim: int
        节点输入特征的维度
        output_dim: int
        输出特征维度
        use_bias : bool, optional
        是否使用偏置
        """
        super(GraphConvolution, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.weight = nn.Parameter(torch.Tensor(input_dim, output_dim))
        if self.use_bias:
        	self.bias = nn.Parameter(torch.Tensor(output_dim))
        else:
        	self.register_parameter('bias', None)
        self.reset_parameters()
        
    def reset_parameters(self):
        init.kaiming_uniform_(self.weight)
        if self.use_bias:
            init.zeros_(self.bias)

    def forward(self, adjacency, input_feature):
        """邻接矩阵是稀疏矩阵, 因此在计算时使用稀疏矩阵乘法
        Args:
        -------
        adjacency: torch.sparse.FloatTensor
        邻接矩阵
        input_feature: torch.Tensor
        输入特征
        """
        support = torch.mm(input_feature, self.weight)
        output = torch.sparse.mm(adjacency, support)
        if self.use_bias:
            output += self.bias
        retur output

总结: jax没有PyTorch的model类。jax的函数尽可能写成纯函数,不保留内部数据,数据尽可能作为参数,外界传入。

3. 定义神经网络
def GCN(nhid: int, nclass: int, sparse: bool = False):
    
    gc1_init, gc1_fun = GraphConvolution(nhid, sparse=sparse)
    gc2_init, gc2_fun = GraphConvolution(nclass, sparse=sparse)

    init_funs = [gc1_init, gc2_init]

    def init_fun(rng, input_shape):
        params = []
        for init_fun in init_funs:
            rng, layer_rng = random.split(rng)
            input_shape, param = init_fun(layer_rng, input_shape)
            params.append(param)
        return input_shape, params

    def apply_fun(params, feature, adj, **kwargs):
        rng = kwargs.pop('rng', None)
        k1, k2 = random.split(rng, 2)
        x = gc1_fun(params[0], feature, adj, rng=k1)
        x = nn.relu(x)
        x = gc2_fun(params[1], x, adj, rng=k2)
        x = nn.log_softmax(x)
        return x
    
    return init_fun, apply_fun
class GcnNet(nn.Module):
    """
    定义一个包含两层GraphConvolution的模型
    """
    def __init__(self, input_dim=1433):
        super(GcnNet, self).__init__()
        self.gcn1 = GraphConvolution(input_dim, 16)
        self.gcn2 = GraphConvolution(16, 7)
        
    def forward(self, adjacency, feature):
        h = F.relu(self.gcn1(adjacency, feature))
        logits = self.gcn2(adjacency, h)
        return logits
4.训练
  1. 模型初始化

    init_fun, predict_fun = GCN(nhid=hidden, nclass=labels.shape[1],sparse=args.sparse)
    _, init_params = init_fun(init_key, input_shape)
    
    # 模型初始化
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = GcnNet().to(device)
    
  2. 定义损失函数

    @jit
    def loss(params, batch):
        """
        The idxes of the batch indicate which nodes are used to compute the loss.
        """
        inputs, targets, adj, rng, idx = batch
        preds = predict_fun(params, inputs, adj, rng=rng)
        ce_loss = -np.mean(np.sum(preds[idx] * targets[idx], axis=1))
        l2_loss = 5e-4 * optimizers.l2_norm(params)**2 # tf doesn't use sqrt
        return ce_loss + l2_loss
    
    # 损失函数使用交叉熵
    criterion = nn.CrossEntropyLoss().to(device)
    
  3. 定义优化器

    optimizer = optax.adam(start_learning_rate)
    opt_state = optimizer.init(init_params)	# 优化器状态初始化
    
    # 优化器使用Adam
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
  4. 正(反)向传播

    for epoch in range(num_epochs):
        grads = jax.grad(compute_loss)(params, xs, ys)
      	updates, opt_state = optimizer.update(grads, opt_state)		# updates更新参数的方式
      	params = optax.apply_updates(params, updates)
    
    for epoch in range(epochs):
        logits = model(tensor_adjacency, tensor_x) # 前向传播
        loss = criterion(logits, train_y) # 计算损失值
        optimizer.zero_grad()
        loss.backward() # 反向传播计算参数的梯度
        optimizer.step() # 使用优化方法进行梯度更新
    

    jax的前向传播过程其实定义在了损失函数里面。在更新的时候,会调用损失函数求梯度,也就前向传播了。update完成了前向传播,求损失值,求梯度更新参数的过程。

总结: 最大的区别就是jax是面向纯函数编程,PyTorch是面向对象编程。

Neural Network Libraries

  • Flax - Centered on flexibility and clarity.
  • Haiku - Focused on simplicity, created by the authors of Sonnet at DeepMind.
  • Objax - Has an object oriented design similar to PyTorch.
  • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax.
  • Trax - “Batteries included” deep learning library focused on providing solutions for common workloads.
  • Jraph - Lightweight graph neural network library.
  • Neural Tangents - High-level API for specifying neural networks of both finite and infinite width.
  • HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax).
  • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

参考资料

Google JAX Notebook

JAX 中文教程

JAX官方文档

2022年,我该用JAX吗?GitHub 1.6万星,这个年轻的工具并不完美

Implementing Graph Neural Networks with JAX

《深入浅出图神经网络:GNN原理解析》(刘忠雨 李彦霖 周洋)

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

JAX基本用法以及GCN实现 的相关文章

随机推荐

  • Arduino教程四——u8g2库OLED屏幕显示

    1 功能 u8g2库OLED屏幕显示英文 OLED 0 96寸 128X64 对于这几个参数进行说明 0 96指的是屏幕的显示尺寸0 96inch 128 64指的是屏幕的分辨率为128 64 128列64行 u8g2 屏幕显示 固定搭配
  • Arthas(阿尔萨斯) 的安装与使用

    arthas官方文档 https arthas aliyun com doc index html点击此处进入 是Alibaba开源的Java诊断工具 深受开发者喜爱 在线排查问题 无需重启 动态跟踪Java代码 实时监控JVM状态 Art
  • 11月20日 如何在场景开启Debug,自定义AI任务,EQS,创建自己的环境任务,使用Pawn环境检测来检测周围的环境,让AI动作更顺滑(动画混合

    如何在场景开启Debug 按F1开启线框模式 按 打开Debug数据栏 按数字键3打开EQSDEBUG 开启距离场debug 自定义AI任务 创建BTTask RangeAttack h Fill out your copyright no
  • 使用msfconsole拿到win2008 R2的shell并进行维权二(权限维持)

    声明 本博文仅供学习交流使用 不可用于任何违法犯罪活动 由此带来的任何法律后果 本人概不承担 使用msfconsole拿到win2008 R2的shell并进行维权二 权限维持 四 维权后门 4 1查询服务器信息 4 1 1查看当前用户以及
  • linux挂载磁盘超时问题解决记录

    上周公司一台k8节点nfs挂载超时 同事反映 这个盘挂载是有问题 开始各种排查 都没问题 最后排查到nfs server节点iptables规则限制所致 记录一下这次的排查过程 1 server端排查 看配置 检查 showmount e
  • 拆机小白的联想小新I1000内存升级过程

    终于有时间升级一下我的4GB内存的联想小新I1000了 原想着如果可以扩展的话 内存升到最高 硬盘加装一个不用太大的SSD硬盘 把系统就装在SSD上面 机械就只作为一个存储的硬盘 可惜联想小新I1000不支持呀 内存条和硬盘都只是一个卡槽
  • 三、OpenCV图像的预处理——二值化与自适应阈值

    教程汇总 python基础入门系列 定义 图像的二值化 就是将图像上的像素点的灰度值设置为0或255 也就是将整个图像呈现出明显的只有黑和白的视觉效果 一幅图像包括目标物体 背景还有噪声 要想从多值的数字图像中直接提取出目标物体 常用的方法
  • 矩阵求秩

    矩阵的秩怎么计算 这个问题一下子我居然不知道怎么下手 虽然本科的时候学过线性代数 但是好久不用 很多东西都忘了 今天略微梳理一下吧 最简单直观的方法 化成行最简形 或行阶梯形 然后数一下非零行数 例如 将矩阵做初等行变换后 非零行的个数叫行
  • Python 实现多个类别数据的直方图区间层面累积堆叠

    Python 实现多个类别数据的直方图区间层面累积堆叠 数据可视化是数据科学中不可缺少的一部分 它能够帮助我们更好地理解和分析数据 直方图是一种常用的数据可视化方法 它可以将数据分布情况以柱状图的形式展示出来 如果存在多个类别的数据 我们可
  • mysql convert函数 解决读取double为科学计数法问题

    convert顾名思义就是转化 cast差不多 MySQL CONVERT 函数 参考手册 为什么需要这个函数 mysql是弱类型的 where stringcol 1 and intcol 1 都行 会自动转化 那我为什么还要呢 mysq
  • 错误:编码GBK的不可映射字符解决方案(亲测有效)

    CMD编译运行JAVA程序出现的错误 原要求 这次作业要求用命令行输出 但是java命令后显示的是中文乱码 也有的出现错误 编码GBK的不可映射字符 原因 引用 由于JDK是国际版的 我们在用javac exe编译时 编译程序首先会获得我们
  • 插入mysql,Cause: com.mysql.cj.jdbc.exceptions.MysqlDataTruncation:Data truncation: Data too long

    插入mysql 报错 Error updating database Cause com mysql cj jdbc exceptions MysqlDataTruncation Data truncation Data too long
  • Legal or Not HDU - 3342 拓扑排序 判环

    这道题的意思是 给你n个点 m行关系数据 左 gt 右 判断有无环的出现 方法 直接拓扑排序 如果能正常排序完 这个就是无环的有向图DAG 如果不能 在拓扑排序的过程中有些点的入度经过去边操作之后一直不为零 就是有环的存在 include
  • GPT4.0一句话实现各类图表制作,让数据可视化变得更简单!类图、流程图、ER图.....

    不知道大家有没有被ER建模工具复杂的操作按钮给困扰过 在作者学习ER建模时 曾希望能直接画出类图 但最终还是不得不学习繁琐的操作流程 然而 随着GPT的出现 AI现在也可以绘制UML图了 今天要向大家分享一个AI工具 它能够借助强大的GPT
  • STM32多中断模式

    1 基本概念 ARM Coetex M3内核共支持256个中断 其中16个内部中断 240个外部中断和可编程的256级中断优先级的设置 STM32目前支持的中断共84个 16个内部 68个外部 还有16级可编程的中断优先级的设置 仅使用中断
  • STM32与BLE蓝牙通信 Android APP配置(一)

    事物的难度远远低于对事物的恐惧 0 前言 最近完成了一个基于BLE蓝牙通信的简单APP 在这里记录下来 供大家参考希望能给需要的人解决疑惑 这个APP中一共是两个界面 第一个界面实现打开蓝牙 关闭蓝牙 扫描蓝牙和显示扫描的结果 通过选择扫描
  • 【接口测试基础】第十四篇

    iHRM项目实战 简介 功能模块 技术架构 前端 以Node js为核心的Vue js前端技术生态架构 后端 SprintBoot SprintCloud SprintMVC SprintData Spring全家桶 MySQL Redis
  • java list stream 去除 null_Stream流的这些操作,你得知道,对你工作有很大帮助

    作者 扬帆 起航 原文 https blog csdn net qq 43677736 Stream流 Stream 流 是一个来自数据源的元素队列并支持聚合操作 元素是特定类型的对象 形成一个队列 Java中的Stream并不会存储元素
  • SPSS软件学习

    1 文件 新建 数据 2 修改变量信息 在这里插入图片描述 3 查看数据基本情况 分析 描述统计 描述 4 相关性分析 相关 双变量 结果显示Pearson相关系数为 0 902 P值小于0 01 相关关系具有统计学意义 但实际上它们不一定
  • JAX基本用法以及GCN实现

    JAX定位 JAX 不是一个深度学习框架或深度学习库 其设计初衷也不是成为一个深度学习框架或深度学习库 JAX 的定位科学计算 Scientific Computing 和函数转换 Function Transformations 的交叉融