Ich interessiere mich für die Schulung eines neuronalen Netzwerks mit JAX. Ich habe mir tf.data.dataset gesehen, aber es liefert ausschließlich TF -Tensoren. Ich suchte nach einer Möglichkeit, den Datensatz in Jax Numpy Array zu verwandeln, und fand viele Implementierungen, die Dataset verwenden. Ich frage mich jedoch, ob es eine gute Praxis ist, da Numpy -Arrays im CPU -Speicher gespeichert sind und es nicht das ist, was ich für mein Training möchte (ich benutze die GPU). Die letzte Idee, die ich gefunden habe, besteht darin, die Arrays manuell neu zu neu umzusetzen, indem sie JNP.Array aufrufen, aber es ist nicht wirklich elegant (ich habe Angst vor der Kopie im GPU -Speicher). Hat jemand eine bessere Idee dafür?
Code: Select all
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
# not good
# good but not elegant