vyre_wgpu/runtime/cache/
buffer_pool.rs1use std::collections::HashMap;
4use std::ops::Deref;
5use std::sync::{Arc, LazyLock, Mutex, Weak};
6use vyre::error::{Error, Result};
7
8const MAX_BUFFERS_PER_CLASS: usize = 8;
9
10#[derive(Clone, Hash, PartialEq, Eq)]
11struct BufferKey {
12 device: wgpu::Device,
13 size_class: u64,
14 usage_bits: u32,
15}
16
17#[derive(Default)]
18struct BufferPoolInner {
19 buffers: Mutex<HashMap<BufferKey, Vec<wgpu::Buffer>>>,
20}
21
22#[derive(Clone, Default)]
24pub struct BufferPool {
25 inner: Arc<BufferPoolInner>,
26}
27
28pub struct PooledBuffer {
30 key: BufferKey,
31 buffer: Option<wgpu::Buffer>,
32 pool: Weak<BufferPoolInner>,
33}
34
35impl BufferPool {
36 #[must_use]
38 #[inline]
39 pub fn global() -> &'static Self {
40 static POOL: LazyLock<BufferPool> = LazyLock::new(BufferPool::new);
41 &POOL
42 }
43
44 #[must_use]
46 #[inline]
47 pub fn new() -> Self {
48 Self {
49 inner: Arc::new(BufferPoolInner::default()),
50 }
51 }
52
53 #[inline]
59 pub fn acquire(
60 &self,
61 device: &wgpu::Device,
62 label: &str,
63 size: u64,
64 usage: wgpu::BufferUsages,
65 ) -> Result<PooledBuffer> {
66 let key = BufferKey {
67 device: device.clone(),
68 size_class: size_class(size),
69 usage_bits: usage.bits(),
70 };
71 let mut buffers = self.inner.buffers.lock().map_err(|source| Error::Gpu {
72 message: format!(
73 "GPU buffer pool mutex poisoned: {source}. Fix: restart the process or avoid panicking while holding the buffer pool lock."
74 ),
75 })?;
76 let buffer = buffers
77 .entry(key.clone())
78 .or_default()
79 .pop()
80 .unwrap_or_else(|| {
81 device.create_buffer(&wgpu::BufferDescriptor {
82 label: Some(label),
83 size: key.size_class,
84 usage,
85 mapped_at_creation: false,
86 })
87 });
88 Ok(PooledBuffer {
89 key,
90 buffer: Some(buffer),
91 pool: Arc::downgrade(&self.inner),
92 })
93 }
94
95 #[inline]
97 pub fn release(&self, buffer: PooledBuffer) {
98 drop(buffer);
99 }
100
101 #[inline]
107 pub fn with_buffer<R>(
108 &self,
109 device: &wgpu::Device,
110 label: &str,
111 size: u64,
112 usage: wgpu::BufferUsages,
113 f: impl FnOnce(&wgpu::Buffer) -> R,
114 ) -> Result<R> {
115 let buffer = self.acquire(device, label, size, usage)?;
116 let result = f(&buffer);
117 self.release(buffer);
118 Ok(result)
119 }
120}
121
122impl PooledBuffer {
123 #[must_use]
125 #[inline]
126 pub fn size(&self) -> u64 {
127 self.key.size_class
128 }
129
130 #[must_use]
132 #[inline]
133 pub fn buffer(&self) -> &wgpu::Buffer {
134 self.buffer
135 .as_ref()
136 .expect("pooled buffer missing inner buffer. Fix: do not use PooledBuffer after drop.")
137 }
138}
139
140impl Deref for PooledBuffer {
141 type Target = wgpu::Buffer;
142
143 fn deref(&self) -> &Self::Target {
144 self.buffer()
145 }
146}
147
148impl Drop for PooledBuffer {
149 fn drop(&mut self) {
150 let Some(buffer) = self.buffer.take() else {
151 return;
152 };
153 let Some(pool) = self.pool.upgrade() else {
154 return;
155 };
156 let Ok(mut buffers) = pool.buffers.lock() else {
157 return;
158 };
159 let class = buffers.entry(self.key.clone()).or_default();
160 if class.len() < MAX_BUFFERS_PER_CLASS {
161 class.push(buffer);
162 }
163 }
164}
165
166fn size_class(size: u64) -> u64 {
167 size.max(wgpu::COPY_BUFFER_ALIGNMENT)
168 .next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
169}
170
171#[cfg(test)]
172mod tests {
173 use super::size_class;
174
175 #[test]
176 fn size_class_is_copy_aligned_and_nonzero() {
177 assert_eq!(size_class(0), wgpu::COPY_BUFFER_ALIGNMENT);
178 assert_eq!(size_class(1), wgpu::COPY_BUFFER_ALIGNMENT);
179 assert_eq!(
180 size_class(wgpu::COPY_BUFFER_ALIGNMENT + 1),
181 (wgpu::COPY_BUFFER_ALIGNMENT + 1).next_multiple_of(wgpu::COPY_BUFFER_ALIGNMENT)
182 );
183 }
184}