Skip to content

JAX

JAX allows numeric code to be written in such a way that it can be executed almost unchanged on CPU (normal computers), GPU (graphics cards) or TPU (cloud). The syntax is similar to numpy. Parallelisation and compilation take place in the background. JAX also offers automatic differentiation of functions.

The programming style, which only plays a role in larger projects, is slightly different: JAX is functional instead of imperative. This means that functions have no side effects, but only link input and output strictly in the mathematical sense. In addition, the calls must be reproducible. However, pseudo-random numbers are possible in JAX, so that nothing stands in the way of statistical methods.

Technically, JAX offers less control over the actual operations that are executed on the hardware. This is necessary to allow the code to run on CPU and graphics cards, but also means that code may run slower on CPU than it would with numpy alone.

Not all programme code can be ported to JAX without changes. This applies in particular to code with explicit loops or when data structures are changed in-place (i.e. without copying).

The usual imports for JAX are:

import jax.numpy as jnp
import jax

They are required in the examples.