Ich habe verschiedene Methoden zur Feinabstimmung ausprobiert, aber keine davon hat funktioniert.
Ich habe auch einen Ausschnitt meiner Feinabstimmungsmethode für Llama geteilt und derzeit kommt es während des Trainings immer wieder zu Fehlern.
Ich möchte auch klarstellen, dass ich sie verwende eine RTX 4090 mit 64 GB RAM und einem I9-14900k.
Und unten ist das Format für meinen Roboflow-Datensatz:
Yolov8-Format für Datensatz
Code: Select all
def train_llama(images, descriptions):
print("Initializing model and processor...")
model_id = "meta-llama/Llama-Guard-3-11B-Vision"
# Initialize processor
processor = AutoProcessor.from_pretrained(model_id)
# Initialize model
model = AutoModelForImageTextToText.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda"
)
model.config.use_cache = False
print("Creating dataset...")
dataset = CustomImageTextDataset(images, descriptions, processor)
# Split dataset
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
print(f"Train size: {train_size}, Validation size: {val_size}")
# Training arguments
training_args = TrainingArguments(
output_dir="./llama_finetuned",
learning_rate=1e-5,
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_ratio=0.05,
logging_steps=10,
eval_strategy="steps",
eval_steps=50,
save_strategy="steps",
save_steps=100,
fp16=True,
gradient_checkpointing=True,
remove_unused_columns=False,
report_to="tensorboard",
load_best_model_at_end=True,
metric_for_best_model="loss",
dataloader_num_workers=0
)
# Initialize trainer
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset
)
print("Starting Llama training...")
try:
trainer.train()
print("Training completed successfully!")
print("Saving model...")
trainer.save_model("./final_llama_model")
print("Model saved successfully!")
return True
except Exception as e:
print(f"Error during training: {str(e)}")
print(f"Traceback: {traceback.format_exc()}")
return False