From 5af7dc84c27ad4e51b1839b296db12b6695b4a89 Mon Sep 17 00:00:00 2001 From: Charles Frye Date: Thu, 29 Jan 2026 14:58:17 -0800 Subject: [PATCH] adds a 'kitchen sink' example for sglang + qwen (#1463) * adds a 'kitchen sink' example for sglang + qwen * finish renaming * ruff --- .../llm-serving/sglang_kitchen_sink.py | 358 ++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 06_gpu_and_ml/llm-serving/sglang_kitchen_sink.py diff --git a/06_gpu_and_ml/llm-serving/sglang_kitchen_sink.py b/06_gpu_and_ml/llm-serving/sglang_kitchen_sink.py new file mode 100644 index 00000000..4bd7965e --- /dev/null +++ b/06_gpu_and_ml/llm-serving/sglang_kitchen_sink.py @@ -0,0 +1,358 @@ +# --- +# deploy: true +# cmd: ["python", "06_gpu_and_ml/llm-serving/sglang_kitchen_sink.py"] +# --- + +# # Fast-booting, low-latency Qwen 3 8B with SGLang, GPU snapshots, and speculative decoding + +# This is a bare-bones "kitchen-sink" demo of all of the tips and tricks +# you can use to make Qwen 3 8B go brrt. + +# Unlike our other examples, this demo includes limited explanation of the code. +# For a detailed guide to the principles and practices implemented here, +# see [this guide](https://modal.com/docs/guide/high-performance-llm-inference). + +# ## Set up the container image + +import asyncio +import json +import subprocess +import time + +import aiohttp +import modal +import modal.experimental + +MINUTES = 60 # seconds + +sglang_image = ( + modal.Image.from_registry( + "modalresearch/sglang:v0.5.7-fa4-dflash-preview" # bleeding-edge custom SGLang build + ).entrypoint([]) # silence chatty logs on container start +) + +sglang_image.env( + { # bleeding-edge SGLang perf opt settings + "SGLANG_ENABLE_SPEC_V2": "1", + "SGLANG_ENABLE_DFLASH_SPEC_V2": "1", + "SGLANG_ENABLE_OVERLAP_PLAN_STREAM": "1", + } +) + +# ## Choose a GPU + +GPU_TYPE, N_GPUS = "B200", 1 +GPU = f"{GPU_TYPE}:{N_GPUS}" + +# ### Loading and cacheing the model weights + +MODEL_NAME = "Qwen/Qwen3-8B-FP8" +MODEL_REVISION = ( # pin revision id to avoid nasty surprises! + "220b46e3b2180893580a4454f21f22d3ebb187d3" # latest commit as of 2026-01-29, from 2025-07-25 +) + +sglang_image = sglang_image.uv_pip_install("huggingface-hub==0.36.0") + +HF_CACHE_VOL = modal.Volume.from_name("huggingface-cache", create_if_missing=True) +HF_CACHE_PATH = "/root/.cache/huggingface" +MODEL_PATH = f"{HF_CACHE_PATH}/{MODEL_NAME}" + +sglang_image = sglang_image.env( + {"HF_HUB_CACHE": HF_CACHE_PATH, "HF_XET_HIGH_PERFORMANCE": "1"} +) + +# ### Cacheing compilation artifacts + +# JIT DeepGEMM kernels are on by default, but we explicitly enable them via an environment variable. + +DG_CACHE_VOL = modal.Volume.from_name("deepgemm-cache", create_if_missing=True) +DG_CACHE_PATH = "/root/.cache/deepgemm" + +sglang_image = sglang_image.env({"SGLANG_ENABLE_JIT_DEEPGEMM": "1"}) + + +def compile_deep_gemm(): + import os + + if int(os.environ.get("SGLANG_ENABLE_JIT_DEEPGEMM", "1")): + subprocess.run( + f"python3 -m sglang.compile_deep_gemm --model-path {MODEL_NAME} --revision {MODEL_REVISION} --tp {N_GPUS}", + shell=True, + ) + + +sglang_image = sglang_image.run_function( + compile_deep_gemm, + volumes={DG_CACHE_PATH: DG_CACHE_VOL, HF_CACHE_PATH: HF_CACHE_VOL}, + gpu=GPU, +) + +# ## Configure SGLang for minimal latency + +speculative_config = { # use bleeding-edge speculative decoding method + "speculative-algorithm": "DFLASH", + "speculative-draft-model-path": "z-lab/Qwen3-8B-DFlash-b16", +} + +# ## Speed up cold starts with GPU snapshotting + +sglang_image = sglang_image.env({"TORCHINDUCTOR_COMPILE_THREADS": "1"}) + +# ### Sleeping and waking an SGLang server + +with sglang_image.imports(): + import requests + + +def warmup(): + payload = { + "messages": [{"role": "user", "content": "Hello, how are you?"}], + "max_tokens": 16, + } + for _ in range(3): + requests.post( + f"http://127.0.0.1:{PORT}/v1/chat/completions", json=payload, timeout=10 + ).raise_for_status() + + +def sleep(): + requests.post( + f"http://127.0.0.1:{PORT}/release_memory_occupation", json={} + ).raise_for_status() + + +def wake_up(): + requests.post( + f"http://127.0.0.1:{PORT}/resume_memory_occupation", json={} + ).raise_for_status() + + +# ## Define the inference server and infrastructure + +# ### Selecting infrastructure to minimize latency + +REGION = "us" +PROXY_REGION = "us-west" + +MIN_CONTAINERS = 0 # set to 1 to ensure one replica is always ready + +# ### Determining autoscaling policy with `@modal.concurrent` + +TARGET_INPUTS = 10 + +# ### Controlling container lifecycles with `modal.Cls` + + +def wait_ready(process: subprocess.Popen, timeout: int = 5 * MINUTES): + deadline = time.time() + timeout + while time.time() < deadline: + try: + check_running(process) + requests.get(f"http://127.0.0.1:{PORT}/health").raise_for_status() + return + except ( + subprocess.CalledProcessError, + requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, + ): + time.sleep(5) + raise TimeoutError(f"SGLang server not ready within {timeout} seconds") + + +def check_running(p: subprocess.Popen): + if (rc := p.poll()) is not None: + raise subprocess.CalledProcessError(rc, cmd=p.args) + + +app = modal.App(name="example-sglang-kitchen-sink") +PORT = 8000 + + +@app.cls( + image=sglang_image, + gpu=GPU, + volumes={DG_CACHE_PATH: DG_CACHE_VOL, HF_CACHE_PATH: HF_CACHE_VOL}, + region=REGION, + min_containers=MIN_CONTAINERS, + enable_memory_snapshot=True, + experimental_options={"enable_gpu_snapshot": True}, +) +@modal.experimental.http_server( + port=PORT, # wrapped code must listen on this port + proxy_regions=[PROXY_REGION], # location of proxies, should be same as Cls region + exit_grace_period=5, # seconds, time to finish up requests when closing down +) +@modal.concurrent(target_inputs=TARGET_INPUTS) +class SGLang: + @modal.enter(snap=True) + def startup(self): + """Start the SGLang server and block until it is healthy, then warm it up and put it to sleep.""" + cmd = [ + "python", + "-m", + "sglang.launch_server", + "--model-path", + MODEL_NAME, + "--revision", + MODEL_REVISION, + "--served-model-name", + MODEL_NAME, + "--host", + "0.0.0.0", + "--port", + f"{PORT}", + "--tp", # use all GPUs to split up tensor-parallel operations + f"{N_GPUS}", + "--cuda-graph-max-bs", # capture CUDA graphs up to batch sizes we're likely to observe + f"{TARGET_INPUTS * 2}", + "--max-running-requests", + f"{TARGET_INPUTS * 4}", + "--enable-metrics", # expose metrics endpoints for telemetry + "--enable-memory-saver", # enable offload, for snapshotting + "--enable-weights-cpu-backup", # enable offload, for snapshotting + "--decode-log-interval", # how often to log during decoding, in tokens + "100", + "--mem-fraction", # leave space for speculative model + "0.8", + "--attention-backend", + "fa4", # use bleeding-edge attention backend + ] + + cmd += [ # add speculative config + item for k, v in speculative_config.items() for item in (f"--{k}", str(v)) + ] + + self.process = subprocess.Popen(cmd, start_new_session=True) + wait_ready(self.process) + warmup() + sleep() # release GPU memory occupation before snapshot + + @modal.enter(snap=False) + def restore(self): + """After snapshot restoration, resume GPU memory occupation.""" + wake_up() + + @modal.exit() + def stop(self): + self.process.terminate() + + +# ## Deploy the server + +# ```bash +# modal deploy sglang_kitchen_sink.py +# ``` + +# ## Test the server + +# ```bash +# modal run sglang_kitchen_sink.py +# ``` + + +@app.local_entrypoint() +async def test(test_timeout=10 * MINUTES, prompt=None, twice=True): + url = SGLang._experimental_get_flash_urls()[0] + + system_prompt = { + "role": "system", + "content": "You are a pirate who can't help but drop sly reminders that he went to Harvard.", + } + if prompt is None: + prompt = "Explain the Singular Value Decomposition." + + content = [{"type": "text", "text": prompt}] + + messages = [ # OpenAI chat format + system_prompt, + {"role": "user", "content": content}, + ] + + await probe(url, messages, timeout=test_timeout) + if twice: + messages[0]["content"] = "You are Jar Jar Binks." + print(f"Sending messages to {url}:", *messages, sep="\n\t") + await probe(url, messages, timeout=1 * MINUTES) + + +async def probe(url, messages=None, timeout=5 * MINUTES): + if messages is None: + messages = [{"role": "user", "content": "Tell me a joke."}] + + client_id = str(0) # set this to some string per multi-turn interaction + # often a UUID per "conversation" + headers = {"Modal-Session-ID": client_id} + deadline = time.time() + timeout + async with aiohttp.ClientSession(base_url=url, headers=headers) as session: + while time.time() < deadline: + try: + await _send_request_streaming(session, messages) + return + except asyncio.TimeoutError: + await asyncio.sleep(1) + except aiohttp.client_exceptions.ClientResponseError as e: + if e.status == 503: + await asyncio.sleep(1) + continue + raise e + raise TimeoutError(f"No response from server within {timeout} seconds") + + +async def _send_request_streaming( + session: aiohttp.ClientSession, messages: list, timeout: int | None = None +) -> None: + payload = {"messages": messages, "stream": True} + headers = {"Accept": "text/event-stream"} + + async with session.post( + "/v1/chat/completions", json=payload, headers=headers, timeout=timeout + ) as resp: + resp.raise_for_status() + full_text = "" + + async for raw in resp.content: + line = raw.decode("utf-8", errors="ignore").strip() + if not line: + continue + + # Server-Sent Events format: "data: ...." + if not line.startswith("data:"): + continue + + data = line[len("data:") :].strip() + if data == "[DONE]": + break + + try: + evt = json.loads(data) + except json.JSONDecodeError: + # ignore any non-JSON keepalive + continue + + delta = (evt.get("choices") or [{}])[0].get("delta") or {} + chunk = delta.get("content") + + if chunk: + print(chunk, end="", flush="\n" in chunk or "." in chunk) + full_text += chunk + print() # newline after stream completes + print(full_text) + + +# ### Test memory snapshotting + +# ```bash +# python sglang_kitchen_sink.py +# ``` + +if __name__ == "__main__": + # after deployment, we can use the class from anywhere + SGLang = modal.Cls.from_name("example-sglang-kitchen-sink", "SGLang") + + print("calling inference server") + try: + asyncio.run(probe(SGLang._experimental_get_flash_urls()[0])) + except modal.exception.NotFoundError as e: + raise Exception( + f"To take advantage of GPU snapshots, deploy first with modal deploy {__file__}" + ) from e