Installation
Pour installer le paquet JAX (sur Olympe), veuillez suivre les étapes suivantes :
Créez un environnement conda avec Python 3.9 pour JAX-GPU :
module purge module load conda
/4.9.2conda create -n jax-gpu python=3.9
Allouez un nœud de calcul contenant des GPUs :
salloc -N 1 -n 36 --gres=gpu:1 --time=01:00:00 --mem=20G
Connectez-vous au nœud :
ssh olympevolta<numero_du_noeud>
Charger cuda version 11.5, conda et Activer l'environement:
module purge module load cuda/11.7 module load conda/4.9.2 conda activate jax-gpu
Créer un repertoire pip_tmp et setter la variable TMPDIR vers ce repertoire pour eviter un erreur de dépassement de mémoire de pip
mkdir /tmpdir/$(whoami)/pip_tmp export TMPDIR=/tmpdir/$(whoami)/pip_tmp
Installez tous les paquets conda dont vous avez besoin et après installez JAX:
proxychains4 pip install --upgrade "jax[cuda11_pip]==0.4.12;" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note : proxychains4 permet d'accéder à Internet depuis le nœud de calcul.
Test jax sur GPU
Une fois l'installation ce bien passé, se placer sur un nœud de calcul Volta et faire ces trois étapes si c'est pas déjà fait:
module load cuda/11.7 module load conda/4.9.2 conda activate jax-gpu
Générer le script test Python :
cat << EOF > jax_test_gpu.py import jax import jax.numpy as jnp import os # Check the available GPU devices jax.devices() # Define a simple function to run on GPU def gpu_add(a, b): return jax.device_put(a + b) # Create some arrays x_gpu = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000), dtype=jnp.float32) y_gpu = jax.random.normal(jax.random.PRNGKey(1), (1000, 1000), dtype=jnp.float32) # Run the function on GPU result_gpu = gpu_add(x_gpu, y_gpu) # Check the result print(result_gpu) EOF
Si lors de l'exécution du code suivant aucun message ne mentionne la non-utilisation des GPUs, cela signifierait que tout est correct :
python jax_test_gpu.py