Skip to main content

wifi_densepose_worldmodel/
lib.rs

1//! `wifi-densepose-worldmodel` — OccWorld thin-client bridge (ADR-147).
2//!
3//! Bridges [`wifi_densepose_worldgraph`] `PersonTrack` history to the OccWorld
4//! Python inference subprocess and returns [`TrajectoryPrior`]s that can be
5//! injected into the Kalman pose tracker.
6//!
7//! ## Quick start
8//! ```rust,no_run
9//! use wifi_densepose_worldmodel::{
10//!     OccWorldBridge, OccupancyWorldModelRequest, OccupancyGrid3D,
11//!     SceneBoundsJson, worldgraph_to_occupancy,
12//! };
13//! use wifi_densepose_worldmodel::occupancy::{PersonPosition, SceneBounds};
14//!
15//! # async fn example() -> Result<(), wifi_densepose_worldmodel::WorldModelError> {
16//! let bridge = OccWorldBridge::new("/tmp/occworld.sock");
17//!
18//! let bounds = SceneBounds { min_e: -10.0, min_n: -10.0, max_e: 10.0, max_n: 10.0 };
19//! let persons = vec![
20//!     PersonPosition { track_id: 1, east_m: 2.0, north_m: 3.0, up_m: 1.0 },
21//! ];
22//! let frame = worldgraph_to_occupancy(&persons, &bounds, 0.1);
23//!
24//! let request = OccupancyWorldModelRequest {
25//!     past_frames: vec![frame],
26//!     voxel_resolution_m: 0.1,
27//!     scene_bounds: SceneBoundsJson {
28//!         min_e: bounds.min_e, min_n: bounds.min_n,
29//!         max_e: bounds.max_e, max_n: bounds.max_n,
30//!     },
31//!     prediction_steps: 15,
32//! };
33//!
34//! let response = bridge.predict(request).await?;
35//! println!("confidence={:.2}", response.confidence);
36//! for prior in &response.trajectory_priors {
37//!     println!("track {} has {} waypoints", prior.track_id, prior.waypoints.len());
38//! }
39//! # Ok(())
40//! # }
41//! ```
42
43pub mod bridge;
44pub mod error;
45pub mod occupancy;
46
47// Re-export the bridge type at the crate root for convenience.
48pub use bridge::{default_socket_path, OccWorldBridge};
49pub use error::WorldModelError;
50pub use occupancy::worldgraph_to_occupancy;
51
52use serde::{Deserialize, Serialize};
53
54// ---------------------------------------------------------------------------
55// Voxel grid
56// ---------------------------------------------------------------------------
57
58/// A 3-D occupancy grid whose voxel values are class indices (u8).
59///
60/// Layout: `voxels[z * height * width + y * width + x]` (row-major, depth last).
61/// The grid is always `200 × 200 × 16` when produced by
62/// [`worldgraph_to_occupancy`].
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct OccupancyGrid3D {
65    /// Number of voxels along the east/x axis.
66    pub width: u32,
67    /// Number of voxels along the north/y axis.
68    pub height: u32,
69    /// Number of voxels along the up/z axis.
70    pub depth: u32,
71    /// Flat class-index array, length `width * height * depth`.
72    pub voxels: Vec<u8>,
73}
74
75// ---------------------------------------------------------------------------
76// Trajectory types
77// ---------------------------------------------------------------------------
78
79/// A single point on a predicted trajectory, with a relative timestamp.
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct TrajectoryWaypoint {
82    /// East offset from installation origin, in metres.
83    pub e: f64,
84    /// North offset from installation origin, in metres.
85    pub n: f64,
86    /// Up offset (height), in metres.
87    pub u: f64,
88    /// Time offset from "now", in seconds (positive = future).
89    pub t_s: f32,
90}
91
92/// Predicted future trajectory for one tracked person.
93#[derive(Debug, Clone, Serialize, Deserialize)]
94pub struct TrajectoryPrior {
95    /// Stable track identifier (mirrors `WorldNode::PersonTrack::track_id`).
96    pub track_id: u64,
97    /// Ordered sequence of predicted future waypoints.
98    pub waypoints: Vec<TrajectoryWaypoint>,
99}
100
101// ---------------------------------------------------------------------------
102// Scene bounds (JSON wire shape)
103// ---------------------------------------------------------------------------
104
105/// Axis-aligned scene footprint sent to the OccWorld server in the IPC
106/// request.  Mirrors [`occupancy::SceneBounds`] but derives `Serialize` /
107/// `Deserialize` for direct inclusion in the JSON payload.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub struct SceneBoundsJson {
110    /// Western (minimum east) edge of the scene, in metres.
111    pub min_e: f64,
112    /// Southern (minimum north) edge of the scene, in metres.
113    pub min_n: f64,
114    /// Eastern (maximum east) edge of the scene, in metres.
115    pub max_e: f64,
116    /// Northern (maximum north) edge of the scene, in metres.
117    pub max_n: f64,
118}
119
120// ---------------------------------------------------------------------------
121// IPC request / response
122// ---------------------------------------------------------------------------
123
124/// JSON request sent from the Rust bridge to the OccWorld Python server.
125///
126/// Serialised as a single newline-terminated JSON object over the Unix socket.
127#[derive(Debug, Clone, Serialize, Deserialize)]
128pub struct OccupancyWorldModelRequest {
129    /// History of occupancy grids (chronological, oldest first).
130    /// OccWorld expects at least one frame; the reference implementation uses
131    /// the most recent 4 frames for temporal context.
132    pub past_frames: Vec<OccupancyGrid3D>,
133
134    /// Physical size of one voxel cell on the ground plane, in metres.
135    pub voxel_resolution_m: f32,
136
137    /// Scene footprint used to build the occupancy grid.
138    pub scene_bounds: SceneBoundsJson,
139
140    /// Number of future time steps to predict (reference: 15 × 0.1 s = 1.5 s).
141    pub prediction_steps: u32,
142}
143
144/// JSON response returned by the OccWorld Python server.
145///
146/// Decoded from a single newline-terminated JSON object on the Unix socket.
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct OccupancyWorldModelResponse {
149    /// Predicted future occupancy grids (chronological, `prediction_steps`
150    /// frames in total).
151    pub future_frames: Vec<OccupancyGrid3D>,
152
153    /// Per-person predicted trajectories extracted from `future_frames`.
154    pub trajectory_priors: Vec<TrajectoryPrior>,
155
156    /// Aggregate confidence score in `[0, 1]` for the entire prediction.
157    pub confidence: f32,
158
159    /// Identifier of the model that produced this response.
160    /// The sentinel prefix `"error:vram:"` signals a VRAM error (see ADR-147).
161    pub model_id: String,
162
163    /// Wall-clock time the Python server spent on inference, in milliseconds.
164    pub inference_ms: u64,
165}
166
167// ---------------------------------------------------------------------------
168// WorldGraph helper — extract PersonPosition list from a WorldGraph snapshot
169// ---------------------------------------------------------------------------
170
171use wifi_densepose_worldgraph::WorldGraph;
172
173use crate::occupancy::PersonPosition;
174
175/// Extracts all [`PersonPosition`]s from a [`WorldGraph`] by serialising the
176/// graph to its canonical JSON form (via [`WorldGraph::to_json`]) and scanning
177/// the `nodes` array for `PersonTrack` entries.
178///
179/// This avoids coupling to the private fields of `WorldGraphSnapshot`.
180/// The returned positions are unsorted; callers may sort by `track_id` if
181/// deterministic ordering is required.
182///
183/// # Panics
184/// Does not panic — if serialisation fails the function returns an empty
185/// `Vec` and logs a warning via `eprintln!`.  In practice, serialisation of a
186/// valid `WorldGraph` should never fail.
187pub fn persons_from_worldgraph(graph: &WorldGraph) -> Vec<PersonPosition> {
188    let bytes = match graph.to_json() {
189        Ok(b) => b,
190        Err(e) => {
191            eprintln!("[worldmodel] WorldGraph::to_json failed: {e}");
192            return Vec::new();
193        }
194    };
195
196    // Parse as a raw JSON value to avoid depending on the exact shape of the
197    // private `WorldGraphSnapshot` struct fields.
198    let value: serde_json::Value = match serde_json::from_slice(&bytes) {
199        Ok(v) => v,
200        Err(e) => {
201            eprintln!("[worldmodel] failed to parse WorldGraph JSON: {e}");
202            return Vec::new();
203        }
204    };
205
206    let nodes = match value.get("nodes").and_then(|n| n.as_array()) {
207        Some(arr) => arr,
208        None => return Vec::new(),
209    };
210
211    nodes
212        .iter()
213        .filter_map(|node| {
214            // Nodes use a serde-tagged enum; the PersonTrack variant carries a
215            // `kind` discriminator equal to `"person_track"`.
216            if node.get("kind")?.as_str()? != "person_track" {
217                return None;
218            }
219            let track_id = node.get("track_id")?.as_u64()?;
220            let pos = node.get("last_position")?;
221            let east_m = pos.get("east_m")?.as_f64()?;
222            let north_m = pos.get("north_m")?.as_f64()?;
223            let up_m = pos.get("up_m")?.as_f64()?;
224            Some(PersonPosition { track_id, east_m, north_m, up_m })
225        })
226        .collect()
227}
228
229// ---------------------------------------------------------------------------
230// Tests
231// ---------------------------------------------------------------------------
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn occupancy_grid_serde_roundtrip() {
239        let grid = OccupancyGrid3D {
240            width: 4,
241            height: 4,
242            depth: 2,
243            voxels: vec![17u8; 32],
244        };
245        let json = serde_json::to_string(&grid).expect("serialize");
246        let decoded: OccupancyGrid3D = serde_json::from_str(&json).expect("deserialize");
247        assert_eq!(decoded.width, grid.width);
248        assert_eq!(decoded.voxels.len(), grid.voxels.len());
249    }
250
251    #[test]
252    fn trajectory_prior_serde_roundtrip() {
253        let prior = TrajectoryPrior {
254            track_id: 42,
255            waypoints: vec![
256                TrajectoryWaypoint { e: 1.0, n: 2.0, u: 0.0, t_s: 0.1 },
257                TrajectoryWaypoint { e: 1.1, n: 2.1, u: 0.0, t_s: 0.2 },
258            ],
259        };
260        let json = serde_json::to_string(&prior).expect("serialize");
261        let decoded: TrajectoryPrior = serde_json::from_str(&json).expect("deserialize");
262        assert_eq!(decoded.track_id, 42);
263        assert_eq!(decoded.waypoints.len(), 2);
264    }
265
266    #[test]
267    fn request_serde_roundtrip() {
268        let req = OccupancyWorldModelRequest {
269            past_frames: vec![OccupancyGrid3D {
270                width: 200,
271                height: 200,
272                depth: 16,
273                voxels: vec![17u8; 200 * 200 * 16],
274            }],
275            voxel_resolution_m: 0.1,
276            scene_bounds: SceneBoundsJson {
277                min_e: -10.0,
278                min_n: -10.0,
279                max_e: 10.0,
280                max_n: 10.0,
281            },
282            prediction_steps: 15,
283        };
284        let json = serde_json::to_string(&req).expect("serialize");
285        let decoded: OccupancyWorldModelRequest =
286            serde_json::from_str(&json).expect("deserialize");
287        assert_eq!(decoded.prediction_steps, 15);
288        assert_eq!(decoded.past_frames.len(), 1);
289    }
290
291    #[test]
292    fn response_serde_roundtrip() {
293        let resp = OccupancyWorldModelResponse {
294            future_frames: vec![],
295            trajectory_priors: vec![TrajectoryPrior {
296                track_id: 1,
297                waypoints: vec![TrajectoryWaypoint { e: 0.0, n: 0.0, u: 0.0, t_s: 0.0 }],
298            }],
299            confidence: 0.82,
300            model_id: "occworld-dummy-v0".into(),
301            inference_ms: 375,
302        };
303        let json = serde_json::to_string(&resp).expect("serialize");
304        let decoded: OccupancyWorldModelResponse =
305            serde_json::from_str(&json).expect("deserialize");
306        assert_eq!(decoded.inference_ms, 375);
307        assert!((decoded.confidence - 0.82).abs() < 1e-5);
308    }
309
310    #[test]
311    fn vram_error_sentinel_roundtrip() {
312        let resp = OccupancyWorldModelResponse {
313            future_frames: vec![],
314            trajectory_priors: vec![],
315            confidence: 0.0,
316            model_id: "error:vram:out of memory (CUDA)".into(),
317            inference_ms: 0,
318        };
319        assert!(resp.model_id.starts_with("error:vram:"));
320    }
321}