Just-In-Time-Kompilierung
Haben wir bei der Einführung in numpy
noch gesehen, dass es vorteilhaft ist, Schleifen o.ä. in numpy
zu halten statt in Python-Logik, so ist es manchmal unvermeidlich, wenn die Funktionalität nicht in numpy
vorhanden ist. In solchen Fällen kann die Kompilierung von Code helfen: hier wird ineffizienter Python-Code einmal analysiert und in äquivalenten aber effizienteren Code umgewandelt. Damit ist der erste Aufruf einer Funktion deutlich langsamer, da die Analyse durchgeführt werden muss, aber die folgenden Aufrufe sind dann substantiell schneller.
Angenommen, wir hätten keine Funktion um die Norm eines Vektors zu berechnen. Dann könnte eine reine Python-Implementierung dessen wie folgt aussehen:
def norm_simple(x):
total = 0
for val in x:
total += val * val
return np.sqrt(total)
Wenn wir den numpy
-Aufruf durch jax.numpy
ersetzen, erhalten wir eine neue Funktion, die wir mit jit
kompilieren können:
def norm_simple(x):
total = 0
for val in x:
total += val * val
return jnp.sqrt(total)
norm_simple_jit = jax.jit(norm_simple)
Beim ersten Aufruf von norm_simple_jit
wird der Code analysiert und in effizienten Code umgewandelt. Die folgenden Aufrufe sind dann deutlich schneller.
Das kann sogar bei simplem Code zu deutlichen Verbesserungen führen:
def third_power(mat):
return mat*mat*mat
mat = np.random.uniform(size=(5000, 5000))
third_power_jit = jax.jit(third_power)
Hier benötigt third_power_jit
nur noch ca. 1/4 der Zeit im Vergleich zu third_power
.