wgcore/
shapes.rs

1//! Tensor shape definition.
2
3use dashmap::DashMap;
4use std::sync::Arc;
5use wgpu::util::{BufferInitDescriptor, DeviceExt};
6use wgpu::{Buffer, BufferUsages, Device};
7
8#[derive(Copy, Clone, PartialEq, Eq, Hash, bytemuck::Pod, bytemuck::Zeroable)]
9#[repr(C, align(16))]
10/// The shape of a matrix view over a GPU tensor.
11pub struct ViewShape {
12    /// The view’s number of rows and columns.
13    pub size: [u32; 2],
14    /// The view’s column stride (number of elements between two columns).
15    pub stride: u32,
16    /// Index of the first element of the view on the underlying buffer.
17    pub offset: u32,
18}
19
20/// A map between a `ViewShape` and an uniform storage `Buffer` containing its value on the gpu.
21///
22/// Ideally, we should use push-constants for view shapes. Unfortunately, push-constants is an
23/// optional extension, so we have to emulate them with uniforms for maximum portability.
24#[derive(Default)]
25pub struct ViewShapeBuffers {
26    buffers: DashMap<ViewShape, Arc<Buffer>>,
27}
28
29impl ViewShapeBuffers {
30    /// Creates an empty map.
31    pub fn new() -> Self {
32        Self {
33            buffers: DashMap::new(),
34        }
35    }
36
37    /// Gets of insert the gpu uniform storage `Buffer` containing the value of `shape`.
38    pub fn get(&self, device: &Device, shape: ViewShape) -> Arc<Buffer> {
39        self.buffers
40            .entry(shape)
41            .or_insert_with(|| {
42                Arc::new(device.create_buffer_init(&BufferInitDescriptor {
43                    label: None,
44                    contents: bytemuck::cast_slice(&[shape]),
45                    usage: BufferUsages::UNIFORM,
46                }))
47            })
48            .clone()
49    }
50}