AttributeError:模块“jaxlib.xla_extension”没有属性“PmapFunction”

2024-03-14

有人可以帮我修复在 check_not_jax_transformed(f) 中的“/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)”时出现以下错误吗?多谢。

"AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'"

jaxlib.xla_extension.PmapFunction在 jaxlib 版本 0.72 中添加;听起来您安装了较旧的 jaxlib 版本。您应该使用以下方法更新它:

pip install -U jaxlib

注意:如果您使用 GPU/TPU,则应使用适当的加速器特定安装命令,位于https://github.com/google/jax#installation https://github.com/google/jax#installation.

如果这不起作用,请检查您的 Python 版本。 jaxlib 从 0.1.70 版本开始需要 Python 3.7 或更高版本,因此如果您使用的是 Python 3.6,则需要先升级 Python,然后才能升级到更新的 jaxlib。

看来有问题的行已添加到haiku在您发布问题之前几个小时打包:https://github.com/deepmind/dm-haiku/commit/e6a13af352a8b46d355ac1b7131b64c615cfcf57 https://github.com/deepmind/dm-haiku/commit/e6a13af352a8b46d355ac1b7131b64c615cfcf57如果您不想更新 jaxlib,另一个选择是安装稳定版本dm-haiku而不是使用开发版本:

pip install dm-haiku==0.0.5
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

AttributeError:模块“jaxlib.xla_extension”没有属性“PmapFunction” 的相关文章

随机推荐