1use threecrate_core::{Point3f, Result, Error};
4use crate::GpuContext;
5use bytemuck::{Pod, Zeroable};
6
7#[repr(C)]
9#[derive(Copy, Clone, Pod, Zeroable)]
10pub struct NearestNeighborParams {
11 pub num_points: u32,
12 pub k_neighbors: u32,
13 pub max_distance: f32,
14 pub _padding: u32,
15}
16
17#[repr(C)]
19#[derive(Copy, Clone, Pod, Zeroable)]
20#[repr(align(16))]
21pub struct GpuPoint {
22 pub position: [f32; 3],
23 pub _padding: f32,
24}
25
26#[repr(C)]
28#[derive(Copy, Clone, Pod, Zeroable)]
29pub struct NeighborResult {
30 pub index: u32,
31 pub distance: f32,
32 pub _padding: [u32; 2],
33}
34
35const NEAREST_NEIGHBOR_SHADER: &str = r#"
36struct GpuPoint {
37 position: vec3<f32>,
38 _padding: f32,
39}
40
41@group(0) @binding(0) var<storage, read> input_points: array<GpuPoint>;
42@group(0) @binding(1) var<storage, read> query_points: array<GpuPoint>;
43@group(0) @binding(2) var<storage, read_write> output_neighbors: array<array<vec2<f32>, MAX_K>>;
44@group(0) @binding(3) var<uniform> params: NearestNeighborParams;
45
46struct NearestNeighborParams {
47 num_points: u32,
48 k_neighbors: u32,
49 max_distance: f32,
50 _padding: u32,
51}
52
53@compute @workgroup_size(64)
54fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
55 let query_idx = global_id.x;
56 if (query_idx >= arrayLength(&query_points)) {
57 return;
58 }
59
60 let query_point = query_points[query_idx].position;
61
62 // Initialize neighbors with maximum distance
63 var neighbors: array<vec2<f32>, MAX_K>;
64 for (var i = 0u; i < params.k_neighbors; i++) {
65 neighbors[i] = vec2<f32>(f32(params.num_points), params.max_distance);
66 }
67
68 // Find k nearest neighbors
69 for (var i = 0u; i < params.num_points; i++) {
70 let diff = input_points[i].position - query_point;
71 let distance = length(diff);
72
73 if (distance < params.max_distance) {
74 // Insert into sorted neighbors array
75 let neighbor = vec2<f32>(f32(i), distance);
76
77 // Find insertion point
78 var insert_idx = params.k_neighbors;
79 for (var j = 0u; j < params.k_neighbors; j++) {
80 if (distance < neighbors[j].y) {
81 insert_idx = j;
82 break;
83 }
84 }
85
86 // Shift and insert
87 if (insert_idx < params.k_neighbors) {
88 for (var j = params.k_neighbors - 1u; j > insert_idx; j--) {
89 neighbors[j] = neighbors[j - 1u];
90 }
91 neighbors[insert_idx] = neighbor;
92 }
93 }
94 }
95
96 // Write results
97 for (var i = 0u; i < params.k_neighbors; i++) {
98 output_neighbors[query_idx][i] = neighbors[i];
99 }
100}
101"#;
102
103impl GpuContext {
104 pub async fn find_k_nearest_neighbors(
106 &self,
107 points: &[Point3f],
108 query_points: &[Point3f],
109 k: usize,
110 max_distance: f32,
111 ) -> Result<Vec<Vec<(usize, f32)>>> {
112 if points.is_empty() || query_points.is_empty() {
113 return Ok(vec![Vec::new(); query_points.len()]);
114 }
115
116 let k = k.min(32).max(1); let gpu_points: Vec<GpuPoint> = points
120 .iter()
121 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
122 .collect();
123
124 let gpu_query_points: Vec<GpuPoint> = query_points
125 .iter()
126 .map(|p| GpuPoint { position: [p.x, p.y, p.z], _padding: 0.0 })
127 .collect();
128
129 let points_buffer = self.create_buffer_init(
131 "Points Buffer",
132 &gpu_points,
133 wgpu::BufferUsages::STORAGE,
134 );
135
136 let query_buffer = self.create_buffer_init(
137 "Query Points Buffer",
138 &gpu_query_points,
139 wgpu::BufferUsages::STORAGE,
140 );
141
142 let output_buffer = self.create_buffer(
143 "Output Buffer",
144 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
145 wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
146 );
147
148 let params = NearestNeighborParams {
149 num_points: points.len() as u32,
150 k_neighbors: k as u32,
151 max_distance,
152 _padding: 0,
153 };
154
155 let params_buffer = self.create_buffer_init(
156 "Params Buffer",
157 &[params],
158 wgpu::BufferUsages::UNIFORM,
159 );
160
161 let shader_source = NEAREST_NEIGHBOR_SHADER.replace("MAX_K", &k.to_string());
163 let shader = self.create_shader_module("Nearest Neighbor Shader", &shader_source);
164
165 let bind_group_layout = self.create_bind_group_layout(
167 "Nearest Neighbor Layout",
168 &[
169 wgpu::BindGroupLayoutEntry {
170 binding: 0,
171 visibility: wgpu::ShaderStages::COMPUTE,
172 ty: wgpu::BindingType::Buffer {
173 ty: wgpu::BufferBindingType::Storage { read_only: true },
174 has_dynamic_offset: false,
175 min_binding_size: None,
176 },
177 count: None,
178 },
179 wgpu::BindGroupLayoutEntry {
180 binding: 1,
181 visibility: wgpu::ShaderStages::COMPUTE,
182 ty: wgpu::BindingType::Buffer {
183 ty: wgpu::BufferBindingType::Storage { read_only: true },
184 has_dynamic_offset: false,
185 min_binding_size: None,
186 },
187 count: None,
188 },
189 wgpu::BindGroupLayoutEntry {
190 binding: 2,
191 visibility: wgpu::ShaderStages::COMPUTE,
192 ty: wgpu::BindingType::Buffer {
193 ty: wgpu::BufferBindingType::Storage { read_only: false },
194 has_dynamic_offset: false,
195 min_binding_size: None,
196 },
197 count: None,
198 },
199 wgpu::BindGroupLayoutEntry {
200 binding: 3,
201 visibility: wgpu::ShaderStages::COMPUTE,
202 ty: wgpu::BindingType::Buffer {
203 ty: wgpu::BufferBindingType::Uniform,
204 has_dynamic_offset: false,
205 min_binding_size: None,
206 },
207 count: None,
208 },
209 ],
210 );
211
212 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
214 label: Some("Nearest Neighbor Pipeline"),
215 layout: Some(&self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
216 label: Some("Nearest Neighbor Pipeline Layout"),
217 bind_group_layouts: &[&bind_group_layout],
218 push_constant_ranges: &[],
219 })),
220 module: &shader,
221 entry_point: "main",
222 compilation_options: wgpu::PipelineCompilationOptions::default(),
223 });
224
225 let bind_group = self.create_bind_group(
227 "Nearest Neighbor Bind Group",
228 &bind_group_layout,
229 &[
230 wgpu::BindGroupEntry {
231 binding: 0,
232 resource: points_buffer.as_entire_binding(),
233 },
234 wgpu::BindGroupEntry {
235 binding: 1,
236 resource: query_buffer.as_entire_binding(),
237 },
238 wgpu::BindGroupEntry {
239 binding: 2,
240 resource: output_buffer.as_entire_binding(),
241 },
242 wgpu::BindGroupEntry {
243 binding: 3,
244 resource: params_buffer.as_entire_binding(),
245 },
246 ],
247 );
248
249 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
251 label: Some("Nearest Neighbor Encoder"),
252 });
253
254 {
255 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
256 label: Some("Nearest Neighbor Pass"),
257 timestamp_writes: None,
258 });
259 compute_pass.set_pipeline(&pipeline);
260 compute_pass.set_bind_group(0, &bind_group, &[]);
261 let workgroup_count = (query_points.len() + 63) / 64;
262 compute_pass.dispatch_workgroups(workgroup_count as u32, 1, 1);
263 }
264
265 let staging_buffer = self.create_buffer(
267 "Staging Buffer",
268 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
269 wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
270 );
271
272 encoder.copy_buffer_to_buffer(
273 &output_buffer,
274 0,
275 &staging_buffer,
276 0,
277 (query_points.len() * k * std::mem::size_of::<[f32; 2]>()) as u64,
278 );
279
280 self.queue.submit(std::iter::once(encoder.finish()));
281
282 let buffer_slice = staging_buffer.slice(..);
284 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
285 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
286
287 self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
288
289 if let Some(Ok(())) = receiver.receive().await {
290 let data = buffer_slice.get_mapped_range();
291 let raw_neighbors: Vec<[f32; 2]> = bytemuck::cast_slice(&data).to_vec();
292
293 let mut results = Vec::with_capacity(query_points.len());
294 for i in 0..query_points.len() {
295 let mut neighbors = Vec::with_capacity(k);
296 for j in 0..k {
297 let idx = i * k + j;
298 if idx < raw_neighbors.len() {
299 let neighbor = raw_neighbors[idx];
300 let point_idx = neighbor[0] as usize;
301 let distance = neighbor[1];
302
303 if point_idx < points.len() && distance < max_distance {
304 neighbors.push((point_idx, distance));
305 }
306 }
307 }
308 results.push(neighbors);
309 }
310
311 drop(data);
312 staging_buffer.unmap();
313
314 Ok(results)
315 } else {
316 Err(Error::Gpu("Failed to read GPU results".to_string()))
317 }
318 }
319}
320
321pub async fn gpu_find_k_nearest(
323 gpu_context: &GpuContext,
324 points: &[Point3f],
325 query: &Point3f,
326 k: usize,
327) -> Result<Vec<(usize, f32)>> {
328 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], k, f32::INFINITY).await?;
329 Ok(results.into_iter().next().unwrap_or_default())
330}
331
332pub async fn gpu_find_k_nearest_batch(
334 gpu_context: &GpuContext,
335 points: &[Point3f],
336 query_points: &[Point3f],
337 k: usize,
338) -> Result<Vec<Vec<(usize, f32)>>> {
339 gpu_context.find_k_nearest_neighbors(points, query_points, k, f32::INFINITY).await
340}
341
342pub async fn gpu_find_radius_neighbors(
344 gpu_context: &GpuContext,
345 points: &[Point3f],
346 query: &Point3f,
347 radius: f32,
348) -> Result<Vec<(usize, f32)>> {
349 let results = gpu_context.find_k_nearest_neighbors(points, &[*query], 32, radius).await?;
350 Ok(results.into_iter().next().unwrap_or_default())
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use crate::device::GpuContext;
357 use threecrate_core::Point3f;
358 use approx::assert_relative_eq;
359
360 async fn try_create_gpu_context() -> Option<GpuContext> {
362 match GpuContext::new().await {
363 Ok(gpu) => Some(gpu),
364 Err(_) => {
365 println!("⚠️ GPU not available, skipping GPU-dependent test");
366 None
367 }
368 }
369 }
370
371 fn create_test_points() -> Vec<Point3f> {
373 vec![
374 Point3f::new(0.0, 0.0, 0.0),
375 Point3f::new(1.0, 0.0, 0.0),
376 Point3f::new(0.0, 1.0, 0.0),
377 Point3f::new(0.0, 0.0, 1.0),
378 Point3f::new(1.0, 1.0, 1.0),
379 ]
380 }
381
382 #[test]
383 fn test_gpu_nearest_neighbor_single() {
384 pollster::block_on(async {
385 let Some(gpu) = try_create_gpu_context().await else {
386 return;
387 };
388
389 let points = create_test_points();
390 let query = Point3f::new(0.1, 0.1, 0.1);
391
392 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 3).await.unwrap();
393
394 assert_eq!(neighbors.len(), 3);
395 assert_eq!(neighbors[0].0, 0); assert!(neighbors[0].1 < 0.2); println!("✓ GPU single nearest neighbor test passed");
399 });
400 }
401
402 #[test]
403 fn test_gpu_nearest_neighbor_batch() {
404 pollster::block_on(async {
405 let Some(gpu) = try_create_gpu_context().await else {
406 return;
407 };
408
409 let points = create_test_points();
410 let queries = vec![
411 Point3f::new(0.1, 0.1, 0.1),
412 Point3f::new(0.9, 0.1, 0.1),
413 ];
414
415 let results = gpu_find_k_nearest_batch(&gpu, &points, &queries, 2).await.unwrap();
416
417 assert_eq!(results.len(), 2);
418 assert_eq!(results[0].len(), 2);
419 assert_eq!(results[1].len(), 2);
420
421 assert_eq!(results[0][0].0, 0);
423
424 assert_eq!(results[1][0].0, 1);
426
427 println!("✓ GPU batch nearest neighbor test passed");
428 });
429 }
430
431 #[test]
432 fn test_gpu_radius_neighbors() {
433 pollster::block_on(async {
434 let Some(gpu) = try_create_gpu_context().await else {
435 return;
436 };
437
438 let points = create_test_points();
439 let query = Point3f::new(0.0, 0.0, 0.0);
440 let radius = 1.5;
441
442 let neighbors = gpu_find_radius_neighbors(&gpu, &points, &query, radius).await.unwrap();
443
444 assert!(!neighbors.is_empty());
446
447 for (_, distance) in &neighbors {
449 assert!(*distance <= radius);
450 }
451
452 println!("✓ GPU radius neighbors test passed: {} neighbors found", neighbors.len());
453 });
454 }
455
456 #[test]
457 fn test_gpu_nearest_neighbor_accuracy() {
458 pollster::block_on(async {
459 let Some(gpu) = try_create_gpu_context().await else {
460 return;
461 };
462
463 let points = create_test_points();
464 let query = Point3f::new(0.5, 0.5, 0.5);
465
466 let neighbors = gpu_find_k_nearest(&gpu, &points, &query, 1).await.unwrap();
467
468 assert_eq!(neighbors.len(), 1);
469
470 let mut min_dist = f32::INFINITY;
472 let mut min_idx = 0;
473
474 for (i, point) in points.iter().enumerate() {
475 let dist = (query - *point).magnitude();
476 if dist < min_dist {
477 min_dist = dist;
478 min_idx = i;
479 }
480 }
481
482 assert_eq!(neighbors[0].0, min_idx);
483 assert_relative_eq!(neighbors[0].1, min_dist, epsilon = 0.001);
484
485 println!("✓ GPU nearest neighbor accuracy test passed");
486 });
487 }
488}