Ich suche nach einem definitiven, narrensicheren Befehlssatz, um eine funktionierende GPU-Installation zu erhalten.
Systemspezifikationen:
- Betriebssystem: Ubuntu 18.04 LTS
- GPU: 8x NVIDIA Quadro RTX 8000
- NVIDIA-Treiber: 550.144.03
- CUDA-Version (vom Treiber gemeldet): 12.4
- Python: 3.10 (verwaltet von Conda)
Ich habe für jeden Versuch sorgfältig neue Conda-Umgebungen erstellt, um sicherzustellen, dass es keine Konflikte gibt.
Versuch Nr. 1: Die empfohlene Standardmethode
Dies ist der offiziell empfohlene Befehl.
Code: Select all
conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
- Erwartetes Ergebnis: Ein großes, mehrere Gigabyte großes Jaxlib-Rad mit enthaltenen CUDA-Bibliotheken sollte heruntergeladen und installiert werden.
- Tatsächliches Ergebnis: pip ignoriert konsequent die [cuda12_pip]-Direktive, lädt die kleine CPU-Version von jaxlib (89,9 MB) herunter und gibt eine Warnung. Der Verifizierungsbefehl bestätigt diesen Fehler:
Code: Select all
WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Dies ist die professionelle Problemumgehung, um die Installation eines bestimmten GPU-Rads zu erzwingen.
Code: Select all
# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy
Mobile version