Wie installiere ich JAX mit CUDA unter Linux korrekt, wenn „jax[cuda12_pip]“ ständig auf die CPU-Version zurückgreift?Python

Python-Programme
Anonymous
 Wie installiere ich JAX mit CUDA unter Linux korrekt, wenn „jax[cuda12_pip]“ ständig auf die CPU-Version zurückgreift?

Post by Anonymous »

Ich versuche, JAX mit GPU-Unterstützung auf einem leistungsstarken, dedizierten Linux-Server zu installieren, aber ich stecke in einer Zwickmühle fest, in der jede offizielle Installationsmethode auf andere Weise fehlschlägt, was immer dazu führt, dass JAX auf die CPU zurückfällt.
Ich suche nach einem definitiven, narrensicheren Befehlssatz, um eine funktionierende GPU-Installation zu erhalten.
Systemspezifikationen:
  • Betriebssystem: Ubuntu 18.04 LTS
  • GPU: 8x NVIDIA Quadro RTX 8000
  • NVIDIA-Treiber: 550.144.03
  • CUDA-Version (vom Treiber gemeldet): 12.4
  • Python: 3.10 (verwaltet von Conda)
Was ich versucht habe
Ich habe für jeden Versuch sorgfältig neue Conda-Umgebungen erstellt, um sicherzustellen, dass es keine Konflikte gibt.
Versuch Nr. 1: Die empfohlene Standardmethode
Dies ist der offiziell empfohlene Befehl.

Code: Select all

conda create -n jax_test python=3.10 -y
conda activate jax_test
pip install "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --no-cache-dir
  • Erwartetes Ergebnis: Ein großes, mehrere Gigabyte großes Jaxlib-Rad mit enthaltenen CUDA-Bibliotheken sollte heruntergeladen und installiert werden.
  • Tatsächliches Ergebnis: pip ignoriert konsequent die [cuda12_pip]-Direktive, lädt die kleine CPU-Version von jaxlib (89,9 MB) herunter und gibt eine Warnung. Der Verifizierungsbefehl bestätigt diesen Fehler:

Code: Select all

WARNING: jax 0.6.2 does not provide the extra 'cuda12-pip'
Downloading jaxlib-0.6.2-cp310-cp310-manylinux2014_x86_64.whl (89.9 MB)
...
$ python -c "import jax; print(jax.devices())"
WARNING: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]
Versuch Nr. 2: Die direkte URL-Methode
Dies ist die professionelle Problemumgehung, um die Installation eines bestimmten GPU-Rads zu erzwingen.

Code: Select all

# In a clean environment...
pip install "https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.23+cuda12.cudnn88-cp310-cp310-manylinux2014_x86_64.whl"
pip install jax==0.4.23 "numpy

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post