threecrate_algorithms/
normals.rs

1//! Normal estimation algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f, NormalPoint3f, Error};
4use nalgebra::Matrix3;
5use rayon::prelude::*;
6use std::collections::BinaryHeap;
7use std::cmp::Ordering;
8
9/// A simple distance-based neighbor for priority queue
10#[derive(Debug, Clone)]
11struct Neighbor {
12    index: usize,
13    distance: f32,
14}
15
16impl PartialEq for Neighbor {
17    fn eq(&self, other: &Self) -> bool {
18        self.distance == other.distance
19    }
20}
21
22impl Eq for Neighbor {}
23
24impl PartialOrd for Neighbor {
25    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
26        // Reverse ordering for min-heap behavior
27        other.distance.partial_cmp(&self.distance)
28    }
29}
30
31impl Ord for Neighbor {
32    fn cmp(&self, other: &Self) -> Ordering {
33        self.partial_cmp(other).unwrap_or(Ordering::Equal)
34    }
35}
36
37/// Find k-nearest neighbors using brute force search
38fn find_k_nearest_neighbors(points: &[Point3f], query_idx: usize, k: usize) -> Vec<usize> {
39    let query = &points[query_idx];
40    let mut heap = BinaryHeap::with_capacity(k + 1);
41    
42    for (i, point) in points.iter().enumerate() {
43        if i == query_idx {
44            continue; // Skip the query point itself
45        }
46        
47        let distance = (point - query).magnitude_squared();
48        let neighbor = Neighbor { index: i, distance };
49        
50        if heap.len() < k {
51            heap.push(neighbor);
52        } else if let Some(farthest) = heap.peek() {
53            if neighbor.distance < farthest.distance {
54                heap.pop();
55                heap.push(neighbor);
56            }
57        }
58    }
59    
60    heap.into_iter().map(|n| n.index).collect()
61}
62
63/// Compute normal using PCA on the neighborhood points
64fn compute_normal_pca(points: &[Point3f], indices: &[usize]) -> Vector3f {
65    if indices.len() < 3 {
66        // Default normal if not enough points
67        return Vector3f::new(0.0, 0.0, 1.0);
68    }
69    
70    // Compute centroid
71    let mut centroid = Point3f::origin();
72    for &idx in indices {
73        centroid += points[idx].coords;
74    }
75    centroid /= indices.len() as f32;
76    
77    // Build covariance matrix
78    let mut covariance = Matrix3::zeros();
79    for &idx in indices {
80        let diff = points[idx] - centroid;
81        covariance += diff * diff.transpose();
82    }
83    covariance /= indices.len() as f32;
84    
85    // Find eigenvector corresponding to smallest eigenvalue
86    // This is the normal direction
87    let eigen = covariance.symmetric_eigen();
88    let eigenvalues = eigen.eigenvalues;
89    let eigenvectors = eigen.eigenvectors;
90    
91    // Find index of smallest eigenvalue
92    let mut min_idx = 0;
93    for i in 1..3 {
94        if eigenvalues[i] < eigenvalues[min_idx] {
95            min_idx = i;
96        }
97    }
98    
99    // Return the corresponding eigenvector as normal
100    let normal = eigenvectors.column(min_idx).into();
101    
102    // Ensure consistent orientation (optional: could orient towards viewpoint)
103    normal
104}
105
106/// Estimate normals for a point cloud using k-nearest neighbors
107/// 
108/// This function computes surface normals for each point in the cloud by:
109/// 1. Finding k nearest neighbors for each point
110/// 2. Computing the normal using Principal Component Analysis (PCA)
111/// 3. The normal is the eigenvector corresponding to the smallest eigenvalue
112/// 
113/// # Arguments
114/// * `cloud` - Mutable reference to the point cloud
115/// * `k` - Number of nearest neighbors to use (typically 10-30)
116/// 
117/// # Returns
118/// * `Result<PointCloud<NormalPoint3f>>` - A new point cloud with normals
119pub fn estimate_normals(cloud: &PointCloud<Point3f>, k: usize) -> Result<PointCloud<NormalPoint3f>> {
120    if cloud.is_empty() {
121        return Ok(PointCloud::new());
122    }
123    
124    if k < 3 {
125        return Err(Error::InvalidData("k must be at least 3".to_string()));
126    }
127    
128    let points = &cloud.points;
129    
130    // Compute normals in parallel
131    let normals: Vec<NormalPoint3f> = (0..points.len())
132        .into_par_iter()
133        .map(|i| {
134            let neighbors = find_k_nearest_neighbors(points, i, k);
135            let mut neighborhood = vec![i]; // Include the point itself
136            neighborhood.extend(neighbors);
137            
138            let normal = compute_normal_pca(points, &neighborhood);
139            
140            NormalPoint3f {
141                position: points[i],
142                normal,
143            }
144        })
145        .collect();
146    
147    Ok(PointCloud::from_points(normals))
148}
149
150/// Estimate normals and modify the input cloud in-place (legacy API)
151/// This function is deprecated in favor of the version that returns a new cloud
152#[deprecated(note = "Use estimate_normals instead which returns a new point cloud")]
153pub fn estimate_normals_inplace(_cloud: &mut PointCloud<Point3f>, k: usize) -> Result<()> {
154    // This would require converting the point cloud type, which isn't straightforward
155    // with the current type system. The new API is cleaner.
156    let _ = k;
157    Err(Error::Unsupported("Use estimate_normals instead".to_string()))
158}
159
160#[cfg(test)]
161mod tests {
162    use super::*;
163
164    
165    #[test]
166    fn test_estimate_normals_simple() {
167        // Create a simple planar point cloud (XY plane)
168        let mut cloud = PointCloud::new();
169        cloud.push(Point3f::new(0.0, 0.0, 0.0));
170        cloud.push(Point3f::new(1.0, 0.0, 0.0));
171        cloud.push(Point3f::new(0.0, 1.0, 0.0));
172        cloud.push(Point3f::new(1.0, 1.0, 0.0));
173        cloud.push(Point3f::new(0.5, 0.5, 0.0));
174        
175        let result = estimate_normals(&cloud, 3).unwrap();
176        
177        assert_eq!(result.len(), 5);
178        
179        // For a planar surface in XY plane, normals should point along Z axis
180        for point in result.iter() {
181            let normal = point.normal;
182            // Normal should be close to (0, 0, 1) or (0, 0, -1)
183            assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
184        }
185    }
186    
187    #[test]
188    fn test_estimate_normals_empty() {
189        let cloud = PointCloud::<Point3f>::new();
190        let result = estimate_normals(&cloud, 5).unwrap();
191        assert!(result.is_empty());
192    }
193    
194    #[test]
195    fn test_estimate_normals_insufficient_k() {
196        let mut cloud = PointCloud::new();
197        cloud.push(Point3f::new(0.0, 0.0, 0.0));
198        
199        let result = estimate_normals(&cloud, 2);
200        assert!(result.is_err());
201    }
202}