Automatische Differenzierung
Normalerweise muss man für die Implementierung der Ableitung einer Funktion die Ableitung erst selbst herleiten und dann erneut implementieren. Das ist fehleranfällig, weil entweder die Herleitung oder die Implementierung fehlerhaft sein können. Zudem könnten Diskrepanzen zwischen der originalen Funktion und deren Ableitung entstehen, wenn die beiden nicht gleichzeitig aktualisiert werden. Darüberhinaus ist das auch zeitaufwändig.
Automatische Differenzierung (AD) erlaubt es, eine Folge von Code-Funktionen als mathematische Funktionen aufzufassen. Wird ein Ergebnis der einen Python-Funktion in eine andere gegeben, ist das formal eine Verkettung der beiden Funktionen, sodass die Kettenregel angewendet werden kann. Ist nun für jeden einzelnen Schritt die Ableitung bekannt, so kann die Ableitung der gesamten Funktion berechnet werden. JAX hat eine solche Bibliothek, die für (nahezu) alle Funktionen in numpy
eine Ableitung bereit hält.
Zunächst ein einfaches Beispiel:
def f(x):
return x**2
df = jax.grad(f)
float(df(2.)) # 4.0
Die Funktion jax.grad
nimmt eine Funktion als Argument und rechnet deren Ableitung nach dem ersten Argument aus. Das lässt sich auch mehrfach anwenden:
def f(x):
return x**2
df = jax.grad(f)
ddf = jax.grad(df)
float(ddf(2.)) # 2.0
Die Rückgabe ist ein Array
, nicht etwa eine Zahl und muss deshalb mit float
in eine Zahl umgewandelt werden.
Die Funktion jax.grad
ist in der Lage, auch Funktionen zu differenzieren, die mehrere Argumente haben:
def f(x, y):
return x**2 + 2*y**2
df = jax.grad(f, argnums=0)
float(df(2., 3.)) # 4.0
Das Argument argnums
gibt an, nach welchem Argument differenziert werden soll. Soll nach beiden Argumenten differenziert werden, müssen wir ein Tupel übergeben:
def f(x, y):
return x**2 + 2*y**2
df = jax.grad(f, argnums=(0, 1))
df(2., 3.) # (4.0, 12.0)
Ist das Argument ein Vektor, so muss jnp.array
statt np.array
verwendet werden:
def norm(x):
return jnp.sqrt(jnp.dot(x, x))
vector = jnp.array((1.,2.))
df = jax.grad(norm)
norm(vector) # 2.236068
df(vector) # [0.4472136, 0.8944272]
Bei Funktionen, die einen Vektor zurückgeben, ist statt jax.grad
die Funktion jax.jacobian
zu verwenden:
def test(x):
return x
vector = jnp.array((1.,2.))
jax.jacobian(test)(vector) # [[1., 0.],
# [0., 1.]]
Als Beispiel hier ein Weg, ein harmonisches Potential in JAX zu implementieren und automatisch die Ableitung zu erhalten:
def potential(x):
distances = jnp.abs(x[jnp.newaxis, :] - x[:, jnp.newaxis])
return (distances**2).sum() / 2
jax.grad(potential)(vector)
Will man gleichzeit Wert und Gradient berechnen, geht das mit jax.value_and_grad
:
jax.value_and_grad(potential)(vector) # 1., [-2., 2.]