JAX 仅在 jit 下的数组切片上应用函数

2024-04-13

我正在使用 JAX,我想执行类似的操作

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

这不能在以下情况下执行jit。有没有办法做到这一点jax.ops or jax.lax? 我想用jax.ops.index_update(x, idx, y)但我找不到计算方法y不会再次遇到同样的问题。


The 之前的回答 https://stackoverflow.com/a/68423274/2937831由 @rvinas 使用dynamic_slice如果您的索引是静态的,效果很好,但您也可以使用动态索引来完成此操作jnp.where。例如:

import jax
import jax.numpy as jnp

def other_fun(x):
    return x + 1

@jax.jit
def fun(x, index):
  mask = jnp.arange(x.shape[0]) < index
  return jnp.where(mask, other_fun(x), x)

x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

JAX 仅在 jit 下的数组切片上应用函数 的相关文章

随机推荐