threecrate_algorithms/
segmentation.rs

1//! Segmentation algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f, Error};
4use nalgebra::{Vector4};
5use rayon::prelude::*;
6use rand::prelude::*;
7use std::collections::HashSet;
8
9/// A 3D plane model defined by the equation ax + by + cz + d = 0
10#[derive(Debug, Clone, PartialEq)]
11pub struct PlaneModel {
12    /// Plane coefficients [a, b, c, d] where ax + by + cz + d = 0
13    pub coefficients: Vector4<f32>,
14}
15
16impl PlaneModel {
17    /// Create a new plane model from coefficients
18    pub fn new(a: f32, b: f32, c: f32, d: f32) -> Self {
19        Self {
20            coefficients: Vector4::new(a, b, c, d),
21        }
22    }
23
24    /// Create a plane model from three points
25    pub fn from_points(p1: &Point3f, p2: &Point3f, p3: &Point3f) -> Option<Self> {
26        // Calculate two vectors in the plane
27        let v1 = p2 - p1;
28        let v2 = p3 - p1;
29        
30        // Calculate normal vector using cross product
31        let normal = v1.cross(&v2);
32        
33        // Check if points are collinear
34        if normal.magnitude() < 1e-8 {
35            return None;
36        }
37        
38        let normal = normal.normalize();
39        
40        // Calculate d coefficient using point p1
41        let d = -normal.dot(&p1.coords);
42        
43        Some(PlaneModel::new(normal.x, normal.y, normal.z, d))
44    }
45
46    /// Get the normal vector of the plane
47    pub fn normal(&self) -> Vector3f {
48        Vector3f::new(
49            self.coefficients.x,
50            self.coefficients.y,
51            self.coefficients.z,
52        )
53    }
54
55    /// Calculate the distance from a point to the plane
56    pub fn distance_to_point(&self, point: &Point3f) -> f32 {
57        let normal = self.normal();
58        let normal_magnitude = normal.magnitude();
59        
60        if normal_magnitude < 1e-8 {
61            return f32::INFINITY;
62        }
63        
64        (self.coefficients.x * point.x + 
65         self.coefficients.y * point.y + 
66         self.coefficients.z * point.z + 
67         self.coefficients.w).abs() / normal_magnitude
68    }
69
70    /// Count inliers within a distance threshold
71    pub fn count_inliers(&self, points: &[Point3f], threshold: f32) -> usize {
72        points.iter()
73            .filter(|point| self.distance_to_point(point) <= threshold)
74            .count()
75    }
76
77    /// Get indices of inlier points within a distance threshold
78    pub fn get_inliers(&self, points: &[Point3f], threshold: f32) -> Vec<usize> {
79        points.iter()
80            .enumerate()
81            .filter(|(_, point)| self.distance_to_point(point) <= threshold)
82            .map(|(i, _)| i)
83            .collect()
84    }
85}
86
87/// RANSAC plane segmentation result
88#[derive(Debug, Clone)]
89pub struct PlaneSegmentationResult {
90    /// The best plane model found
91    pub model: PlaneModel,
92    /// Indices of inlier points
93    pub inliers: Vec<usize>,
94    /// Number of RANSAC iterations performed
95    pub iterations: usize,
96}
97
98/// Plane segmentation using RANSAC algorithm
99/// 
100/// This function finds the best plane that fits the most points in the cloud
101/// using the RANSAC (Random Sample Consensus) algorithm.
102/// 
103/// # Arguments
104/// * `cloud` - Input point cloud
105/// * `threshold` - Maximum distance for a point to be considered an inlier
106/// * `max_iters` - Maximum number of RANSAC iterations
107/// 
108/// # Returns
109/// * `Result<PlaneSegmentationResult>` - The best plane model and inlier indices
110pub fn segment_plane(
111    cloud: &PointCloud<Point3f>, 
112    threshold: f32, 
113    max_iters: usize
114) -> Result<PlaneSegmentationResult> {
115    if cloud.len() < 3 {
116        return Err(Error::InvalidData("Need at least 3 points for plane segmentation".to_string()));
117    }
118
119    if threshold <= 0.0 {
120        return Err(Error::InvalidData("Threshold must be positive".to_string()));
121    }
122
123    if max_iters == 0 {
124        return Err(Error::InvalidData("Max iterations must be positive".to_string()));
125    }
126
127    let points = &cloud.points;
128    let mut rng = thread_rng();
129    let mut best_model: Option<PlaneModel> = None;
130    let mut best_inliers = Vec::new();
131    let mut best_score = 0;
132
133    for _iteration in 0..max_iters {
134        // Randomly sample 3 points
135        let mut indices = HashSet::new();
136        while indices.len() < 3 {
137            indices.insert(rng.gen_range(0..points.len()));
138        }
139        let indices: Vec<usize> = indices.into_iter().collect();
140
141        let p1 = &points[indices[0]];
142        let p2 = &points[indices[1]];
143        let p3 = &points[indices[2]];
144
145        // Try to create a plane model from these points
146        if let Some(model) = PlaneModel::from_points(p1, p2, p3) {
147            // Count inliers
148            let inlier_count = model.count_inliers(points, threshold);
149            
150            // Update best model if this one is better
151            if inlier_count > best_score {
152                best_score = inlier_count;
153                best_inliers = model.get_inliers(points, threshold);
154                best_model = Some(model);
155            }
156        }
157    }
158
159    match best_model {
160        Some(model) => Ok(PlaneSegmentationResult {
161            model,
162            inliers: best_inliers,
163            iterations: max_iters,
164        }),
165        None => Err(Error::Algorithm("Failed to find valid plane model".to_string())),
166    }
167}
168
169/// Parallel RANSAC plane segmentation for better performance on large point clouds
170/// 
171/// This version uses parallel processing to speed up the RANSAC algorithm
172/// by running multiple iterations in parallel.
173/// 
174/// # Arguments
175/// * `cloud` - Input point cloud
176/// * `threshold` - Maximum distance for a point to be considered an inlier
177/// * `max_iters` - Maximum number of RANSAC iterations
178/// 
179/// # Returns
180/// * `Result<PlaneSegmentationResult>` - The best plane model and inlier indices
181pub fn segment_plane_parallel(
182    cloud: &PointCloud<Point3f>, 
183    threshold: f32, 
184    max_iters: usize
185) -> Result<PlaneSegmentationResult> {
186    if cloud.len() < 3 {
187        return Err(Error::InvalidData("Need at least 3 points for plane segmentation".to_string()));
188    }
189
190    if threshold <= 0.0 {
191        return Err(Error::InvalidData("Threshold must be positive".to_string()));
192    }
193
194    if max_iters == 0 {
195        return Err(Error::InvalidData("Max iterations must be positive".to_string()));
196    }
197
198    let points = &cloud.points;
199    
200    // Run RANSAC iterations in parallel
201    let results: Vec<_> = (0..max_iters)
202        .into_par_iter()
203        .filter_map(|_| {
204            let mut rng = thread_rng();
205            
206            // Randomly sample 3 points
207            let mut indices = HashSet::new();
208            while indices.len() < 3 {
209                indices.insert(rng.gen_range(0..points.len()));
210            }
211            let indices: Vec<usize> = indices.into_iter().collect();
212
213            let p1 = &points[indices[0]];
214            let p2 = &points[indices[1]];
215            let p3 = &points[indices[2]];
216
217            // Try to create a plane model from these points
218            PlaneModel::from_points(p1, p2, p3).map(|model| {
219                let inliers = model.get_inliers(points, threshold);
220                let score = inliers.len();
221                (model, inliers, score)
222            })
223        })
224        .collect();
225
226    // Find the best result
227    let best = results.into_iter()
228        .max_by_key(|(_, _, score)| *score);
229
230    match best {
231        Some((model, inliers, _)) => Ok(PlaneSegmentationResult {
232            model,
233            inliers,
234            iterations: max_iters,
235        }),
236        None => Err(Error::Algorithm("Failed to find valid plane model".to_string())),
237    }
238}
239
240/// Legacy function for backward compatibility
241#[deprecated(note = "Use segment_plane instead which returns a complete result")]
242pub fn segment_plane_legacy(cloud: &PointCloud<Point3f>, threshold: f32) -> Result<Vec<usize>> {
243    let result = segment_plane(cloud, threshold, 1000)?;
244    Ok(result.inliers)
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250    use approx::assert_relative_eq;
251
252    #[test]
253    fn test_plane_model_from_points() {
254        // Create a plane in XY plane (z=0)
255        let p1 = Point3f::new(0.0, 0.0, 0.0);
256        let p2 = Point3f::new(1.0, 0.0, 0.0);
257        let p3 = Point3f::new(0.0, 1.0, 0.0);
258
259        let model = PlaneModel::from_points(&p1, &p2, &p3).unwrap();
260        
261        // Normal should be close to (0, 0, 1) or (0, 0, -1)
262        let normal = model.normal();
263        assert!(normal.z.abs() > 0.9, "Normal should be primarily in Z direction: {:?}", normal);
264        
265        // Distance to points on the plane should be ~0
266        assert!(model.distance_to_point(&p1) < 1e-6);
267        assert!(model.distance_to_point(&p2) < 1e-6);
268        assert!(model.distance_to_point(&p3) < 1e-6);
269    }
270
271    #[test]
272    fn test_plane_model_collinear_points() {
273        // Create collinear points
274        let p1 = Point3f::new(0.0, 0.0, 0.0);
275        let p2 = Point3f::new(1.0, 0.0, 0.0);
276        let p3 = Point3f::new(2.0, 0.0, 0.0);
277
278        let model = PlaneModel::from_points(&p1, &p2, &p3);
279        assert!(model.is_none(), "Should return None for collinear points");
280    }
281
282    #[test]
283    fn test_plane_distance_calculation() {
284        // Create a plane at z=1
285        let model = PlaneModel::new(0.0, 0.0, 1.0, -1.0);
286        
287        let point_on_plane = Point3f::new(0.0, 0.0, 1.0);
288        let point_above_plane = Point3f::new(0.0, 0.0, 2.0);
289        let point_below_plane = Point3f::new(0.0, 0.0, 0.0);
290        
291        assert_relative_eq!(model.distance_to_point(&point_on_plane), 0.0, epsilon = 1e-6);
292        assert_relative_eq!(model.distance_to_point(&point_above_plane), 1.0, epsilon = 1e-6);
293        assert_relative_eq!(model.distance_to_point(&point_below_plane), 1.0, epsilon = 1e-6);
294    }
295
296    #[test]
297    fn test_segment_plane_simple() {
298        // Create a point cloud with most points on a plane
299        let mut cloud = PointCloud::new();
300        
301        // Add points on XY plane (z=0)
302        for i in 0..10 {
303            for j in 0..10 {
304                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
305            }
306        }
307        
308        // Add a few outliers
309        cloud.push(Point3f::new(5.0, 5.0, 10.0));
310        cloud.push(Point3f::new(5.0, 5.0, -10.0));
311        
312        let result = segment_plane(&cloud, 0.1, 100).unwrap();
313        
314        // Should find most of the points as inliers
315        assert!(result.inliers.len() >= 95, "Should find most points as inliers");
316        
317        // Normal should be close to (0, 0, 1) or (0, 0, -1)
318        let normal = result.model.normal();
319        assert!(normal.z.abs() > 0.9, "Normal should be primarily in Z direction");
320    }
321
322    #[test]
323    fn test_segment_plane_insufficient_points() {
324        let mut cloud = PointCloud::new();
325        cloud.push(Point3f::new(0.0, 0.0, 0.0));
326        cloud.push(Point3f::new(1.0, 0.0, 0.0));
327        
328        let result = segment_plane(&cloud, 0.1, 100);
329        assert!(result.is_err(), "Should fail with insufficient points");
330    }
331
332    #[test]
333    fn test_segment_plane_invalid_threshold() {
334        let mut cloud = PointCloud::new();
335        cloud.push(Point3f::new(0.0, 0.0, 0.0));
336        cloud.push(Point3f::new(1.0, 0.0, 0.0));
337        cloud.push(Point3f::new(0.0, 1.0, 0.0));
338        
339        let result = segment_plane(&cloud, -0.1, 100);
340        assert!(result.is_err(), "Should fail with negative threshold");
341    }
342
343    #[test]
344    fn test_segment_plane_parallel() {
345        // Create a point cloud with most points on a plane
346        let mut cloud = PointCloud::new();
347        
348        // Add points on XY plane (z=0)
349        for i in 0..10 {
350            for j in 0..10 {
351                cloud.push(Point3f::new(i as f32, j as f32, 0.0));
352            }
353        }
354        
355        let result = segment_plane_parallel(&cloud, 0.1, 100).unwrap();
356        
357        // Should find most of the points as inliers
358        assert!(result.inliers.len() >= 95, "Should find most points as inliers");
359    }
360}