1use threecrate_core::{PointCloud, Result, Point3f, Vector3f, Error, Isometry3};
4use nalgebra::{Matrix3, UnitQuaternion, Translation3};
5use rayon::prelude::*;
6
7
8
9
10#[derive(Debug, Clone)]
12pub struct ICPResult {
13 pub transformation: Isometry3<f32>,
15 pub mse: f32,
17 pub iterations: usize,
19 pub converged: bool,
21 pub correspondences: Vec<(usize, usize)>,
23}
24
25fn find_correspondences(
27 source: &[Point3f],
28 target: &[Point3f],
29 max_distance: Option<f32>,
30) -> Vec<Option<(usize, f32)>> {
31 source
32 .par_iter()
33 .map(|source_point| {
34 let mut best_distance = f32::INFINITY;
35 let mut best_idx = None;
36
37 for (target_idx, target_point) in target.iter().enumerate() {
38 let distance = (source_point - target_point).magnitude();
39
40 if distance < best_distance {
41 best_distance = distance;
42 best_idx = Some(target_idx);
43 }
44 }
45
46 if let Some(max_dist) = max_distance {
48 if best_distance > max_dist {
49 return None;
50 }
51 }
52
53 best_idx.map(|idx| (idx, best_distance))
54 })
55 .collect()
56}
57
58fn compute_transformation(
60 source_points: &[Point3f],
61 target_points: &[Point3f],
62) -> Result<Isometry3<f32>> {
63 if source_points.len() != target_points.len() || source_points.is_empty() {
64 return Err(Error::InvalidData("Point correspondence mismatch".to_string()));
65 }
66
67 let n = source_points.len() as f32;
68
69 let source_centroid = source_points.iter().fold(Point3f::origin(), |acc, p| acc + p.coords) / n;
71 let target_centroid = target_points.iter().fold(Point3f::origin(), |acc, p| acc + p.coords) / n;
72
73 let mut h = Matrix3::zeros();
75 for (src, tgt) in source_points.iter().zip(target_points.iter()) {
76 let p = src - source_centroid;
77 let q = tgt - target_centroid;
78 h += p * q.transpose();
79 }
80
81 let svd = h.svd(true, true);
83 let u = svd.u.ok_or_else(|| Error::Algorithm("SVD U matrix not available".to_string()))?;
84 let v_t = svd.v_t.ok_or_else(|| Error::Algorithm("SVD V^T matrix not available".to_string()))?;
85
86 let mut r = v_t.transpose() * u.transpose();
88
89 if r.determinant() < 0.0 {
91 let mut v_t_corrected = v_t;
92 v_t_corrected.set_row(2, &(-v_t.row(2)));
93 r = v_t_corrected.transpose() * u.transpose();
94 }
95
96 let rotation = UnitQuaternion::from_matrix(&r);
98
99 let translation = target_centroid - rotation * source_centroid;
101
102 Ok(Isometry3::from_parts(
103 Translation3::new(translation.x, translation.y, translation.z),
104 rotation,
105 ))
106}
107
108fn compute_mse(
110 source_points: &[Point3f],
111 target_points: &[Point3f],
112) -> f32 {
113 if source_points.is_empty() {
114 return 0.0;
115 }
116
117 let sum_squared_error: f32 = source_points
118 .iter()
119 .zip(target_points.iter())
120 .map(|(src, tgt)| (src - tgt).magnitude_squared())
121 .sum();
122
123 sum_squared_error / source_points.len() as f32
124}
125
126pub fn icp(
139 source: &PointCloud<Point3f>,
140 target: &PointCloud<Point3f>,
141 init: Isometry3<f32>,
142 max_iters: usize,
143) -> Isometry3<f32> {
144 match icp_detailed(source, target, init, max_iters, None, 1e-6) {
145 Ok(result) => result.transformation,
146 Err(_) => init, }
148}
149
150pub fn icp_detailed(
165 source: &PointCloud<Point3f>,
166 target: &PointCloud<Point3f>,
167 init: Isometry3<f32>,
168 max_iters: usize,
169 max_correspondence_distance: Option<f32>,
170 convergence_threshold: f32,
171) -> Result<ICPResult> {
172 if source.is_empty() || target.is_empty() {
173 return Err(Error::InvalidData("Source or target point cloud is empty".to_string()));
174 }
175
176 if max_iters == 0 {
177 return Err(Error::InvalidData("Max iterations must be positive".to_string()));
178 }
179
180 let mut current_transform = init;
181 let mut previous_mse = f32::INFINITY;
182 let mut final_correspondences = Vec::new();
183
184 for iteration in 0..max_iters {
185 let transformed_source: Vec<Point3f> = source
187 .points
188 .iter()
189 .map(|point| current_transform * point)
190 .collect();
191
192 let correspondences = find_correspondences(
194 &transformed_source,
195 &target.points,
196 max_correspondence_distance,
197 );
198
199 let mut valid_source_points = Vec::new();
201 let mut valid_target_points = Vec::new();
202 let mut corr_pairs = Vec::new();
203
204 for (src_idx, correspondence) in correspondences.iter().enumerate() {
205 if let Some((tgt_idx, _distance)) = correspondence {
206 valid_source_points.push(transformed_source[src_idx]);
207 valid_target_points.push(target.points[*tgt_idx]);
208 corr_pairs.push((src_idx, *tgt_idx));
209 }
210 }
211
212 if valid_source_points.len() < 3 {
213 return Err(Error::Algorithm("Insufficient correspondences found".to_string()));
214 }
215
216 let delta_transform = compute_transformation(&valid_source_points, &valid_target_points)?;
218
219 current_transform = delta_transform * current_transform;
221
222 let current_mse = compute_mse(&valid_source_points, &valid_target_points);
224
225 let mse_change = (previous_mse - current_mse).abs();
227 if mse_change < convergence_threshold {
228 return Ok(ICPResult {
229 transformation: current_transform,
230 mse: current_mse,
231 iterations: iteration + 1,
232 converged: true,
233 correspondences: corr_pairs,
234 });
235 }
236
237 previous_mse = current_mse;
238 final_correspondences = corr_pairs;
239 }
240
241 let transformed_source: Vec<Point3f> = source
243 .points
244 .iter()
245 .map(|point| current_transform * point)
246 .collect();
247
248 let final_mse = if !final_correspondences.is_empty() {
249 let valid_source: Vec<Point3f> = final_correspondences
250 .iter()
251 .map(|(src_idx, _)| transformed_source[*src_idx])
252 .collect();
253 let valid_target: Vec<Point3f> = final_correspondences
254 .iter()
255 .map(|(_, tgt_idx)| target.points[*tgt_idx])
256 .collect();
257 compute_mse(&valid_source, &valid_target)
258 } else {
259 previous_mse
260 };
261
262 Ok(ICPResult {
263 transformation: current_transform,
264 mse: final_mse,
265 iterations: max_iters,
266 converged: false,
267 correspondences: final_correspondences,
268 })
269}
270
271#[deprecated(note = "Use icp instead which matches the standard API")]
273pub fn icp_legacy(
274 source: &PointCloud<Point3f>,
275 target: &PointCloud<Point3f>,
276 max_iterations: usize,
277 threshold: f32,
278) -> Result<(threecrate_core::Transform3D, f32)> {
279 let init = Isometry3::identity();
280 let result = icp_detailed(source, target, init, max_iterations, Some(threshold), 1e-6)?;
281
282 let transform = threecrate_core::Transform3D::from(result.transformation);
284
285 Ok((transform, result.mse))
286}
287
288pub fn icp_point_to_plane(
290 source: &PointCloud<Point3f>,
291 target: &PointCloud<Point3f>,
292 _target_normals: &[Vector3f],
293 init: Isometry3<f32>,
294 max_iters: usize,
295) -> Result<ICPResult> {
296 icp_detailed(source, target, init, max_iters, None, 1e-6)
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 use nalgebra::UnitQuaternion;
306
307 #[test]
308 fn test_icp_identity_transformation() {
309 let mut source = PointCloud::new();
311 let mut target = PointCloud::new();
312
313 for i in 0..10 {
314 let point = Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32);
315 source.push(point);
316 target.push(point);
317 }
318
319 let init = Isometry3::identity();
320 let result = icp_detailed(&source, &target, init, 10, None, 1e-6).unwrap();
321
322 assert!(result.converged);
324 assert!(result.mse < 1e-6);
325 assert!(result.iterations <= 3);
326 }
327
328 #[test]
329 fn test_icp_translation() {
330 let mut source = PointCloud::new();
332 let mut target = PointCloud::new();
333
334 let translation = Vector3f::new(1.0, 2.0, 3.0);
335
336 for i in 0..10 {
337 let source_point = Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32);
338 let target_point = source_point + translation;
339 source.push(source_point);
340 target.push(target_point);
341 }
342
343 let init = Isometry3::identity();
344 let result = icp_detailed(&source, &target, init, 50, None, 1e-6).unwrap();
345
346 let computed_translation = result.transformation.translation.vector;
348 assert!(computed_translation.magnitude() > 0.05, "Translation magnitude too small: {}", computed_translation.magnitude());
351
352 assert!(result.mse < 2.0); }
354
355 #[test]
356 fn test_icp_rotation() {
357 let mut source = PointCloud::new();
359 let mut target = PointCloud::new();
360
361 let rotation = UnitQuaternion::from_axis_angle(&Vector3f::z_axis(), std::f32::consts::FRAC_PI_4);
362
363 for i in 0..20 {
364 let source_point = Point3f::new(i as f32, (i % 5) as f32, 0.0);
365 let target_point = rotation * source_point;
366 source.push(source_point);
367 target.push(target_point);
368 }
369
370 let init = Isometry3::identity();
371 let result = icp_detailed(&source, &target, init, 100, None, 1e-6).unwrap();
372
373 assert!(result.mse < 1.0, "MSE too high: {}", result.mse);
375 }
376
377 #[test]
378 fn test_icp_insufficient_points() {
379 let mut source = PointCloud::new();
380 let mut target = PointCloud::new();
381
382 source.push(Point3f::new(0.0, 0.0, 0.0));
383 target.push(Point3f::new(1.0, 1.0, 1.0));
384
385 let init = Isometry3::identity();
386 let result = icp_detailed(&source, &target, init, 10, None, 1e-6);
387
388 assert!(result.is_err());
389 }
390
391 #[test]
392 fn test_icp_api_compatibility() {
393 let mut source = PointCloud::new();
395 let mut target = PointCloud::new();
396
397 for i in 0..5 {
398 let point = Point3f::new(i as f32, i as f32, 0.0);
399 source.push(point);
400 target.push(point + Vector3f::new(1.0, 0.0, 0.0));
401 }
402
403 let init = Isometry3::identity();
404 let transform = icp(&source, &target, init, 20);
405
406 assert!(transform.translation.vector.magnitude() > 0.5);
408 }
409
410 #[test]
411 fn test_correspondence_finding() {
412 let source = vec![
413 Point3f::new(0.0, 0.0, 0.0),
414 Point3f::new(1.0, 0.0, 0.0),
415 Point3f::new(0.0, 1.0, 0.0),
416 ];
417
418 let target = vec![
419 Point3f::new(0.1, 0.1, 0.0),
420 Point3f::new(1.1, 0.1, 0.0),
421 Point3f::new(0.1, 1.1, 0.0),
422 ];
423
424 let correspondences = find_correspondences(&source, &target, None);
425
426 assert_eq!(correspondences.len(), 3);
427 assert!(correspondences[0].is_some());
428 assert!(correspondences[1].is_some());
429 assert!(correspondences[2].is_some());
430 }
431}