Just-in-time compilation
While we saw in the introduction to numpy
that it is advantageous to keep loops or similar in numpy
instead of in Python logic, it is sometimes unavoidable if the functionality is not available in numpy
. In such cases, code compilation can help: inefficient Python code is analysed once and converted into equivalent but more efficient code. This means that the first call of a function is significantly slower, as it has to be analysed, but the subsequent calls are then substantially faster.
Suppose we had no function to calculate the norm of a vector. Then a pure Python implementation of this could look like this:
def norm_simple(x):
total = 0
for val in x:
total += val * val
return np.sqrt(total)
If we replace the numpy
call with jax.numpy
, we get a new function that we can compile with jit
:
def norm_simple(x):
total = 0
for val in x:
total += val * val
return jnp.sqrt(total)
norm_simple_jit = jax.jit(norm_simple)
The first time norm_simple_jit
is called, the code is analysed and converted into efficient code. The subsequent calls are then significantly faster.
This can lead to significant improvements even with simple code:
def third_power(mat):
return mat*mat*mat
mat = np.random.uniform(size=(5000, 5000))
third_power_jit = jax.jit(third_power)
Here, third_power_jit
only needs about 1/4 of the time compared to third_power
.