Skip to main content

threecrate_gpu/
nearest_neighbor.rs

1//! GPU-accelerated nearest neighbor search
2
3use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7/// Parameters for nearest neighbor search
8#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11    pub num_points: u32,
12    pub k_neighbors: u32,
13    pub max_distance: f32,
14    pub _padding: u32,
15}
16
17/// GPU representation of a point for nearest neighbor search
18#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22    pub position: [f32; 3],
23    pub _padding: f32,
24}
25
26/// Result of nearest neighbor search
27#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30    pub index: u32,
31    pub distance: f32,
32    pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37    position: vec3<f32>,
38    _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47    num_points: u32,
48    k_neighbors: u32,
49    max_distance: f32,
50    _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55    let query_idx = global_id.x;
56    if (query_idx >= arrayLength(&query_points)) {
57        return;
58    }
59    
60    let query_point = query_points[query_idx].position;
61    
62    // Initialize neighbors with maximum distance
63    var neighbors: array<vec2<f32>, MAX_K>;
64    for (var i = 0u; i < params.k_neighbors; i++) {
65        neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66    }
67    
68    // Find k nearest neighbors
69    for (var i = 0u; i < params.num_points; i++) {
70        let diff = input_points[i].position - query_point;
71        let distance = length(diff);
72        
73        if (distance < params.max_distance) {
74            // Insert into sorted neighbors array
75            let neighbor = vec2<f32>(f32(i), distance);
76            
77            // Find insertion point
78            var insert_idx = params.k_neighbors;
79            for (var j = 0u; j < params.k_neighbors; j++) {
80                if (distance < neighbors[j].y) {
81                    insert_idx = j;
82                    break;
83                }
84            }
85            
86            // Shift and insert
87            if (insert_idx < params.k_neighbors) {
88                for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89                    neighbors[j] = neighbors[j - 1u];
90                }
91                neighbors[insert_idx] = neighbor;
92            }
93        }
94    }
95    
96    // Write results
97    for (var i = 0u; i < params.k_neighbors; i++) {
98        output_neighbors[query_idx][i] = neighbors[i];
99    }
100}
101"#;
102
103impl GpuContext {
104    /// GPU-accelerated k-nearest neighbor search
105    pub async fn find_k_nearest_neighbors(
106        &self,
107        points: &[Point3f],
108        query_points: &[Point3f],
109        k: usize,
110        max_distance: f32,
111    ) -> Result<Vec<Vec<(usize, f32)>>> {
112        if points.is_empty() || query_points.is_empty() {
113            return Ok(vec![Vec::new(); query_points.len()]);
114        }
115        
116        let k = k.min(32).max(1); // Limit k to reasonable bounds
117        
118        // Convert points to GPU format with proper alignment
119        let gpu_points: Vec<GpuPoint> = points
120            .iter()
121            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122            .collect();
123            
124        let gpu_query_points: Vec<GpuPoint> = query_points
125            .iter()
126            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127            .collect();
128
129        // Create buffers
130        let points_buffer = self.create_buffer_init(
131            "Points Buffer",
132            &gpu_points,
133            wgpu::BufferUsages::STORAGE,
134        );
135
136        let query_buffer = self.create_buffer_init(
137            "Query Points Buffer",
138            &gpu_query_points,
139            wgpu::BufferUsages::STORAGE,
140        );
141
142        let output_buffer = self.create_buffer(
143            "Output Buffer",
144            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146        );
147
148        let params = NearestNeighborParams {
149            num_points: points.len() as u32,
150            k_neighbors: k as u32,
151            max_distance,
152            _padding: 0,
153        };
154
155        let params_buffer = self.create_buffer_init(
156            "Params Buffer",
157            &[params],
158            wgpu::BufferUsages::UNIFORM,
159        );
160
161        // Create shader with MAX_K constant
162        let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163        let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165        // Create bind group layout
166        let bind_group_layout = self.create_bind_group_layout(
167            "Nearest Neighbor Layout",
168            &[
169                wgpu::BindGroupLayoutEntry {
170                    binding: 0,
171                    visibility: wgpu::ShaderStages::COMPUTE,
172                    ty: wgpu::BindingType::Buffer {
173                        ty: wgpu::BufferBindingType::Storage { read_only: true },
174                        has_dynamic_offset: false,
175                        min_binding_size: None,
176                    },
177                    count: None,
178                },
179                wgpu::BindGroupLayoutEntry {
180                    binding: 1,
181                    visibility: wgpu::ShaderStages::COMPUTE,
182                    ty: wgpu::BindingType::Buffer {
183                        ty: wgpu::BufferBindingType::Storage { read_only: true },
184                        has_dynamic_offset: false,
185                        min_binding_size: None,
186                    },
187                    count: None,
188                },
189                wgpu::BindGroupLayoutEntry {
190                    binding: 2,
191                    visibility: wgpu::ShaderStages::COMPUTE,
192                    ty: wgpu::BindingType::Buffer {
193                        ty: wgpu::BufferBindingType::Storage { read_only: false },
194                        has_dynamic_offset: false,
195                        min_binding_size: None,
196                    },
197                    count: None,
198                },
199                wgpu::BindGroupLayoutEntry {
200                    binding: 3,
201                    visibility: wgpu::ShaderStages::COMPUTE,
202                    ty: wgpu::BindingType::Buffer {
203                        ty: wgpu::BufferBindingType::Uniform,
204                        has_dynamic_offset: false,
205                        min_binding_size: None,
206                    },
207                    count: None,
208                },
209            ],
210        );
211
212        // Create compute pipeline
213        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214            label: Some("Nearest Neighbor Pipeline"),
215            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216                label: Some("Nearest Neighbor Pipeline Layout"),
217                bind_group_layouts: &[&bind_group_layout],
218                immediate_size: 0,
219            })),
220            module: &shader,
221            entry_point: Some("main"),
222            compilation_options: wgpu::PipelineCompilationOptions::default(),
223            cache: None,
224        });
225
226        // Create bind group
227        let bind_group = self.create_bind_group(
228            "Nearest Neighbor Bind Group",
229            &bind_group_layout,
230            &[
231                wgpu::BindGroupEntry {
232                    binding: 0,
233                    resource: points_buffer.as_entire_binding(),
234                },
235                wgpu::BindGroupEntry {
236                    binding: 1,
237                    resource: query_buffer.as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 2,
241                    resource: output_buffer.as_entire_binding(),
242                },
243                wgpu::BindGroupEntry {
244                    binding: 3,
245                    resource: params_buffer.as_entire_binding(),
246                },
247            ],
248        );
249
250        // Execute compute shader
251        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
252            label: Some("Nearest Neighbor Encoder"),
253        });
254
255        {
256            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
257                label: Some("Nearest Neighbor Pass"),
258                timestamp_writes: None,
259            });
260            compute_pass.set_pipeline(&pipeline);
261            compute_pass.set_bind_group(0, &bind_group, &[]);
262            let workgroup_count = (query_points.len() + 63) / 64;
263            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
264        }
265
266        // Read back results
267        let staging_buffer = self.create_buffer(
268            "Staging Buffer",
269            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
270            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
271        );
272
273        encoder.copy_buffer_to_buffer(
274            &output_buffer,
275            0,
276            &staging_buffer,
277            0,
278            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
279        );
280
281        self.queue.submit(std::iter::once(encoder.finish()));
282
283        // Map and read results
284        let buffer_slice = staging_buffer.slice(..);
285        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
286        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
287
288        self.device.poll(wgpu::PollType::Wait {
289            submission_index: None,
290            timeout: None,
291        });
292
293        if let Some(Ok(())) = receiver.receive().await {
294            let data = buffer_slice.get_mapped_range();
295            let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
296            
297            let mut results = Vec::with_capacity(query_points.len());
298            for i in 0..query_points.len() {
299                let mut neighbors = Vec::with_capacity(k);
300                for j in 0..k {
301                    let idx = i * k + j;
302                    if idx < raw_neighbors.len() {
303                        let neighbor = raw_neighbors[idx];
304                        let point_idx = neighbor[0] as usize;
305                        let distance = neighbor[1];
306                        
307                        if point_idx < points.len() && distance < max_distance {
308                            neighbors.push((point_idx, distance));
309                        }
310                    }
311                }
312                results.push(neighbors);
313            }
314            
315            drop(data);
316            staging_buffer.unmap();
317            
318            Ok(results)
319        } else {
320            Err(Error::Gpu("Failed to read GPU results".to_string()))
321        }
322    }
323}
324
325/// GPU-accelerated nearest neighbor search for single query point
326pub async fn gpu_find_k_nearest(
327    gpu_context: &GpuContext,
328    points: &[Point3f],
329    query: &Point3f,
330    k: usize,
331) -> Result<Vec<(usize, f32)>> {
332    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
333    Ok(results.into_iter().next().unwrap_or_default())
334}
335
336/// GPU-accelerated nearest neighbor search for multiple query points
337pub async fn gpu_find_k_nearest_batch(
338    gpu_context: &GpuContext,
339    points: &[Point3f],
340    query_points: &[Point3f],
341    k: usize,
342) -> Result<Vec<Vec<(usize, f32)>>> {
343    gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
344}
345
346/// GPU-accelerated radius-based nearest neighbor search
347pub async fn gpu_find_radius_neighbors(
348    gpu_context: &GpuContext,
349    points: &[Point3f],
350    query: &Point3f,
351    radius: f32,
352) -> Result<Vec<(usize, f32)>> {
353    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
354    Ok(results.into_iter().next().unwrap_or_default())
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use crate::device::GpuContext;
361    use threecrate_core::Point3f;
362    use approx::assert_relative_eq;
363
364    /// Try to create a GPU context, return None if not available
365    async fn try_create_gpu_context() -> Option<GpuContext> {
366        match GpuContext::new().await {
367            Ok(gpu) => Some(gpu),
368            Err(_) => {
369                println!("⚠️  GPU not available, skipping GPU-dependent test");
370                None
371            }
372        }
373    }
374
375    /// Create a simple test point cloud
376    fn create_test_points() -> Vec<Point3f> {
377        vec![
378            Point3f::new(0.0, 0.0, 0.0),
379            Point3f::new(1.0, 0.0, 0.0),
380            Point3f::new(0.0, 1.0, 0.0),
381            Point3f::new(0.0, 0.0, 1.0),
382            Point3f::new(1.0, 1.0, 1.0),
383        ]
384    }
385
386    #[test]
387    fn test_gpu_nearest_neighbor_single() {
388        pollster::block_on(async {
389            let Some(gpu) = try_create_gpu_context().await else {
390                return;
391            };
392
393            let points = create_test_points();
394            let query = Point3f::new(0.1, 0.1, 0.1);
395            
396            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
397            
398            assert_eq!(neighbors.len(), 3);
399            assert_eq!(neighbors[0].0, 0); // Closest should be origin
400            assert!(neighbors[0].1 < 0.2); // Distance should be small
401            
402            println!("✓ GPU single nearest neighbor test passed");
403        });
404    }
405
406    #[test]
407    fn test_gpu_nearest_neighbor_batch() {
408        pollster::block_on(async {
409            let Some(gpu) = try_create_gpu_context().await else {
410                return;
411            };
412
413            let points = create_test_points();
414            let queries = vec![
415                Point3f::new(0.1, 0.1, 0.1),
416                Point3f::new(0.9, 0.1, 0.1),
417            ];
418            
419            let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
420            
421            assert_eq!(results.len(), 2);
422            assert_eq!(results[0].len(), 2);
423            assert_eq!(results[1].len(), 2);
424            
425            // First query should find origin as closest
426            assert_eq!(results[0][0].0, 0);
427            
428            // Second query should find (1,0,0) as closest
429            assert_eq!(results[1][0].0, 1);
430            
431            println!("✓ GPU batch nearest neighbor test passed");
432        });
433    }
434
435    #[test]
436    fn test_gpu_radius_neighbors() {
437        pollster::block_on(async {
438            let Some(gpu) = try_create_gpu_context().await else {
439                return;
440            };
441
442            let points = create_test_points();
443            let query = Point3f::new(0.0, 0.0, 0.0);
444            let radius = 1.5;
445            
446            let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
447            
448            // Should find points within radius
449            assert!(!neighbors.is_empty());
450            
451            // All distances should be within radius
452            for (_, distance) in &neighbors {
453                assert!(*distance <= radius);
454            }
455            
456            println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
457        });
458    }
459
460    #[test]
461    fn test_gpu_nearest_neighbor_accuracy() {
462        pollster::block_on(async {
463            let Some(gpu) = try_create_gpu_context().await else {
464                return;
465            };
466
467            let points = create_test_points();
468            let query = Point3f::new(0.5, 0.5, 0.5);
469            
470            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
471            
472            assert_eq!(neighbors.len(), 1);
473            
474            // Manually verify the nearest neighbor
475            let mut min_dist = f32::INFINITY;
476            let mut min_idx = 0;
477            
478            for (i, point) in points.iter().enumerate() {
479                let dist = (query - *point).magnitude();
480                if dist < min_dist {
481                    min_dist = dist;
482                    min_idx = i;
483                }
484            }
485            
486            assert_eq!(neighbors[0].0, min_idx);
487            assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
488            
489            println!("✓ GPU nearest neighbor accuracy test passed");
490        });
491    }
492}