Skip to main content

vyre_wgpu/runtime/cache/
buffer_pool.rs

1//! Reusable GPU buffer pool keyed by device, size class, and usage flags.
2
3use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::{Arc, LazyLock, Mutex, Weak};
6use vyre::error::{Error, Result};
7
8const MAX_BUFFERS_PER_CLASS: usize = 8;
9
10#[derive(Clone, Hash, PartialEq, Eq)]
11struct BufferKey {
12    device: wgpu::Device,
13    size_class: u64,
14    usage_bits: u32,
15}
16
17#[derive(Default)]
18struct BufferPoolInner {
19    buffers: Mutex<HashMap<BufferKey, Vec<wgpu::Buffer>>>,
20}
21
22/// Device-aware reusable GPU buffer pool.
23#[derive(Clone, Default)]
24pub struct BufferPool {
25    inner: Arc<BufferPoolInner>,
26}
27
28/// Buffer handle that returns to its originating [`BufferPool`] on drop.
29pub struct PooledBuffer {
30    key: BufferKey,
31    buffer: Option<wgpu::Buffer>,
32    pool: Weak<BufferPoolInner>,
33}
34
35impl BufferPool {
36    /// Return the process-wide buffer pool.
37    #[must_use]
38    #[inline]
39    pub fn global() -> &'static Self {
40        static POOL: LazyLock<BufferPool> = LazyLock::new(BufferPool::new);
41        &POOL
42    }
43
44    /// Create an empty buffer pool.
45    #[must_use]
46    #[inline]
47    pub fn new() -> Self {
48        Self {
49            inner: Arc::new(BufferPoolInner::default()),
50        }
51    }
52
53    /// Acquire a reusable buffer with at least `size` bytes and exactly `usage`.
54    ///
55    /// # Errors
56    ///
57    /// Returns [`Error::Gpu`] when pool metadata cannot be locked.
58    #[inline]
59    pub fn acquire(
60        &self,
61        device: &wgpu::Device,
62        label: &str,
63        size: u64,
64        usage: wgpu::BufferUsages,
65    ) -> Result<PooledBuffer> {
66        let key = BufferKey {
67            device: device.clone(),
68            size_class: size_class(size),
69            usage_bits: usage.bits(),
70        };
71        let mut buffers = self.inner.buffers.lock().map_err(|source| Error::Gpu {
72            message: format!(
73                "GPU buffer pool mutex poisoned: {source}. Fix: restart the process or avoid panicking while holding the buffer pool lock."
74            ),
75        })?;
76        let buffer = buffers
77            .entry(key.clone())
78            .or_default()
79            .pop()
80            .unwrap_or_else(|| {
81                device.create_buffer(&wgpu::BufferDescriptor {
82                    label: Some(label),
83                    size: key.size_class,
84                    usage,
85                    mapped_at_creation: false,
86                })
87            });
88        Ok(PooledBuffer {
89            key,
90            buffer: Some(buffer),
91            pool: Arc::downgrade(&self.inner),
92        })
93    }
94
95    /// Release a buffer to the pool immediately.
96    #[inline]
97    pub fn release(&self, buffer: PooledBuffer) {
98        drop(buffer);
99    }
100
101    /// Acquire a buffer for the duration of `f` and release it afterward.
102    ///
103    /// # Errors
104    ///
105    /// Returns [`Error::Gpu`] when buffer acquisition fails.
106    #[inline]
107    pub fn with_buffer<R>(
108        &self,
109        device: &wgpu::Device,
110        label: &str,
111        size: u64,
112        usage: wgpu::BufferUsages,
113        f: impl FnOnce(&wgpu::Buffer) -> R,
114    ) -> Result<R> {
115        let buffer = self.acquire(device, label, size, usage)?;
116        let result = f(&buffer);
117        self.release(buffer);
118        Ok(result)
119    }
120}
121
122impl PooledBuffer {
123    /// Return the size-class allocation backing this pooled buffer.
124    #[must_use]
125    #[inline]
126    pub fn size(&self) -> u64 {
127        self.key.size_class
128    }
129
130    /// Return the inner `wgpu::Buffer`.
131    #[must_use]
132    #[inline]
133    pub fn buffer(&self) -> &wgpu::Buffer {
134        self.buffer
135            .as_ref()
136            .expect("pooled buffer missing inner buffer. Fix: do not use PooledBuffer after drop.")
137    }
138}
139
140impl Deref for PooledBuffer {
141    type Target = wgpu::Buffer;
142
143    fn deref(&self) -> &Self::Target {
144        self.buffer()
145    }
146}
147
148impl Drop for PooledBuffer {
149    fn drop(&mut self) {
150        let Some(buffer) = self.buffer.take() else {
151            return;
152        };
153        let Some(pool) = self.pool.upgrade() else {
154            return;
155        };
156        let Ok(mut buffers) = pool.buffers.lock() else {
157            return;
158        };
159        let class = buffers.entry(self.key.clone()).or_default();
160        if class.len() < MAX_BUFFERS_PER_CLASS {
161            class.push(buffer);
162        }
163    }
164}
165
166fn size_class(size: u64) -> u64 {
167    size.max(wgpu::COPY_BUFFER_ALIGNMENT)
168        .next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
169}
170
171#[cfg(test)]
172mod tests {
173    use super::size_class;
174
175    #[test]
176    fn size_class_is_copy_aligned_and_nonzero() {
177        assert_eq!(size_class(0), wgpu::COPY_BUFFER_ALIGNMENT);
178        assert_eq!(size_class(1), wgpu::COPY_BUFFER_ALIGNMENT);
179        assert_eq!(
180            size_class(wgpu::COPY_BUFFER_ALIGNMENT + 1),
181            (wgpu::COPY_BUFFER_ALIGNMENT + 1).next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
182        );
183    }
184}