Verwenden Sie einen tf.data.dataset zu einem jax.numpy iteratorPython

Python-Programme
Anonymous
 Verwenden Sie einen tf.data.dataset zu einem jax.numpy iterator

Post by Anonymous »

Ich interessiere mich für die Schulung eines neuronalen Netzwerks mit JAX. Ich habe mir tf.data.dataset gesehen, aber es liefert ausschließlich TF -Tensoren. Ich suchte nach einer Möglichkeit, den Datensatz in Jax Numpy Array zu verwandeln, und fand viele Implementierungen, die Dataset verwenden. Ich frage mich jedoch, ob es eine gute Praxis ist, da Numpy -Arrays im CPU -Speicher gespeichert sind und es nicht das ist, was ich für mein Training möchte (ich benutze die GPU). Die letzte Idee, die ich gefunden habe, besteht darin, die Arrays manuell neu zu neu umzusetzen, indem sie JNP.Array aufrufen, aber es ist nicht wirklich elegant (ich habe Angst vor der Kopie im GPU -Speicher). Hat jemand eine bessere Idee dafür?

Code: Select all

import os
import jax.numpy as jnp
import tensorflow as tf

def generator():
for _ in range(2):
yield tf.random.uniform((1, ))

ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))

ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)

for i, batch in enumerate(ds1):
print(type(batch))

for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))

# returns:

 # not good
 # good but not elegant

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post