Code: Select all
def __call__(self, x):
def forward(carry, x_t):
def layer_apply(h_t, layer):
return layer(h_t, x_t)
t_memory = lax.scan(layer_apply, carry, self.layers)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
< /code>
Dieser Code ist eine Funktion einer Klasse namens 'recurrent_block' (an eqx.module), die als Parameter ein Tupel von Rnn -Schichten akzeptiert und ein Selbst erstellt. jnp.zeros, one for layer: (h(0), h(1)... h(n))
[b]self.layers = tuple of eqx.Module classes, each with a[/b] [b]call[/b] method that implements the RNN forward pass
[b]x = jnp.array with shape (seq_len, batch_size, Funktionen) [/b]
(Wie Sie sehen können, ist alles zum Kompilierzeit bekannt)
Jetzt gibt es zwei große Probleme im Code:
[list]
lax.scan
Code: Select all
w_h, w_h
I would prefer the solution to use lax.scan as it is very optimized for this type of operation from what I know, even multiple nested scans are fine, the only important thing is that it can be compiled in XLA without any strange fallbacks, and I absolutely don't want an hard-coded solution like this:
Code: Select all
x = layer[0](h[0], x),
Code: Select all
x = layer[1](h[1], x)
...
Weil das offensichtlich sehr unpraktisch wäre. Der Punkt ist, dass es möglich sein sollte, die Funktion zu bejagen, da ich alles zur Kompilierzeit bekannt sein sollte, aber ich kann keine Lösung finden.
Code: Select all
def __call__(self, x, i):
def forward(carry, x_t):
def layer_apply(h_t, i):
return self.layers[i](h_t[i], x_t)
t_memory = lax.scan(layer_apply, carry, i)
return t_memory[0], t_memory[1]
return lax.scan(forward, self.h, x)
Code: Select all
**layer_apply
return self.layers[i](h_t[i], x_t)
~~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function layer_apply at test.py:37 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError**
< /code>
Verwenden einer für Schleife nicht abgerollten < /li>
< /ol>
def __call__(self, x):
def forward(carry, x_t):
for layer, h_t in zip(self.layers, carry):
x_t = layer(h_t, x_t)
return carry, x_t
return lax.scan(forward, self.h, x)[1]
Code: Select all
import jax
import jax.numpy as jnp
import jax.lax as lax
import equinox as eqx
class Rnn(eqx.Module): #layer class
dim : int = eqx.static_field()
w_x : jnp.array
w_h : jnp.array
b : jnp.array
def __init__(self, dim):
self.dim = dim
self.w_x = jnp.ones((dim,dim), dtype=jnp.float32)
self.w_h = jnp.ones((dim,dim), dtype=jnp.float32)
self.b = jnp.zeros((dim,), dtype=jnp.float32)
def __call__(self, h, x): #rnn forward
return [email protected]_x + [email protected]_h + self.b
dim = 2
batch_size = 3
seq_len = 3
x = jnp.ones((seq_len, 3, dim)) #input
layers = (Rnn(dim), Rnn(dim)) #tuple of layers
h = (jnp.zeros((batch_size, dim), dtype=jnp.float32), jnp.zeros((batch_size, dim), dtype=jnp.float32)) #temporal memory
def scan_fn():
return #your solution !
print(scan_fn()) #a general result for the iteration in every rnn layer