1use threecrate_core::{PointCloud, Result, Point3f};
4use crate::GpuContext;
5use nalgebra::Isometry3;
6
7const ICP_NEAREST_NEIGHBOR_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read_write> correspondences: array<u32>;
11@group(0) @binding(3) var<storage, read_write> distances: array<f32>;
12@group(0) @binding(4) var<uniform> params: ICPParams;
13
14struct ICPParams {
15 num_source: u32,
16 num_target: u32,
17 max_distance: f32,
18}
19
20@compute @workgroup_size(64)
21fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22 let index = global_id.x;
23 if (index >= params.num_source) {
24 return;
25 }
26
27 let source_point = source_points[index].xyz;
28 var min_distance = params.max_distance;
29 var best_match = 0u;
30
31 // Find nearest neighbor in target
32 for (var i = 0u; i < params.num_target; i++) {
33 let target_point = target_points[i].xyz;
34 let diff = source_point - target_point;
35 let distance = length(diff);
36
37 if (distance < min_distance) {
38 min_distance = distance;
39 best_match = i;
40 }
41 }
42
43 correspondences[index] = best_match;
44 distances[index] = min_distance;
45}
46"#;
47
48const ICP_CENTROID_SHADER: &str = r#"
49@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
50@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
51@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
52@group(0) @binding(3) var<storage, read_write> centroids: array<vec4<f32>>; // [source_centroid, target_centroid]
53@group(0) @binding(4) var<uniform> params: CentroidParams;
54
55struct CentroidParams {
56 num_correspondences: u32,
57}
58
59@compute @workgroup_size(1)
60fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
61 if (global_id.x != 0u) {
62 return;
63 }
64
65 var source_centroid = vec3<f32>(0.0);
66 var target_centroid = vec3<f32>(0.0);
67
68 for (var i = 0u; i < params.num_correspondences; i++) {
69 let target_idx = correspondences[i];
70 source_centroid += source_points[i].xyz;
71 target_centroid += target_points[target_idx].xyz;
72 }
73
74 let scale = 1.0 / f32(params.num_correspondences);
75 centroids[0] = vec4<f32>(source_centroid * scale, 0.0);
76 centroids[1] = vec4<f32>(target_centroid * scale, 0.0);
77}
78"#;
79
80#[allow(dead_code)]
81const ICP_COVARIANCE_SHADER: &str = r#"
82@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
83@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
84@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
85@group(0) @binding(3) var<storage, read> centroids: array<vec4<f32>>;
86@group(0) @binding(4) var<storage, read_write> covariance: array<f32>; // 9 elements for 3x3 matrix
87@group(0) @binding(5) var<uniform> params: CovarianceParams;
88
89struct CovarianceParams {
90 num_correspondences: u32,
91}
92
93@compute @workgroup_size(64)
94fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
95 let index = global_id.x;
96 if (index >= params.num_correspondences) {
97 return;
98 }
99
100 let source_centroid = centroids[0].xyz;
101 let target_centroid = centroids[1].xyz;
102
103 let target_idx = correspondences[index];
104 let source_centered = source_points[index].xyz - source_centroid;
105 let target_centered = target_points[target_idx].xyz - target_centroid;
106
107 // Compute outer product contribution
108 let h00 = source_centered.x * target_centered.x;
109 let h01 = source_centered.x * target_centered.y;
110 let h02 = source_centered.x * target_centered.z;
111 let h10 = source_centered.y * target_centered.x;
112 let h11 = source_centered.y * target_centered.y;
113 let h12 = source_centered.y * target_centered.z;
114 let h20 = source_centered.z * target_centered.x;
115 let h21 = source_centered.z * target_centered.y;
116 let h22 = source_centered.z * target_centered.z;
117
118 // Atomic add to covariance matrix (approximated with individual element updates)
119 atomicAdd(&covariance[0], bitcast<i32>(h00));
120 atomicAdd(&covariance[1], bitcast<i32>(h01));
121 atomicAdd(&covariance[2], bitcast<i32>(h02));
122 atomicAdd(&covariance[3], bitcast<i32>(h10));
123 atomicAdd(&covariance[4], bitcast<i32>(h11));
124 atomicAdd(&covariance[5], bitcast<i32>(h12));
125 atomicAdd(&covariance[6], bitcast<i32>(h20));
126 atomicAdd(&covariance[7], bitcast<i32>(h21));
127 atomicAdd(&covariance[8], bitcast<i32>(h22));
128}
129"#;
130
131#[derive(Debug, Clone)]
133pub struct BatchICPJob {
134 pub source: PointCloud<Point3f>,
135 pub target: PointCloud<Point3f>,
136 pub max_iterations: usize,
137 pub convergence_threshold: f32,
138 pub max_correspondence_distance: f32,
139}
140
141#[derive(Debug, Clone)]
143pub struct BatchICPResult {
144 pub transformation: Isometry3<f32>,
145 pub final_error: f32,
146 pub iterations: usize,
147}
148
149impl GpuContext {
150 pub async fn batch_icp_align(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
152 let mut results = Vec::with_capacity(jobs.len());
153
154 const BATCH_SIZE: usize = 4; for batch in jobs.chunks(BATCH_SIZE) {
158 let batch_results = self.process_icp_batch(batch).await?;
159 results.extend(batch_results);
160 }
161
162 Ok(results)
163 }
164
165 async fn process_icp_batch(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
167 let mut results = Vec::with_capacity(jobs.len());
168
169 for job in jobs {
172 let result = self.optimized_icp_align(
173 &job.source,
174 &job.target,
175 job.max_iterations,
176 job.convergence_threshold,
177 job.max_correspondence_distance,
178 ).await?;
179 results.push(result);
180 }
181
182 Ok(results)
183 }
184
185 async fn optimized_icp_align(
187 &self,
188 source: &PointCloud<Point3f>,
189 target: &PointCloud<Point3f>,
190 max_iterations: usize,
191 convergence_threshold: f32,
192 max_correspondence_distance: f32,
193 ) -> Result<BatchICPResult> {
194 if source.is_empty() || target.is_empty() {
195 return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
196 }
197
198 let mut current_transform = Isometry3::identity();
199 let mut transformed_source = source.clone();
200 let mut final_error = f32::INFINITY;
201 let mut iterations_used = 0;
202
203 for iteration in 0..max_iterations {
204 iterations_used = iteration + 1;
205
206 let correspondences = self.find_correspondences(
208 &transformed_source.points,
209 &target.points,
210 max_correspondence_distance,
211 ).await?;
212
213 if correspondences.is_empty() {
214 break;
215 }
216
217 let (transform_delta, error) = self.compute_transformation_gpu(
219 &transformed_source.points,
220 &target.points,
221 &correspondences
222 ).await?;
223
224 final_error = error;
225
226 current_transform = transform_delta * current_transform;
228
229 transformed_source = source.clone();
231 for point in &mut transformed_source.points {
232 *point = current_transform.transform_point(point);
233 }
234
235 let translation_norm = transform_delta.translation.vector.norm();
237 let rotation_angle = transform_delta.rotation.angle();
238
239 if translation_norm < convergence_threshold && rotation_angle < convergence_threshold {
240 break;
241 }
242 }
243
244 Ok(BatchICPResult {
245 transformation: current_transform,
246 final_error,
247 iterations: iterations_used,
248 })
249 }
250
251 async fn compute_transformation_gpu(
253 &self,
254 source_points: &[Point3f],
255 target_points: &[Point3f],
256 correspondences: &[(usize, usize, f32)],
257 ) -> Result<(Isometry3<f32>, f32)> {
258 if correspondences.is_empty() {
259 return Ok((Isometry3::identity(), 0.0));
260 }
261
262 let source_data: Vec<[f32; 4]> = source_points
264 .iter()
265 .map(|p| [p.x, p.y, p.z, 0.0])
266 .collect();
267
268 let target_data: Vec<[f32; 4]> = target_points
269 .iter()
270 .map(|p| [p.x, p.y, p.z, 0.0])
271 .collect();
272
273 let correspondence_indices: Vec<u32> = correspondences
274 .iter()
275 .map(|(_, target_idx, _)| *target_idx as u32)
276 .collect();
277
278 let source_buffer = self.create_buffer_init("Source Points", &source_data, wgpu::BufferUsages::STORAGE);
280 let target_buffer = self.create_buffer_init("Target Points", &target_data, wgpu::BufferUsages::STORAGE);
281 let correspondence_buffer = self.create_buffer_init("Correspondences", &correspondence_indices, wgpu::BufferUsages::STORAGE);
282
283 let centroids = self.compute_centroids_gpu(
285 &source_buffer,
286 &target_buffer,
287 &correspondence_buffer,
288 correspondences.len(),
289 ).await?;
290
291 let covariance = self.compute_covariance_cpu(
293 source_points,
294 target_points,
295 correspondences,
296 ¢roids,
297 )?;
298
299 let transformation = self.svd_to_transformation(&covariance, ¢roids)?;
301
302 let error = correspondences.iter().map(|(_, _, dist)| dist).sum::<f32>() / correspondences.len() as f32;
304
305 Ok((transformation, error))
306 }
307
308 async fn compute_centroids_gpu(
310 &self,
311 source_buffer: &wgpu::Buffer,
312 target_buffer: &wgpu::Buffer,
313 correspondence_buffer: &wgpu::Buffer,
314 num_correspondences: usize,
315 ) -> Result<Vec<[f32; 3]>> {
316 #[repr(C)]
317 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
318 struct CentroidParams {
319 num_correspondences: u32,
320 }
321
322 let params = CentroidParams {
323 num_correspondences: num_correspondences as u32,
324 };
325
326 let params_buffer = self.create_buffer_init("Centroid Params", &[params], wgpu::BufferUsages::UNIFORM);
327 let centroids_buffer = self.create_buffer("Centroids", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC);
328
329 let shader = self.create_shader_module("Centroid Computation", ICP_CENTROID_SHADER);
330
331 let bind_group_layout = self.create_bind_group_layout("Centroid Layout", &[
332 wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
333 wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
334 wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
335 wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None },
336 wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, count: None },
337 ]);
338
339 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
340 label: Some("Centroid Pipeline"),
341 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
342 label: Some("Centroid Layout"),
343 bind_group_layouts: &[&bind_group_layout],
344 push_constant_ranges: &[],
345 })),
346 module: &shader,
347 entry_point: "main",
348 compilation_options: wgpu::PipelineCompilationOptions::default(),
349 });
350
351 let bind_group = self.create_bind_group("Centroid Bind Group", &bind_group_layout, &[
352 wgpu::BindGroupEntry { binding: 0, resource: source_buffer.as_entire_binding() },
353 wgpu::BindGroupEntry { binding: 1, resource: target_buffer.as_entire_binding() },
354 wgpu::BindGroupEntry { binding: 2, resource: correspondence_buffer.as_entire_binding() },
355 wgpu::BindGroupEntry { binding: 3, resource: centroids_buffer.as_entire_binding() },
356 wgpu::BindGroupEntry { binding: 4, resource: params_buffer.as_entire_binding() },
357 ]);
358
359 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Centroid Computation") });
360 {
361 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Centroid Pass"), timestamp_writes: None });
362 compute_pass.set_pipeline(&pipeline);
363 compute_pass.set_bind_group(0, &bind_group, &[]);
364 compute_pass.dispatch_workgroups(1, 1, 1);
365 }
366
367 let staging_buffer = self.create_buffer("Centroid Staging", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ);
368 encoder.copy_buffer_to_buffer(¢roids_buffer, 0, &staging_buffer, 0, 2 * std::mem::size_of::<[f32; 3]>() as u64);
369 self.queue.submit(std::iter::once(encoder.finish()));
370
371 let buffer_slice = staging_buffer.slice(..);
372 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
373 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
374 self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
375
376 if let Some(Ok(())) = receiver.receive().await {
377 let data = buffer_slice.get_mapped_range();
378 let centroids: Vec<[f32; 3]> = bytemuck::cast_slice(&data).to_vec();
379 drop(data);
380 staging_buffer.unmap();
381 Ok(centroids)
382 } else {
383 Err(threecrate_core::Error::Gpu("Failed to read centroid results".to_string()))
384 }
385 }
386
387 fn compute_covariance_cpu(
390 &self,
391 source_points: &[Point3f],
392 target_points: &[Point3f],
393 correspondences: &[(usize, usize, f32)],
394 centroids: &[[f32; 3]],
395 ) -> Result<nalgebra::Matrix3<f32>> {
396 let mut h = nalgebra::Matrix3::zeros();
397 let sc = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
398 let tc = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
399 for (si, ti, _dist) in correspondences.iter().copied() {
400 let s = source_points[si];
401 let t = target_points[ti];
402 let sv = s.coords - sc;
403 let tv = t.coords - tc;
404 h += sv * tv.transpose();
405 }
406 Ok(h)
407 }
408
409 fn svd_to_transformation(&self, covariance: &nalgebra::Matrix3<f32>, centroids: &[[f32; 3]]) -> Result<Isometry3<f32>> {
411 let source_centroid = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
412 let target_centroid = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
413
414 let svd = covariance.svd(true, true);
416 let u = svd.u.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
417 let v_t = svd.v_t.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
418
419 let mut rotation = v_t.transpose() * u.transpose();
420
421 if rotation.determinant() < 0.0 {
423 let mut v_corrected = v_t.transpose();
424 v_corrected.column_mut(2).scale_mut(-1.0);
425 rotation = v_corrected * u.transpose();
426 }
427
428 let translation = target_centroid - rotation * source_centroid;
429
430 Ok(Isometry3::from_parts(
431 nalgebra::Translation3::from(translation),
432 nalgebra::UnitQuaternion::from_matrix(&rotation),
433 ))
434 }
435
436 pub async fn icp_align(
438 &self,
439 source: &PointCloud<Point3f>,
440 target: &PointCloud<Point3f>,
441 max_iterations: usize,
442 convergence_threshold: f32,
443 max_correspondence_distance: f32,
444 ) -> Result<Isometry3<f32>> {
445 let result = self.optimized_icp_align(
446 source,
447 target,
448 max_iterations,
449 convergence_threshold,
450 max_correspondence_distance
451 ).await?;
452 Ok(result.transformation)
453 }
454
455 async fn find_correspondences(
457 &self,
458 source_points: &[Point3f],
459 target_points: &[Point3f],
460 max_distance: f32,
461 ) -> Result<Vec<(usize, usize, f32)>> {
462 let source_data: Vec<[f32; 3]> = source_points
464 .iter()
465 .map(|p| [p.x, p.y, p.z])
466 .collect();
467
468 let target_data: Vec<[f32; 3]> = target_points
469 .iter()
470 .map(|p| [p.x, p.y, p.z])
471 .collect();
472
473 let source_buffer = self.create_buffer_init(
475 "Source Points",
476 &source_data,
477 wgpu::BufferUsages::STORAGE,
478 );
479
480 let target_buffer = self.create_buffer_init(
481 "Target Points",
482 &target_data,
483 wgpu::BufferUsages::STORAGE,
484 );
485
486 let correspondences_buffer = self.create_buffer(
487 "Correspondences",
488 (source_data.len() * std::mem::size_of::<u32>()) as u64,
489 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
490 );
491
492 let distances_buffer = self.create_buffer(
493 "Distances",
494 (source_data.len() * std::mem::size_of::<f32>()) as u64,
495 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
496 );
497
498 #[repr(C)]
499 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
500 struct ICPParams {
501 num_source: u32,
502 num_target: u32,
503 max_distance: f32,
504 }
505
506 let params = ICPParams {
507 num_source: source_data.len() as u32,
508 num_target: target_data.len() as u32,
509 max_distance,
510 };
511
512 let params_buffer = self.create_buffer_init(
513 "ICP Params",
514 &[params],
515 wgpu::BufferUsages::UNIFORM,
516 );
517
518 let shader = self.create_shader_module("ICP Nearest Neighbor", ICP_NEAREST_NEIGHBOR_SHADER);
520
521 let bind_group_layout = self.create_bind_group_layout(
522 "ICP Correspondence",
523 &[
524 wgpu::BindGroupLayoutEntry {
525 binding: 0,
526 visibility: wgpu::ShaderStages::COMPUTE,
527 ty: wgpu::BindingType::Buffer {
528 ty: wgpu::BufferBindingType::Storage { read_only: true },
529 has_dynamic_offset: false,
530 min_binding_size: None,
531 },
532 count: None,
533 },
534 wgpu::BindGroupLayoutEntry {
535 binding: 1,
536 visibility: wgpu::ShaderStages::COMPUTE,
537 ty: wgpu::BindingType::Buffer {
538 ty: wgpu::BufferBindingType::Storage { read_only: true },
539 has_dynamic_offset: false,
540 min_binding_size: None,
541 },
542 count: None,
543 },
544 wgpu::BindGroupLayoutEntry {
545 binding: 2,
546 visibility: wgpu::ShaderStages::COMPUTE,
547 ty: wgpu::BindingType::Buffer {
548 ty: wgpu::BufferBindingType::Storage { read_only: false },
549 has_dynamic_offset: false,
550 min_binding_size: None,
551 },
552 count: None,
553 },
554 wgpu::BindGroupLayoutEntry {
555 binding: 3,
556 visibility: wgpu::ShaderStages::COMPUTE,
557 ty: wgpu::BindingType::Buffer {
558 ty: wgpu::BufferBindingType::Storage { read_only: false },
559 has_dynamic_offset: false,
560 min_binding_size: None,
561 },
562 count: None,
563 },
564 wgpu::BindGroupLayoutEntry {
565 binding: 4,
566 visibility: wgpu::ShaderStages::COMPUTE,
567 ty: wgpu::BindingType::Buffer {
568 ty: wgpu::BufferBindingType::Uniform,
569 has_dynamic_offset: false,
570 min_binding_size: None,
571 },
572 count: None,
573 },
574 ],
575 );
576
577 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
578 label: Some("ICP Correspondence"),
579 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
580 label: Some("ICP Pipeline Layout"),
581 bind_group_layouts: &[&bind_group_layout],
582 push_constant_ranges: &[],
583 })),
584 module: &shader,
585 entry_point: "main",
586 compilation_options: wgpu::PipelineCompilationOptions::default(),
587 });
588
589 let bind_group = self.create_bind_group(
590 "ICP Correspondence",
591 &bind_group_layout,
592 &[
593 wgpu::BindGroupEntry {
594 binding: 0,
595 resource: source_buffer.as_entire_binding(),
596 },
597 wgpu::BindGroupEntry {
598 binding: 1,
599 resource: target_buffer.as_entire_binding(),
600 },
601 wgpu::BindGroupEntry {
602 binding: 2,
603 resource: correspondences_buffer.as_entire_binding(),
604 },
605 wgpu::BindGroupEntry {
606 binding: 3,
607 resource: distances_buffer.as_entire_binding(),
608 },
609 wgpu::BindGroupEntry {
610 binding: 4,
611 resource: params_buffer.as_entire_binding(),
612 },
613 ],
614 );
615
616 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
618 label: Some("ICP Correspondence"),
619 });
620
621 {
622 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
623 label: Some("ICP Correspondence Pass"),
624 timestamp_writes: None,
625 });
626 compute_pass.set_pipeline(&pipeline);
627 compute_pass.set_bind_group(0, &bind_group, &[]);
628 let workgroup_count = (source_data.len() + 63) / 64;
629 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
630 }
631
632 let correspondences_staging = self.create_buffer(
634 "Correspondences Staging",
635 (source_data.len() * std::mem::size_of::<u32>()) as u64,
636 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
637 );
638
639 let distances_staging = self.create_buffer(
640 "Distances Staging",
641 (source_data.len() * std::mem::size_of::<f32>()) as u64,
642 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
643 );
644
645 encoder.copy_buffer_to_buffer(
646 &correspondences_buffer,
647 0,
648 &correspondences_staging,
649 0,
650 (source_data.len() * std::mem::size_of::<u32>()) as u64,
651 );
652
653 encoder.copy_buffer_to_buffer(
654 &distances_buffer,
655 0,
656 &distances_staging,
657 0,
658 (source_data.len() * std::mem::size_of::<f32>()) as u64,
659 );
660
661 self.queue.submit(std::iter::once(encoder.finish()));
662
663 let corr_slice = correspondences_staging.slice(..);
665 let dist_slice = distances_staging.slice(..);
666
667 let (corr_sender, corr_receiver) = futures_intrusive::channel::shared::oneshot_channel();
668 let (dist_sender, dist_receiver) = futures_intrusive::channel::shared::oneshot_channel();
669
670 corr_slice.map_async(wgpu::MapMode::Read, move |v| corr_sender.send(v).unwrap());
671 dist_slice.map_async(wgpu::MapMode::Read, move |v| dist_sender.send(v).unwrap());
672
673 self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
674
675 if let (Some(Ok(())), Some(Ok(()))) = (corr_receiver.receive().await, dist_receiver.receive().await) {
676 let corr_data = corr_slice.get_mapped_range();
677 let dist_data = dist_slice.get_mapped_range();
678
679 let correspondences: Vec<u32> = bytemuck::cast_slice(&corr_data).to_vec();
680 let distances: Vec<f32> = bytemuck::cast_slice(&dist_data).to_vec();
681
682 let result: Vec<(usize, usize, f32)> = correspondences
683 .into_iter()
684 .zip(distances.into_iter())
685 .enumerate()
686 .filter(|(_, (_, distance))| *distance < max_distance)
687 .map(|(i, (target_idx, distance))| (i, target_idx as usize, distance))
688 .collect();
689
690 drop(corr_data);
691 drop(dist_data);
692 correspondences_staging.unmap();
693 distances_staging.unmap();
694
695 Ok(result)
696 } else {
697 Err(threecrate_core::Error::Gpu("Failed to read GPU correspondence results".to_string()))
698 }
699 }
700}
701
702pub async fn gpu_icp(
704 gpu_context: &GpuContext,
705 source: &PointCloud<Point3f>,
706 target: &PointCloud<Point3f>,
707 max_iterations: usize,
708 convergence_threshold: f32,
709 max_correspondence_distance: f32,
710) -> Result<Isometry3<f32>> {
711 gpu_context.icp_align(source, target, max_iterations, convergence_threshold, max_correspondence_distance).await
712}
713
714pub async fn gpu_batch_icp(
716 gpu_context: &GpuContext,
717 jobs: &[BatchICPJob],
718) -> Result<Vec<BatchICPResult>> {
719 gpu_context.batch_icp_align(jobs).await
720}