JAX LAX.SCAN: Wie kann man über Ebenen und Speicherscheiben gleichzeitig ohne dynamische Indexierung in einer mehrschich
Posted: 17 Aug 2025, 14:36
Ich versuche, ein Framework zu implementieren, in dem RNN -Netzwerke mit einer willkürlichen Anzahl von Ebenen verwaltet werden (es ist Teil einer Bibliothek, die ich basierend auf JAX/Equinox baue). Das Problem ist, dass ich keine Implementierung finden kann, die es mir ermöglicht, mehrere Ebenen modular zu verwalten:
akzeptiert XS mit nicht-homogenen Formen nicht ( und der Bias haben von Natur aus unterschiedliche Formen), daher kann ich das Layer -Tupel nicht als XS zum Scannen übergeben. />[/list]
I would prefer the solution to use lax.scan as it is very optimized for this type of operation from what I know, even multiple nested scans are fine, the only important thing is that it can be compiled in XLA without any strange fallbacks, and I absolutely don't want an hard-coded solution like this:
...
Weil das offensichtlich sehr unpraktisch wäre. Der Punkt ist, dass es möglich sein sollte, die Funktion zu bejagen, da ich alles zur Kompilierzeit bekannt sein sollte, aber ich kann keine Lösung finden.
wobei i ein Indexarray mit Arange ist. Der Fehler dieses Codes lautet:
Dies ist wahrscheinlich mein Favorit, aber wie Sie sehen können, ist das Problem, dass Carry nicht aktualisiert wird, sodass es keinen Fehler erhebt, aber es macht keinen Sinn.
Code: Select all
def __call__(self, x):
def forward(carry, x_t):
def layer_apply(h_t, layer):
return layer(h_t, x_t)
t_memory = lax.scan(layer_apply, carry, self.layers)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
< /code>
Dieser Code ist eine Funktion einer Klasse namens 'recurrent_block' (an eqx.module), die als Parameter ein Tupel von Rnn -Schichten akzeptiert und ein Selbst erstellt. jnp.zeros, one for layer: (h(0), h(1)... h(n))
[b]self.layers = tuple of eqx.Module classes, each with a[/b] [b]call[/b] method that implements the RNN forward pass
[b]x = jnp.array with shape (seq_len, batch_size, Funktionen) [/b]
(Wie Sie sehen können, ist alles zum Kompilierzeit bekannt)
Jetzt gibt es zwei große Probleme im Code:
[list]
lax.scan
Code: Select all
w_h, w_h
I would prefer the solution to use lax.scan as it is very optimized for this type of operation from what I know, even multiple nested scans are fine, the only important thing is that it can be compiled in XLA without any strange fallbacks, and I absolutely don't want an hard-coded solution like this:
Code: Select all
x = layer[0](h[0], x),
Code: Select all
x = layer[1](h[1], x)
...
Weil das offensichtlich sehr unpraktisch wäre. Der Punkt ist, dass es möglich sein sollte, die Funktion zu bejagen, da ich alles zur Kompilierzeit bekannt sein sollte, aber ich kann keine Lösung finden.
Code: Select all
def __call__(self, x, i):
def forward(carry, x_t):
def layer_apply(h_t, i):
return self.layers[i](h_t[i], x_t)
t_memory = lax.scan(layer_apply, carry, i)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
Code: Select all
**layer_apply
return self.layers[i](h_t[i], x_t)
~~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function layer_apply at test.py:37 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError**
< /code>
Verwenden einer für Schleife nicht abgerollten < /li>
< /ol>
def __call__(self, x):
def forward(carry, x_t):
for layer, h_t in zip(self.layers, carry):
x_t = layer(h_t, x_t)
return carry, x_t
return lax.scan(forward, self.h, x)[1]
Code: Select all
import jax
import jax.numpy as jnp
import jax.lax as lax
import equinox as eqx
class Rnn(eqx.Module): #layer class
dim : int = eqx.static_field()
w_x : jnp.array
w_h : jnp.array
b : jnp.array
def __init__(self, dim):
self.dim = dim
self.w_x = jnp.ones((dim,dim), dtype=jnp.float32)
self.w_h = jnp.ones((dim,dim), dtype=jnp.float32)
self.b = jnp.zeros((dim,), dtype=jnp.float32)
def __call__(self, h, x): #rnn forward
return [email protected]_x + [email protected]_h + self.b
dim = 2
batch_size = 3
seq_len = 3
x = jnp.ones((seq_len, 3, dim)) #input
layers = (Rnn(dim), Rnn(dim)) #tuple of layers
h = (jnp.zeros((batch_size, dim), dtype=jnp.float32), jnp.zeros((batch_size, dim), dtype=jnp.float32)) #temporal memory
def scan_fn():
return #your solution !
print(scan_fn()) #a general result for the iteration in every rnn layer