1use crate::device::DeviceInfo;
2use crate::error::{GpuError, Result};
3use crate::helpers::{bgl_entry, div_ceil, top_k_cpu, Params4, Params4U, WORKGROUP_SIZE};
4use crate::shaders;
5
6use bytemuck::Pod;
7use wgpu::util::DeviceExt;
8
9pub struct GpuAccelerator {
18 device: wgpu::Device,
19 queue: wgpu::Queue,
20 info: DeviceInfo,
21 max_binding_size: u32,
22 vectors_bufs: Vec<wgpu::Buffer>,
23 norms_bufs: Vec<wgpu::Buffer>,
24 chunk_counts: Vec<usize>,
25 dim: usize,
26 n_vectors: usize,
27}
28
29impl GpuAccelerator {
30 pub fn is_available() -> bool {
32 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
33 backends: wgpu::Backends::all(),
34 ..Default::default()
35 });
36 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
37 power_preference: wgpu::PowerPreference::HighPerformance,
38 compatible_surface: None,
39 force_fallback_adapter: false,
40 }));
41 adapter.is_ok()
42 }
43
44 pub fn new() -> Result<Self> {
49 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
50 backends: wgpu::Backends::all(),
51 ..Default::default()
52 });
53
54 let adapter =
55 pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
56 power_preference: wgpu::PowerPreference::HighPerformance,
57 compatible_surface: None,
58 force_fallback_adapter: false,
59 }))
60 .map_err(|_| GpuError::NoDevice)?;
61
62 let adapter_info = adapter.get_info();
63 let adapter_limits = adapter.limits();
64
65 let info = DeviceInfo {
66 name: adapter_info.name.clone(),
67 backend: format!("{:?}", adapter_info.backend),
68 device_type: format!("{:?}", adapter_info.device_type),
69 max_buffer_size: adapter_limits.max_buffer_size,
70 max_storage_buffer_binding_size: adapter_limits.max_storage_buffer_binding_size,
71 };
72
73 let (device, queue) = pollster::block_on(adapter.request_device(
74 &wgpu::DeviceDescriptor {
75 label: Some("rustyhdf5-gpu"),
76 required_features: wgpu::Features::empty(),
77 required_limits: wgpu::Limits {
78 max_storage_buffer_binding_size: adapter_limits
79 .max_storage_buffer_binding_size,
80 max_buffer_size: adapter_limits.max_buffer_size,
81 ..wgpu::Limits::default()
82 },
83 ..Default::default()
84 },
85 ))
86 .map_err(|e: wgpu::RequestDeviceError| GpuError::DeviceRequest(e.to_string()))?;
87
88 Ok(Self {
89 device,
90 queue,
91 info,
92 max_binding_size: adapter_limits.max_storage_buffer_binding_size,
93 vectors_bufs: Vec::new(),
94 norms_bufs: Vec::new(),
95 chunk_counts: Vec::new(),
96 dim: 0,
97 n_vectors: 0,
98 })
99 }
100
101 pub fn device_info(&self) -> &DeviceInfo {
102 &self.info
103 }
104
105 pub fn max_storage_buffer_binding_size(&self) -> u32 {
106 self.max_binding_size
107 }
108
109 pub fn upload_vectors(&mut self, vectors: &[f32], dim: usize) -> Result<()> {
112 if vectors.is_empty() || dim == 0 {
113 return Err(GpuError::DimensionMismatch {
114 expected: 1,
115 got: 0,
116 });
117 }
118 let n = vectors.len() / dim;
119 if vectors.len() != n * dim {
120 return Err(GpuError::DimensionMismatch {
121 expected: n * dim,
122 got: vectors.len(),
123 });
124 }
125
126 let max_vecs_per_chunk = self.max_binding_size as usize / (dim * 4);
127 if max_vecs_per_chunk == 0 {
128 return Err(GpuError::OutOfMemory {
129 need_mb: (dim as u64 * 4) / (1024 * 1024),
130 avail_mb: self.max_binding_size as u64 / (1024 * 1024),
131 });
132 }
133
134 let mut bufs = Vec::new();
135 let mut counts = Vec::new();
136 let mut offset = 0;
137 while offset < n {
138 let chunk_n = (n - offset).min(max_vecs_per_chunk);
139 let start = offset * dim;
140 let end = start + chunk_n * dim;
141 bufs.push(self.make_storage_buf("vectors_chunk", &vectors[start..end]));
142 counts.push(chunk_n);
143 offset += chunk_n;
144 }
145
146 self.vectors_bufs = bufs;
147 self.chunk_counts = counts;
148 self.norms_bufs.clear();
149 self.dim = dim;
150 self.n_vectors = n;
151 Ok(())
152 }
153
154 pub fn upload_norms(&mut self, norms: &[f32]) -> Result<()> {
156 if norms.len() != self.n_vectors {
157 return Err(GpuError::DimensionMismatch {
158 expected: self.n_vectors,
159 got: norms.len(),
160 });
161 }
162 let mut bufs = Vec::new();
163 let mut offset = 0;
164 for &chunk_n in &self.chunk_counts {
165 bufs.push(self.make_storage_buf("norms_chunk", &norms[offset..offset + chunk_n]));
166 offset += chunk_n;
167 }
168 self.norms_bufs = bufs;
169 Ok(())
170 }
171
172 pub fn cosine_search(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>> {
175 self.check_ready(query.len())?;
176 if k > self.n_vectors {
177 return Err(GpuError::KExceedsN {
178 k,
179 n: self.n_vectors,
180 });
181 }
182 let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
183 let dim = self.dim as u32;
184 let mut all_results: Vec<(usize, f32)> = Vec::new();
185 let mut offset = 0usize;
186
187 for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
188 let chunk_n = self.chunk_counts[ci];
189 let params = Params4 {
190 a: dim,
191 b: chunk_n as u32,
192 c: query_norm,
193 d: 0,
194 };
195 let scores = self.run_4bind_shader(
196 shaders::COSINE_SIMILARITY,
197 ¶ms,
198 query,
199 vecs_buf,
200 Some(&self.norms_bufs[ci]),
201 chunk_n,
202 )?;
203 let chunk_topk = top_k_cpu(&scores, k.min(chunk_n), true);
204 for (idx, score) in chunk_topk {
205 all_results.push((idx + offset, score));
206 }
207 offset += chunk_n;
208 }
209
210 all_results
211 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
212 all_results.truncate(k);
213 Ok(all_results)
214 }
215
216 pub fn batch_cosine_search(
218 &self,
219 queries: &[Vec<f32>],
220 k: usize,
221 ) -> Result<Vec<Vec<(usize, f32)>>> {
222 queries.iter().map(|q| self.cosine_search(q, k)).collect()
223 }
224
225 pub fn l2_search(&self, query: &[f32], k: usize) -> Result<Vec<(usize, f32)>> {
227 self.check_vectors(query.len())?;
228 if k > self.n_vectors {
229 return Err(GpuError::KExceedsN {
230 k,
231 n: self.n_vectors,
232 });
233 }
234 let dim = self.dim as u32;
235 let mut all_results: Vec<(usize, f32)> = Vec::new();
236 let mut offset = 0usize;
237
238 for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
239 let chunk_n = self.chunk_counts[ci];
240 let params = Params4U {
241 a: dim,
242 b: chunk_n as u32,
243 c: 0,
244 d: 0,
245 };
246 let scores =
247 self.run_3bind_shader(shaders::L2_DISTANCE, ¶ms, query, vecs_buf, chunk_n)?;
248 let chunk_topk = top_k_cpu(&scores, k.min(chunk_n), false);
249 for (idx, dist) in chunk_topk {
250 all_results.push((idx + offset, dist));
251 }
252 offset += chunk_n;
253 }
254
255 all_results
256 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
257 all_results.truncate(k);
258 Ok(all_results)
259 }
260
261 pub fn compute_norms(&self) -> Result<Vec<f32>> {
263 if self.vectors_bufs.is_empty() {
264 return Err(GpuError::NoVectors);
265 }
266 let mut all_norms = Vec::with_capacity(self.n_vectors);
267 for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
268 let norms = self.run_norms_shader(vecs_buf, self.chunk_counts[ci], self.dim)?;
269 all_norms.extend_from_slice(&norms);
270 }
271 Ok(all_norms)
272 }
273
274 pub fn compute_norms_gpu(&self, vectors: &[f32], dim: usize) -> Result<Vec<f32>> {
277 if vectors.is_empty() || dim == 0 {
278 return Err(GpuError::DimensionMismatch {
279 expected: 1,
280 got: 0,
281 });
282 }
283 let n = vectors.len() / dim;
284 if vectors.len() != n * dim {
285 return Err(GpuError::DimensionMismatch {
286 expected: n * dim,
287 got: vectors.len(),
288 });
289 }
290 let max_vecs = self.max_binding_size as usize / (dim * 4);
291 let mut all_norms = Vec::with_capacity(n);
292 let mut offset = 0;
293 while offset < n {
294 let chunk_n = (n - offset).min(max_vecs);
295 let start = offset * dim;
296 let end = start + chunk_n * dim;
297 let vecs_buf = self.make_storage_buf("temp_vectors", &vectors[start..end]);
298 let norms = self.run_norms_shader(&vecs_buf, chunk_n, dim)?;
299 all_norms.extend_from_slice(&norms);
300 offset += chunk_n;
301 }
302 Ok(all_norms)
303 }
304
305 pub fn batch_dot_product(
307 &self,
308 queries_flat: &[f32],
309 num_queries: usize,
310 ) -> Result<Vec<f32>> {
311 if self.vectors_bufs.is_empty() {
312 return Err(GpuError::NoVectors);
313 }
314 if queries_flat.len() != num_queries * self.dim {
315 return Err(GpuError::DimensionMismatch {
316 expected: num_queries * self.dim,
317 got: queries_flat.len(),
318 });
319 }
320 let queries_buf = self.make_storage_buf("queries", queries_flat);
321 let mut output = vec![0.0f32; num_queries * self.n_vectors];
322 let mut col_offset = 0usize;
323
324 for (ci, vecs_buf) in self.vectors_bufs.iter().enumerate() {
325 let chunk_n = self.chunk_counts[ci];
326 let total = (num_queries * chunk_n) as u32;
327 let params = Params4U {
328 a: self.dim as u32,
329 b: chunk_n as u32,
330 c: num_queries as u32,
331 d: 0,
332 };
333 let chunk_scores = self.run_batch_shader(
334 shaders::BATCH_DOT_PRODUCT,
335 ¶ms,
336 &queries_buf,
337 vecs_buf,
338 total,
339 num_queries * chunk_n,
340 )?;
341 for qi in 0..num_queries {
342 let src = qi * chunk_n;
343 let dst = qi * self.n_vectors + col_offset;
344 output[dst..dst + chunk_n].copy_from_slice(&chunk_scores[src..src + chunk_n]);
345 }
346 col_offset += chunk_n;
347 }
348 Ok(output)
349 }
350
351 pub fn distance_matrix(
354 &self,
355 queries: &[f32],
356 vectors: &[f32],
357 dim: usize,
358 ) -> Result<Vec<Vec<f32>>> {
359 if queries.is_empty() || vectors.is_empty() || dim == 0 {
360 return Err(GpuError::DimensionMismatch {
361 expected: 1,
362 got: 0,
363 });
364 }
365 let num_queries = queries.len() / dim;
366 let n = vectors.len() / dim;
367 if queries.len() != num_queries * dim || vectors.len() != n * dim {
368 return Err(GpuError::DimensionMismatch {
369 expected: num_queries * dim,
370 got: queries.len(),
371 });
372 }
373 let queries_buf = self.make_storage_buf("dm_queries", queries);
374
375 let max_vecs_input = self.max_binding_size as usize / (dim * 4);
376 let max_vecs_output = if num_queries > 0 {
377 self.max_binding_size as usize / (num_queries * 4)
378 } else {
379 max_vecs_input
380 };
381 let max_vecs = max_vecs_input.min(max_vecs_output).max(1);
382
383 let mut flat_output = vec![0.0f32; num_queries * n];
384 let mut col_offset = 0usize;
385 let mut vec_offset = 0usize;
386
387 while vec_offset < n {
388 let chunk_n = (n - vec_offset).min(max_vecs);
389 let start = vec_offset * dim;
390 let end = start + chunk_n * dim;
391 let vecs_buf = self.make_storage_buf("dm_vectors", &vectors[start..end]);
392 let params = Params4U {
393 a: dim as u32,
394 b: chunk_n as u32,
395 c: num_queries as u32,
396 d: 0,
397 };
398 let chunk_dists = self.run_distance_matrix_shader(
399 ¶ms,
400 &queries_buf,
401 &vecs_buf,
402 num_queries,
403 chunk_n,
404 )?;
405 for qi in 0..num_queries {
406 let src = qi * chunk_n;
407 let dst = qi * n + col_offset;
408 flat_output[dst..dst + chunk_n]
409 .copy_from_slice(&chunk_dists[src..src + chunk_n]);
410 }
411 col_offset += chunk_n;
412 vec_offset += chunk_n;
413 }
414
415 Ok((0..num_queries)
416 .map(|qi| flat_output[qi * n..(qi + 1) * n].to_vec())
417 .collect())
418 }
419
420 pub fn f16_to_f32_batch(&self, f16_bits: &[u16]) -> Result<Vec<f32>> {
422 let total = f16_bits.len() as u32;
423 let params = Params4U { a: total, b: 0, c: 0, d: 0 };
424 let packed: Vec<u32> = f16_bits
425 .chunks(2)
426 .map(|c| {
427 let lo = c[0] as u32;
428 let hi = if c.len() > 1 { c[1] as u32 } else { 0 };
429 lo | (hi << 16)
430 })
431 .collect();
432
433 let pair_count = f16_bits.len().div_ceil(2);
434 let (_params_buf, _input_buf, output_buf, bgl, bind_group) = self.make_3bind_group(
435 ¶ms,
436 bytemuck::cast_slice(&packed),
437 (f16_bits.len() * 4) as u64,
438 );
439 let module = self.make_module("f16_to_f32", shaders::F16_TO_F32);
440 let pipeline = self.create_pipeline(&module, &bgl);
441 self.dispatch(&pipeline, &bind_group, div_ceil(pair_count as u32, WORKGROUP_SIZE));
442 self.read_buffer::<f32>(&output_buf, f16_bits.len())
443 }
444
445 pub fn f32_to_f16_batch(&self, values: &[f32]) -> Result<Vec<u16>> {
447 let total = values.len() as u32;
448 let params = Params4U { a: total, b: 0, c: 0, d: 0 };
449 let pair_count = values.len().div_ceil(2);
450
451 let (_params_buf, _input_buf, output_buf, bgl, bind_group) = self.make_3bind_group(
452 ¶ms,
453 bytemuck::cast_slice(values),
454 (pair_count * 4) as u64,
455 );
456 let module = self.make_module("f32_to_f16", shaders::F32_TO_F16);
457 let pipeline = self.create_pipeline(&module, &bgl);
458 self.dispatch(&pipeline, &bind_group, div_ceil(pair_count as u32, WORKGROUP_SIZE));
459
460 let packed = self.read_buffer::<u32>(&output_buf, pair_count)?;
461 let mut result = Vec::with_capacity(values.len());
462 for (i, &word) in packed.iter().enumerate() {
463 result.push((word & 0xFFFF) as u16);
464 if i * 2 + 1 < values.len() {
465 result.push((word >> 16) as u16);
466 }
467 }
468 Ok(result)
469 }
470
471 pub fn vector_count(&self) -> usize {
472 self.n_vectors
473 }
474
475 pub fn dimension(&self) -> usize {
476 self.dim
477 }
478
479 pub fn chunk_count(&self) -> usize {
480 self.chunk_counts.len()
481 }
482
483 fn check_ready(&self, query_dim: usize) -> Result<()> {
486 self.check_vectors(query_dim)?;
487 if self.norms_bufs.is_empty() {
488 return Err(GpuError::NoNorms);
489 }
490 Ok(())
491 }
492
493 fn check_vectors(&self, query_dim: usize) -> Result<()> {
494 if self.vectors_bufs.is_empty() {
495 return Err(GpuError::NoVectors);
496 }
497 if query_dim != self.dim {
498 return Err(GpuError::DimensionMismatch {
499 expected: self.dim,
500 got: query_dim,
501 });
502 }
503 Ok(())
504 }
505
506 fn make_storage_buf(&self, label: &str, data: &[f32]) -> wgpu::Buffer {
507 self.device
508 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
509 label: Some(label),
510 contents: bytemuck::cast_slice(data),
511 usage: wgpu::BufferUsages::STORAGE,
512 })
513 }
514
515 fn make_module(&self, label: &str, src: &str) -> wgpu::ShaderModule {
516 self.device
517 .create_shader_module(wgpu::ShaderModuleDescriptor {
518 label: Some(label),
519 source: wgpu::ShaderSource::Wgsl(src.into()),
520 })
521 }
522
523 #[allow(clippy::type_complexity)]
525 fn make_3bind_group(
526 &self,
527 params: &Params4U,
528 input_data: &[u8],
529 output_size: u64,
530 ) -> (
531 wgpu::Buffer,
532 wgpu::Buffer,
533 wgpu::Buffer,
534 wgpu::BindGroupLayout,
535 wgpu::BindGroup,
536 ) {
537 let params_buf = self
538 .device
539 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
540 label: Some("params"),
541 contents: bytemuck::bytes_of(params),
542 usage: wgpu::BufferUsages::UNIFORM,
543 });
544 let input_buf = self
545 .device
546 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
547 label: Some("input"),
548 contents: input_data,
549 usage: wgpu::BufferUsages::STORAGE,
550 });
551 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
552 label: Some("output"),
553 size: output_size,
554 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
555 mapped_at_creation: false,
556 });
557 let bgl = self
558 .device
559 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
560 label: None,
561 entries: &[
562 bgl_entry(0, wgpu::BufferBindingType::Uniform),
563 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
564 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: false }),
565 ],
566 });
567 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
568 label: None,
569 layout: &bgl,
570 entries: &[
571 wgpu::BindGroupEntry {
572 binding: 0,
573 resource: params_buf.as_entire_binding(),
574 },
575 wgpu::BindGroupEntry {
576 binding: 1,
577 resource: input_buf.as_entire_binding(),
578 },
579 wgpu::BindGroupEntry {
580 binding: 2,
581 resource: output_buf.as_entire_binding(),
582 },
583 ],
584 });
585 (params_buf, input_buf, output_buf, bgl, bind_group)
586 }
587
588 fn run_norms_shader(
589 &self,
590 vecs_buf: &wgpu::Buffer,
591 n: usize,
592 dim: usize,
593 ) -> Result<Vec<f32>> {
594 let params = Params4U {
595 a: dim as u32,
596 b: n as u32,
597 c: 0,
598 d: 0,
599 };
600 let params_buf = self
601 .device
602 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
603 label: Some("params"),
604 contents: bytemuck::bytes_of(¶ms),
605 usage: wgpu::BufferUsages::UNIFORM,
606 });
607 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
608 label: Some("norms_out"),
609 size: (n * 4) as u64,
610 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
611 mapped_at_creation: false,
612 });
613 let module = self.make_module("batch_norms", shaders::BATCH_NORMS);
614 let bgl = self
615 .device
616 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
617 label: None,
618 entries: &[
619 bgl_entry(0, wgpu::BufferBindingType::Uniform),
620 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
621 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: false }),
622 ],
623 });
624 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
625 label: None,
626 layout: &bgl,
627 entries: &[
628 wgpu::BindGroupEntry {
629 binding: 0,
630 resource: params_buf.as_entire_binding(),
631 },
632 wgpu::BindGroupEntry {
633 binding: 1,
634 resource: vecs_buf.as_entire_binding(),
635 },
636 wgpu::BindGroupEntry {
637 binding: 2,
638 resource: output_buf.as_entire_binding(),
639 },
640 ],
641 });
642 let pipeline = self.create_pipeline(&module, &bgl);
643 self.dispatch(&pipeline, &bind_group, div_ceil(n as u32, WORKGROUP_SIZE));
644 self.read_buffer::<f32>(&output_buf, n)
645 }
646
647 fn run_batch_shader(
648 &self,
649 shader_src: &str,
650 params: &Params4U,
651 queries_buf: &wgpu::Buffer,
652 vecs_buf: &wgpu::Buffer,
653 total_threads: u32,
654 output_len: usize,
655 ) -> Result<Vec<f32>> {
656 let params_buf = self
657 .device
658 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
659 label: Some("params"),
660 contents: bytemuck::bytes_of(params),
661 usage: wgpu::BufferUsages::UNIFORM,
662 });
663 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
664 label: Some("scores"),
665 size: (output_len * 4) as u64,
666 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
667 mapped_at_creation: false,
668 });
669 let module = self
670 .device
671 .create_shader_module(wgpu::ShaderModuleDescriptor {
672 label: None,
673 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
674 });
675 let bgl = self
676 .device
677 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
678 label: None,
679 entries: &[
680 bgl_entry(0, wgpu::BufferBindingType::Uniform),
681 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
682 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
683 bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
684 ],
685 });
686 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
687 label: None,
688 layout: &bgl,
689 entries: &[
690 wgpu::BindGroupEntry {
691 binding: 0,
692 resource: params_buf.as_entire_binding(),
693 },
694 wgpu::BindGroupEntry {
695 binding: 1,
696 resource: queries_buf.as_entire_binding(),
697 },
698 wgpu::BindGroupEntry {
699 binding: 2,
700 resource: vecs_buf.as_entire_binding(),
701 },
702 wgpu::BindGroupEntry {
703 binding: 3,
704 resource: output_buf.as_entire_binding(),
705 },
706 ],
707 });
708 let pipeline = self.create_pipeline(&module, &bgl);
709 self.dispatch(&pipeline, &bind_group, div_ceil(total_threads, WORKGROUP_SIZE));
710 self.read_buffer::<f32>(&output_buf, output_len)
711 }
712
713 fn run_distance_matrix_shader(
714 &self,
715 params: &Params4U,
716 queries_buf: &wgpu::Buffer,
717 vecs_buf: &wgpu::Buffer,
718 num_queries: usize,
719 chunk_n: usize,
720 ) -> Result<Vec<f32>> {
721 let params_buf = self
722 .device
723 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
724 label: Some("params"),
725 contents: bytemuck::bytes_of(params),
726 usage: wgpu::BufferUsages::UNIFORM,
727 });
728 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
729 label: Some("distances"),
730 size: (num_queries * chunk_n * 4) as u64,
731 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
732 mapped_at_creation: false,
733 });
734 let module = self.make_module("distance_matrix", shaders::DISTANCE_MATRIX);
735 let bgl = self
736 .device
737 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
738 label: None,
739 entries: &[
740 bgl_entry(0, wgpu::BufferBindingType::Uniform),
741 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
742 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
743 bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
744 ],
745 });
746 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
747 label: None,
748 layout: &bgl,
749 entries: &[
750 wgpu::BindGroupEntry {
751 binding: 0,
752 resource: params_buf.as_entire_binding(),
753 },
754 wgpu::BindGroupEntry {
755 binding: 1,
756 resource: queries_buf.as_entire_binding(),
757 },
758 wgpu::BindGroupEntry {
759 binding: 2,
760 resource: vecs_buf.as_entire_binding(),
761 },
762 wgpu::BindGroupEntry {
763 binding: 3,
764 resource: output_buf.as_entire_binding(),
765 },
766 ],
767 });
768 let pipeline = self.create_pipeline(&module, &bgl);
769 let wg_x = div_ceil(chunk_n as u32, 16);
770 let wg_y = div_ceil(num_queries as u32, 16);
771 self.dispatch_2d(&pipeline, &bind_group, wg_x, wg_y);
772 self.read_buffer::<f32>(&output_buf, num_queries * chunk_n)
773 }
774
775 fn run_4bind_shader(
776 &self,
777 shader_src: &str,
778 params: &Params4,
779 query: &[f32],
780 vectors_buf: &wgpu::Buffer,
781 extra_buf: Option<&wgpu::Buffer>,
782 output_len: usize,
783 ) -> Result<Vec<f32>> {
784 let params_buf = self
785 .device
786 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
787 label: Some("params"),
788 contents: bytemuck::bytes_of(params),
789 usage: wgpu::BufferUsages::UNIFORM,
790 });
791 let query_buf = self.make_storage_buf("query", query);
792 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
793 label: Some("scores"),
794 size: (output_len * 4) as u64,
795 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
796 mapped_at_creation: false,
797 });
798 let module = self
799 .device
800 .create_shader_module(wgpu::ShaderModuleDescriptor {
801 label: None,
802 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
803 });
804
805 let mut entries_desc = vec![
806 bgl_entry(0, wgpu::BufferBindingType::Uniform),
807 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
808 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
809 ];
810 let mut bind_entries = vec![
811 wgpu::BindGroupEntry {
812 binding: 0,
813 resource: params_buf.as_entire_binding(),
814 },
815 wgpu::BindGroupEntry {
816 binding: 1,
817 resource: query_buf.as_entire_binding(),
818 },
819 wgpu::BindGroupEntry {
820 binding: 2,
821 resource: vectors_buf.as_entire_binding(),
822 },
823 ];
824
825 if let Some(eb) = extra_buf {
826 entries_desc.push(bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: true }));
827 entries_desc.push(bgl_entry(4, wgpu::BufferBindingType::Storage { read_only: false }));
828 bind_entries.push(wgpu::BindGroupEntry {
829 binding: 3,
830 resource: eb.as_entire_binding(),
831 });
832 bind_entries.push(wgpu::BindGroupEntry {
833 binding: 4,
834 resource: output_buf.as_entire_binding(),
835 });
836 } else {
837 entries_desc.push(bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }));
838 bind_entries.push(wgpu::BindGroupEntry {
839 binding: 3,
840 resource: output_buf.as_entire_binding(),
841 });
842 }
843
844 let bgl = self
845 .device
846 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
847 label: None,
848 entries: &entries_desc,
849 });
850 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
851 label: None,
852 layout: &bgl,
853 entries: &bind_entries,
854 });
855
856 let pipeline = self.create_pipeline(&module, &bgl);
857 self.dispatch(
858 &pipeline,
859 &bind_group,
860 div_ceil(output_len as u32, WORKGROUP_SIZE),
861 );
862 self.read_buffer::<f32>(&output_buf, output_len)
863 }
864
865 fn run_3bind_shader(
866 &self,
867 shader_src: &str,
868 params: &Params4U,
869 query: &[f32],
870 vectors_buf: &wgpu::Buffer,
871 output_len: usize,
872 ) -> Result<Vec<f32>> {
873 let params_buf = self
874 .device
875 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
876 label: Some("params"),
877 contents: bytemuck::bytes_of(params),
878 usage: wgpu::BufferUsages::UNIFORM,
879 });
880 let query_buf = self.make_storage_buf("query", query);
881 let output_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
882 label: Some("scores"),
883 size: (output_len * 4) as u64,
884 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
885 mapped_at_creation: false,
886 });
887 let module = self
888 .device
889 .create_shader_module(wgpu::ShaderModuleDescriptor {
890 label: None,
891 source: wgpu::ShaderSource::Wgsl(shader_src.into()),
892 });
893 let bgl = self
894 .device
895 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
896 label: None,
897 entries: &[
898 bgl_entry(0, wgpu::BufferBindingType::Uniform),
899 bgl_entry(1, wgpu::BufferBindingType::Storage { read_only: true }),
900 bgl_entry(2, wgpu::BufferBindingType::Storage { read_only: true }),
901 bgl_entry(3, wgpu::BufferBindingType::Storage { read_only: false }),
902 ],
903 });
904 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
905 label: None,
906 layout: &bgl,
907 entries: &[
908 wgpu::BindGroupEntry {
909 binding: 0,
910 resource: params_buf.as_entire_binding(),
911 },
912 wgpu::BindGroupEntry {
913 binding: 1,
914 resource: query_buf.as_entire_binding(),
915 },
916 wgpu::BindGroupEntry {
917 binding: 2,
918 resource: vectors_buf.as_entire_binding(),
919 },
920 wgpu::BindGroupEntry {
921 binding: 3,
922 resource: output_buf.as_entire_binding(),
923 },
924 ],
925 });
926 let pipeline = self.create_pipeline(&module, &bgl);
927 self.dispatch(
928 &pipeline,
929 &bind_group,
930 div_ceil(output_len as u32, WORKGROUP_SIZE),
931 );
932 self.read_buffer::<f32>(&output_buf, output_len)
933 }
934
935 fn create_pipeline(
936 &self,
937 module: &wgpu::ShaderModule,
938 bgl: &wgpu::BindGroupLayout,
939 ) -> wgpu::ComputePipeline {
940 let layout = self
941 .device
942 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
943 label: None,
944 bind_group_layouts: &[bgl],
945 immediate_size: 0,
946 });
947 self.device
948 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
949 label: None,
950 layout: Some(&layout),
951 module,
952 entry_point: Some("main"),
953 compilation_options: Default::default(),
954 cache: None,
955 })
956 }
957
958 fn dispatch(
959 &self,
960 pipeline: &wgpu::ComputePipeline,
961 bind_group: &wgpu::BindGroup,
962 workgroups: u32,
963 ) {
964 let mut encoder = self
965 .device
966 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
967 {
968 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
969 label: None,
970 timestamp_writes: None,
971 });
972 pass.set_pipeline(pipeline);
973 pass.set_bind_group(0, bind_group, &[]);
974 pass.dispatch_workgroups(workgroups, 1, 1);
975 }
976 self.queue.submit(std::iter::once(encoder.finish()));
977 }
978
979 fn dispatch_2d(
980 &self,
981 pipeline: &wgpu::ComputePipeline,
982 bind_group: &wgpu::BindGroup,
983 wg_x: u32,
984 wg_y: u32,
985 ) {
986 let mut encoder = self
987 .device
988 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
989 {
990 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
991 label: None,
992 timestamp_writes: None,
993 });
994 pass.set_pipeline(pipeline);
995 pass.set_bind_group(0, bind_group, &[]);
996 pass.dispatch_workgroups(wg_x, wg_y, 1);
997 }
998 self.queue.submit(std::iter::once(encoder.finish()));
999 }
1000
1001 fn read_buffer<T: Pod>(&self, buffer: &wgpu::Buffer, count: usize) -> Result<Vec<T>> {
1002 let elem_size = std::mem::size_of::<T>();
1003 let byte_len = (count * elem_size) as u64;
1004 let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
1005 label: Some("staging"),
1006 size: byte_len,
1007 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1008 mapped_at_creation: false,
1009 });
1010 let mut encoder = self
1011 .device
1012 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
1013 encoder.copy_buffer_to_buffer(buffer, 0, &staging, 0, byte_len);
1014 self.queue.submit(std::iter::once(encoder.finish()));
1015
1016 let slice = staging.slice(..);
1017 let (tx, rx) = std::sync::mpsc::channel();
1018 slice.map_async(wgpu::MapMode::Read, move |result| {
1019 let _ = tx.send(result);
1020 });
1021 let _ = self.device.poll(wgpu::PollType::Wait {
1022 submission_index: None,
1023 timeout: None,
1024 });
1025 rx.recv()
1026 .map_err(|e| GpuError::BufferMap(e.to_string()))?
1027 .map_err(|e| GpuError::BufferMap(e.to_string()))?;
1028
1029 let data = slice.get_mapped_range();
1030 let result: Vec<T> = bytemuck::cast_slice(&data).to_vec();
1031 drop(data);
1032 staging.unmap();
1033 Ok(result)
1034 }
1035}