JAX LAX.SCAN: Wie kann man über Ebenen und Speicherscheiben gleichzeitig ohne dynamische Indexierung in einer mehrschichPython

Python-Programme
Anonymous
 JAX LAX.SCAN: Wie kann man über Ebenen und Speicherscheiben gleichzeitig ohne dynamische Indexierung in einer mehrschich

Post by Anonymous »

Ich versuche, ein Framework zu implementieren, in dem RNN -Netzwerke mit einer willkürlichen Anzahl von Ebenen verwaltet werden (es ist Teil einer Bibliothek, die ich basierend auf JAX/Equinox baue). Das Problem ist, dass ich keine Implementierung finden kann, die es mir ermöglicht, mehrere Ebenen modular zu verwalten:

Code: Select all

    def __call__(self, x):

def forward(carry, x_t):

def layer_apply(h_t, layer):
return layer(h_t, x_t)

t_memory = lax.scan(layer_apply, carry, self.layers)

return t_memory[0], t_memory[1]

return lax.scan(forward, self.h, x)
< /code>
Dieser Code ist eine Funktion einer Klasse namens 'recurrent_block' (an eqx.module), die als Parameter ein Tupel von Rnn -Schichten akzeptiert und ein Selbst erstellt. jnp.zeros, one for layer: (h(0), h(1)... h(n))
[b]self.layers = tuple of eqx.Module classes, each with a[/b] [b]call[/b] method that implements the RNN forward pass
[b]x = jnp.array with shape (seq_len, batch_size, Funktionen) [/b] 
(Wie Sie sehen können, ist alles zum Kompilierzeit bekannt) 
Jetzt gibt es zwei große Probleme  im Code: 
[list]
lax.scan
akzeptiert XS mit nicht-homogenen Formen nicht (

Code: Select all

w_h, w_h
und der Bias haben von Natur aus unterschiedliche Formen), daher kann ich das Layer -Tupel nicht als XS zum Scannen übergeben. />[/list]
I would prefer the solution to use lax.scan as it is very optimized for this type of operation from what I know, even multiple nested scans are fine, the only important thing is that it can be compiled in XLA without any strange fallbacks, and I absolutely don't want an hard-coded solution like this:

Code: Select all

x = layer[0](h[0], x),

Code: Select all

x = layer[1](h[1], x)

...
Weil das offensichtlich sehr unpraktisch wäre. Der Punkt ist, dass es möglich sein sollte, die Funktion zu bejagen, da ich alles zur Kompilierzeit bekannt sein sollte, aber ich kann keine Lösung finden.

Code: Select all

    def __call__(self, x, i):

def forward(carry, x_t):

def layer_apply(h_t, i):
return self.layers[i](h_t[i], x_t)

t_memory = lax.scan(layer_apply, carry, i)

return t_memory[0], t_memory[1]

return lax.scan(forward, self.h, x)
wobei i ein Indexarray mit Arange ist. Der Fehler dieses Codes lautet:

Code: Select all

**layer_apply
return self.layers[i](h_t[i], x_t)
~~~~~~^^^
jax.errors.TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
The error occurred while tracing the function layer_apply at test.py:37 for scan. This concrete value was not available in Python because it depends on the value of the argument i.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerIntegerConversionError**
< /code>

 Verwenden einer für Schleife nicht abgerollten < /li>
< /ol>
    def __call__(self, x):

def forward(carry, x_t):
for layer, h_t in zip(self.layers, carry):
x_t = layer(h_t, x_t)
return carry, x_t

return lax.scan(forward, self.h, x)[1]
Dies ist wahrscheinlich mein Favorit, aber wie Sie sehen können, ist das Problem, dass Carry nicht aktualisiert wird, sodass es keinen Fehler erhebt, aber es macht keinen Sinn.

Code: Select all

import jax
import jax.numpy as jnp
import jax.lax as lax
import equinox as eqx

class Rnn(eqx.Module): #layer class
dim : int = eqx.static_field()
w_x : jnp.array
w_h : jnp.array
b : jnp.array

def __init__(self, dim):
self.dim = dim
self.w_x = jnp.ones((dim,dim), dtype=jnp.float32)
self.w_h = jnp.ones((dim,dim), dtype=jnp.float32)
self.b = jnp.zeros((dim,), dtype=jnp.float32)

def __call__(self, h, x): #rnn forward
return [email protected]_x + [email protected]_h + self.b

dim = 2
batch_size = 3
seq_len = 3

x = jnp.ones((seq_len, 3, dim)) #input

layers = (Rnn(dim), Rnn(dim)) #tuple of layers

h = (jnp.zeros((batch_size, dim), dtype=jnp.float32), jnp.zeros((batch_size, dim), dtype=jnp.float32)) #temporal memory

def scan_fn():
return #your solution !

print(scan_fn()) #a general result for the iteration in every rnn layer

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post