1use threecrate_core::{PointCloud, Result, Point3f, Vector3f};
4use crate::GpuContext;
5use nalgebra::{Isometry3, Matrix6, Vector6, UnitQuaternion, Translation3};
6
7const ICP_NEAREST_NEIGHBOR_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read_write> correspondences: array<u32>;
11@group(0) @binding(3) var<storage, read_write> distances: array<f32>;
12@group(0) @binding(4) var<uniform> params: ICPParams;
13
14struct ICPParams {
15 num_source: u32,
16 num_target: u32,
17 max_distance: f32,
18}
19
20@compute @workgroup_size(64)
21fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22 let index = global_id.x;
23 if (index >= params.num_source) {
24 return;
25 }
26
27 let source_point = source_points[index].xyz;
28 var min_distance = params.max_distance;
29 var best_match = 0u;
30
31 // Find nearest neighbor in target
32 for (var i = 0u; i < params.num_target; i++) {
33 let target_point = target_points[i].xyz;
34 let diff = source_point - target_point;
35 let distance = length(diff);
36
37 if (distance < min_distance) {
38 min_distance = distance;
39 best_match = i;
40 }
41 }
42
43 correspondences[index] = best_match;
44 distances[index] = min_distance;
45}
46"#;
47
48const ICP_CENTROID_SHADER: &str = r#"
49@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
50@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
51@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
52@group(0) @binding(3) var<storage, read_write> centroids: array<vec4<f32>>; // [source_centroid, target_centroid]
53@group(0) @binding(4) var<uniform> params: CentroidParams;
54
55struct CentroidParams {
56 num_correspondences: u32,
57}
58
59@compute @workgroup_size(1)
60fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
61 if (global_id.x != 0u) {
62 return;
63 }
64
65 var source_centroid = vec3<f32>(0.0);
66 var target_centroid = vec3<f32>(0.0);
67
68 for (var i = 0u; i < params.num_correspondences; i++) {
69 let target_idx = correspondences[i];
70 source_centroid += source_points[i].xyz;
71 target_centroid += target_points[target_idx].xyz;
72 }
73
74 let scale = 1.0 / f32(params.num_correspondences);
75 centroids[0] = vec4<f32>(source_centroid * scale, 0.0);
76 centroids[1] = vec4<f32>(target_centroid * scale, 0.0);
77}
78"#;
79
80#[allow(dead_code)]
81const ICP_COVARIANCE_SHADER: &str = r#"
82@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
83@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
84@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
85@group(0) @binding(3) var<storage, read> centroids: array<vec4<f32>>;
86@group(0) @binding(4) var<storage, read_write> covariance: array<f32>; // 9 elements for 3x3 matrix
87@group(0) @binding(5) var<uniform> params: CovarianceParams;
88
89struct CovarianceParams {
90 num_correspondences: u32,
91}
92
93@compute @workgroup_size(64)
94fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
95 let index = global_id.x;
96 if (index >= params.num_correspondences) {
97 return;
98 }
99
100 let source_centroid = centroids[0].xyz;
101 let target_centroid = centroids[1].xyz;
102
103 let target_idx = correspondences[index];
104 let source_centered = source_points[index].xyz - source_centroid;
105 let target_centered = target_points[target_idx].xyz - target_centroid;
106
107 // Compute outer product contribution
108 let h00 = source_centered.x * target_centered.x;
109 let h01 = source_centered.x * target_centered.y;
110 let h02 = source_centered.x * target_centered.z;
111 let h10 = source_centered.y * target_centered.x;
112 let h11 = source_centered.y * target_centered.y;
113 let h12 = source_centered.y * target_centered.z;
114 let h20 = source_centered.z * target_centered.x;
115 let h21 = source_centered.z * target_centered.y;
116 let h22 = source_centered.z * target_centered.z;
117
118 // Atomic add to covariance matrix (approximated with individual element updates)
119 atomicAdd(&covariance[0], bitcast<i32>(h00));
120 atomicAdd(&covariance[1], bitcast<i32>(h01));
121 atomicAdd(&covariance[2], bitcast<i32>(h02));
122 atomicAdd(&covariance[3], bitcast<i32>(h10));
123 atomicAdd(&covariance[4], bitcast<i32>(h11));
124 atomicAdd(&covariance[5], bitcast<i32>(h12));
125 atomicAdd(&covariance[6], bitcast<i32>(h20));
126 atomicAdd(&covariance[7], bitcast<i32>(h21));
127 atomicAdd(&covariance[8], bitcast<i32>(h22));
128}
129"#;
130
131#[derive(Debug, Clone)]
133pub struct BatchICPJob {
134 pub source: PointCloud<Point3f>,
135 pub target: PointCloud<Point3f>,
136 pub max_iterations: usize,
137 pub convergence_threshold: f32,
138 pub max_correspondence_distance: f32,
139}
140
141#[derive(Debug, Clone)]
143pub struct BatchICPResult {
144 pub transformation: Isometry3<f32>,
145 pub final_error: f32,
146 pub iterations: usize,
147}
148
149impl GpuContext {
150 pub async fn batch_icp_align(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
152 let mut results = Vec::with_capacity(jobs.len());
153
154 const BATCH_SIZE: usize = 4; for batch in jobs.chunks(BATCH_SIZE) {
158 let batch_results = self.process_icp_batch(batch).await?;
159 results.extend(batch_results);
160 }
161
162 Ok(results)
163 }
164
165 async fn process_icp_batch(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
167 let mut results = Vec::with_capacity(jobs.len());
168
169 for job in jobs {
172 let result = self.optimized_icp_align(
173 &job.source,
174 &job.target,
175 job.max_iterations,
176 job.convergence_threshold,
177 job.max_correspondence_distance,
178 ).await?;
179 results.push(result);
180 }
181
182 Ok(results)
183 }
184
185 async fn optimized_icp_align(
187 &self,
188 source: &PointCloud<Point3f>,
189 target: &PointCloud<Point3f>,
190 max_iterations: usize,
191 convergence_threshold: f32,
192 max_correspondence_distance: f32,
193 ) -> Result<BatchICPResult> {
194 if source.is_empty() || target.is_empty() {
195 return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
196 }
197
198 let mut current_transform = Isometry3::identity();
199 let mut transformed_source = source.clone();
200 let mut final_error = f32::INFINITY;
201 let mut iterations_used = 0;
202
203 for iteration in 0..max_iterations {
204 iterations_used = iteration + 1;
205
206 let correspondences = self.find_correspondences(
208 &transformed_source.points,
209 &target.points,
210 max_correspondence_distance,
211 ).await?;
212
213 if correspondences.is_empty() {
214 break;
215 }
216
217 let (transform_delta, error) = self.compute_transformation_gpu(
219 &transformed_source.points,
220 &target.points,
221 &correspondences
222 ).await?;
223
224 final_error = error;
225
226 current_transform = transform_delta * current_transform;
228
229 transformed_source = source.clone();
231 for point in &mut transformed_source.points {
232 *point = current_transform.transform_point(point);
233 }
234
235 let translation_norm = transform_delta.translation.vector.norm();
237 let rotation_angle = transform_delta.rotation.angle();
238
239 if translation_norm < convergence_threshold && rotation_angle < convergence_threshold {
240 break;
241 }
242 }
243
244 Ok(BatchICPResult {
245 transformation: current_transform,
246 final_error,
247 iterations: iterations_used,
248 })
249 }
250
251 async fn compute_transformation_gpu(
253 &self,
254 source_points: &[Point3f],
255 target_points: &[Point3f],
256 correspondences: &[(usize, usize, f32)],
257 ) -> Result<(Isometry3<f32>, f32)> {
258 if correspondences.is_empty() {
259 return Ok((Isometry3::identity(), 0.0));
260 }
261
262 let source_data: Vec<[f32; 4]> = source_points
264 .iter()
265 .map(|p| [p.x, p.y, p.z, 0.0])
266 .collect();
267
268 let target_data: Vec<[f32; 4]> = target_points
269 .iter()
270 .map(|p| [p.x, p.y, p.z, 0.0])
271 .collect();
272
273 let correspondence_indices: Vec<u32> = correspondences
274 .iter()
275 .map(|(_, target_idx, _)| *target_idx as u32)
276 .collect();
277
278 let source_buffer = self.create_buffer_init("Source Points", &source_data, wgpu::BufferUsages::STORAGE);
280 let target_buffer = self.create_buffer_init("Target Points", &target_data, wgpu::BufferUsages::STORAGE);
281 let correspondence_buffer = self.create_buffer_init("Correspondences", &correspondence_indices, wgpu::BufferUsages::STORAGE);
282
283 let centroids = self.compute_centroids_gpu(
285 &source_buffer,
286 &target_buffer,
287 &correspondence_buffer,
288 correspondences.len(),
289 ).await?;
290
291 let covariance = self.compute_covariance_cpu(
293 source_points,
294 target_points,
295 correspondences,
296 ¢roids,
297 )?;
298
299 let transformation = self.svd_to_transformation(&covariance, ¢roids)?;
301
302 let error = correspondences.iter().map(|(_, _, dist)| dist).sum::<f32>() / correspondences.len() as f32;
304
305 Ok((transformation, error))
306 }
307
308 async fn compute_centroids_gpu(
310 &self,
311 source_buffer: &wgpu::Buffer,
312 target_buffer: &wgpu::Buffer,
313 correspondence_buffer: &wgpu::Buffer,
314 num_correspondences: usize,
315 ) -> Result<Vec<[f32; 3]>> {
316 #[repr(C)]
317 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
318 struct CentroidParams {
319 num_correspondences: u32,
320 }
321
322 let params = CentroidParams {
323 num_correspondences: num_correspondences as u32,
324 };
325
326 let params_buffer = self.create_buffer_init("Centroid Params", &[params], wgpu::BufferUsages::UNIFORM);
327 let centroids_buffer = self.create_buffer("Centroids", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC);
328
329 let shader = self.create_shader_module("Centroid Computation", ICP_CENTROID_SHADER);
330
331 let bind_group_layout = self.create_bind_group_layout("Centroid Layout", &[
332 wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
333 wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
334 wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
335 wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None },
336 wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, count: None },
337 ]);
338
339 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
340 label: Some("Centroid Pipeline"),
341 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
342 label: Some("Centroid Layout"),
343 bind_group_layouts: &[&bind_group_layout],
344 immediate_size: 0,
345 })),
346 module: &shader,
347 entry_point: Some("main"),
348 compilation_options: wgpu::PipelineCompilationOptions::default(),
349 cache: None,
350 });
351
352 let bind_group = self.create_bind_group("Centroid Bind Group", &bind_group_layout, &[
353 wgpu::BindGroupEntry { binding: 0, resource: source_buffer.as_entire_binding() },
354 wgpu::BindGroupEntry { binding: 1, resource: target_buffer.as_entire_binding() },
355 wgpu::BindGroupEntry { binding: 2, resource: correspondence_buffer.as_entire_binding() },
356 wgpu::BindGroupEntry { binding: 3, resource: centroids_buffer.as_entire_binding() },
357 wgpu::BindGroupEntry { binding: 4, resource: params_buffer.as_entire_binding() },
358 ]);
359
360 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Centroid Computation") });
361 {
362 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Centroid Pass"), timestamp_writes: None });
363 compute_pass.set_pipeline(&pipeline);
364 compute_pass.set_bind_group(0, &bind_group, &[]);
365 compute_pass.dispatch_workgroups(1, 1, 1);
366 }
367
368 let staging_buffer = self.create_buffer("Centroid Staging", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ);
369 encoder.copy_buffer_to_buffer(¢roids_buffer, 0, &staging_buffer, 0, 2 * std::mem::size_of::<[f32; 3]>() as u64);
370 self.queue.submit(std::iter::once(encoder.finish()));
371
372 let buffer_slice = staging_buffer.slice(..);
373 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
374 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
375 self.device.poll(wgpu::PollType::Wait {
376 submission_index: None,
377 timeout: None,
378 });
379
380 if let Some(Ok(())) = receiver.receive().await {
381 let data = buffer_slice.get_mapped_range();
382 let centroids: Vec<[f32; 3]> = bytemuck::cast_slice(&data).to_vec();
383 drop(data);
384 staging_buffer.unmap();
385 Ok(centroids)
386 } else {
387 Err(threecrate_core::Error::Gpu("Failed to read centroid results".to_string()))
388 }
389 }
390
391 fn compute_covariance_cpu(
394 &self,
395 source_points: &[Point3f],
396 target_points: &[Point3f],
397 correspondences: &[(usize, usize, f32)],
398 centroids: &[[f32; 3]],
399 ) -> Result<nalgebra::Matrix3<f32>> {
400 let mut h = nalgebra::Matrix3::zeros();
401 let sc = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
402 let tc = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
403 for (si, ti, _dist) in correspondences.iter().copied() {
404 let s = source_points[si];
405 let t = target_points[ti];
406 let sv = s.coords - sc;
407 let tv = t.coords - tc;
408 h += sv * tv.transpose();
409 }
410 Ok(h)
411 }
412
413 fn svd_to_transformation(&self, covariance: &nalgebra::Matrix3<f32>, centroids: &[[f32; 3]]) -> Result<Isometry3<f32>> {
415 let source_centroid = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
416 let target_centroid = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
417
418 let svd = covariance.svd(true, true);
420 let u = svd.u.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
421 let v_t = svd.v_t.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
422
423 let mut rotation = v_t.transpose() * u.transpose();
424
425 if rotation.determinant() < 0.0 {
427 let mut v_corrected = v_t.transpose();
428 v_corrected.column_mut(2).scale_mut(-1.0);
429 rotation = v_corrected * u.transpose();
430 }
431
432 let translation = target_centroid - rotation * source_centroid;
433
434 Ok(Isometry3::from_parts(
435 nalgebra::Translation3::from(translation),
436 nalgebra::UnitQuaternion::from_matrix(&rotation),
437 ))
438 }
439
440 pub async fn icp_align(
442 &self,
443 source: &PointCloud<Point3f>,
444 target: &PointCloud<Point3f>,
445 max_iterations: usize,
446 convergence_threshold: f32,
447 max_correspondence_distance: f32,
448 ) -> Result<Isometry3<f32>> {
449 let result = self.optimized_icp_align(
450 source,
451 target,
452 max_iterations,
453 convergence_threshold,
454 max_correspondence_distance
455 ).await?;
456 Ok(result.transformation)
457 }
458
459 async fn find_correspondences(
461 &self,
462 source_points: &[Point3f],
463 target_points: &[Point3f],
464 max_distance: f32,
465 ) -> Result<Vec<(usize, usize, f32)>> {
466 let source_data: Vec<[f32; 3]> = source_points
468 .iter()
469 .map(|p| [p.x, p.y, p.z])
470 .collect();
471
472 let target_data: Vec<[f32; 3]> = target_points
473 .iter()
474 .map(|p| [p.x, p.y, p.z])
475 .collect();
476
477 let source_buffer = self.create_buffer_init(
479 "Source Points",
480 &source_data,
481 wgpu::BufferUsages::STORAGE,
482 );
483
484 let target_buffer = self.create_buffer_init(
485 "Target Points",
486 &target_data,
487 wgpu::BufferUsages::STORAGE,
488 );
489
490 let correspondences_buffer = self.create_buffer(
491 "Correspondences",
492 (source_data.len() * std::mem::size_of::<u32>()) as u64,
493 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
494 );
495
496 let distances_buffer = self.create_buffer(
497 "Distances",
498 (source_data.len() * std::mem::size_of::<f32>()) as u64,
499 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
500 );
501
502 #[repr(C)]
503 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
504 struct ICPParams {
505 num_source: u32,
506 num_target: u32,
507 max_distance: f32,
508 }
509
510 let params = ICPParams {
511 num_source: source_data.len() as u32,
512 num_target: target_data.len() as u32,
513 max_distance,
514 };
515
516 let params_buffer = self.create_buffer_init(
517 "ICP Params",
518 &[params],
519 wgpu::BufferUsages::UNIFORM,
520 );
521
522 let shader = self.create_shader_module("ICP Nearest Neighbor", ICP_NEAREST_NEIGHBOR_SHADER);
524
525 let bind_group_layout = self.create_bind_group_layout(
526 "ICP Correspondence",
527 &[
528 wgpu::BindGroupLayoutEntry {
529 binding: 0,
530 visibility: wgpu::ShaderStages::COMPUTE,
531 ty: wgpu::BindingType::Buffer {
532 ty: wgpu::BufferBindingType::Storage { read_only: true },
533 has_dynamic_offset: false,
534 min_binding_size: None,
535 },
536 count: None,
537 },
538 wgpu::BindGroupLayoutEntry {
539 binding: 1,
540 visibility: wgpu::ShaderStages::COMPUTE,
541 ty: wgpu::BindingType::Buffer {
542 ty: wgpu::BufferBindingType::Storage { read_only: true },
543 has_dynamic_offset: false,
544 min_binding_size: None,
545 },
546 count: None,
547 },
548 wgpu::BindGroupLayoutEntry {
549 binding: 2,
550 visibility: wgpu::ShaderStages::COMPUTE,
551 ty: wgpu::BindingType::Buffer {
552 ty: wgpu::BufferBindingType::Storage { read_only: false },
553 has_dynamic_offset: false,
554 min_binding_size: None,
555 },
556 count: None,
557 },
558 wgpu::BindGroupLayoutEntry {
559 binding: 3,
560 visibility: wgpu::ShaderStages::COMPUTE,
561 ty: wgpu::BindingType::Buffer {
562 ty: wgpu::BufferBindingType::Storage { read_only: false },
563 has_dynamic_offset: false,
564 min_binding_size: None,
565 },
566 count: None,
567 },
568 wgpu::BindGroupLayoutEntry {
569 binding: 4,
570 visibility: wgpu::ShaderStages::COMPUTE,
571 ty: wgpu::BindingType::Buffer {
572 ty: wgpu::BufferBindingType::Uniform,
573 has_dynamic_offset: false,
574 min_binding_size: None,
575 },
576 count: None,
577 },
578 ],
579 );
580
581 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
582 label: Some("ICP Correspondence"),
583 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
584 label: Some("ICP Pipeline Layout"),
585 bind_group_layouts: &[&bind_group_layout],
586 immediate_size: 0,
587 })),
588 module: &shader,
589 entry_point: Some("main"),
590 compilation_options: wgpu::PipelineCompilationOptions::default(),
591 cache: None,
592 });
593
594 let bind_group = self.create_bind_group(
595 "ICP Correspondence",
596 &bind_group_layout,
597 &[
598 wgpu::BindGroupEntry {
599 binding: 0,
600 resource: source_buffer.as_entire_binding(),
601 },
602 wgpu::BindGroupEntry {
603 binding: 1,
604 resource: target_buffer.as_entire_binding(),
605 },
606 wgpu::BindGroupEntry {
607 binding: 2,
608 resource: correspondences_buffer.as_entire_binding(),
609 },
610 wgpu::BindGroupEntry {
611 binding: 3,
612 resource: distances_buffer.as_entire_binding(),
613 },
614 wgpu::BindGroupEntry {
615 binding: 4,
616 resource: params_buffer.as_entire_binding(),
617 },
618 ],
619 );
620
621 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
623 label: Some("ICP Correspondence"),
624 });
625
626 {
627 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
628 label: Some("ICP Correspondence Pass"),
629 timestamp_writes: None,
630 });
631 compute_pass.set_pipeline(&pipeline);
632 compute_pass.set_bind_group(0, &bind_group, &[]);
633 let workgroup_count = (source_data.len() + 63) / 64;
634 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
635 }
636
637 let correspondences_staging = self.create_buffer(
639 "Correspondences Staging",
640 (source_data.len() * std::mem::size_of::<u32>()) as u64,
641 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
642 );
643
644 let distances_staging = self.create_buffer(
645 "Distances Staging",
646 (source_data.len() * std::mem::size_of::<f32>()) as u64,
647 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
648 );
649
650 encoder.copy_buffer_to_buffer(
651 &correspondences_buffer,
652 0,
653 &correspondences_staging,
654 0,
655 (source_data.len() * std::mem::size_of::<u32>()) as u64,
656 );
657
658 encoder.copy_buffer_to_buffer(
659 &distances_buffer,
660 0,
661 &distances_staging,
662 0,
663 (source_data.len() * std::mem::size_of::<f32>()) as u64,
664 );
665
666 self.queue.submit(std::iter::once(encoder.finish()));
667
668 let corr_slice = correspondences_staging.slice(..);
670 let dist_slice = distances_staging.slice(..);
671
672 let (corr_sender, corr_receiver) = futures_intrusive::channel::shared::oneshot_channel();
673 let (dist_sender, dist_receiver) = futures_intrusive::channel::shared::oneshot_channel();
674
675 corr_slice.map_async(wgpu::MapMode::Read, move |v| corr_sender.send(v).unwrap());
676 dist_slice.map_async(wgpu::MapMode::Read, move |v| dist_sender.send(v).unwrap());
677
678 self.device.poll(wgpu::PollType::Wait {
679 submission_index: None,
680 timeout: None,
681 });
682
683 if let (Some(Ok(())), Some(Ok(()))) = (corr_receiver.receive().await, dist_receiver.receive().await) {
684 let corr_data = corr_slice.get_mapped_range();
685 let dist_data = dist_slice.get_mapped_range();
686
687 let correspondences: Vec<u32> = bytemuck::cast_slice(&corr_data).to_vec();
688 let distances: Vec<f32> = bytemuck::cast_slice(&dist_data).to_vec();
689
690 let result: Vec<(usize, usize, f32)> = correspondences
691 .into_iter()
692 .zip(distances.into_iter())
693 .enumerate()
694 .filter(|(_, (_, distance))| *distance < max_distance)
695 .map(|(i, (target_idx, distance))| (i, target_idx as usize, distance))
696 .collect();
697
698 drop(corr_data);
699 drop(dist_data);
700 correspondences_staging.unmap();
701 distances_staging.unmap();
702
703 Ok(result)
704 } else {
705 Err(threecrate_core::Error::Gpu("Failed to read GPU correspondence results".to_string()))
706 }
707 }
708}
709
710#[derive(Debug, Clone)]
712pub struct GpuPointToPlaneICPResult {
713 pub transformation: Isometry3<f32>,
714 pub final_error: f32,
715 pub iterations: usize,
716 pub converged: bool,
717}
718
719impl GpuContext {
720 pub async fn icp_point_to_plane_align(
725 &self,
726 source: &PointCloud<Point3f>,
727 target: &PointCloud<Point3f>,
728 target_normals: &[Vector3f],
729 max_iterations: usize,
730 convergence_threshold: f32,
731 max_correspondence_distance: f32,
732 ) -> Result<GpuPointToPlaneICPResult> {
733 if source.is_empty() || target.is_empty() {
734 return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
735 }
736 if target_normals.len() != target.points.len() {
737 return Err(threecrate_core::Error::InvalidData(
738 "target_normals length must equal target point count".to_string(),
739 ));
740 }
741
742 let mut current_transform = Isometry3::identity();
743 let mut final_error = f32::INFINITY;
744 let mut iterations_used = 0;
745 let mut converged = false;
746
747 for iteration in 0..max_iterations {
748 iterations_used = iteration + 1;
749
750 let mut transformed_source = source.clone();
752 for p in &mut transformed_source.points {
753 *p = current_transform.transform_point(p);
754 }
755
756 let correspondences = self
758 .find_correspondences(
759 &transformed_source.points,
760 &target.points,
761 max_correspondence_distance,
762 )
763 .await?;
764
765 if correspondences.len() < 6 {
766 break;
767 }
768
769 let valid_source: Vec<Point3f> =
771 correspondences.iter().map(|(si, _, _)| transformed_source.points[*si]).collect();
772 let valid_target: Vec<Point3f> =
773 correspondences.iter().map(|(_, ti, _)| target.points[*ti]).collect();
774 let valid_normals: Vec<Vector3f> =
775 correspondences.iter().map(|(_, ti, _)| target_normals[*ti]).collect();
776
777 let delta = Self::solve_point_to_plane_cpu(&valid_source, &valid_target, &valid_normals)?;
779
780 current_transform = delta * current_transform;
781
782 let mse: f32 = valid_source
784 .iter()
785 .zip(valid_target.iter())
786 .zip(valid_normals.iter())
787 .map(|((s, d), n)| {
788 let v = n.dot(&(d.coords - s.coords));
789 v * v
790 })
791 .sum::<f32>()
792 / valid_source.len() as f32;
793
794 let translation_change = delta.translation.vector.norm();
795 let rotation_change = delta.rotation.angle();
796
797 final_error = mse;
798
799 if translation_change < convergence_threshold && rotation_change < convergence_threshold {
800 converged = true;
801 break;
802 }
803 }
804
805 Ok(GpuPointToPlaneICPResult {
806 transformation: current_transform,
807 final_error,
808 iterations: iterations_used,
809 converged,
810 })
811 }
812
813 fn solve_point_to_plane_cpu(
815 source_points: &[Point3f],
816 target_points: &[Point3f],
817 normals: &[Vector3f],
818 ) -> Result<Isometry3<f32>> {
819 let mut ata = Matrix6::<f32>::zeros();
820 let mut atb = Vector6::<f32>::zeros();
821
822 for ((src, tgt), n) in source_points
823 .iter()
824 .zip(target_points.iter())
825 .zip(normals.iter())
826 {
827 let c = src.coords.cross(n);
828 let a_row = Vector6::new(c.x, c.y, c.z, n.x, n.y, n.z);
829 let b_i = n.dot(&(tgt.coords - src.coords));
830 ata += a_row * a_row.transpose();
831 atb += a_row * b_i;
832 }
833
834 let x = if let Some(chol) = ata.cholesky() {
835 chol.solve(&atb)
836 } else {
837 ata.lu()
838 .solve(&atb)
839 .ok_or_else(|| threecrate_core::Error::Algorithm(
840 "Point-to-plane GPU system is ill-conditioned".to_string(),
841 ))?
842 };
843
844 let rot_x = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::x_axis(), x[0]);
845 let rot_y = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::y_axis(), x[1]);
846 let rot_z = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::z_axis(), x[2]);
847
848 Ok(Isometry3::from_parts(
849 Translation3::new(x[3], x[4], x[5]),
850 rot_z * rot_y * rot_x,
851 ))
852 }
853}
854
855pub async fn gpu_icp(
857 gpu_context: &GpuContext,
858 source: &PointCloud<Point3f>,
859 target: &PointCloud<Point3f>,
860 max_iterations: usize,
861 convergence_threshold: f32,
862 max_correspondence_distance: f32,
863) -> Result<Isometry3<f32>> {
864 gpu_context.icp_align(source, target, max_iterations, convergence_threshold, max_correspondence_distance).await
865}
866
867pub async fn gpu_batch_icp(
869 gpu_context: &GpuContext,
870 jobs: &[BatchICPJob],
871) -> Result<Vec<BatchICPResult>> {
872 gpu_context.batch_icp_align(jobs).await
873}
874
875pub async fn gpu_icp_point_to_plane(
889 gpu_context: &GpuContext,
890 source: &PointCloud<Point3f>,
891 target: &PointCloud<Point3f>,
892 target_normals: &[Vector3f],
893 max_iterations: usize,
894 convergence_threshold: f32,
895 max_correspondence_distance: f32,
896) -> Result<GpuPointToPlaneICPResult> {
897 gpu_context
898 .icp_point_to_plane_align(
899 source,
900 target,
901 target_normals,
902 max_iterations,
903 convergence_threshold,
904 max_correspondence_distance,
905 )
906 .await
907}