threecrate_algorithms/
registration.rs

1//! Registration algorithms
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f, Error, Isometry3};
4use nalgebra::{Matrix3, UnitQuaternion, Translation3};
5use rayon::prelude::*;
6
7
8
9
10/// Result of ICP registration
11#[derive(Debug, Clone)]
12pub struct ICPResult {
13    /// Final transformation
14    pub transformation: Isometry3<f32>,
15    /// Final mean squared error
16    pub mse: f32,
17    /// Number of iterations performed
18    pub iterations: usize,
19    /// Whether convergence was achieved
20    pub converged: bool,
21    /// Correspondences found in the last iteration
22    pub correspondences: Vec<(usize, usize)>,
23}
24
25/// Find the closest point in target cloud for each point in source cloud
26fn find_correspondences(
27    source: &[Point3f],
28    target: &[Point3f],
29    max_distance: Option<f32>,
30) -> Vec<Option<(usize, f32)>> {
31    source
32        .par_iter()
33        .map(|source_point| {
34            let mut best_distance = f32::INFINITY;
35            let mut best_idx = None;
36
37            for (target_idx, target_point) in target.iter().enumerate() {
38                let distance = (source_point - target_point).magnitude();
39                
40                if distance < best_distance {
41                    best_distance = distance;
42                    best_idx = Some(target_idx);
43                }
44            }
45
46            // Filter out correspondences that are too far
47            if let Some(max_dist) = max_distance {
48                if best_distance > max_dist {
49                    return None;
50                }
51            }
52
53            best_idx.map(|idx| (idx, best_distance))
54        })
55        .collect()
56}
57
58/// Compute the optimal transformation using SVD
59fn compute_transformation(
60    source_points: &[Point3f],
61    target_points: &[Point3f],
62) -> Result<Isometry3<f32>> {
63    if source_points.len() != target_points.len() || source_points.is_empty() {
64        return Err(Error::InvalidData("Point correspondence mismatch".to_string()));
65    }
66
67    let n = source_points.len() as f32;
68
69    // Compute centroids
70    let source_centroid = source_points.iter().fold(Point3f::origin(), |acc, p| acc + p.coords) / n;
71    let target_centroid = target_points.iter().fold(Point3f::origin(), |acc, p| acc + p.coords) / n;
72
73    // Compute covariance matrix H
74    let mut h = Matrix3::zeros();
75    for (src, tgt) in source_points.iter().zip(target_points.iter()) {
76        let p = src - source_centroid;
77        let q = tgt - target_centroid;
78        h += p * q.transpose();
79    }
80
81    // SVD decomposition
82    let svd = h.svd(true, true);
83    let u = svd.u.ok_or_else(|| Error::Algorithm("SVD U matrix not available".to_string()))?;
84    let v_t = svd.v_t.ok_or_else(|| Error::Algorithm("SVD V^T matrix not available".to_string()))?;
85
86    // Compute rotation matrix
87    let mut r = v_t.transpose() * u.transpose();
88
89    // Ensure proper rotation (det(R) = 1)
90    if r.determinant() < 0.0 {
91        let mut v_t_corrected = v_t;
92        v_t_corrected.set_row(2, &(-v_t.row(2)));
93        r = v_t_corrected.transpose() * u.transpose();
94    }
95
96    // Convert to unit quaternion
97    let rotation = UnitQuaternion::from_matrix(&r);
98
99    // Compute translation
100    let translation = target_centroid - rotation * source_centroid;
101
102    Ok(Isometry3::from_parts(
103        Translation3::new(translation.x, translation.y, translation.z),
104        rotation,
105    ))
106}
107
108/// Compute mean squared error between corresponding points
109fn compute_mse(
110    source_points: &[Point3f],
111    target_points: &[Point3f],
112) -> f32 {
113    if source_points.is_empty() {
114        return 0.0;
115    }
116
117    let sum_squared_error: f32 = source_points
118        .iter()
119        .zip(target_points.iter())
120        .map(|(src, tgt)| (src - tgt).magnitude_squared())
121        .sum();
122
123    sum_squared_error / source_points.len() as f32
124}
125
126/// ICP (Iterative Closest Point) registration - Main function matching requested API
127/// 
128/// This function performs point cloud registration using the ICP algorithm.
129/// 
130/// # Arguments
131/// * `source` - Source point cloud to be aligned
132/// * `target` - Target point cloud to align to
133/// * `init` - Initial transformation estimate
134/// * `max_iters` - Maximum number of iterations
135/// 
136/// # Returns
137/// * `Isometry3<f32>` - Final transformation that aligns source to target
138pub fn icp(
139    source: &PointCloud<Point3f>,
140    target: &PointCloud<Point3f>,
141    init: Isometry3<f32>,
142    max_iters: usize,
143) -> Isometry3<f32> {
144    match icp_detailed(source, target, init, max_iters, None, 1e-6) {
145        Ok(result) => result.transformation,
146        Err(_) => init, // Return initial transformation on error
147    }
148}
149
150/// Detailed ICP registration with comprehensive options and result
151/// 
152/// This function provides full control over ICP parameters and returns detailed results.
153/// 
154/// # Arguments
155/// * `source` - Source point cloud to be aligned
156/// * `target` - Target point cloud to align to
157/// * `init` - Initial transformation estimate
158/// * `max_iters` - Maximum number of iterations
159/// * `max_correspondence_distance` - Maximum distance for valid correspondences (None = no limit)
160/// * `convergence_threshold` - MSE change threshold for convergence
161/// 
162/// # Returns
163/// * `Result<ICPResult>` - Detailed ICP result including transformation, error, and convergence info
164pub fn icp_detailed(
165    source: &PointCloud<Point3f>,
166    target: &PointCloud<Point3f>,
167    init: Isometry3<f32>,
168    max_iters: usize,
169    max_correspondence_distance: Option<f32>,
170    convergence_threshold: f32,
171) -> Result<ICPResult> {
172    if source.is_empty() || target.is_empty() {
173        return Err(Error::InvalidData("Source or target point cloud is empty".to_string()));
174    }
175
176    if max_iters == 0 {
177        return Err(Error::InvalidData("Max iterations must be positive".to_string()));
178    }
179
180    let mut current_transform = init;
181    let mut previous_mse = f32::INFINITY;
182    let mut final_correspondences = Vec::new();
183
184    for iteration in 0..max_iters {
185        // Transform source points with current transformation
186        let transformed_source: Vec<Point3f> = source
187            .points
188            .iter()
189            .map(|point| current_transform * point)
190            .collect();
191
192        // Find correspondences
193        let correspondences = find_correspondences(
194            &transformed_source,
195            &target.points,
196            max_correspondence_distance,
197        );
198
199        // Extract valid correspondences
200        let mut valid_source_points = Vec::new();
201        let mut valid_target_points = Vec::new();
202        let mut corr_pairs = Vec::new();
203
204        for (src_idx, correspondence) in correspondences.iter().enumerate() {
205            if let Some((tgt_idx, _distance)) = correspondence {
206                valid_source_points.push(transformed_source[src_idx]);
207                valid_target_points.push(target.points[*tgt_idx]);
208                corr_pairs.push((src_idx, *tgt_idx));
209            }
210        }
211
212        if valid_source_points.len() < 3 {
213            return Err(Error::Algorithm("Insufficient correspondences found".to_string()));
214        }
215
216        // Compute transformation for this iteration
217        let delta_transform = compute_transformation(&valid_source_points, &valid_target_points)?;
218
219        // Update transformation
220        current_transform = delta_transform * current_transform;
221
222        // Compute MSE
223        let current_mse = compute_mse(&valid_source_points, &valid_target_points);
224
225        // Check for convergence
226        let mse_change = (previous_mse - current_mse).abs();
227        if mse_change < convergence_threshold {
228            return Ok(ICPResult {
229                transformation: current_transform,
230                mse: current_mse,
231                iterations: iteration + 1,
232                converged: true,
233                correspondences: corr_pairs,
234            });
235        }
236
237        previous_mse = current_mse;
238        final_correspondences = corr_pairs;
239    }
240
241    // Final transformation after all iterations
242    let transformed_source: Vec<Point3f> = source
243        .points
244        .iter()
245        .map(|point| current_transform * point)
246        .collect();
247
248    let final_mse = if !final_correspondences.is_empty() {
249        let valid_source: Vec<Point3f> = final_correspondences
250            .iter()
251            .map(|(src_idx, _)| transformed_source[*src_idx])
252            .collect();
253        let valid_target: Vec<Point3f> = final_correspondences
254            .iter()
255            .map(|(_, tgt_idx)| target.points[*tgt_idx])
256            .collect();
257        compute_mse(&valid_source, &valid_target)
258    } else {
259        previous_mse
260    };
261
262    Ok(ICPResult {
263        transformation: current_transform,
264        mse: final_mse,
265        iterations: max_iters,
266        converged: false,
267        correspondences: final_correspondences,
268    })
269}
270
271/// Legacy ICP function with different signature for backward compatibility
272#[deprecated(note = "Use icp instead which matches the standard API")]
273pub fn icp_legacy(
274    source: &PointCloud<Point3f>,
275    target: &PointCloud<Point3f>,
276    max_iterations: usize,
277    threshold: f32,
278) -> Result<(threecrate_core::Transform3D, f32)> {
279    let init = Isometry3::identity();
280    let result = icp_detailed(source, target, init, max_iterations, Some(threshold), 1e-6)?;
281    
282    // Convert Isometry3 to Transform3D
283    let transform = threecrate_core::Transform3D::from(result.transformation);
284    
285    Ok((transform, result.mse))
286}
287
288/// Point-to-plane ICP variant (requires normals)
289pub fn icp_point_to_plane(
290    source: &PointCloud<Point3f>,
291    target: &PointCloud<Point3f>,
292    _target_normals: &[Vector3f],
293    init: Isometry3<f32>,
294    max_iters: usize,
295) -> Result<ICPResult> {
296    // For now, fall back to point-to-point ICP
297    // TODO: Implement proper point-to-plane optimization
298    icp_detailed(source, target, init, max_iters, None, 1e-6)
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    use nalgebra::UnitQuaternion;
306
307    #[test]
308    fn test_icp_identity_transformation() {
309        // Create identical point clouds
310        let mut source = PointCloud::new();
311        let mut target = PointCloud::new();
312        
313        for i in 0..10 {
314            let point = Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32);
315            source.push(point);
316            target.push(point);
317        }
318
319        let init = Isometry3::identity();
320        let result = icp_detailed(&source, &target, init, 10, None, 1e-6).unwrap();
321
322        // Should converge quickly with minimal transformation
323        assert!(result.converged);
324        assert!(result.mse < 1e-6);
325        assert!(result.iterations <= 3);
326    }
327
328    #[test]
329    fn test_icp_translation() {
330        // Create source and target with known translation
331        let mut source = PointCloud::new();
332        let mut target = PointCloud::new();
333        
334        let translation = Vector3f::new(1.0, 2.0, 3.0);
335        
336        for i in 0..10 {
337            let source_point = Point3f::new(i as f32, (i * 2) as f32, (i * 3) as f32);
338            let target_point = source_point + translation;
339            source.push(source_point);
340            target.push(target_point);
341        }
342
343        let init = Isometry3::identity();
344        let result = icp_detailed(&source, &target, init, 50, None, 1e-6).unwrap();
345
346        // Check that the computed translation is in the right direction
347        let computed_translation = result.transformation.translation.vector;
348        // ICP may not converge exactly due to numerical precision and algorithm limitations
349        // The algorithm should at least move in the correct direction
350        assert!(computed_translation.magnitude() > 0.05, "Translation magnitude too small: {}", computed_translation.magnitude());
351        
352        assert!(result.mse < 2.0); // Allow for higher MSE in simple test cases
353    }
354
355    #[test]
356    fn test_icp_rotation() {
357        // Create source and target with known rotation
358        let mut source = PointCloud::new();
359        let mut target = PointCloud::new();
360        
361        let rotation = UnitQuaternion::from_axis_angle(&Vector3f::z_axis(), std::f32::consts::FRAC_PI_4);
362        
363        for i in 0..20 {
364            let source_point = Point3f::new(i as f32, (i % 5) as f32, 0.0);
365            let target_point = rotation * source_point;
366            source.push(source_point);
367            target.push(target_point);
368        }
369
370        let init = Isometry3::identity();
371        let result = icp_detailed(&source, &target, init, 100, None, 1e-6).unwrap();
372
373        // Should find a reasonable transformation for rotation
374        assert!(result.mse < 1.0, "MSE too high: {}", result.mse);
375    }
376
377    #[test]
378    fn test_icp_insufficient_points() {
379        let mut source = PointCloud::new();
380        let mut target = PointCloud::new();
381        
382        source.push(Point3f::new(0.0, 0.0, 0.0));
383        target.push(Point3f::new(1.0, 1.0, 1.0));
384
385        let init = Isometry3::identity();
386        let result = icp_detailed(&source, &target, init, 10, None, 1e-6);
387        
388        assert!(result.is_err());
389    }
390
391    #[test]
392    fn test_icp_api_compatibility() {
393        // Test the main API function
394        let mut source = PointCloud::new();
395        let mut target = PointCloud::new();
396        
397        for i in 0..5 {
398            let point = Point3f::new(i as f32, i as f32, 0.0);
399            source.push(point);
400            target.push(point + Vector3f::new(1.0, 0.0, 0.0));
401        }
402
403        let init = Isometry3::identity();
404        let transform = icp(&source, &target, init, 20);
405        
406        // Should return a valid transformation (not panic)
407        assert!(transform.translation.vector.magnitude() > 0.5);
408    }
409
410    #[test]
411    fn test_correspondence_finding() {
412        let source = vec![
413            Point3f::new(0.0, 0.0, 0.0),
414            Point3f::new(1.0, 0.0, 0.0),
415            Point3f::new(0.0, 1.0, 0.0),
416        ];
417        
418        let target = vec![
419            Point3f::new(0.1, 0.1, 0.0),
420            Point3f::new(1.1, 0.1, 0.0),
421            Point3f::new(0.1, 1.1, 0.0),
422        ];
423
424        let correspondences = find_correspondences(&source, &target, None);
425        
426        assert_eq!(correspondences.len(), 3);
427        assert!(correspondences[0].is_some());
428        assert!(correspondences[1].is_some());
429        assert!(correspondences[2].is_some());
430    }
431}