使用 vmap 时,Jax 不支持不可散列的静态参数

2024-01-05

这与这个问题 https://stackoverflow.com/questions/65612989/jax-cannot-find-the-static-argnums。经过一些工作,我设法将其更改为最后一个错误。代码现在看起来像这样。

import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

@jax.partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(jax.partial(map_row, alpha, T), axis)(input)
    return result

@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

@jax.partial(jax.jit, static_argnums=(0,2,))    
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

jax.jit(value_and_grad(toy))(input, weight)

这导致(我希望的)是最终的错误,即

Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.batching.BatchTracer'> for function map_row is non-hashable.

这很奇怪,因为我想我已经标记了每个地方axis看起来是静态的,但它仍然告诉我它是被追踪的。


当你写一个partial带有位置参数的函数,这些参数首先被传递。所以这:

jax.partial(map_row, alpha, T)

本质上等价于:

lambda z_input: map_row(alpha, T, z_input)

注意参数的顺序不正确——这就是导致错误的原因:你正在传递z_input,一个不可散列的跟踪器,指向一个预计是静态的参数。

您可以通过替换来修复此问题partial上面的声明:

lambda z: map_row(z, alpha, T)

然后你的代码就会正确运行。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用 vmap 时,Jax 不支持不可散列的静态参数 的相关文章

随机推荐

  • 无法使用 BatchNorm 层导入冻结图

    我基于此训练了一个 Keras 模型repo https github com bonlime keras deeplab v3 plus 训练后 我将模型保存为检查点文件 如下所示 sess tf keras backend get se
  • 使用 JQuery(立即)检测对 的所有更改

    价值的体现有多种方式
  • 使用 PowerShell 匹配存储在变量中的字符串

    我正在尝试创建一个备份脚本来移动超过 30 天的文件 但我希望能够从列表中排除文件夹 a C Temp Exclude test b C Temp Exclude 如果我运行以下命令 a match b 下列的 Guy Guy Thomas
  • Select2 使用ajax响应数据生成id

    我的 JSON 响应数据不包含 ID 字段 而 Select2 需要该字段才能显示结果 在文档中 他们提供了一种生成 id 的方法 但是我无法这样做 有人可以提供一个关于如何执行此操作的示例吗 到目前为止我已经尝试过了 itemSearch
  • Net Core 2 中 HandleErrorAttribute 的等效项

    我正在将 Net 4 6 2 项目迁移到 Net Core 2 相当于什么HandleErrorAttribute 第 2 行以下接收错误 public static void RegisterGlobalFilters GlobalFil
  • 在 Java 中使用 volatile 关键字的完整示例?

    我需要一个简单的使用示例volatileJava 中的关键字 由于不使用而导致行为不一致volatile 理论部分volatile用法对我来说已经很清楚了 首先 没有保证由于非易失性变量而暴露缓存的方式 您的 JVM 可能一直对您非常友善
  • 访问VBA:根据非绑定列在组合框中查找项目

    我在 Access 表单上有一个两列组合框 表示键到代码的映射 组合框的第一列是 绑定列 即 当MyComboBox Value叫做 我需要动态设置Value我的组合框基于第二列中找到的值 例如 如果我的组合框源是 Value Code A
  • 蓝牙+模拟鼠标

    有人知道是否可以制作一个应用程序通过蓝牙模拟触摸屏鼠标或触控板 如何使 PC 或 MAC 将我识别为鼠标设备 问候 胡安 您应该查看蓝牙 HID 规范 这可能是可能的 具体取决于您用来模拟鼠标 触控板的设备堆栈 我不熟悉 Android 上
  • 由于错误 800a025e,无法完成操作

    这个错误在 IE10 11 中意味着什么 Error Could not complete the operation due to error 800a025e 我该如何调试它 它说的是这一行 this nativeSelection r
  • TabPanel 中的 gwt ScrollPanel:没有垂直滚动条

    EDIT 我通过调整组件内的大小来修复空白行为VerticalPanel 这似乎对面板尺寸产生了影响 但控制台却忽略了这一点 我不太明白怎么办 但是 我的面板仍然没有显示垂直滚动条 在 GWT 项目中 我具有以下结构 Page DockLa
  • 如何检查我是否有 Base Clearcase 或 UCM?

    我是 ClearCase 的新手 我以前用过理性协同 我们在项目中使用 ClearCase 进行版本控制 在我的旧项目中 我使用了合理的协同作用 其中我们为文件中的任何修改创建 任务 我了解到我们在 ClearCase 中有活动 我想在我们
  • 用于启用/禁用用户的 Firebase 触发器

    在 Firebase Auth 控制台中 每个用户都有一个选项 例如启用 禁用其帐户 如何在 Firebase 函数和 Android 应用程序中触发此事件 函数无法在这种事件上触发 至少现在还没有 函数只会在这些情况下触发 查看doc h
  • Java 中的函数式编程 [关闭]

    Closed 这个问题不符合堆栈溢出指南 help closed questions 目前不接受答案 Java 有没有一个好的函数式编程库 我正在寻找类似的东西谓词 http msdn microsoft com en us library
  • 读取调制解调器固件版本:Android

    我正在开发一个 iPhone 和 Android 中的应用程序 我需要阅读Modem Firmware Version正如 iPhone 开发人员在他身边所做的那样 我在 Internet SO 上搜索 但找不到与我的问题相关的任何内容 是
  • 在 JUnit5 (Eclipse) 中创建 TestSuite

    我在 eclipse 中创建了多个测试用例 java 文件 JUnit 的版本是 JUnit5 现在 我尝试通过 eclipse GUI 创建 Junit TestSuite 在创建过程中 我没有在可用版本中看到 JUnit5 这是我为创建
  • 浏览器中自动完成下拉菜单的样式

    例如 在许多网站上 当输入用户名时 会在显示先前输入的位置出现一个下拉菜单 以便用户可以轻松选择某些内容而不用输入 我知道您可以通过让表单或输入具有以下属性来在浏览器中关闭此功能autocomplete off 问题是当我想要它打开并且输入
  • 告诉 `endl` 不要刷新

    我的程序打印大量短行cout 作为一个稍微做作的例子 我的线条看起来有点像这样 cout lt lt The variable s value is lt
  • 使用核心数据实体更新表节标题的有效方法?

    我为我的 UITableView 使用 NSFetchedResultsController 它显示了我存储在核心数据中的一堆事件 我想做的是按相对日期 即今天 明天 本周等 对表格进行分组 每个事件都有一个开始日期 我尝试在事件实体中创建
  • 为什么索引会使查询变得非常慢?

    有一天我回答了一个question https stackoverflow com questions 5642880 slow mysql query 5642908 5642908就这样 被认为是正确的 但答案给我留下了很大的疑问 不久
  • 使用 vmap 时,Jax 不支持不可散列的静态参数

    这与这个问题 https stackoverflow com questions 65612989 jax cannot find the static argnums 经过一些工作 我设法将其更改为最后一个错误 代码现在看起来像这样 im