1use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11 pub num_points: u32,
12 pub k_neighbors: u32,
13 pub max_distance: f32,
14 pub _padding: u32,
15}
16
17#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22 pub position: [f32; 3],
23 pub _padding: f32,
24}
25
26#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30 pub index: u32,
31 pub distance: f32,
32 pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37 position: vec3<f32>,
38 _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47 num_points: u32,
48 k_neighbors: u32,
49 max_distance: f32,
50 _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55 let query_idx = global_id.x;
56 if (query_idx >= arrayLength(&query_points)) {
57 return;
58 }
59
60 let query_point = query_points[query_idx].position;
61
62 // Initialize neighbors with maximum distance
63 var neighbors: array<vec2<f32>, MAX_K>;
64 for (var i = 0u; i < params.k_neighbors; i++) {
65 neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66 }
67
68 // Find k nearest neighbors
69 for (var i = 0u; i < params.num_points; i++) {
70 let diff = input_points[i].position - query_point;
71 let distance = length(diff);
72
73 if (distance < params.max_distance) {
74 // Insert into sorted neighbors array
75 let neighbor = vec2<f32>(f32(i), distance);
76
77 // Find insertion point
78 var insert_idx = params.k_neighbors;
79 for (var j = 0u; j < params.k_neighbors; j++) {
80 if (distance < neighbors[j].y) {
81 insert_idx = j;
82 break;
83 }
84 }
85
86 // Shift and insert
87 if (insert_idx < params.k_neighbors) {
88 for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89 neighbors[j] = neighbors[j - 1u];
90 }
91 neighbors[insert_idx] = neighbor;
92 }
93 }
94 }
95
96 // Write results
97 for (var i = 0u; i < params.k_neighbors; i++) {
98 output_neighbors[query_idx][i] = neighbors[i];
99 }
100}
101"#;
102
103impl GpuContext {
104 pub async fn find_k_nearest_neighbors(
106 &self,
107 points: &[Point3f],
108 query_points: &[Point3f],
109 k: usize,
110 max_distance: f32,
111 ) -> Result<Vec<Vec<(usize, f32)>>> {
112 if points.is_empty() || query_points.is_empty() {
113 return Ok(vec![Vec::new(); query_points.len()]);
114 }
115
116 let k = k.min(32).max(1); let gpu_points: Vec<GpuPoint> = points
120 .iter()
121 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122 .collect();
123
124 let gpu_query_points: Vec<GpuPoint> = query_points
125 .iter()
126 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127 .collect();
128
129 let points_buffer = self.create_buffer_init(
131 "Points Buffer",
132 &gpu_points,
133 wgpu::BufferUsages::STORAGE,
134 );
135
136 let query_buffer = self.create_buffer_init(
137 "Query Points Buffer",
138 &gpu_query_points,
139 wgpu::BufferUsages::STORAGE,
140 );
141
142 let output_buffer = self.create_buffer(
143 "Output Buffer",
144 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146 );
147
148 let params = NearestNeighborParams {
149 num_points: points.len() as u32,
150 k_neighbors: k as u32,
151 max_distance,
152 _padding: 0,
153 };
154
155 let params_buffer = self.create_buffer_init(
156 "Params Buffer",
157 &[params],
158 wgpu::BufferUsages::UNIFORM,
159 );
160
161 let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163 let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165 let bind_group_layout = self.create_bind_group_layout(
167 "Nearest Neighbor Layout",
168 &[
169 wgpu::BindGroupLayoutEntry {
170 binding: 0,
171 visibility: wgpu::ShaderStages::COMPUTE,
172 ty: wgpu::BindingType::Buffer {
173 ty: wgpu::BufferBindingType::Storage { read_only: true },
174 has_dynamic_offset: false,
175 min_binding_size: None,
176 },
177 count: None,
178 },
179 wgpu::BindGroupLayoutEntry {
180 binding: 1,
181 visibility: wgpu::ShaderStages::COMPUTE,
182 ty: wgpu::BindingType::Buffer {
183 ty: wgpu::BufferBindingType::Storage { read_only: true },
184 has_dynamic_offset: false,
185 min_binding_size: None,
186 },
187 count: None,
188 },
189 wgpu::BindGroupLayoutEntry {
190 binding: 2,
191 visibility: wgpu::ShaderStages::COMPUTE,
192 ty: wgpu::BindingType::Buffer {
193 ty: wgpu::BufferBindingType::Storage { read_only: false },
194 has_dynamic_offset: false,
195 min_binding_size: None,
196 },
197 count: None,
198 },
199 wgpu::BindGroupLayoutEntry {
200 binding: 3,
201 visibility: wgpu::ShaderStages::COMPUTE,
202 ty: wgpu::BindingType::Buffer {
203 ty: wgpu::BufferBindingType::Uniform,
204 has_dynamic_offset: false,
205 min_binding_size: None,
206 },
207 count: None,
208 },
209 ],
210 );
211
212 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214 label: Some("Nearest Neighbor Pipeline"),
215 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216 label: Some("Nearest Neighbor Pipeline Layout"),
217 bind_group_layouts: &[&bind_group_layout],
218 push_constant_ranges: &[],
219 })),
220 module: &shader,
221 entry_point: Some("main"),
222 compilation_options: wgpu::PipelineCompilationOptions::default(),
223 cache: None,
224 });
225
226 let bind_group = self.create_bind_group(
228 "Nearest Neighbor Bind Group",
229 &bind_group_layout,
230 &[
231 wgpu::BindGroupEntry {
232 binding: 0,
233 resource: points_buffer.as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 1,
237 resource: query_buffer.as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 2,
241 resource: output_buffer.as_entire_binding(),
242 },
243 wgpu::BindGroupEntry {
244 binding: 3,
245 resource: params_buffer.as_entire_binding(),
246 },
247 ],
248 );
249
250 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
252 label: Some("Nearest Neighbor Encoder"),
253 });
254
255 {
256 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
257 label: Some("Nearest Neighbor Pass"),
258 timestamp_writes: None,
259 });
260 compute_pass.set_pipeline(&pipeline);
261 compute_pass.set_bind_group(0, &bind_group, &[]);
262 let workgroup_count = (query_points.len() + 63) / 64;
263 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
264 }
265
266 let staging_buffer = self.create_buffer(
268 "Staging Buffer",
269 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
270 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
271 );
272
273 encoder.copy_buffer_to_buffer(
274 &output_buffer,
275 0,
276 &staging_buffer,
277 0,
278 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
279 );
280
281 self.queue.submit(std::iter::once(encoder.finish()));
282
283 let buffer_slice = staging_buffer.slice(..);
285 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
286 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
287
288 let _ = self.device.poll(wgpu::PollType::Wait);
289
290 if let Some(Ok(())) = receiver.receive().await {
291 let data = buffer_slice.get_mapped_range();
292 let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
293
294 let mut results = Vec::with_capacity(query_points.len());
295 for i in 0..query_points.len() {
296 let mut neighbors = Vec::with_capacity(k);
297 for j in 0..k {
298 let idx = i * k + j;
299 if idx < raw_neighbors.len() {
300 let neighbor = raw_neighbors[idx];
301 let point_idx = neighbor[0] as usize;
302 let distance = neighbor[1];
303
304 if point_idx < points.len() && distance < max_distance {
305 neighbors.push((point_idx, distance));
306 }
307 }
308 }
309 results.push(neighbors);
310 }
311
312 drop(data);
313 staging_buffer.unmap();
314
315 Ok(results)
316 } else {
317 Err(Error::Gpu("Failed to read GPU results".to_string()))
318 }
319 }
320}
321
322pub async fn gpu_find_k_nearest(
324 gpu_context: &GpuContext,
325 points: &[Point3f],
326 query: &Point3f,
327 k: usize,
328) -> Result<Vec<(usize, f32)>> {
329 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
330 Ok(results.into_iter().next().unwrap_or_default())
331}
332
333pub async fn gpu_find_k_nearest_batch(
335 gpu_context: &GpuContext,
336 points: &[Point3f],
337 query_points: &[Point3f],
338 k: usize,
339) -> Result<Vec<Vec<(usize, f32)>>> {
340 gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
341}
342
343pub async fn gpu_find_radius_neighbors(
345 gpu_context: &GpuContext,
346 points: &[Point3f],
347 query: &Point3f,
348 radius: f32,
349) -> Result<Vec<(usize, f32)>> {
350 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
351 Ok(results.into_iter().next().unwrap_or_default())
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::device::GpuContext;
358 use threecrate_core::Point3f;
359 use approx::assert_relative_eq;
360
361 async fn try_create_gpu_context() -> Option<GpuContext> {
363 match GpuContext::new().await {
364 Ok(gpu) => Some(gpu),
365 Err(_) => {
366 println!("⚠️ GPU not available, skipping GPU-dependent test");
367 None
368 }
369 }
370 }
371
372 fn create_test_points() -> Vec<Point3f> {
374 vec![
375 Point3f::new(0.0, 0.0, 0.0),
376 Point3f::new(1.0, 0.0, 0.0),
377 Point3f::new(0.0, 1.0, 0.0),
378 Point3f::new(0.0, 0.0, 1.0),
379 Point3f::new(1.0, 1.0, 1.0),
380 ]
381 }
382
383 #[test]
384 fn test_gpu_nearest_neighbor_single() {
385 pollster::block_on(async {
386 let Some(gpu) = try_create_gpu_context().await else {
387 return;
388 };
389
390 let points = create_test_points();
391 let query = Point3f::new(0.1, 0.1, 0.1);
392
393 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
394
395 assert_eq!(neighbors.len(), 3);
396 assert_eq!(neighbors[0].0, 0); assert!(neighbors[0].1 < 0.2); println!("✓ GPU single nearest neighbor test passed");
400 });
401 }
402
403 #[test]
404 fn test_gpu_nearest_neighbor_batch() {
405 pollster::block_on(async {
406 let Some(gpu) = try_create_gpu_context().await else {
407 return;
408 };
409
410 let points = create_test_points();
411 let queries = vec![
412 Point3f::new(0.1, 0.1, 0.1),
413 Point3f::new(0.9, 0.1, 0.1),
414 ];
415
416 let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
417
418 assert_eq!(results.len(), 2);
419 assert_eq!(results[0].len(), 2);
420 assert_eq!(results[1].len(), 2);
421
422 assert_eq!(results[0][0].0, 0);
424
425 assert_eq!(results[1][0].0, 1);
427
428 println!("✓ GPU batch nearest neighbor test passed");
429 });
430 }
431
432 #[test]
433 fn test_gpu_radius_neighbors() {
434 pollster::block_on(async {
435 let Some(gpu) = try_create_gpu_context().await else {
436 return;
437 };
438
439 let points = create_test_points();
440 let query = Point3f::new(0.0, 0.0, 0.0);
441 let radius = 1.5;
442
443 let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
444
445 assert!(!neighbors.is_empty());
447
448 for (_, distance) in &neighbors {
450 assert!(*distance <= radius);
451 }
452
453 println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
454 });
455 }
456
457 #[test]
458 fn test_gpu_nearest_neighbor_accuracy() {
459 pollster::block_on(async {
460 let Some(gpu) = try_create_gpu_context().await else {
461 return;
462 };
463
464 let points = create_test_points();
465 let query = Point3f::new(0.5, 0.5, 0.5);
466
467 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
468
469 assert_eq!(neighbors.len(), 1);
470
471 let mut min_dist = f32::INFINITY;
473 let mut min_idx = 0;
474
475 for (i, point) in points.iter().enumerate() {
476 let dist = (query - *point).magnitude();
477 if dist < min_dist {
478 min_dist = dist;
479 min_idx = i;
480 }
481 }
482
483 assert_eq!(neighbors[0].0, min_idx);
484 assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
485
486 println!("✓ GPU nearest neighbor accuracy test passed");
487 });
488 }
489}