Iterabledataset nicht auf Grpotrainer unterstütztPython

Python-Programme
Guest
 Iterabledataset nicht auf Grpotrainer unterstützt

Post by Guest »

Das folgende Programm stürzt nach Ausführung < /p>
ab ">

Code: Select all

from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer

prompts = ["Hi", "Hello"]
def data_generator():
while True:
for s in prompts:
yield {"prompt" : s}
dataset = IterableDataset.from_generator(data_generator)

training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)

trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)

trainer.train()

< /code>
verursacht die folgende Ablaufverfolgung: < /p>
Traceback (most recent call last):
File "/home/pietro/Documents/Code/CS234/starter_code/trl_testing.py", line 24, in 
trainer.train()
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2500, in _inner_training_loop
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 5180, in get_batch_samples
batch_samples += [next(epoch_iterator)]
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 856, in __iter__
next_batch, next_batch_info = self._fetch_batches(main_iterator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 812, in _fetch_batches
batch = concatenate(batches, dim=0)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in concatenate
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in 
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in concatenate
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in 
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 619, in concatenate
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
TypeError: Can only concatenate tensors but got 
Ersetzen Sie jedoch den iterableDEdataset wie unten durch ein analoges Datensatz Das Problem behebt:
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer

prompts = ["Hi", "Hello"]
dataset = Dataset.from_dict({"prompt" : prompts})

training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)

trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)

trainer.train()
< /code>
Dies wurde auf 2 sehr unterschiedlichen Systemen reproduziert, daher ist dies unwahrscheinlich, dass dies die Ursache ist.
Fehlt mir etwas? < /p>

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post