Skip to main content

threecrate_algorithms/
segmentation.rs

1//! Segmentation algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f, Error};
4use nalgebra::{Vector4};
5use rayon::prelude::*;
6use rand::prelude::*;
7use rand::thread_rng;
8use std::collections::HashSet;
9
10/// A 3D plane model defined by the equation ax + by + cz + d = 0
11#[derive(Debug, Clone, PartialEq)]
12pub struct PlaneModel {
13    /// Plane coefficients [a, b, c, d] where ax + by + cz + d = 0
14    pub coefficients: Vector4<f32>,
15}
16
17impl PlaneModel {
18    /// Create a new plane model from coefficients
19    pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
20        Self {
21            coefficients: Vector4::new(a, b, c, d),
22        }
23    }
24
25    /// Create a plane model from three points
26    pub fn from_points(p1: &Point3f, p2: &Point3f, p3: &Point3f) -> Option<Self> {
27        // Calculate two vectors in the plane
28        let v1 = p2 - p1;
29        let v2 = p3 - p1;
30        
31        // Calculate normal vector using cross product
32        let normal = v1.cross(&v2);
33        
34        // Check if points are collinear
35        if normal.magnitude() < 1e-8 {
36            return None;
37        }
38        
39        let normal = normal.normalize();
40        
41        // Calculate d coefficient using point p1
42        let d = -normal.dot(&p1.coords);
43        
44        Some(PlaneModel::new(normal.x, normal.y, normal.z, d))
45    }
46
47    /// Get the normal vector of the plane
48    pub fn normal(&self) -> Vector3f {
49        Vector3f::new(
50            self.coefficients.x,
51            self.coefficients.y,
52            self.coefficients.z,
53        )
54    }
55
56    /// Calculate the distance from a point to the plane
57    pub fn distance_to_point(&self, point: &Point3f) -> f32 {
58        let normal = self.normal();
59        let normal_magnitude = normal.magnitude();
60        
61        if normal_magnitude < 1e-8 {
62            return f32::INFINITY;
63        }
64        
65        (self.coefficients.x * point.x + 
66         self.coefficients.y * point.y + 
67         self.coefficients.z * point.z + 
68         self.coefficients.w).abs() / normal_magnitude
69    }
70
71    /// Count inliers within a distance threshold
72    pub fn count_inliers(&self, points: &[Point3f], threshold: f32) -> usize {
73        points.iter()
74            .filter(|point| self.distance_to_point(point) <= threshold)
75            .count()
76    }
77
78    /// Get indices of inlier points within a distance threshold
79    pub fn get_inliers(&self, points: &[Point3f], threshold: f32) -> Vec<usize> {
80        points.iter()
81            .enumerate()
82            .filter(|(_, point)| self.distance_to_point(point) <= threshold)
83            .map(|(i, _)| i)
84            .collect()
85    }
86}
87
88/// RANSAC plane segmentation result
89#[derive(Debug, Clone)]
90pub struct PlaneSegmentationResult {
91    /// The best plane model found
92    pub model: PlaneModel,
93    /// Indices of inlier points
94    pub inliers: Vec<usize>,
95    /// Number of RANSAC iterations performed
96    pub iterations: usize,
97}
98
99/// Plane segmentation using RANSAC algorithm
100/// 
101/// This function finds the best plane that fits the most points in the cloud
102/// using the RANSAC (Random Sample Consensus) algorithm.
103/// 
104/// # Arguments
105/// * `cloud` - Input point cloud
106/// * `threshold` - Maximum distance for a point to be considered an inlier
107/// * `max_iters` - Maximum number of RANSAC iterations
108/// 
109/// # Returns
110/// * `Result<PlaneSegmentationResult>` - The best plane model and inlier indices
111pub fn segment_plane(
112    cloud: &PointCloud<Point3f>, 
113    threshold: f32, 
114    max_iters: usize
115) -> Result<PlaneSegmentationResult> {
116    if cloud.len() < 3 {
117        return Err(Error::InvalidData("Need at least 3 points for plane segmentation".to_string()));
118    }
119
120    if threshold <= 0.0 {
121        return Err(Error::InvalidData("Threshold must be positive".to_string()));
122    }
123
124    if max_iters == 0 {
125        return Err(Error::InvalidData("Max iterations must be positive".to_string()));
126    }
127
128    let points = &cloud.points;
129    let mut rng = thread_rng();
130    let mut best_model: Option<PlaneModel> = None;
131    let mut best_inliers = Vec::new();
132    let mut best_score = 0;
133
134    for _iteration in 0..max_iters {
135        // Randomly sample 3 points
136        let mut indices = HashSet::new();
137        while indices.len() < 3 {
138            indices.insert(rng.gen_range(0..points.len()));
139        }
140        let indices: Vec<usize> = indices.into_iter().collect();
141
142        let p1 = &points[indices[0]];
143        let p2 = &points[indices[1]];
144        let p3 = &points[indices[2]];
145
146        // Try to create a plane model from these points
147        if let Some(model) = PlaneModel::from_points(p1, p2, p3) {
148            // Count inliers
149            let inlier_count = model.count_inliers(points, threshold);
150            
151            // Update best model if this one is better
152            if inlier_count > best_score {
153                best_score = inlier_count;
154                best_inliers = model.get_inliers(points, threshold);
155                best_model = Some(model);
156            }
157        }
158    }
159
160    match best_model {
161        Some(model) => Ok(PlaneSegmentationResult {
162            model,
163            inliers: best_inliers,
164            iterations: max_iters,
165        }),
166        None => Err(Error::Algorithm("Failed to find valid plane model".to_string())),
167    }
168}
169
170/// Parallel RANSAC plane segmentation for better performance on large point clouds
171/// 
172/// This version uses parallel processing to speed up the RANSAC algorithm
173/// by running multiple iterations in parallel.
174/// 
175/// # Arguments
176/// * `cloud` - Input point cloud
177/// * `threshold` - Maximum distance for a point to be considered an inlier
178/// * `max_iters` - Maximum number of RANSAC iterations
179/// 
180/// # Returns
181/// * `Result<PlaneSegmentationResult>` - The best plane model and inlier indices
182pub fn segment_plane_parallel(
183    cloud: &PointCloud<Point3f>, 
184    threshold: f32, 
185    max_iters: usize
186) -> Result<PlaneSegmentationResult> {
187    if cloud.len() < 3 {
188        return Err(Error::InvalidData("Need at least 3 points for plane segmentation".to_string()));
189    }
190
191    if threshold <= 0.0 {
192        return Err(Error::InvalidData("Threshold must be positive".to_string()));
193    }
194
195    if max_iters == 0 {
196        return Err(Error::InvalidData("Max iterations must be positive".to_string()));
197    }
198
199    let points = &cloud.points;
200    
201    // Run RANSAC iterations in parallel
202    let results: Vec<_> = (0..max_iters)
203        .into_par_iter()
204        .filter_map(|_| {
205            let mut rng = thread_rng();
206            
207            // Randomly sample 3 points
208            let mut indices = HashSet::new();
209            while indices.len() < 3 {
210                indices.insert(rng.gen_range(0..points.len()));
211            }
212            let indices: Vec<usize> = indices.into_iter().collect();
213
214            let p1 = &points[indices[0]];
215            let p2 = &points[indices[1]];
216            let p3 = &points[indices[2]];
217
218            // Try to create a plane model from these points
219            PlaneModel::from_points(p1, p2, p3).map(|model| {
220                let inliers = model.get_inliers(points, threshold);
221                let score = inliers.len();
222                (model, inliers, score)
223            })
224        })
225        .collect();
226
227    // Find the best result
228    let best = results.into_iter()
229        .max_by_key(|(_, _, score)| *score);
230
231    match best {
232        Some((model, inliers, _)) => Ok(PlaneSegmentationResult {
233            model,
234            inliers,
235            iterations: max_iters,
236        }),
237        None => Err(Error::Algorithm("Failed to find valid plane model".to_string())),
238    }
239}
240
241/// Legacy function for backward compatibility
242#[deprecated(note = "Use segment_plane instead which returns a complete result")]
243pub fn segment_plane_legacy(cloud: &PointCloud<Point3f>, threshold: f32) -> Result<Vec<usize>> {
244    let result = segment_plane(cloud, threshold, 1000)?;
245    Ok(result.inliers)
246}
247
248/// RANSAC plane segmentation with simplified interface
249/// 
250/// This function provides a simplified interface for RANSAC plane segmentation
251/// that returns plane coefficients and inlier indices directly.
252/// 
253/// # Arguments
254/// * `cloud` - Input point cloud
255/// * `max_iters` - Maximum number of RANSAC iterations
256/// * `threshold` - Maximum distance for a point to be considered an inlier
257/// 
258/// # Returns
259/// * `Result<(Vector4<f32>, Vec<usize>)>` - Plane coefficients and inlier indices
260/// 
261/// # Example
262/// ```rust
263/// use threecrate_algorithms::segment_plane_ransac;
264/// use threecrate_core::{PointCloud, Point3f};
265/// use nalgebra::Vector4;
266/// 
267/// fn main() -> Result<(), Box<dyn std::error::Error>> {
268///     let cloud = PointCloud::from_points(vec![
269///         Point3f::new(0.0, 0.0, 0.0),
270///         Point3f::new(1.0, 0.0, 0.0),
271///         Point3f::new(0.0, 1.0, 0.0),
272///     ]);
273/// 
274///     let (coefficients, inliers) = segment_plane_ransac(&cloud, 1000, 0.01)?;
275///     println!("Plane coefficients: {:?}", coefficients);
276///     println!("Found {} inliers", inliers.len());
277///     Ok(())
278/// }
279/// ```
280pub fn segment_plane_ransac(
281    cloud: &PointCloud<Point3f>,
282    max_iters: usize,
283    threshold: f32,
284) -> Result<(Vector4<f32>, Vec<usize>)> {
285    let result = segment_plane(cloud, threshold, max_iters)?;
286    Ok((result.model.coefficients, result.inliers))
287}
288
289/// RANSAC plane segmentation (alias for segment_plane_ransac)
290/// 
291/// This function is an alias for `segment_plane_ransac` to maintain compatibility
292/// with the README documentation.
293/// 
294/// # Arguments
295/// * `cloud` - Input point cloud
296/// * `max_iters` - Maximum number of RANSAC iterations
297/// * `threshold` - Maximum distance for a point to be considered an inlier
298/// 
299/// # Returns
300/// * `Result<(Vector4<f32>, Vec<usize>)>` - Plane coefficients and inlier indices
301pub fn plane_segmentation_ransac(
302    cloud: &PointCloud<Point3f>,
303    max_iters: usize,
304    threshold: f32,
305) -> Result<(Vector4<f32>, Vec<usize>)> {
306    segment_plane_ransac(cloud, max_iters, threshold)
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312    use approx::assert_relative_eq;
313
314    #[test]
315    fn test_plane_model_from_points() {
316        // Create a plane in XY plane (z=0)
317        let p1 = Point3f::new(0.0, 0.0, 0.0);
318        let p2 = Point3f::new(1.0, 0.0, 0.0);
319        let p3 = Point3f::new(0.0, 1.0, 0.0);
320
321        let model = PlaneModel::from_points(&p1, &p2, &p3).unwrap();
322        
323        // Normal should be close to (0, 0, 1) or (0, 0, -1)
324        let normal = model.normal();
325        assert!(normal.z.abs() > 0.9, "Normal should be primarily in Z direction: {:?}", normal);
326        
327        // Distance to points on the plane should be ~0
328        assert!(model.distance_to_point(&p1) < 1e-6);
329        assert!(model.distance_to_point(&p2) < 1e-6);
330        assert!(model.distance_to_point(&p3) < 1e-6);
331    }
332
333    #[test]
334    fn test_plane_model_collinear_points() {
335        // Create collinear points
336        let p1 = Point3f::new(0.0, 0.0, 0.0);
337        let p2 = Point3f::new(1.0, 0.0, 0.0);
338        let p3 = Point3f::new(2.0, 0.0, 0.0);
339
340        let model = PlaneModel::from_points(&p1, &p2, &p3);
341        assert!(model.is_none(), "Should return None for collinear points");
342    }
343
344    #[test]
345    fn test_plane_distance_calculation() {
346        // Create a plane at z=1
347        let model = PlaneModel::new(0.0, 0.0, 1.0, -1.0);
348        
349        let point_on_plane = Point3f::new(0.0, 0.0, 1.0);
350        let point_above_plane = Point3f::new(0.0, 0.0, 2.0);
351        let point_below_plane = Point3f::new(0.0, 0.0, 0.0);
352        
353        assert_relative_eq!(model.distance_to_point(&point_on_plane), 0.0, epsilon = 1e-6);
354        assert_relative_eq!(model.distance_to_point(&point_above_plane), 1.0, epsilon = 1e-6);
355        assert_relative_eq!(model.distance_to_point(&point_below_plane), 1.0, epsilon = 1e-6);
356    }
357
358    #[test]
359    fn test_segment_plane_simple() {
360        // Create a point cloud with most points on a plane
361        let mut cloud = PointCloud::new();
362        
363        // Add points on XY plane (z=0)
364        for i in 0..10 {
365            for j in 0..10 {
366                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
367            }
368        }
369        
370        // Add a few outliers
371        cloud.push(Point3f::new(5.0, 5.0, 10.0));
372        cloud.push(Point3f::new(5.0, 5.0, -10.0));
373        
374        let result = segment_plane(&cloud, 0.1, 100).unwrap();
375        
376        // Should find most of the points as inliers
377        assert!(result.inliers.len() >= 95, "Should find most points as inliers");
378        
379        // Normal should be close to (0, 0, 1) or (0, 0, -1)
380        let normal = result.model.normal();
381        assert!(normal.z.abs() > 0.9, "Normal should be primarily in Z direction");
382    }
383
384    #[test]
385    fn test_segment_plane_insufficient_points() {
386        let mut cloud = PointCloud::new();
387        cloud.push(Point3f::new(0.0, 0.0, 0.0));
388        cloud.push(Point3f::new(1.0, 0.0, 0.0));
389        
390        let result = segment_plane(&cloud, 0.1, 100);
391        assert!(result.is_err(), "Should fail with insufficient points");
392    }
393
394    #[test]
395    fn test_segment_plane_invalid_threshold() {
396        let mut cloud = PointCloud::new();
397        cloud.push(Point3f::new(0.0, 0.0, 0.0));
398        cloud.push(Point3f::new(1.0, 0.0, 0.0));
399        cloud.push(Point3f::new(0.0, 1.0, 0.0));
400        
401        let result = segment_plane(&cloud, -0.1, 100);
402        assert!(result.is_err(), "Should fail with negative threshold");
403    }
404
405    #[test]
406    fn test_segment_plane_parallel() {
407        // Create a point cloud with most points on a plane
408        let mut cloud = PointCloud::new();
409        
410        // Add points on XY plane (z=0)
411        for i in 0..10 {
412            for j in 0..10 {
413                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
414            }
415        }
416        
417        let result = segment_plane_parallel(&cloud, 0.1, 100).unwrap();
418        
419        // Should find most of the points as inliers
420        assert!(result.inliers.len() >= 95, "Should find most points as inliers");
421    }
422
423    #[test]
424    fn test_segment_plane_ransac_simple() {
425        // Create a point cloud with most points on a plane
426        let mut cloud = PointCloud::new();
427        
428        // Add points on XY plane (z=0)
429        for i in 0..10 {
430            for j in 0..10 {
431                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
432            }
433        }
434        
435        // Add a few outliers
436        cloud.push(Point3f::new(5.0, 5.0, 10.0));
437        cloud.push(Point3f::new(5.0, 5.0, -10.0));
438        
439        let (coefficients, inliers) = segment_plane_ransac(&cloud, 100, 0.1).unwrap();
440        
441        // Should find most of the points as inliers
442        assert!(inliers.len() >= 95, "Should find most points as inliers");
443        
444        // Normal should be close to (0, 0, 1) or (0, 0, -1)
445        let normal = Vector3f::new(coefficients.x, coefficients.y, coefficients.z);
446        assert!(normal.z.abs() > 0.9, "Normal should be primarily in Z direction: {:?}", normal);
447    }
448
449    #[test]
450    fn test_segment_plane_ransac_noisy() {
451        // Create a point cloud with noisy planar points
452        let mut cloud = PointCloud::new();
453        let mut rng = thread_rng();
454        
455        // Add points on XY plane (z=0) with noise
456        for i in 0..20 {
457            for j in 0..20 {
458                let x = i as f32;
459                let y = j as f32;
460                let z = rng.gen_range(-0.05..0.05); // Add noise to z coordinate
461                cloud.push(Point3f::new(x, y, z));
462            }
463        }
464        
465        // Add some outliers
466        for _ in 0..20 {
467            let x = rng.gen_range(0.0..20.0);
468            let y = rng.gen_range(0.0..20.0);
469            let z = rng.gen_range(1.0..5.0); // Outliers above the plane
470            cloud.push(Point3f::new(x, y, z));
471        }
472        
473        let (coefficients, inliers) = segment_plane_ransac(&cloud, 1000, 0.1).unwrap();
474        
475        // Should find most of the planar points as inliers
476        assert!(inliers.len() >= 350, "Should find most planar points as inliers");
477        
478        // Normal should be close to (0, 0, 1) or (0, 0, -1)
479        let normal = Vector3f::new(coefficients.x, coefficients.y, coefficients.z);
480        assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
481        
482        // Test that outliers are not included in inliers
483        let outlier_indices: Vec<usize> = (400..420).collect();
484        let outlier_inliers: Vec<usize> = inliers.iter()
485            .filter(|&&idx| outlier_indices.contains(&idx))
486            .cloned()
487            .collect();
488        assert!(outlier_inliers.len() <= 2, "Should not include many outliers in inliers");
489    }
490
491    #[test]
492    fn test_segment_plane_ransac_tilted_plane() {
493        // Create a tilted plane (not aligned with coordinate axes)
494        let mut cloud = PointCloud::new();
495        let mut rng = thread_rng();
496        
497        // Create a tilted plane: x + y + z = 0
498        for i in 0..15 {
499            for j in 0..15 {
500                let x = i as f32;
501                let y = j as f32;
502                let z = -(x + y); // Points on the plane x + y + z = 0
503                
504                // Add some noise
505                let noise_x = rng.gen_range(-0.02..0.02);
506                let noise_y = rng.gen_range(-0.02..0.02);
507                let noise_z = rng.gen_range(-0.02..0.02);
508                
509                cloud.push(Point3f::new(x + noise_x, y + noise_y, z + noise_z));
510            }
511        }
512        
513        // Add outliers
514        for _ in 0..30 {
515            let x = rng.gen_range(0.0..15.0);
516            let y = rng.gen_range(0.0..15.0);
517            let z = rng.gen_range(5.0..10.0); // Outliers above the plane
518            cloud.push(Point3f::new(x, y, z));
519        }
520        
521        let (coefficients, inliers) = segment_plane_ransac(&cloud, 1000, 0.1).unwrap();
522        
523        // Should find most of the planar points as inliers
524        assert!(inliers.len() >= 200, "Should find most planar points as inliers");
525        
526        // Normal should be close to (1, 1, 1) normalized
527        let normal = Vector3f::new(coefficients.x, coefficients.y, coefficients.z);
528        let expected_normal = Vector3f::new(1.0, 1.0, 1.0).normalize();
529        let dot_product = normal.dot(&expected_normal).abs();
530        assert!(dot_product > 0.8, "Normal should be close to expected direction: {:?}", normal);
531    }
532
533    #[test]
534    fn test_plane_segmentation_ransac_alias() {
535        // Test that plane_segmentation_ransac is an alias for segment_plane_ransac
536        let mut cloud = PointCloud::new();
537        
538        // Add points on XY plane (z=0)
539        for i in 0..5 {
540            for j in 0..5 {
541                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
542            }
543        }
544        
545        let result1 = segment_plane_ransac(&cloud, 100, 0.1).unwrap();
546        let result2 = plane_segmentation_ransac(&cloud, 100, 0.1).unwrap();
547        
548        // Both should return valid results (RANSAC is stochastic, so exact values may differ)
549        assert!(result1.1.len() >= 20, "Should find most points as inliers");
550        assert!(result2.1.len() >= 20, "Should find most points as inliers");
551        
552        // Both should have similar inlier counts (within reasonable bounds)
553        let diff = (result1.1.len() as i32 - result2.1.len() as i32).abs();
554        assert!(diff <= 5, "Inlier counts should be similar: {} vs {}", result1.1.len(), result2.1.len());
555    }
556
557    #[test]
558    fn test_segment_plane_ransac_insufficient_points() {
559        let mut cloud = PointCloud::new();
560        cloud.push(Point3f::new(0.0, 0.0, 0.0));
561        cloud.push(Point3f::new(1.0, 0.0, 0.0));
562        
563        let result = segment_plane_ransac(&cloud, 100, 0.1);
564        assert!(result.is_err(), "Should fail with insufficient points");
565    }
566
567    #[test]
568    fn test_segment_plane_ransac_invalid_threshold() {
569        let mut cloud = PointCloud::new();
570        cloud.push(Point3f::new(0.0, 0.0, 0.0));
571        cloud.push(Point3f::new(1.0, 0.0, 0.0));
572        cloud.push(Point3f::new(0.0, 1.0, 0.0));
573        
574        let result = segment_plane_ransac(&cloud, 100, -0.1);
575        assert!(result.is_err(), "Should fail with negative threshold");
576    }
577
578    #[test]
579    fn test_segment_plane_ransac_zero_iterations() {
580        let mut cloud = PointCloud::new();
581        cloud.push(Point3f::new(0.0, 0.0, 0.0));
582        cloud.push(Point3f::new(1.0, 0.0, 0.0));
583        cloud.push(Point3f::new(0.0, 1.0, 0.0));
584        
585        let result = segment_plane_ransac(&cloud, 0, 0.1);
586        assert!(result.is_err(), "Should fail with zero iterations");
587    }
588}