Installation
Pour installer le paquet JAX (sur Olympe), veuillez suivre les étapes suivantes :
Allouez un nœud de calcul contenant des GPUs :
salloc -N 1 -n 36 --gres=gpu:1 --time=01:00:00 --mem=20GConnectez-vous au nœud :
ssh olympevolta<numero_du_noeud>Charger cuda version 11.5, conda et creer l'environement avec le fichier .yml fournie en bas (il faut creer le fichier jax-gpu_environment.yml et copier le contenu):
module purge module load cuda/11.7 module load conda/4.9.2 proxychains4 conda env create -f jax-gpu_environment.yml
jax-gpu_environment.yml :
name: jax-gpuchannels:- conda-forge- defaultsdependencies:- _libgcc_mutex=0.1=conda_forge- _openmp_mutex=4.5=2_gnu- brotli-python=1.1.0=py38h17151c0_1- bzip2=1.0.8=hd590300_5- c-ares=1.23.0=hd590300_0- ca-certificates=2023.11.17=hbcca054_0- certifi=2023.11.17=pyhd8ed1ab_0- charset-normalizer=3.3.2=pyhd8ed1ab_0- cuda-version=11.5=h6c6c5af_2- cudatoolkit=11.5.2=hbdc67f6_12- cudnn=8.8.0.121=hcdd5f01_4- idna=3.6=pyhd8ed1ab_0- importlib-metadata=7.0.0=pyha770c72_0- importlib_metadata=7.0.0=hd8ed1ab_0- jax=0.4.13=pyhd8ed1ab_0- jaxlib=0.4.12=cuda112py38h67cd1f8_201- ld_impl_linux-64=2.40=h41732ed_0- libabseil=20230125.3=cxx17_h59595ed_0- libblas=3.9.0=20_linux64_openblas- libcblas=3.9.0=20_linux64_openblas- libffi=3.4.2=h7f98852_5- libgcc-ng=13.2.0=h807b86a_3- libgfortran-ng=13.2.0=h69a702a_3- libgfortran5=13.2.0=ha4646dd_3- libgomp=13.2.0=h807b86a_3- libgrpc=1.56.2=h3905398_1- liblapack=3.9.0=20_linux64_openblas- libnsl=2.0.1=hd590300_0- libopenblas=0.3.25=pthreads_h413a1c8_0- libprotobuf=4.23.3=hd1fb520_1- libsqlite=3.44.2=h2797004_0- libstdcxx-ng=13.2.0=h7e041cc_3- libuuid=2.38.1=h0b41bf4_0- libzlib=1.2.13=hd590300_5- ml_dtypes=0.2.0=py38h53bb729_2- nccl=2.19.4.1=h0800d71_0- ncurses=6.4=h59595ed_2- numpy=1.24.4=py38h59b608b_0- openssl=3.2.0=hd590300_1- opt_einsum=3.3.0=pyhc1e730c_2- packaging=23.2=pyhd8ed1ab_0- pip=23.3.1=pyhd8ed1ab_0- platformdirs=4.1.0=pyhd8ed1ab_0- pooch=1.8.0=pyhd8ed1ab_0- pysocks=1.7.1=pyha2e5f31_6- python=3.8.18=hd12c33a_0_cpython- python_abi=3.8=4_cp38- re2=2023.03.02=h8c504da_0- readline=8.2=h8228510_1- requests=2.31.0=pyhd8ed1ab_0- scipy=1.10.1=py38h59b608b_3- setuptools=68.2.2=pyhd8ed1ab_0- tk=8.6.13=noxft_h4845f30_101- urllib3=2.1.0=pyhd8ed1ab_0- wheel=0.42.0=pyhd8ed1ab_0- xz=5.2.6=h166bdaf_0- zipp=3.17.0=pyhd8ed1ab_0
Note : proxychains4 permet d'accéder à Internet depuis le nœud de calcul.
Test jax sur GPU
Une fois l'installation ce bien passé, se placer sur un nœud de calcul Volta et faire ces trois étapes si c'est pas déjà fait:
module load cuda/11.7 module load conda/4.9.2 conda activate jax-gpu
Générer le script test Python :
cat << EOF > jax_test_gpu.py
import jax
import jax.numpy as jnp
import os
# Check the available GPU devices
jax.devices()
# Define a simple function to run on GPU
def gpu_add(a, b):
return jax.device_put(a + b)
# Create some arrays
x_gpu = jax.random.normal(jax.random.PRNGKey(0), (1000, 1000), dtype=jnp.float32)
y_gpu = jax.random.normal(jax.random.PRNGKey(1), (1000, 1000), dtype=jnp.float32)
# Run the function on GPU
result_gpu = gpu_add(x_gpu, y_gpu)
# Check the result
print(result_gpu)
EOFSi lors de l'exécution du code suivant aucun message ne mentionne la non-utilisation des GPUs, cela signifierait que tout est correct :
python jax_test_gpu.py