by Anonymous » 24 Aug 2025, 01:52
Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?
Code: Select all
import jax
import jax.numpy as jnp
from jax import tree_util
class Params:
def __init__(self, x_array, a):
self.x_array = x_array
self.a = a
return
def one_step(self,state, input):
x = state
y = input
next_state = (self.x_array + x + jnp.ones(self.a))*y
return next_state
def _tree_flatten(self):
children = (self.x_array,)
aux_data = {'a':self.a}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
tree_util.register_pytree_node(Params,
Params._tree_flatten,
Params._tree_unflatten)
def scan_body(params, state, input):
x = state
y = input
x_new = params.one_step(x, y)
return x_new, [x_new]
@jax.jit
def example(params):
body_fun = lambda state, input: scan_body(params, state, input)
init_state = jnp.array([0.,1.])
input_array = jnp.array([1.,2.,3.])
last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
return last_state, result_list
if __name__ == "__main__":
params1 = Params(jnp.array([1.,2.]), 2)
last_state, result_list = example(params1)
print(last_state)
params2 = Params(jnp.array([3.,4.]), 2)
last_state, result_list = example(params2)
print(last_state)
Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?[code]import jax
import jax.numpy as jnp
from jax import tree_util
class Params:
def __init__(self, x_array, a):
self.x_array = x_array
self.a = a
return
def one_step(self,state, input):
x = state
y = input
next_state = (self.x_array + x + jnp.ones(self.a))*y
return next_state
def _tree_flatten(self):
children = (self.x_array,)
aux_data = {'a':self.a}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
tree_util.register_pytree_node(Params,
Params._tree_flatten,
Params._tree_unflatten)
def scan_body(params, state, input):
x = state
y = input
x_new = params.one_step(x, y)
return x_new, [x_new]
@jax.jit
def example(params):
body_fun = lambda state, input: scan_body(params, state, input)
init_state = jnp.array([0.,1.])
input_array = jnp.array([1.,2.,3.])
last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
return last_state, result_list
if __name__ == "__main__":
params1 = Params(jnp.array([1.,2.]), 2)
last_state, result_list = example(params1)
print(last_state)
params2 = Params(jnp.array([3.,4.]), 2)
last_state, result_list = example(params2)
print(last_state)
[/code]