threecrate_gpu/
nearest_neighbor.rs

1//! GPU-accelerated nearest neighbor search
2
3use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7/// Parameters for nearest neighbor search
8#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11    pub num_points: u32,
12    pub k_neighbors: u32,
13    pub max_distance: f32,
14    pub _padding: u32,
15}
16
17/// GPU representation of a point for nearest neighbor search
18#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22    pub position: [f32; 3],
23    pub _padding: f32,
24}
25
26/// Result of nearest neighbor search
27#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30    pub index: u32,
31    pub distance: f32,
32    pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37    position: vec3<f32>,
38    _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47    num_points: u32,
48    k_neighbors: u32,
49    max_distance: f32,
50    _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55    let query_idx = global_id.x;
56    if (query_idx >= arrayLength(&query_points)) {
57        return;
58    }
59    
60    let query_point = query_points[query_idx].position;
61    
62    // Initialize neighbors with maximum distance
63    var neighbors: array<vec2<f32>, MAX_K>;
64    for (var i = 0u; i < params.k_neighbors; i++) {
65        neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66    }
67    
68    // Find k nearest neighbors
69    for (var i = 0u; i < params.num_points; i++) {
70        let diff = input_points[i].position - query_point;
71        let distance = length(diff);
72        
73        if (distance < params.max_distance) {
74            // Insert into sorted neighbors array
75            let neighbor = vec2<f32>(f32(i), distance);
76            
77            // Find insertion point
78            var insert_idx = params.k_neighbors;
79            for (var j = 0u; j < params.k_neighbors; j++) {
80                if (distance < neighbors[j].y) {
81                    insert_idx = j;
82                    break;
83                }
84            }
85            
86            // Shift and insert
87            if (insert_idx < params.k_neighbors) {
88                for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89                    neighbors[j] = neighbors[j - 1u];
90                }
91                neighbors[insert_idx] = neighbor;
92            }
93        }
94    }
95    
96    // Write results
97    for (var i = 0u; i < params.k_neighbors; i++) {
98        output_neighbors[query_idx][i] = neighbors[i];
99    }
100}
101"#;
102
103impl GpuContext {
104    /// GPU-accelerated k-nearest neighbor search
105    pub async fn find_k_nearest_neighbors(
106        &self,
107        points: &[Point3f],
108        query_points: &[Point3f],
109        k: usize,
110        max_distance: f32,
111    ) -> Result<Vec<Vec<(usize, f32)>>> {
112        if points.is_empty() || query_points.is_empty() {
113            return Ok(vec![Vec::new(); query_points.len()]);
114        }
115        
116        let k = k.min(32).max(1); // Limit k to reasonable bounds
117        
118        // Convert points to GPU format with proper alignment
119        let gpu_points: Vec<GpuPoint> = points
120            .iter()
121            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122            .collect();
123            
124        let gpu_query_points: Vec<GpuPoint> = query_points
125            .iter()
126            .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127            .collect();
128
129        // Create buffers
130        let points_buffer = self.create_buffer_init(
131            "Points Buffer",
132            &gpu_points,
133            wgpu::BufferUsages::STORAGE,
134        );
135
136        let query_buffer = self.create_buffer_init(
137            "Query Points Buffer",
138            &gpu_query_points,
139            wgpu::BufferUsages::STORAGE,
140        );
141
142        let output_buffer = self.create_buffer(
143            "Output Buffer",
144            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145            wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146        );
147
148        let params = NearestNeighborParams {
149            num_points: points.len() as u32,
150            k_neighbors: k as u32,
151            max_distance,
152            _padding: 0,
153        };
154
155        let params_buffer = self.create_buffer_init(
156            "Params Buffer",
157            &[params],
158            wgpu::BufferUsages::UNIFORM,
159        );
160
161        // Create shader with MAX_K constant
162        let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163        let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165        // Create bind group layout
166        let bind_group_layout = self.create_bind_group_layout(
167            "Nearest Neighbor Layout",
168            &[
169                wgpu::BindGroupLayoutEntry {
170                    binding: 0,
171                    visibility: wgpu::ShaderStages::COMPUTE,
172                    ty: wgpu::BindingType::Buffer {
173                        ty: wgpu::BufferBindingType::Storage { read_only: true },
174                        has_dynamic_offset: false,
175                        min_binding_size: None,
176                    },
177                    count: None,
178                },
179                wgpu::BindGroupLayoutEntry {
180                    binding: 1,
181                    visibility: wgpu::ShaderStages::COMPUTE,
182                    ty: wgpu::BindingType::Buffer {
183                        ty: wgpu::BufferBindingType::Storage { read_only: true },
184                        has_dynamic_offset: false,
185                        min_binding_size: None,
186                    },
187                    count: None,
188                },
189                wgpu::BindGroupLayoutEntry {
190                    binding: 2,
191                    visibility: wgpu::ShaderStages::COMPUTE,
192                    ty: wgpu::BindingType::Buffer {
193                        ty: wgpu::BufferBindingType::Storage { read_only: false },
194                        has_dynamic_offset: false,
195                        min_binding_size: None,
196                    },
197                    count: None,
198                },
199                wgpu::BindGroupLayoutEntry {
200                    binding: 3,
201                    visibility: wgpu::ShaderStages::COMPUTE,
202                    ty: wgpu::BindingType::Buffer {
203                        ty: wgpu::BufferBindingType::Uniform,
204                        has_dynamic_offset: false,
205                        min_binding_size: None,
206                    },
207                    count: None,
208                },
209            ],
210        );
211
212        // Create compute pipeline
213        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214            label: Some("Nearest Neighbor Pipeline"),
215            layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216                label: Some("Nearest Neighbor Pipeline Layout"),
217                bind_group_layouts: &[&bind_group_layout],
218                push_constant_ranges: &[],
219            })),
220            module: &shader,
221            entry_point: Some("main"),
222            compilation_options: wgpu::PipelineCompilationOptions::default(),
223            cache: None,
224        });
225
226        // Create bind group
227        let bind_group = self.create_bind_group(
228            "Nearest Neighbor Bind Group",
229            &bind_group_layout,
230            &[
231                wgpu::BindGroupEntry {
232                    binding: 0,
233                    resource: points_buffer.as_entire_binding(),
234                },
235                wgpu::BindGroupEntry {
236                    binding: 1,
237                    resource: query_buffer.as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 2,
241                    resource: output_buffer.as_entire_binding(),
242                },
243                wgpu::BindGroupEntry {
244                    binding: 3,
245                    resource: params_buffer.as_entire_binding(),
246                },
247            ],
248        );
249
250        // Execute compute shader
251        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
252            label: Some("Nearest Neighbor Encoder"),
253        });
254
255        {
256            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
257                label: Some("Nearest Neighbor Pass"),
258                timestamp_writes: None,
259            });
260            compute_pass.set_pipeline(&pipeline);
261            compute_pass.set_bind_group(0, &bind_group, &[]);
262            let workgroup_count = (query_points.len() + 63) / 64;
263            compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
264        }
265
266        // Read back results
267        let staging_buffer = self.create_buffer(
268            "Staging Buffer",
269            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
270            wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
271        );
272
273        encoder.copy_buffer_to_buffer(
274            &output_buffer,
275            0,
276            &staging_buffer,
277            0,
278            (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
279        );
280
281        self.queue.submit(std::iter::once(encoder.finish()));
282
283        // Map and read results
284        let buffer_slice = staging_buffer.slice(..);
285        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
286        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
287
288        let _ = self.device.poll(wgpu::PollType::Wait);
289
290        if let Some(Ok(())) = receiver.receive().await {
291            let data = buffer_slice.get_mapped_range();
292            let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
293            
294            let mut results = Vec::with_capacity(query_points.len());
295            for i in 0..query_points.len() {
296                let mut neighbors = Vec::with_capacity(k);
297                for j in 0..k {
298                    let idx = i * k + j;
299                    if idx < raw_neighbors.len() {
300                        let neighbor = raw_neighbors[idx];
301                        let point_idx = neighbor[0] as usize;
302                        let distance = neighbor[1];
303                        
304                        if point_idx < points.len() && distance < max_distance {
305                            neighbors.push((point_idx, distance));
306                        }
307                    }
308                }
309                results.push(neighbors);
310            }
311            
312            drop(data);
313            staging_buffer.unmap();
314            
315            Ok(results)
316        } else {
317            Err(Error::Gpu("Failed to read GPU results".to_string()))
318        }
319    }
320}
321
322/// GPU-accelerated nearest neighbor search for single query point
323pub async fn gpu_find_k_nearest(
324    gpu_context: &GpuContext,
325    points: &[Point3f],
326    query: &Point3f,
327    k: usize,
328) -> Result<Vec<(usize, f32)>> {
329    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
330    Ok(results.into_iter().next().unwrap_or_default())
331}
332
333/// GPU-accelerated nearest neighbor search for multiple query points
334pub async fn gpu_find_k_nearest_batch(
335    gpu_context: &GpuContext,
336    points: &[Point3f],
337    query_points: &[Point3f],
338    k: usize,
339) -> Result<Vec<Vec<(usize, f32)>>> {
340    gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
341}
342
343/// GPU-accelerated radius-based nearest neighbor search
344pub async fn gpu_find_radius_neighbors(
345    gpu_context: &GpuContext,
346    points: &[Point3f],
347    query: &Point3f,
348    radius: f32,
349) -> Result<Vec<(usize, f32)>> {
350    let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
351    Ok(results.into_iter().next().unwrap_or_default())
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::device::GpuContext;
358    use threecrate_core::Point3f;
359    use approx::assert_relative_eq;
360
361    /// Try to create a GPU context, return None if not available
362    async fn try_create_gpu_context() -> Option<GpuContext> {
363        match GpuContext::new().await {
364            Ok(gpu) => Some(gpu),
365            Err(_) => {
366                println!("⚠️  GPU not available, skipping GPU-dependent test");
367                None
368            }
369        }
370    }
371
372    /// Create a simple test point cloud
373    fn create_test_points() -> Vec<Point3f> {
374        vec![
375            Point3f::new(0.0, 0.0, 0.0),
376            Point3f::new(1.0, 0.0, 0.0),
377            Point3f::new(0.0, 1.0, 0.0),
378            Point3f::new(0.0, 0.0, 1.0),
379            Point3f::new(1.0, 1.0, 1.0),
380        ]
381    }
382
383    #[test]
384    fn test_gpu_nearest_neighbor_single() {
385        pollster::block_on(async {
386            let Some(gpu) = try_create_gpu_context().await else {
387                return;
388            };
389
390            let points = create_test_points();
391            let query = Point3f::new(0.1, 0.1, 0.1);
392            
393            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
394            
395            assert_eq!(neighbors.len(), 3);
396            assert_eq!(neighbors[0].0, 0); // Closest should be origin
397            assert!(neighbors[0].1 < 0.2); // Distance should be small
398            
399            println!("✓ GPU single nearest neighbor test passed");
400        });
401    }
402
403    #[test]
404    fn test_gpu_nearest_neighbor_batch() {
405        pollster::block_on(async {
406            let Some(gpu) = try_create_gpu_context().await else {
407                return;
408            };
409
410            let points = create_test_points();
411            let queries = vec![
412                Point3f::new(0.1, 0.1, 0.1),
413                Point3f::new(0.9, 0.1, 0.1),
414            ];
415            
416            let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
417            
418            assert_eq!(results.len(), 2);
419            assert_eq!(results[0].len(), 2);
420            assert_eq!(results[1].len(), 2);
421            
422            // First query should find origin as closest
423            assert_eq!(results[0][0].0, 0);
424            
425            // Second query should find (1,0,0) as closest
426            assert_eq!(results[1][0].0, 1);
427            
428            println!("✓ GPU batch nearest neighbor test passed");
429        });
430    }
431
432    #[test]
433    fn test_gpu_radius_neighbors() {
434        pollster::block_on(async {
435            let Some(gpu) = try_create_gpu_context().await else {
436                return;
437            };
438
439            let points = create_test_points();
440            let query = Point3f::new(0.0, 0.0, 0.0);
441            let radius = 1.5;
442            
443            let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
444            
445            // Should find points within radius
446            assert!(!neighbors.is_empty());
447            
448            // All distances should be within radius
449            for (_, distance) in &neighbors {
450                assert!(*distance <= radius);
451            }
452            
453            println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
454        });
455    }
456
457    #[test]
458    fn test_gpu_nearest_neighbor_accuracy() {
459        pollster::block_on(async {
460            let Some(gpu) = try_create_gpu_context().await else {
461                return;
462            };
463
464            let points = create_test_points();
465            let query = Point3f::new(0.5, 0.5, 0.5);
466            
467            let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
468            
469            assert_eq!(neighbors.len(), 1);
470            
471            // Manually verify the nearest neighbor
472            let mut min_dist = f32::INFINITY;
473            let mut min_idx = 0;
474            
475            for (i, point) in points.iter().enumerate() {
476                let dist = (query - *point).magnitude();
477                if dist < min_dist {
478                    min_dist = dist;
479                    min_idx = i;
480                }
481            }
482            
483            assert_eq!(neighbors[0].0, min_idx);
484            assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
485            
486            println!("✓ GPU nearest neighbor accuracy test passed");
487        });
488    }
489}