Automatic differentiation
Normally, to implement the derivative of a function, you first have to derive the derivative yourself and then implement it again. This is prone to errors because either the derivation or the implementation can be incorrect. In addition, discrepancies could arise between the original function and its derivation if the two are not updated at the same time. This is also time-consuming.
Automatic differentiation (AD) allows a sequence of code functions to be interpreted as mathematical functions. If a result of one Python function is passed into another, this is formally a concatenation of the two functions so that the chain rule can be applied. If the derivative is now known for each individual step, the derivative of the entire function can be calculated. JAX has such a library, which provides a derivative for (almost) all functions in numpy
.
First a simple example:
def f(x):
return x**2
df = jax.grad(f)
float(df(2.)) # 4.0
The function jax.grad
takes a function as an argument and calculates its derivative according to the first argument. This can also be used multiple times:
def f(x):
return x**2
df = jax.grad(f)
ddf = jax.grad(df)
float(ddf(2.)) # 2.0
The return is an array
, not a number, and must therefore be converted into a number with float
.
The function jax.grad
is also able to differentiate functions that have several arguments:
def f(x, y):
return x**2 + 2*y**2
df = jax.grad(f, argnums=0)
float(df(2., 3.)) # 4.0
The argument argnums
specifies which argument is to be differentiated according to. If both arguments are to be differentiated, we must pass a tuple:
def f(x, y):
return x**2 + 2*y**2
df = jax.grad(f, argnums=(0, 1))
df(2., 3.) # (4.0, 12.0)
If the argument is a vector, jnp.array
must be used instead of np.array
:
def norm(x):
return jnp.sqrt(jnp.dot(x, x))
vector = jnp.array((1.,2.))
df = jax.grad(norm)
norm(vector) # 2.236068
df(vector) # [0.4472136, 0.8944272]
For functions that return a vector, use the function jax.jacobian
instead of jax.grad
:
def test(x):
return x
vector = jnp.array((1.,2.))
jax.jacobian(test)(vector) # [[1., 0.],
# [0., 1.]]
As an example, here is a way to implement a harmonic potential in JAX and automatically obtain the derivative:
def potential(x):
distances = jnp.abs(x[jnp.newaxis, :] - x[:, jnp.newaxis])
return (distances**2).sum() / 2
jax.grad(potential)(vector)
If you want to calculate value and gradient at the same time, you can do this with jax.value_and_grad
:
jax.value_and_grad(potential)(vector) # 1., [-2., 2.]