High-Throughput Diffusion Inference System

As an Machine Learning Engineer intern at GMI Cloud, I designed and shipped a production inference system for a 12-billion-parameter Diffusion Transformer (DiT) model running on NVIDIA H100 GPU clusters. The system generates high-quality 1360×768 images from natural-language prompts in just 4 denoising steps, deployed as a long-running worker behind an internal request queue — handling continuous, concurrent traffic from real users.

What made this project challenging wasn't running the model once — it was making it run thousands of times a day without dropping a request, leaking memory, or going silent. I built the entire production wrapper from scratch: distributed multi-GPU coordination, request lifecycle management, cloud storage integration, health monitoring, and prompt safety checking.

20–30
Images / Min / GPU
1–3s
Per-Image Latency
12B
Model Parameters
4
Denoising Steps
10–15×
Throughput Gain
PyTorchNCCLtorchrun FastAPIGoogle Cloud Storage H100 GPUFlux-Schnell DiTVAE T5 / CLIP

System Design

Stateless Worker Architecture

Each GPU node operates as a stateless worker that registers with a centralized task queue on startup, continuously polls for incoming generation requests, and reports results back via REST. No session affinity — any worker can pick up any request, so horizontal scaling is literally just launching more nodes.

The worker lifecycle follows a clean state machine:

1. Boot & Validate
2. Register (get worker ID)
3. Load Models to GPU
4. Infinite Serve Loop

All four model components (T5 encoder, CLIP encoder, DiT flow model, VAE decoder) are loaded to the GPU at startup and stay resident permanently — zero cold-start penalty between requests. Compared to a load-generate-unload baseline, this delivered a 10–15× throughput improvement.

Multi-GPU Distributed Inference (NCCL)

The system launches via torchrun and initializes an NCCL process group, so adding GPUs is a config change, not a code change. The distributed coordination protocol:

  • Rank-0 handles all I/O. Queue polling, status updates, heartbeats, cloud uploads — all on the primary process. Worker GPUs don't touch the network.
  • Broadcast before inference. Once rank-0 pulls a request, the payload (prompt, dimensions, seed) is broadcast to all workers via NCCL broadcast_object_list.
  • Barrier synchronization. dist.barrier() before and after generation ensures lockstep — no GPU starts early, no GPU races ahead.
  • Seed consistency. Random seed broadcast from rank-0 guarantees deterministic output regardless of GPU count.
NCCL timeout set to 120 seconds with async error handling — prevents a single slow GPU from deadlocking the entire cluster.

Health Monitoring (Heartbeat)

A daemon thread sends a periodic heartbeat POST to the queue scheduler every 30 seconds, with worker ID, GPU count, status ("ready" / "occupied" / "initializing"), and timestamp. Failed heartbeats retry with exponential backoff. Without this, a silently-dead node (OOM kill, GPU hang, network partition) would cause in-flight jobs to disappear permanently.

Request Lifecycle

Pull from Queue
Validate (dims mod 16, safety)
Status → Processing
DiT Inference
Upload to GCS
Status → Success / Failed

On any exception — bad prompt, GPU error, upload failure — the system marks the request "failed" and continues to the next request. No single bad input can crash the server loop.


Inference Engine

The generation engine wraps the full Flux-Schnell model stack under @torch.inference_mode(). The pipeline mirrors the flow-matching paper's formulation:

Sample Noise
Text/Image Conditioning (T5+CLIP)
Denoising Schedule
Flow Denoise Loop (4 steps)
Unpack Latents
VAE Decode → Pixels

An optional --offload flag supports CPU offload for memory-constrained environments, but trades ~3–5× latency increase for lower VRAM usage.

Content Safety Pipeline

Every request passes a two-stage safety check before the model processes it:

  1. Text-level filtering — toxicity detector scans the prompt; rejected before any GPU compute is spent.
  2. Image-level classification — NSFW classifier runs on every generated image before the upload step. Both stages run on rank-0 only, with zero distributed overhead.

Measured Performance

MetricValue
Output resolution1360 × 768
Denoising steps4 (Flux-Schnell)
Per-image latency~1–3 seconds
Throughput per GPU20–30 images / minute
VRAM usage~8–12 GB (model-dependent)
Multi-GPU scalingLinear (tested up to multi-node)

What I Learned

This was my first time building a system where uptime matters more than accuracy. In production, a crash at 3 AM means real users get error pages. Designing for graceful degradation — catching every exception, retrying uploads, re-queuing failed jobs, never crashing the server loop — was a fundamentally different engineering discipline than anything I'd done in school.

I also learned that distributed GPU programming is 20% NCCL calls and 80% making sure every process agrees on what's happening. The broadcast-barrier-broadcast pattern for request distribution seems simple in retrospect, but getting the synchronization right — especially around error paths — took more debugging than the actual inference code.