1use threecrate_core::{PointCloud, Result, Point3f, Error, Isometry3};
14use nalgebra::{Matrix3, Vector3, Rotation3, UnitQuaternion, Translation3};
15use rand::Rng;
16use rayon::prelude::*;
17use crate::features::{extract_fpfh_features_with_normals, FpfhConfig, FPFH_DIM};
18use crate::normals::estimate_normals;
19use crate::registration::{icp_point_to_point, ICPResult};
20
21#[derive(Debug, Clone)]
27pub struct GlobalRegistrationConfig {
28 pub ransac_iterations: usize,
30 pub distance_threshold: f32,
32 pub inlier_ratio: f32,
34 pub fpfh_radius: f32,
36 pub fpfh_k_neighbors: usize,
38 pub normal_k_neighbors: usize,
40 pub refine_with_icp: bool,
42 pub icp_max_iterations: usize,
44 pub icp_distance_threshold: Option<f32>,
46}
47
48impl Default for GlobalRegistrationConfig {
49 fn default() -> Self {
50 Self {
51 ransac_iterations: 50_000,
52 distance_threshold: 0.05,
53 inlier_ratio: 0.25,
54 fpfh_radius: 0.25,
55 fpfh_k_neighbors: 10,
56 normal_k_neighbors: 10,
57 refine_with_icp: true,
58 icp_max_iterations: 50,
59 icp_distance_threshold: None,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
70pub struct GlobalRegistrationResult {
71 pub transformation: Isometry3<f32>,
73 pub inlier_count: usize,
75 pub inlier_ratio: f32,
77 pub icp_result: Option<ICPResult>,
79}
80
81#[inline]
87fn fpfh_dist_sq(a: &[f32; FPFH_DIM], b: &[f32; FPFH_DIM]) -> f32 {
88 a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum()
89}
90
91fn find_feature_correspondences(
95 src_descs: &[[f32; FPFH_DIM]],
96 tgt_descs: &[[f32; FPFH_DIM]],
97) -> Vec<(usize, usize)> {
98 src_descs
99 .par_iter()
100 .enumerate()
101 .filter_map(|(i, sd)| {
102 let best = tgt_descs
103 .iter()
104 .enumerate()
105 .min_by(|(_, ta), (_, tb)| {
106 fpfh_dist_sq(sd, ta)
107 .partial_cmp(&fpfh_dist_sq(sd, tb))
108 .unwrap_or(std::cmp::Ordering::Equal)
109 })?;
110 Some((i, best.0))
111 })
112 .collect()
113}
114
115fn estimate_transform_svd(src_pts: &[Point3f], tgt_pts: &[Point3f]) -> Option<Isometry3<f32>> {
117 let n = src_pts.len();
118 if n < 3 {
119 return None;
120 }
121 let scale = 1.0 / n as f32;
122
123 let src_centroid = src_pts.iter().fold(Vector3::zeros(), |a, p| a + p.coords) * scale;
124 let tgt_centroid = tgt_pts.iter().fold(Vector3::zeros(), |a, p| a + p.coords) * scale;
125
126 let mut h = Matrix3::<f32>::zeros();
127 for (s, t) in src_pts.iter().zip(tgt_pts.iter()) {
128 let ds = s.coords - src_centroid;
129 let dt = t.coords - tgt_centroid;
130 h += ds * dt.transpose();
131 }
132
133 let svd = h.svd(true, true);
134 let u = svd.u?;
135 let vt = svd.v_t?;
136 let mut r = vt.transpose() * u.transpose();
137
138 if r.determinant() < 0.0 {
140 let mut vt_fix = vt;
141 vt_fix.row_mut(2).neg_mut();
142 r = vt_fix.transpose() * u.transpose();
143 }
144
145 let rotation = UnitQuaternion::from_rotation_matrix(&Rotation3::from_matrix_unchecked(r));
146 let t = tgt_centroid - rotation * src_centroid;
147 Some(Isometry3::from_parts(Translation3::from(t), rotation))
148}
149
150fn count_inliers(
152 corrs: &[(usize, usize)],
153 src_pts: &[Point3f],
154 tgt_pts: &[Point3f],
155 transform: &Isometry3<f32>,
156 threshold: f32,
157) -> usize {
158 let thr_sq = threshold * threshold;
159 corrs
160 .iter()
161 .filter(|&&(si, ti)| {
162 let tp = transform * src_pts[si];
163 (tp - tgt_pts[ti]).magnitude_squared() <= thr_sq
164 })
165 .count()
166}
167
168pub fn global_registration(
189 source: &PointCloud<Point3f>,
190 target: &PointCloud<Point3f>,
191 config: &GlobalRegistrationConfig,
192) -> Result<GlobalRegistrationResult> {
193 if source.is_empty() {
194 return Err(Error::Algorithm("Source point cloud is empty".into()));
195 }
196 if target.is_empty() {
197 return Err(Error::Algorithm("Target point cloud is empty".into()));
198 }
199
200 let src_normals = estimate_normals(source, config.normal_k_neighbors)?;
201 let tgt_normals = estimate_normals(target, config.normal_k_neighbors)?;
202
203 global_registration_with_normals(&src_normals, &tgt_normals, source, target, config)
204}
205
206pub fn global_registration_with_normals(
217 source_n: &PointCloud<threecrate_core::NormalPoint3f>,
218 target_n: &PointCloud<threecrate_core::NormalPoint3f>,
219 source: &PointCloud<Point3f>,
220 target: &PointCloud<Point3f>,
221 config: &GlobalRegistrationConfig,
222) -> Result<GlobalRegistrationResult> {
223 if source_n.is_empty() || target_n.is_empty() {
224 return Err(Error::Algorithm("Source or target cloud is empty".into()));
225 }
226
227 let fpfh_cfg = FpfhConfig {
229 search_radius: config.fpfh_radius,
230 k_neighbors: config.fpfh_k_neighbors,
231 };
232 let src_descs = extract_fpfh_features_with_normals(source_n, &fpfh_cfg)?;
233 let tgt_descs = extract_fpfh_features_with_normals(target_n, &fpfh_cfg)?;
234
235 let corrs = find_feature_correspondences(&src_descs, &tgt_descs);
237
238 if corrs.len() < 3 {
239 return Err(Error::Algorithm(
240 "Too few feature correspondences for RANSAC (need ≥ 3)".into(),
241 ));
242 }
243
244 let src_pts = &source_n.points.iter().map(|p| p.position).collect::<Vec<_>>();
245 let tgt_pts = &target_n.points.iter().map(|p| p.position).collect::<Vec<_>>();
246
247 let mut best_transform = Isometry3::identity();
249 let mut best_inliers = 0usize;
250 let early_exit_count =
251 ((config.inlier_ratio * corrs.len() as f32).ceil() as usize).max(3);
252 let mut rng = rand::rng();
253 let n_corrs = corrs.len();
254
255 for _ in 0..config.ransac_iterations {
256 let i0 = rng.random_range(0..n_corrs);
258 let mut i1 = rng.random_range(0..n_corrs - 1);
259 if i1 >= i0 { i1 += 1; }
260 let mut i2 = rng.random_range(0..n_corrs - 2);
261 if i2 >= i0.min(i1) { i2 += 1; }
262 if i2 >= i0.max(i1) { i2 += 1; }
263 let sample = [i0, i1, i2];
264
265 let s_pts: Vec<Point3f> = sample.iter().map(|&i| src_pts[corrs[i].0]).collect();
266 let t_pts: Vec<Point3f> = sample.iter().map(|&i| tgt_pts[corrs[i].1]).collect();
267
268 let transform = match estimate_transform_svd(&s_pts, &t_pts) {
269 Some(t) => t,
270 None => continue,
271 };
272
273 let inliers = count_inliers(&corrs, src_pts, tgt_pts, &transform, config.distance_threshold);
274
275 if inliers > best_inliers {
276 best_inliers = inliers;
277 best_transform = transform;
278
279 if inliers >= early_exit_count {
280 break;
281 }
282 }
283 }
284
285 let total_corrs = corrs.len();
286 let inlier_ratio = if total_corrs > 0 {
287 best_inliers as f32 / total_corrs as f32
288 } else {
289 0.0
290 };
291
292 let icp_result = if config.refine_with_icp {
294 Some(icp_point_to_point(
295 source,
296 target,
297 best_transform,
298 config.icp_max_iterations,
299 1e-6,
300 config.icp_distance_threshold,
301 )?)
302 } else {
303 None
304 };
305
306 let final_transform = icp_result
307 .as_ref()
308 .map(|r| r.transformation)
309 .unwrap_or(best_transform);
310
311 Ok(GlobalRegistrationResult {
312 transformation: final_transform,
313 inlier_count: best_inliers,
314 inlier_ratio,
315 icp_result,
316 })
317}
318
319#[cfg(test)]
324mod tests {
325 use super::*;
326 use threecrate_core::{PointCloud, Point3f};
327 use nalgebra::{Isometry3, Translation3, UnitQuaternion};
328
329 fn grid_cloud(nx: usize, ny: usize, nz: usize, scale: f32) -> PointCloud<Point3f> {
330 let mut pts = Vec::new();
331 for ix in 0..nx {
332 for iy in 0..ny {
333 for iz in 0..nz {
334 pts.push(Point3f::new(ix as f32 * scale, iy as f32 * scale, iz as f32 * scale));
335 }
336 }
337 }
338 PointCloud { points: pts }
339 }
340
341 fn apply(cloud: &PointCloud<Point3f>, iso: &Isometry3<f32>) -> PointCloud<Point3f> {
342 PointCloud { points: cloud.points.iter().map(|p| iso * p).collect() }
343 }
344
345 #[test]
346 fn test_global_reg_empty_source() {
347 let empty: PointCloud<Point3f> = PointCloud { points: vec![] };
348 let target = grid_cloud(4, 4, 4, 1.0);
349 let cfg = GlobalRegistrationConfig::default();
350 assert!(global_registration(&empty, &target, &cfg).is_err());
351 }
352
353 #[test]
354 fn test_global_reg_empty_target() {
355 let source = grid_cloud(4, 4, 4, 1.0);
356 let empty: PointCloud<Point3f> = PointCloud { points: vec![] };
357 let cfg = GlobalRegistrationConfig::default();
358 assert!(global_registration(&source, &empty, &cfg).is_err());
359 }
360
361 #[test]
362 fn test_global_reg_identity() {
363 let cloud = grid_cloud(4, 4, 4, 1.0);
364 let cfg = GlobalRegistrationConfig {
365 ransac_iterations: 200,
366 distance_threshold: 0.5,
367 fpfh_radius: 3.0,
368 refine_with_icp: false,
369 ..Default::default()
370 };
371 let result = global_registration(&cloud, &cloud, &cfg).unwrap();
372 assert!(result.inlier_count > 0);
374 assert!(result.inlier_ratio > 0.0);
375 }
376
377 #[test]
378 fn test_global_reg_returns_valid_isometry() {
379 let cloud = grid_cloud(4, 4, 4, 1.0);
380 let cfg = GlobalRegistrationConfig {
381 ransac_iterations: 100,
382 distance_threshold: 0.5,
383 fpfh_radius: 3.0,
384 refine_with_icp: false,
385 ..Default::default()
386 };
387 let result = global_registration(&cloud, &cloud, &cfg).unwrap();
388 let qnorm = result.transformation.rotation.norm();
390 assert!((qnorm - 1.0).abs() < 1e-5);
391 }
392
393 #[test]
394 fn test_global_reg_with_icp() {
395 let target = grid_cloud(4, 4, 4, 1.0);
396 let t = Isometry3::from_parts(
397 Translation3::new(0.3, 0.2, 0.1),
398 UnitQuaternion::identity(),
399 );
400 let source = apply(&target, &t);
401 let cfg = GlobalRegistrationConfig {
402 ransac_iterations: 500,
403 distance_threshold: 1.0,
404 fpfh_radius: 3.0,
405 refine_with_icp: true,
406 icp_max_iterations: 30,
407 icp_distance_threshold: Some(2.0),
408 ..Default::default()
409 };
410 let result = global_registration(&source, &target, &cfg).unwrap();
411 assert!(result.icp_result.is_some());
412 }
413
414 #[test]
415 fn test_estimate_transform_svd_three_points() {
416 let src = vec![
417 Point3f::new(0.0, 0.0, 0.0),
418 Point3f::new(1.0, 0.0, 0.0),
419 Point3f::new(0.0, 1.0, 0.0),
420 ];
421 let shift = Vector3::new(1.0f32, 2.0, 3.0);
422 let tgt: Vec<Point3f> = src.iter().map(|p| Point3f::from(p.coords + shift)).collect();
423 let iso = estimate_transform_svd(&src, &tgt).unwrap();
424 let t = iso.translation.vector;
425 assert!((t.x - 1.0).abs() < 1e-4);
426 assert!((t.y - 2.0).abs() < 1e-4);
427 assert!((t.z - 3.0).abs() < 1e-4);
428 }
429
430 #[test]
431 fn test_config_defaults() {
432 let cfg = GlobalRegistrationConfig::default();
433 assert!(cfg.ransac_iterations > 0);
434 assert!(cfg.distance_threshold > 0.0);
435 assert!(cfg.inlier_ratio > 0.0 && cfg.inlier_ratio < 1.0);
436 assert!(cfg.fpfh_radius > 0.0);
437 }
438}