Skip to main content

ruvector_robotics/bridge/
gaussian.rs

1//! Gaussian splatting types and point-cloud-to-Gaussian conversion.
2//!
3//! Provides a [`GaussianSplat`] representation that maps each point cloud
4//! cluster to a 3D Gaussian with position, colour, opacity, scale, and
5//! optional temporal trajectory.  The serialised format is compatible with
6//! the `vwm-viewer` Canvas2D renderer.
7
8use crate::bridge::{Point3D, PointCloud};
9use crate::perception::clustering;
10
11use serde::{Deserialize, Serialize};
12
13/// A single 3-D Gaussian suitable for splatting-based rendering.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct GaussianSplat {
16    /// Centre of the Gaussian in world coordinates.
17    pub center: [f64; 3],
18    /// RGB colour in \[0, 1\].
19    pub color: [f32; 3],
20    /// Opacity in \[0, 1\].
21    pub opacity: f32,
22    /// Anisotropic scale along each axis.
23    pub scale: [f32; 3],
24    /// Number of raw points that contributed to this Gaussian.
25    pub point_count: usize,
26    /// Semantic label (e.g. `"obstacle"`, `"ground"`).
27    pub label: String,
28    /// Temporal trajectory: each entry is a position at a successive timestep.
29    /// Empty for static Gaussians.
30    pub trajectory: Vec<[f64; 3]>,
31}
32
33/// A collection of Gaussians derived from one or more point cloud frames.
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct GaussianSplatCloud {
36    pub gaussians: Vec<GaussianSplat>,
37    pub timestamp_us: i64,
38    pub frame_id: String,
39}
40
41impl GaussianSplatCloud {
42    /// Number of Gaussians.
43    pub fn len(&self) -> usize {
44        self.gaussians.len()
45    }
46
47    pub fn is_empty(&self) -> bool {
48        self.gaussians.is_empty()
49    }
50}
51
52/// Configuration for point-cloud → Gaussian conversion.
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct GaussianConfig {
55    /// Clustering cell size in metres.  Smaller = more Gaussians.
56    pub cell_size: f64,
57    /// Minimum number of points to form a Gaussian.
58    pub min_cluster_size: usize,
59    /// Default colour for unlabelled Gaussians `[R, G, B]`.
60    pub default_color: [f32; 3],
61    /// Base opacity for generated Gaussians.
62    pub base_opacity: f32,
63}
64
65impl Default for GaussianConfig {
66    fn default() -> Self {
67        Self {
68            cell_size: 0.5,
69            min_cluster_size: 2,
70            default_color: [0.3, 0.5, 0.8],
71            base_opacity: 0.7,
72        }
73    }
74}
75
76/// Convert a [`PointCloud`] into a [`GaussianSplatCloud`] by clustering nearby
77/// points and computing per-cluster statistics.
78pub fn gaussians_from_cloud(
79    cloud: &PointCloud,
80    config: &GaussianConfig,
81) -> GaussianSplatCloud {
82    if cloud.is_empty() || config.cell_size <= 0.0 {
83        return GaussianSplatCloud {
84            gaussians: Vec::new(),
85            timestamp_us: cloud.timestamp_us,
86            frame_id: cloud.frame_id.clone(),
87        };
88    }
89
90    let clusters = clustering::cluster_point_cloud(cloud, config.cell_size);
91
92    let gaussians: Vec<GaussianSplat> = clusters
93        .into_iter()
94        .filter(|c| c.len() >= config.min_cluster_size)
95        .map(|pts| cluster_to_gaussian(&pts, config))
96        .collect();
97
98    GaussianSplatCloud {
99        gaussians,
100        timestamp_us: cloud.timestamp_us,
101        frame_id: cloud.frame_id.clone(),
102    }
103}
104
105fn cluster_to_gaussian(points: &[Point3D], config: &GaussianConfig) -> GaussianSplat {
106    let n = points.len() as f64;
107    let (mut sx, mut sy, mut sz) = (0.0_f64, 0.0_f64, 0.0_f64);
108    for p in points {
109        sx += p.x as f64;
110        sy += p.y as f64;
111        sz += p.z as f64;
112    }
113    let center = [sx / n, sy / n, sz / n];
114
115    // Compute per-axis standard deviation as the scale.
116    let (mut vx, mut vy, mut vz) = (0.0_f64, 0.0_f64, 0.0_f64);
117    for p in points {
118        let dx = p.x as f64 - center[0];
119        let dy = p.y as f64 - center[1];
120        let dz = p.z as f64 - center[2];
121        vx += dx * dx;
122        vy += dy * dy;
123        vz += dz * dz;
124    }
125    let scale = [
126        (vx / n).sqrt().max(0.01) as f32,
127        (vy / n).sqrt().max(0.01) as f32,
128        (vz / n).sqrt().max(0.01) as f32,
129    ];
130
131    // Opacity proportional to cluster density.
132    let opacity = (config.base_opacity * (points.len() as f32 / 50.0).min(1.0)).max(0.1);
133
134    GaussianSplat {
135        center,
136        color: config.default_color,
137        opacity,
138        scale,
139        point_count: points.len(),
140        label: String::new(),
141        trajectory: Vec::new(),
142    }
143}
144
145/// Serialise a [`GaussianSplatCloud`] to the JSON format expected by the
146/// `vwm-viewer` Canvas2D renderer.
147pub fn to_viewer_json(cloud: &GaussianSplatCloud) -> serde_json::Value {
148    let gs: Vec<serde_json::Value> = cloud
149        .gaussians
150        .iter()
151        .map(|g| {
152            let positions: Vec<Vec<f64>> = if g.trajectory.is_empty() {
153                vec![g.center.to_vec()]
154            } else {
155                g.trajectory.iter().map(|p| p.to_vec()).collect()
156            };
157            serde_json::json!({
158                "positions": positions,
159                "color": g.color,
160                "opacity": g.opacity,
161                "scale": g.scale,
162                "label": g.label,
163                "point_count": g.point_count,
164            })
165        })
166        .collect();
167
168    serde_json::json!({
169        "gaussians": gs,
170        "timestamp_us": cloud.timestamp_us,
171        "frame_id": cloud.frame_id,
172        "count": cloud.len(),
173    })
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    fn make_cloud(pts: &[[f32; 3]], ts: i64) -> PointCloud {
181        let points: Vec<Point3D> = pts.iter().map(|a| Point3D::new(a[0], a[1], a[2])).collect();
182        PointCloud::new(points, ts)
183    }
184
185    #[test]
186    fn test_empty_cloud() {
187        let cloud = PointCloud::default();
188        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
189        assert!(gs.is_empty());
190    }
191
192    #[test]
193    fn test_single_cluster() {
194        let cloud = make_cloud(
195            &[[1.0, 0.0, 0.0], [1.1, 0.0, 0.0], [1.0, 0.1, 0.0]],
196            1000,
197        );
198        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
199        assert_eq!(gs.len(), 1);
200        let g = &gs.gaussians[0];
201        assert_eq!(g.point_count, 3);
202        assert!(g.center[0] > 0.9 && g.center[0] < 1.2);
203    }
204
205    #[test]
206    fn test_two_clusters() {
207        let cloud = make_cloud(
208            &[
209                [0.0, 0.0, 0.0], [0.1, 0.0, 0.0],
210                [10.0, 10.0, 0.0], [10.1, 10.0, 0.0],
211            ],
212            2000,
213        );
214        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
215        assert_eq!(gs.len(), 2);
216    }
217
218    #[test]
219    fn test_min_cluster_size_filtering() {
220        let cloud = make_cloud(
221            &[[0.0, 0.0, 0.0], [10.0, 10.0, 0.0]],
222            0,
223        );
224        let config = GaussianConfig { min_cluster_size: 3, ..Default::default() };
225        let gs = gaussians_from_cloud(&cloud, &config);
226        assert!(gs.is_empty());
227    }
228
229    #[test]
230    fn test_scale_reflects_spread() {
231        // Use a larger cell size so all three points end up in one cluster.
232        let cloud = make_cloud(
233            &[[0.0, 0.0, 0.0], [0.3, 0.0, 0.0], [0.15, 0.0, 0.0]],
234            0,
235        );
236        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
237        assert_eq!(gs.len(), 1);
238        let g = &gs.gaussians[0];
239        // X-axis spread > Y/Z spread (Y/Z should be clamped minimum 0.01).
240        assert!(g.scale[0] > g.scale[1]);
241    }
242
243    #[test]
244    fn test_viewer_json_format() {
245        let cloud = make_cloud(&[[1.0, 2.0, 3.0], [1.1, 2.0, 3.0]], 5000);
246        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
247        let json = to_viewer_json(&gs);
248        assert_eq!(json["count"], 1);
249        assert_eq!(json["timestamp_us"], 5000);
250        let arr = json["gaussians"].as_array().unwrap();
251        assert_eq!(arr.len(), 1);
252        assert!(arr[0]["positions"].is_array());
253        assert!(arr[0]["color"].is_array());
254    }
255
256    #[test]
257    fn test_serde_roundtrip() {
258        let cloud = make_cloud(&[[0.0, 0.0, 0.0], [0.1, 0.1, 0.0]], 0);
259        let gs = gaussians_from_cloud(&cloud, &GaussianConfig::default());
260        let json = serde_json::to_string(&gs).unwrap();
261        let restored: GaussianSplatCloud = serde_json::from_str(&json).unwrap();
262        assert_eq!(restored.len(), gs.len());
263    }
264
265    #[test]
266    fn test_zero_cell_size() {
267        let cloud = make_cloud(&[[1.0, 0.0, 0.0]], 0);
268        let config = GaussianConfig { cell_size: 0.0, ..Default::default() };
269        let gs = gaussians_from_cloud(&cloud, &config);
270        assert!(gs.is_empty());
271    }
272}