How to Install JAX with ROCm Acceleration on Ubuntu 24.04

  • Last Created On Dec 13, 2024
  • 88
0 0

Introduction

JAX is an open-source library designed for high-performance numerical computing and machine learning research. It provides tools for automatic differentiation, GPU/TPU acceleration, and just-in-time compilation to optimize code execution. ROCm, AMD's platform for GPU computing, enables JAX to utilize the power of AMD GPUs for faster computations. A ROCm-enabled JAX container is a pre-configured, portable environment that includes JAX optimized for AMD GPUs. By using this container, developers and researchers can easily leverage AMD GPUs for their JAX-based machine learning tasks without worrying about setting up dependencies or hardware configurations, streamlining their workflow.

In this article, you are to download and run a ROCm supported JAX container, and install JAX using Pip for ROCm compute platform.

Use ROCm Supported JAX Containers

In this section, you are to download and run a ROCm supported JAX container. You are to also check the GPU availability from the container.

  1. Pull the ROCm supported container for Jax.

    console
    $ docker pull rocm/jax:latest
    
  2. Run a temporary Docker container.

    console
    $ docker run --rm -it --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --shm-size 8G rocm/jax:latest
    

    The above command runs a temporary container enabling access to GPU devices (/dev/kfd and /dev/dri) for ROCm-supported JAX workloads.

  3. Verify GPU availability from the container.

    console
    $ rocm-smi
    $ python3 -c 'import jax; print(jax.devices())'
    

    The output of the above command should list all devices along with their specifications.

  4. Exit and destroy the temporary container.

    console
    $ exit
    

Install JAX on Host using Pip

In this section, you are to install jaxlibJAX ROCm Plugin, and JAX on the host machine using Pip and check for GPU availability.

  1. Fetch the Python version.

    console
    $ python3 -V
    
  2. Fetch the ROCm version.

    console
    $ amd-smi version
    
  3. Navigate to the JAX ROCm GitHub fork releases page.

  4. Find the installation commands by matching the Python version you retrieved in step 1 with the ROCm version you retrieved in step 2 in the latest release notes.

  5. Copy and execute the jaxlib and JAX ROCm Plugin installation commands from the release notes into your terminal.

    • jaxlib: It acts as a bridge between JAX and the hardware it runs on, such as CPUs, GPUs, or TPUs. Without jaxlib, JAX cannot execute computations efficiently on the target hardware.
    • JAX ROCm Plugin: It is a specialized component that enables JAX to utilize AMD GPUs via the ROCm platform. This plugin ensures compatibility and optimization for JAX workloads on AMD hardware, enabling tasks like automatic differentiation and parallelized matrix operations to run efficiently on ROCm-supported GPUs.
  6. Install the JAX Python package.

    console
    $ python3 -m pip install jax
    
  7. Verify GPU availability.

    console
    $ python3 -c 'import jax; print(jax.devices())'
    

    The output of the above command should display all devices along with their respective IDs.

Conclusion

In this article, you downloaded and ran a ROCm supported JAX container with access to ROCm supported devices. Furthermore, you also installed JAX using Pip for ROCm compute platform.

Views: 88

Recent Articles

  • How to Install JAX with ROCm Acceleratio...
    88
  • Deploy a PyTorch Workspace on a Vultr Cl...
    69
  • Managing Backup Storage
    78
  • Automating FTP Backups in Windows Server
    70
  • Automating FTP Backups in Linux
    73

Popular Articles

  • Our General Terms & Conditions
    2592
  • Our Privacy Policy
    2444
  • Our Cookies Policy
    269
  • Our Terms of Use
    147
  • How to Install JAX with ROCm Acceleratio...
    88