Ich versuche herauszufinden, wie man nnx.split_rngs verwendet. Kann jemand eine Version des folgenden Codes geben, der nnx.split_rngs mit jax.tree.map verwendetimport jax
from flax import nnx
from functools import partial
if __name__ == '__main__':
session_sizes = {
'a':2,
'b':3,
'c':4,
'd':5,
'e':6,
}
dz = 2
rngs = nnx.Rngs(0)
my_linear = partial(
nnx.Linear,
use_bias = False,
in_features = dz,
rngs=rngs )
def my_linear_wrapper(a):
return my_linear( out_features=a )
q_s = jax.tree.map(my_linear_wrapper, session_sizes)
for k in session_sizes.keys():
print(q_s[k].kernel)
< /code>
In diesem Fall würden wir einen Baum mit Schichten benötigen, der unsere 2 IN_Features in Räume von 2, ..., 6 Out_Features bringt. @nnx.split_rngs Funktion Dekorateur.>
Flachs NNX / JAX: Baum.Map für Schichten von inkongruenter Größe ⇐ Python
-
- Similar Topics
- Replies
- Views
- Last post
-
-
Konvertieren Sie „Map with Set of Strings“ als Schlüssel in „Map with Strings“.
by Anonymous » » in Java - 0 Replies
- 26 Views
-
Last post by Anonymous
-