Skip to main content

threecrate_algorithms/
ground_segmentation.rs

1//! Ground segmentation for outdoor LiDAR point clouds.
2//!
3//! Implements a Patchwork++-style algorithm that splits a point cloud into
4//! ground and non-ground subsets. The approach uses a Concentric Zone Model
5//! (CZM): the XY plane around the sensor is divided into concentric rings and
6//! angular sectors, producing a grid of patches whose size adapts with range.
7//! Each patch is fit with a Region-wise Ground Plane Fit (R-GPF) — seed points
8//! near the patch minimum are selected, a plane is fit by PCA, then inliers
9//! are extracted and the plane is refit iteratively. Each candidate plane is
10//! then validated by three criteria: **uprightness** (normal nearly vertical),
11//! **elevation** (within an allowed height band per zone), and **flatness**
12//! (smallest eigenvalue ratio).
13//!
14//! Reference: Lee et al., "Patchwork++: Fast and Robust Ground Segmentation
15//! Solving Partial Under-Segmentation Using 3D Point Cloud", IROS 2022.
16
17use nalgebra::{Matrix3, Vector3};
18use rayon::prelude::*;
19use std::f32::consts::PI;
20use threecrate_core::{Error, Point3f, PointCloud, Result, Vector3f};
21
22/// Configuration for the Patchwork++ ground segmentation algorithm.
23#[derive(Debug, Clone)]
24pub struct PatchworkConfig {
25    /// Height of the LiDAR sensor above the expected ground plane (meters).
26    pub sensor_height: f32,
27    /// Concentric zone boundary radii in meters, length = num_zones + 1.
28    /// Each adjacent pair defines one zone (inner_radius, outer_radius].
29    pub zone_radii: Vec<f32>,
30    /// Number of concentric rings per zone (length = num_zones).
31    pub num_rings_per_zone: Vec<usize>,
32    /// Number of angular sectors per zone (length = num_zones).
33    pub num_sectors_per_zone: Vec<usize>,
34    /// Maximum sensing range (meters); points beyond are treated as non-ground.
35    pub max_range: f32,
36    /// Minimum number of points required in a patch to attempt ground fitting.
37    pub min_points_per_patch: usize,
38    /// Minimum number of seed points used in the initial PCA fit.
39    pub num_seed_points: usize,
40    /// A point is a seed candidate if its z is within this distance of the
41    /// patch's minimum z value.
42    pub seed_selection_threshold: f32,
43    /// Inlier distance threshold (meters) used when refining the plane fit.
44    pub dist_threshold: f32,
45    /// Number of refit iterations of the R-GPF inner loop.
46    pub num_iterations: usize,
47    /// Minimum |n_z| required for a patch's plane to count as ground
48    /// (cos of maximum allowed slope; 0.707 ≈ 45°).
49    pub uprightness_threshold: f32,
50    /// Maximum allowed flatness ratio = lambda_min / (lambda_0 + lambda_1 + lambda_min);
51    /// smaller means the patch must be flatter.
52    pub flatness_threshold: f32,
53    /// Maximum signed deviation of the patch mean z from `-sensor_height`
54    /// (meters). Patches sitting unreasonably high or low are rejected.
55    pub elevation_threshold: f32,
56}
57
58impl Default for PatchworkConfig {
59    fn default() -> Self {
60        // Defaults follow the Patchwork++ reference implementation.
61        Self {
62            sensor_height: 1.723,
63            zone_radii: vec![0.0, 2.7, 12.3625, 22.025, 80.0],
64            num_rings_per_zone: vec![2, 4, 4, 4],
65            num_sectors_per_zone: vec![16, 32, 54, 32],
66            max_range: 80.0,
67            min_points_per_patch: 10,
68            num_seed_points: 20,
69            seed_selection_threshold: 0.5,
70            dist_threshold: 0.125,
71            num_iterations: 3,
72            uprightness_threshold: 0.707,
73            flatness_threshold: 0.05,
74            elevation_threshold: 1.0,
75        }
76    }
77}
78
79/// Result of ground segmentation.
80#[derive(Debug, Clone)]
81pub struct GroundSegmentationResult {
82    /// Points classified as ground.
83    pub ground: PointCloud<Point3f>,
84    /// Points classified as non-ground (obstacles, vegetation above ground, etc.).
85    pub nonground: PointCloud<Point3f>,
86    /// Per-input-point labels: true if classified as ground.
87    pub labels: Vec<bool>,
88}
89
90fn validate_config(cfg: &PatchworkConfig) -> Result<()> {
91    let nz = cfg.num_rings_per_zone.len();
92    if nz == 0 {
93        return Err(Error::InvalidData("num_rings_per_zone must be non-empty".into()));
94    }
95    if cfg.zone_radii.len() != nz + 1 {
96        return Err(Error::InvalidData(
97            "zone_radii.len() must equal num_rings_per_zone.len() + 1".into(),
98        ));
99    }
100    if cfg.num_sectors_per_zone.len() != nz {
101        return Err(Error::InvalidData(
102            "num_sectors_per_zone.len() must equal num_rings_per_zone.len()".into(),
103        ));
104    }
105    if cfg.zone_radii.windows(2).any(|w| w[0] >= w[1]) {
106        return Err(Error::InvalidData("zone_radii must be strictly increasing".into()));
107    }
108    if cfg.dist_threshold <= 0.0 {
109        return Err(Error::InvalidData("dist_threshold must be positive".into()));
110    }
111    if cfg.num_seed_points == 0 {
112        return Err(Error::InvalidData("num_seed_points must be at least 1".into()));
113    }
114    if cfg.uprightness_threshold <= 0.0 || cfg.uprightness_threshold > 1.0 {
115        return Err(Error::InvalidData("uprightness_threshold must be in (0, 1]".into()));
116    }
117    Ok(())
118}
119
120/// Decide which zone a (range) value falls into; returns `None` if it's
121/// outside the configured zone range.
122fn find_zone(radius: f32, zone_radii: &[f32]) -> Option<usize> {
123    if radius < zone_radii[0] || radius >= *zone_radii.last().unwrap() {
124        return None;
125    }
126    for i in 0..zone_radii.len() - 1 {
127        if radius >= zone_radii[i] && radius < zone_radii[i + 1] {
128            return Some(i);
129        }
130    }
131    None
132}
133
134/// Group point indices into CZM patches keyed by (zone, ring, sector).
135fn bucket_points(
136    points: &[Point3f],
137    cfg: &PatchworkConfig,
138) -> (Vec<Vec<Vec<Vec<usize>>>>, Vec<bool>) {
139    // patches[zone][ring][sector] = Vec<point_index>
140    let mut patches: Vec<Vec<Vec<Vec<usize>>>> = (0..cfg.num_rings_per_zone.len())
141        .map(|z| {
142            (0..cfg.num_rings_per_zone[z])
143                .map(|_| (0..cfg.num_sectors_per_zone[z]).map(|_| Vec::new()).collect())
144                .collect()
145        })
146        .collect();
147    let mut out_of_range = vec![false; points.len()];
148
149    for (idx, p) in points.iter().enumerate() {
150        let r = (p.x * p.x + p.y * p.y).sqrt();
151        if r > cfg.max_range {
152            out_of_range[idx] = true;
153            continue;
154        }
155        let zone = match find_zone(r, &cfg.zone_radii) {
156            Some(z) => z,
157            None => {
158                out_of_range[idx] = true;
159                continue;
160            }
161        };
162        let r_inner = cfg.zone_radii[zone];
163        let r_outer = cfg.zone_radii[zone + 1];
164        let ring_width = (r_outer - r_inner) / cfg.num_rings_per_zone[zone] as f32;
165        let ring = (((r - r_inner) / ring_width) as usize)
166            .min(cfg.num_rings_per_zone[zone] - 1);
167
168        let mut theta = p.y.atan2(p.x);
169        if theta < 0.0 {
170            theta += 2.0 * PI;
171        }
172        let sector_width = 2.0 * PI / cfg.num_sectors_per_zone[zone] as f32;
173        let sector = ((theta / sector_width) as usize)
174            .min(cfg.num_sectors_per_zone[zone] - 1);
175
176        patches[zone][ring][sector].push(idx);
177    }
178
179    (patches, out_of_range)
180}
181
182/// PCA on a set of points; returns (mean, eigenvalues sorted ascending,
183/// corresponding eigenvectors as columns of the matrix).
184fn pca(points: &[Point3f], indices: &[usize]) -> Option<(Vector3<f32>, [f32; 3], Matrix3<f32>)> {
185    if indices.len() < 3 {
186        return None;
187    }
188    let n = indices.len() as f32;
189    let mut mean = Vector3::<f32>::zeros();
190    for &i in indices {
191        mean += points[i].coords;
192    }
193    mean /= n;
194
195    let mut cov = Matrix3::<f32>::zeros();
196    for &i in indices {
197        let d = points[i].coords - mean;
198        cov += d * d.transpose();
199    }
200    cov /= n;
201
202    // Symmetric eigendecomposition.
203    let eig = cov.symmetric_eigen();
204    let mut idx = [0usize, 1, 2];
205    idx.sort_by(|&a, &b| eig.eigenvalues[a].partial_cmp(&eig.eigenvalues[b]).unwrap());
206    let vals = [
207        eig.eigenvalues[idx[0]],
208        eig.eigenvalues[idx[1]],
209        eig.eigenvalues[idx[2]],
210    ];
211    let mut vecs = Matrix3::<f32>::zeros();
212    for k in 0..3 {
213        vecs.set_column(k, &eig.eigenvectors.column(idx[k]));
214    }
215    Some((mean, vals, vecs))
216}
217
218/// Fit a ground plane to a single patch using R-GPF; returns plane parameters
219/// (normal, d) and the inlier indices (into the original `points` slice) if a
220/// valid plane was found.
221fn fit_patch(
222    points: &[Point3f],
223    patch: &[usize],
224    cfg: &PatchworkConfig,
225) -> Option<(Vector3<f32>, f32, Vec<usize>)> {
226    if patch.len() < cfg.min_points_per_patch {
227        return None;
228    }
229
230    // Initial seeds: points whose z is within seed_selection_threshold of min z.
231    let mut sorted_by_z: Vec<usize> = patch.to_vec();
232    sorted_by_z.sort_by(|&a, &b| points[a].z.partial_cmp(&points[b].z).unwrap());
233
234    let seed_count = cfg.num_seed_points.min(sorted_by_z.len());
235    let z_min_mean = {
236        let n = seed_count.min(sorted_by_z.len());
237        if n == 0 {
238            return None;
239        }
240        let mut s = 0.0;
241        for &i in &sorted_by_z[..n] {
242            s += points[i].z;
243        }
244        s / n as f32
245    };
246    let cutoff = z_min_mean + cfg.seed_selection_threshold;
247    let mut current: Vec<usize> = sorted_by_z
248        .iter()
249        .copied()
250        .take_while(|&i| points[i].z <= cutoff)
251        .collect();
252    if current.len() < 3 {
253        return None;
254    }
255
256    let mut last: Option<(Vector3<f32>, f32)> = None;
257    for _ in 0..cfg.num_iterations {
258        let (mean, _vals, vecs) = pca(points, &current)?;
259        // Smallest-eigenvalue eigenvector = surface normal.
260        let mut normal = Vector3::new(vecs[(0, 0)], vecs[(1, 0)], vecs[(2, 0)]);
261        if normal.z < 0.0 {
262            normal = -normal;
263        }
264        let d = -normal.dot(&mean);
265
266        // Re-collect inliers from the full patch.
267        let mut new_inliers = Vec::with_capacity(patch.len());
268        for &i in patch {
269            let dist = (normal.dot(&points[i].coords) + d).abs();
270            if dist <= cfg.dist_threshold {
271                new_inliers.push(i);
272            }
273        }
274        if new_inliers.len() < 3 {
275            return None;
276        }
277        last = Some((normal, d));
278        if new_inliers.len() == current.len() {
279            current = new_inliers;
280            break;
281        }
282        current = new_inliers;
283    }
284
285    let (normal, d) = last?;
286    Some((normal, d, current))
287}
288
289/// Validate a fitted patch under uprightness, elevation, and flatness criteria.
290fn validate_patch(
291    points: &[Point3f],
292    inliers: &[usize],
293    normal: Vector3<f32>,
294    cfg: &PatchworkConfig,
295) -> bool {
296    // Uprightness: normal must be close to +z.
297    if normal.z.abs() < cfg.uprightness_threshold {
298        return false;
299    }
300
301    // Mean z elevation: expected near -sensor_height.
302    let mut mean_z = 0.0;
303    for &i in inliers {
304        mean_z += points[i].z;
305    }
306    mean_z /= inliers.len() as f32;
307    if (mean_z + cfg.sensor_height).abs() > cfg.elevation_threshold {
308        return false;
309    }
310
311    // Flatness: smallest eigenvalue should be small compared to the others.
312    if let Some((_mean, vals, _vecs)) = pca(points, inliers) {
313        let sum = vals[0] + vals[1] + vals[2];
314        if sum > 0.0 {
315            let ratio = vals[0] / sum;
316            if ratio > cfg.flatness_threshold {
317                return false;
318            }
319        }
320    }
321
322    true
323}
324
325/// Run Patchwork++-style ground segmentation on a point cloud.
326pub fn patchwork_plus_plus(
327    cloud: &PointCloud<Point3f>,
328    config: PatchworkConfig,
329) -> Result<GroundSegmentationResult> {
330    validate_config(&config)?;
331    let points = &cloud.points;
332    let mut labels = vec![false; points.len()];
333
334    if points.is_empty() {
335        return Ok(GroundSegmentationResult {
336            ground: PointCloud::new(),
337            nonground: PointCloud::new(),
338            labels,
339        });
340    }
341
342    let (patches, out_of_range) = bucket_points(points, &config);
343
344    // Flatten patches so we can process them in parallel.
345    let mut flat: Vec<&Vec<usize>> = Vec::new();
346    for zone in &patches {
347        for ring in zone {
348            for sector in ring {
349                if !sector.is_empty() {
350                    flat.push(sector);
351                }
352            }
353        }
354    }
355
356    let cfg_ref = &config;
357    let ground_index_sets: Vec<Vec<usize>> = flat
358        .par_iter()
359        .filter_map(|patch| {
360            let (normal, _d, inliers) = fit_patch(points, patch, cfg_ref)?;
361            if validate_patch(points, &inliers, normal, cfg_ref) {
362                Some(inliers)
363            } else {
364                None
365            }
366        })
367        .collect();
368
369    for set in ground_index_sets {
370        for i in set {
371            labels[i] = true;
372        }
373    }
374
375    // Anything out of range is non-ground.
376    for (i, oor) in out_of_range.iter().enumerate() {
377        if *oor {
378            labels[i] = false;
379        }
380    }
381
382    let mut ground = PointCloud::with_capacity(labels.iter().filter(|b| **b).count());
383    let mut nonground = PointCloud::with_capacity(labels.iter().filter(|b| !**b).count());
384    for (i, p) in points.iter().enumerate() {
385        if labels[i] {
386            ground.push(*p);
387        } else {
388            nonground.push(*p);
389        }
390    }
391
392    Ok(GroundSegmentationResult { ground, nonground, labels })
393}
394
395/// Convenience wrapper using default configuration with a given sensor height.
396pub fn segment_ground(
397    cloud: &PointCloud<Point3f>,
398    sensor_height: f32,
399) -> Result<GroundSegmentationResult> {
400    let config = PatchworkConfig { sensor_height, ..Default::default() };
401    patchwork_plus_plus(cloud, config)
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use rand::prelude::*;
408
409    fn build_scene(sensor_height: f32, with_obstacles: bool) -> PointCloud<Point3f> {
410        let mut cloud = PointCloud::new();
411        let mut rng = StdRng::seed_from_u64(42);
412
413        // Flat ground at z = -sensor_height over a 60×60 area, with mild noise.
414        let z_ground = -sensor_height;
415        for _ in 0..8000 {
416            let x: f32 = rng.gen_range(-30.0..30.0);
417            let y: f32 = rng.gen_range(-30.0..30.0);
418            let z = z_ground + rng.gen_range(-0.02..0.02);
419            // Skip points right under the sensor (too close to origin).
420            if x * x + y * y < 0.25 {
421                continue;
422            }
423            cloud.push(Point3f::new(x, y, z));
424        }
425
426        if with_obstacles {
427            // Tall vertical "wall" / obstacle cluster.
428            for _ in 0..1500 {
429                let x = 8.0 + rng.gen_range(-0.4..0.4);
430                let y = rng.gen_range(-3.0..3.0);
431                let z = z_ground + rng.gen_range(0.5..3.0);
432                cloud.push(Point3f::new(x, y, z));
433            }
434            // A pole.
435            for _ in 0..400 {
436                let x = -5.0 + rng.gen_range(-0.1..0.1);
437                let y = -5.0 + rng.gen_range(-0.1..0.1);
438                let z = z_ground + rng.gen_range(0.0..4.0);
439                cloud.push(Point3f::new(x, y, z));
440            }
441        }
442
443        cloud
444    }
445
446    #[test]
447    fn flat_ground_is_mostly_ground() {
448        let sensor_h = 1.8;
449        let cloud = build_scene(sensor_h, false);
450        let n = cloud.len();
451        let result = segment_ground(&cloud, sensor_h).unwrap();
452        let ground_frac = result.ground.len() as f32 / n as f32;
453        assert!(
454            ground_frac > 0.85,
455            "expected >85% ground on a flat scene, got {:.2}%",
456            ground_frac * 100.0
457        );
458    }
459
460    #[test]
461    fn obstacles_are_separated() {
462        let sensor_h = 1.8;
463        let cloud = build_scene(sensor_h, true);
464        let result = segment_ground(&cloud, sensor_h).unwrap();
465        let ground_z_mean = mean_z(&result.ground.points);
466        let nonground_z_mean = mean_z(&result.nonground.points);
467        assert!(
468            nonground_z_mean > ground_z_mean + 0.3,
469            "obstacles should sit above the ground band: ng={:.3} g={:.3}",
470            nonground_z_mean,
471            ground_z_mean
472        );
473        // Both classes should be non-empty.
474        assert!(result.ground.len() > 0);
475        assert!(result.nonground.len() > 0);
476        // ground + nonground must equal input.
477        assert_eq!(result.ground.len() + result.nonground.len(), cloud.len());
478    }
479
480    fn mean_z(pts: &[Point3f]) -> f32 {
481        if pts.is_empty() {
482            return 0.0;
483        }
484        pts.iter().map(|p| p.z).sum::<f32>() / pts.len() as f32
485    }
486
487    #[test]
488    fn empty_cloud_is_handled() {
489        let cloud: PointCloud<Point3f> = PointCloud::new();
490        let result = segment_ground(&cloud, 1.8).unwrap();
491        assert_eq!(result.ground.len(), 0);
492        assert_eq!(result.nonground.len(), 0);
493    }
494
495    #[test]
496    fn invalid_config_is_rejected() {
497        let cloud = build_scene(1.8, false);
498        let bad = PatchworkConfig {
499            zone_radii: vec![0.0, 10.0],
500            num_rings_per_zone: vec![2, 2],
501            num_sectors_per_zone: vec![8, 8],
502            ..Default::default()
503        };
504        assert!(patchwork_plus_plus(&cloud, bad).is_err());
505    }
506
507    #[test]
508    fn labels_match_partition() {
509        let cloud = build_scene(1.8, true);
510        let result = segment_ground(&cloud, 1.8).unwrap();
511        assert_eq!(result.labels.len(), cloud.len());
512        let ground_count = result.labels.iter().filter(|b| **b).count();
513        assert_eq!(ground_count, result.ground.len());
514    }
515}