Skip to main content

threecrate_algorithms/
global_registration.rs

1//! Global point cloud registration via FPFH feature matching and RANSAC.
2//!
3//! Implements a coarse-to-fine pipeline:
4//! 1. Estimate surface normals (if not pre-computed)
5//! 2. Extract FPFH descriptors from source and target
6//! 3. Match descriptors to build putative correspondences
7//! 4. RANSAC to find the best rigid transformation consistent with those correspondences
8//! 5. Optional ICP refinement to polish the coarse alignment
9//!
10//! Expected use: call `global_registration()` to get a good initial pose, then hand off to ICP
11//! or NDT for sub-millimetre refinement.
12
13use threecrate_core::{PointCloud, Result, Point3f, Error, Isometry3};
14use nalgebra::{Matrix3, Vector3, Rotation3, UnitQuaternion, Translation3};
15use rand::Rng;
16use rayon::prelude::*;
17use crate::features::{extract_fpfh_features_with_normals, FpfhConfig, FPFH_DIM};
18use crate::normals::estimate_normals;
19use crate::registration::{icp_point_to_point, ICPResult};
20
21// ---------------------------------------------------------------------------
22// Config
23// ---------------------------------------------------------------------------
24
25/// Configuration for the global registration pipeline.
26#[derive(Debug, Clone)]
27pub struct GlobalRegistrationConfig {
28    /// Maximum RANSAC iterations.  Higher = more robust, slower.
29    pub ransac_iterations: usize,
30    /// Maximum Euclidean distance (in model units) for a correspondence to count as an inlier.
31    pub distance_threshold: f32,
32    /// Early-exit RANSAC when the fraction of inlier correspondences exceeds this value.
33    pub inlier_ratio: f32,
34    /// Radius used for FPFH feature extraction.
35    pub fpfh_radius: f32,
36    /// Minimum neighbours required by radius search; falls back to k-NN when fewer found.
37    pub fpfh_k_neighbors: usize,
38    /// Number of nearest neighbours used for surface normal estimation.
39    pub normal_k_neighbors: usize,
40    /// Run ICP after RANSAC to refine the coarse alignment.
41    pub refine_with_icp: bool,
42    /// Maximum ICP iterations (only used when `refine_with_icp` is true).
43    pub icp_max_iterations: usize,
44    /// Maximum point-to-point correspondence distance for ICP (None = unlimited).
45    pub icp_distance_threshold: Option<f32>,
46}
47
48impl Default for GlobalRegistrationConfig {
49    fn default() -> Self {
50        Self {
51            ransac_iterations: 50_000,
52            distance_threshold: 0.05,
53            inlier_ratio: 0.25,
54            fpfh_radius: 0.25,
55            fpfh_k_neighbors: 10,
56            normal_k_neighbors: 10,
57            refine_with_icp: true,
58            icp_max_iterations: 50,
59            icp_distance_threshold: None,
60        }
61    }
62}
63
64// ---------------------------------------------------------------------------
65// Result
66// ---------------------------------------------------------------------------
67
68/// Result of the global registration pipeline.
69#[derive(Debug, Clone)]
70pub struct GlobalRegistrationResult {
71    /// Best-found rigid transformation that maps `source` onto `target`.
72    pub transformation: Isometry3<f32>,
73    /// Number of correspondences classified as inliers under `transformation`.
74    pub inlier_count: usize,
75    /// Fraction of total correspondences that are inliers (0.0 – 1.0).
76    pub inlier_ratio: f32,
77    /// ICP refinement result, present when `config.refine_with_icp` is `true`.
78    pub icp_result: Option<ICPResult>,
79}
80
81// ---------------------------------------------------------------------------
82// Internal helpers
83// ---------------------------------------------------------------------------
84
85/// Squared L2 distance between two FPFH descriptors.
86#[inline]
87fn fpfh_dist_sq(a: &[f32; FPFH_DIM], b: &[f32; FPFH_DIM]) -> f32 {
88    a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
89}
90
91/// Match source descriptors to target descriptors (nearest-neighbour in descriptor space).
92///
93/// Returns a list of `(source_idx, target_idx)` pairs.
94fn find_feature_correspondences(
95    src_descs: &[[f32; FPFH_DIM]],
96    tgt_descs: &[[f32; FPFH_DIM]],
97) -> Vec<(usize, usize)> {
98    src_descs
99        .par_iter()
100        .enumerate()
101        .filter_map(|(i, sd)| {
102            let best = tgt_descs
103                .iter()
104                .enumerate()
105                .min_by(|(_, ta), (_, tb)| {
106                    fpfh_dist_sq(sd, ta)
107                        .partial_cmp(&fpfh_dist_sq(sd, tb))
108                        .unwrap_or(std::cmp::Ordering::Equal)
109                })?;
110            Some((i, best.0))
111        })
112        .collect()
113}
114
115/// Estimate a rigid transformation from ≥ 3 point pairs using SVD (same method as ICP).
116fn estimate_transform_svd(src_pts: &[Point3f], tgt_pts: &[Point3f]) -> Option<Isometry3<f32>> {
117    let n = src_pts.len();
118    if n < 3 {
119        return None;
120    }
121    let scale = 1.0 / n as f32;
122
123    let src_centroid = src_pts.iter().fold(Vector3::zeros(), |a, p| a + p.coords) * scale;
124    let tgt_centroid = tgt_pts.iter().fold(Vector3::zeros(), |a, p| a + p.coords) * scale;
125
126    let mut h = Matrix3::<f32>::zeros();
127    for (s, t) in src_pts.iter().zip(tgt_pts.iter()) {
128        let ds = s.coords - src_centroid;
129        let dt = t.coords - tgt_centroid;
130        h += ds * dt.transpose();
131    }
132
133    let svd = h.svd(true, true);
134    let u = svd.u?;
135    let vt = svd.v_t?;
136    let mut r = vt.transpose() * u.transpose();
137
138    // Fix reflection
139    if r.determinant() < 0.0 {
140        let mut vt_fix = vt;
141        vt_fix.row_mut(2).neg_mut();
142        r = vt_fix.transpose() * u.transpose();
143    }
144
145    let rotation = UnitQuaternion::from_rotation_matrix(&Rotation3::from_matrix_unchecked(r));
146    let t = tgt_centroid - rotation * src_centroid;
147    Some(Isometry3::from_parts(Translation3::from(t), rotation))
148}
149
150/// Count correspondences that are inliers under `transform`.
151fn count_inliers(
152    corrs: &[(usize, usize)],
153    src_pts: &[Point3f],
154    tgt_pts: &[Point3f],
155    transform: &Isometry3<f32>,
156    threshold: f32,
157) -> usize {
158    let thr_sq = threshold * threshold;
159    corrs
160        .iter()
161        .filter(|&&(si, ti)| {
162            let tp = transform * src_pts[si];
163            (tp - tgt_pts[ti]).magnitude_squared() <= thr_sq
164        })
165        .count()
166}
167
168// ---------------------------------------------------------------------------
169// Public API
170// ---------------------------------------------------------------------------
171
172/// Run the full global registration pipeline on raw (un-normalised) point clouds.
173///
174/// This function estimates normals internally. For more control (e.g. when normals are
175/// already computed or cloud density is irregular) call `global_registration_with_normals`.
176///
177/// # Workflow
178/// 1. Normal estimation → FPFH extraction → feature matching → RANSAC → optional ICP
179///
180/// # Arguments
181/// * `source` – Point cloud to align (the "model")
182/// * `target` – Reference point cloud (the "scene")
183/// * `config` – Pipeline parameters
184///
185/// # Returns
186/// [`GlobalRegistrationResult`] containing the best transformation found, inlier statistics,
187/// and (optionally) the ICP refinement result.
188pub fn global_registration(
189    source: &PointCloud<Point3f>,
190    target: &PointCloud<Point3f>,
191    config: &GlobalRegistrationConfig,
192) -> Result<GlobalRegistrationResult> {
193    if source.is_empty() {
194        return Err(Error::Algorithm("Source point cloud is empty".into()));
195    }
196    if target.is_empty() {
197        return Err(Error::Algorithm("Target point cloud is empty".into()));
198    }
199
200    let src_normals = estimate_normals(source, config.normal_k_neighbors)?;
201    let tgt_normals = estimate_normals(target, config.normal_k_neighbors)?;
202
203    global_registration_with_normals(&src_normals, &tgt_normals, source, target, config)
204}
205
206/// Global registration when surface normals are already available.
207///
208/// Skips normal estimation; otherwise identical to [`global_registration`].
209///
210/// # Arguments
211/// * `source_n` – Source cloud with normals (for FPFH)
212/// * `target_n` – Target cloud with normals (for FPFH)
213/// * `source`   – Raw source positions (for ICP refinement)
214/// * `target`   – Raw target positions (for ICP refinement)
215/// * `config`   – Pipeline parameters
216pub fn global_registration_with_normals(
217    source_n: &PointCloud<threecrate_core::NormalPoint3f>,
218    target_n: &PointCloud<threecrate_core::NormalPoint3f>,
219    source: &PointCloud<Point3f>,
220    target: &PointCloud<Point3f>,
221    config: &GlobalRegistrationConfig,
222) -> Result<GlobalRegistrationResult> {
223    if source_n.is_empty() || target_n.is_empty() {
224        return Err(Error::Algorithm("Source or target cloud is empty".into()));
225    }
226
227    // --- FPFH extraction ---
228    let fpfh_cfg = FpfhConfig {
229        search_radius: config.fpfh_radius,
230        k_neighbors: config.fpfh_k_neighbors,
231    };
232    let src_descs = extract_fpfh_features_with_normals(source_n, &fpfh_cfg)?;
233    let tgt_descs = extract_fpfh_features_with_normals(target_n, &fpfh_cfg)?;
234
235    // --- Feature correspondences ---
236    let corrs = find_feature_correspondences(&src_descs, &tgt_descs);
237
238    if corrs.len() < 3 {
239        return Err(Error::Algorithm(
240            "Too few feature correspondences for RANSAC (need ≥ 3)".into(),
241        ));
242    }
243
244    let src_pts = &source_n.points.iter().map(|p| p.position).collect::<Vec<_>>();
245    let tgt_pts = &target_n.points.iter().map(|p| p.position).collect::<Vec<_>>();
246
247    // --- RANSAC ---
248    let mut best_transform = Isometry3::identity();
249    let mut best_inliers = 0usize;
250    let early_exit_count =
251        ((config.inlier_ratio * corrs.len() as f32).ceil() as usize).max(3);
252    let mut rng = rand::rng();
253    let n_corrs = corrs.len();
254
255    for _ in 0..config.ransac_iterations {
256        // Pick 3 unique random correspondence indices
257        let i0 = rng.random_range(0..n_corrs);
258        let mut i1 = rng.random_range(0..n_corrs - 1);
259        if i1 >= i0 { i1 += 1; }
260        let mut i2 = rng.random_range(0..n_corrs - 2);
261        if i2 >= i0.min(i1) { i2 += 1; }
262        if i2 >= i0.max(i1) { i2 += 1; }
263        let sample = [i0, i1, i2];
264
265        let s_pts: Vec<Point3f> = sample.iter().map(|&i| src_pts[corrs[i].0]).collect();
266        let t_pts: Vec<Point3f> = sample.iter().map(|&i| tgt_pts[corrs[i].1]).collect();
267
268        let transform = match estimate_transform_svd(&s_pts, &t_pts) {
269            Some(t) => t,
270            None => continue,
271        };
272
273        let inliers = count_inliers(&corrs, src_pts, tgt_pts, &transform, config.distance_threshold);
274
275        if inliers > best_inliers {
276            best_inliers = inliers;
277            best_transform = transform;
278
279            if inliers >= early_exit_count {
280                break;
281            }
282        }
283    }
284
285    let total_corrs = corrs.len();
286    let inlier_ratio = if total_corrs > 0 {
287        best_inliers as f32 / total_corrs as f32
288    } else {
289        0.0
290    };
291
292    // --- ICP refinement ---
293    let icp_result = if config.refine_with_icp {
294        Some(icp_point_to_point(
295            source,
296            target,
297            best_transform,
298            config.icp_max_iterations,
299            1e-6,
300            config.icp_distance_threshold,
301        )?)
302    } else {
303        None
304    };
305
306    let final_transform = icp_result
307        .as_ref()
308        .map(|r| r.transformation)
309        .unwrap_or(best_transform);
310
311    Ok(GlobalRegistrationResult {
312        transformation: final_transform,
313        inlier_count: best_inliers,
314        inlier_ratio,
315        icp_result,
316    })
317}
318
319// ---------------------------------------------------------------------------
320// Tests
321// ---------------------------------------------------------------------------
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use threecrate_core::{PointCloud, Point3f};
327    use nalgebra::{Isometry3, Translation3, UnitQuaternion};
328
329    fn grid_cloud(nx: usize, ny: usize, nz: usize, scale: f32) -> PointCloud<Point3f> {
330        let mut pts = Vec::new();
331        for ix in 0..nx {
332            for iy in 0..ny {
333                for iz in 0..nz {
334                    pts.push(Point3f::new(ix as f32 * scale, iy as f32 * scale, iz as f32 * scale));
335                }
336            }
337        }
338        PointCloud { points: pts }
339    }
340
341    fn apply(cloud: &PointCloud<Point3f>, iso: &Isometry3<f32>) -> PointCloud<Point3f> {
342        PointCloud { points: cloud.points.iter().map(|p| iso * p).collect() }
343    }
344
345    #[test]
346    fn test_global_reg_empty_source() {
347        let empty: PointCloud<Point3f> = PointCloud { points: vec![] };
348        let target = grid_cloud(4, 4, 4, 1.0);
349        let cfg = GlobalRegistrationConfig::default();
350        assert!(global_registration(&empty, &target, &cfg).is_err());
351    }
352
353    #[test]
354    fn test_global_reg_empty_target() {
355        let source = grid_cloud(4, 4, 4, 1.0);
356        let empty: PointCloud<Point3f> = PointCloud { points: vec![] };
357        let cfg = GlobalRegistrationConfig::default();
358        assert!(global_registration(&source, &empty, &cfg).is_err());
359    }
360
361    #[test]
362    fn test_global_reg_identity() {
363        let cloud = grid_cloud(4, 4, 4, 1.0);
364        let cfg = GlobalRegistrationConfig {
365            ransac_iterations: 200,
366            distance_threshold: 0.5,
367            fpfh_radius: 3.0,
368            refine_with_icp: false,
369            ..Default::default()
370        };
371        let result = global_registration(&cloud, &cloud, &cfg).unwrap();
372        // With identical clouds, should find many inliers
373        assert!(result.inlier_count > 0);
374        assert!(result.inlier_ratio > 0.0);
375    }
376
377    #[test]
378    fn test_global_reg_returns_valid_isometry() {
379        let cloud = grid_cloud(4, 4, 4, 1.0);
380        let cfg = GlobalRegistrationConfig {
381            ransac_iterations: 100,
382            distance_threshold: 0.5,
383            fpfh_radius: 3.0,
384            refine_with_icp: false,
385            ..Default::default()
386        };
387        let result = global_registration(&cloud, &cloud, &cfg).unwrap();
388        // Rotation must have unit quaternion norm
389        let qnorm = result.transformation.rotation.norm();
390        assert!((qnorm - 1.0).abs() < 1e-5);
391    }
392
393    #[test]
394    fn test_global_reg_with_icp() {
395        let target = grid_cloud(4, 4, 4, 1.0);
396        let t = Isometry3::from_parts(
397            Translation3::new(0.3, 0.2, 0.1),
398            UnitQuaternion::identity(),
399        );
400        let source = apply(&target, &t);
401        let cfg = GlobalRegistrationConfig {
402            ransac_iterations: 500,
403            distance_threshold: 1.0,
404            fpfh_radius: 3.0,
405            refine_with_icp: true,
406            icp_max_iterations: 30,
407            icp_distance_threshold: Some(2.0),
408            ..Default::default()
409        };
410        let result = global_registration(&source, &target, &cfg).unwrap();
411        assert!(result.icp_result.is_some());
412    }
413
414    #[test]
415    fn test_estimate_transform_svd_three_points() {
416        let src = vec![
417            Point3f::new(0.0, 0.0, 0.0),
418            Point3f::new(1.0, 0.0, 0.0),
419            Point3f::new(0.0, 1.0, 0.0),
420        ];
421        let shift = Vector3::new(1.0f32, 2.0, 3.0);
422        let tgt: Vec<Point3f> = src.iter().map(|p| Point3f::from(p.coords + shift)).collect();
423        let iso = estimate_transform_svd(&src, &tgt).unwrap();
424        let t = iso.translation.vector;
425        assert!((t.x - 1.0).abs() < 1e-4);
426        assert!((t.y - 2.0).abs() < 1e-4);
427        assert!((t.z - 3.0).abs() < 1e-4);
428    }
429
430    #[test]
431    fn test_config_defaults() {
432        let cfg = GlobalRegistrationConfig::default();
433        assert!(cfg.ransac_iterations > 0);
434        assert!(cfg.distance_threshold > 0.0);
435        assert!(cfg.inlier_ratio > 0.0 && cfg.inlier_ratio < 1.0);
436        assert!(cfg.fpfh_radius > 0.0);
437    }
438}