threecrate_gpu/
nearest_neighbor.rs

1//! GPU-accelerated nearest neighbor search
2
3use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7/// Parameters for nearest neighbor search
8#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11    pub num_points: u32,
12    pub k_neighbors: u32,
13    pub max_distance: f32,
14    pub _padding: u32,
15}
16
17/// GPU representation of a point for nearest neighbor search
18#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22    pub position: [f32; 3],
23    pub _padding: f32,
24}
25
26/// Result of nearest neighbor search
27#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30    pub index: u32,
31    pub distance: f32,
32    pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37    position: vec3<f32>,
38    _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47    num_points: u32,
48    k_neighbors: u32,
49    max_distance: f32,
50    _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55    let query_idx = global_id.x;
56    if (query_idx >= arrayLength(&query_points)) {
57        return;
58    }
59    
60    let query_point = query_points[query_idx].position;
61    
62    // Initialize neighbors with maximum distance
63    var neighbors: array<vec2<f32>, MAX_K>;
64    for (var i = 0u; i < params.k_neighbors; i++) {
65        neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66    }
67    
68    // Find k nearest neighbors
69    for (var i = 0u; i < params.num_points; i++) {
70        let diff = input_points[i].position - query_point;
71        let distance = length(diff);
72        
73        if (distance < params.max_distance) {
74            // Insert into sorted neighbors array
75            let neighbor = vec2<f32>(f32(i), distance);
76            
77            // Find insertion point
78            var insert_idx = params.k_neighbors;
79            for (var j = 0u; j < params.k_neighbors; j++) {
80                if (distance < neighbors[j].y) {
81                    insert_idx = j;
82                    break;
83                }
84            }
85            
86            // Shift and insert
87            if (insert_idx < params.k_neighbors) {
88                for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89                    neighbors[j] = neighbors[j - 1u];
90                }
91                neighbors[insert_idx] = neighbor;
92            }
93        }
94    }
95    
96    // Write results
97    for (var i = 0u; i < params.k_neighbors; i++) {
98        output_neighbors[query_idx][i] = neighbors[i];
99    }
100}
101"#;
102
103impl GpuContext {
104    /// GPU-accelerated k-nearest neighbor search
105    pub async fn find_k_nearest_neighbors(
106        &self,
107        points: &[Point3f],
108        query_points: &[Point3f],
109        k: usize,
110        max_distance: f32,
111    ) -> Result<Vec<Vec<(usize, f32)>>> {
112        if points.is_empty() || query_points.is_empty() {
113            return Ok(vec![Vec::new(); query_points.len()]);
114        }
115        
116        let k = k.min(32).max(1); // Limit k to reasonable bounds
117        
118        // Convert points to GPU format with proper alignment
119        let gpu_points: Vec<GpuPoint> = points
120            .iter()
121            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122            .collect();
123            
124        let gpu_query_points: Vec<GpuPoint> = query_points
125            .iter()
126            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127            .collect();
128
129        // Create buffers
130        let points_buffer = self.create_buffer_init(
131            "Points Buffer",
132            &gpu_points,
133            wgpu::BufferUsages::STORAGE,
134        );
135
136        let query_buffer = self.create_buffer_init(
137            "Query Points Buffer",
138            &gpu_query_points,
139            wgpu::BufferUsages::STORAGE,
140        );
141
142        let output_buffer = self.create_buffer(
143            "Output Buffer",
144            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146        );
147
148        let params = NearestNeighborParams {
149            num_points: points.len() as u32,
150            k_neighbors: k as u32,
151            max_distance,
152            _padding: 0,
153        };
154
155        let params_buffer = self.create_buffer_init(
156            "Params Buffer",
157            &[params],
158            wgpu::BufferUsages::UNIFORM,
159        );
160
161        // Create shader with MAX_K constant
162        let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163        let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165        // Create bind group layout
166        let bind_group_layout = self.create_bind_group_layout(
167            "Nearest Neighbor Layout",
168            &[
169                wgpu::BindGroupLayoutEntry {
170                    binding: 0,
171                    visibility: wgpu::ShaderStages::COMPUTE,
172                    ty: wgpu::BindingType::Buffer {
173                        ty: wgpu::BufferBindingType::Storage { read_only: true },
174                        has_dynamic_offset: false,
175                        min_binding_size: None,
176                    },
177                    count: None,
178                },
179                wgpu::BindGroupLayoutEntry {
180                    binding: 1,
181                    visibility: wgpu::ShaderStages::COMPUTE,
182                    ty: wgpu::BindingType::Buffer {
183                        ty: wgpu::BufferBindingType::Storage { read_only: true },
184                        has_dynamic_offset: false,
185                        min_binding_size: None,
186                    },
187                    count: None,
188                },
189                wgpu::BindGroupLayoutEntry {
190                    binding: 2,
191                    visibility: wgpu::ShaderStages::COMPUTE,
192                    ty: wgpu::BindingType::Buffer {
193                        ty: wgpu::BufferBindingType::Storage { read_only: false },
194                        has_dynamic_offset: false,
195                        min_binding_size: None,
196                    },
197                    count: None,
198                },
199                wgpu::BindGroupLayoutEntry {
200                    binding: 3,
201                    visibility: wgpu::ShaderStages::COMPUTE,
202                    ty: wgpu::BindingType::Buffer {
203                        ty: wgpu::BufferBindingType::Uniform,
204                        has_dynamic_offset: false,
205                        min_binding_size: None,
206                    },
207                    count: None,
208                },
209            ],
210        );
211
212        // Create compute pipeline
213        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214            label: Some("Nearest Neighbor Pipeline"),
215            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216                label: Some("Nearest Neighbor Pipeline Layout"),
217                bind_group_layouts: &[&bind_group_layout],
218                push_constant_ranges: &[],
219            })),
220            module: &shader,
221            entry_point: "main",
222            compilation_options: wgpu::PipelineCompilationOptions::default(),
223        });
224
225        // Create bind group
226        let bind_group = self.create_bind_group(
227            "Nearest Neighbor Bind Group",
228            &bind_group_layout,
229            &[
230                wgpu::BindGroupEntry {
231                    binding: 0,
232                    resource: points_buffer.as_entire_binding(),
233                },
234                wgpu::BindGroupEntry {
235                    binding: 1,
236                    resource: query_buffer.as_entire_binding(),
237                },
238                wgpu::BindGroupEntry {
239                    binding: 2,
240                    resource: output_buffer.as_entire_binding(),
241                },
242                wgpu::BindGroupEntry {
243                    binding: 3,
244                    resource: params_buffer.as_entire_binding(),
245                },
246            ],
247        );
248
249        // Execute compute shader
250        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
251            label: Some("Nearest Neighbor Encoder"),
252        });
253
254        {
255            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
256                label: Some("Nearest Neighbor Pass"),
257                timestamp_writes: None,
258            });
259            compute_pass.set_pipeline(&pipeline);
260            compute_pass.set_bind_group(0, &bind_group, &[]);
261            let workgroup_count = (query_points.len() + 63) / 64;
262            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
263        }
264
265        // Read back results
266        let staging_buffer = self.create_buffer(
267            "Staging Buffer",
268            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
269            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
270        );
271
272        encoder.copy_buffer_to_buffer(
273            &output_buffer,
274            0,
275            &staging_buffer,
276            0,
277            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
278        );
279
280        self.queue.submit(std::iter::once(encoder.finish()));
281
282        // Map and read results
283        let buffer_slice = staging_buffer.slice(..);
284        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
285        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
286
287        self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
288
289        if let Some(Ok(())) = receiver.receive().await {
290            let data = buffer_slice.get_mapped_range();
291            let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
292            
293            let mut results = Vec::with_capacity(query_points.len());
294            for i in 0..query_points.len() {
295                let mut neighbors = Vec::with_capacity(k);
296                for j in 0..k {
297                    let idx = i * k + j;
298                    if idx < raw_neighbors.len() {
299                        let neighbor = raw_neighbors[idx];
300                        let point_idx = neighbor[0] as usize;
301                        let distance = neighbor[1];
302                        
303                        if point_idx < points.len() && distance < max_distance {
304                            neighbors.push((point_idx, distance));
305                        }
306                    }
307                }
308                results.push(neighbors);
309            }
310            
311            drop(data);
312            staging_buffer.unmap();
313            
314            Ok(results)
315        } else {
316            Err(Error::Gpu("Failed to read GPU results".to_string()))
317        }
318    }
319}
320
321/// GPU-accelerated nearest neighbor search for single query point
322pub async fn gpu_find_k_nearest(
323    gpu_context: &GpuContext,
324    points: &[Point3f],
325    query: &Point3f,
326    k: usize,
327) -> Result<Vec<(usize, f32)>> {
328    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
329    Ok(results.into_iter().next().unwrap_or_default())
330}
331
332/// GPU-accelerated nearest neighbor search for multiple query points
333pub async fn gpu_find_k_nearest_batch(
334    gpu_context: &GpuContext,
335    points: &[Point3f],
336    query_points: &[Point3f],
337    k: usize,
338) -> Result<Vec<Vec<(usize, f32)>>> {
339    gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
340}
341
342/// GPU-accelerated radius-based nearest neighbor search
343pub async fn gpu_find_radius_neighbors(
344    gpu_context: &GpuContext,
345    points: &[Point3f],
346    query: &Point3f,
347    radius: f32,
348) -> Result<Vec<(usize, f32)>> {
349    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
350    Ok(results.into_iter().next().unwrap_or_default())
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use crate::device::GpuContext;
357    use threecrate_core::Point3f;
358    use approx::assert_relative_eq;
359
360    /// Try to create a GPU context, return None if not available
361    async fn try_create_gpu_context() -> Option<GpuContext> {
362        match GpuContext::new().await {
363            Ok(gpu) => Some(gpu),
364            Err(_) => {
365                println!("⚠️  GPU not available, skipping GPU-dependent test");
366                None
367            }
368        }
369    }
370
371    /// Create a simple test point cloud
372    fn create_test_points() -> Vec<Point3f> {
373        vec![
374            Point3f::new(0.0, 0.0, 0.0),
375            Point3f::new(1.0, 0.0, 0.0),
376            Point3f::new(0.0, 1.0, 0.0),
377            Point3f::new(0.0, 0.0, 1.0),
378            Point3f::new(1.0, 1.0, 1.0),
379        ]
380    }
381
382    #[test]
383    fn test_gpu_nearest_neighbor_single() {
384        pollster::block_on(async {
385            let Some(gpu) = try_create_gpu_context().await else {
386                return;
387            };
388
389            let points = create_test_points();
390            let query = Point3f::new(0.1, 0.1, 0.1);
391            
392            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
393            
394            assert_eq!(neighbors.len(), 3);
395            assert_eq!(neighbors[0].0, 0); // Closest should be origin
396            assert!(neighbors[0].1 < 0.2); // Distance should be small
397            
398            println!("✓ GPU single nearest neighbor test passed");
399        });
400    }
401
402    #[test]
403    fn test_gpu_nearest_neighbor_batch() {
404        pollster::block_on(async {
405            let Some(gpu) = try_create_gpu_context().await else {
406                return;
407            };
408
409            let points = create_test_points();
410            let queries = vec![
411                Point3f::new(0.1, 0.1, 0.1),
412                Point3f::new(0.9, 0.1, 0.1),
413            ];
414            
415            let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
416            
417            assert_eq!(results.len(), 2);
418            assert_eq!(results[0].len(), 2);
419            assert_eq!(results[1].len(), 2);
420            
421            // First query should find origin as closest
422            assert_eq!(results[0][0].0, 0);
423            
424            // Second query should find (1,0,0) as closest
425            assert_eq!(results[1][0].0, 1);
426            
427            println!("✓ GPU batch nearest neighbor test passed");
428        });
429    }
430
431    #[test]
432    fn test_gpu_radius_neighbors() {
433        pollster::block_on(async {
434            let Some(gpu) = try_create_gpu_context().await else {
435                return;
436            };
437
438            let points = create_test_points();
439            let query = Point3f::new(0.0, 0.0, 0.0);
440            let radius = 1.5;
441            
442            let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
443            
444            // Should find points within radius
445            assert!(!neighbors.is_empty());
446            
447            // All distances should be within radius
448            for (_, distance) in &neighbors {
449                assert!(*distance <= radius);
450            }
451            
452            println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
453        });
454    }
455
456    #[test]
457    fn test_gpu_nearest_neighbor_accuracy() {
458        pollster::block_on(async {
459            let Some(gpu) = try_create_gpu_context().await else {
460                return;
461            };
462
463            let points = create_test_points();
464            let query = Point3f::new(0.5, 0.5, 0.5);
465            
466            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
467            
468            assert_eq!(neighbors.len(), 1);
469            
470            // Manually verify the nearest neighbor
471            let mut min_dist = f32::INFINITY;
472            let mut min_idx = 0;
473            
474            for (i, point) in points.iter().enumerate() {
475                let dist = (query - *point).magnitude();
476                if dist < min_dist {
477                    min_dist = dist;
478                    min_idx = i;
479                }
480            }
481            
482            assert_eq!(neighbors[0].0, min_idx);
483            assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
484            
485            println!("✓ GPU nearest neighbor accuracy test passed");
486        });
487    }
488}