Flachs NNX / JAX: Baum.Map für Schichten von inkongruenter GrößePython

Python-Programme
Anonymous
 Flachs NNX / JAX: Baum.Map für Schichten von inkongruenter Größe

Post by Anonymous »

Ich versuche herauszufinden, wie man nnx.split_rngs verwendet. Kann jemand eine Version des folgenden Codes geben, der nnx.split_rngs mit jax.tree.map verwendetimport jax
from flax import nnx
from functools import partial

if __name__ == '__main__':

session_sizes = {
'a':2,
'b':3,
'c':4,
'd':5,
'e':6,
}
dz = 2

rngs = nnx.Rngs(0)

my_linear = partial(
nnx.Linear,
use_bias = False,
in_features = dz,
rngs=rngs )

def my_linear_wrapper(a):
return my_linear( out_features=a )

q_s = jax.tree.map(my_linear_wrapper, session_sizes)

for k in session_sizes.keys():
print(q_s[k].kernel)
< /code>
In diesem Fall würden wir einen Baum mit Schichten benötigen, der unsere 2 IN_Features in Räume von 2, ..., 6 Out_Features bringt. @nnx.split_rngs Funktion Dekorateur.>

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post