1use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11 pub num_points: u32,
12 pub k_neighbors: u32,
13 pub max_distance: f32,
14 pub _padding: u32,
15}
16
17#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22 pub position: [f32; 3],
23 pub _padding: f32,
24}
25
26#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30 pub index: u32,
31 pub distance: f32,
32 pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37 position: vec3<f32>,
38 _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47 num_points: u32,
48 k_neighbors: u32,
49 max_distance: f32,
50 _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55 let query_idx = global_id.x;
56 if (query_idx >= arrayLength(&query_points)) {
57 return;
58 }
59
60 let query_point = query_points[query_idx].position;
61
62 // Initialize neighbors with maximum distance
63 var neighbors: array<vec2<f32>, MAX_K>;
64 for (var i = 0u; i < params.k_neighbors; i++) {
65 neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66 }
67
68 // Find k nearest neighbors
69 for (var i = 0u; i < params.num_points; i++) {
70 let diff = input_points[i].position - query_point;
71 let distance = length(diff);
72
73 if (distance < params.max_distance) {
74 // Insert into sorted neighbors array
75 let neighbor = vec2<f32>(f32(i), distance);
76
77 // Find insertion point
78 var insert_idx = params.k_neighbors;
79 for (var j = 0u; j < params.k_neighbors; j++) {
80 if (distance < neighbors[j].y) {
81 insert_idx = j;
82 break;
83 }
84 }
85
86 // Shift and insert
87 if (insert_idx < params.k_neighbors) {
88 for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89 neighbors[j] = neighbors[j - 1u];
90 }
91 neighbors[insert_idx] = neighbor;
92 }
93 }
94 }
95
96 // Write results
97 for (var i = 0u; i < params.k_neighbors; i++) {
98 output_neighbors[query_idx][i] = neighbors[i];
99 }
100}
101"#;
102
103impl GpuContext {
104 pub async fn find_k_nearest_neighbors(
106 &self,
107 points: &[Point3f],
108 query_points: &[Point3f],
109 k: usize,
110 max_distance: f32,
111 ) -> Result<Vec<Vec<(usize, f32)>>> {
112 if points.is_empty() || query_points.is_empty() {
113 return Ok(vec![Vec::new(); query_points.len()]);
114 }
115
116 let k = k.min(32).max(1); let gpu_points: Vec<GpuPoint> = points
120 .iter()
121 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122 .collect();
123
124 let gpu_query_points: Vec<GpuPoint> = query_points
125 .iter()
126 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127 .collect();
128
129 let points_buffer = self.create_buffer_init(
131 "Points Buffer",
132 &gpu_points,
133 wgpu::BufferUsages::STORAGE,
134 );
135
136 let query_buffer = self.create_buffer_init(
137 "Query Points Buffer",
138 &gpu_query_points,
139 wgpu::BufferUsages::STORAGE,
140 );
141
142 let output_buffer = self.create_buffer(
143 "Output Buffer",
144 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146 );
147
148 let params = NearestNeighborParams {
149 num_points: points.len() as u32,
150 k_neighbors: k as u32,
151 max_distance,
152 _padding: 0,
153 };
154
155 let params_buffer = self.create_buffer_init(
156 "Params Buffer",
157 &[params],
158 wgpu::BufferUsages::UNIFORM,
159 );
160
161 let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163 let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165 let bind_group_layout = self.create_bind_group_layout(
167 "Nearest Neighbor Layout",
168 &[
169 wgpu::BindGroupLayoutEntry {
170 binding: 0,
171 visibility: wgpu::ShaderStages::COMPUTE,
172 ty: wgpu::BindingType::Buffer {
173 ty: wgpu::BufferBindingType::Storage { read_only: true },
174 has_dynamic_offset: false,
175 min_binding_size: None,
176 },
177 count: None,
178 },
179 wgpu::BindGroupLayoutEntry {
180 binding: 1,
181 visibility: wgpu::ShaderStages::COMPUTE,
182 ty: wgpu::BindingType::Buffer {
183 ty: wgpu::BufferBindingType::Storage { read_only: true },
184 has_dynamic_offset: false,
185 min_binding_size: None,
186 },
187 count: None,
188 },
189 wgpu::BindGroupLayoutEntry {
190 binding: 2,
191 visibility: wgpu::ShaderStages::COMPUTE,
192 ty: wgpu::BindingType::Buffer {
193 ty: wgpu::BufferBindingType::Storage { read_only: false },
194 has_dynamic_offset: false,
195 min_binding_size: None,
196 },
197 count: None,
198 },
199 wgpu::BindGroupLayoutEntry {
200 binding: 3,
201 visibility: wgpu::ShaderStages::COMPUTE,
202 ty: wgpu::BindingType::Buffer {
203 ty: wgpu::BufferBindingType::Uniform,
204 has_dynamic_offset: false,
205 min_binding_size: None,
206 },
207 count: None,
208 },
209 ],
210 );
211
212 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214 label: Some("Nearest Neighbor Pipeline"),
215 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216 label: Some("Nearest Neighbor Pipeline Layout"),
217 bind_group_layouts: &[&bind_group_layout],
218 immediate_size: 0,
219 })),
220 module: &shader,
221 entry_point: Some("main"),
222 compilation_options: wgpu::PipelineCompilationOptions::default(),
223 cache: None,
224 });
225
226 let bind_group = self.create_bind_group(
228 "Nearest Neighbor Bind Group",
229 &bind_group_layout,
230 &[
231 wgpu::BindGroupEntry {
232 binding: 0,
233 resource: points_buffer.as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 1,
237 resource: query_buffer.as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 2,
241 resource: output_buffer.as_entire_binding(),
242 },
243 wgpu::BindGroupEntry {
244 binding: 3,
245 resource: params_buffer.as_entire_binding(),
246 },
247 ],
248 );
249
250 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
252 label: Some("Nearest Neighbor Encoder"),
253 });
254
255 {
256 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
257 label: Some("Nearest Neighbor Pass"),
258 timestamp_writes: None,
259 });
260 compute_pass.set_pipeline(&pipeline);
261 compute_pass.set_bind_group(0, &bind_group, &[]);
262 let workgroup_count = (query_points.len() + 63) / 64;
263 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
264 }
265
266 let staging_buffer = self.create_buffer(
268 "Staging Buffer",
269 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
270 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
271 );
272
273 encoder.copy_buffer_to_buffer(
274 &output_buffer,
275 0,
276 &staging_buffer,
277 0,
278 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
279 );
280
281 self.queue.submit(std::iter::once(encoder.finish()));
282
283 let buffer_slice = staging_buffer.slice(..);
285 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
286 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
287
288 self.device.poll(wgpu::PollType::Wait {
289 submission_index: None,
290 timeout: None,
291 });
292
293 if let Some(Ok(())) = receiver.receive().await {
294 let data = buffer_slice.get_mapped_range();
295 let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
296
297 let mut results = Vec::with_capacity(query_points.len());
298 for i in 0..query_points.len() {
299 let mut neighbors = Vec::with_capacity(k);
300 for j in 0..k {
301 let idx = i * k + j;
302 if idx < raw_neighbors.len() {
303 let neighbor = raw_neighbors[idx];
304 let point_idx = neighbor[0] as usize;
305 let distance = neighbor[1];
306
307 if point_idx < points.len() && distance < max_distance {
308 neighbors.push((point_idx, distance));
309 }
310 }
311 }
312 results.push(neighbors);
313 }
314
315 drop(data);
316 staging_buffer.unmap();
317
318 Ok(results)
319 } else {
320 Err(Error::Gpu("Failed to read GPU results".to_string()))
321 }
322 }
323}
324
325pub async fn gpu_find_k_nearest(
327 gpu_context: &GpuContext,
328 points: &[Point3f],
329 query: &Point3f,
330 k: usize,
331) -> Result<Vec<(usize, f32)>> {
332 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
333 Ok(results.into_iter().next().unwrap_or_default())
334}
335
336pub async fn gpu_find_k_nearest_batch(
338 gpu_context: &GpuContext,
339 points: &[Point3f],
340 query_points: &[Point3f],
341 k: usize,
342) -> Result<Vec<Vec<(usize, f32)>>> {
343 gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
344}
345
346pub async fn gpu_find_radius_neighbors(
348 gpu_context: &GpuContext,
349 points: &[Point3f],
350 query: &Point3f,
351 radius: f32,
352) -> Result<Vec<(usize, f32)>> {
353 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
354 Ok(results.into_iter().next().unwrap_or_default())
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::device::GpuContext;
361 use threecrate_core::Point3f;
362 use approx::assert_relative_eq;
363
364 async fn try_create_gpu_context() -> Option<GpuContext> {
366 match GpuContext::new().await {
367 Ok(gpu) => Some(gpu),
368 Err(_) => {
369 println!("⚠️ GPU not available, skipping GPU-dependent test");
370 None
371 }
372 }
373 }
374
375 fn create_test_points() -> Vec<Point3f> {
377 vec![
378 Point3f::new(0.0, 0.0, 0.0),
379 Point3f::new(1.0, 0.0, 0.0),
380 Point3f::new(0.0, 1.0, 0.0),
381 Point3f::new(0.0, 0.0, 1.0),
382 Point3f::new(1.0, 1.0, 1.0),
383 ]
384 }
385
386 #[test]
387 fn test_gpu_nearest_neighbor_single() {
388 pollster::block_on(async {
389 let Some(gpu) = try_create_gpu_context().await else {
390 return;
391 };
392
393 let points = create_test_points();
394 let query = Point3f::new(0.1, 0.1, 0.1);
395
396 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
397
398 assert_eq!(neighbors.len(), 3);
399 assert_eq!(neighbors[0].0, 0); assert!(neighbors[0].1 < 0.2); println!("✓ GPU single nearest neighbor test passed");
403 });
404 }
405
406 #[test]
407 fn test_gpu_nearest_neighbor_batch() {
408 pollster::block_on(async {
409 let Some(gpu) = try_create_gpu_context().await else {
410 return;
411 };
412
413 let points = create_test_points();
414 let queries = vec![
415 Point3f::new(0.1, 0.1, 0.1),
416 Point3f::new(0.9, 0.1, 0.1),
417 ];
418
419 let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
420
421 assert_eq!(results.len(), 2);
422 assert_eq!(results[0].len(), 2);
423 assert_eq!(results[1].len(), 2);
424
425 assert_eq!(results[0][0].0, 0);
427
428 assert_eq!(results[1][0].0, 1);
430
431 println!("✓ GPU batch nearest neighbor test passed");
432 });
433 }
434
435 #[test]
436 fn test_gpu_radius_neighbors() {
437 pollster::block_on(async {
438 let Some(gpu) = try_create_gpu_context().await else {
439 return;
440 };
441
442 let points = create_test_points();
443 let query = Point3f::new(0.0, 0.0, 0.0);
444 let radius = 1.5;
445
446 let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
447
448 assert!(!neighbors.is_empty());
450
451 for (_, distance) in &neighbors {
453 assert!(*distance <= radius);
454 }
455
456 println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
457 });
458 }
459
460 #[test]
461 fn test_gpu_nearest_neighbor_accuracy() {
462 pollster::block_on(async {
463 let Some(gpu) = try_create_gpu_context().await else {
464 return;
465 };
466
467 let points = create_test_points();
468 let query = Point3f::new(0.5, 0.5, 0.5);
469
470 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
471
472 assert_eq!(neighbors.len(), 1);
473
474 let mut min_dist = f32::INFINITY;
476 let mut min_idx = 0;
477
478 for (i, point) in points.iter().enumerate() {
479 let dist = (query - *point).magnitude();
480 if dist < min_dist {
481 min_dist = dist;
482 min_idx = i;
483 }
484 }
485
486 assert_eq!(neighbors[0].0, min_idx);
487 assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
488
489 println!("✓ GPU nearest neighbor accuracy test passed");
490 });
491 }
492}