Differenz der variablen Werte in JAX Nicht-Jit-Laufzeit und JIT-transformierte LaufzeitPython

Python-Programme
Anonymous
 Differenz der variablen Werte in JAX Nicht-Jit-Laufzeit und JIT-transformierte Laufzeit

Post by Anonymous »

Ich habe einen tiefen Lernmodus, den ich in der JIT -transformierten Weise durch: < /p>
verwandeltemy_function_checked = checkify.checkify(model.apply)
model_jitted = jax.jit(my_function_checked)
err, pred = model_jitted({"params": params}, batch, training=training, rng=rng)
err.throw()
< /code>
Der Code kompiliert gut, aber jetzt möchte ich die Zwischenwerte nach allen paar Schritten debuggen, die Arrays speichern und sie dann mit Pytorch -Tensoren vergleichen. Dafür muss ich die Arrays wiederholt speichern. Der einfachste Weg, dies zu tun, besteht darin, den eingebauten Debugger einer IDE zu verwenden und den Save -Ausdruck nach allen paar Schritten zu bewerten. Aber Jax.jit Transformed Code erlaubt externe Debugger nicht. Aber ich kann das tun, nachdem ich die JIT deaktiviert habe. Sollte ich zwischen den beiden Läufen Abweichungen erwarten? Kann ich davon ausgehen, dass die Werte in JIT- und Nichtjit-Läufen gleich bleiben?

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post