Skip to content

JAX

JAX erlaubt es, numerischen Code so zu schreiben, dass er nahezu unverändert auf CPU (normale Rechner), GPU (Grafikkarten) oder TPU (Cloud) ausgeführt werden kann. Die Syntax ist dabei ähnlich zu numpy. Parallelisierung und Kompilierung erfolgt dabei im Hintergrund. JAX bietet auch eine automatische Differenzierung von Funktionen.

Der Programmierstil, der erst bei größeren Projekten eine Rolle spielt, ist dabei leicht anders: JAX ist funktional statt imperativ. Das bedeutet, dass Funktionen keine Nebeneffekte haben, sodern nur strikt im mathematischen Sinne Eingabe und Ausgabe verknüpfen. Darüberhinaus müssen die Aufrufe reproduzierbar sein. Pseudo-Zufallszahlen sind in JAX allerdings möglich, sodass statistischen Methoden nichts im Weg steht.

Technisch bietet JAX weniger Kontrolle über die tatsächlichen Operationen, die auf der Hardware ausgeführt werden. Das ist nötig, damit der Code auf CPU und Grafikkarten laufen kann, bedeutet aber auch, dass u.U. Code auf CPU langsamer ausgeführt wird, als das mit numpy allein der Fall wäre.

Nicht aller Programmcode kann ohne Änderungen auf JAX portiert werden. Das betrifft v.a. Code mit expliziten Schleifen oder wenn Datenstrukturen in-place (also ohne Kopie) verändert werden.

Die üblichen Importe für JAX sind:

import jax.numpy as jnp
import jax

Sie werden in den Beispielen vorausgesetzt.