Skip to content

Automatische Differenzierung

Normalerweise muss man für die Implementierung der Ableitung einer Funktion die Ableitung erst selbst herleiten und dann erneut implementieren. Das ist fehleranfällig, weil entweder die Herleitung oder die Implementierung fehlerhaft sein können. Zudem könnten Diskrepanzen zwischen der originalen Funktion und deren Ableitung entstehen, wenn die beiden nicht gleichzeitig aktualisiert werden. Darüberhinaus ist das auch zeitaufwändig.

Automatische Differenzierung (AD) erlaubt es, eine Folge von Code-Funktionen als mathematische Funktionen aufzufassen. Wird ein Ergebnis der einen Python-Funktion in eine andere gegeben, ist das formal eine Verkettung der beiden Funktionen, sodass die Kettenregel angewendet werden kann. Ist nun für jeden einzelnen Schritt die Ableitung bekannt, so kann die Ableitung der gesamten Funktion berechnet werden. JAX hat eine solche Bibliothek, die für (nahezu) alle Funktionen in numpy eine Ableitung bereit hält.

Zunächst ein einfaches Beispiel:

def f(x): 
    return x**2

df = jax.grad(f)
float(df(2.)) # 4.0

Die Funktion jax.grad nimmt eine Funktion als Argument und rechnet deren Ableitung nach dem ersten Argument aus. Das lässt sich auch mehrfach anwenden:

def f(x): 
    return x**2

df = jax.grad(f)
ddf = jax.grad(df)
float(ddf(2.)) # 2.0

Die Rückgabe ist ein Array, nicht etwa eine Zahl und muss deshalb mit float in eine Zahl umgewandelt werden.

Die Funktion jax.grad ist in der Lage, auch Funktionen zu differenzieren, die mehrere Argumente haben:

def f(x, y): 
    return x**2 + 2*y**2

df = jax.grad(f, argnums=0)
float(df(2., 3.)) # 4.0

Das Argument argnums gibt an, nach welchem Argument differenziert werden soll. Soll nach beiden Argumenten differenziert werden, müssen wir ein Tupel übergeben:

def f(x, y): 
    return x**2 + 2*y**2

df = jax.grad(f, argnums=(0, 1))
df(2., 3.) # (4.0, 12.0)

Ist das Argument ein Vektor, so muss jnp.array statt np.array verwendet werden:

def norm(x):
    return jnp.sqrt(jnp.dot(x, x))

vector = jnp.array((1.,2.))
df = jax.grad(norm)
norm(vector) # 2.236068
df(vector)   # [0.4472136, 0.8944272]

Bei Funktionen, die einen Vektor zurückgeben, ist statt jax.grad die Funktion jax.jacobian zu verwenden:

def test(x):
    return x

vector = jnp.array((1.,2.))
jax.jacobian(test)(vector) #  [[1., 0.],
                           #   [0., 1.]]

Als Beispiel hier ein Weg, ein harmonisches Potential in JAX zu implementieren und automatisch die Ableitung zu erhalten:

def potential(x):
    distances = jnp.abs(x[jnp.newaxis, :] - x[:, jnp.newaxis])
    return (distances**2).sum() / 2

jax.grad(potential)(vector)

Will man gleichzeit Wert und Gradient berechnen, geht das mit jax.value_and_grad:

jax.value_and_grad(potential)(vector) # 1., [-2.,  2.]