wgcore/shapes.rs
1//! Tensor shape definition.
2
3use dashmap::DashMap;
4use std::sync::Arc;
5use wgpu::util::{BufferInitDescriptor, DeviceExt};
6use wgpu::{Buffer, BufferUsages, Device};
7
8#[derive(Copy, Clone, PartialEq, Eq, Hash, bytemuck::Pod, bytemuck::Zeroable)]
9#[repr(C, align(16))]
10/// The shape of a matrix view over a GPU tensor.
11pub struct ViewShape {
12 /// The view’s number of rows and columns.
13 pub size: [u32; 2],
14 /// The view’s column stride (number of elements between two columns).
15 pub stride: u32,
16 /// Index of the first element of the view on the underlying buffer.
17 pub offset: u32,
18}
19
20/// A map between a `ViewShape` and an uniform storage `Buffer` containing its value on the gpu.
21///
22/// Ideally, we should use push-constants for view shapes. Unfortunately, push-constants is an
23/// optional extension, so we have to emulate them with uniforms for maximum portability.
24#[derive(Default)]
25pub struct ViewShapeBuffers {
26 buffers: DashMap<ViewShape, Arc<Buffer>>,
27}
28
29impl ViewShapeBuffers {
30 /// Creates an empty map.
31 pub fn new() -> Self {
32 Self {
33 buffers: DashMap::new(),
34 }
35 }
36
37 /// Gets of insert the gpu uniform storage `Buffer` containing the value of `shape`.
38 pub fn get(&self, device: &Device, shape: ViewShape) -> Arc<Buffer> {
39 self.buffers
40 .entry(shape)
41 .or_insert_with(|| {
42 Arc::new(device.create_buffer_init(&BufferInitDescriptor {
43 label: None,
44 contents: bytemuck::cast_slice(&[shape]),
45 usage: BufferUsages::UNIFORM,
46 }))
47 })
48 .clone()
49 }
50}