你可以安装python包Jax https://github.com/google/jax/blob/main/setup.py根据您的环境,有一些额外的软件包。
For GPU:
pip install jax[cuda] --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
For TPU:
pip install jax[tpu] --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html
我如何添加这些--find-links
网址如下pyproject.toml
?
[build-system]
requires = ["setuptools>=67.6.0"]
build-backend = "setuptools.build_meta"
[project]
name = "minimal_example"
version = '0.0.1'
requires-python = ">=3.9"
dependencies = [
"seqio-nightly[gcp,cache-tasks]",
"t5[gcp]",
"t5x @ git+https://github.com/google-research/t5x.git"
]
[project.optional-dependencies]
cpu = ["jax[cpu]"]
gpu = ["jax[cuda]" , "t5x[gpu] @ git+https://github.com/google-research/t5x.git"]
tpu = ["jax[tpu]", "t5x[tpu] @ git+https://github.com/google-research/t5x.git"]
dev = ["pytest", "mkdocs"]
If I do pip install -e .
然后我就得到了一个有效的安装。
但做一个pip install -e ".[gpu]
给我一个ResolutionImpossible
错误。
并在做pip install -e ".[tpu]
给我:
Packages installed from PyPI cannot depend on packages which are not also hosted on PyPI.
jax depends on libtpu-nightly@ https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20210615-py3-none-any.whl