threecrate_algorithms/
normals.rs

1//! Normal estimation algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f, NormalPoint3f, Error};
4use nalgebra::Matrix3;
5use rayon::prelude::*;
6use std::collections::BinaryHeap;
7use std::cmp::Ordering;
8
9/// Configuration for normal estimation
10#[derive(Debug, Clone)]
11pub struct NormalEstimationConfig {
12    /// Number of nearest neighbors to use (k-NN)
13    pub k_neighbors: usize,
14    /// Optional radius for radius-based neighbor search
15    pub radius: Option<f32>,
16    /// Whether to enforce orientation consistency
17    pub consistent_orientation: bool,
18    /// Viewpoint for orientation consistency (if None, uses positive Z direction)
19    pub viewpoint: Option<Point3f>,
20}
21
22impl Default for NormalEstimationConfig {
23    fn default() -> Self {
24        Self {
25            k_neighbors: 10,
26            radius: None,
27            consistent_orientation: true,
28            viewpoint: None,
29        }
30    }
31}
32
33/// A simple distance-based neighbor for priority queue
34#[derive(Debug, Clone)]
35struct Neighbor {
36    index: usize,
37    distance: f32,
38}
39
40impl PartialEq for Neighbor {
41    fn eq(&self, other: &Self) -> bool {
42        self.distance == other.distance
43    }
44}
45
46impl Eq for Neighbor {}
47
48impl PartialOrd for Neighbor {
49    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50        // Reverse ordering for min-heap behavior
51        other.distance.partial_cmp(&self.distance)
52    }
53}
54
55impl Ord for Neighbor {
56    fn cmp(&self, other: &Self) -> Ordering {
57        self.partial_cmp(other).unwrap_or(Ordering::Equal)
58    }
59}
60
61/// Find k-nearest neighbors using brute force search
62fn find_k_nearest_neighbors(points: &[Point3f], query_idx: usize, k: usize) -> Vec<usize> {
63    let query = &points[query_idx];
64    let mut heap = BinaryHeap::with_capacity(k + 1);
65    
66    for (i, point) in points.iter().enumerate() {
67        if i == query_idx {
68            continue; // Skip the query point itself
69        }
70        
71        let distance = (point - query).magnitude_squared();
72        let neighbor = Neighbor { index: i, distance };
73        
74        if heap.len() < k {
75            heap.push(neighbor);
76        } else if let Some(farthest) = heap.peek() {
77            if neighbor.distance < farthest.distance {
78                heap.pop();
79                heap.push(neighbor);
80            }
81        }
82    }
83    
84    heap.into_iter().map(|n| n.index).collect()
85}
86
87/// Find neighbors within a radius
88fn find_radius_neighbors(points: &[Point3f], query_idx: usize, radius: f32) -> Vec<usize> {
89    let query = &points[query_idx];
90    let radius_squared = radius * radius;
91    
92    points.iter()
93        .enumerate()
94        .filter(|(i, point)| {
95            *i != query_idx && (**point - query).magnitude_squared() <= radius_squared
96        })
97        .map(|(i, _)| i)
98        .collect()
99}
100
101/// Find neighbors using either k-NN or radius-based search
102fn find_neighbors(points: &[Point3f], query_idx: usize, config: &NormalEstimationConfig) -> Vec<usize> {
103    if let Some(radius) = config.radius {
104        // Use radius-based search
105        find_radius_neighbors(points, query_idx, radius)
106    } else {
107        // Use k-NN search
108        find_k_nearest_neighbors(points, query_idx, config.k_neighbors)
109    }
110}
111
112/// Compute normal using PCA on the neighborhood points
113fn compute_normal_pca(points: &[Point3f], indices: &[usize]) -> Vector3f {
114    if indices.len() < 3 {
115        // Default normal if not enough points
116        return Vector3f::new(0.0, 0.0, 1.0);
117    }
118    
119    // Compute centroid
120    let mut centroid = Point3f::origin();
121    for &idx in indices {
122        centroid += points[idx].coords;
123    }
124    centroid /= indices.len() as f32;
125    
126    // Build covariance matrix
127    let mut covariance = Matrix3::zeros();
128    for &idx in indices {
129        let diff = points[idx] - centroid;
130        covariance += diff * diff.transpose();
131    }
132    covariance /= indices.len() as f32;
133    
134    // Find eigenvector corresponding to smallest eigenvalue
135    // This is the normal direction
136    let eigen = covariance.symmetric_eigen();
137    let eigenvalues = eigen.eigenvalues;
138    let eigenvectors = eigen.eigenvectors;
139    
140    // Find index of smallest eigenvalue
141    let mut min_idx = 0;
142    for i in 1..3 {
143        if eigenvalues[i] < eigenvalues[min_idx] {
144            min_idx = i;
145        }
146    }
147    
148
149    
150    // Return the corresponding eigenvector as normal
151    let mut normal: Vector3f = eigenvectors.column(min_idx).into();
152    
153    // Ensure the normal is normalized
154    let magnitude = normal.magnitude();
155    if magnitude > 1e-6 {
156        normal /= magnitude;
157    } else {
158        normal = Vector3f::new(0.0, 0.0, 1.0);
159    }
160    
161    normal
162}
163
164/// Orient normal towards viewpoint for consistency
165fn orient_normal_towards_viewpoint(normal: Vector3f, point: Point3f, viewpoint: Point3f) -> Vector3f {
166    let to_viewpoint = (viewpoint - point).normalize();
167    let dot_product = normal.dot(&to_viewpoint);
168    
169    // If the angle between normal and viewpoint direction is > 90 degrees, flip the normal
170    if dot_product < 0.0 {
171        -normal
172    } else {
173        normal
174    }
175}
176
177/// Estimate normals for a point cloud using k-nearest neighbors
178/// 
179/// This function computes surface normals for each point in the cloud by:
180/// 1. Finding k nearest neighbors for each point (or neighbors within radius)
181/// 2. Computing the normal using Principal Component Analysis (PCA)
182/// 3. The normal is the eigenvector corresponding to the smallest eigenvalue
183/// 4. Optionally enforcing orientation consistency towards a viewpoint
184/// 
185/// # Arguments
186/// * `cloud` - Reference to the point cloud
187/// * `k` - Number of nearest neighbors to use (typically 10-30)
188/// 
189/// # Returns
190/// * `Result<PointCloud<NormalPoint3f>>` - A new point cloud with normals
191pub fn estimate_normals(cloud: &PointCloud<Point3f>, k: usize) -> Result<PointCloud<NormalPoint3f>> {
192    let config = NormalEstimationConfig {
193        k_neighbors: k,
194        ..Default::default()
195    };
196    estimate_normals_with_config(cloud, &config)
197}
198
199/// Estimate normals with advanced configuration
200/// 
201/// # Arguments
202/// * `cloud` - Reference to the point cloud
203/// * `config` - Configuration for normal estimation
204/// 
205/// # Returns
206/// * `Result<PointCloud<NormalPoint3f>>` - A new point cloud with normals
207pub fn estimate_normals_with_config(
208    cloud: &PointCloud<Point3f>, 
209    config: &NormalEstimationConfig
210) -> Result<PointCloud<NormalPoint3f>> {
211    if cloud.is_empty() {
212        return Ok(PointCloud::new());
213    }
214    
215    if config.k_neighbors < 3 {
216        return Err(Error::InvalidData("k_neighbors must be at least 3".to_string()));
217    }
218    
219    let points = &cloud.points;
220    
221    // Determine viewpoint for orientation consistency
222    let viewpoint = config.viewpoint.unwrap_or_else(|| {
223        // Default viewpoint: compute a good viewpoint based on the point cloud bounds
224        let mut min_x = points[0].x;
225        let mut min_y = points[0].y;
226        let mut min_z = points[0].z;
227        let mut max_x = points[0].x;
228        let mut max_y = points[0].y;
229        let mut max_z = points[0].z;
230        
231        for point in points {
232            min_x = min_x.min(point.x);
233            min_y = min_y.min(point.y);
234            min_z = min_z.min(point.z);
235            max_x = max_x.max(point.x);
236            max_y = max_y.max(point.y);
237            max_z = max_z.max(point.z);
238        }
239        
240        let center = Point3f::new(
241            (min_x + max_x) / 2.0,
242            (min_y + max_y) / 2.0,
243            (min_z + max_z) / 2.0,
244        );
245        let extent = ((max_x - min_x).powi(2) + (max_y - min_y).powi(2) + (max_z - min_z).powi(2)).sqrt();
246        
247        // Viewpoint is above the center of the point cloud
248        center + Vector3f::new(0.0, 0.0, extent)
249    });
250    
251    // Compute normals in parallel
252    let normals: Vec<NormalPoint3f> = (0..points.len())
253        .into_par_iter()
254        .map(|i| {
255            let neighbors = find_neighbors(points, i, config);
256            
257            // Use only the neighbors for PCA, not the query point itself
258            let mut neighborhood = neighbors;
259            
260            // If radius-based search didn't find enough neighbors, fall back to k-NN
261            if config.radius.is_some() && neighborhood.len() < config.k_neighbors {
262                neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors);
263            }
264            
265            // Ensure we have enough neighbors for PCA
266            if neighborhood.len() < 3 {
267                // If we still don't have enough neighbors, use a larger k
268                neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors.max(5));
269            }
270            
271            let mut normal = compute_normal_pca(points, &neighborhood);
272            
273            // Apply orientation consistency if requested
274            if config.consistent_orientation {
275                normal = orient_normal_towards_viewpoint(normal, points[i], viewpoint);
276            }
277            
278            NormalPoint3f {
279                position: points[i],
280                normal,
281            }
282        })
283        .collect();
284    
285    Ok(PointCloud::from_points(normals))
286}
287
288/// Estimate normals using radius-based neighbor search
289/// 
290/// # Arguments
291/// * `cloud` - Reference to the point cloud
292/// * `radius` - Search radius for neighbors
293/// * `consistent_orientation` - Whether to enforce orientation consistency
294/// 
295/// # Returns
296/// * `Result<PointCloud<NormalPoint3f>>` - A new point cloud with normals
297pub fn estimate_normals_radius(
298    cloud: &PointCloud<Point3f>, 
299    radius: f32, 
300    consistent_orientation: bool
301) -> Result<PointCloud<NormalPoint3f>> {
302    let config = NormalEstimationConfig {
303        k_neighbors: 10, // Fallback value
304        radius: Some(radius),
305        consistent_orientation,
306        viewpoint: None,
307    };
308    estimate_normals_with_config(cloud, &config)
309}
310
311/// Estimate normals and modify the input cloud in-place (legacy API)
312/// This function is deprecated in favor of the version that returns a new cloud
313#[deprecated(note = "Use estimate_normals instead which returns a new point cloud")]
314pub fn estimate_normals_inplace(_cloud: &mut PointCloud<Point3f>, k: usize) -> Result<()> {
315    // This would require converting the point cloud type, which isn't straightforward
316    // with the current type system. The new API is cleaner.
317    let _ = k;
318    Err(Error::Unsupported("Use estimate_normals instead".to_string()))
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324
325    #[test]
326    fn test_estimate_normals_simple() {
327        // Create a simple planar point cloud (XY plane)
328        let mut cloud = PointCloud::new();
329        cloud.push(Point3f::new(0.0, 0.0, 0.0));
330        cloud.push(Point3f::new(1.0, 0.0, 0.0));
331        cloud.push(Point3f::new(0.0, 1.0, 0.0));
332        cloud.push(Point3f::new(1.0, 1.0, 0.0));
333        cloud.push(Point3f::new(0.5, 0.5, 0.0));
334        
335        let result = estimate_normals(&cloud, 3).unwrap();
336        
337        assert_eq!(result.len(), 5);
338        
339        // For a planar surface in XY plane, normals should point along Z axis
340        for point in result.iter() {
341            let normal = point.normal;
342            // Normal should be close to (0, 0, 1) or (0, 0, -1)
343            assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
344        }
345    }
346    
347    #[test]
348    fn test_estimate_normals_empty() {
349        let cloud = PointCloud::<Point3f>::new();
350        let result = estimate_normals(&cloud, 5).unwrap();
351        assert!(result.is_empty());
352    }
353    
354    #[test]
355    fn test_estimate_normals_insufficient_k() {
356        let mut cloud = PointCloud::new();
357        cloud.push(Point3f::new(0.0, 0.0, 0.0));
358        
359        let result = estimate_normals(&cloud, 2);
360        assert!(result.is_err());
361    }
362    
363    #[test]
364    fn test_estimate_normals_radius() {
365        // Create a simple planar point cloud for testing radius-based search
366        let mut cloud = PointCloud::new();
367        for i in 0..20 {
368            for j in 0..20 {
369                let x = (i as f32) * 0.1;
370                let y = (j as f32) * 0.1;
371                let z = 0.0;
372                cloud.push(Point3f::new(x, y, z));
373            }
374        }
375        
376        let result = estimate_normals_radius(&cloud, 0.2, true).unwrap();
377        assert_eq!(result.len(), 400);
378        
379        // Check that normals are computed and have reasonable values
380        let mut z_direction_count = 0;
381        for point in result.iter() {
382            let normal_magnitude = point.normal.magnitude();
383            // Normals should be unit vectors
384            assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
385            
386            // For a planar surface, normals should be primarily in Z direction
387            if point.normal.z.abs() > 0.8 {
388                z_direction_count += 1;
389            }
390        }
391        
392        // At least 80% of normals should be in Z direction for a planar surface
393        let percentage = (z_direction_count as f32 / result.len() as f32) * 100.0;
394        assert!(percentage > 80.0, "Only {:.1}% of normals are in Z direction", percentage);
395    }
396    
397    #[test]
398    fn test_estimate_normals_cylinder() {
399        // Create a true cylindrical point cloud (points on a cylinder surface)
400        let mut cloud = PointCloud::new();
401        for i in 0..10 {
402            for j in 0..10 {
403                let angle = (i as f32) * 0.6;
404                let height = (j as f32) * 0.2 - 1.0;
405                let x = angle.cos();
406                let y = angle.sin();
407                let z = height;
408                cloud.push(Point3f::new(x, y, z));
409            }
410        }
411        
412        let config = NormalEstimationConfig {
413            k_neighbors: 8, // Increase k for better results
414            radius: None,
415            consistent_orientation: true,
416            viewpoint: Some(Point3f::new(0.0, 0.0, 2.0)), // View from above
417        };
418        
419        let result = estimate_normals_with_config(&cloud, &config).unwrap();
420        assert_eq!(result.len(), 100);
421        
422        // Check that normals are computed and have reasonable values
423        let mut perpendicular_count = 0;
424        let mut outward_count = 0;
425        for point in result.iter() {
426            let normal_magnitude = point.normal.magnitude();
427            // Normals should be unit vectors
428            assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
429            
430            // For a cylinder, normals should be roughly perpendicular to the cylinder axis (Z-axis)
431            let dot_with_z = point.normal.z.abs();
432            if dot_with_z < 0.8 {
433                perpendicular_count += 1;
434            }
435            
436            // Check if normal points outward from center
437            let to_center = Vector3f::new(-point.position.x, -point.position.y, 0.0).normalize();
438            let dot_outward = point.normal.dot(&to_center);
439            if dot_outward > 0.5 {
440                outward_count += 1;
441            }
442        }
443        
444        // At least 60% of normals should be perpendicular to Z-axis for a cylinder
445        let percentage_perpendicular = (perpendicular_count as f32 / result.len() as f32) * 100.0;
446        let percentage_outward = (outward_count as f32 / result.len() as f32) * 100.0;
447        
448        println!("Cylinder test: {:.1}% perpendicular to Z, {:.1}% pointing outward", 
449                percentage_perpendicular, percentage_outward);
450        
451        // For a cylinder, normals should be perpendicular to Z-axis
452        assert!(percentage_perpendicular > 60.0, "Only {:.1}% of normals are perpendicular to Z-axis", percentage_perpendicular);
453    }
454    
455    #[test]
456    fn test_estimate_normals_orientation_consistency() {
457        // Create a simple planar point cloud
458        let mut cloud = PointCloud::new();
459        cloud.push(Point3f::new(0.0, 0.0, 0.0));
460        cloud.push(Point3f::new(1.0, 0.0, 0.0));
461        cloud.push(Point3f::new(0.0, 1.0, 0.0));
462        cloud.push(Point3f::new(1.0, 1.0, 0.0));
463        
464        // Test with orientation consistency enabled
465        let config_consistent = NormalEstimationConfig {
466            k_neighbors: 3,
467            radius: None,
468            consistent_orientation: true,
469            viewpoint: Some(Point3f::new(0.0, 0.0, 1.0)), // View from positive Z
470        };
471        
472        let result_consistent = estimate_normals_with_config(&cloud, &config_consistent).unwrap();
473        
474        // Test with orientation consistency disabled
475        let config_inconsistent = NormalEstimationConfig {
476            k_neighbors: 3,
477            radius: None,
478            consistent_orientation: false,
479            viewpoint: None,
480        };
481        
482        let _result_inconsistent = estimate_normals_with_config(&cloud, &config_inconsistent).unwrap();
483        
484        // With consistent orientation, all normals should point in the same direction (positive Z)
485        let first_normal_consistent = result_consistent.points[0].normal.z;
486        for point in result_consistent.iter() {
487            assert!((point.normal.z * first_normal_consistent) > 0.0, 
488                   "Normals should have consistent orientation");
489        }
490        
491        // Without consistent orientation, normals might point in different directions
492        // (This test is less strict since the algorithm might still produce consistent results)
493        println!("Consistent orientation test completed");
494    }
495    
496    #[test]
497    fn test_find_neighbors() {
498        let points = vec![
499            Point3f::new(0.0, 0.0, 0.0),
500            Point3f::new(1.0, 0.0, 0.0),
501            Point3f::new(0.0, 1.0, 0.0),
502            Point3f::new(2.0, 0.0, 0.0),
503        ];
504        
505        // Test k-NN
506        let config_knn = NormalEstimationConfig {
507            k_neighbors: 2,
508            radius: None,
509            consistent_orientation: false,
510            viewpoint: None,
511        };
512        
513        let neighbors_knn = find_neighbors(&points, 0, &config_knn);
514        assert_eq!(neighbors_knn.len(), 2);
515        
516        // Test radius-based
517        let config_radius = NormalEstimationConfig {
518            k_neighbors: 10,
519            radius: Some(1.5),
520            consistent_orientation: false,
521            viewpoint: None,
522        };
523        
524        let neighbors_radius = find_neighbors(&points, 0, &config_radius);
525        assert_eq!(neighbors_radius.len(), 2); // Points at (1,0,0) and (0,1,0) are within radius 1.5
526    }
527}