Skip to content

Just-In-Time-Kompilierung

Haben wir bei der Einführung in numpy noch gesehen, dass es vorteilhaft ist, Schleifen o.ä. in numpy zu halten statt in Python-Logik, so ist es manchmal unvermeidlich, wenn die Funktionalität nicht in numpy vorhanden ist. In solchen Fällen kann die Kompilierung von Code helfen: hier wird ineffizienter Python-Code einmal analysiert und in äquivalenten aber effizienteren Code umgewandelt. Damit ist der erste Aufruf einer Funktion deutlich langsamer, da die Analyse durchgeführt werden muss, aber die folgenden Aufrufe sind dann substantiell schneller.

Angenommen, wir hätten keine Funktion um die Norm eines Vektors zu berechnen. Dann könnte eine reine Python-Implementierung dessen wie folgt aussehen:

def norm_simple(x):
    total = 0
    for val in x:
        total += val * val
    return np.sqrt(total)

Wenn wir den numpy-Aufruf durch jax.numpy ersetzen, erhalten wir eine neue Funktion, die wir mit jit kompilieren können:

def norm_simple(x):
    total = 0
    for val in x:
        total += val * val
    return jnp.sqrt(total)

norm_simple_jit = jax.jit(norm_simple)

Beim ersten Aufruf von norm_simple_jit wird der Code analysiert und in effizienten Code umgewandelt. Die folgenden Aufrufe sind dann deutlich schneller.

Das kann sogar bei simplem Code zu deutlichen Verbesserungen führen:

def third_power(mat):
    return mat*mat*mat

mat = np.random.uniform(size=(5000, 5000))
third_power_jit = jax.jit(third_power)

Hier benötigt third_power_jit nur noch ca. 1/4 der Zeit im Vergleich zu third_power.