Skip to content

Machine Learning and Quantum Alchemy: Day 4

Part of the course Machine Learning and Quantum Alchemy.

Make sure to make jax run in double precision mode:

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

Exercises

  1. Implement the Lennard-Jones potential which approximately describes a diatomic potential as function LJ(r: float, sigma: float, epsilon: float) -> float.

    \[V(r)=4\varepsilon \left[\left(\frac{\sigma}{r}\right)^{12}-\left(\frac{\sigma}{r}\right)^{6}\right]\]

    Here, \(\varepsilon\) models the depth of the potential, \(\sigma\) describes the position and \(r\) is the distance at which the potential is evaluated.

  2. Plot the potential for \(\varepsilon=1\) and \(\sigma=1\) over the interval \([0.95, 2.5]\).

  3. Obtain the numerical derivative w.r.t. the distance \(r\) by using finite differences for a displacement of \(\delta = 10^{-5}\). Plot it.

  4. Obtain the analytical derivative of the function by hand, implement it as dLJ(r: float, sigma: float, epsilon: float) -> float and plot it together with the finite difference derivative.

  5. Read the page on automatic differentiation for yourself and try some of the examples. Then use jax.grad to derive the function LJ plot the results.

  6. Obtain the derivative of LJ for all three arguments \(\varepsilon\), \(\sigma\), and \(r\) as well as the function value in a single jax call.

  7. How can we obtain higher order derivatives?

    Note that jax is not doing symbolic derivatives. This can be done e.g. using sympy.

    import sympy as sp
    def analytical_LJ_derivative(order: int):
        r, sigma, epsilon = sp.symbols("r \sigma \epsilon")
        return sp.diff(4*epsilon*((sigma/r)**12-(sigma/r)**6), r, order)
    

    You may inspect what jax is doing using jax.make_jaxpr.

  8. Different elements would be simulated with different LJ parameters. In that situation, the parameter are unified using mixing rules. Implement the Lorentz-Berthelot rules in the following code stub and plot the jax derivative for Ne-Ne, Ar-Ne, and Ar-Ar interactions. Make sure to call the LJ function you had before rather than copying the code into energy. For distances larger 10, the function should directly return 0.

    def energy(element1: str, element2: str, distance: float) -> float:
        table = {'Ar': (1.2, 3.), "Ne": (3., 2.7)}
        sigma1, epsilon1 = table[element1]
        sigma2, epsilon2 = table[element2]
        # continue here
    
  9. Understand what the following code does and obtain the minimum energy configuration optionally using the derivatives.

    def LJ_many(elements: list[str], positions1d: np.array) -> float:
        table = {'Ar': (1.2, 3.), "Ne": (3., 2.7)}
        params = jnp.array([table[_] for _ in elements])
        epsilons = jnp.sqrt(jnp.outer(params[:, 1], params[:, 1]))
        sigmas = (params[:, 0, np.newaxis] + params[:, 0])/2
    
        rs = jnp.abs(positions1d[:, np.newaxis] - positions1d)
        rs = jnp.fill_diagonal(rs, 1e100, inplace=False)
        rmod = (sigmas/rs)**6
        all = 4*epsilons*(rmod**2-rmod)
        return jnp.triu(all, 1).sum()
    
    LJ_many(["Ne", "Ne", "Ar"], jnp.array([0., 2., 4.]))
    

    The solution is 0., 3.365, 5.722 or any translated positions.

    Advanced solution

    scipy can use the analytical gradient information from jax during optimization to speed up high-dimensional gradient optimisation.

    import scipy.optimize as sco
    elements = ["Ne", "Ne", "Ar"]
    
    start = jnp.array([0., 2., 4.])
    func = lambda xs: LJ_many(elements, xs)
    grad = lambda xs: jax.grad(LJ_many, argnums=1)(elements, xs)
    sco.minimize(func, start, jac=grad)