From ee3f4100a07619d9b5e9794e5a920ae5177892db Mon Sep 17 00:00:00 2001 From: Jaxen Godfrey <80279129+jaxengodfrey@users.noreply.github.com> Date: Fri, 31 Jan 2025 13:07:10 -0800 Subject: [PATCH] Update installation.rst --- docs/installation.rst | 50 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 44 insertions(+), 6 deletions(-) diff --git a/docs/installation.rst b/docs/installation.rst index 91d7dc5a..fe0ccee8 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -10,28 +10,66 @@ Clone the repository git clone https://github.com/FarrOutLab/GWInferno.git cd gwinferno -Recommended to use conda to set up your environment with python versions newer than at least 3.9: +It is recommended to use conda to set up your environment with python versions newer than at least 3.12: -For CPU usage only create an environment and install requirements and GWInferno with: +--------------------------------------------------------------------- +For CPU +--------------------------------------------------------------------- + +For CPU usage only, create an environment and install requirements and GWInferno with: .. code-block:: bash cd gwinferno - conda create -n gwinferno python=3.10 + conda create -n gwinferno python=3.12 conda activate gwinferno conda install -c conda-forge numpyro h5py pip install --upgrade pip pip install -r pip_requirements.txt python -m pip install . -To enable JAX access to CUDA enabled GPUs we need to specify specific versions to install (See `JAX `_ installation instructions for more details). For a GPU enabled environment use: +--------------------------------------------------------------------- +For GPU +--------------------------------------------------------------------- + +To enable JAX access to CUDA enabled GPUs, we need to specify specific versions to install. The following procedure will only work for Linux x86_64 and Linux aarch64; for other platforms see `Jax documentation `_. + +Jax recommends installing Nvidia CUDA and cuDNN with pip wheels. If you use local installations of CUDA and cuDNN, which could be the case for a remote cluster, then you'll need to install jax from the single CUDA wheel variant it offers. As of writing, this wheel is only compatible with CUDA >= 12.1 and cuDNN >= 9.1 < 10.0. See `JAX `_ installation instructions for more details. + +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Installation process for CUDA installed via pip: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +.. code-block:: bash + + cd gwinferno + conda create -n gwinferno_gpu python=3.12 + conda activate gwinferno_gpu + pip install --upgrade pip + +Next, install CUDA and cuDNN pip wheels. See `Nvidia's CUDA quickstart guide `_ for the CUDA installation procedure and the `cuDNN documentation `_ for the cuDNN installation procedure. Once that has finished, continue with these steps: + +.. code-block:: bash + + pip install --upgrade "jax[cuda12]" + pip install numpyro[cuda] + pip install -r pip_requirements.txt + python -m pip install . + +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Installation process for locally installed CUDA and cuDNN: +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Ensure `Nvidia CUDA `_ and `cuDNN `_ are installed locally. .. code-block:: bash cd gwinferno - conda create -n gwinferno_gpu python=3.10 + conda create -n gwinferno_gpu python=3.12 conda activate gwinferno_gpu - conda install -c nvidia -c conda-forge jaxlib=*=*cuda* jax cuda-nvcc numpyro h5py pip install --upgrade pip + pip install --upgrade "jax[cuda12_local]" + pip install numpyro[cuda] pip install -r pip_requirements.txt python -m pip install .