1use threecrate_core::{PointCloud, Result, Point3f, Vector3f, NormalPoint3f, Error};
4use nalgebra::Matrix3;
5use rayon::prelude::*;
6use std::collections::BinaryHeap;
7use std::cmp::Ordering;
8
9#[derive(Debug, Clone)]
11pub struct NormalEstimationConfig {
12 pub k_neighbors: usize,
14 pub radius: Option<f32>,
16 pub consistent_orientation: bool,
18 pub viewpoint: Option<Point3f>,
20}
21
22impl Default for NormalEstimationConfig {
23 fn default() -> Self {
24 Self {
25 k_neighbors: 10,
26 radius: None,
27 consistent_orientation: true,
28 viewpoint: None,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35struct Neighbor {
36 index: usize,
37 distance: f32,
38}
39
40impl PartialEq for Neighbor {
41 fn eq(&self, other: &Self) -> bool {
42 self.distance == other.distance
43 }
44}
45
46impl Eq for Neighbor {}
47
48impl PartialOrd for Neighbor {
49 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
50 other.distance.partial_cmp(&self.distance)
52 }
53}
54
55impl Ord for Neighbor {
56 fn cmp(&self, other: &Self) -> Ordering {
57 self.partial_cmp(other).unwrap_or(Ordering::Equal)
58 }
59}
60
61fn find_k_nearest_neighbors(points: &[Point3f], query_idx: usize, k: usize) -> Vec<usize> {
63 let query = &points[query_idx];
64 let mut heap = BinaryHeap::with_capacity(k + 1);
65
66 for (i, point) in points.iter().enumerate() {
67 if i == query_idx {
68 continue; }
70
71 let distance = (point - query).magnitude_squared();
72 let neighbor = Neighbor { index: i, distance };
73
74 if heap.len() < k {
75 heap.push(neighbor);
76 } else if let Some(farthest) = heap.peek() {
77 if neighbor.distance < farthest.distance {
78 heap.pop();
79 heap.push(neighbor);
80 }
81 }
82 }
83
84 heap.into_iter().map(|n| n.index).collect()
85}
86
87fn find_radius_neighbors(points: &[Point3f], query_idx: usize, radius: f32) -> Vec<usize> {
89 let query = &points[query_idx];
90 let radius_squared = radius * radius;
91
92 points.iter()
93 .enumerate()
94 .filter(|(i, point)| {
95 *i != query_idx && (**point - query).magnitude_squared() <= radius_squared
96 })
97 .map(|(i, _)| i)
98 .collect()
99}
100
101fn find_neighbors(points: &[Point3f], query_idx: usize, config: &NormalEstimationConfig) -> Vec<usize> {
103 if let Some(radius) = config.radius {
104 find_radius_neighbors(points, query_idx, radius)
106 } else {
107 find_k_nearest_neighbors(points, query_idx, config.k_neighbors)
109 }
110}
111
112fn compute_normal_pca(points: &[Point3f], indices: &[usize]) -> Vector3f {
114 if indices.len() < 3 {
115 return Vector3f::new(0.0, 0.0, 1.0);
117 }
118
119 let mut centroid = Point3f::origin();
121 for &idx in indices {
122 centroid += points[idx].coords;
123 }
124 centroid /= indices.len() as f32;
125
126 let mut covariance = Matrix3::zeros();
128 for &idx in indices {
129 let diff = points[idx] - centroid;
130 covariance += diff * diff.transpose();
131 }
132 covariance /= indices.len() as f32;
133
134 let eigen = covariance.symmetric_eigen();
137 let eigenvalues = eigen.eigenvalues;
138 let eigenvectors = eigen.eigenvectors;
139
140 let mut min_idx = 0;
142 for i in 1..3 {
143 if eigenvalues[i] < eigenvalues[min_idx] {
144 min_idx = i;
145 }
146 }
147
148
149
150 let mut normal: Vector3f = eigenvectors.column(min_idx).into();
152
153 let magnitude = normal.magnitude();
155 if magnitude > 1e-6 {
156 normal /= magnitude;
157 } else {
158 normal = Vector3f::new(0.0, 0.0, 1.0);
159 }
160
161 normal
162}
163
164fn orient_normal_towards_viewpoint(normal: Vector3f, point: Point3f, viewpoint: Point3f) -> Vector3f {
166 let to_viewpoint = (viewpoint - point).normalize();
167 let dot_product = normal.dot(&to_viewpoint);
168
169 if dot_product < 0.0 {
171 -normal
172 } else {
173 normal
174 }
175}
176
177pub fn estimate_normals(cloud: &PointCloud<Point3f>, k: usize) -> Result<PointCloud<NormalPoint3f>> {
192 let config = NormalEstimationConfig {
193 k_neighbors: k,
194 ..Default::default()
195 };
196 estimate_normals_with_config(cloud, &config)
197}
198
199pub fn estimate_normals_with_config(
208 cloud: &PointCloud<Point3f>,
209 config: &NormalEstimationConfig
210) -> Result<PointCloud<NormalPoint3f>> {
211 if cloud.is_empty() {
212 return Ok(PointCloud::new());
213 }
214
215 if config.k_neighbors < 3 {
216 return Err(Error::InvalidData("k_neighbors must be at least 3".to_string()));
217 }
218
219 let points = &cloud.points;
220
221 let viewpoint = config.viewpoint.unwrap_or_else(|| {
223 let mut min_x = points[0].x;
225 let mut min_y = points[0].y;
226 let mut min_z = points[0].z;
227 let mut max_x = points[0].x;
228 let mut max_y = points[0].y;
229 let mut max_z = points[0].z;
230
231 for point in points {
232 min_x = min_x.min(point.x);
233 min_y = min_y.min(point.y);
234 min_z = min_z.min(point.z);
235 max_x = max_x.max(point.x);
236 max_y = max_y.max(point.y);
237 max_z = max_z.max(point.z);
238 }
239
240 let center = Point3f::new(
241 (min_x + max_x) / 2.0,
242 (min_y + max_y) / 2.0,
243 (min_z + max_z) / 2.0,
244 );
245 let extent = ((max_x - min_x).powi(2) + (max_y - min_y).powi(2) + (max_z - min_z).powi(2)).sqrt();
246
247 center + Vector3f::new(0.0, 0.0, extent)
249 });
250
251 let normals: Vec<NormalPoint3f> = (0..points.len())
253 .into_par_iter()
254 .map(|i| {
255 let neighbors = find_neighbors(points, i, config);
256
257 let mut neighborhood = neighbors;
259
260 if config.radius.is_some() && neighborhood.len() < config.k_neighbors {
262 neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors);
263 }
264
265 if neighborhood.len() < 3 {
267 neighborhood = find_k_nearest_neighbors(points, i, config.k_neighbors.max(5));
269 }
270
271 let mut normal = compute_normal_pca(points, &neighborhood);
272
273 if config.consistent_orientation {
275 normal = orient_normal_towards_viewpoint(normal, points[i], viewpoint);
276 }
277
278 NormalPoint3f {
279 position: points[i],
280 normal,
281 }
282 })
283 .collect();
284
285 Ok(PointCloud::from_points(normals))
286}
287
288pub fn estimate_normals_radius(
298 cloud: &PointCloud<Point3f>,
299 radius: f32,
300 consistent_orientation: bool
301) -> Result<PointCloud<NormalPoint3f>> {
302 let config = NormalEstimationConfig {
303 k_neighbors: 10, radius: Some(radius),
305 consistent_orientation,
306 viewpoint: None,
307 };
308 estimate_normals_with_config(cloud, &config)
309}
310
311#[deprecated(note = "Use estimate_normals instead which returns a new point cloud")]
314pub fn estimate_normals_inplace(_cloud: &mut PointCloud<Point3f>, k: usize) -> Result<()> {
315 let _ = k;
318 Err(Error::Unsupported("Use estimate_normals instead".to_string()))
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324
325 #[test]
326 fn test_estimate_normals_simple() {
327 let mut cloud = PointCloud::new();
329 cloud.push(Point3f::new(0.0, 0.0, 0.0));
330 cloud.push(Point3f::new(1.0, 0.0, 0.0));
331 cloud.push(Point3f::new(0.0, 1.0, 0.0));
332 cloud.push(Point3f::new(1.0, 1.0, 0.0));
333 cloud.push(Point3f::new(0.5, 0.5, 0.0));
334
335 let result = estimate_normals(&cloud, 3).unwrap();
336
337 assert_eq!(result.len(), 5);
338
339 for point in result.iter() {
341 let normal = point.normal;
342 assert!(normal.z.abs() > 0.8, "Normal should be primarily in Z direction: {:?}", normal);
344 }
345 }
346
347 #[test]
348 fn test_estimate_normals_empty() {
349 let cloud = PointCloud::<Point3f>::new();
350 let result = estimate_normals(&cloud, 5).unwrap();
351 assert!(result.is_empty());
352 }
353
354 #[test]
355 fn test_estimate_normals_insufficient_k() {
356 let mut cloud = PointCloud::new();
357 cloud.push(Point3f::new(0.0, 0.0, 0.0));
358
359 let result = estimate_normals(&cloud, 2);
360 assert!(result.is_err());
361 }
362
363 #[test]
364 fn test_estimate_normals_radius() {
365 let mut cloud = PointCloud::new();
367 for i in 0..20 {
368 for j in 0..20 {
369 let x = (i as f32) * 0.1;
370 let y = (j as f32) * 0.1;
371 let z = 0.0;
372 cloud.push(Point3f::new(x, y, z));
373 }
374 }
375
376 let result = estimate_normals_radius(&cloud, 0.2, true).unwrap();
377 assert_eq!(result.len(), 400);
378
379 let mut z_direction_count = 0;
381 for point in result.iter() {
382 let normal_magnitude = point.normal.magnitude();
383 assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
385
386 if point.normal.z.abs() > 0.8 {
388 z_direction_count += 1;
389 }
390 }
391
392 let percentage = (z_direction_count as f32 / result.len() as f32) * 100.0;
394 assert!(percentage > 80.0, "Only {:.1}% of normals are in Z direction", percentage);
395 }
396
397 #[test]
398 fn test_estimate_normals_cylinder() {
399 let mut cloud = PointCloud::new();
401 for i in 0..10 {
402 for j in 0..10 {
403 let angle = (i as f32) * 0.6;
404 let height = (j as f32) * 0.2 - 1.0;
405 let x = angle.cos();
406 let y = angle.sin();
407 let z = height;
408 cloud.push(Point3f::new(x, y, z));
409 }
410 }
411
412 let config = NormalEstimationConfig {
413 k_neighbors: 8, radius: None,
415 consistent_orientation: true,
416 viewpoint: Some(Point3f::new(0.0, 0.0, 2.0)), };
418
419 let result = estimate_normals_with_config(&cloud, &config).unwrap();
420 assert_eq!(result.len(), 100);
421
422 let mut perpendicular_count = 0;
424 let mut outward_count = 0;
425 for point in result.iter() {
426 let normal_magnitude = point.normal.magnitude();
427 assert!((normal_magnitude - 1.0).abs() < 0.1, "Normal should be unit vector: magnitude={}", normal_magnitude);
429
430 let dot_with_z = point.normal.z.abs();
432 if dot_with_z < 0.8 {
433 perpendicular_count += 1;
434 }
435
436 let to_center = Vector3f::new(-point.position.x, -point.position.y, 0.0).normalize();
438 let dot_outward = point.normal.dot(&to_center);
439 if dot_outward > 0.5 {
440 outward_count += 1;
441 }
442 }
443
444 let percentage_perpendicular = (perpendicular_count as f32 / result.len() as f32) * 100.0;
446 let percentage_outward = (outward_count as f32 / result.len() as f32) * 100.0;
447
448 println!("Cylinder test: {:.1}% perpendicular to Z, {:.1}% pointing outward",
449 percentage_perpendicular, percentage_outward);
450
451 assert!(percentage_perpendicular > 60.0, "Only {:.1}% of normals are perpendicular to Z-axis", percentage_perpendicular);
453 }
454
455 #[test]
456 fn test_estimate_normals_orientation_consistency() {
457 let mut cloud = PointCloud::new();
459 cloud.push(Point3f::new(0.0, 0.0, 0.0));
460 cloud.push(Point3f::new(1.0, 0.0, 0.0));
461 cloud.push(Point3f::new(0.0, 1.0, 0.0));
462 cloud.push(Point3f::new(1.0, 1.0, 0.0));
463
464 let config_consistent = NormalEstimationConfig {
466 k_neighbors: 3,
467 radius: None,
468 consistent_orientation: true,
469 viewpoint: Some(Point3f::new(0.0, 0.0, 1.0)), };
471
472 let result_consistent = estimate_normals_with_config(&cloud, &config_consistent).unwrap();
473
474 let config_inconsistent = NormalEstimationConfig {
476 k_neighbors: 3,
477 radius: None,
478 consistent_orientation: false,
479 viewpoint: None,
480 };
481
482 let _result_inconsistent = estimate_normals_with_config(&cloud, &config_inconsistent).unwrap();
483
484 let first_normal_consistent = result_consistent.points[0].normal.z;
486 for point in result_consistent.iter() {
487 assert!((point.normal.z * first_normal_consistent) > 0.0,
488 "Normals should have consistent orientation");
489 }
490
491 println!("Consistent orientation test completed");
494 }
495
496 #[test]
497 fn test_find_neighbors() {
498 let points = vec![
499 Point3f::new(0.0, 0.0, 0.0),
500 Point3f::new(1.0, 0.0, 0.0),
501 Point3f::new(0.0, 1.0, 0.0),
502 Point3f::new(2.0, 0.0, 0.0),
503 ];
504
505 let config_knn = NormalEstimationConfig {
507 k_neighbors: 2,
508 radius: None,
509 consistent_orientation: false,
510 viewpoint: None,
511 };
512
513 let neighbors_knn = find_neighbors(&points, 0, &config_knn);
514 assert_eq!(neighbors_knn.len(), 2);
515
516 let config_radius = NormalEstimationConfig {
518 k_neighbors: 10,
519 radius: Some(1.5),
520 consistent_orientation: false,
521 viewpoint: None,
522 };
523
524 let neighbors_radius = find_neighbors(&points, 0, &config_radius);
525 assert_eq!(neighbors_radius.len(), 2); }
527}