JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. It is developed by
Google
Google LLC (, ) is an American multinational corporation and technology company focusing on online advertising, search engine technology, cloud computing, computer software, quantum computing, e-commerce, consumer electronics, and artificial ...
with contributions from
Nvidia
Nvidia Corporation ( ) is an American multinational corporation and technology company headquartered in Santa Clara, California, and incorporated in Delaware. Founded in 1993 by Jensen Huang (president and CEO), Chris Malachowsky, and Curti ...
and other community contributors.
It is described as bringing together a modified version o
autograd(automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's
XLA (Accelerated Linear Algebra). It is designed to follow the structure and workflow of
NumPy
NumPy (pronounced ) is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predeces ...
as closely as possible and works with various existing frameworks such as
TensorFlow
TensorFlow is a Library (computing), software library for machine learning and artificial intelligence. It can be used across a range of tasks, but is used mainly for Types of artificial neural networks#Training, training and Statistical infer ...
and
PyTorch
PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is one of the mo ...
. The primary features of JAX are:
# Providing a unified
NumPy
NumPy (pronounced ) is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predeces ...
-like interface to computations that run on CPU, GPU, or TPU, in local or distributed settings.
# Built-in Just-In-Time (JIT) compilation via Open XLA, an open-source machine learning compiler ecosystem.
# Efficient evaluation of gradients via its
automatic differentiation
In mathematics and computer algebra, automatic differentiation (auto-differentiation, autodiff, or AD), also called algorithmic differentiation, computational differentiation, and differentiation arithmetic Hend Dawood and Nefertiti Megahed (2023) ...
transformations.
#
Automatically vectorized to efficiently map them over arrays representing batches of inputs.
grad
The below code demonstrates the grad function's automatic differentiation.
# imports
from jax import grad
import jax.numpy as jnp
# define the logistic function
def logistic(x):
return jnp.exp(x) / (jnp.exp(x) + 1)
# obtain the gradient function of the logistic function
grad_logistic = grad(logistic)
# evaluate the gradient of the logistic function at x = 1
grad_log_out = grad_logistic(1.0)
print(grad_log_out)
The final line should outputː
0.19661194
jit
The below code demonstrates the jit function's optimization through fusion.
# imports
from jax import jit
import jax.numpy as jnp
# define the cube function
def cube(x):
return x * x * x
# generate data
x = jnp.ones((10000, 10000))
# create the jit version of the cube function
jit_cube = jit(cube)
# apply the cube and jit_cube functions to the same data for speed comparison
cube(x)
jit_cube(x)
The computation time for (line #17) should be noticeably shorter than that for (line #16). Increasing the values on line #7, will further exacerbate the difference.
vmap
The below code demonstrates the vmap function's vectorization.
# imports
from jax import vmap partial
import jax.numpy as jnp
# define function
def grads(self, inputs):
in_grad_partial = jax.partial(self._net_grads, self._net_params)
grad_vmap = jax.vmap(in_grad_partial)
rich_grads = grad_vmap(inputs)
flat_grads = np.asarray(self._flatten_batch(rich_grads))
assert flat_grads.ndim 2 and flat_grads.shape inputs.shape return flat_grads
The GIF on the right of this section illustrates the notion of vectorized addition.
pmap
The below code demonstrates the pmap function's parallelization for matrix multiplication.
# import pmap and random from JAX; import JAX NumPy
from jax import pmap, random
import jax.numpy as jnp
# generate 2 random matrices of dimensions 5000 x 6000, one per device
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)
# without data transfer, in parallel, perform a local matrix multiplication on each CPU/GPU
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)
# without data transfer, in parallel, obtain the mean for both matrices on each CPU/GPU separately
means = pmap(jnp.mean)(outputs)
print(means)
The final line should print the valuesː
.1566595 1.1805978
See also
*
NumPy
NumPy (pronounced ) is a library for the Python programming language, adding support for large, multi-dimensional arrays and matrices, along with a large collection of high-level mathematical functions to operate on these arrays. The predeces ...
*
TensorFlow
TensorFlow is a Library (computing), software library for machine learning and artificial intelligence. It can be used across a range of tasks, but is used mainly for Types of artificial neural networks#Training, training and Statistical infer ...
*
PyTorch
PyTorch is a machine learning library based on the Torch library, used for applications such as computer vision and natural language processing, originally developed by Meta AI and now part of the Linux Foundation umbrella. It is one of the mo ...
*
CUDA
In computing, CUDA (Compute Unified Device Architecture) is a proprietary parallel computing platform and application programming interface (API) that allows software to use certain types of graphics processing units (GPUs) for accelerated gene ...
*
Accelerated Linear Algebra
XLA (Accelerated Linear Algebra) is an open-source compiler for machine learning developed by the OpenXLA project. XLA is designed to improve the performance of machine learning models by optimizing the computation graphs at a lower level, making ...
External links
* Documentationː
* Colab (
Jupyter
Project Jupyter (pronounced "Jupiter") is a project to develop open-source software, open standards, and services for interactive computing across multiple programming languages.
It was spun off from IPython in 2014 by Fernando Pérez and Brian ...
/iPython) Quickstart Guideː
*
TensorFlow
TensorFlow is a Library (computing), software library for machine learning and artificial intelligence. It can be used across a range of tasks, but is used mainly for Types of artificial neural networks#Training, training and Statistical infer ...
's XLAː (Accelerated Linear Algebra)
*
YouTube
YouTube is an American social media and online video sharing platform owned by Google. YouTube was founded on February 14, 2005, by Steve Chen, Chad Hurley, and Jawed Karim who were three former employees of PayPal. Headquartered in ...
TensorFlow Channel "Intro to JAX: Accelerating Machine Learning research":
* Original paperː
References
{{differentiable computing
Machine learning
Google
Articles with example Python (programming language) code