
Installation
Pour installer le paquet JAX (sur Olympe), veuillez suivre les étapes suivantes :
Allouez un nœud de calcul contenant des GPUs :
salloc -N 1 -n 36 --gres=gpu:1 --time=01:00:00 --mem=20G
Connectez-vous au nœud :
ssh olympevolta<numero_du_noeud>
Charger cuda version 11.5, conda et creer l'environement avec le fichier .yml fournie en bas (il faut creer le fichier jax-gpu_environment.yml et copier le contenu):
module purge module load cuda/11.7 module load conda/4.9.2 proxychains4 conda env create -f jax-gpu_environment.yml
jax-gpu_environment.yml :
name: jax-gpu
channels:
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- brotli-python=1.1.0=py38h17151c0_1
- bzip2=1.0.8=hd590300_5
- c-ares=1.23.0=hd590300_0
- ca-certificates=2023.11.17=hbcca054_0
- certifi=2023.11.17=pyhd8ed1ab_0
- charset-normalizer=3.3.2=pyhd8ed1ab_0
- cuda-version=11.5=h6c6c5af_2
- cudatoolkit=11.5.2=hbdc67f6_12
- cudnn=8.8.0.121=hcdd5f01_4
- idna=3.6=pyhd8ed1ab_0
- importlib-metadata=7.0.0=pyha770c72_0
- importlib_metadata=7.0.0=hd8ed1ab_0
- jax=0.4.13=pyhd8ed1ab_0
- jaxlib=0.4.12=cuda112py38h67cd1f8_201
- ld_impl_linux-64=2.40=h41732ed_0
- libabseil=20230125.3=cxx17_h59595ed_0
- libblas=3.9.0=20_linux64_openblas
- libcblas=3.9.0=20_linux64_openblas
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_3
- libgfortran-ng=13.2.0=h69a702a_3
- libgfortran5=13.2.0=ha4646dd_3
- libgomp=13.2.0=h807b86a_3
- libgrpc=1.56.2=h3905398_1
- liblapack=3.9.0=20_linux64_openblas
- libnsl=2.0.1=hd590300_0
- libopenblas=0.3.25=pthreads_h413a1c8_0
- libprotobuf=4.23.3=hd1fb520_1
- libsqlite=3.44.2=h2797004_0
- libstdcxx-ng=13.2.0=h7e041cc_3
- libuuid=2.38.1=h0b41bf4_0
- libzlib=1.2.13=hd590300_5
- ml_dtypes=0.2.0=py38h53bb729_2
- nccl=2.19.4.1=h0800d71_0
- ncurses=6.4=h59595ed_2
- numpy=1.24.4=py38h59b608b_0
- openssl=3.2.0=hd590300_1
- opt_einsum=3.3.0=pyhc1e730c_2
- packaging=23.2=pyhd8ed1ab_0
- pip=23.3.1=pyhd8ed1ab_0
- platformdirs=4.1.0=pyhd8ed1ab_0
- pooch=1.8.0=pyhd8ed1ab_0
- pysocks=1.7.1=pyha2e5f31_6
- python=3.8.18=hd12c33a_0_cpython
- python_abi=3.8=4_cp38
- re2=2023.03.02=h8c504da_0
- readline=8.2=h8228510_1
- requests=2.31.0=pyhd8ed1ab_0
- scipy=1.10.1=py38h59b608b_3
- setuptools=68.2.2=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- urllib3=2.1.0=pyhd8ed1ab_0
- wheel=0.42.0=pyhd8ed1ab_0
- xz=5.2.6=h166bdaf_0
- zipp=3.17.0=pyhd8ed1ab_0
Note : proxychains4 permet d'accéder à Internet depuis le nœud de calcul.
Test jax sur GPU
Une fois l'installation ce bien passé, se placer sur un nœud de calcul Volta et faire ces trois étapes si c'est pas déjà fait:
module load cuda/11.7 module load conda/4.9.2 conda activate jax-gpu
Générer le script test Python :
cat << EOF > jax_test_gpu.py import jax import jax.numpy as jnp import os # Check the available GPU devices jax.devices() # Define a simple function to run on GPU def gpu_add(a, b): return jax.device_put(a + b) # Create some arrays x_gpu = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000), dtype=jnp.float32) y_gpu = jax.random.normal(jax.random.PRNGKey(1), (1000, 1000), dtype=jnp.float32) # Run the function on GPU result_gpu = gpu_add(x_gpu, y_gpu) # Check the result print(result_gpu) EOF
Si lors de l'exécution du code suivant aucun message ne mentionne la non-utilisation des GPUs, cela signifierait que tout est correct :
python jax_test_gpu.py