Skip to main content

threecrate_gpu/
icp.rs

1//! GPU-accelerated ICP
2
3use threecrate_core::{PointCloud, Result, Point3f, Vector3f};
4use crate::GpuContext;
5use nalgebra::{Isometry3, Matrix6, Vector6, UnitQuaternion, Translation3};
6
7const ICP_NEAREST_NEIGHBOR_SHADER: &str = r#"
8@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
9@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
10@group(0) @binding(2) var<storage, read_write> correspondences: array<u32>;
11@group(0) @binding(3) var<storage, read_write> distances: array<f32>;
12@group(0) @binding(4) var<uniform> params: ICPParams;
13
14struct ICPParams {
15    num_source: u32,
16    num_target: u32,
17    max_distance: f32,
18}
19
20@compute @workgroup_size(64)
21fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
22    let index = global_id.x;
23    if (index >= params.num_source) {
24        return;
25    }
26    
27    let source_point = source_points[index].xyz;
28    var min_distance = params.max_distance;
29    var best_match = 0u;
30    
31    // Find nearest neighbor in target
32    for (var i = 0u; i < params.num_target; i++) {
33        let target_point = target_points[i].xyz;
34        let diff = source_point - target_point;
35        let distance = length(diff);
36        
37        if (distance < min_distance) {
38            min_distance = distance;
39            best_match = i;
40        }
41    }
42    
43    correspondences[index] = best_match;
44    distances[index] = min_distance;
45}
46"#;
47
48const ICP_CENTROID_SHADER: &str = r#"
49@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
50@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
51@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
52@group(0) @binding(3) var<storage, read_write> centroids: array<vec4<f32>>; // [source_centroid, target_centroid]
53@group(0) @binding(4) var<uniform> params: CentroidParams;
54
55struct CentroidParams {
56    num_correspondences: u32,
57}
58
59@compute @workgroup_size(1)
60fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
61    if (global_id.x != 0u) {
62        return;
63    }
64    
65    var source_centroid = vec3<f32>(0.0);
66    var target_centroid = vec3<f32>(0.0);
67    
68    for (var i = 0u; i < params.num_correspondences; i++) {
69        let target_idx = correspondences[i];
70        source_centroid += source_points[i].xyz;
71        target_centroid += target_points[target_idx].xyz;
72    }
73    
74    let scale = 1.0 / f32(params.num_correspondences);
75    centroids[0] = vec4<f32>(source_centroid * scale, 0.0);
76    centroids[1] = vec4<f32>(target_centroid * scale, 0.0);
77}
78"#;
79
80#[allow(dead_code)]
81const ICP_COVARIANCE_SHADER: &str = r#"
82@group(0) @binding(0) var<storage, read> source_points: array<vec4<f32>>;
83@group(0) @binding(1) var<storage, read> target_points: array<vec4<f32>>;
84@group(0) @binding(2) var<storage, read> correspondences: array<u32>;
85@group(0) @binding(3) var<storage, read> centroids: array<vec4<f32>>;
86@group(0) @binding(4) var<storage, read_write> covariance: array<f32>; // 9 elements for 3x3 matrix
87@group(0) @binding(5) var<uniform> params: CovarianceParams;
88
89struct CovarianceParams {
90    num_correspondences: u32,
91}
92
93@compute @workgroup_size(64)
94fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
95    let index = global_id.x;
96    if (index >= params.num_correspondences) {
97        return;
98    }
99    
100    let source_centroid = centroids[0].xyz;
101    let target_centroid = centroids[1].xyz;
102    
103    let target_idx = correspondences[index];
104    let source_centered = source_points[index].xyz - source_centroid;
105    let target_centered = target_points[target_idx].xyz - target_centroid;
106    
107    // Compute outer product contribution
108    let h00 = source_centered.x * target_centered.x;
109    let h01 = source_centered.x * target_centered.y;
110    let h02 = source_centered.x * target_centered.z;
111    let h10 = source_centered.y * target_centered.x;
112    let h11 = source_centered.y * target_centered.y;
113    let h12 = source_centered.y * target_centered.z;
114    let h20 = source_centered.z * target_centered.x;
115    let h21 = source_centered.z * target_centered.y;
116    let h22 = source_centered.z * target_centered.z;
117    
118    // Atomic add to covariance matrix (approximated with individual element updates)
119    atomicAdd(&covariance[0], bitcast<i32>(h00));
120    atomicAdd(&covariance[1], bitcast<i32>(h01));
121    atomicAdd(&covariance[2], bitcast<i32>(h02));
122    atomicAdd(&covariance[3], bitcast<i32>(h10));
123    atomicAdd(&covariance[4], bitcast<i32>(h11));
124    atomicAdd(&covariance[5], bitcast<i32>(h12));
125    atomicAdd(&covariance[6], bitcast<i32>(h20));
126    atomicAdd(&covariance[7], bitcast<i32>(h21));
127    atomicAdd(&covariance[8], bitcast<i32>(h22));
128}
129"#;
130
131/// Batch ICP operation for multiple point cloud pairs
132#[derive(Debug, Clone)]
133pub struct BatchICPJob {
134    pub source: PointCloud<Point3f>,
135    pub target: PointCloud<Point3f>,
136    pub max_iterations: usize,
137    pub convergence_threshold: f32,
138    pub max_correspondence_distance: f32,
139}
140
141/// Result of a batch ICP operation
142#[derive(Debug, Clone)]
143pub struct BatchICPResult {
144    pub transformation: Isometry3<f32>,
145    pub final_error: f32,
146    pub iterations: usize,
147}
148
149impl GpuContext {
150    /// Execute multiple ICP operations in parallel batches
151    pub async fn batch_icp_align(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
152        let mut results = Vec::with_capacity(jobs.len());
153        
154        // Process jobs in parallel batches to optimize GPU utilization
155        const BATCH_SIZE: usize = 4; // Adjust based on GPU memory
156        
157        for batch in jobs.chunks(BATCH_SIZE) {
158            let batch_results = self.process_icp_batch(batch).await?;
159            results.extend(batch_results);
160        }
161        
162        Ok(results)
163    }
164    
165    /// Process a batch of ICP jobs simultaneously
166    async fn process_icp_batch(&self, jobs: &[BatchICPJob]) -> Result<Vec<BatchICPResult>> {
167        let mut results = Vec::with_capacity(jobs.len());
168        
169        // For now, process sequentially with optimized GPU operations
170        // Future enhancement: true parallel execution with multiple command buffers
171        for job in jobs {
172            let result = self.optimized_icp_align(
173                &job.source,
174                &job.target,
175                job.max_iterations,
176                job.convergence_threshold,
177                job.max_correspondence_distance,
178            ).await?;
179            results.push(result);
180        }
181        
182        Ok(results)
183    }
184    
185    /// Optimized ICP with GPU-accelerated SVD approximation
186    async fn optimized_icp_align(
187        &self,
188        source: &PointCloud<Point3f>,
189        target: &PointCloud<Point3f>,
190        max_iterations: usize,
191        convergence_threshold: f32,
192        max_correspondence_distance: f32,
193    ) -> Result<BatchICPResult> {
194        if source.is_empty() || target.is_empty() {
195            return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
196        }
197
198        let mut current_transform = Isometry3::identity();
199        let mut transformed_source = source.clone();
200        let mut final_error = f32::INFINITY;
201        let mut iterations_used = 0;
202        
203        for iteration in 0..max_iterations {
204            iterations_used = iteration + 1;
205            
206            // Find correspondences using GPU
207            let correspondences = self.find_correspondences(
208                &transformed_source.points,
209                &target.points,
210                max_correspondence_distance,
211            ).await?;
212            
213            if correspondences.is_empty() {
214                break;
215            }
216            
217            // Compute transformation using GPU-accelerated methods
218            let (transform_delta, error) = self.compute_transformation_gpu(
219                &transformed_source.points, 
220                &target.points, 
221                &correspondences
222            ).await?;
223            
224            final_error = error;
225            
226            // Update current transform
227            current_transform = transform_delta * current_transform;
228            
229            // Transform source points
230            transformed_source = source.clone();
231            for point in &mut transformed_source.points {
232                *point = current_transform.transform_point(point);
233            }
234            
235            // Check convergence
236            let translation_norm = transform_delta.translation.vector.norm();
237            let rotation_angle = transform_delta.rotation.angle();
238            
239            if translation_norm < convergence_threshold && rotation_angle < convergence_threshold {
240                break;
241            }
242        }
243        
244        Ok(BatchICPResult {
245            transformation: current_transform,
246            final_error,
247            iterations: iterations_used,
248        })
249    }
250    
251    /// GPU-accelerated transformation computation with centroid and covariance calculation
252    async fn compute_transformation_gpu(
253        &self,
254        source_points: &[Point3f],
255        target_points: &[Point3f],
256        correspondences: &[(usize, usize, f32)],
257    ) -> Result<(Isometry3<f32>, f32)> {
258        if correspondences.is_empty() {
259            return Ok((Isometry3::identity(), 0.0));
260        }
261        
262        // Convert to GPU format (vec4 alignment)
263        let source_data: Vec<[f32; 4]> = source_points
264            .iter()
265            .map(|p| [p.x, p.y, p.z, 0.0])
266            .collect();
267            
268        let target_data: Vec<[f32; 4]> = target_points
269            .iter()
270            .map(|p| [p.x, p.y, p.z, 0.0])
271            .collect();
272            
273        let correspondence_indices: Vec<u32> = correspondences
274            .iter()
275            .map(|(_, target_idx, _)| *target_idx as u32)
276            .collect();
277        
278        // Create GPU buffers
279        let source_buffer = self.create_buffer_init("Source Points", &source_data, wgpu::BufferUsages::STORAGE);
280        let target_buffer = self.create_buffer_init("Target Points", &target_data, wgpu::BufferUsages::STORAGE);
281        let correspondence_buffer = self.create_buffer_init("Correspondences", &correspondence_indices, wgpu::BufferUsages::STORAGE);
282        
283        // Step 1: Compute centroids on GPU
284        let centroids = self.compute_centroids_gpu(
285            &source_buffer,
286            &target_buffer,
287            &correspondence_buffer,
288            correspondences.len(),
289        ).await?;
290        
291        // Step 2: Compute covariance matrix on CPU for stability
292        let covariance = self.compute_covariance_cpu(
293            source_points,
294            target_points,
295            correspondences,
296            &centroids,
297        )?;
298        
299        // Step 3: Perform SVD on CPU (could be moved to GPU with more complex implementation)
300        let transformation = self.svd_to_transformation(&covariance, &centroids)?;
301        
302        // Compute final error
303        let error = correspondences.iter().map(|(_, _, dist)| dist).sum::<f32>() / correspondences.len() as f32;
304        
305        Ok((transformation, error))
306    }
307    
308    /// Compute centroids using GPU
309    async fn compute_centroids_gpu(
310        &self,
311        source_buffer: &wgpu::Buffer,
312        target_buffer: &wgpu::Buffer,
313        correspondence_buffer: &wgpu::Buffer,
314        num_correspondences: usize,
315    ) -> Result<Vec<[f32; 3]>> {
316        #[repr(C)]
317        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
318        struct CentroidParams {
319            num_correspondences: u32,
320        }
321
322        let params = CentroidParams {
323            num_correspondences: num_correspondences as u32,
324        };
325
326        let params_buffer = self.create_buffer_init("Centroid Params", &[params], wgpu::BufferUsages::UNIFORM);
327        let centroids_buffer = self.create_buffer("Centroids", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC);
328
329        let shader = self.create_shader_module("Centroid Computation", ICP_CENTROID_SHADER);
330        
331        let bind_group_layout = self.create_bind_group_layout("Centroid Layout", &[
332            wgpu::BindGroupLayoutEntry { binding: 0, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
333            wgpu::BindGroupLayoutEntry { binding: 1, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
334            wgpu::BindGroupLayoutEntry { binding: 2, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: true }, has_dynamic_offset: false, min_binding_size: None }, count: None },
335            wgpu::BindGroupLayoutEntry { binding: 3, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Storage { read_only: false }, has_dynamic_offset: false, min_binding_size: None }, count: None },
336            wgpu::BindGroupLayoutEntry { binding: 4, visibility: wgpu::ShaderStages::COMPUTE, ty: wgpu::BindingType::Buffer { ty: wgpu::BufferBindingType::Uniform, has_dynamic_offset: false, min_binding_size: None }, count: None },
337        ]);
338
339        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
340            label: Some("Centroid Pipeline"),
341            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
342                label: Some("Centroid Layout"),
343                bind_group_layouts: &[&bind_group_layout],
344                immediate_size: 0,
345            })),
346            module: &shader,
347            entry_point: Some("main"),
348            compilation_options: wgpu::PipelineCompilationOptions::default(),
349            cache: None,
350        });
351
352        let bind_group = self.create_bind_group("Centroid Bind Group", &bind_group_layout, &[
353            wgpu::BindGroupEntry { binding: 0, resource: source_buffer.as_entire_binding() },
354            wgpu::BindGroupEntry { binding: 1, resource: target_buffer.as_entire_binding() },
355            wgpu::BindGroupEntry { binding: 2, resource: correspondence_buffer.as_entire_binding() },
356            wgpu::BindGroupEntry { binding: 3, resource: centroids_buffer.as_entire_binding() },
357            wgpu::BindGroupEntry { binding: 4, resource: params_buffer.as_entire_binding() },
358        ]);
359
360        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Centroid Computation") });
361        {
362            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: Some("Centroid Pass"), timestamp_writes: None });
363            compute_pass.set_pipeline(&pipeline);
364            compute_pass.set_bind_group(0, &bind_group, &[]);
365            compute_pass.dispatch_workgroups(1, 1, 1);
366        }
367
368        let staging_buffer = self.create_buffer("Centroid Staging", 2 * std::mem::size_of::<[f32; 3]>() as u64, wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ);
369        encoder.copy_buffer_to_buffer(&centroids_buffer, 0, &staging_buffer, 0, 2 * std::mem::size_of::<[f32; 3]>() as u64);
370        self.queue.submit(std::iter::once(encoder.finish()));
371
372        let buffer_slice = staging_buffer.slice(..);
373        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
374        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
375        self.device.poll(wgpu::PollType::Wait {
376            submission_index: None,
377            timeout: None,
378        });
379
380        if let Some(Ok(())) = receiver.receive().await {
381            let data = buffer_slice.get_mapped_range();
382            let centroids: Vec<[f32; 3]> = bytemuck::cast_slice(&data).to_vec();
383            drop(data);
384            staging_buffer.unmap();
385            Ok(centroids)
386        } else {
387            Err(threecrate_core::Error::Gpu("Failed to read centroid results".to_string()))
388        }
389    }
390    
391    /// Compute covariance matrix using GPU (simplified version without atomics for now)
392    // CPU covariance as stable fallback
393    fn compute_covariance_cpu(
394        &self,
395        source_points: &[Point3f],
396        target_points: &[Point3f],
397        correspondences: &[(usize, usize, f32)],
398        centroids: &[[f32; 3]],
399    ) -> Result<nalgebra::Matrix3<f32>> {
400        let mut h = nalgebra::Matrix3::zeros();
401        let sc = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
402        let tc = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
403        for (si, ti, _dist) in correspondences.iter().copied() {
404            let s = source_points[si];
405            let t = target_points[ti];
406            let sv = s.coords - sc;
407            let tv = t.coords - tc;
408            h += sv * tv.transpose();
409        }
410        Ok(h)
411    }
412    
413    /// Convert covariance matrix to transformation using SVD
414    fn svd_to_transformation(&self, covariance: &nalgebra::Matrix3<f32>, centroids: &[[f32; 3]]) -> Result<Isometry3<f32>> {
415        let source_centroid = nalgebra::Vector3::new(centroids[0][0], centroids[0][1], centroids[0][2]);
416        let target_centroid = nalgebra::Vector3::new(centroids[1][0], centroids[1][1], centroids[1][2]);
417        
418        // Compute SVD
419        let svd = covariance.svd(true, true);
420        let u = svd.u.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
421        let v_t = svd.v_t.ok_or_else(|| threecrate_core::Error::Algorithm("SVD failed".to_string()))?;
422        
423        let mut rotation = v_t.transpose() * u.transpose();
424        
425        // Ensure proper rotation (det = 1)
426        if rotation.determinant() < 0.0 {
427            let mut v_corrected = v_t.transpose();
428            v_corrected.column_mut(2).scale_mut(-1.0);
429            rotation = v_corrected * u.transpose();
430        }
431        
432        let translation = target_centroid - rotation * source_centroid;
433        
434        Ok(Isometry3::from_parts(
435            nalgebra::Translation3::from(translation),
436            nalgebra::UnitQuaternion::from_matrix(&rotation),
437        ))
438    }
439
440    /// GPU-accelerated ICP alignment (original single implementation)
441    pub async fn icp_align(
442        &self,
443        source: &PointCloud<Point3f>,
444        target: &PointCloud<Point3f>,
445        max_iterations: usize,
446        convergence_threshold: f32,
447        max_correspondence_distance: f32,
448    ) -> Result<Isometry3<f32>> {
449        let result = self.optimized_icp_align(
450            source, 
451            target, 
452            max_iterations, 
453            convergence_threshold, 
454            max_correspondence_distance
455        ).await?;
456        Ok(result.transformation)
457    }
458    
459    /// Find nearest neighbor correspondences using GPU
460    async fn find_correspondences(
461        &self,
462        source_points: &[Point3f],
463        target_points: &[Point3f],
464        max_distance: f32,
465    ) -> Result<Vec<(usize, usize, f32)>> {
466        // Convert points to GPU format
467        let source_data: Vec<[f32; 3]> = source_points
468            .iter()
469            .map(|p| [p.x, p.y, p.z])
470            .collect();
471            
472        let target_data: Vec<[f32; 3]> = target_points
473            .iter()
474            .map(|p| [p.x, p.y, p.z])
475            .collect();
476
477        // Create buffers
478        let source_buffer = self.create_buffer_init(
479            "Source Points",
480            &source_data,
481            wgpu::BufferUsages::STORAGE,
482        );
483
484        let target_buffer = self.create_buffer_init(
485            "Target Points", 
486            &target_data,
487            wgpu::BufferUsages::STORAGE,
488        );
489
490        let correspondences_buffer = self.create_buffer(
491            "Correspondences",
492            (source_data.len() * std::mem::size_of::<u32>()) as u64,
493            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
494        );
495
496        let distances_buffer = self.create_buffer(
497            "Distances",
498            (source_data.len() * std::mem::size_of::<f32>()) as u64,
499            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
500        );
501
502        #[repr(C)]
503        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
504        struct ICPParams {
505            num_source: u32,
506            num_target: u32,
507            max_distance: f32,
508        }
509
510        let params = ICPParams {
511            num_source: source_data.len() as u32,
512            num_target: target_data.len() as u32,
513            max_distance,
514        };
515
516        let params_buffer = self.create_buffer_init(
517            "ICP Params",
518            &[params],
519            wgpu::BufferUsages::UNIFORM,
520        );
521
522        // Create shader and pipeline
523        let shader = self.create_shader_module("ICP Nearest Neighbor", ICP_NEAREST_NEIGHBOR_SHADER);
524        
525        let bind_group_layout = self.create_bind_group_layout(
526            "ICP Correspondence",
527            &[
528                wgpu::BindGroupLayoutEntry {
529                    binding: 0,
530                    visibility: wgpu::ShaderStages::COMPUTE,
531                    ty: wgpu::BindingType::Buffer {
532                        ty: wgpu::BufferBindingType::Storage { read_only: true },
533                        has_dynamic_offset: false,
534                        min_binding_size: None,
535                    },
536                    count: None,
537                },
538                wgpu::BindGroupLayoutEntry {
539                    binding: 1,
540                    visibility: wgpu::ShaderStages::COMPUTE,
541                    ty: wgpu::BindingType::Buffer {
542                        ty: wgpu::BufferBindingType::Storage { read_only: true },
543                        has_dynamic_offset: false,
544                        min_binding_size: None,
545                    },
546                    count: None,
547                },
548                wgpu::BindGroupLayoutEntry {
549                    binding: 2,
550                    visibility: wgpu::ShaderStages::COMPUTE,
551                    ty: wgpu::BindingType::Buffer {
552                        ty: wgpu::BufferBindingType::Storage { read_only: false },
553                        has_dynamic_offset: false,
554                        min_binding_size: None,
555                    },
556                    count: None,
557                },
558                wgpu::BindGroupLayoutEntry {
559                    binding: 3,
560                    visibility: wgpu::ShaderStages::COMPUTE,
561                    ty: wgpu::BindingType::Buffer {
562                        ty: wgpu::BufferBindingType::Storage { read_only: false },
563                        has_dynamic_offset: false,
564                        min_binding_size: None,
565                    },
566                    count: None,
567                },
568                wgpu::BindGroupLayoutEntry {
569                    binding: 4,
570                    visibility: wgpu::ShaderStages::COMPUTE,
571                    ty: wgpu::BindingType::Buffer {
572                        ty: wgpu::BufferBindingType::Uniform,
573                        has_dynamic_offset: false,
574                        min_binding_size: None,
575                    },
576                    count: None,
577                },
578            ],
579        );
580
581        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
582            label: Some("ICP Correspondence"),
583            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
584                label: Some("ICP Pipeline Layout"),
585                bind_group_layouts: &[&bind_group_layout],
586                immediate_size: 0,
587            })),
588            module: &shader,
589            entry_point: Some("main"),
590            compilation_options: wgpu::PipelineCompilationOptions::default(),
591            cache: None,
592        });
593
594        let bind_group = self.create_bind_group(
595            "ICP Correspondence",
596            &bind_group_layout,
597            &[
598                wgpu::BindGroupEntry {
599                    binding: 0,
600                    resource: source_buffer.as_entire_binding(),
601                },
602                wgpu::BindGroupEntry {
603                    binding: 1,
604                    resource: target_buffer.as_entire_binding(),
605                },
606                wgpu::BindGroupEntry {
607                    binding: 2,
608                    resource: correspondences_buffer.as_entire_binding(),
609                },
610                wgpu::BindGroupEntry {
611                    binding: 3,
612                    resource: distances_buffer.as_entire_binding(),
613                },
614                wgpu::BindGroupEntry {
615                    binding: 4,
616                    resource: params_buffer.as_entire_binding(),
617                },
618            ],
619        );
620
621        // Execute compute shader
622        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
623            label: Some("ICP Correspondence"),
624        });
625
626        {
627            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
628                label: Some("ICP Correspondence Pass"),
629                timestamp_writes: None,
630            });
631            compute_pass.set_pipeline(&pipeline);
632            compute_pass.set_bind_group(0, &bind_group, &[]);
633            let workgroup_count = (source_data.len() + 63) / 64;
634            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
635        }
636
637        // Read back results
638        let correspondences_staging = self.create_buffer(
639            "Correspondences Staging",
640            (source_data.len() * std::mem::size_of::<u32>()) as u64,
641            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
642        );
643
644        let distances_staging = self.create_buffer(
645            "Distances Staging",
646            (source_data.len() * std::mem::size_of::<f32>()) as u64,
647            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
648        );
649
650        encoder.copy_buffer_to_buffer(
651            &correspondences_buffer,
652            0,
653            &correspondences_staging,
654            0,
655            (source_data.len() * std::mem::size_of::<u32>()) as u64,
656        );
657
658        encoder.copy_buffer_to_buffer(
659            &distances_buffer,
660            0,
661            &distances_staging,
662            0,
663            (source_data.len() * std::mem::size_of::<f32>()) as u64,
664        );
665
666        self.queue.submit(std::iter::once(encoder.finish()));
667
668        // Map and read correspondences
669        let corr_slice = correspondences_staging.slice(..);
670        let dist_slice = distances_staging.slice(..);
671        
672        let (corr_sender, corr_receiver) = futures_intrusive::channel::shared::oneshot_channel();
673        let (dist_sender, dist_receiver) = futures_intrusive::channel::shared::oneshot_channel();
674        
675        corr_slice.map_async(wgpu::MapMode::Read, move |v| corr_sender.send(v).unwrap());
676        dist_slice.map_async(wgpu::MapMode::Read, move |v| dist_sender.send(v).unwrap());
677
678        self.device.poll(wgpu::PollType::Wait {
679            submission_index: None,
680            timeout: None,
681        });
682
683        if let (Some(Ok(())), Some(Ok(()))) = (corr_receiver.receive().await, dist_receiver.receive().await) {
684            let corr_data = corr_slice.get_mapped_range();
685            let dist_data = dist_slice.get_mapped_range();
686            
687            let correspondences: Vec<u32> = bytemuck::cast_slice(&corr_data).to_vec();
688            let distances: Vec<f32> = bytemuck::cast_slice(&dist_data).to_vec();
689            
690            let result: Vec<(usize, usize, f32)> = correspondences
691                .into_iter()
692                .zip(distances.into_iter())
693                .enumerate()
694                .filter(|(_, (_, distance))| *distance < max_distance)
695                .map(|(i, (target_idx, distance))| (i, target_idx as usize, distance))
696                .collect();
697            
698            drop(corr_data);
699            drop(dist_data);
700            correspondences_staging.unmap();
701            distances_staging.unmap();
702            
703            Ok(result)
704        } else {
705            Err(threecrate_core::Error::Gpu("Failed to read GPU correspondence results".to_string()))
706        }
707    }
708}
709
710/// Result of GPU point-to-plane ICP
711#[derive(Debug, Clone)]
712pub struct GpuPointToPlaneICPResult {
713    pub transformation: Isometry3<f32>,
714    pub final_error: f32,
715    pub iterations: usize,
716    pub converged: bool,
717}
718
719impl GpuContext {
720    /// GPU-accelerated point-to-plane ICP.
721    ///
722    /// Uses GPU for nearest-neighbor correspondence finding and CPU for the
723    /// 6×6 linearized optimization step (Chen & Medioni 1992).
724    pub async fn icp_point_to_plane_align(
725        &self,
726        source: &PointCloud<Point3f>,
727        target: &PointCloud<Point3f>,
728        target_normals: &[Vector3f],
729        max_iterations: usize,
730        convergence_threshold: f32,
731        max_correspondence_distance: f32,
732    ) -> Result<GpuPointToPlaneICPResult> {
733        if source.is_empty() || target.is_empty() {
734            return Err(threecrate_core::Error::InvalidData("Empty point clouds".to_string()));
735        }
736        if target_normals.len() != target.points.len() {
737            return Err(threecrate_core::Error::InvalidData(
738                "target_normals length must equal target point count".to_string(),
739            ));
740        }
741
742        let mut current_transform = Isometry3::identity();
743        let mut final_error = f32::INFINITY;
744        let mut iterations_used = 0;
745        let mut converged = false;
746
747        for iteration in 0..max_iterations {
748            iterations_used = iteration + 1;
749
750            // Transform source using current estimate
751            let mut transformed_source = source.clone();
752            for p in &mut transformed_source.points {
753                *p = current_transform.transform_point(p);
754            }
755
756            // GPU nearest-neighbor correspondences
757            let correspondences = self
758                .find_correspondences(
759                    &transformed_source.points,
760                    &target.points,
761                    max_correspondence_distance,
762                )
763                .await?;
764
765            if correspondences.len() < 6 {
766                break;
767            }
768
769            // Gather matched points and normals
770            let valid_source: Vec<Point3f> =
771                correspondences.iter().map(|(si, _, _)| transformed_source.points[*si]).collect();
772            let valid_target: Vec<Point3f> =
773                correspondences.iter().map(|(_, ti, _)| target.points[*ti]).collect();
774            let valid_normals: Vec<Vector3f> =
775                correspondences.iter().map(|(_, ti, _)| target_normals[*ti]).collect();
776
777            // Linearized point-to-plane optimization (6×6 system on CPU)
778            let delta = Self::solve_point_to_plane_cpu(&valid_source, &valid_target, &valid_normals)?;
779
780            current_transform = delta * current_transform;
781
782            // Point-to-plane MSE
783            let mse: f32 = valid_source
784                .iter()
785                .zip(valid_target.iter())
786                .zip(valid_normals.iter())
787                .map(|((s, d), n)| {
788                    let v = n.dot(&(d.coords - s.coords));
789                    v * v
790                })
791                .sum::<f32>()
792                / valid_source.len() as f32;
793
794            let translation_change = delta.translation.vector.norm();
795            let rotation_change = delta.rotation.angle();
796
797            final_error = mse;
798
799            if translation_change < convergence_threshold && rotation_change < convergence_threshold {
800                converged = true;
801                break;
802            }
803        }
804
805        Ok(GpuPointToPlaneICPResult {
806            transformation: current_transform,
807            final_error,
808            iterations: iterations_used,
809            converged,
810        })
811    }
812
813    /// CPU-side 6×6 linear solve for one point-to-plane ICP iteration.
814    fn solve_point_to_plane_cpu(
815        source_points: &[Point3f],
816        target_points: &[Point3f],
817        normals: &[Vector3f],
818    ) -> Result<Isometry3<f32>> {
819        let mut ata = Matrix6::<f32>::zeros();
820        let mut atb = Vector6::<f32>::zeros();
821
822        for ((src, tgt), n) in source_points
823            .iter()
824            .zip(target_points.iter())
825            .zip(normals.iter())
826        {
827            let c = src.coords.cross(n);
828            let a_row = Vector6::new(c.x, c.y, c.z, n.x, n.y, n.z);
829            let b_i = n.dot(&(tgt.coords - src.coords));
830            ata += a_row * a_row.transpose();
831            atb += a_row * b_i;
832        }
833
834        let x = if let Some(chol) = ata.cholesky() {
835            chol.solve(&atb)
836        } else {
837            ata.lu()
838                .solve(&atb)
839                .ok_or_else(|| threecrate_core::Error::Algorithm(
840                    "Point-to-plane GPU system is ill-conditioned".to_string(),
841                ))?
842        };
843
844        let rot_x = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::x_axis(), x[0]);
845        let rot_y = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::y_axis(), x[1]);
846        let rot_z = UnitQuaternion::from_axis_angle(&nalgebra::Vector3::z_axis(), x[2]);
847
848        Ok(Isometry3::from_parts(
849            Translation3::new(x[3], x[4], x[5]),
850            rot_z * rot_y * rot_x,
851        ))
852    }
853}
854
855/// GPU-accelerated ICP registration
856pub async fn gpu_icp(
857    gpu_context: &GpuContext,
858    source: &PointCloud<Point3f>,
859    target: &PointCloud<Point3f>,
860    max_iterations: usize,
861    convergence_threshold: f32,
862    max_correspondence_distance: f32,
863) -> Result<Isometry3<f32>> {
864    gpu_context.icp_align(source, target, max_iterations, convergence_threshold, max_correspondence_distance).await
865}
866
867/// Execute batch ICP operations on multiple point cloud pairs
868pub async fn gpu_batch_icp(
869    gpu_context: &GpuContext,
870    jobs: &[BatchICPJob],
871) -> Result<Vec<BatchICPResult>> {
872    gpu_context.batch_icp_align(jobs).await
873}
874
875/// GPU-accelerated point-to-plane ICP registration.
876///
877/// Uses the GPU for nearest-neighbor correspondence finding and the CPU for
878/// the 6×6 linearized optimization step (Chen & Medioni 1992).
879///
880/// # Arguments
881/// * `gpu_context`                  - Initialized GPU context
882/// * `source`                       - Source point cloud
883/// * `target`                       - Target point cloud
884/// * `target_normals`               - Surface normals at each target point
885/// * `max_iterations`               - Maximum number of iterations
886/// * `convergence_threshold`        - Delta-transform norm threshold for convergence
887/// * `max_correspondence_distance`  - Maximum distance for valid correspondences
888pub async fn gpu_icp_point_to_plane(
889    gpu_context: &GpuContext,
890    source: &PointCloud<Point3f>,
891    target: &PointCloud<Point3f>,
892    target_normals: &[Vector3f],
893    max_iterations: usize,
894    convergence_threshold: f32,
895    max_correspondence_distance: f32,
896) -> Result<GpuPointToPlaneICPResult> {
897    gpu_context
898        .icp_point_to_plane_align(
899            source,
900            target,
901            target_normals,
902            max_iterations,
903            convergence_threshold,
904            max_correspondence_distance,
905        )
906        .await
907}