1use threecrate_core::{PointCloud, Result, Point3f};
4use crate::GpuContext;
5use nalgebra::Isometry3;
6
7const ICP_NEAREST_NEIGHBOR_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read_write> correspondences: array<u32>;
11@group(0) @binding(3) var<storage, read_write> distances: array<f32>;
12@group(0) @binding(4) var<uniform> params: ICPParams;
13
14struct ICPParams {
15 num_source: u32,
16 num_target: u32,
17 max_distance: f32,
18}
19
20@compute @workgroup_size(64)
21fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22 let index = global_id.x;
23 if (index >= params.num_source) {
24 return;
25 }
26
27 let source_point = source_points[index].xyz;
28 var min_distance = params.max_distance;
29 var best_match = 0u;
30
31 // Find nearest neighbor in target
32 for (var i = 0u; i < params.num_target; i++) {
33 let target_point = target_points[i].xyz;
34 let diff = source_point - target_point;
35 let distance = length(diff);
36
37 if (distance < min_distance) {
38 min_distance = distance;
39 best_match = i;
40 }
41 }
42
43 correspondences[index] = best_match;
44 distances[index] = min_distance;
45}
46"#;
47
48const ICP_CENTROID_SHADER: &str = r#"
49@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
50@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
51@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
52@group(0) @binding(3) var<storage, read_write> centroids: array<vec4<f32>>; // [source_centroid, target_centroid]
53@group(0) @binding(4) var<uniform> params: CentroidParams;
54
55struct CentroidParams {
56 num_correspondences: u32,
57}
58
59@compute @workgroup_size(1)
60fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
61 if (global_id.x != 0u) {
62 return;
63 }
64
65 var source_centroid = vec3<f32>(0.0);
66 var target_centroid = vec3<f32>(0.0);
67
68 for (var i = 0u; i < params.num_correspondences; i++) {
69 let target_idx = correspondences[i];
70 source_centroid += source_points[i].xyz;
71 target_centroid += target_points[target_idx].xyz;
72 }
73
74 let scale = 1.0 / f32(params.num_correspondences);
75 centroids[0] = vec4<f32>(source_centroid * scale, 0.0);
76 centroids[1] = vec4<f32>(target_centroid * scale, 0.0);
77}
78"#;
79
80#[allow(dead_code)]
81const ICP_COVARIANCE_SHADER: &str = r#"
82@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
83@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
84@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
85@group(0) @binding(3) var<storage, read> centroids: array<vec4<f32>>;
86@group(0) @binding(4) var<storage, read_write> covariance: array<f32>; // 9 elements for 3x3 matrix
87@group(0) @binding(5) var<uniform> params: CovarianceParams;
88
89struct CovarianceParams {
90 num_correspondences: u32,
91}
92
93@compute @workgroup_size(64)
94fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
95 let index = global_id.x;
96 if (index >= params.num_correspondences) {
97 return;
98 }
99
100 let source_centroid = centroids[0].xyz;
101 let target_centroid = centroids[1].xyz;
102
103 let target_idx = correspondences[index];
104 let source_centered = source_points[index].xyz - source_centroid;
105 let target_centered = target_points[target_idx].xyz - target_centroid;
106
107 // Compute outer product contribution
108 let h00 = source_centered.x * target_centered.x;
109 let h01 = source_centered.x * target_centered.y;
110 let h02 = source_centered.x * target_centered.z;
111 let h10 = source_centered.y * target_centered.x;
112 let h11 = source_centered.y * target_centered.y;
113 let h12 = source_centered.y * target_centered.z;
114 let h20 = source_centered.z * target_centered.x;
115 let h21 = source_centered.z * target_centered.y;
116 let h22 = source_centered.z * target_centered.z;
117
118 // Atomic add to covariance matrix (approximated with individual element updates)
119 atomicAdd(&covariance[0], bitcast<i32>(h00));
120 atomicAdd(&covariance[1], bitcast<i32>(h01));
121 atomicAdd(&covariance[2], bitcast<i32>(h02));
122 atomicAdd(&covariance[3], bitcast<i32>(h10));
123 atomicAdd(&covariance[4], bitcast<i32>(h11));
124 atomicAdd(&covariance[5], bitcast<i32>(h12));
125 atomicAdd(&covariance[6], bitcast<i32>(h20));
126 atomicAdd(&covariance[7], bitcast<i32>(h21));
127 atomicAdd(&covariance[8], bitcast<i32>(h22));
128}
129"#;
130
131#[derive(Debug, Clone)]
133pub struct BatchICPJob {
134 pub source: PointCloud<Point3f>,
135 pub target: PointCloud<Point3f>,
136 pub max_iterations: usize,
137 pub convergence_threshold: f32,
138 pub max_correspondence_distance: f32,
139}
140
141#[derive(Debug, Clone)]
143pub struct BatchICPResult {
144 pub transformation: Isometry3<f32>,
145 pub final_error: f32,
146 pub iterations: usize,
147}
148
149impl GpuContext {
150 pub async fn batch_icp_align(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
152 let mut results = Vec::with_capacity(jobs.len());
153
154 const BATCH_SIZE: usize = 4; for batch in jobs.chunks(BATCH_SIZE) {
158 let batch_results = self.process_icp_batch(batch).await?;
159 results.extend(batch_results);
160 }
161
162 Ok(results)
163 }
164
165 async fn process_icp_batch(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
167 let mut results = Vec::with_capacity(jobs.len());
168
169 for job in jobs {
172 let result = self.optimized_icp_align(
173 &job.source,
174 &job.target,
175 job.max_iterations,
176 job.convergence_threshold,
177 job.max_correspondence_distance,
178 ).await?;
179 results.push(result);
180 }
181
182 Ok(results)
183 }
184
185 async fn optimized_icp_align(
187 &self,
188 source: &PointCloud<Point3f>,
189 target: &PointCloud<Point3f>,
190 max_iterations: usize,
191 convergence_threshold: f32,
192 max_correspondence_distance: f32,
193 ) -> Result<BatchICPResult> {
194 if source.is_empty() || target.is_empty() {
195 return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
196 }
197
198 let mut current_transform = Isometry3::identity();
199 let mut transformed_source = source.clone();
200 let mut final_error = f32::INFINITY;
201 let mut iterations_used = 0;
202
203 for iteration in 0..max_iterations {
204 iterations_used = iteration + 1;
205
206 let correspondences = self.find_correspondences(
208 &transformed_source.points,
209 &target.points,
210 max_correspondence_distance,
211 ).await?;
212
213 if correspondences.is_empty() {
214 break;
215 }
216
217 let (transform_delta, error) = self.compute_transformation_gpu(
219 &transformed_source.points,
220 &target.points,
221 &correspondences
222 ).await?;
223
224 final_error = error;
225
226 current_transform = transform_delta * current_transform;
228
229 transformed_source = source.clone();
231 for point in &mut transformed_source.points {
232 *point = current_transform.transform_point(point);
233 }
234
235 let translation_norm = transform_delta.translation.vector.norm();
237 let rotation_angle = transform_delta.rotation.angle();
238
239 if translation_norm < convergence_threshold && rotation_angle < convergence_threshold {
240 break;
241 }
242 }
243
244 Ok(BatchICPResult {
245 transformation: current_transform,
246 final_error,
247 iterations: iterations_used,
248 })
249 }
250
251 async fn compute_transformation_gpu(
253 &self,
254 source_points: &[Point3f],
255 target_points: &[Point3f],
256 correspondences: &[(usize, usize, f32)],
257 ) -> Result<(Isometry3<f32>, f32)> {
258 if correspondences.is_empty() {
259 return Ok((Isometry3::identity(), 0.0));
260 }
261
262 let source_data: Vec<[f32; 4]> = source_points
264 .iter()
265 .map(|p| [p.x, p.y, p.z, 0.0])
266 .collect();
267
268 let target_data: Vec<[f32; 4]> = target_points
269 .iter()
270 .map(|p| [p.x, p.y, p.z, 0.0])
271 .collect();
272
273 let correspondence_indices: Vec<u32> = correspondences
274 .iter()
275 .map(|(_, target_idx, _)| *target_idx as u32)
276 .collect();
277
278 let source_buffer = self.create_buffer_init("Source Points", &source_data, wgpu::BufferUsages::STORAGE);
280 let target_buffer = self.create_buffer_init("Target Points", &target_data, wgpu::BufferUsages::STORAGE);
281 let correspondence_buffer = self.create_buffer_init("Correspondences", &correspondence_indices, wgpu::BufferUsages::STORAGE);
282
283 let centroids = self.compute_centroids_gpu(
285 &source_buffer,
286 &target_buffer,
287 &correspondence_buffer,
288 correspondences.len(),
289 ).await?;
290
291 let covariance = self.compute_covariance_cpu(
293 source_points,
294 target_points,
295 correspondences,
296 ¢roids,
297 )?;
298
299 let transformation = self.svd_to_transformation(&covariance, ¢roids)?;
301
302 let error = correspondences.iter().map(|(_, _, dist)| dist).sum::<f32>() / correspondences.len() as f32;
304
305 Ok((transformation, error))
306 }
307
308 async fn compute_centroids_gpu(
310 &self,
311 source_buffer: &wgpu::Buffer,
312 target_buffer: &wgpu::Buffer,
313 correspondence_buffer: &wgpu::Buffer,
314 num_correspondences: usize,
315 ) -> Result<Vec<[f32; 3]>> {
316 #[repr(C)]
317 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
318 struct CentroidParams {
319 num_correspondences: u32,
320 }
321
322 let params = CentroidParams {
323 num_correspondences: num_correspondences as u32,
324 };
325
326 let params_buffer = self.create_buffer_init("Centroid Params", &[params], wgpu::BufferUsages::UNIFORM);
327 let centroids_buffer = self.create_buffer("Centroids", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC);
328
329 let shader = self.create_shader_module("Centroid Computation", ICP_CENTROID_SHADER);
330
331 let bind_group_layout = self.create_bind_group_layout("Centroid Layout", &[
332 wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
333 wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
334 wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
335 wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None },
336 wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, count: None },
337 ]);
338
339 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
340 label: Some("Centroid Pipeline"),
341 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
342 label: Some("Centroid Layout"),
343 bind_group_layouts: &[&bind_group_layout],
344 push_constant_ranges: &[],
345 })),
346 module: &shader,
347 entry_point: Some("main"),
348 compilation_options: wgpu::PipelineCompilationOptions::default(),
349 cache: None,
350 });
351
352 let bind_group = self.create_bind_group("Centroid Bind Group", &bind_group_layout, &[
353 wgpu::BindGroupEntry { binding: 0, resource: source_buffer.as_entire_binding() },
354 wgpu::BindGroupEntry { binding: 1, resource: target_buffer.as_entire_binding() },
355 wgpu::BindGroupEntry { binding: 2, resource: correspondence_buffer.as_entire_binding() },
356 wgpu::BindGroupEntry { binding: 3, resource: centroids_buffer.as_entire_binding() },
357 wgpu::BindGroupEntry { binding: 4, resource: params_buffer.as_entire_binding() },
358 ]);
359
360 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Centroid Computation") });
361 {
362 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Centroid Pass"), timestamp_writes: None });
363 compute_pass.set_pipeline(&pipeline);
364 compute_pass.set_bind_group(0, &bind_group, &[]);
365 compute_pass.dispatch_workgroups(1, 1, 1);
366 }
367
368 let staging_buffer = self.create_buffer("Centroid Staging", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ);
369 encoder.copy_buffer_to_buffer(¢roids_buffer, 0, &staging_buffer, 0, 2 * std::mem::size_of::<[f32; 3]>() as u64);
370 self.queue.submit(std::iter::once(encoder.finish()));
371
372 let buffer_slice = staging_buffer.slice(..);
373 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
374 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
375 let _ = self.device.poll(wgpu::PollType::Wait);
376
377 if let Some(Ok(())) = receiver.receive().await {
378 let data = buffer_slice.get_mapped_range();
379 let centroids: Vec<[f32; 3]> = bytemuck::cast_slice(&data).to_vec();
380 drop(data);
381 staging_buffer.unmap();
382 Ok(centroids)
383 } else {
384 Err(threecrate_core::Error::Gpu("Failed to read centroid results".to_string()))
385 }
386 }
387
388 fn compute_covariance_cpu(
391 &self,
392 source_points: &[Point3f],
393 target_points: &[Point3f],
394 correspondences: &[(usize, usize, f32)],
395 centroids: &[[f32; 3]],
396 ) -> Result<nalgebra::Matrix3<f32>> {
397 let mut h = nalgebra::Matrix3::zeros();
398 let sc = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
399 let tc = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
400 for (si, ti, _dist) in correspondences.iter().copied() {
401 let s = source_points[si];
402 let t = target_points[ti];
403 let sv = s.coords - sc;
404 let tv = t.coords - tc;
405 h += sv * tv.transpose();
406 }
407 Ok(h)
408 }
409
410 fn svd_to_transformation(&self, covariance: &nalgebra::Matrix3<f32>, centroids: &[[f32; 3]]) -> Result<Isometry3<f32>> {
412 let source_centroid = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
413 let target_centroid = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
414
415 let svd = covariance.svd(true, true);
417 let u = svd.u.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
418 let v_t = svd.v_t.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
419
420 let mut rotation = v_t.transpose() * u.transpose();
421
422 if rotation.determinant() < 0.0 {
424 let mut v_corrected = v_t.transpose();
425 v_corrected.column_mut(2).scale_mut(-1.0);
426 rotation = v_corrected * u.transpose();
427 }
428
429 let translation = target_centroid - rotation * source_centroid;
430
431 Ok(Isometry3::from_parts(
432 nalgebra::Translation3::from(translation),
433 nalgebra::UnitQuaternion::from_matrix(&rotation),
434 ))
435 }
436
437 pub async fn icp_align(
439 &self,
440 source: &PointCloud<Point3f>,
441 target: &PointCloud<Point3f>,
442 max_iterations: usize,
443 convergence_threshold: f32,
444 max_correspondence_distance: f32,
445 ) -> Result<Isometry3<f32>> {
446 let result = self.optimized_icp_align(
447 source,
448 target,
449 max_iterations,
450 convergence_threshold,
451 max_correspondence_distance
452 ).await?;
453 Ok(result.transformation)
454 }
455
456 async fn find_correspondences(
458 &self,
459 source_points: &[Point3f],
460 target_points: &[Point3f],
461 max_distance: f32,
462 ) -> Result<Vec<(usize, usize, f32)>> {
463 let source_data: Vec<[f32; 3]> = source_points
465 .iter()
466 .map(|p| [p.x, p.y, p.z])
467 .collect();
468
469 let target_data: Vec<[f32; 3]> = target_points
470 .iter()
471 .map(|p| [p.x, p.y, p.z])
472 .collect();
473
474 let source_buffer = self.create_buffer_init(
476 "Source Points",
477 &source_data,
478 wgpu::BufferUsages::STORAGE,
479 );
480
481 let target_buffer = self.create_buffer_init(
482 "Target Points",
483 &target_data,
484 wgpu::BufferUsages::STORAGE,
485 );
486
487 let correspondences_buffer = self.create_buffer(
488 "Correspondences",
489 (source_data.len() * std::mem::size_of::<u32>()) as u64,
490 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
491 );
492
493 let distances_buffer = self.create_buffer(
494 "Distances",
495 (source_data.len() * std::mem::size_of::<f32>()) as u64,
496 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
497 );
498
499 #[repr(C)]
500 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
501 struct ICPParams {
502 num_source: u32,
503 num_target: u32,
504 max_distance: f32,
505 }
506
507 let params = ICPParams {
508 num_source: source_data.len() as u32,
509 num_target: target_data.len() as u32,
510 max_distance,
511 };
512
513 let params_buffer = self.create_buffer_init(
514 "ICP Params",
515 &[params],
516 wgpu::BufferUsages::UNIFORM,
517 );
518
519 let shader = self.create_shader_module("ICP Nearest Neighbor", ICP_NEAREST_NEIGHBOR_SHADER);
521
522 let bind_group_layout = self.create_bind_group_layout(
523 "ICP Correspondence",
524 &[
525 wgpu::BindGroupLayoutEntry {
526 binding: 0,
527 visibility: wgpu::ShaderStages::COMPUTE,
528 ty: wgpu::BindingType::Buffer {
529 ty: wgpu::BufferBindingType::Storage { read_only: true },
530 has_dynamic_offset: false,
531 min_binding_size: None,
532 },
533 count: None,
534 },
535 wgpu::BindGroupLayoutEntry {
536 binding: 1,
537 visibility: wgpu::ShaderStages::COMPUTE,
538 ty: wgpu::BindingType::Buffer {
539 ty: wgpu::BufferBindingType::Storage { read_only: true },
540 has_dynamic_offset: false,
541 min_binding_size: None,
542 },
543 count: None,
544 },
545 wgpu::BindGroupLayoutEntry {
546 binding: 2,
547 visibility: wgpu::ShaderStages::COMPUTE,
548 ty: wgpu::BindingType::Buffer {
549 ty: wgpu::BufferBindingType::Storage { read_only: false },
550 has_dynamic_offset: false,
551 min_binding_size: None,
552 },
553 count: None,
554 },
555 wgpu::BindGroupLayoutEntry {
556 binding: 3,
557 visibility: wgpu::ShaderStages::COMPUTE,
558 ty: wgpu::BindingType::Buffer {
559 ty: wgpu::BufferBindingType::Storage { read_only: false },
560 has_dynamic_offset: false,
561 min_binding_size: None,
562 },
563 count: None,
564 },
565 wgpu::BindGroupLayoutEntry {
566 binding: 4,
567 visibility: wgpu::ShaderStages::COMPUTE,
568 ty: wgpu::BindingType::Buffer {
569 ty: wgpu::BufferBindingType::Uniform,
570 has_dynamic_offset: false,
571 min_binding_size: None,
572 },
573 count: None,
574 },
575 ],
576 );
577
578 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
579 label: Some("ICP Correspondence"),
580 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
581 label: Some("ICP Pipeline Layout"),
582 bind_group_layouts: &[&bind_group_layout],
583 push_constant_ranges: &[],
584 })),
585 module: &shader,
586 entry_point: Some("main"),
587 compilation_options: wgpu::PipelineCompilationOptions::default(),
588 cache: None,
589 });
590
591 let bind_group = self.create_bind_group(
592 "ICP Correspondence",
593 &bind_group_layout,
594 &[
595 wgpu::BindGroupEntry {
596 binding: 0,
597 resource: source_buffer.as_entire_binding(),
598 },
599 wgpu::BindGroupEntry {
600 binding: 1,
601 resource: target_buffer.as_entire_binding(),
602 },
603 wgpu::BindGroupEntry {
604 binding: 2,
605 resource: correspondences_buffer.as_entire_binding(),
606 },
607 wgpu::BindGroupEntry {
608 binding: 3,
609 resource: distances_buffer.as_entire_binding(),
610 },
611 wgpu::BindGroupEntry {
612 binding: 4,
613 resource: params_buffer.as_entire_binding(),
614 },
615 ],
616 );
617
618 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
620 label: Some("ICP Correspondence"),
621 });
622
623 {
624 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
625 label: Some("ICP Correspondence Pass"),
626 timestamp_writes: None,
627 });
628 compute_pass.set_pipeline(&pipeline);
629 compute_pass.set_bind_group(0, &bind_group, &[]);
630 let workgroup_count = (source_data.len() + 63) / 64;
631 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
632 }
633
634 let correspondences_staging = self.create_buffer(
636 "Correspondences Staging",
637 (source_data.len() * std::mem::size_of::<u32>()) as u64,
638 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
639 );
640
641 let distances_staging = self.create_buffer(
642 "Distances Staging",
643 (source_data.len() * std::mem::size_of::<f32>()) as u64,
644 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
645 );
646
647 encoder.copy_buffer_to_buffer(
648 &correspondences_buffer,
649 0,
650 &correspondences_staging,
651 0,
652 (source_data.len() * std::mem::size_of::<u32>()) as u64,
653 );
654
655 encoder.copy_buffer_to_buffer(
656 &distances_buffer,
657 0,
658 &distances_staging,
659 0,
660 (source_data.len() * std::mem::size_of::<f32>()) as u64,
661 );
662
663 self.queue.submit(std::iter::once(encoder.finish()));
664
665 let corr_slice = correspondences_staging.slice(..);
667 let dist_slice = distances_staging.slice(..);
668
669 let (corr_sender, corr_receiver) = futures_intrusive::channel::shared::oneshot_channel();
670 let (dist_sender, dist_receiver) = futures_intrusive::channel::shared::oneshot_channel();
671
672 corr_slice.map_async(wgpu::MapMode::Read, move |v| corr_sender.send(v).unwrap());
673 dist_slice.map_async(wgpu::MapMode::Read, move |v| dist_sender.send(v).unwrap());
674
675 let _ = self.device.poll(wgpu::PollType::Wait);
676
677 if let (Some(Ok(())), Some(Ok(()))) = (corr_receiver.receive().await, dist_receiver.receive().await) {
678 let corr_data = corr_slice.get_mapped_range();
679 let dist_data = dist_slice.get_mapped_range();
680
681 let correspondences: Vec<u32> = bytemuck::cast_slice(&corr_data).to_vec();
682 let distances: Vec<f32> = bytemuck::cast_slice(&dist_data).to_vec();
683
684 let result: Vec<(usize, usize, f32)> = correspondences
685 .into_iter()
686 .zip(distances.into_iter())
687 .enumerate()
688 .filter(|(_, (_, distance))| *distance < max_distance)
689 .map(|(i, (target_idx, distance))| (i, target_idx as usize, distance))
690 .collect();
691
692 drop(corr_data);
693 drop(dist_data);
694 correspondences_staging.unmap();
695 distances_staging.unmap();
696
697 Ok(result)
698 } else {
699 Err(threecrate_core::Error::Gpu("Failed to read GPU correspondence results".to_string()))
700 }
701 }
702}
703
704pub async fn gpu_icp(
706 gpu_context: &GpuContext,
707 source: &PointCloud<Point3f>,
708 target: &PointCloud<Point3f>,
709 max_iterations: usize,
710 convergence_threshold: f32,
711 max_correspondence_distance: f32,
712) -> Result<Isometry3<f32>> {
713 gpu_context.icp_align(source, target, max_iterations, convergence_threshold, max_correspondence_distance).await
714}
715
716pub async fn gpu_batch_icp(
718 gpu_context: &GpuContext,
719 jobs: &[BatchICPJob],
720) -> Result<Vec<BatchICPResult>> {
721 gpu_context.batch_icp_align(jobs).await
722}