High-Throughput Diffusion Inference System
As an Machine Learning Engineer intern at GMI Cloud, I designed and shipped a production inference system for a 12-billion-parameter Diffusion Transformer (DiT) model running on NVIDIA H100 GPU clusters. The system generates high-quality 1360×768 images from natural-language prompts in just 4 denoising steps, deployed as a long-running worker behind an internal request queue — handling continuous, concurrent traffic from real users.
What made this project challenging wasn't running the model once — it was making it run thousands of times a day without dropping a request, leaking memory, or going silent. I built the entire production wrapper from scratch: distributed multi-GPU coordination, request lifecycle management, cloud storage integration, health monitoring, and prompt safety checking.
System Design
Stateless Worker Architecture
Each GPU node operates as a stateless worker that registers with a centralized task queue on startup, continuously polls for incoming generation requests, and reports results back via REST. No session affinity — any worker can pick up any request, so horizontal scaling is literally just launching more nodes.
The worker lifecycle follows a clean state machine:
All four model components (T5 encoder, CLIP encoder, DiT flow model, VAE decoder) are loaded to the GPU at startup and stay resident permanently — zero cold-start penalty between requests. Compared to a load-generate-unload baseline, this delivered a 10–15× throughput improvement.
Multi-GPU Distributed Inference (NCCL)
The system launches via torchrun and initializes an NCCL process group, so adding GPUs is a config change, not a code change. The distributed coordination protocol:
- Rank-0 handles all I/O. Queue polling, status updates, heartbeats, cloud uploads — all on the primary process. Worker GPUs don't touch the network.
- Broadcast before inference. Once rank-0 pulls a request, the payload (prompt, dimensions, seed) is broadcast to all workers via NCCL
broadcast_object_list. - Barrier synchronization.
dist.barrier()before and after generation ensures lockstep — no GPU starts early, no GPU races ahead. - Seed consistency. Random seed broadcast from rank-0 guarantees deterministic output regardless of GPU count.
Health Monitoring (Heartbeat)
A daemon thread sends a periodic heartbeat POST to the queue scheduler every 30 seconds, with worker ID, GPU count, status ("ready" / "occupied" / "initializing"), and timestamp. Failed heartbeats retry with exponential backoff. Without this, a silently-dead node (OOM kill, GPU hang, network partition) would cause in-flight jobs to disappear permanently.
Request Lifecycle
On any exception — bad prompt, GPU error, upload failure — the system marks the request "failed" and continues to the next request. No single bad input can crash the server loop.
Inference Engine
The generation engine wraps the full Flux-Schnell model stack under @torch.inference_mode(). The pipeline mirrors the flow-matching paper's formulation:
An optional --offload flag supports CPU offload for memory-constrained environments, but trades ~3–5× latency increase for lower VRAM usage.
Content Safety Pipeline
Every request passes a two-stage safety check before the model processes it:
- Text-level filtering — toxicity detector scans the prompt; rejected before any GPU compute is spent.
- Image-level classification — NSFW classifier runs on every generated image before the upload step. Both stages run on rank-0 only, with zero distributed overhead.
Measured Performance
| Metric | Value |
|---|---|
| Output resolution | 1360 × 768 |
| Denoising steps | 4 (Flux-Schnell) |
| Per-image latency | ~1–3 seconds |
| Throughput per GPU | 20–30 images / minute |
| VRAM usage | ~8–12 GB (model-dependent) |
| Multi-GPU scaling | Linear (tested up to multi-node) |
What I Learned
This was my first time building a system where uptime matters more than accuracy. In production, a crash at 3 AM means real users get error pages. Designing for graceful degradation — catching every exception, retrying uploads, re-queuing failed jobs, never crashing the server loop — was a fundamentally different engineering discipline than anything I'd done in school.
I also learned that distributed GPU programming is 20% NCCL calls and 80% making sure every process agrees on what's happening. The broadcast-barrier-broadcast pattern for request distribution seems simple in retrospect, but getting the synchronization right — especially around error paths — took more debugging than the actual inference code.
