Wie zeige ich in jeder Registerkarte einer Gradio-App einen eigenen Fortschrittsbalken an?Python

Python-Programme
Guest
 Wie zeige ich in jeder Registerkarte einer Gradio-App einen eigenen Fortschrittsbalken an?

Post by Guest »

Ich entwickle eine Gradio-App zur Bildgenerierung, die mehrere Modelle wie SD3.5, Flux und andere verwendet, um Bilder aus einer bestimmten Eingabeaufforderung zu generieren.
Die App verfügt jeweils über 7 Registerkarten einem bestimmten Modell entspricht. Auf jeder Registerkarte wird ein vom jeweiligen Modell generiertes Bild angezeigt.
Mein Problem ist, dass ich nicht für jede Registerkarte einzeln einen Fortschrittsbalken anzeigen kann. Derzeit wird der Fortschrittsbalken auf allen Registerkarten gleichzeitig angezeigt. Allerdings benötige ich einen „tab-spezifischen Fortschrittsbalken“.
Unten finden Sie meine Codebasis und Screenshots der App, die den Bildgenerierungsprozess im Anhang nachahmen. Wie kann ich diese Funktion implementieren?

Code: Select all

import random
from time import sleep

import gradio as gr
import threading
import requests
from PIL import Image
from io import BytesIO

# Constants
MAX_IMAGE_SIZE = 1024

# Model configurations
MODEL_CONFIGS = {
"Stable Diffusion 3.5": {
"repo_id": "stabilityai/stable-diffusion-3.5-large",
"pipeline_class": "StableDiffusion3Pipeline"
},
"FLUX": {
"repo_id": "black-forest-labs/FLUX.1-dev",
"pipeline_class": "FluxPipeline"
},
"PixArt": {
"repo_id": "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
"pipeline_class": "PixArtSigmaPipeline"
},
"AuraFlow": {
"repo_id": "fal/AuraFlow",
"pipeline_class": "AuraFlowPipeline"
},
"Kandinsky": {
"repo_id": "kandinsky-community/kandinsky-3",
"pipeline_class": "Kandinsky3Pipeline"
},
"Hunyuan": {
"repo_id": "Tencent-Hunyuan/HunyuanDiT-Diffusers",
"pipeline_class": "HunyuanDiTPipeline"
},
"Lumina": {
"repo_id": "Alpha-VLLM/Lumina-Next-SFT-diffusers",
"pipeline_class": "LuminaText2ImgPipeline"
}
}

# Dictionary to store model pipelines
pipes = {}
model_locks = {model_name: threading.Lock() for model_name in MODEL_CONFIGS.keys()}

def fetch_image_from_url(url):
try:
response = requests.get(url)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except Exception as e:
print(f"Error fetching image from URL {url}: {e}")
return None

def generate_all(prompt, negative_prompt, seed, randomize_seed, width, height,
guidance_scale, num_inference_steps):
# Initialize a list to store all outputs
all_outputs = [None] * (len(MODEL_CONFIGS) * 2)  # Pre-fill with None for each model's image and seed

for idx, model_name in enumerate(MODEL_CONFIGS.keys()):
try:
progress_dict[model_name](0, desc=f"Starting generation for {model_name}...")
print(f"IMAGE GENERATING {model_name}")
generated_seed = seed if not randomize_seed else random.randint(0, 100000)

# Fetch an image from a URL
url = f"https://placehold.co/600x400/000000/FFFFFF.png?text=Hello+{model_name}+ +{generated_seed}"   # Replace with actual URL as needed
image = fetch_image_from_url(url)

progress_dict[model_name](0.9, desc=f"downloaded {model_name}...")
# Update the outputs array with the result and seed, leaving remaining slots as None
all_outputs[idx * 2] = image  # Image slot
all_outputs[idx * 2 + 1] = generated_seed  # Seed slot

# Add intermediate results to progress * (len(all_outputs) - len(all_outputs))
yield all_outputs + [None]
progress_dict[model_name](1, desc=f"generated {model_name}...")
sleep(1)  # Simulate processing time

except Exception as e:
print(f"Error generating with {model_name}: {str(e)}")
# Leave the slots for this model as None
all_outputs[idx * 2] = None
all_outputs[idx * 2 + 1] = None

# Return the final completed array
return all_outputs

# Gradio Interface
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""

with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# Multi-Model Image Generation")

with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Generate", scale=0, variant="primary")

with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=100,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

memory_indicator = gr.Markdown("Current memory usage: 0 GB")

with gr.Row():
with gr.Column(scale=2):
with gr.Tabs() as tabs:
results = {}
seeds = {}
progress_dict: dict[str, gr.Progress] = {}

for model_name in MODEL_CONFIGS.keys():
with gr.Tab(model_name):
results[model_name] = gr.Image(label=f"{model_name} Result")
seeds[model_name] = gr.Number(label="Seed used", visible=True)
progress_dict[model_name] = gr.Progress()

# Prepare the input and output components
input_components = [
prompt, seed, randomize_seed,
]

output_components = []
for model_name in MODEL_CONFIGS.keys():
output_components.extend([results[model_name], seeds[model_name]])

run_button.click(
fn=generate_all,
inputs=input_components,
outputs=output_components,
)

if __name__ == "__main__":
demo.launch()
Image

Quick Reply

Change Text Case: 
   
  • Similar Topics
    Replies
    Views
    Last post