1use nalgebra::{Matrix3, Vector3};
18use rayon::prelude::*;
19use std::f32::consts::PI;
20use threecrate_core::{Error, Point3f, PointCloud, Result, Vector3f};
21
22#[derive(Debug, Clone)]
24pub struct PatchworkConfig {
25 pub sensor_height: f32,
27 pub zone_radii: Vec<f32>,
30 pub num_rings_per_zone: Vec<usize>,
32 pub num_sectors_per_zone: Vec<usize>,
34 pub max_range: f32,
36 pub min_points_per_patch: usize,
38 pub num_seed_points: usize,
40 pub seed_selection_threshold: f32,
43 pub dist_threshold: f32,
45 pub num_iterations: usize,
47 pub uprightness_threshold: f32,
50 pub flatness_threshold: f32,
53 pub elevation_threshold: f32,
56}
57
58impl Default for PatchworkConfig {
59 fn default() -> Self {
60 Self {
62 sensor_height: 1.723,
63 zone_radii: vec![0.0, 2.7, 12.3625, 22.025, 80.0],
64 num_rings_per_zone: vec![2, 4, 4, 4],
65 num_sectors_per_zone: vec![16, 32, 54, 32],
66 max_range: 80.0,
67 min_points_per_patch: 10,
68 num_seed_points: 20,
69 seed_selection_threshold: 0.5,
70 dist_threshold: 0.125,
71 num_iterations: 3,
72 uprightness_threshold: 0.707,
73 flatness_threshold: 0.05,
74 elevation_threshold: 1.0,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct GroundSegmentationResult {
82 pub ground: PointCloud<Point3f>,
84 pub nonground: PointCloud<Point3f>,
86 pub labels: Vec<bool>,
88}
89
90fn validate_config(cfg: &PatchworkConfig) -> Result<()> {
91 let nz = cfg.num_rings_per_zone.len();
92 if nz == 0 {
93 return Err(Error::InvalidData("num_rings_per_zone must be non-empty".into()));
94 }
95 if cfg.zone_radii.len() != nz + 1 {
96 return Err(Error::InvalidData(
97 "zone_radii.len() must equal num_rings_per_zone.len() + 1".into(),
98 ));
99 }
100 if cfg.num_sectors_per_zone.len() != nz {
101 return Err(Error::InvalidData(
102 "num_sectors_per_zone.len() must equal num_rings_per_zone.len()".into(),
103 ));
104 }
105 if cfg.zone_radii.windows(2).any(|w| w[0] >= w[1]) {
106 return Err(Error::InvalidData("zone_radii must be strictly increasing".into()));
107 }
108 if cfg.dist_threshold <= 0.0 {
109 return Err(Error::InvalidData("dist_threshold must be positive".into()));
110 }
111 if cfg.num_seed_points == 0 {
112 return Err(Error::InvalidData("num_seed_points must be at least 1".into()));
113 }
114 if cfg.uprightness_threshold <= 0.0 || cfg.uprightness_threshold > 1.0 {
115 return Err(Error::InvalidData("uprightness_threshold must be in (0, 1]".into()));
116 }
117 Ok(())
118}
119
120fn find_zone(radius: f32, zone_radii: &[f32]) -> Option<usize> {
123 if radius < zone_radii[0] || radius >= *zone_radii.last().unwrap() {
124 return None;
125 }
126 for i in 0..zone_radii.len() - 1 {
127 if radius >= zone_radii[i] && radius < zone_radii[i + 1] {
128 return Some(i);
129 }
130 }
131 None
132}
133
134fn bucket_points(
136 points: &[Point3f],
137 cfg: &PatchworkConfig,
138) -> (Vec<Vec<Vec<Vec<usize>>>>, Vec<bool>) {
139 let mut patches: Vec<Vec<Vec<Vec<usize>>>> = (0..cfg.num_rings_per_zone.len())
141 .map(|z| {
142 (0..cfg.num_rings_per_zone[z])
143 .map(|_| (0..cfg.num_sectors_per_zone[z]).map(|_| Vec::new()).collect())
144 .collect()
145 })
146 .collect();
147 let mut out_of_range = vec![false; points.len()];
148
149 for (idx, p) in points.iter().enumerate() {
150 let r = (p.x * p.x + p.y * p.y).sqrt();
151 if r > cfg.max_range {
152 out_of_range[idx] = true;
153 continue;
154 }
155 let zone = match find_zone(r, &cfg.zone_radii) {
156 Some(z) => z,
157 None => {
158 out_of_range[idx] = true;
159 continue;
160 }
161 };
162 let r_inner = cfg.zone_radii[zone];
163 let r_outer = cfg.zone_radii[zone + 1];
164 let ring_width = (r_outer - r_inner) / cfg.num_rings_per_zone[zone] as f32;
165 let ring = (((r - r_inner) / ring_width) as usize)
166 .min(cfg.num_rings_per_zone[zone] - 1);
167
168 let mut theta = p.y.atan2(p.x);
169 if theta < 0.0 {
170 theta += 2.0 * PI;
171 }
172 let sector_width = 2.0 * PI / cfg.num_sectors_per_zone[zone] as f32;
173 let sector = ((theta / sector_width) as usize)
174 .min(cfg.num_sectors_per_zone[zone] - 1);
175
176 patches[zone][ring][sector].push(idx);
177 }
178
179 (patches, out_of_range)
180}
181
182fn pca(points: &[Point3f], indices: &[usize]) -> Option<(Vector3<f32>, [f32; 3], Matrix3<f32>)> {
185 if indices.len() < 3 {
186 return None;
187 }
188 let n = indices.len() as f32;
189 let mut mean = Vector3::<f32>::zeros();
190 for &i in indices {
191 mean += points[i].coords;
192 }
193 mean /= n;
194
195 let mut cov = Matrix3::<f32>::zeros();
196 for &i in indices {
197 let d = points[i].coords - mean;
198 cov += d * d.transpose();
199 }
200 cov /= n;
201
202 let eig = cov.symmetric_eigen();
204 let mut idx = [0usize, 1, 2];
205 idx.sort_by(|&a, &b| eig.eigenvalues[a].partial_cmp(&eig.eigenvalues[b]).unwrap());
206 let vals = [
207 eig.eigenvalues[idx[0]],
208 eig.eigenvalues[idx[1]],
209 eig.eigenvalues[idx[2]],
210 ];
211 let mut vecs = Matrix3::<f32>::zeros();
212 for k in 0..3 {
213 vecs.set_column(k, &eig.eigenvectors.column(idx[k]));
214 }
215 Some((mean, vals, vecs))
216}
217
218fn fit_patch(
222 points: &[Point3f],
223 patch: &[usize],
224 cfg: &PatchworkConfig,
225) -> Option<(Vector3<f32>, f32, Vec<usize>)> {
226 if patch.len() < cfg.min_points_per_patch {
227 return None;
228 }
229
230 let mut sorted_by_z: Vec<usize> = patch.to_vec();
232 sorted_by_z.sort_by(|&a, &b| points[a].z.partial_cmp(&points[b].z).unwrap());
233
234 let seed_count = cfg.num_seed_points.min(sorted_by_z.len());
235 let z_min_mean = {
236 let n = seed_count.min(sorted_by_z.len());
237 if n == 0 {
238 return None;
239 }
240 let mut s = 0.0;
241 for &i in &sorted_by_z[..n] {
242 s += points[i].z;
243 }
244 s / n as f32
245 };
246 let cutoff = z_min_mean + cfg.seed_selection_threshold;
247 let mut current: Vec<usize> = sorted_by_z
248 .iter()
249 .copied()
250 .take_while(|&i| points[i].z <= cutoff)
251 .collect();
252 if current.len() < 3 {
253 return None;
254 }
255
256 let mut last: Option<(Vector3<f32>, f32)> = None;
257 for _ in 0..cfg.num_iterations {
258 let (mean, _vals, vecs) = pca(points, ¤t)?;
259 let mut normal = Vector3::new(vecs[(0, 0)], vecs[(1, 0)], vecs[(2, 0)]);
261 if normal.z < 0.0 {
262 normal = -normal;
263 }
264 let d = -normal.dot(&mean);
265
266 let mut new_inliers = Vec::with_capacity(patch.len());
268 for &i in patch {
269 let dist = (normal.dot(&points[i].coords) + d).abs();
270 if dist <= cfg.dist_threshold {
271 new_inliers.push(i);
272 }
273 }
274 if new_inliers.len() < 3 {
275 return None;
276 }
277 last = Some((normal, d));
278 if new_inliers.len() == current.len() {
279 current = new_inliers;
280 break;
281 }
282 current = new_inliers;
283 }
284
285 let (normal, d) = last?;
286 Some((normal, d, current))
287}
288
289fn validate_patch(
291 points: &[Point3f],
292 inliers: &[usize],
293 normal: Vector3<f32>,
294 cfg: &PatchworkConfig,
295) -> bool {
296 if normal.z.abs() < cfg.uprightness_threshold {
298 return false;
299 }
300
301 let mut mean_z = 0.0;
303 for &i in inliers {
304 mean_z += points[i].z;
305 }
306 mean_z /= inliers.len() as f32;
307 if (mean_z + cfg.sensor_height).abs() > cfg.elevation_threshold {
308 return false;
309 }
310
311 if let Some((_mean, vals, _vecs)) = pca(points, inliers) {
313 let sum = vals[0] + vals[1] + vals[2];
314 if sum > 0.0 {
315 let ratio = vals[0] / sum;
316 if ratio > cfg.flatness_threshold {
317 return false;
318 }
319 }
320 }
321
322 true
323}
324
325pub fn patchwork_plus_plus(
327 cloud: &PointCloud<Point3f>,
328 config: PatchworkConfig,
329) -> Result<GroundSegmentationResult> {
330 validate_config(&config)?;
331 let points = &cloud.points;
332 let mut labels = vec![false; points.len()];
333
334 if points.is_empty() {
335 return Ok(GroundSegmentationResult {
336 ground: PointCloud::new(),
337 nonground: PointCloud::new(),
338 labels,
339 });
340 }
341
342 let (patches, out_of_range) = bucket_points(points, &config);
343
344 let mut flat: Vec<&Vec<usize>> = Vec::new();
346 for zone in &patches {
347 for ring in zone {
348 for sector in ring {
349 if !sector.is_empty() {
350 flat.push(sector);
351 }
352 }
353 }
354 }
355
356 let cfg_ref = &config;
357 let ground_index_sets: Vec<Vec<usize>> = flat
358 .par_iter()
359 .filter_map(|patch| {
360 let (normal, _d, inliers) = fit_patch(points, patch, cfg_ref)?;
361 if validate_patch(points, &inliers, normal, cfg_ref) {
362 Some(inliers)
363 } else {
364 None
365 }
366 })
367 .collect();
368
369 for set in ground_index_sets {
370 for i in set {
371 labels[i] = true;
372 }
373 }
374
375 for (i, oor) in out_of_range.iter().enumerate() {
377 if *oor {
378 labels[i] = false;
379 }
380 }
381
382 let mut ground = PointCloud::with_capacity(labels.iter().filter(|b| **b).count());
383 let mut nonground = PointCloud::with_capacity(labels.iter().filter(|b| !**b).count());
384 for (i, p) in points.iter().enumerate() {
385 if labels[i] {
386 ground.push(*p);
387 } else {
388 nonground.push(*p);
389 }
390 }
391
392 Ok(GroundSegmentationResult { ground, nonground, labels })
393}
394
395pub fn segment_ground(
397 cloud: &PointCloud<Point3f>,
398 sensor_height: f32,
399) -> Result<GroundSegmentationResult> {
400 let config = PatchworkConfig { sensor_height, ..Default::default() };
401 patchwork_plus_plus(cloud, config)
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use rand::prelude::*;
408
409 fn build_scene(sensor_height: f32, with_obstacles: bool) -> PointCloud<Point3f> {
410 let mut cloud = PointCloud::new();
411 let mut rng = StdRng::seed_from_u64(42);
412
413 let z_ground = -sensor_height;
415 for _ in 0..8000 {
416 let x: f32 = rng.gen_range(-30.0..30.0);
417 let y: f32 = rng.gen_range(-30.0..30.0);
418 let z = z_ground + rng.gen_range(-0.02..0.02);
419 if x * x + y * y < 0.25 {
421 continue;
422 }
423 cloud.push(Point3f::new(x, y, z));
424 }
425
426 if with_obstacles {
427 for _ in 0..1500 {
429 let x = 8.0 + rng.gen_range(-0.4..0.4);
430 let y = rng.gen_range(-3.0..3.0);
431 let z = z_ground + rng.gen_range(0.5..3.0);
432 cloud.push(Point3f::new(x, y, z));
433 }
434 for _ in 0..400 {
436 let x = -5.0 + rng.gen_range(-0.1..0.1);
437 let y = -5.0 + rng.gen_range(-0.1..0.1);
438 let z = z_ground + rng.gen_range(0.0..4.0);
439 cloud.push(Point3f::new(x, y, z));
440 }
441 }
442
443 cloud
444 }
445
446 #[test]
447 fn flat_ground_is_mostly_ground() {
448 let sensor_h = 1.8;
449 let cloud = build_scene(sensor_h, false);
450 let n = cloud.len();
451 let result = segment_ground(&cloud, sensor_h).unwrap();
452 let ground_frac = result.ground.len() as f32 / n as f32;
453 assert!(
454 ground_frac > 0.85,
455 "expected >85% ground on a flat scene, got {:.2}%",
456 ground_frac * 100.0
457 );
458 }
459
460 #[test]
461 fn obstacles_are_separated() {
462 let sensor_h = 1.8;
463 let cloud = build_scene(sensor_h, true);
464 let result = segment_ground(&cloud, sensor_h).unwrap();
465 let ground_z_mean = mean_z(&result.ground.points);
466 let nonground_z_mean = mean_z(&result.nonground.points);
467 assert!(
468 nonground_z_mean > ground_z_mean + 0.3,
469 "obstacles should sit above the ground band: ng={:.3} g={:.3}",
470 nonground_z_mean,
471 ground_z_mean
472 );
473 assert!(result.ground.len() > 0);
475 assert!(result.nonground.len() > 0);
476 assert_eq!(result.ground.len() + result.nonground.len(), cloud.len());
478 }
479
480 fn mean_z(pts: &[Point3f]) -> f32 {
481 if pts.is_empty() {
482 return 0.0;
483 }
484 pts.iter().map(|p| p.z).sum::<f32>() / pts.len() as f32
485 }
486
487 #[test]
488 fn empty_cloud_is_handled() {
489 let cloud: PointCloud<Point3f> = PointCloud::new();
490 let result = segment_ground(&cloud, 1.8).unwrap();
491 assert_eq!(result.ground.len(), 0);
492 assert_eq!(result.nonground.len(), 0);
493 }
494
495 #[test]
496 fn invalid_config_is_rejected() {
497 let cloud = build_scene(1.8, false);
498 let bad = PatchworkConfig {
499 zone_radii: vec![0.0, 10.0],
500 num_rings_per_zone: vec![2, 2],
501 num_sectors_per_zone: vec![8, 8],
502 ..Default::default()
503 };
504 assert!(patchwork_plus_plus(&cloud, bad).is_err());
505 }
506
507 #[test]
508 fn labels_match_partition() {
509 let cloud = build_scene(1.8, true);
510 let result = segment_ground(&cloud, 1.8).unwrap();
511 assert_eq!(result.labels.len(), cloud.len());
512 let ground_count = result.labels.iter().filter(|b| **b).count();
513 assert_eq!(ground_count, result.ground.len());
514 }
515}