wifi_densepose_worldmodel/lib.rs
1//! `wifi-densepose-worldmodel` — OccWorld thin-client bridge (ADR-147).
2//!
3//! Bridges [`wifi_densepose_worldgraph`] `PersonTrack` history to the OccWorld
4//! Python inference subprocess and returns [`TrajectoryPrior`]s that can be
5//! injected into the Kalman pose tracker.
6//!
7//! ## Quick start
8//! ```rust,no_run
9//! use wifi_densepose_worldmodel::{
10//! OccWorldBridge, OccupancyWorldModelRequest, OccupancyGrid3D,
11//! SceneBoundsJson, worldgraph_to_occupancy,
12//! };
13//! use wifi_densepose_worldmodel::occupancy::{PersonPosition, SceneBounds};
14//!
15//! # async fn example() -> Result<(), wifi_densepose_worldmodel::WorldModelError> {
16//! let bridge = OccWorldBridge::new("/tmp/occworld.sock");
17//!
18//! let bounds = SceneBounds { min_e: -10.0, min_n: -10.0, max_e: 10.0, max_n: 10.0 };
19//! let persons = vec![
20//! PersonPosition { track_id: 1, east_m: 2.0, north_m: 3.0, up_m: 1.0 },
21//! ];
22//! let frame = worldgraph_to_occupancy(&persons, &bounds, 0.1);
23//!
24//! let request = OccupancyWorldModelRequest {
25//! past_frames: vec![frame],
26//! voxel_resolution_m: 0.1,
27//! scene_bounds: SceneBoundsJson {
28//! min_e: bounds.min_e, min_n: bounds.min_n,
29//! max_e: bounds.max_e, max_n: bounds.max_n,
30//! },
31//! prediction_steps: 15,
32//! };
33//!
34//! let response = bridge.predict(request).await?;
35//! println!("confidence={:.2}", response.confidence);
36//! for prior in &response.trajectory_priors {
37//! println!("track {} has {} waypoints", prior.track_id, prior.waypoints.len());
38//! }
39//! # Ok(())
40//! # }
41//! ```
42
43pub mod bridge;
44pub mod error;
45pub mod occupancy;
46
47// Re-export the bridge type at the crate root for convenience.
48pub use bridge::{default_socket_path, OccWorldBridge};
49pub use error::WorldModelError;
50pub use occupancy::worldgraph_to_occupancy;
51
52use serde::{Deserialize, Serialize};
53
54// ---------------------------------------------------------------------------
55// Voxel grid
56// ---------------------------------------------------------------------------
57
58/// A 3-D occupancy grid whose voxel values are class indices (u8).
59///
60/// Layout: `voxels[z * height * width + y * width + x]` (row-major, depth last).
61/// The grid is always `200 × 200 × 16` when produced by
62/// [`worldgraph_to_occupancy`].
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct OccupancyGrid3D {
65 /// Number of voxels along the east/x axis.
66 pub width: u32,
67 /// Number of voxels along the north/y axis.
68 pub height: u32,
69 /// Number of voxels along the up/z axis.
70 pub depth: u32,
71 /// Flat class-index array, length `width * height * depth`.
72 pub voxels: Vec<u8>,
73}
74
75// ---------------------------------------------------------------------------
76// Trajectory types
77// ---------------------------------------------------------------------------
78
79/// A single point on a predicted trajectory, with a relative timestamp.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TrajectoryWaypoint {
82 /// East offset from installation origin, in metres.
83 pub e: f64,
84 /// North offset from installation origin, in metres.
85 pub n: f64,
86 /// Up offset (height), in metres.
87 pub u: f64,
88 /// Time offset from "now", in seconds (positive = future).
89 pub t_s: f32,
90}
91
92/// Predicted future trajectory for one tracked person.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrajectoryPrior {
95 /// Stable track identifier (mirrors `WorldNode::PersonTrack::track_id`).
96 pub track_id: u64,
97 /// Ordered sequence of predicted future waypoints.
98 pub waypoints: Vec<TrajectoryWaypoint>,
99}
100
101// ---------------------------------------------------------------------------
102// Scene bounds (JSON wire shape)
103// ---------------------------------------------------------------------------
104
105/// Axis-aligned scene footprint sent to the OccWorld server in the IPC
106/// request. Mirrors [`occupancy::SceneBounds`] but derives `Serialize` /
107/// `Deserialize` for direct inclusion in the JSON payload.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SceneBoundsJson {
110 /// Western (minimum east) edge of the scene, in metres.
111 pub min_e: f64,
112 /// Southern (minimum north) edge of the scene, in metres.
113 pub min_n: f64,
114 /// Eastern (maximum east) edge of the scene, in metres.
115 pub max_e: f64,
116 /// Northern (maximum north) edge of the scene, in metres.
117 pub max_n: f64,
118}
119
120// ---------------------------------------------------------------------------
121// IPC request / response
122// ---------------------------------------------------------------------------
123
124/// JSON request sent from the Rust bridge to the OccWorld Python server.
125///
126/// Serialised as a single newline-terminated JSON object over the Unix socket.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct OccupancyWorldModelRequest {
129 /// History of occupancy grids (chronological, oldest first).
130 /// OccWorld expects at least one frame; the reference implementation uses
131 /// the most recent 4 frames for temporal context.
132 pub past_frames: Vec<OccupancyGrid3D>,
133
134 /// Physical size of one voxel cell on the ground plane, in metres.
135 pub voxel_resolution_m: f32,
136
137 /// Scene footprint used to build the occupancy grid.
138 pub scene_bounds: SceneBoundsJson,
139
140 /// Number of future time steps to predict (reference: 15 × 0.1 s = 1.5 s).
141 pub prediction_steps: u32,
142}
143
144/// JSON response returned by the OccWorld Python server.
145///
146/// Decoded from a single newline-terminated JSON object on the Unix socket.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct OccupancyWorldModelResponse {
149 /// Predicted future occupancy grids (chronological, `prediction_steps`
150 /// frames in total).
151 pub future_frames: Vec<OccupancyGrid3D>,
152
153 /// Per-person predicted trajectories extracted from `future_frames`.
154 pub trajectory_priors: Vec<TrajectoryPrior>,
155
156 /// Aggregate confidence score in `[0, 1]` for the entire prediction.
157 pub confidence: f32,
158
159 /// Identifier of the model that produced this response.
160 /// The sentinel prefix `"error:vram:"` signals a VRAM error (see ADR-147).
161 pub model_id: String,
162
163 /// Wall-clock time the Python server spent on inference, in milliseconds.
164 pub inference_ms: u64,
165}
166
167// ---------------------------------------------------------------------------
168// WorldGraph helper — extract PersonPosition list from a WorldGraph snapshot
169// ---------------------------------------------------------------------------
170
171use wifi_densepose_worldgraph::WorldGraph;
172
173use crate::occupancy::PersonPosition;
174
175/// Extracts all [`PersonPosition`]s from a [`WorldGraph`] by serialising the
176/// graph to its canonical JSON form (via [`WorldGraph::to_json`]) and scanning
177/// the `nodes` array for `PersonTrack` entries.
178///
179/// This avoids coupling to the private fields of `WorldGraphSnapshot`.
180/// The returned positions are unsorted; callers may sort by `track_id` if
181/// deterministic ordering is required.
182///
183/// # Panics
184/// Does not panic — if serialisation fails the function returns an empty
185/// `Vec` and logs a warning via `eprintln!`. In practice, serialisation of a
186/// valid `WorldGraph` should never fail.
187pub fn persons_from_worldgraph(graph: &WorldGraph) -> Vec<PersonPosition> {
188 let bytes = match graph.to_json() {
189 Ok(b) => b,
190 Err(e) => {
191 eprintln!("[worldmodel] WorldGraph::to_json failed: {e}");
192 return Vec::new();
193 }
194 };
195
196 // Parse as a raw JSON value to avoid depending on the exact shape of the
197 // private `WorldGraphSnapshot` struct fields.
198 let value: serde_json::Value = match serde_json::from_slice(&bytes) {
199 Ok(v) => v,
200 Err(e) => {
201 eprintln!("[worldmodel] failed to parse WorldGraph JSON: {e}");
202 return Vec::new();
203 }
204 };
205
206 let nodes = match value.get("nodes").and_then(|n| n.as_array()) {
207 Some(arr) => arr,
208 None => return Vec::new(),
209 };
210
211 nodes
212 .iter()
213 .filter_map(|node| {
214 // Nodes use a serde-tagged enum; the PersonTrack variant carries a
215 // `kind` discriminator equal to `"person_track"`.
216 if node.get("kind")?.as_str()? != "person_track" {
217 return None;
218 }
219 let track_id = node.get("track_id")?.as_u64()?;
220 let pos = node.get("last_position")?;
221 let east_m = pos.get("east_m")?.as_f64()?;
222 let north_m = pos.get("north_m")?.as_f64()?;
223 let up_m = pos.get("up_m")?.as_f64()?;
224 Some(PersonPosition { track_id, east_m, north_m, up_m })
225 })
226 .collect()
227}
228
229// ---------------------------------------------------------------------------
230// Tests
231// ---------------------------------------------------------------------------
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn occupancy_grid_serde_roundtrip() {
239 let grid = OccupancyGrid3D {
240 width: 4,
241 height: 4,
242 depth: 2,
243 voxels: vec![17u8; 32],
244 };
245 let json = serde_json::to_string(&grid).expect("serialize");
246 let decoded: OccupancyGrid3D = serde_json::from_str(&json).expect("deserialize");
247 assert_eq!(decoded.width, grid.width);
248 assert_eq!(decoded.voxels.len(), grid.voxels.len());
249 }
250
251 #[test]
252 fn trajectory_prior_serde_roundtrip() {
253 let prior = TrajectoryPrior {
254 track_id: 42,
255 waypoints: vec![
256 TrajectoryWaypoint { e: 1.0, n: 2.0, u: 0.0, t_s: 0.1 },
257 TrajectoryWaypoint { e: 1.1, n: 2.1, u: 0.0, t_s: 0.2 },
258 ],
259 };
260 let json = serde_json::to_string(&prior).expect("serialize");
261 let decoded: TrajectoryPrior = serde_json::from_str(&json).expect("deserialize");
262 assert_eq!(decoded.track_id, 42);
263 assert_eq!(decoded.waypoints.len(), 2);
264 }
265
266 #[test]
267 fn request_serde_roundtrip() {
268 let req = OccupancyWorldModelRequest {
269 past_frames: vec![OccupancyGrid3D {
270 width: 200,
271 height: 200,
272 depth: 16,
273 voxels: vec![17u8; 200 * 200 * 16],
274 }],
275 voxel_resolution_m: 0.1,
276 scene_bounds: SceneBoundsJson {
277 min_e: -10.0,
278 min_n: -10.0,
279 max_e: 10.0,
280 max_n: 10.0,
281 },
282 prediction_steps: 15,
283 };
284 let json = serde_json::to_string(&req).expect("serialize");
285 let decoded: OccupancyWorldModelRequest =
286 serde_json::from_str(&json).expect("deserialize");
287 assert_eq!(decoded.prediction_steps, 15);
288 assert_eq!(decoded.past_frames.len(), 1);
289 }
290
291 #[test]
292 fn response_serde_roundtrip() {
293 let resp = OccupancyWorldModelResponse {
294 future_frames: vec![],
295 trajectory_priors: vec![TrajectoryPrior {
296 track_id: 1,
297 waypoints: vec![TrajectoryWaypoint { e: 0.0, n: 0.0, u: 0.0, t_s: 0.0 }],
298 }],
299 confidence: 0.82,
300 model_id: "occworld-dummy-v0".into(),
301 inference_ms: 375,
302 };
303 let json = serde_json::to_string(&resp).expect("serialize");
304 let decoded: OccupancyWorldModelResponse =
305 serde_json::from_str(&json).expect("deserialize");
306 assert_eq!(decoded.inference_ms, 375);
307 assert!((decoded.confidence - 0.82).abs() < 1e-5);
308 }
309
310 #[test]
311 fn vram_error_sentinel_roundtrip() {
312 let resp = OccupancyWorldModelResponse {
313 future_frames: vec![],
314 trajectory_priors: vec![],
315 confidence: 0.0,
316 model_id: "error:vram:out of memory (CUDA)".into(),
317 inference_ms: 0,
318 };
319 assert!(resp.model_id.starts_with("error:vram:"));
320 }
321}