Skip to main content

vyre_driver/backend/
device_buffer.rs

1//! Backend-owned device-buffer abstraction (SEED-6).
2//!
3//! Today every `VyreBackend::dispatch` round-trips inputs and outputs as
4//! owned `Vec<u8>` buffers  -  uploaded to the device per call, downloaded
5//! back per call. For workloads that issue thousands of small dispatches
6//! against the same logical buffers (the C parser preprocessor pipeline
7//! is the canonical case), the host-device copies dominate wall time.
8//!
9//! `DeviceBuffer` is the substrate-neutral handle to a backend-owned
10//! allocation. Backends that opt in implement
11//! [`VyreBackend::allocate_device_buffer`] and
12//! [`VyreBackend::dispatch_with_device_buffers`]; consumers that hold a
13//! `Box<dyn DeviceBuffer>` can re-bind it across dispatches instead of
14//! re-uploading bytes. Backends that don't opt in return
15//! [`BackendError::UnsupportedFeature`]; production callers that select
16//! this API must treat that as a hard capability miss, not as permission
17//! to hide a host-buffer dispatch fallback.
18
19use crate::backend::{BackendError, DispatchConfig};
20use vyre_foundation::ir::Program;
21
22/// Opaque handle to a device-resident allocation owned by one backend.
23///
24/// The handle is `Send + Sync` so callers can park it across awaits and
25/// share it across worker threads, but it is NOT portable across
26/// backends  -  the backend that allocated it must be the same backend
27/// that dispatches it. Cross-backend transfer requires explicit
28/// download → re-upload through the substrate-neutral host path.
29///
30/// `Any` is a supertrait so concrete backends and tests can downcast to
31/// their own allocation type without adding substrate-specific methods
32/// to this public trait.
33pub trait DeviceBuffer: std::any::Any + Send + Sync + std::fmt::Debug {
34    /// Stable backend identifier the buffer belongs to. Matches
35    /// [`crate::backend::VyreBackend::id`] of the allocating backend.
36    fn backend_id(&self) -> &'static str;
37
38    /// Size of the allocation in bytes. The kernel sees this as the
39    /// declared `BufferDecl::count * element_size`.
40    fn byte_len(&self) -> usize;
41
42    /// Optional human-readable label (debug surface only). Backends may
43    /// return `None` when no label was set.
44    fn debug_label(&self) -> Option<&str> {
45        None
46    }
47
48    /// Erase to `&dyn Any` so callers can downcast without naming the
49    /// concrete buffer type. Implementors return `self`.
50    fn as_any(&self) -> &dyn std::any::Any;
51
52    /// Mutable variant of [`Self::as_any`]. Implementors return `self`.
53    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
54}
55
56/// Marker returned by backends that have not opted in to
57/// [`DeviceBuffer`] yet. The default
58/// [`crate::backend::VyreBackend::allocate_device_buffer`] returns this
59/// variant via [`BackendError::UnsupportedFeature`] so the consumer
60/// path is the same shape across all backends  -  opt-in detection is one
61/// `Result::is_err` check, no separate trait.
62pub const DEVICE_BUFFER_FEATURE: &str = "DeviceBuffer";
63
64/// Convenience helper for default `VyreBackend::allocate_device_buffer`
65/// impls  -  every shipped backend returns this variant until they
66/// implement persistent device-buffer allocation.
67pub(crate) fn unsupported_device_buffer(backend_id: &'static str) -> BackendError {
68    BackendError::UnsupportedFeature {
69        name: DEVICE_BUFFER_FEATURE.to_string(),
70        backend: backend_id.to_string(),
71    }
72}
73
74/// Implementor of [`DeviceBuffer`] for compatibility tests and explicit
75/// host-resident fixtures  -  stores raw bytes on the host, identifies as
76/// the requesting backend.
77///
78/// This is not a production substitute for real device allocation. Real
79/// device backends override [`crate::backend::VyreBackend::allocate_device_buffer`]
80/// to return their own concrete buffer type wrapped in `Box<dyn DeviceBuffer>`.
81#[derive(Debug)]
82pub struct HostShimBuffer {
83    backend_id: &'static str,
84    bytes: Vec<u8>,
85    label: Option<String>,
86}
87
88impl HostShimBuffer {
89    /// Allocate a zero-filled host-resident buffer. The bytes live in
90    /// process memory; backends that use this still pay the upload
91    /// cost on every dispatch but the consumer-side API is the same as
92    /// for true device buffers.
93    #[must_use]
94    pub fn allocate(backend_id: &'static str, byte_len: usize) -> Box<dyn DeviceBuffer> {
95        Box::new(Self {
96            backend_id,
97            bytes: vec![0; byte_len],
98            label: None,
99        })
100    }
101
102    /// Allocate from existing bytes. The buffer takes ownership.
103    #[must_use]
104    pub fn from_bytes(backend_id: &'static str, bytes: Vec<u8>) -> Box<dyn DeviceBuffer> {
105        Box::new(Self {
106            backend_id,
107            bytes,
108            label: None,
109        })
110    }
111
112    /// Borrow the underlying bytes. Only `HostShimBuffer` exposes this  -
113    /// real device buffers cannot be byte-borrowed without a download.
114    #[must_use]
115    pub fn as_slice(&self) -> &[u8] {
116        &self.bytes
117    }
118
119    /// Mutably borrow the underlying bytes. Only valid on host-shim
120    /// buffers; real device buffers panic.
121    #[must_use]
122    pub fn as_mut_slice(&mut self) -> &mut [u8] {
123        &mut self.bytes
124    }
125
126    /// Attach a debug label after allocation.
127    pub fn set_label(&mut self, label: impl Into<String>) {
128        self.label = Some(label.into());
129    }
130}
131
132impl DeviceBuffer for HostShimBuffer {
133    fn backend_id(&self) -> &'static str {
134        self.backend_id
135    }
136
137    fn byte_len(&self) -> usize {
138        self.bytes.len()
139    }
140
141    fn debug_label(&self) -> Option<&str> {
142        self.label.as_deref()
143    }
144
145    fn as_any(&self) -> &dyn std::any::Any {
146        self
147    }
148
149    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
150        self
151    }
152}
153
154/// Default `VyreBackend::dispatch_with_device_buffers` implementation
155/// shape: validate every input/output buffer belongs to the same
156/// backend, then delegate. Real device backends override this to bind
157/// their concrete buffer type without the host round-trip.
158///
159/// Consumers that don't care about backend identity can pass any
160/// `&dyn DeviceBuffer` and trust the validation here to surface the
161/// mismatch with an actionable error.
162///
163/// # Errors
164///
165/// Returns [`BackendError::UnsupportedFeature`] when the buffer's
166/// `backend_id` does not match `self_backend_id`, OR when the backend
167/// has not opted in to device-buffer dispatch.
168pub fn validate_buffer_ownership<'a>(
169    self_backend_id: &str,
170    buffers: impl IntoIterator<Item = &'a dyn DeviceBuffer>,
171) -> Result<(), BackendError> {
172    for (idx, buffer) in buffers.into_iter().enumerate() {
173        if buffer.backend_id() != self_backend_id {
174            return Err(BackendError::UnsupportedFeature {
175                name: format!(
176                    "DeviceBuffer cross-backend dispatch (buffer {idx} owned by `{}`)",
177                    buffer.backend_id()
178                ),
179                backend: self_backend_id.to_string(),
180            });
181        }
182    }
183    Ok(())
184}
185
186/// Default `dispatch_with_device_buffers` body. Backends that have not
187/// implemented their concrete persistent-buffer path fail loudly after
188/// ownership validation.
189///
190/// Earlier versions mirrored device-buffer dispatch through
191/// [`HostShimBuffer`] and regular `dispatch`. That hid host copies behind
192/// the resident-buffer API and made performance regressions look like
193/// working functionality. Backends with real device buffers MUST override
194/// this method.
195///
196/// # Errors
197///
198/// Returns [`BackendError::UnsupportedFeature`] when the backend has not
199/// provided a real resident-buffer implementation.
200pub fn default_dispatch_with_device_buffers(
201    backend: &dyn crate::backend::VyreBackend,
202    program: &Program,
203    inputs: &[&dyn DeviceBuffer],
204    outputs: &mut [&mut dyn DeviceBuffer],
205    config: &DispatchConfig,
206) -> Result<(), BackendError> {
207    let _ = (program, config);
208    validate_buffer_ownership(backend.id(), inputs.iter().copied())?;
209    validate_buffer_ownership(
210        backend.id(),
211        outputs.iter().map(|b| &**b as &dyn DeviceBuffer),
212    )?;
213    Err(BackendError::UnsupportedFeature {
214        name: "DeviceBuffer dispatch requires a backend-native resident-buffer implementation; host-shim dispatch fallback is forbidden".to_string(),
215        backend: backend.id().to_string(),
216    })
217}
218
219/// Compile-time confirmation that the trait is dyn-safe.
220const _ASSERT_DYN_SAFE: Option<&dyn DeviceBuffer> = None;
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn host_shim_buffer_reports_size_and_backend() {
228        let buf = HostShimBuffer::allocate("test-backend", 64);
229        assert_eq!(buf.backend_id(), "test-backend");
230        assert_eq!(buf.byte_len(), 64);
231        assert!(buf.debug_label().is_none());
232    }
233
234    #[test]
235    fn host_shim_buffer_round_trips_bytes() {
236        let mut buf = HostShimBuffer::allocate("test-backend", 8);
237        let shim = buf
238            .as_any_mut()
239            .downcast_mut::<HostShimBuffer>()
240            .expect("Fix: HostShimBuffer");
241        shim.as_mut_slice()
242            .copy_from_slice(&[1, 2, 3, 4, 5, 6, 7, 8]);
243        let shim_ref = buf
244            .as_any()
245            .downcast_ref::<HostShimBuffer>()
246            .expect("Fix: HostShimBuffer");
247        assert_eq!(shim_ref.as_slice(), &[1, 2, 3, 4, 5, 6, 7, 8]);
248    }
249
250    #[test]
251    fn validate_buffer_ownership_rejects_cross_backend() {
252        let cuda_buf = HostShimBuffer::allocate("cuda", 4);
253        let wgpu_buf = HostShimBuffer::allocate("wgpu", 4);
254        let result =
255            validate_buffer_ownership("cuda", [cuda_buf.as_ref(), wgpu_buf.as_ref()].into_iter());
256        assert!(matches!(
257            result,
258            Err(BackendError::UnsupportedFeature { .. })
259        ));
260    }
261
262    #[test]
263    fn validate_buffer_ownership_accepts_same_backend() {
264        let a = HostShimBuffer::allocate("cuda", 4);
265        let b = HostShimBuffer::allocate("cuda", 8);
266        validate_buffer_ownership("cuda", [a.as_ref(), b.as_ref()].into_iter())
267            .expect("Fix: same-backend buffers must validate");
268    }
269
270    #[test]
271    fn unsupported_device_buffer_marks_feature_correctly() {
272        let err = unsupported_device_buffer("test-backend");
273        match err {
274            BackendError::UnsupportedFeature { name, backend } => {
275                assert_eq!(name, DEVICE_BUFFER_FEATURE);
276                assert_eq!(backend, "test-backend");
277            }
278            other => panic!("unexpected variant: {other:?}"),
279        }
280    }
281}