WebRTC and YOLO example (#1176) (#1177)

* WebRTC and YOLO example  (#1176)

webrtc example using aiortc and yolo

* attempt to fix frontmatter

* trying a different pattern for the cmd arg with m flag

* address review comments

* fix github links and add suggestions from warnings for huggingface hub and web_url vs get_web_url

* run precommit hooks

* move test to separate file

* move test to separate file

* why did changes from modal_webrtc not commit previously?

* change testing file name to not trigger pytest directly

* change cmd in front matter to change new file name

* fix typehint for peer class variable and move so we can use ModalWebRtcPeer as type

* rename test file and ignore in pyproject pytest options

* better signaling server doct string

* update docstrings and change how we specificy the modal peer class

* revert change in workflow from local testing

* change cmd to match test file name

* debug checking provided modal peer class is subsclass of ModalWebRtcPeer

* change negotation to signaling in copy/diagrams, remove deploy:true frontmatter

* close revision of copy

* remove print statement and add information to test fail message

* close read and revision of copy

* add region note and increase ws connection timeout

* address jacks comments

* make diagrams more modal-style

* add tag for lamda-test opt-out reason

* update fastrtc example

* add video and link to demo, edit text

---------

Co-authored-by: Ben Shababo <shababo@Bens-MacBook-Air.local>
Co-authored-by: Charles Frye <charles@modal.com>
This commit is contained in:
Ben Shababo
2025-05-23 22:24:06 -07:00
committed by GitHub
parent d1c1d7439e
commit d089c92650
12 changed files with 1608 additions and 3 deletions

View File

@@ -1,5 +1,6 @@
# ---
# cmd: ["modal", "serve", "07_web_endpoints/fastrtc_flip_webcam.py"]
# deploy: true
# ---
# # Run a FastRTC app on Modal
@@ -7,7 +8,11 @@
# [FastRTC](https://fastrtc.org/) is a Python library for real-time communication on the web.
# This example demonstrates how to run a simple FastRTC app in the cloud on Modal.
# In it, we stream webcam video from a browser to a container on Modal,
# It's intended to help you get up and running with real-time streaming applications on Modal
# as quickly as possible. If you're interested in running a production-grade WebRTC app on Modal,
# see [this example](https://modal.com/docs/examples/webrtc_yolo).
# In this example, we stream webcam video from a browser to a container on Modal,
# where the video is flipped, annotated, and sent back with under 100ms of delay.
# You can try it out [here](https://modal-labs-examples--fastrtc-flip-webcam-ui.modal.run/)
# or just dive straight into the code to run it yourself.

View File

@@ -0,0 +1,70 @@
<!DOCTYPE html>
<html>
<head>
<title>WebRTC YOLO Demo</title>
<style>
video {
width: 320px;
height: 240px;
margin: 10px;
border: 1px solid black;
}
button {
margin: 10px;
padding: 10px;
}
#videos {
display: flex;
flex-wrap: wrap;
}
.radio-group {
margin: 10px;
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
}
.radio-group label {
margin-right: 15px;
}
#statusDisplay {
margin: 10px;
padding: 10px;
border: 1px solid #ccc;
border-radius: 4px;
background-color: #f5f5f5;
min-height: 20px;
max-height: 150px;
overflow-y: auto;
font-family: monospace;
white-space: pre-wrap;
word-wrap: break-word;
}
.status-line {
margin: 2px 0;
padding: 2px;
border-bottom: 1px solid #eee;
}
</style>
</head>
<body>
<div class="radio-group">
<label>
<input type="radio" name="iceServer" value="stun" checked> STUN Server
</label>
<label>
<input type="radio" name="iceServer" value="turn"> TURN Server
</label>
</div>
<div id="videos">
<video id="localVideo" autoplay playsinline muted></video>
<video id="remoteVideo" autoplay playsinline></video>
</div>
<div>
<button id="startWebcamButton">Start Webcam</button>
<button id="startStreamingButton" disabled>Stream YOLO</button>
<button id="stopStreamingButton" disabled>Stop Streaming</button>
</div>
<div id="statusDisplay"></div>
<script type="module" src="/static/webcam_webrtc.js"></script>
</body>
</html>

View File

@@ -0,0 +1,223 @@
export class ModalWebRtcClient extends EventTarget {
constructor() {
super();
this.ws = null;
this.localStream = null;
this.peerConnection = null;
this.iceServers = null;
this.peerID = null;
this.iceServerType = 'stun';
}
updateStatus(message) {
this.dispatchEvent(new CustomEvent('status', {
detail: { message }
}));
console.log(message);
}
// Get webcam media stream
async startWebcam() {
try {
this.localStream = await navigator.mediaDevices.getUserMedia({
video: {
facingMode: { ideal: "environment" }
},
audio: false
});
this.dispatchEvent(new CustomEvent('localStream', {
detail: { stream: this.localStream }
}));
return this.localStream;
} catch (err) {
console.error('Error accessing media devices:', err);
this.dispatchEvent(new CustomEvent('error', {
detail: { error: err }
}));
throw err;
}
}
// Create and set up peer connection
async startStreaming() {
this.peerID = this.generateShortUUID();
this.updateStatus('Loading YOLO GPU inference in the cloud (this can take up to 20 seconds)...');
await this.negotiate();
}
async negotiate() {
try {
// setup websocket connection
this.ws = new WebSocket(`/ws/${this.peerID}`);
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
this.dispatchEvent(new CustomEvent('error', {
detail: { error }
}));
};
this.ws.onclose = () => {
console.log('WebSocket connection closed');
this.dispatchEvent(new CustomEvent('websocketClosed'));
};
this.ws.onmessage = (event) => {
const msg = JSON.parse(event.data);
if (msg.type === 'answer') {
this.updateStatus('Establishing WebRTC connection...');
this.peerConnection.setRemoteDescription(msg);
} else if (msg.type === 'turn_servers') {
this.iceServers = msg.ice_servers;
} else {
console.error('Unexpected response from server:', msg);
}
};
console.log('Waiting for websocket to open...');
await new Promise((resolve) => {
if (this.ws.readyState === WebSocket.OPEN) {
resolve();
} else {
this.ws.addEventListener('open', () => resolve(), { once: true });
}
});
if (this.iceServerType === 'turn') {
this.ws.send(JSON.stringify({type: 'get_turn_servers', peer_id: this.peerID}));
} else {
this.iceServers = [
{
urls: ["stun:stun.l.google.com:19302"],
},
];
}
// Wait until we have ICE servers
if (this.iceServerType === 'turn') {
await new Promise((resolve) => {
const checkIceServers = () => {
if (this.iceServers) {
resolve();
} else {
setTimeout(checkIceServers, 100);
}
};
checkIceServers();
});
}
const rtcConfiguration = {
iceServers: this.iceServers,
}
this.peerConnection = new RTCPeerConnection(rtcConfiguration);
// Add local stream to peer connection
this.localStream.getTracks().forEach(track => {
console.log('Adding track:', track);
this.peerConnection.addTrack(track, this.localStream);
});
// Handle remote stream when triggered
this.peerConnection.ontrack = event => {
console.log('Received remote stream:', event.streams[0]);
this.dispatchEvent(new CustomEvent('remoteStream', {
detail: { stream: event.streams[0] }
}));
};
// Handle ICE candidates using Trickle ICE pattern
this.peerConnection.onicecandidate = async (event) => {
if (!event.candidate || !event.candidate.candidate) {
return;
}
const iceCandidate = {
peer_id: this.peerID,
candidate_sdp: event.candidate.candidate,
sdpMid: event.candidate.sdpMid,
sdpMLineIndex: event.candidate.sdpMLineIndex,
usernameFragment: event.candidate.usernameFragment
};
console.log('Sending ICE candidate: ', iceCandidate.candidate_sdp);
// send ice candidate over ws
this.ws.send(JSON.stringify({type: 'ice_candidate', candidate: iceCandidate}));
};
this.peerConnection.onconnectionstatechange = async () => {
const state = this.peerConnection.connectionState;
this.updateStatus(`WebRTCConnection state: ${state}`);
this.dispatchEvent(new CustomEvent('connectionStateChange', {
detail: { state }
}));
if (state === 'connected') {
if (this.ws.readyState === WebSocket.OPEN) {
this.ws.close();
}
}
};
// set local description and send as offer to peer
console.log('Setting local description...');
await this.peerConnection.setLocalDescription();
var offer = this.peerConnection.localDescription;
console.log('Sending offer...');
// send offer over ws
this.ws.send(JSON.stringify({peer_id: this.peerID, type: 'offer', sdp: offer.sdp}));
} catch (e) {
console.error('Error negotiating:', e);
this.dispatchEvent(new CustomEvent('error', {
detail: { error: e }
}));
throw e;
}
}
// Stop streaming
async stopStreaming() {
await this.cleanup();
this.updateStatus('Streaming stopped.');
this.dispatchEvent(new CustomEvent('streamingStopped'));
}
// cleanup
async cleanup() {
console.log('Cleaning up...');
this.iceServers = null;
if (this.peerConnection) {
console.log('Peer Connection state:', this.peerConnection.connectionState);
await this.peerConnection.close();
this.peerConnection = null;
}
if (this.ws && this.ws.readyState === WebSocket.OPEN) {
await this.ws.close();
this.ws = null;
}
this.dispatchEvent(new CustomEvent('cleanup'));
}
setIceServerType(type) {
this.iceServerType = type;
console.log('ICE server type changed to:', this.iceServerType);
this.dispatchEvent(new CustomEvent('iceServerTypeChanged', {
detail: { type }
}));
}
// Generate a short, URL-safe UUID
generateShortUUID() {
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_';
let result = '';
// Generate 22 characters (similar to short-uuid's default length)
for (let i = 0; i < 22; i++) {
result += chars.charAt(Math.floor(Math.random() * chars.length));
}
return result;
}
}

View File

@@ -0,0 +1,119 @@
import { ModalWebRtcClient } from './modal_webrtc.js';
// Add status display element
const statusDisplay = document.getElementById('statusDisplay');
const MAX_STATUS_HISTORY = 100;
let statusHistory = [];
// DOM elements
const localVideo = document.getElementById('localVideo');
const remoteVideo = document.getElementById('remoteVideo');
const startWebcamButton = document.getElementById('startWebcamButton');
const startStreamingButton = document.getElementById('startStreamingButton');
const stopStreamingButton = document.getElementById('stopStreamingButton');
// Initialize WebRTC client
const webrtcClient = new ModalWebRtcClient();
// Set up event listeners
webrtcClient.addEventListener('status', (event) => {
// Add timestamp to message
const now = new Date();
const timestamp = now.toLocaleTimeString();
const statusLine = `[${timestamp}] ${event.detail.message}`;
// Add to history
statusHistory.push(statusLine);
// Keep only last MAX_STATUS_HISTORY messages
if (statusHistory.length > MAX_STATUS_HISTORY) {
statusHistory.shift();
}
// Update display
statusDisplay.innerHTML = statusHistory.map(line =>
`<div class="status-line">${line}</div>`
).join('');
// Scroll to bottom
statusDisplay.scrollTop = statusDisplay.scrollHeight;
});
webrtcClient.addEventListener('localStream', (event) => {
localVideo.srcObject = event.detail.stream;
});
webrtcClient.addEventListener('remoteStream', (event) => {
remoteVideo.srcObject = event.detail.stream;
});
webrtcClient.addEventListener('error', (event) => {
console.error('WebRTC error:', event.detail.error);
});
webrtcClient.addEventListener('connectionStateChange', (event) => {
if (event.detail.state === 'connected') {
startStreamingButton.disabled = true;
stopStreamingButton.disabled = false;
}
});
webrtcClient.addEventListener('streamingStopped', () => {
stopStreamingButton.disabled = true;
startStreamingButton.disabled = false;
remoteVideo.srcObject = null;
});
// Initialize button states
startWebcamButton.disabled = false;
startStreamingButton.disabled = true;
stopStreamingButton.disabled = true;
// Event handlers
async function handleStartWebcam() {
try {
await webrtcClient.startWebcam();
startWebcamButton.disabled = true;
startStreamingButton.disabled = false;
} catch (err) {
console.error('Error starting webcam:', err);
}
}
async function handleStartStreaming() {
startWebcamButton.disabled = true;
startStreamingButton.disabled = true;
stopStreamingButton.disabled = false;
await webrtcClient.startStreaming();
}
async function handleStopStreaming() {
await webrtcClient.stopStreaming();
}
// Add event listener for STUN/TURN server radio buttons
document.querySelectorAll('input[name="iceServer"]').forEach(radio => {
radio.addEventListener('change', (e) => {
webrtcClient.setIceServerType(e.target.value);
});
});
// Event listeners
startWebcamButton.addEventListener('click', handleStartWebcam);
startStreamingButton.addEventListener('click', handleStartStreaming);
stopStreamingButton.addEventListener('click', handleStopStreaming);
// Add cleanup handler for when browser tab is closed
window.addEventListener('beforeunload', async () => {
await webrtcClient.cleanup();
// ensure stun/turn radio and iceServerType are reset
document.querySelectorAll('input[name="iceServer"]').forEach(radio => {
if (radio.value == "turn") {
radio.checked = false;
} else {
radio.checked = true;
}
});
webrtcClient.setIceServerType('stun');
});

View File

@@ -0,0 +1,322 @@
# ---
# lambda-test: false # auxiliary-file
# ---
import asyncio
import json
from abc import ABC, abstractmethod
from typing import Optional
import modal
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketState
class ModalWebRtcPeer(ABC):
"""
Base class for implementing WebRTC peer connections in Modal using aiortc.
Implement using the `app.cls` decorator.
This class provides a complete WebRTC peer implementation that handles:
- Peer connection lifecycle management (creation, negotiation, cleanup)
- Signaling via Modal Queue for SDP offer/answer exchange and ICE candidate handling
- Automatic STUN server configuration (defaults to Google's STUN server)
- Stream setup and management
Required methods to override:
- setup_streams(): Implementation for setting up media tracks and streams
Optional methods to override:
- initialize(): Custom initialization logic when peer is created
- run_streams(): Implementation for stream runtime logic
- get_turn_servers(): Implementation to provide custom TURN server configuration
- exit(): Custom cleanup logic when peer is shutting down
The peer connection is established through a ModalWebRtcSignalingServer that manages
the signaling process between this peer and client peers.
"""
@modal.enter()
async def _initialize(self):
import shortuuid
self.id = shortuuid.uuid()
self.pcs = {}
self.pending_candidates = {}
# call custom init logic
await self.initialize()
async def initialize(self):
"""Override to add custom logic when creating a peer"""
@abstractmethod
async def setup_streams(self, peer_id):
"""Override to add custom logic when creating a connection and setting up streams"""
raise NotImplementedError
async def run_streams(self, peer_id):
"""Override to add custom logic when running streams"""
async def get_turn_servers(self, peer_id=None, msg=None) -> Optional[list]:
"""Override to customize TURN servers"""
async def _setup_peer_connection(self, peer_id):
"""Creates an RTC peer connection via an ICE server"""
from aiortc import RTCConfiguration, RTCIceServer, RTCPeerConnection
# aiortc automatically uses google's STUN server,
# but we can also specify our own
config = RTCConfiguration(
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
)
self.pcs[peer_id] = RTCPeerConnection(configuration=config)
self.pending_candidates[peer_id] = []
await self.setup_streams(peer_id)
print(f"Created peer connection and setup streams from {self.id} to {peer_id}")
@modal.method()
async def run(self, queue: modal.Queue, peer_id: str):
"""Run the RTC peer after establishing a connection by passing WebSocket messages over a Queue."""
print(f"Running modal peer instance for client peer {peer_id}...")
await self._connect_over_queue(queue, peer_id)
await self._run_streams(peer_id)
async def _connect_over_queue(self, queue, peer_id):
"""Connect this peer to another by passing messages along a Modal Queue."""
msg_handlers = { # message types we need to handle
"offer": self.handle_offer, # SDP offer
"ice_candidate": self.handle_ice_candidate, # trickled ICE candidate
"identify": self.get_identity, # identify challenge
"get_turn_servers": self.get_turn_servers, # TURN server request
}
while True:
try:
if self.pcs.get(peer_id) and (
self.pcs[peer_id].connectionState
in ["connected", "closed", "failed"]
):
await queue.put.aio("close", partition="server")
break
# read and parse websocket message passed over queue
msg = json.loads(
await asyncio.wait_for(
queue.get.aio(partition=peer_id), timeout=0.5
)
)
# dispatch the message to its handler
if handler := msg_handlers.get(msg.get("type")):
response = await handler(peer_id, msg)
else:
print(f"Unknown message type: {msg.get('type')}")
response = None
# pass the message back over the queue to the server
if response is not None:
await queue.put.aio(json.dumps(response), partition="server")
except Exception:
continue
async def _run_streams(self, peer_id):
"""Run WebRTC streaming with a peer."""
print(f"Modal peer {self.id} running streams for {peer_id}...")
await self.run_streams(peer_id)
# run until connection is closed or broken
while self.pcs[peer_id].connectionState == "connected":
await asyncio.sleep(0.1)
print(f"Modal peer {self.id} ending streaming for {peer_id}")
async def handle_offer(self, peer_id, msg):
"""Handles a peers SDP offer message by producing an SDP answer."""
from aiortc import RTCSessionDescription
print(f"Peer {self.id} handling SDP offer from {peer_id}...")
await self._setup_peer_connection(peer_id)
await self.pcs[peer_id].setRemoteDescription(
RTCSessionDescription(msg["sdp"], msg["type"])
)
answer = await self.pcs[peer_id].createAnswer()
await self.pcs[peer_id].setLocalDescription(answer)
sdp = self.pcs[peer_id].localDescription.sdp
return {"sdp": sdp, "type": answer.type, "peer_id": self.id}
async def handle_ice_candidate(self, peer_id, msg):
"""Add an ICE candidate sent by a peer."""
from aiortc import RTCIceCandidate
from aiortc.sdp import candidate_from_sdp
candidate = msg.get("candidate")
if not candidate:
raise ValueError
print(
f"Modal peer {self.id} received ice candidate from {peer_id}: {candidate['candidate_sdp']}..."
)
# parse ice candidate
ice_candidate: RTCIceCandidate = candidate_from_sdp(candidate["candidate_sdp"])
ice_candidate.sdpMid = candidate["sdpMid"]
ice_candidate.sdpMLineIndex = candidate["sdpMLineIndex"]
if not self.pcs.get(peer_id):
self.pending_candidates[peer_id].append(ice_candidate)
else:
if len(self.pending_candidates[peer_id]) > 0:
[
await self.pcs[peer_id].addIceCandidate(c)
for c in self.pending_candidates[peer_id]
]
self.pending_candidates[peer_id] = []
await self.pcs[peer_id].addIceCandidate(ice_candidate)
async def get_identity(self, peer_id=None, msg=None):
"""Reply to an identify message with own id."""
return {"type": "identify", "peer_id": self.id}
async def generate_offer(self, peer_id):
print(f"Peer {self.id} generating offer for {peer_id}...")
await self._setup_peer_connection(peer_id)
offer = await self.pcs[peer_id].createOffer()
await self.pcs[peer_id].setLocalDescription(offer)
sdp = self.pcs[peer_id].localDescription.sdp
return {"sdp": sdp, "type": offer.type, "peer_id": self.id}
async def handle_answer(self, peer_id, answer):
from aiortc import RTCSessionDescription
print(f"Peer {self.id} handling answer from {peer_id}...")
# set remote peer description
await self.pcs[peer_id].setRemoteDescription(
RTCSessionDescription(sdp=answer["sdp"], type=answer["type"])
)
@modal.exit()
async def _exit(self):
print(f"Shutting down peer: {self.id}...")
await self.exit()
if self.pcs:
print(f"Closing peer connections for peer {self.id}...")
await asyncio.gather(*[pc.close() for pc in self.pcs.values()])
self.pcs = {}
async def exit(self):
"""Override with any custom logic when shutting down container."""
class ModalWebRtcSignalingServer:
"""
WebRTC signaling server implementation that mediates connections between client peers
and Modal-based WebRTC peers. Implement using the `app.cls` decorator.
This server:
- Provides a WebSocket endpoint (/ws/{peer_id}) for client connections
- Spawns Modal-based peer instances for each client connection
- Handles the WebRTC signaling process by relaying messages between clients and Modal peers
- Manages the lifecycle of Modal peer instances
To use this class:
1. Create a subclass implementing get_modal_peer_class() to return your ModalWebRtcPeer implementation
2. Optionally override initialize() for custom server setup
3. Optionally add a frontend route to the `web_app` attribute
"""
@modal.enter()
def _initialize(self):
self.web_app = FastAPI()
# handle signaling through websocket endpoint
@self.web_app.websocket("/ws/{peer_id}")
async def ws(client_websocket: WebSocket, peer_id: str):
await client_websocket.accept()
await self._mediate_negotiation(client_websocket, peer_id)
self.initialize()
def initialize(self):
pass
@abstractmethod
def get_modal_peer_class(self) -> type[ModalWebRtcPeer]:
"""
Abstract method to return the `ModalWebRtcPeer` implementation to use.
"""
raise NotImplementedError(
"Implement `get_modal_peer` to use `ModalWebRtcSignalingServer`"
)
@modal.asgi_app()
def web(self):
return self.web_app
async def _mediate_negotiation(self, websocket: WebSocket, peer_id: str):
modal_peer_class = self.get_modal_peer_class()
if not any(
base.__name__ == "ModalWebRtcPeer" for base in modal_peer_class.__bases__
):
raise ValueError(
"Modal peer class must be an implementation of `ModalWebRtcPeer`"
)
with modal.Queue.ephemeral() as q:
print(f"Spawning modal peer instance for client peer {peer_id}...")
modal_peer = modal_peer_class()
modal_peer.run.spawn(q, peer_id)
await asyncio.gather(
relay_websocket_to_queue(websocket, q, peer_id),
relay_queue_to_websocket(websocket, q, peer_id),
)
async def relay_websocket_to_queue(websocket: WebSocket, q: modal.Queue, peer_id: str):
while True:
try:
# get websocket message off queue and parse as json
msg = await asyncio.wait_for(websocket.receive_text(), timeout=0.5)
await q.put.aio(msg, partition=peer_id)
except Exception:
if WebSocketState.DISCONNECTED in [
websocket.application_state,
websocket.client_state,
]:
return
async def relay_queue_to_websocket(websocket: WebSocket, q: modal.Queue, peer_id: str):
while True:
try:
# get websocket message off queue and parse from json
modal_peer_msg = await asyncio.wait_for(
q.get.aio(partition="server"), timeout=0.5
)
if modal_peer_msg.startswith("close"):
await websocket.close()
return
await websocket.send_text(modal_peer_msg)
except Exception:
if WebSocketState.DISCONNECTED in [
websocket.application_state,
websocket.client_state,
]:
return

View File

@@ -0,0 +1,401 @@
# ---
# cmd: ["modal", "serve", "-m", "07_web_endpoints.webrtc.webrtc_yolo"]
# deploy: true
# ---
# # Real-time object detection with WebRTC and YOLO
# This example demonstrates how to architect a serverless real-time streaming application with Modal and WebRTC.
# The sample application detects objects in webcam video with YOLO.
# See the clip below from a live demo of this example in a course by [Kwindla Kramer](https://machine-theory.com/), WebRTC OG and co-founder of [Daily](https://www.daily.co/).
# <center>
# <video controls autoplay muted>
# <source src="https://modal-cdn.com/example-webrtc_yolo.mp4" type="video/mp4">
# </video>
# </center>
# You can also try our deployment [here](https://modal-labs-examples--example-webrtc-yolo-webcamobjdet-web.modal.run).
# ## What is WebRTC?
# WebRTC (Web Real-Time Communication) is an [IETF Internet protocol](https://www.rfc-editor.org/rfc/rfc8825) and a [W3C API specification](https://www.w3.org/TR/webrtc/) for real-time media streaming between peers
# over internets or the World Wide Web.
# What makes it so effective and different from other bidirectional web-based communication protocols (e.g. WebSockets) is that is purpose-built for media streaming in real time.
# It's primarily designed for browser applications using the JavaScript API, but [APIs exist for other languages](https://www.webrtc-developers.com/did-i-choose-the-right-webrtc-stack/).
# We'll build our app using Python's [`aiortc`](https://aiortc.readthedocs.io/en/latest/) package.
# ### What makes up a WebRTC application?
# A simple WebRTC app generally consists of three players:
# 1. a peer that initiates the connection,
# 2. a peer that responds to the connection, and
# 3. a server that passes some initial messages between the two peers.
# First, one peer initiates the connection by offering up a description of itself - its media sources, codec capabilities, Internet Protocol (IP) addressing info, etc - which is relayed to another peer through the server.
# The other peer then either accepts the offer by providing a compatible description of its own capabilities or rejects it if no compatible configuration is possible.
# This process is called "signaling" or sometimes the "negotiation" in the WebRTC world, and the server that mediates it is usually called the "signaling server".
# Once the peers have agreed on a configuration there's a brief pause to establish communication... and then you're live.
# ![Basic WebRTC architecture](https://modal-cdn.com/cdnbot/just_webrtc-1oic3iems_a4a8e77c.webp)
# <small>A basic WebRTC app architecture</small>
# Obviously theres more going on under the hood.
# If you want to get into the details, we recommend checking out the [RFCs](https://www.rfc-editor.org/rfc/rfc8825) or a [more-thorough explainer](https://webrtcforthecurious.com/).
# In this document, we'll focus on how to architect a WebRTC application where one or more peer is running on Modal's serverless cloud infrastructure.
# If you just want to quickly get started with WebRTC for a small internal service or a hack project, check out
# [our FastRTC example](https://modal.com/docs/examples/fastrtc_flip_webcam) instead.
# ## How do I run a WebRTC app on Modal?
# Modal turns Python code into scalable cloud services.
# When you call a Modal Function, you get one replica.
# If you call it 999 more times before it returns, you have 1000 replicas.
# When your Functions all return, you spin down to 0 replicas.
# The core constraints of the Modal programming model that make this possible are that Function Calls are stateless and self-contained.
# In other words, correctly-written Modal Functions don't store information in memory between runs (though they might cache data to the ephemeral local disk for efficiency) and they don't need create processes or tasks which must continue to run after the Function Call returns in order for the application to be correct.
# WebRTC apps, on the other hand, require passing messages back and forth in a multi-step protocol, and APIs spawn several "agents" (no, AI is not involved, just processes) which do work behind the scenes - including managing the peer-to-peer (P2P) connection itself.
# This means that streaming may have only just begun when the application logic in our Function has finished.
# ![Modal programming model and WebRTC signaling](https://modal-cdn.com/cdnbot/flow_comparisong6iibzq3_638bdd84.webp)
# <small>Modal's stateless programming model (left) and WebRTC's stateful signaling (right)</small>
# To ensure we properly leverage Modal's autoscaling and concurrency features, we need to align the signaling and streaming lifetimes with Modal Function Call lifetimes.
# The architecture we recommend for this appears below.
# ![WebRTC on Modal](https://modal-cdn.com/cdnbot/webrtc_with_modal-2horb680q_eab69b28.webp)
# <small>A clean architecture for WebRTC on Modal</small>
# It handles passing messages between the client peer and the signaling server using a
# [WebSocket](https://modal.com/docs/guide/webhooks#websockets) for persistent, bidirectional communication over the Web within a single Function Call.
# (Modal's Web layer maps HTTP and WS onto Function Calls, details [here](https://modal.com/blog/serverless-http)).
# We [`.spawn`](https://modal.com/docs/reference/modal.Function#spawn) the cloud peer inside the WebSocket endpoint
# and communicate it using a [`modal.Queue`](https://modal.com/docs/reference/modal.Queue).
# We can then use the state of the P2P connection to determine when to return from the calls to both the signaling server and the cloud peer.
# When the P2P connection has been _established_, we'll close the WebSocket which in turn ends the call to the signaling server.
# And when the P2P connection has been _closed_, we'll return from the call to the cloud peer.
# That way, our WebRTC application benefits from all the autoscaling and concurrency logic built into Modal
# that enables users to deliver efficient cloud applications.
# We wrote two classes, `ModalWebRtcPeer` and `ModalWebRtcSignalingServer`, to abstract away that boilerplate as well as a lot of the `aiortc` implementation details.
# They're also decorated with Modal [lifetime hooks](https://modal.com/docs/guide/lifecycle-functions).
# Add the [`app.cls`](https://modal.com/docs/reference/modal.App#cls) decorator and some custom logic, and you're ready to deploy on Modal.
# You can find them in the [`modal_webrtc.py` file](https://github.com/modal-labs/modal-examples/blob/main/07_web_endpoints/webrtc/modal_webrtc.py) provided alongside this example in the [GitHub repo](https://github.com/modal-labs/modal-examples/tree/main/07_web_endpoints/webrtc/modal_webrtc.py).
# ## Using `modal_webrtc` to detect objects in webcam footage
# For our WebRTC app, we'll take a client's video stream, run a [YOLO](https://docs.ultralytics.com/tasks/detect/) object detector on it with an A100 GPU on Modal, and then stream the annotated video back to the client. Let's get started!
# With this setup, we can achieve inference times between 2-4 milliseconds per frame and RTTs below video frame rates (usually around 30 milliseconds per frame).
# ### Setup
# We'll start with a simple container [Image](https://modal.com/docs/guide/images) and then
# - set it up to properly use TensorRT and the ONNX Runtime, which keep latency minimal,
# - install the necessary libs for processing video, `opencv` and `ffmpeg`, and
# - install the necessary Python packages.
import os
from pathlib import Path
import modal
from .modal_webrtc import ModalWebRtcPeer, ModalWebRtcSignalingServer
py_version = "3.12"
tensorrt_ld_path = f"/usr/local/lib/python{py_version}/site-packages/tensorrt_libs"
video_processing_image = (
modal.Image.debian_slim(python_version=py_version) # matching ld path
# update locale as required by onnx
.apt_install("locales")
.run_commands(
"sed -i '/^#\\s*en_US.UTF-8 UTF-8/ s/^#//' /etc/locale.gen", # use sed to uncomment
"locale-gen en_US.UTF-8", # set locale
"update-locale LANG=en_US.UTF-8",
)
.env({"LD_LIBRARY_PATH": tensorrt_ld_path, "LANG": "en_US.UTF-8"})
# install system dependencies
.apt_install("python3-opencv", "ffmpeg")
# install Python dependencies
.pip_install(
"aiortc==1.11.0",
"fastapi==0.115.12",
"huggingface-hub[hf_xet]==0.30.2",
"onnxruntime-gpu==1.21.0",
"opencv-python==4.11.0.86",
"tensorrt==10.9.0.34",
"torch==2.7.0",
"shortuuid==1.0.13",
)
)
# ### Cache weights and compute graphs on a Volume
# We also need to create a Modal [Volume](https://modal.com/docs/guide/volumes) to store things we need across replicas --
# primarily the model weights and ONNX inference graph, but also a few other artifacts like a video file where
# we'll write out the processed video stream for testing.
# The very first time we run the app, downloading the model and building the ONNX inference graph will take a few minutes.
# After that, we can load the cached weights and graph from the Volume, which reduces the startup time to about 15 seconds per container.
CACHE_VOLUME = modal.Volume.from_name("webrtc-yolo-cache", create_if_missing=True)
CACHE_PATH = Path("/cache")
cache = {CACHE_PATH: CACHE_VOLUME}
app = modal.App("example-webrtc-yolo")
# ### Implement YOLO object detection as a `ModalWebRtcPeer`
# Our application needs to process an incoming video track with YOLO and return an annotated video track to the source peer.
# To implement a `ModalWebRtcPeer`, we need to:
# - Decorate our subclass with `@app.cls`. We provision it with an A100 GPU and a [Secret](https://modal.com/docs/guide/secrets) credential, described below.
# - Implement the method `setup_streams`. This is where we'll use `aiortc` to add the logic for processing the incoming video track with YOLO
# and returning an annotated video track to the source peer.
# `ModalWebRtcPeer` has a few other methods that users can optionally implement:
# - `initialize()`: This contains any custom initialization logic, called when `@modal.enter()` is called.
# - `run_streams()`: Logic for starting streams. This is necessary when the peer is the source of the stream.
# This is where you'd ensure a webcam was running, start playing a video file, or spin up a [video generative model](https://modal.com/docs/examples/image_to_video).
# - `get_turn_servers()`: We haven't talked about [TURN servers](https://datatracker.ietf.org/doc/html/rfc5766),
# but just know that they're necessary if you want to use WebRTC across complex (e.g. carrier-grade) NAT or firewall configurations.
# Free services have tight limits because TURN servers are expensive to run (lots of bandwidth and state management required).
# [STUN](https://datatracker.ietf.org/doc/html/rfc5389) servers, on the other hand, are essentially just echo servers, and so there are many free services available.
# If you don't provide TURN servers you can still serve your app on many networks using any of a number of free STUN servers for NAT traversal.
# - `exit()`: This contains any custom cleanup logic, called when `@modal.exit()` is called.
# In our case, we load the YOLO model in `initialize` and provide server information for the free [Open Relay TURN server](https://www.metered.ca/tools/openrelay/).
# If you want to use it, you'll need to create an account [here](https://dashboard.metered.ca/login?tool=turnserver)
# and then create a Modal [Secret](https://modal.com/docs/guide/secrets) called `turn-credentials` [here](https://modal.com/secrets).
# We also use the `@modal.concurrent` decorator to allow multiple instances of our peer to run on one GPU.
# **Setting the Region**
# Much of the latency in Internet applications comes from distance between communicating parties --
# the Internet operates within a factor of two of the speed of light, but that's just not that fast.
# To minimize latency under this constraint, the physical distance of the P2P connection
# between the webcam-using peer and the GPU container needs to be kept as short as possible.
# We'll use the `region` parameter of the `cls` decorator to set the region of the GPU container.
# You should set this to the closest region to your users.
# See the [region selection](https://modal.com/docs/guide/region-selection) guide for more information.
@app.cls(
image=video_processing_image,
gpu="A100-40GB",
volumes=cache,
secrets=[modal.Secret.from_name("turn-credentials")],
region="us-east", # set to your region
)
@modal.concurrent(
target_inputs=2, # try to stick to just two peers per GPU container
max_inputs=3, # but allow up to three
)
class ObjDet(ModalWebRtcPeer):
async def initialize(self):
self.yolo_model = get_yolo_model(CACHE_PATH)
async def setup_streams(self, peer_id: str):
from aiortc import MediaStreamTrack
# keep us notified on connection state changes
@self.pcs[peer_id].on("connectionstatechange")
async def on_connectionstatechange() -> None:
if self.pcs[peer_id]:
print(
f"Video Processor, {self.id}, connection state to {peer_id}: {self.pcs[peer_id].connectionState}"
)
# when we receive a track from the source peer
# we create a processed track and add it to our stream
# back to the source peer
@self.pcs[peer_id].on("track")
def on_track(track: MediaStreamTrack) -> None:
print(
f"Video Processor, {self.id}, received {track.kind} track from {peer_id}"
)
output_track = get_yolo_track(track, self.yolo_model) # see Addenda
self.pcs[peer_id].addTrack(output_track)
# keep us notified when the incoming track ends
@track.on("ended")
async def on_ended() -> None:
print(
f"Video Processor, {self.id}, incoming video track from {peer_id} ended"
)
async def get_turn_servers(self, peer_id=None, msg=None) -> dict:
creds = {
"username": os.environ["TURN_USERNAME"],
"credential": os.environ["TURN_CREDENTIAL"],
}
turn_servers = [
{"urls": "stun:stun.relay.metered.ca:80"}, # STUN is free, no creds neeeded
# for TURN, sign up for the free service here: https://www.metered.ca/tools/openrelay/
{"urls": "turn:standard.relay.metered.ca:80"} | creds,
{"urls": "turn:standard.relay.metered.ca:80?transport=tcp"} | creds,
{"urls": "turn:standard.relay.metered.ca:443"} | creds,
{"urls": "turns:standard.relay.metered.ca:443?transport=tcp"} | creds,
]
return {"type": "turn_servers", "ice_servers": turn_servers}
# ### Implement a `SignalingServer`
# The `ModalWebRtcSignalingServer` class is much simpler to implement.
# The only thing we need to do is implement the `get_modal_peer_class` method which will return our implementation of the `ModalWebRtcPeer` class, `ObjDet`.
#
# It also has an `initialize()` method we can optionally override (called at the beginning of the [container lifecycle](https://modal.com/docs/guides/lifecycle-functions))
# as well as a `web_app` property which will be [served by Modal](https://modal.com/docs/guide/webhooks#asgi-apps---fastapi-fasthtml-starlette).
# We'll use these to add a frontend which uses the WebRTC JavaScript API to stream a peer's webcam from the browser.
#
# The JavaScript and HTML files are alongside this example in the [Github repo](https://github.com/modal-labs/modal-examples/tree/main/07_web_endpoints/webrtc/frontend).
base_image = (
modal.Image.debian_slim(python_version="3.12")
.apt_install("python3-opencv", "ffmpeg")
.pip_install(
"fastapi[standard]==0.115.4",
"aiortc==1.11.0",
"opencv-python==4.11.0.86",
"shortuuid==1.0.13",
)
)
this_directory = Path(__file__).parent.resolve()
server_image = base_image.add_local_dir(
this_directory / "frontend", remote_path="/frontend"
)
@app.cls(image=server_image)
class WebcamObjDet(ModalWebRtcSignalingServer):
def get_modal_peer_class(self):
return ObjDet
def initialize(self):
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
self.web_app.mount("/static", StaticFiles(directory="/frontend"))
@self.web_app.get("/")
async def root():
html = open("/frontend/index.html").read()
return HTMLResponse(content=html)
# ## Addenda
# The remainder of this page is not central to running a WebRTC application on Modal,
# but is included for completeness.
# ### YOLO helper functions
# The two functions below are used to set up the YOLO model and create our custom [`MediaStreamTrack`](https://aiortc.readthedocs.io/en/latest/api.html#aiortc.MediaStreamTrack).
# The first, `get_yolo_model`, sets up the ONNXRuntime and loads the model weights.
# We call this in the `initialize` method of the `ModalWebRtcPeer` class
# so that it only happens once per container.
def get_yolo_model(cache_path):
import onnxruntime
from .yolo import YOLOv10
onnxruntime.preload_dlls()
return YOLOv10(cache_path)
# The second, `get_yolo_track`, creates a custom `MediaStreamTrack` that performs object detection on the video stream.
# We call this in the `setup_streams` method of the `ModalWebRtcPeer` class
# so it happens once per peer connection.
def get_yolo_track(track, yolo_model=None):
import numpy as np
import onnxruntime
from aiortc import MediaStreamTrack
from aiortc.contrib.media import VideoFrame
from .yolo import YOLOv10
class YOLOTrack(MediaStreamTrack):
"""
Custom media stream track performs object detection
on the video stream and passes it back to the source peer
"""
kind: str = "video"
conf_threshold: float = 0.15
def __init__(self, track: MediaStreamTrack, yolo_model=None) -> None:
super().__init__()
self.track = track
if yolo_model is None:
onnxruntime.preload_dlls()
self.yolo_model = YOLOv10(CACHE_PATH)
else:
self.yolo_model = yolo_model
def detection(self, image: np.ndarray) -> np.ndarray:
import cv2
orig_shape = image.shape[:-1]
image = cv2.resize(
image,
(self.yolo_model.input_width, self.yolo_model.input_height),
)
image = self.yolo_model.detect_objects(image, self.conf_threshold)
image = cv2.resize(image, (orig_shape[1], orig_shape[0]))
return image
# this is the essential method we need to implement
# to create a custom MediaStreamTrack
async def recv(self) -> VideoFrame:
frame = await self.track.recv()
img = frame.to_ndarray(format="bgr24")
processed_img = self.detection(img)
# VideoFrames are from a really nice package called av
# which is a pythonic wrapper around ffmpeg
# and a dependency of aiortc
new_frame = VideoFrame.from_ndarray(processed_img, format="bgr24")
new_frame.pts = frame.pts
new_frame.time_base = frame.time_base
return new_frame
return YOLOTrack(track)
# ### Testing a WebRTC application on Modal
# As any seasoned developer of real-time applications on the Web will tell you,
# testing and ensuring correctness is quite difficult. We spent nearly as much time
# designing and troubleshooting an appropriate testing process for this application as we did writing
# the application itself!
# You can find the testing code in the GitHub repository [here](https://github.com/modal-labs/modal-examples/tree/main/07_web_endpoints/webrtc/webrtc_yolo_test.py).

View File

@@ -0,0 +1,172 @@
# ---
# cmd: ["modal", "run", "-m", "07_web_endpoints.webrtc.webrtc_yolo_test"]
# ---
import modal
from .modal_webrtc import ModalWebRtcPeer
from .webrtc_yolo import (
CACHE_PATH,
WebcamObjDet,
app,
base_image,
cache,
)
# ## Testing WebRTC and Modal
# First we define a `local_entrypoint` to run and evaluate the test.
# Our test will stream an .mp4 file to the cloud peer and record the annoated video to a new file.
# The test itself ensurse that the new video is no more than five frames shorter than the source file.
# The difference is due to dropped frames while the connection is starting up.
@app.local_entrypoint()
def test():
input_frames, output_frames = TestPeer().run_video_processing_test.remote()
# allow a few dropped frames from the connection starting up
assert input_frames - output_frames < 5, (
f"Streaming failed. Frame difference: {input_frames} - {output_frames} = {input_frames - output_frames}"
)
# Because our test will require Python dependencies outside the standard library, we'll run the test itself in a container on Modal.
# In fact, this will be another `ModalWebRtcPeer` class. So the test will also demonstrate how to setup WebRTC between Modal containers.
# There are some details in here regarding the use of `aiortc`'s `MediaPlayer` and `MediaRecorder` classes that won't cover here.
# Just know that these are `aiortc` specific classes - not a WebRTC thing.
# That said, using these classes does require us to manually `start` and `stop` streams.
# For example, we'll need to override the `run_streams` method to start the source stream, and we'll make use of the `on_ended` callback to stop the recording.
@app.cls(image=base_image, volumes=cache)
class TestPeer(ModalWebRtcPeer):
TEST_VIDEO_SOURCE_URL = "https://modal-cdn.com/cliff_jumping.mp4"
TEST_VIDEO_RECORD_FILE = CACHE_PATH / "test_video.mp4"
# extra time to run streams beyond input video duration
VIDEO_DURATION_BUFFER_SECS = 5.0
# allow time for container to spin up (can timeout with default 10)
WS_OPEN_TIMEOUT = 300 # seconds
async def initialize(self) -> None:
import cv2
# get input video duration in seconds
self.input_video = cv2.VideoCapture(self.TEST_VIDEO_SOURCE_URL)
self.input_video_duration_frames = self.input_video.get(
cv2.CAP_PROP_FRAME_COUNT
)
self.input_video_duration_seconds = (
self.input_video_duration_frames / self.input_video.get(cv2.CAP_PROP_FPS)
)
self.input_video.release()
# set streaming duration to input video duration plus a buffer
self.stream_duration = (
self.input_video_duration_seconds + self.VIDEO_DURATION_BUFFER_SECS
)
self.player = None # video stream source
self.recorder = None # processed video stream sink
async def setup_streams(self, peer_id: str) -> None:
import os
from aiortc import MediaStreamTrack
from aiortc.contrib.media import MediaPlayer, MediaRecorder
# setup video player and to peer connection
self.video_src = MediaPlayer(self.TEST_VIDEO_SOURCE_URL)
self.pcs[peer_id].addTrack(self.video_src.video)
# setup video recorder
if os.path.exists(self.TEST_VIDEO_RECORD_FILE):
os.remove(self.TEST_VIDEO_RECORD_FILE)
self.recorder = MediaRecorder(self.TEST_VIDEO_RECORD_FILE)
# keep us notified on connection state changes
@self.pcs[peer_id].on("connectionstatechange")
async def on_connectionstatechange() -> None:
print(
f"Video Tester connection state updated: {self.pcs[peer_id].connectionState}"
)
# when we receive a track back from
# the video processing peer we record it
# to the output file
@self.pcs[peer_id].on("track")
def on_track(track: MediaStreamTrack) -> None:
print(f"Video Tester received {track.kind} track from {peer_id}")
# record track to file
self.recorder.addTrack(track)
@track.on("ended")
async def on_ended() -> None:
print("Video Tester's processed video stream ended")
# stop recording when incoming track ends to finish writing video
await self.recorder.stop()
# reset recorder and player
self.recorder = None
self.video_src = None
async def run_streams(self, peer_id: str) -> None:
import asyncio
print(f"Video Tester running streams for {peer_id}...")
# MediaRecorders need to be started manually
# but in most cases the track is already streaming
await self.recorder.start()
# run until sufficient time has passed
await asyncio.sleep(self.stream_duration)
# close peer connection manually
await self.pcs[peer_id].close()
def count_frames(self):
import cv2
# compare output video length to input video length
output_video = cv2.VideoCapture(self.TEST_VIDEO_RECORD_FILE)
output_video_duration_frames = int(output_video.get(cv2.CAP_PROP_FRAME_COUNT))
output_video.release()
return self.input_video_duration_frames, output_video_duration_frames
@modal.method()
async def run_video_processing_test(self) -> bool:
import json
import websockets
peer_id = None
# connect to server via websocket
ws_uri = (
WebcamObjDet().web.get_web_url().replace("http", "ws") + f"/ws/{self.id}"
)
async with websockets.connect(
ws_uri, open_timeout=self.WS_OPEN_TIMEOUT
) as websocket:
await websocket.send(json.dumps({"type": "identify", "peer_id": self.id}))
peer_id = json.loads(await websocket.recv())["peer_id"]
offer_msg = await self.generate_offer(peer_id)
await websocket.send(json.dumps(offer_msg))
try:
# receive answer
answer = json.loads(await websocket.recv())
if answer.get("type") == "answer":
await self.handle_answer(peer_id, answer)
except websockets.exceptions.ConnectionClosed:
await websocket.close()
# loop until video player is finished
if peer_id:
await self.run_streams(peer_id)
return self.count_frames()

View File

@@ -0,0 +1 @@
from .yolo import YOLOv10 as YOLOv10

View File

@@ -0,0 +1,212 @@
# ---
# lambda-test: false
# ---
from pathlib import Path
import cv2
import numpy as np
import onnxruntime
this_dir = Path(__file__).parent.resolve()
class YOLOv10:
def __init__(self, cache_dir):
from huggingface_hub import hf_hub_download
# Initialize model
self.cache_dir = cache_dir
print(f"Initializing YOLO model from {self.cache_dir}")
model_file = hf_hub_download(
repo_id="onnx-community/yolov10n",
filename="onnx/model.onnx",
cache_dir=self.cache_dir,
)
self.initialize_model(model_file)
print("YOLO model initialized")
def initialize_model(self, model_file):
self.session = onnxruntime.InferenceSession(
model_file,
providers=[
(
"TensorrtExecutionProvider",
{
"trt_engine_cache_enable": True,
"trt_engine_cache_path": self.cache_dir / "onnx.cache",
},
),
"CUDAExecutionProvider",
],
)
# Get model info
self.get_input_details()
self.get_output_details()
# get class names
with open(this_dir / "yolo_classes.txt", "r") as f:
self.class_names = f.read().splitlines()
rng = np.random.default_rng(3)
self.colors = rng.uniform(0, 255, size=(len(self.class_names), 3))
def detect_objects(self, image, conf_threshold=0.3):
input_tensor = self.prepare_input(image)
# Perform inference on the image
new_image = self.inference(image, input_tensor, conf_threshold)
return new_image
def prepare_input(self, image):
self.img_height, self.img_width = image.shape[:2]
input_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Resize input image
input_img = cv2.resize(input_img, (self.input_width, self.input_height))
# Scale input pixel values to 0 to 1
input_img = input_img / 255.0
input_img = input_img.transpose(2, 0, 1)
input_tensor = input_img[np.newaxis, :, :, :].astype(np.float32)
return input_tensor
def inference(self, image, input_tensor, conf_threshold=0.3):
# set seed to potentially create smoother output in RT setting
onnxruntime.set_seed(42)
# start = time.perf_counter()
outputs = self.session.run(
self.output_names, {self.input_names[0]: input_tensor}
)
# print(f"Inference time: {(time.perf_counter() - start) * 1000:.2f} ms")
(
boxes,
scores,
class_ids,
) = self.process_output(outputs, conf_threshold)
return self.draw_detections(image, boxes, scores, class_ids)
def process_output(self, output, conf_threshold=0.3):
predictions = np.squeeze(output[0])
# Filter out object confidence scores below threshold
scores = predictions[:, 4]
predictions = predictions[scores > conf_threshold, :]
scores = scores[scores > conf_threshold]
if len(scores) == 0:
return [], [], []
# Get the class with the highest confidence
class_ids = predictions[:, 5].astype(int)
# Get bounding boxes for each object
boxes = self.extract_boxes(predictions)
return boxes, scores, class_ids
def extract_boxes(self, predictions):
# Extract boxes from predictions
boxes = predictions[:, :4]
# Scale boxes to original image dimensions
boxes = self.rescale_boxes(boxes)
# Convert boxes to xyxy format
# boxes = xywh2xyxy(boxes)
return boxes
def rescale_boxes(self, boxes):
# Rescale boxes to original image dimensions
input_shape = np.array(
[
self.input_width,
self.input_height,
self.input_width,
self.input_height,
]
)
boxes = np.divide(boxes, input_shape, dtype=np.float32)
boxes *= np.array(
[self.img_width, self.img_height, self.img_width, self.img_height]
)
return boxes
def draw_detections(
self, image, boxes, scores, class_ids, draw_scores=True, mask_alpha=0.4
):
det_img = image.copy()
img_height, img_width = image.shape[:2]
font_size = min([img_height, img_width]) * 0.0012
text_thickness = int(min([img_height, img_width]) * 0.004)
# Draw bounding boxes and labels of detections
for class_id, box, score in zip(class_ids, boxes, scores):
color = self.colors[class_id]
self.draw_box(det_img, box, color) # type: ignore
label = self.class_names[class_id]
caption = f"{label} {int(score * 100)}%"
self.draw_text(det_img, caption, box, color, font_size, text_thickness) # type: ignore
return det_img
def get_input_details(self):
model_inputs = self.session.get_inputs()
self.input_names = [model_inputs[i].name for i in range(len(model_inputs))]
self.input_shape = model_inputs[0].shape
self.input_height = self.input_shape[2]
self.input_width = self.input_shape[3]
def get_output_details(self):
model_outputs = self.session.get_outputs()
self.output_names = [model_outputs[i].name for i in range(len(model_outputs))]
def draw_box(
self,
image: np.ndarray,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
thickness: int = 5,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
return cv2.rectangle(image, (x1, y1), (x2, y2), color, thickness)
def draw_text(
self,
image: np.ndarray,
text: str,
box: np.ndarray,
color: tuple[int, int, int] = (0, 0, 255),
font_size: float = 0.100,
text_thickness: int = 5,
box_thickness: int = 5,
) -> np.ndarray:
x1, y1, x2, y2 = box.astype(int)
(tw, th), _ = cv2.getTextSize(
text=text,
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_size,
thickness=text_thickness,
)
x1 = x1 - box_thickness
th = int(th * 1.2)
cv2.rectangle(image, (x1, y1), (x1 + tw, y1 - th), color, -1)
return cv2.putText(
image,
text,
(x1, y1),
cv2.FONT_HERSHEY_SIMPLEX,
font_size,
(255, 255, 255),
text_thickness,
cv2.LINE_AA,
)

View File

@@ -0,0 +1,80 @@
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

View File

@@ -168,4 +168,4 @@ def get_examples_json():
if __name__ == "__main__":
for example in get_examples():
print(example.json())
print(example.model_dump_json())

View File

@@ -4,7 +4,7 @@ filterwarnings = [
"error::modal.exception.DeprecationError",
"ignore::DeprecationWarning:pytest.*:",
]
addopts = "--ignore 06_gpu_and_ml/llm-serving/openai_compatible/load_test.py --ignore 07_web_endpoints/fasthtml-checkboxes/cbx_load_test.py"
addopts = "--ignore 07_web_endpoints/webrtc/webrtc_yolo_test.py --ignore 06_gpu_and_ml/llm-serving/openai_compatible/load_test.py --ignore 07_web_endpoints/fasthtml-checkboxes/cbx_load_test.py"
[tool.mypy]
ignore_missing_imports = true