threecrate_gpu/
icp.rs

1//! GPU-accelerated ICP
2
3use threecrate_core::{PointCloud, Result, Point3f};
4use crate::GpuContext;
5use nalgebra::Isometry3;
6
7const ICP_NEAREST_NEIGHBOR_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read_write> correspondences: array<u32>;
11@group(0) @binding(3) var<storage, read_write> distances: array<f32>;
12@group(0) @binding(4) var<uniform> params: ICPParams;
13
14struct ICPParams {
15    num_source: u32,
16    num_target: u32,
17    max_distance: f32,
18}
19
20@compute @workgroup_size(64)
21fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22    let index = global_id.x;
23    if (index >= params.num_source) {
24        return;
25    }
26    
27    let source_point = source_points[index].xyz;
28    var min_distance = params.max_distance;
29    var best_match = 0u;
30    
31    // Find nearest neighbor in target
32    for (var i = 0u; i < params.num_target; i++) {
33        let target_point = target_points[i].xyz;
34        let diff = source_point - target_point;
35        let distance = length(diff);
36        
37        if (distance < min_distance) {
38            min_distance = distance;
39            best_match = i;
40        }
41    }
42    
43    correspondences[index] = best_match;
44    distances[index] = min_distance;
45}
46"#;
47
48const ICP_CENTROID_SHADER: &str = r#"
49@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
50@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
51@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
52@group(0) @binding(3) var<storage, read_write> centroids: array<vec4<f32>>; // [source_centroid, target_centroid]
53@group(0) @binding(4) var<uniform> params: CentroidParams;
54
55struct CentroidParams {
56    num_correspondences: u32,
57}
58
59@compute @workgroup_size(1)
60fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
61    if (global_id.x != 0u) {
62        return;
63    }
64    
65    var source_centroid = vec3<f32>(0.0);
66    var target_centroid = vec3<f32>(0.0);
67    
68    for (var i = 0u; i < params.num_correspondences; i++) {
69        let target_idx = correspondences[i];
70        source_centroid += source_points[i].xyz;
71        target_centroid += target_points[target_idx].xyz;
72    }
73    
74    let scale = 1.0 / f32(params.num_correspondences);
75    centroids[0] = vec4<f32>(source_centroid * scale, 0.0);
76    centroids[1] = vec4<f32>(target_centroid * scale, 0.0);
77}
78"#;
79
80#[allow(dead_code)]
81const ICP_COVARIANCE_SHADER: &str = r#"
82@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
83@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
84@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
85@group(0) @binding(3) var<storage, read> centroids: array<vec4<f32>>;
86@group(0) @binding(4) var<storage, read_write> covariance: array<f32>; // 9 elements for 3x3 matrix
87@group(0) @binding(5) var<uniform> params: CovarianceParams;
88
89struct CovarianceParams {
90    num_correspondences: u32,
91}
92
93@compute @workgroup_size(64)
94fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
95    let index = global_id.x;
96    if (index >= params.num_correspondences) {
97        return;
98    }
99    
100    let source_centroid = centroids[0].xyz;
101    let target_centroid = centroids[1].xyz;
102    
103    let target_idx = correspondences[index];
104    let source_centered = source_points[index].xyz - source_centroid;
105    let target_centered = target_points[target_idx].xyz - target_centroid;
106    
107    // Compute outer product contribution
108    let h00 = source_centered.x * target_centered.x;
109    let h01 = source_centered.x * target_centered.y;
110    let h02 = source_centered.x * target_centered.z;
111    let h10 = source_centered.y * target_centered.x;
112    let h11 = source_centered.y * target_centered.y;
113    let h12 = source_centered.y * target_centered.z;
114    let h20 = source_centered.z * target_centered.x;
115    let h21 = source_centered.z * target_centered.y;
116    let h22 = source_centered.z * target_centered.z;
117    
118    // Atomic add to covariance matrix (approximated with individual element updates)
119    atomicAdd(&covariance[0], bitcast<i32>(h00));
120    atomicAdd(&covariance[1], bitcast<i32>(h01));
121    atomicAdd(&covariance[2], bitcast<i32>(h02));
122    atomicAdd(&covariance[3], bitcast<i32>(h10));
123    atomicAdd(&covariance[4], bitcast<i32>(h11));
124    atomicAdd(&covariance[5], bitcast<i32>(h12));
125    atomicAdd(&covariance[6], bitcast<i32>(h20));
126    atomicAdd(&covariance[7], bitcast<i32>(h21));
127    atomicAdd(&covariance[8], bitcast<i32>(h22));
128}
129"#;
130
131/// Batch ICP operation for multiple point cloud pairs
132#[derive(Debug, Clone)]
133pub struct BatchICPJob {
134    pub source: PointCloud<Point3f>,
135    pub target: PointCloud<Point3f>,
136    pub max_iterations: usize,
137    pub convergence_threshold: f32,
138    pub max_correspondence_distance: f32,
139}
140
141/// Result of a batch ICP operation
142#[derive(Debug, Clone)]
143pub struct BatchICPResult {
144    pub transformation: Isometry3<f32>,
145    pub final_error: f32,
146    pub iterations: usize,
147}
148
149impl GpuContext {
150    /// Execute multiple ICP operations in parallel batches
151    pub async fn batch_icp_align(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
152        let mut results = Vec::with_capacity(jobs.len());
153        
154        // Process jobs in parallel batches to optimize GPU utilization
155        const BATCH_SIZE: usize = 4; // Adjust based on GPU memory
156        
157        for batch in jobs.chunks(BATCH_SIZE) {
158            let batch_results = self.process_icp_batch(batch).await?;
159            results.extend(batch_results);
160        }
161        
162        Ok(results)
163    }
164    
165    /// Process a batch of ICP jobs simultaneously
166    async fn process_icp_batch(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
167        let mut results = Vec::with_capacity(jobs.len());
168        
169        // For now, process sequentially with optimized GPU operations
170        // Future enhancement: true parallel execution with multiple command buffers
171        for job in jobs {
172            let result = self.optimized_icp_align(
173                &job.source,
174                &job.target,
175                job.max_iterations,
176                job.convergence_threshold,
177                job.max_correspondence_distance,
178            ).await?;
179            results.push(result);
180        }
181        
182        Ok(results)
183    }
184    
185    /// Optimized ICP with GPU-accelerated SVD approximation
186    async fn optimized_icp_align(
187        &self,
188        source: &PointCloud<Point3f>,
189        target: &PointCloud<Point3f>,
190        max_iterations: usize,
191        convergence_threshold: f32,
192        max_correspondence_distance: f32,
193    ) -> Result<BatchICPResult> {
194        if source.is_empty() || target.is_empty() {
195            return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
196        }
197
198        let mut current_transform = Isometry3::identity();
199        let mut transformed_source = source.clone();
200        let mut final_error = f32::INFINITY;
201        let mut iterations_used = 0;
202        
203        for iteration in 0..max_iterations {
204            iterations_used = iteration + 1;
205            
206            // Find correspondences using GPU
207            let correspondences = self.find_correspondences(
208                &transformed_source.points,
209                &target.points,
210                max_correspondence_distance,
211            ).await?;
212            
213            if correspondences.is_empty() {
214                break;
215            }
216            
217            // Compute transformation using GPU-accelerated methods
218            let (transform_delta, error) = self.compute_transformation_gpu(
219                &transformed_source.points, 
220                &target.points, 
221                &correspondences
222            ).await?;
223            
224            final_error = error;
225            
226            // Update current transform
227            current_transform = transform_delta * current_transform;
228            
229            // Transform source points
230            transformed_source = source.clone();
231            for point in &mut transformed_source.points {
232                *point = current_transform.transform_point(point);
233            }
234            
235            // Check convergence
236            let translation_norm = transform_delta.translation.vector.norm();
237            let rotation_angle = transform_delta.rotation.angle();
238            
239            if translation_norm < convergence_threshold && rotation_angle < convergence_threshold {
240                break;
241            }
242        }
243        
244        Ok(BatchICPResult {
245            transformation: current_transform,
246            final_error,
247            iterations: iterations_used,
248        })
249    }
250    
251    /// GPU-accelerated transformation computation with centroid and covariance calculation
252    async fn compute_transformation_gpu(
253        &self,
254        source_points: &[Point3f],
255        target_points: &[Point3f],
256        correspondences: &[(usize, usize, f32)],
257    ) -> Result<(Isometry3<f32>, f32)> {
258        if correspondences.is_empty() {
259            return Ok((Isometry3::identity(), 0.0));
260        }
261        
262        // Convert to GPU format (vec4 alignment)
263        let source_data: Vec<[f32; 4]> = source_points
264            .iter()
265            .map(|p| [p.x, p.y, p.z, 0.0])
266            .collect();
267            
268        let target_data: Vec<[f32; 4]> = target_points
269            .iter()
270            .map(|p| [p.x, p.y, p.z, 0.0])
271            .collect();
272            
273        let correspondence_indices: Vec<u32> = correspondences
274            .iter()
275            .map(|(_, target_idx, _)| *target_idx as u32)
276            .collect();
277        
278        // Create GPU buffers
279        let source_buffer = self.create_buffer_init("Source Points", &source_data, wgpu::BufferUsages::STORAGE);
280        let target_buffer = self.create_buffer_init("Target Points", &target_data, wgpu::BufferUsages::STORAGE);
281        let correspondence_buffer = self.create_buffer_init("Correspondences", &correspondence_indices, wgpu::BufferUsages::STORAGE);
282        
283        // Step 1: Compute centroids on GPU
284        let centroids = self.compute_centroids_gpu(
285            &source_buffer,
286            &target_buffer,
287            &correspondence_buffer,
288            correspondences.len(),
289        ).await?;
290        
291        // Step 2: Compute covariance matrix on CPU for stability
292        let covariance = self.compute_covariance_cpu(
293            source_points,
294            target_points,
295            correspondences,
296            &centroids,
297        )?;
298        
299        // Step 3: Perform SVD on CPU (could be moved to GPU with more complex implementation)
300        let transformation = self.svd_to_transformation(&covariance, &centroids)?;
301        
302        // Compute final error
303        let error = correspondences.iter().map(|(_, _, dist)| dist).sum::<f32>() / correspondences.len() as f32;
304        
305        Ok((transformation, error))
306    }
307    
308    /// Compute centroids using GPU
309    async fn compute_centroids_gpu(
310        &self,
311        source_buffer: &wgpu::Buffer,
312        target_buffer: &wgpu::Buffer,
313        correspondence_buffer: &wgpu::Buffer,
314        num_correspondences: usize,
315    ) -> Result<Vec<[f32; 3]>> {
316        #[repr(C)]
317        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
318        struct CentroidParams {
319            num_correspondences: u32,
320        }
321
322        let params = CentroidParams {
323            num_correspondences: num_correspondences as u32,
324        };
325
326        let params_buffer = self.create_buffer_init("Centroid Params", &[params], wgpu::BufferUsages::UNIFORM);
327        let centroids_buffer = self.create_buffer("Centroids", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC);
328
329        let shader = self.create_shader_module("Centroid Computation", ICP_CENTROID_SHADER);
330        
331        let bind_group_layout = self.create_bind_group_layout("Centroid Layout", &[
332            wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
333            wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
334            wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
335            wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None },
336            wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, count: None },
337        ]);
338
339        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
340            label: Some("Centroid Pipeline"),
341            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
342                label: Some("Centroid Layout"),
343                bind_group_layouts: &[&bind_group_layout],
344                push_constant_ranges: &[],
345            })),
346            module: &shader,
347            entry_point: Some("main"),
348            compilation_options: wgpu::PipelineCompilationOptions::default(),
349            cache: None,
350        });
351
352        let bind_group = self.create_bind_group("Centroid Bind Group", &bind_group_layout, &[
353            wgpu::BindGroupEntry { binding: 0, resource: source_buffer.as_entire_binding() },
354            wgpu::BindGroupEntry { binding: 1, resource: target_buffer.as_entire_binding() },
355            wgpu::BindGroupEntry { binding: 2, resource: correspondence_buffer.as_entire_binding() },
356            wgpu::BindGroupEntry { binding: 3, resource: centroids_buffer.as_entire_binding() },
357            wgpu::BindGroupEntry { binding: 4, resource: params_buffer.as_entire_binding() },
358        ]);
359
360        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Centroid Computation") });
361        {
362            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Centroid Pass"), timestamp_writes: None });
363            compute_pass.set_pipeline(&pipeline);
364            compute_pass.set_bind_group(0, &bind_group, &[]);
365            compute_pass.dispatch_workgroups(1, 1, 1);
366        }
367
368        let staging_buffer = self.create_buffer("Centroid Staging", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ);
369        encoder.copy_buffer_to_buffer(&centroids_buffer, 0, &staging_buffer, 0, 2 * std::mem::size_of::<[f32; 3]>() as u64);
370        self.queue.submit(std::iter::once(encoder.finish()));
371
372        let buffer_slice = staging_buffer.slice(..);
373        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
374        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
375        let _ = self.device.poll(wgpu::PollType::Wait);
376
377        if let Some(Ok(())) = receiver.receive().await {
378            let data = buffer_slice.get_mapped_range();
379            let centroids: Vec<[f32; 3]> = bytemuck::cast_slice(&data).to_vec();
380            drop(data);
381            staging_buffer.unmap();
382            Ok(centroids)
383        } else {
384            Err(threecrate_core::Error::Gpu("Failed to read centroid results".to_string()))
385        }
386    }
387    
388    /// Compute covariance matrix using GPU (simplified version without atomics for now)
389    // CPU covariance as stable fallback
390    fn compute_covariance_cpu(
391        &self,
392        source_points: &[Point3f],
393        target_points: &[Point3f],
394        correspondences: &[(usize, usize, f32)],
395        centroids: &[[f32; 3]],
396    ) -> Result<nalgebra::Matrix3<f32>> {
397        let mut h = nalgebra::Matrix3::zeros();
398        let sc = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
399        let tc = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
400        for (si, ti, _dist) in correspondences.iter().copied() {
401            let s = source_points[si];
402            let t = target_points[ti];
403            let sv = s.coords - sc;
404            let tv = t.coords - tc;
405            h += sv * tv.transpose();
406        }
407        Ok(h)
408    }
409    
410    /// Convert covariance matrix to transformation using SVD
411    fn svd_to_transformation(&self, covariance: &nalgebra::Matrix3<f32>, centroids: &[[f32; 3]]) -> Result<Isometry3<f32>> {
412        let source_centroid = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
413        let target_centroid = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
414        
415        // Compute SVD
416        let svd = covariance.svd(true, true);
417        let u = svd.u.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
418        let v_t = svd.v_t.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
419        
420        let mut rotation = v_t.transpose() * u.transpose();
421        
422        // Ensure proper rotation (det = 1)
423        if rotation.determinant() < 0.0 {
424            let mut v_corrected = v_t.transpose();
425            v_corrected.column_mut(2).scale_mut(-1.0);
426            rotation = v_corrected * u.transpose();
427        }
428        
429        let translation = target_centroid - rotation * source_centroid;
430        
431        Ok(Isometry3::from_parts(
432            nalgebra::Translation3::from(translation),
433            nalgebra::UnitQuaternion::from_matrix(&rotation),
434        ))
435    }
436
437    /// GPU-accelerated ICP alignment (original single implementation)
438    pub async fn icp_align(
439        &self,
440        source: &PointCloud<Point3f>,
441        target: &PointCloud<Point3f>,
442        max_iterations: usize,
443        convergence_threshold: f32,
444        max_correspondence_distance: f32,
445    ) -> Result<Isometry3<f32>> {
446        let result = self.optimized_icp_align(
447            source, 
448            target, 
449            max_iterations, 
450            convergence_threshold, 
451            max_correspondence_distance
452        ).await?;
453        Ok(result.transformation)
454    }
455    
456    /// Find nearest neighbor correspondences using GPU
457    async fn find_correspondences(
458        &self,
459        source_points: &[Point3f],
460        target_points: &[Point3f],
461        max_distance: f32,
462    ) -> Result<Vec<(usize, usize, f32)>> {
463        // Convert points to GPU format
464        let source_data: Vec<[f32; 3]> = source_points
465            .iter()
466            .map(|p| [p.x, p.y, p.z])
467            .collect();
468            
469        let target_data: Vec<[f32; 3]> = target_points
470            .iter()
471            .map(|p| [p.x, p.y, p.z])
472            .collect();
473
474        // Create buffers
475        let source_buffer = self.create_buffer_init(
476            "Source Points",
477            &source_data,
478            wgpu::BufferUsages::STORAGE,
479        );
480
481        let target_buffer = self.create_buffer_init(
482            "Target Points", 
483            &target_data,
484            wgpu::BufferUsages::STORAGE,
485        );
486
487        let correspondences_buffer = self.create_buffer(
488            "Correspondences",
489            (source_data.len() * std::mem::size_of::<u32>()) as u64,
490            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
491        );
492
493        let distances_buffer = self.create_buffer(
494            "Distances",
495            (source_data.len() * std::mem::size_of::<f32>()) as u64,
496            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
497        );
498
499        #[repr(C)]
500        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
501        struct ICPParams {
502            num_source: u32,
503            num_target: u32,
504            max_distance: f32,
505        }
506
507        let params = ICPParams {
508            num_source: source_data.len() as u32,
509            num_target: target_data.len() as u32,
510            max_distance,
511        };
512
513        let params_buffer = self.create_buffer_init(
514            "ICP Params",
515            &[params],
516            wgpu::BufferUsages::UNIFORM,
517        );
518
519        // Create shader and pipeline
520        let shader = self.create_shader_module("ICP Nearest Neighbor", ICP_NEAREST_NEIGHBOR_SHADER);
521        
522        let bind_group_layout = self.create_bind_group_layout(
523            "ICP Correspondence",
524            &[
525                wgpu::BindGroupLayoutEntry {
526                    binding: 0,
527                    visibility: wgpu::ShaderStages::COMPUTE,
528                    ty: wgpu::BindingType::Buffer {
529                        ty: wgpu::BufferBindingType::Storage { read_only: true },
530                        has_dynamic_offset: false,
531                        min_binding_size: None,
532                    },
533                    count: None,
534                },
535                wgpu::BindGroupLayoutEntry {
536                    binding: 1,
537                    visibility: wgpu::ShaderStages::COMPUTE,
538                    ty: wgpu::BindingType::Buffer {
539                        ty: wgpu::BufferBindingType::Storage { read_only: true },
540                        has_dynamic_offset: false,
541                        min_binding_size: None,
542                    },
543                    count: None,
544                },
545                wgpu::BindGroupLayoutEntry {
546                    binding: 2,
547                    visibility: wgpu::ShaderStages::COMPUTE,
548                    ty: wgpu::BindingType::Buffer {
549                        ty: wgpu::BufferBindingType::Storage { read_only: false },
550                        has_dynamic_offset: false,
551                        min_binding_size: None,
552                    },
553                    count: None,
554                },
555                wgpu::BindGroupLayoutEntry {
556                    binding: 3,
557                    visibility: wgpu::ShaderStages::COMPUTE,
558                    ty: wgpu::BindingType::Buffer {
559                        ty: wgpu::BufferBindingType::Storage { read_only: false },
560                        has_dynamic_offset: false,
561                        min_binding_size: None,
562                    },
563                    count: None,
564                },
565                wgpu::BindGroupLayoutEntry {
566                    binding: 4,
567                    visibility: wgpu::ShaderStages::COMPUTE,
568                    ty: wgpu::BindingType::Buffer {
569                        ty: wgpu::BufferBindingType::Uniform,
570                        has_dynamic_offset: false,
571                        min_binding_size: None,
572                    },
573                    count: None,
574                },
575            ],
576        );
577
578        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
579            label: Some("ICP Correspondence"),
580            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
581                label: Some("ICP Pipeline Layout"),
582                bind_group_layouts: &[&bind_group_layout],
583                push_constant_ranges: &[],
584            })),
585            module: &shader,
586            entry_point: Some("main"),
587            compilation_options: wgpu::PipelineCompilationOptions::default(),
588            cache: None,
589        });
590
591        let bind_group = self.create_bind_group(
592            "ICP Correspondence",
593            &bind_group_layout,
594            &[
595                wgpu::BindGroupEntry {
596                    binding: 0,
597                    resource: source_buffer.as_entire_binding(),
598                },
599                wgpu::BindGroupEntry {
600                    binding: 1,
601                    resource: target_buffer.as_entire_binding(),
602                },
603                wgpu::BindGroupEntry {
604                    binding: 2,
605                    resource: correspondences_buffer.as_entire_binding(),
606                },
607                wgpu::BindGroupEntry {
608                    binding: 3,
609                    resource: distances_buffer.as_entire_binding(),
610                },
611                wgpu::BindGroupEntry {
612                    binding: 4,
613                    resource: params_buffer.as_entire_binding(),
614                },
615            ],
616        );
617
618        // Execute compute shader
619        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
620            label: Some("ICP Correspondence"),
621        });
622
623        {
624            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
625                label: Some("ICP Correspondence Pass"),
626                timestamp_writes: None,
627            });
628            compute_pass.set_pipeline(&pipeline);
629            compute_pass.set_bind_group(0, &bind_group, &[]);
630            let workgroup_count = (source_data.len() + 63) / 64;
631            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
632        }
633
634        // Read back results
635        let correspondences_staging = self.create_buffer(
636            "Correspondences Staging",
637            (source_data.len() * std::mem::size_of::<u32>()) as u64,
638            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
639        );
640
641        let distances_staging = self.create_buffer(
642            "Distances Staging",
643            (source_data.len() * std::mem::size_of::<f32>()) as u64,
644            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
645        );
646
647        encoder.copy_buffer_to_buffer(
648            &correspondences_buffer,
649            0,
650            &correspondences_staging,
651            0,
652            (source_data.len() * std::mem::size_of::<u32>()) as u64,
653        );
654
655        encoder.copy_buffer_to_buffer(
656            &distances_buffer,
657            0,
658            &distances_staging,
659            0,
660            (source_data.len() * std::mem::size_of::<f32>()) as u64,
661        );
662
663        self.queue.submit(std::iter::once(encoder.finish()));
664
665        // Map and read correspondences
666        let corr_slice = correspondences_staging.slice(..);
667        let dist_slice = distances_staging.slice(..);
668        
669        let (corr_sender, corr_receiver) = futures_intrusive::channel::shared::oneshot_channel();
670        let (dist_sender, dist_receiver) = futures_intrusive::channel::shared::oneshot_channel();
671        
672        corr_slice.map_async(wgpu::MapMode::Read, move |v| corr_sender.send(v).unwrap());
673        dist_slice.map_async(wgpu::MapMode::Read, move |v| dist_sender.send(v).unwrap());
674
675        let _ = self.device.poll(wgpu::PollType::Wait);
676
677        if let (Some(Ok(())), Some(Ok(()))) = (corr_receiver.receive().await, dist_receiver.receive().await) {
678            let corr_data = corr_slice.get_mapped_range();
679            let dist_data = dist_slice.get_mapped_range();
680            
681            let correspondences: Vec<u32> = bytemuck::cast_slice(&corr_data).to_vec();
682            let distances: Vec<f32> = bytemuck::cast_slice(&dist_data).to_vec();
683            
684            let result: Vec<(usize, usize, f32)> = correspondences
685                .into_iter()
686                .zip(distances.into_iter())
687                .enumerate()
688                .filter(|(_, (_, distance))| *distance < max_distance)
689                .map(|(i, (target_idx, distance))| (i, target_idx as usize, distance))
690                .collect();
691            
692            drop(corr_data);
693            drop(dist_data);
694            correspondences_staging.unmap();
695            distances_staging.unmap();
696            
697            Ok(result)
698        } else {
699            Err(threecrate_core::Error::Gpu("Failed to read GPU correspondence results".to_string()))
700        }
701    }
702}
703
704/// GPU-accelerated ICP registration
705pub async fn gpu_icp(
706    gpu_context: &GpuContext,
707    source: &PointCloud<Point3f>,
708    target: &PointCloud<Point3f>,
709    max_iterations: usize,
710    convergence_threshold: f32,
711    max_correspondence_distance: f32,
712) -> Result<Isometry3<f32>> {
713    gpu_context.icp_align(source, target, max_iterations, convergence_threshold, max_correspondence_distance).await
714}
715
716/// Execute batch ICP operations on multiple point cloud pairs
717pub async fn gpu_batch_icp(
718    gpu_context: &GpuContext,
719    jobs: &[BatchICPJob],
720) -> Result<Vec<BatchICPResult>> {
721    gpu_context.batch_icp_align(jobs).await
722}