JAX, Neukompilation bei der Verwendung von Verschluss für eine Funktion

Post a reply

Smilies
:) :( :oops: :chelo: :roll: :wink: :muza: :sorry: :angel: :read: *x) :clever:
View more smilies

BBCode is ON
[img] is ON
[flash] is OFF
[url] is ON
Smilies are ON

Topic review
   

Expand view Topic review: JAX, Neukompilation bei der Verwendung von Verschluss für eine Funktion

by Anonymous » 24 Aug 2025, 01:52

Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?

Code: Select all

import jax
import jax.numpy as jnp
from jax import tree_util

class Params:
def __init__(self, x_array, a):
self.x_array = x_array
self.a = a
return

def one_step(self,state, input):
x = state
y = input
next_state = (self.x_array + x + jnp.ones(self.a))*y
return next_state

def _tree_flatten(self):
children = (self.x_array,)
aux_data = {'a':self.a}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)

tree_util.register_pytree_node(Params,
Params._tree_flatten,
Params._tree_unflatten)

def scan_body(params, state, input):
x = state
y = input
x_new = params.one_step(x, y)
return x_new, [x_new]

@jax.jit
def example(params):
body_fun = lambda state, input: scan_body(params, state, input)
init_state = jnp.array([0.,1.])
input_array = jnp.array([1.,2.,3.])
last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
return last_state, result_list

if __name__ == "__main__":

params1 = Params(jnp.array([1.,2.]), 2)
last_state, result_list = example(params1)
print(last_state)

params2 = Params(jnp.array([3.,4.]), 2)
last_state, result_list = example(params2)
print(last_state)

Top