JAX, Neukompilation bei der Verwendung von Verschluss für eine FunktionPython

Python-Programme
Anonymous
 JAX, Neukompilation bei der Verwendung von Verschluss für eine Funktion

Post by Anonymous »

Ich habe einen JAX -Code, in dem ich ein Array scannen möchte. In der Körperfunktion des Scans habe ich einen Pytree , um einige Parameter und Funktionen zu speichern, die ich während des Scans anwenden möchte. Für den Scan habe ich Lambda im Objekt/Pytree -Namen "Params " verwendet. Wenn ja, wie kann ich die Neukompilation vermeiden?

Code: Select all

import jax
import jax.numpy as jnp
from jax import tree_util

class Params:
def __init__(self, x_array, a):
self.x_array = x_array
self.a = a
return

def one_step(self,state, input):
x = state
y = input
next_state = (self.x_array + x + jnp.ones(self.a))*y
return next_state

def _tree_flatten(self):
children = (self.x_array,)
aux_data = {'a':self.a}
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)

tree_util.register_pytree_node(Params,
Params._tree_flatten,
Params._tree_unflatten)

def scan_body(params, state, input):
x = state
y = input
x_new = params.one_step(x, y)
return x_new, [x_new]

@jax.jit
def example(params):
body_fun = lambda state, input: scan_body(params, state, input)
init_state = jnp.array([0.,1.])
input_array = jnp.array([1.,2.,3.])
last_state, result_list = jax.lax.scan(body_fun, init_state, input_array)
return last_state, result_list

if __name__ == "__main__":

params1 = Params(jnp.array([1.,2.]), 2)
last_state, result_list = example(params1)
print(last_state)

params2 = Params(jnp.array([3.,4.]), 2)
last_state, result_list = example(params2)
print(last_state)

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post