vyre_driver/backend/device_buffer.rs
1//! Backend-owned device-buffer abstraction (SEED-6).
2//!
3//! Today every `VyreBackend::dispatch` round-trips inputs and outputs as
4//! owned `Vec<u8>` buffers - uploaded to the device per call, downloaded
5//! back per call. For workloads that issue thousands of small dispatches
6//! against the same logical buffers (the C parser preprocessor pipeline
7//! is the canonical case), the host-device copies dominate wall time.
8//!
9//! `DeviceBuffer` is the substrate-neutral handle to a backend-owned
10//! allocation. Backends that opt in implement
11//! [`VyreBackend::allocate_device_buffer`] and
12//! [`VyreBackend::dispatch_with_device_buffers`]; consumers that hold a
13//! `Box<dyn DeviceBuffer>` can re-bind it across dispatches instead of
14//! re-uploading bytes. Backends that don't opt in return
15//! [`BackendError::UnsupportedFeature`]; production callers that select
16//! this API must treat that as a hard capability miss, not as permission
17//! to hide a host-buffer dispatch fallback.
18
19use crate::backend::{BackendError, DispatchConfig};
20use vyre_foundation::ir::Program;
21
22/// Opaque handle to a device-resident allocation owned by one backend.
23///
24/// The handle is `Send + Sync` so callers can park it across awaits and
25/// share it across worker threads, but it is NOT portable across
26/// backends - the backend that allocated it must be the same backend
27/// that dispatches it. Cross-backend transfer requires explicit
28/// download → re-upload through the substrate-neutral host path.
29///
30/// `Any` is a supertrait so concrete backends and tests can downcast to
31/// their own allocation type without adding substrate-specific methods
32/// to this public trait.
33pub trait DeviceBuffer: std::any::Any + Send + Sync + std::fmt::Debug {
34 /// Stable backend identifier the buffer belongs to. Matches
35 /// [`crate::backend::VyreBackend::id`] of the allocating backend.
36 fn backend_id(&self) -> &'static str;
37
38 /// Size of the allocation in bytes. The kernel sees this as the
39 /// declared `BufferDecl::count * element_size`.
40 fn byte_len(&self) -> usize;
41
42 /// Optional human-readable label (debug surface only). Backends may
43 /// return `None` when no label was set.
44 fn debug_label(&self) -> Option<&str> {
45 None
46 }
47
48 /// Erase to `&dyn Any` so callers can downcast without naming the
49 /// concrete buffer type. Implementors return `self`.
50 fn as_any(&self) -> &dyn std::any::Any;
51
52 /// Mutable variant of [`Self::as_any`]. Implementors return `self`.
53 fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
54}
55
56/// Marker returned by backends that have not opted in to
57/// [`DeviceBuffer`] yet. The default
58/// [`crate::backend::VyreBackend::allocate_device_buffer`] returns this
59/// variant via [`BackendError::UnsupportedFeature`] so the consumer
60/// path is the same shape across all backends - opt-in detection is one
61/// `Result::is_err` check, no separate trait.
62pub const DEVICE_BUFFER_FEATURE: &str = "DeviceBuffer";
63
64/// Convenience helper for default `VyreBackend::allocate_device_buffer`
65/// impls - every shipped backend returns this variant until they
66/// implement persistent device-buffer allocation.
67pub(crate) fn unsupported_device_buffer(backend_id: &'static str) -> BackendError {
68 BackendError::UnsupportedFeature {
69 name: DEVICE_BUFFER_FEATURE.to_string(),
70 backend: backend_id.to_string(),
71 }
72}
73
74/// Implementor of [`DeviceBuffer`] for compatibility tests and explicit
75/// host-resident fixtures - stores raw bytes on the host, identifies as
76/// the requesting backend.
77///
78/// This is not a production substitute for real device allocation. Real
79/// device backends override [`crate::backend::VyreBackend::allocate_device_buffer`]
80/// to return their own concrete buffer type wrapped in `Box<dyn DeviceBuffer>`.
81#[derive(Debug)]
82pub struct HostShimBuffer {
83 backend_id: &'static str,
84 bytes: Vec<u8>,
85 label: Option<String>,
86}
87
88impl HostShimBuffer {
89 /// Allocate a zero-filled host-resident buffer. The bytes live in
90 /// process memory; backends that use this still pay the upload
91 /// cost on every dispatch but the consumer-side API is the same as
92 /// for true device buffers.
93 #[must_use]
94 pub fn allocate(backend_id: &'static str, byte_len: usize) -> Box<dyn DeviceBuffer> {
95 Box::new(Self {
96 backend_id,
97 bytes: vec![0; byte_len],
98 label: None,
99 })
100 }
101
102 /// Allocate from existing bytes. The buffer takes ownership.
103 #[must_use]
104 pub fn from_bytes(backend_id: &'static str, bytes: Vec<u8>) -> Box<dyn DeviceBuffer> {
105 Box::new(Self {
106 backend_id,
107 bytes,
108 label: None,
109 })
110 }
111
112 /// Borrow the underlying bytes. Only `HostShimBuffer` exposes this -
113 /// real device buffers cannot be byte-borrowed without a download.
114 #[must_use]
115 pub fn as_slice(&self) -> &[u8] {
116 &self.bytes
117 }
118
119 /// Mutably borrow the underlying bytes. Only valid on host-shim
120 /// buffers; real device buffers panic.
121 #[must_use]
122 pub fn as_mut_slice(&mut self) -> &mut [u8] {
123 &mut self.bytes
124 }
125
126 /// Attach a debug label after allocation.
127 pub fn set_label(&mut self, label: impl Into<String>) {
128 self.label = Some(label.into());
129 }
130}
131
132impl DeviceBuffer for HostShimBuffer {
133 fn backend_id(&self) -> &'static str {
134 self.backend_id
135 }
136
137 fn byte_len(&self) -> usize {
138 self.bytes.len()
139 }
140
141 fn debug_label(&self) -> Option<&str> {
142 self.label.as_deref()
143 }
144
145 fn as_any(&self) -> &dyn std::any::Any {
146 self
147 }
148
149 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
150 self
151 }
152}
153
154/// Default `VyreBackend::dispatch_with_device_buffers` implementation
155/// shape: validate every input/output buffer belongs to the same
156/// backend, then delegate. Real device backends override this to bind
157/// their concrete buffer type without the host round-trip.
158///
159/// Consumers that don't care about backend identity can pass any
160/// `&dyn DeviceBuffer` and trust the validation here to surface the
161/// mismatch with an actionable error.
162///
163/// # Errors
164///
165/// Returns [`BackendError::UnsupportedFeature`] when the buffer's
166/// `backend_id` does not match `self_backend_id`, OR when the backend
167/// has not opted in to device-buffer dispatch.
168pub fn validate_buffer_ownership<'a>(
169 self_backend_id: &str,
170 buffers: impl IntoIterator<Item = &'a dyn DeviceBuffer>,
171) -> Result<(), BackendError> {
172 for (idx, buffer) in buffers.into_iter().enumerate() {
173 if buffer.backend_id() != self_backend_id {
174 return Err(BackendError::UnsupportedFeature {
175 name: format!(
176 "DeviceBuffer cross-backend dispatch (buffer {idx} owned by `{}`)",
177 buffer.backend_id()
178 ),
179 backend: self_backend_id.to_string(),
180 });
181 }
182 }
183 Ok(())
184}
185
186/// Default `dispatch_with_device_buffers` body. Backends that have not
187/// implemented their concrete persistent-buffer path fail loudly after
188/// ownership validation.
189///
190/// Earlier versions mirrored device-buffer dispatch through
191/// [`HostShimBuffer`] and regular `dispatch`. That hid host copies behind
192/// the resident-buffer API and made performance regressions look like
193/// working functionality. Backends with real device buffers MUST override
194/// this method.
195///
196/// # Errors
197///
198/// Returns [`BackendError::UnsupportedFeature`] when the backend has not
199/// provided a real resident-buffer implementation.
200pub fn default_dispatch_with_device_buffers(
201 backend: &dyn crate::backend::VyreBackend,
202 program: &Program,
203 inputs: &[&dyn DeviceBuffer],
204 outputs: &mut [&mut dyn DeviceBuffer],
205 config: &DispatchConfig,
206) -> Result<(), BackendError> {
207 let _ = (program, config);
208 validate_buffer_ownership(backend.id(), inputs.iter().copied())?;
209 validate_buffer_ownership(
210 backend.id(),
211 outputs.iter().map(|b| &**b as &dyn DeviceBuffer),
212 )?;
213 Err(BackendError::UnsupportedFeature {
214 name: "DeviceBuffer dispatch requires a backend-native resident-buffer implementation; host-shim dispatch fallback is forbidden".to_string(),
215 backend: backend.id().to_string(),
216 })
217}
218
219/// Compile-time confirmation that the trait is dyn-safe.
220const _ASSERT_DYN_SAFE: Option<&dyn DeviceBuffer> = None;
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn host_shim_buffer_reports_size_and_backend() {
228 let buf = HostShimBuffer::allocate("test-backend", 64);
229 assert_eq!(buf.backend_id(), "test-backend");
230 assert_eq!(buf.byte_len(), 64);
231 assert!(buf.debug_label().is_none());
232 }
233
234 #[test]
235 fn host_shim_buffer_round_trips_bytes() {
236 let mut buf = HostShimBuffer::allocate("test-backend", 8);
237 let shim = buf
238 .as_any_mut()
239 .downcast_mut::<HostShimBuffer>()
240 .expect("Fix: HostShimBuffer");
241 shim.as_mut_slice()
242 .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
243 let shim_ref = buf
244 .as_any()
245 .downcast_ref::<HostShimBuffer>()
246 .expect("Fix: HostShimBuffer");
247 assert_eq!(shim_ref.as_slice(), &[1, 2, 3, 4, 5, 6, 7, 8]);
248 }
249
250 #[test]
251 fn validate_buffer_ownership_rejects_cross_backend() {
252 let cuda_buf = HostShimBuffer::allocate("cuda", 4);
253 let wgpu_buf = HostShimBuffer::allocate("wgpu", 4);
254 let result =
255 validate_buffer_ownership("cuda", [cuda_buf.as_ref(), wgpu_buf.as_ref()].into_iter());
256 assert!(matches!(
257 result,
258 Err(BackendError::UnsupportedFeature { .. })
259 ));
260 }
261
262 #[test]
263 fn validate_buffer_ownership_accepts_same_backend() {
264 let a = HostShimBuffer::allocate("cuda", 4);
265 let b = HostShimBuffer::allocate("cuda", 8);
266 validate_buffer_ownership("cuda", [a.as_ref(), b.as_ref()].into_iter())
267 .expect("Fix: same-backend buffers must validate");
268 }
269
270 #[test]
271 fn unsupported_device_buffer_marks_feature_correctly() {
272 let err = unsupported_device_buffer("test-backend");
273 match err {
274 BackendError::UnsupportedFeature { name, backend } => {
275 assert_eq!(name, DEVICE_BUFFER_FEATURE);
276 assert_eq!(backend, "test-backend");
277 }
278 other => panic!("unexpected variant: {other:?}"),
279 }
280 }
281}