Wie kann ich mit meinem eigenen Datensatz richtig flux.1-dev für die benutzerdefinierte Emoji-Generation fein abgestimmePython

Python-Programme
Anonymous
 Wie kann ich mit meinem eigenen Datensatz richtig flux.1-dev für die benutzerdefinierte Emoji-Generation fein abgestimme

Post by Anonymous »

Ich baue einen in KI ansässigen Emoji-Generator (

Code: Select all

EmoGen-Z
) und ich möchte das schwarz-forest-labs/flux.1-dev modell zum generieren neuer emojis basierend auf benutzerdefinierten Textaufforderungen fein abstellen. PrettyPrint-Override ">emoji_dataset/
├── images/ # 512x512 PNGs
│ ├── image_1.png
│ ├── ...
├── captions/ # Matching text captions
│ ├── caption_1.txt
│ ├── ...
└── train.json # List of {"image": "images/image_1.png", "caption": "captions/caption_1.txt"}
< /code>
Ich habe zuvor ~ 25K hochwertige < /code> emoji Bilder aus verschiedenen Quellen (Apple, Google, Microsoft usw.) gesammelt. Emoji-Caption-Datensatz. /> 🔧 Was ich ausprobiert habe: < /H3>

Verwendete Chatgpt für diesen Code: < /li>
< /ul>
#!/usr/bin/env python3
import os
import json
import torch
import argparse
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from tqdm import tqdm
from diffusers import (
UNet2DConditionModel,
DDPMScheduler,
AutoencoderKL,
CLIPTextModel,
CLIPTokenizer
)

# -------------------------------
# 1. Define the Emoji Dataset
# -------------------------------
class EmojiDataset(Dataset):
def __init__(self, dataset_dir, transform=None):
self.dataset_dir = dataset_dir
with open(os.path.join(dataset_dir, "train.json"), "r") as f:
self.metadata = json.load(f)
self.transform = transform

def __len__(self):
return len(self.metadata)

def __getitem__(self, idx):
sample = self.metadata[idx]
image_path = os.path.join(self.dataset_dir, sample["image"])
caption_path = os.path.join(self.dataset_dir, sample["caption"])
image = Image.open(image_path).convert("RGB")
if self.transform:
image = self.transform(image)
with open(caption_path, "r", encoding="utf-8") as f:
caption = f.read().strip()
return image, caption

def collate_fn(batch):
images, captions = zip(*batch)
images = torch.stack(images)
return images, list(captions)

# -------------------------------------------
# 2. Utility: Save Sample Generated Images
# -------------------------------------------
def save_sample_images(writer, step, unet, vae, text_encoder, tokenizer, scheduler, device):
prompts = [
"emoji of smiley face",
"emoji of a Horse-Man",
"emoji of a flying-pig",
"" # prompt-less generation
]
num_inference_steps = 50
sample_images = []
for prompt in prompts:
# Tokenize prompt
inputs = tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt"
)
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
text_embeddings = text_encoder(input_ids)[0]

# Start with random noise in the latent space (assumed dimensions)
latents = torch.randn((1, unet.in_channels, 64, 64), device=device)
scheduler.set_timesteps(num_inference_steps)
for t in scheduler.timesteps:
with torch.no_grad():
noise_pred = unet(latents, t, encoder_hidden_states=text_embeddings).sample
latents = scheduler.step(noise_pred, t, latents).prev_sample

# Decode the latent to image space using the VAE
with torch.no_grad():
image = vae.decode(latents).sample
# Postprocess: scale to [0,1] and convert to PIL
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
image = (image * 255).astype("uint8")
pil_image = Image.fromarray(image)
sample_images.append(pil_image)

# Create a grid of images and log to TensorBoard
grid = utils.make_grid([transforms.ToTensor()(img) for img in sample_images], nrow=2)
writer.add_image("Samples", grid, global_step=step)

# -------------------------------------------
# 3. Utility: Export Model to ONNX
# -------------------------------------------
def export_to_onnx(unet, export_path="flux_unet.onnx", device="cuda"):
dummy_latents = torch.randn(1, unet.in_channels, 64, 64, device=device)
dummy_timestep = torch.tensor([1], device=device)
# Typical dimensions for text encoder output (e.g., CLIP) are [batch, sequence_length, hidden_dim]
dummy_hidden_states = torch.randn(1, 77, 768, device=device)
torch.onnx.export(
unet,
(dummy_latents, dummy_timestep, dummy_hidden_states),
export_path,
input_names=["latents", "timestep", "encoder_hidden_states"],
output_names=["noise_pred"],
dynamic_axes={
"latents": {0: "batch_size"},
"encoder_hidden_states": {0: "batch_size"}
},
opset_version=11
)
print(f"Model exported to {export_path}")

# -------------------------------
# 4. Main Training Routine
# -------------------------------
def main():
parser = argparse.ArgumentParser(description="Fine-tune FLUX.1 for Emoji Generation")
parser.add_argument("--dataset_dir", type=str, default="emoji_dataset", help="Path to the emoji dataset")
parser.add_argument("--model_id", type=str, default="black-forest-labs/FLUX.1-dev", help="Pretrained model ID")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
parser.add_argument("--num_epochs", type=int, default=5, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--log_interval", type=int, default=50, help="Steps between TensorBoard logging")
parser.add_argument("--save_interval", type=int, default=500, help="Steps between saving checkpoints")
parser.add_argument("--output_dir", type=str, default="output", help="Directory for checkpoints and exports")
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(args.output_dir, exist_ok=True)
writer = SummaryWriter(log_dir=os.path.join(args.output_dir, "logs"))

# Data transformation: ensure images are 512x512 and normalized
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = EmojiDataset(args.dataset_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
collate_fn=collate_fn, num_workers=4)

# -------------------------------
# 5. Load Pretrained Model Components
# -------------------------------
print("Loading model components...")
unet = UNet2DConditionModel.from_pretrained(args.model_id, subfolder="unet").to(device)
vae = AutoencoderKL.from_pretrained(args.model_id, subfolder="vae").to(device)
text_encoder = CLIPTextModel.from_pretrained(args.model_id, subfolder="text_encoder").to(device)
tokenizer = CLIPTokenizer.from_pretrained(args.model_id, subfolder="tokenizer")
scheduler = DDPMScheduler.from_pretrained(args.model_id, subfolder="scheduler")

# Set to training mode
unet.train()
text_encoder.train()

optimizer = optim.AdamW(list(unet.parameters()) + list(text_encoder.parameters()), lr=args.learning_rate)
scaler = torch.cuda.amp.GradScaler() # Mixed-precision training for efficiency

global_step = 0
for epoch in range(args.num_epochs):
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}")
for images, captions in progress_bar:
images = images.to(device)
# Encode images into latent space using the VAE
with torch.no_grad():
latents = vae.encode(images).latent_dist.sample() * 0.18215

# Sample noise and random timesteps
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(0, scheduler.num_train_timesteps, (bsz,), device=device).long()
noisy_latents = scheduler.add_noise(latents, noise, timesteps)

# Ensure each caption contains the trigger word "emoji"
processed_captions = []
for cap in captions:
if "emoji" not in cap.lower():
cap = "emoji, " + cap
processed_captions.append(cap)
inputs = tokenizer(
processed_captions,
padding="max_length",
truncation=True,
max_length=tokenizer.model_max_length,
return_tensors="pt"
)
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
encoder_hidden_states = text_encoder(input_ids)[0]

optimizer.zero_grad()
with torch.cuda.amp.autocast():
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
loss = nn.MSELoss()(noise_pred, noise)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()

global_step += 1
if global_step % args.log_interval == 0:
writer.add_scalar("Loss/train", loss.item(), global_step)
progress_bar.set_postfix(loss=loss.item())
# Log generated samples to TensorBoard
unet.eval()
save_sample_images(writer, global_step, unet, vae, text_encoder, tokenizer, scheduler, device)
unet.train()

if global_step % args.save_interval == 0:
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}.pt")
torch.save({
"unet": unet.state_dict(),
"text_encoder": text_encoder.state_dict(),
"optimizer": optimizer.state_dict(),
"global_step": global_step
}, checkpoint_path)
print(f"Checkpoint saved at step {global_step}")

# Save checkpoint at the end of each epoch
checkpoint_path = os.path.join(args.output_dir, f"checkpoint-epoch-{epoch+1}.pt")
torch.save({
"unet": unet.state_dict(),
"text_encoder": text_encoder.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch+1,
"global_step": global_step
}, checkpoint_path)
print(f"Epoch {epoch+1} checkpoint saved.")

# -------------------------------
# 6. Export the Final Model
# -------------------------------
unet.eval()
export_to_onnx(unet, export_path=os.path.join(args.output_dir, "flux_unet.onnx"), device=device)
writer.close()
print("Training complete. Model is ready for integration into your Flutter app (via ONNX or subsequent conversion to TFLite).")

if __name__ == "__main__":
main()
< /code>
Das Training läuft gut, aber: < /p>

Ich verstehe nicht, ob es etwas lernt. Der Verlust sinkt ein bisschen, scheint aber laut zu sein. < /Li>
Ich kann die Tensorboard -Ausgänge nicht interpretieren. Trainingsanleitung < /h5>

Ist dies der richtige Weg, um Fluss zu optun.1-dev auf einem Datensatz wie meiner? Datensatzberatung < /h5>

Ist 25K -Bilder zu viel oder zu wenig für ein Modell wie dieses? < /Li>
Ist 512x512 Auflösung in Ordnung? Überwachung des Trainings < /h5>

Wie kann ich den echten Fortschritt visuell verfolgen (Bildausgaben, Qualität, Überanpassung usw.)? Best Practices < /h5>

Gibt es Standard-Feinabstimmungskonfigurationen (Stapelgröße, Lernrate, Schritte) für den Fluss.1?>

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post