Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?
Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?[code]import jax import jax.numpy as jnp from jax import tree_util
class Params: def __init__(self, x_array, a): self.x_array = x_array self.a = a return
def one_step(self,state, input): x = state y = input next_state = (self.x_array + x + jnp.ones(self.a))*y return next_state
Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans...
Ich bin ein Jax -Anfänger und jemand mit Jax hat mir gesagt, dass wenn wir wiederholt Anrufe zu einem Scan / for Loop (z. B. wenn diese selbst für Loop selbst eingewickelt werden), könnte es besser...
Ich versuche herauszufinden, wie man nnx.split_rngs verwendet. Kann jemand eine Version des folgenden Codes geben, der nnx.split_rngs mit jax.tree.map verwendetimport jax
from flax import nnx
from...