threecrate_algorithms/
filtering.rs

1//! Filtering algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, NearestNeighborSearch};
4use crate::nearest_neighbor::BruteForceSearch;
5use rayon::prelude::*;
6
7/// Voxel grid filtering
8/// 
9/// This algorithm reduces the density of a point cloud by grouping points into voxels
10/// and keeping only one representative point per voxel. This is useful for downsampling
11/// large point clouds while preserving their overall structure.
12/// 
13/// # Arguments
14/// * `cloud` - Input point cloud
15/// * `voxel_size` - Size of each voxel cube
16/// 
17/// # Returns
18/// * `Result<PointCloud<Point3f>>` - Downsampled point cloud
19/// 
20/// # Example
21/// ```rust
22/// use threecrate_core::{PointCloud, Point3f};
23/// use threecrate_algorithms::voxel_grid_filter;
24/// 
25/// fn main() -> threecrate_core::Result<()> {
26///     let cloud = PointCloud::from_points(vec![
27///         Point3f::new(0.0, 0.0, 0.0),
28///         Point3f::new(0.1, 0.0, 0.0),
29///         Point3f::new(0.0, 0.1, 0.0),
30///         Point3f::new(0.0, 0.0, 0.1),
31///     ]);
32/// 
33///     let filtered = voxel_grid_filter(&cloud, 0.2)?;
34///     println!("Filtered cloud has {} points", filtered.len());
35///     Ok(())
36/// }
37/// ```
38pub fn voxel_grid_filter(cloud: &PointCloud<Point3f>, voxel_size: f32) -> Result<PointCloud<Point3f>> {
39    if cloud.is_empty() {
40        return Ok(PointCloud::new());
41    }
42    
43    if voxel_size <= 0.0 {
44        return Err(threecrate_core::Error::InvalidData(
45            "voxel_size must be positive".to_string()
46        ));
47    }
48    
49    // Compute bounding box
50    let min_x = cloud.points.iter().map(|p| p.x).min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
51    let min_y = cloud.points.iter().map(|p| p.y).min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
52    let min_z = cloud.points.iter().map(|p| p.z).min_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
53    let max_x = cloud.points.iter().map(|p| p.x).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
54    let max_y = cloud.points.iter().map(|p| p.y).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
55    let max_z = cloud.points.iter().map(|p| p.z).max_by(|a, b| a.partial_cmp(b).unwrap()).unwrap();
56    
57    // Compute grid dimensions (for reference, not used in hash-based approach)
58    let _grid_dim_x = ((max_x - min_x) / voxel_size).ceil() as i32 + 1;
59    let _grid_dim_y = ((max_y - min_y) / voxel_size).ceil() as i32 + 1;
60    let _grid_dim_z = ((max_z - min_z) / voxel_size).ceil() as i32 + 1;
61    
62    // Function to get voxel coordinates for a point
63    let get_voxel_coords = |point: &Point3f| -> (i32, i32, i32) {
64        let x = ((point.x - min_x) / voxel_size).floor() as i32;
65        let y = ((point.y - min_y) / voxel_size).floor() as i32;
66        let z = ((point.z - min_z) / voxel_size).floor() as i32;
67        (x, y, z)
68    };
69    
70    // Group points by voxel
71    let mut voxel_map = std::collections::HashMap::new();
72    
73    for (idx, point) in cloud.points.iter().enumerate() {
74        let voxel_coords = get_voxel_coords(point);
75        voxel_map.entry(voxel_coords).or_insert_with(Vec::new).push(idx);
76    }
77    
78    // Keep one point per voxel (the first one)
79    let filtered_points: Vec<Point3f> = voxel_map
80        .values()
81        .map(|indices| cloud.points[indices[0]])
82        .collect();
83    
84    Ok(PointCloud::from_points(filtered_points))
85}
86
87/// Radius outlier removal filter
88/// 
89/// This algorithm removes points that have fewer than `min_neighbors` neighbors
90/// within a specified radius. This is useful for removing isolated noise points
91/// that are far from the main point cloud.
92/// 
93/// # Arguments
94/// * `cloud` - Input point cloud
95/// * `radius` - Search radius for neighbor counting
96/// * `min_neighbors` - Minimum number of neighbors required to keep a point
97/// 
98/// # Returns
99/// * `Result<PointCloud<Point3f>>` - Filtered point cloud with outliers removed
100/// 
101/// # Example
102/// ```rust
103/// use threecrate_core::{PointCloud, Point3f};
104/// use threecrate_algorithms::radius_outlier_removal;
105/// 
106/// fn main() -> threecrate_core::Result<()> {
107///     let cloud = PointCloud::from_points(vec![
108///         Point3f::new(0.0, 0.0, 0.0),
109///         Point3f::new(0.1, 0.0, 0.0),
110///         Point3f::new(0.0, 0.1, 0.0),
111///         Point3f::new(10.0, 10.0, 10.0), // outlier
112///     ]);
113/// 
114///     let filtered = radius_outlier_removal(&cloud, 0.5, 2)?;
115///     println!("Filtered cloud has {} points", filtered.len());
116///     Ok(())
117/// }
118/// ```
119pub fn radius_outlier_removal(
120    cloud: &PointCloud<Point3f>,
121    radius: f32,
122    min_neighbors: usize,
123) -> Result<PointCloud<Point3f>> {
124    if cloud.is_empty() {
125        return Ok(PointCloud::new());
126    }
127    
128    if radius <= 0.0 {
129        return Err(threecrate_core::Error::InvalidData(
130            "radius must be positive".to_string()
131        ));
132    }
133    
134    if min_neighbors == 0 {
135        return Err(threecrate_core::Error::InvalidData(
136            "min_neighbors must be greater than 0".to_string()
137        ));
138    }
139    
140    // Create nearest neighbor search structure
141    let nn_search = BruteForceSearch::new(&cloud.points);
142    
143    // Count neighbors within radius for each point
144    let neighbor_counts: Vec<usize> = cloud.points
145        .par_iter()
146        .map(|point| {
147            let neighbors = nn_search.find_radius_neighbors(point, radius);
148            // Subtract 1 to exclude the point itself
149            neighbors.len().saturating_sub(1)
150        })
151        .collect();
152    
153    // Filter out points with insufficient neighbors
154    let filtered_points: Vec<Point3f> = cloud.points
155        .iter()
156        .zip(neighbor_counts.iter())
157        .filter(|(_, &count)| count >= min_neighbors)
158        .map(|(point, _)| *point)
159        .collect();
160    
161    Ok(PointCloud::from_points(filtered_points))
162}
163
164/// Statistical outlier removal filter
165/// 
166/// This algorithm removes points that are statistical outliers based on the distance
167/// to their k-nearest neighbors. For each point, it computes the mean distance to
168/// its k nearest neighbors. Points with mean distances that deviate more than
169/// `std_dev_multiplier` standard deviations from the global mean are considered
170/// outliers and removed.
171/// 
172/// # Arguments
173/// * `cloud` - Input point cloud
174/// * `k_neighbors` - Number of nearest neighbors to consider for each point
175/// * `std_dev_multiplier` - Standard deviation multiplier for outlier detection
176/// 
177/// # Returns
178/// * `Result<PointCloud<Point3f>>` - Filtered point cloud with outliers removed
179/// 
180/// # Example
181/// ```rust
182/// use threecrate_core::{PointCloud, Point3f};
183/// use threecrate_algorithms::statistical_outlier_removal;
184/// 
185/// fn main() -> threecrate_core::Result<()> {
186///     let cloud = PointCloud::from_points(vec![
187///         Point3f::new(0.0, 0.0, 0.0),
188///         Point3f::new(1.0, 0.0, 0.0),
189///         Point3f::new(0.0, 1.0, 0.0),
190///         Point3f::new(10.0, 10.0, 10.0), // outlier
191///     ]);
192/// 
193///     let filtered = statistical_outlier_removal(&cloud, 3, 1.0)?;
194///     println!("Filtered cloud has {} points", filtered.len());
195///     Ok(())
196/// }
197/// ```
198pub fn statistical_outlier_removal(
199    cloud: &PointCloud<Point3f>,
200    k_neighbors: usize,
201    std_dev_multiplier: f32,
202) -> Result<PointCloud<Point3f>> {
203    if cloud.is_empty() {
204        return Ok(PointCloud::new());
205    }
206    
207    if k_neighbors == 0 {
208        return Err(threecrate_core::Error::InvalidData(
209            "k_neighbors must be greater than 0".to_string()
210        ));
211    }
212    
213    if std_dev_multiplier <= 0.0 {
214        return Err(threecrate_core::Error::InvalidData(
215            "std_dev_multiplier must be positive".to_string()
216        ));
217    }
218    
219    // Create nearest neighbor search structure
220    let nn_search = BruteForceSearch::new(&cloud.points);
221    
222    // Compute mean distances for all points
223    let mean_distances: Vec<f32> = cloud.points
224        .par_iter()
225        .map(|point| {
226            let neighbors = nn_search.find_k_nearest(point, k_neighbors + 1); // +1 to exclude self
227            if neighbors.is_empty() {
228                return 0.0;
229            }
230            
231            // Calculate mean distance to neighbors (skip first neighbor if it's the point itself)
232            let distances: Vec<f32> = neighbors
233                .iter()
234                .filter(|(idx, _)| cloud.points[*idx] != *point) // Skip self
235                .map(|(_, distance)| *distance)
236                .collect();
237            
238            if distances.is_empty() {
239                return 0.0;
240            }
241            
242            distances.iter().sum::<f32>() / distances.len() as f32
243        })
244        .collect();
245    
246    // Compute global statistics
247    let global_mean = mean_distances.iter().sum::<f32>() / mean_distances.len() as f32;
248    
249    let variance = mean_distances
250        .iter()
251        .map(|&d| (d - global_mean).powi(2))
252        .sum::<f32>() / mean_distances.len() as f32;
253    
254    let global_std_dev = variance.sqrt();
255    let threshold = global_mean + std_dev_multiplier * global_std_dev;
256    
257    // Filter out outliers
258    let filtered_points: Vec<Point3f> = cloud.points
259        .iter()
260        .zip(mean_distances.iter())
261        .filter(|(_, &mean_dist)| mean_dist <= threshold)
262        .map(|(point, _)| *point)
263        .collect();
264    
265    Ok(PointCloud::from_points(filtered_points))
266}
267
268/// Statistical outlier removal with custom threshold
269/// 
270/// This variant allows you to specify a custom threshold instead of using
271/// the automatic standard deviation calculation.
272/// 
273/// # Arguments
274/// * `cloud` - Input point cloud
275/// * `k_neighbors` - Number of nearest neighbors to consider for each point
276/// * `threshold` - Custom threshold for outlier detection
277/// 
278/// # Returns
279/// * `Result<PointCloud<Point3f>>` - Filtered point cloud with outliers removed
280pub fn statistical_outlier_removal_with_threshold(
281    cloud: &PointCloud<Point3f>,
282    k_neighbors: usize,
283    threshold: f32,
284) -> Result<PointCloud<Point3f>> {
285    if cloud.is_empty() {
286        return Ok(PointCloud::new());
287    }
288    
289    if k_neighbors == 0 {
290        return Err(threecrate_core::Error::InvalidData(
291            "k_neighbors must be greater than 0".to_string()
292        ));
293    }
294    
295    if threshold <= 0.0 {
296        return Err(threecrate_core::Error::InvalidData(
297            "threshold must be positive".to_string()
298        ));
299    }
300    
301    // Create nearest neighbor search structure
302    let nn_search = BruteForceSearch::new(&cloud.points);
303    
304    // Compute mean distances for all points
305    let mean_distances: Vec<f32> = cloud.points
306        .par_iter()
307        .map(|point| {
308            let neighbors = nn_search.find_k_nearest(point, k_neighbors + 1); // +1 to exclude self
309            if neighbors.is_empty() {
310                return 0.0;
311            }
312            
313            // Calculate mean distance to neighbors (skip first neighbor if it's the point itself)
314            let distances: Vec<f32> = neighbors
315                .iter()
316                .filter(|(idx, _)| cloud.points[*idx] != *point) // Skip self
317                .map(|(_, distance)| *distance)
318                .collect();
319            
320            if distances.is_empty() {
321                return 0.0;
322            }
323            
324            distances.iter().sum::<f32>() / distances.len() as f32
325        })
326        .collect();
327    
328    // Filter out outliers using custom threshold
329    let filtered_points: Vec<Point3f> = cloud.points
330        .iter()
331        .zip(mean_distances.iter())
332        .filter(|(_, &mean_dist)| mean_dist <= threshold)
333        .map(|(point, _)| *point)
334        .collect();
335    
336    Ok(PointCloud::from_points(filtered_points))
337} 
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use threecrate_core::Point3f;
343
344    #[test]
345    fn test_statistical_outlier_removal_empty_cloud() {
346        let cloud = PointCloud::<Point3f>::new();
347        let result = statistical_outlier_removal(&cloud, 5, 1.0);
348        assert!(result.is_ok());
349        assert_eq!(result.unwrap().len(), 0);
350    }
351
352    #[test]
353    fn test_statistical_outlier_removal_single_point() {
354        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
355        let result = statistical_outlier_removal(&cloud, 1, 1.0);
356        assert!(result.is_ok());
357        assert_eq!(result.unwrap().len(), 1);
358    }
359
360    #[test]
361    fn test_statistical_outlier_removal_with_outliers() {
362        // Create a point cloud with some outliers
363        let mut points = Vec::new();
364        
365        // Main cluster
366        for i in 0..10 {
367            for j in 0..10 {
368                for k in 0..10 {
369                    points.push(Point3f::new(
370                        i as f32 * 0.1,
371                        j as f32 * 0.1,
372                        k as f32 * 0.1,
373                    ));
374                }
375            }
376        }
377        
378        // Add some outliers
379        points.push(Point3f::new(10.0, 10.0, 10.0));
380        points.push(Point3f::new(-10.0, -10.0, -10.0));
381        points.push(Point3f::new(5.0, 5.0, 5.0));
382        
383        let cloud = PointCloud::from_points(points);
384        let original_count = cloud.len();
385        
386        let result = statistical_outlier_removal(&cloud, 5, 1.0);
387        assert!(result.is_ok());
388        
389        let filtered = result.unwrap();
390        assert!(filtered.len() < original_count);
391        assert!(filtered.len() > 0);
392        
393        // Check that outliers were removed
394        let has_outlier_1 = filtered.points.iter().any(|p| 
395            (p.x - 10.0).abs() < 0.1 && (p.y - 10.0).abs() < 0.1 && (p.z - 10.0).abs() < 0.1
396        );
397        let has_outlier_2 = filtered.points.iter().any(|p| 
398            (p.x + 10.0).abs() < 0.1 && (p.y + 10.0).abs() < 0.1 && (p.z + 10.0).abs() < 0.1
399        );
400        
401        assert!(!has_outlier_1);
402        assert!(!has_outlier_2);
403    }
404
405    #[test]
406    fn test_statistical_outlier_removal_no_outliers() {
407        // Create a uniform point cloud without outliers
408        let mut points = Vec::new();
409        for i in 0..5 {
410            for j in 0..5 {
411                for k in 0..5 {
412                    points.push(Point3f::new(
413                        i as f32 * 0.1,
414                        j as f32 * 0.1,
415                        k as f32 * 0.1,
416                    ));
417                }
418            }
419        }
420        
421        let cloud = PointCloud::from_points(points);
422        let original_count = cloud.len();
423        
424        let result = statistical_outlier_removal(&cloud, 5, 1.0);
425        assert!(result.is_ok());
426        
427        let filtered = result.unwrap();
428        // Should keep most points since there are no real outliers
429        assert!(filtered.len() > original_count * 8 / 10);
430    }
431
432    #[test]
433    fn test_statistical_outlier_removal_invalid_k() {
434        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
435        let result = statistical_outlier_removal(&cloud, 0, 1.0);
436        assert!(result.is_err());
437    }
438
439    #[test]
440    fn test_statistical_outlier_removal_invalid_std_dev() {
441        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
442        let result = statistical_outlier_removal(&cloud, 5, 0.0);
443        assert!(result.is_err());
444        
445        let result = statistical_outlier_removal(&cloud, 5, -1.0);
446        assert!(result.is_err());
447    }
448
449    #[test]
450    fn test_statistical_outlier_removal_with_threshold() {
451        // Create a point cloud with known outliers
452        let points = vec![
453            Point3f::new(0.0, 0.0, 0.0),
454            Point3f::new(0.1, 0.0, 0.0),
455            Point3f::new(0.0, 0.1, 0.0),
456            Point3f::new(0.0, 0.0, 0.1),
457            Point3f::new(10.0, 10.0, 10.0), // outlier
458        ];
459        
460        let cloud = PointCloud::from_points(points);
461        
462        // Use a very low threshold to remove outliers
463        let result = statistical_outlier_removal_with_threshold(&cloud, 3, 0.5);
464        assert!(result.is_ok());
465        
466        let filtered = result.unwrap();
467        assert_eq!(filtered.len(), 4); // Should remove the outlier
468        
469        // Check that the outlier was removed
470        let has_outlier = filtered.points.iter().any(|p| 
471            (p.x - 10.0).abs() < 0.1 && (p.y - 10.0).abs() < 0.1 && (p.z - 10.0).abs() < 0.1
472        );
473        assert!(!has_outlier);
474    }
475
476    #[test]
477    fn test_statistical_outlier_removal_with_threshold_invalid() {
478        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
479        let result = statistical_outlier_removal_with_threshold(&cloud, 0, 1.0);
480        assert!(result.is_err());
481        
482        let result = statistical_outlier_removal_with_threshold(&cloud, 5, 0.0);
483        assert!(result.is_err());
484    }
485
486    #[test]
487    fn test_voxel_grid_filter_empty_cloud() {
488        let cloud = PointCloud::<Point3f>::new();
489        let result = voxel_grid_filter(&cloud, 0.1);
490        assert!(result.is_ok());
491        assert_eq!(result.unwrap().len(), 0);
492    }
493
494    #[test]
495    fn test_voxel_grid_filter_single_point() {
496        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
497        let result = voxel_grid_filter(&cloud, 0.1);
498        assert!(result.is_ok());
499        assert_eq!(result.unwrap().len(), 1);
500    }
501
502    #[test]
503    fn test_voxel_grid_filter_with_duplicates() {
504        let cloud = PointCloud::from_points(vec![
505            Point3f::new(0.0, 0.0, 0.0),
506            Point3f::new(0.0, 0.0, 0.0), // duplicate
507            Point3f::new(0.1, 0.0, 0.0),
508            Point3f::new(0.1, 0.0, 0.0), // duplicate
509            Point3f::new(0.0, 0.1, 0.0),
510        ]);
511        
512        let result = voxel_grid_filter(&cloud, 0.05);
513        assert!(result.is_ok());
514        let filtered = result.unwrap();
515        assert_eq!(filtered.len(), 3); // Should remove duplicates
516    }
517
518    #[test]
519    fn test_voxel_grid_filter_invalid_voxel_size() {
520        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
521        let result = voxel_grid_filter(&cloud, 0.0);
522        assert!(result.is_err());
523        
524        let result = voxel_grid_filter(&cloud, -1.0);
525        assert!(result.is_err());
526    }
527
528    #[test]
529    fn test_radius_outlier_removal_empty_cloud() {
530        let cloud = PointCloud::<Point3f>::new();
531        let result = radius_outlier_removal(&cloud, 0.5, 3);
532        assert!(result.is_ok());
533        assert_eq!(result.unwrap().len(), 0);
534    }
535
536    #[test]
537    fn test_radius_outlier_removal_single_point() {
538        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
539        let result = radius_outlier_removal(&cloud, 0.5, 1);
540        assert!(result.is_ok());
541        // Should remove the single point since it has no neighbors
542        assert_eq!(result.unwrap().len(), 0);
543    }
544
545    #[test]
546    fn test_radius_outlier_removal_with_outliers() {
547        // Create a point cloud with some outliers
548        let mut points = Vec::new();
549        
550        // Main cluster
551        for i in 0..5 {
552            for j in 0..5 {
553                points.push(Point3f::new(
554                    i as f32 * 0.1,
555                    j as f32 * 0.1,
556                    0.0,
557                ));
558            }
559        }
560        
561        // Add some outliers
562        points.push(Point3f::new(10.0, 10.0, 10.0));
563        points.push(Point3f::new(-10.0, -10.0, -10.0));
564        
565        let cloud = PointCloud::from_points(points);
566        let original_count = cloud.len();
567        
568        let result = radius_outlier_removal(&cloud, 0.5, 2);
569        assert!(result.is_ok());
570        
571        let filtered = result.unwrap();
572        assert!(filtered.len() < original_count);
573        assert!(filtered.len() > 0);
574        
575        // Check that outliers were removed
576        let has_outlier_1 = filtered.points.iter().any(|p| 
577            (p.x - 10.0).abs() < 0.1 && (p.y - 10.0).abs() < 0.1 && (p.z - 10.0).abs() < 0.1
578        );
579        let has_outlier_2 = filtered.points.iter().any(|p| 
580            (p.x + 10.0).abs() < 0.1 && (p.y + 10.0).abs() < 0.1 && (p.z + 10.0).abs() < 0.1
581        );
582        
583        assert!(!has_outlier_1);
584        assert!(!has_outlier_2);
585    }
586
587    #[test]
588    fn test_radius_outlier_removal_invalid_parameters() {
589        let cloud = PointCloud::from_points(vec![Point3f::new(0.0, 0.0, 0.0)]);
590        
591        let result = radius_outlier_removal(&cloud, 0.0, 3);
592        assert!(result.is_err());
593        
594        let result = radius_outlier_removal(&cloud, -1.0, 3);
595        assert!(result.is_err());
596        
597        let result = radius_outlier_removal(&cloud, 0.5, 0);
598        assert!(result.is_err());
599    }
600
601    #[test]
602    fn test_radius_outlier_removal_different_parameters() {
603        let mut points = Vec::new();
604        
605        // Create a cluster
606        for i in 0..3 {
607            for j in 0..3 {
608                points.push(Point3f::new(
609                    i as f32 * 0.1,
610                    j as f32 * 0.1,
611                    0.0,
612                ));
613            }
614        }
615        
616        // Add some outliers
617        points.push(Point3f::new(1.0, 1.0, 1.0));
618        
619        let cloud = PointCloud::from_points(points);
620        
621        // Test different radius values
622        let result1 = radius_outlier_removal(&cloud, 0.2, 2).unwrap();
623        let result2 = radius_outlier_removal(&cloud, 0.5, 2).unwrap();
624        
625        // Larger radius should keep more points
626        assert!(result2.len() >= result1.len());
627    }
628}