Iterabledataset nicht auf Grpotrainer unterstützt
Posted: 24 Feb 2025, 06:08
Das folgende Programm stürzt nach Ausführung < /p>
ab ">
Ersetzen Sie jedoch den iterableDEdataset wie unten durch ein analoges Datensatz Das Problem behebt:
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer
prompts = ["Hi", "Hello"]
dataset = Dataset.from_dict({"prompt" : prompts})
training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)
trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)
trainer.train()
< /code>
Dies wurde auf 2 sehr unterschiedlichen Systemen reproduziert, daher ist dies unwahrscheinlich, dass dies die Ursache ist.
Fehlt mir etwas? < /p>
ab ">
Code: Select all
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer
prompts = ["Hi", "Hello"]
def data_generator():
while True:
for s in prompts:
yield {"prompt" : s}
dataset = IterableDataset.from_generator(data_generator)
training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)
trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)
trainer.train()
< /code>
verursacht die folgende Ablaufverfolgung: < /p>
Traceback (most recent call last):
File "/home/pietro/Documents/Code/CS234/starter_code/trl_testing.py", line 24, in
trainer.train()
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2241, in train
return inner_training_loop(
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 2500, in _inner_training_loop
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/transformers/trainer.py", line 5180, in get_batch_samples
batch_samples += [next(epoch_iterator)]
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 856, in __iter__
next_batch, next_batch_info = self._fetch_batches(main_iterator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/data_loader.py", line 812, in _fetch_batches
batch = concatenate(batches, dim=0)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in concatenate
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 81, in honor_type
return type(obj)(generator)
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 615, in
return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0]))))
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in concatenate
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 617, in
return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()})
File "/home/pietro/.conda/envs/cs234_3/lib/python3.9/site-packages/accelerate/utils/operations.py", line 619, in concatenate
raise TypeError(f"Can only concatenate tensors but got {type(data[0])}")
TypeError: Can only concatenate tensors but got
from datasets import IterableDataset, Dataset
from trl import GRPOConfig, GRPOTrainer
prompts = ["Hi", "Hello"]
dataset = Dataset.from_dict({"prompt" : prompts})
training_args = GRPOConfig(
output_dir= "tmp",
max_steps = 1000,
)
trainer = GRPOTrainer(
model="facebook/opt-350m",
reward_funcs=lambda prompts,completions, **kwargs: [1]*8,
train_dataset=dataset,
args=training_args,
)
trainer.train()
< /code>
Dies wurde auf 2 sehr unterschiedlichen Systemen reproduziert, daher ist dies unwahrscheinlich, dass dies die Ursache ist.
Fehlt mir etwas? < /p>