Machine Learning and Quantum Alchemy: Day 4
Part of the course Machine Learning and Quantum Alchemy.
Make sure to make jax
run in double precision mode:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
Exercises
-
Implement the Lennard-Jones potential which approximately describes a diatomic potential as function
LJ(r: float, sigma: float, epsilon: float) -> float
.\[V(r)=4\varepsilon \left[\left(\frac{\sigma}{r}\right)^{12}-\left(\frac{\sigma}{r}\right)^{6}\right]\]Here, \(\varepsilon\) models the depth of the potential, \(\sigma\) describes the position and \(r\) is the distance at which the potential is evaluated.
-
Plot the potential for \(\varepsilon=1\) and \(\sigma=1\) over the interval \([0.95, 2.5]\).
-
Obtain the numerical derivative w.r.t. the distance \(r\) by using finite differences for a displacement of \(\delta = 10^{-5}\). Plot it.
-
Obtain the analytical derivative of the function by hand, implement it as
dLJ(r: float, sigma: float, epsilon: float) -> float
and plot it together with the finite difference derivative. -
Read the page on automatic differentiation for yourself and try some of the examples. Then use
jax.grad
to derive the functionLJ
plot the results. -
Obtain the derivative of
LJ
for all three arguments \(\varepsilon\), \(\sigma\), and \(r\) as well as the function value in a singlejax
call. -
How can we obtain higher order derivatives?
Note that
jax
is not doing symbolic derivatives. This can be done e.g. usingsympy
.import sympy as sp def analytical_LJ_derivative(order: int): r, sigma, epsilon = sp.symbols("r \sigma \epsilon") return sp.diff(4*epsilon*((sigma/r)**12-(sigma/r)**6), r, order)
You may inspect what
jax
is doing usingjax.make_jaxpr
. -
Different elements would be simulated with different LJ parameters. In that situation, the parameter are unified using mixing rules. Implement the Lorentz-Berthelot rules in the following code stub and plot the
jax
derivative for Ne-Ne, Ar-Ne, and Ar-Ar interactions. Make sure to call theLJ
function you had before rather than copying the code intoenergy
. For distances larger 10, the function should directly return 0.def energy(element1: str, element2: str, distance: float) -> float: table = {'Ar': (1.2, 3.), "Ne": (3., 2.7)} sigma1, epsilon1 = table[element1] sigma2, epsilon2 = table[element2] # continue here
-
Understand what the following code does and obtain the minimum energy configuration optionally using the derivatives.
def LJ_many(elements: list[str], positions1d: np.array) -> float: table = {'Ar': (1.2, 3.), "Ne": (3., 2.7)} params = jnp.array([table[_] for _ in elements]) epsilons = jnp.sqrt(jnp.outer(params[:, 1], params[:, 1])) sigmas = (params[:, 0, np.newaxis] + params[:, 0])/2 rs = jnp.abs(positions1d[:, np.newaxis] - positions1d) rs = jnp.fill_diagonal(rs, 1e100, inplace=False) rmod = (sigmas/rs)**6 all = 4*epsilons*(rmod**2-rmod) return jnp.triu(all, 1).sum() LJ_many(["Ne", "Ne", "Ar"], jnp.array([0., 2., 4.]))
The solution is
0., 3.365, 5.722
or any translated positions.Advanced solution
scipy
can use the analytical gradient information fromjax
during optimization to speed up high-dimensional gradient optimisation.import scipy.optimize as sco elements = ["Ne", "Ne", "Ar"] start = jnp.array([0., 2., 4.]) func = lambda xs: LJ_many(elements, xs) grad = lambda xs: jax.grad(LJ_many, argnums=1)(elements, xs) sco.minimize(func, start, jac=grad)