Skip to main content

vyre_driver/backend/
typed_dispatch.rs

1//! Typed dispatch helpers layered over the frozen backend contract.
2
3use std::mem;
4
5use bytemuck::Pod;
6use smallvec::SmallVec;
7use vyre_foundation::ir::Program;
8
9use crate::backend::{BackendError, DispatchConfig, OutputBuffers, VyreBackend};
10
11/// Extension methods for callers that work with typed POD buffers instead of
12/// manually packing and unpacking byte vectors.
13pub trait TypedDispatchExt: VyreBackend {
14    /// Dispatch borrowed byte slices.
15    ///
16    /// This is a naming convenience over [`VyreBackend::dispatch_borrowed`]
17    /// for call sites that are migrating away from owned `Vec<u8>` inputs.
18    ///
19    /// # Errors
20    ///
21    /// Returns [`BackendError`] when the backend rejects the program, inputs,
22    /// or dispatch.
23    fn dispatch_bytes(
24        &self,
25        program: &Program,
26        inputs: &[&[u8]],
27        config: &DispatchConfig,
28    ) -> Result<Vec<Vec<u8>>, BackendError> {
29        self.dispatch_borrowed(program, inputs, config)
30    }
31
32    /// Dispatch borrowed typed POD inputs and decode each output as `T`.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`BackendError`] when an output byte length is not a whole
37    /// number of `T` values or when the backend dispatch fails.
38    fn dispatch_pod<T: Pod>(
39        &self,
40        program: &Program,
41        inputs: &[&[T]],
42        config: &DispatchConfig,
43    ) -> Result<Vec<Vec<T>>, BackendError> {
44        let byte_inputs = pod_input_byte_slices(inputs)?;
45        let outputs = self.dispatch_borrowed(program, &byte_inputs, config)?;
46        decode_pod_outputs(outputs)
47    }
48
49    /// Dispatch borrowed typed POD inputs and decode each output as `T` into
50    /// caller-owned storage.
51    ///
52    /// `raw_outputs` retains the backend byte buffers between calls and
53    /// `typed_outputs` retains decoded POD slots. Hot loops should use this
54    /// instead of [`TypedDispatchExt::dispatch_pod`] to avoid rebuilding both
55    /// output shells on every launch.
56    ///
57    /// # Errors
58    ///
59    /// Returns [`BackendError`] when an output byte length is not a whole
60    /// number of `T` values or when the backend dispatch fails.
61    fn dispatch_pod_into<T: Pod>(
62        &self,
63        program: &Program,
64        inputs: &[&[T]],
65        config: &DispatchConfig,
66        raw_outputs: &mut OutputBuffers,
67        typed_outputs: &mut Vec<Vec<T>>,
68    ) -> Result<(), BackendError> {
69        let byte_inputs = pod_input_byte_slices(inputs)?;
70        self.dispatch_borrowed_into(program, &byte_inputs, config, raw_outputs)?;
71        decode_pod_outputs_into(raw_outputs, typed_outputs)
72    }
73
74    /// Dispatch borrowed `u32` inputs and decode each output as `u32`.
75    ///
76    /// # Errors
77    ///
78    /// Returns [`BackendError`] on backend failure or malformed output length.
79    fn dispatch_u32(
80        &self,
81        program: &Program,
82        inputs: &[&[u32]],
83        config: &DispatchConfig,
84    ) -> Result<Vec<Vec<u32>>, BackendError> {
85        self.dispatch_pod(program, inputs, config)
86    }
87
88    /// Dispatch borrowed `u32` inputs and decode outputs into caller-owned
89    /// typed storage.
90    ///
91    /// # Errors
92    ///
93    /// Returns [`BackendError`] on backend failure or malformed output length.
94    fn dispatch_u32_into(
95        &self,
96        program: &Program,
97        inputs: &[&[u32]],
98        config: &DispatchConfig,
99        raw_outputs: &mut OutputBuffers,
100        typed_outputs: &mut Vec<Vec<u32>>,
101    ) -> Result<(), BackendError> {
102        self.dispatch_pod_into(program, inputs, config, raw_outputs, typed_outputs)
103    }
104
105    /// Dispatch borrowed `f32` inputs and decode each output as `f32`.
106    ///
107    /// # Errors
108    ///
109    /// Returns [`BackendError`] on backend failure or malformed output length.
110    fn dispatch_f32(
111        &self,
112        program: &Program,
113        inputs: &[&[f32]],
114        config: &DispatchConfig,
115    ) -> Result<Vec<Vec<f32>>, BackendError> {
116        self.dispatch_pod(program, inputs, config)
117    }
118
119    /// Dispatch borrowed `f32` inputs and decode outputs into caller-owned
120    /// typed storage.
121    ///
122    /// # Errors
123    ///
124    /// Returns [`BackendError`] on backend failure or malformed output length.
125    fn dispatch_f32_into(
126        &self,
127        program: &Program,
128        inputs: &[&[f32]],
129        config: &DispatchConfig,
130        raw_outputs: &mut OutputBuffers,
131        typed_outputs: &mut Vec<Vec<f32>>,
132    ) -> Result<(), BackendError> {
133        self.dispatch_pod_into(program, inputs, config, raw_outputs, typed_outputs)
134    }
135}
136
137impl<T: VyreBackend + ?Sized> TypedDispatchExt for T {}
138
139fn pod_input_byte_slices<'a, T: Pod>(
140    inputs: &'a [&'a [T]],
141) -> Result<SmallVec<[&'a [u8]; 8]>, BackendError> {
142    let mut byte_inputs = SmallVec::<[&[u8]; 8]>::new();
143    byte_inputs.try_reserve(inputs.len()).map_err(|error| {
144        BackendError::InvalidProgram {
145            fix: format!(
146                "Fix: typed dispatch could not reserve {} POD input byte slice(s): {error}. Reuse caller-owned byte slices or shard the typed dispatch.",
147                inputs.len()
148            ),
149        }
150    })?;
151    byte_inputs.extend(
152        inputs
153            .iter()
154            .map(|input| bytemuck::cast_slice::<T, u8>(input)),
155    );
156    Ok(byte_inputs)
157}
158
159fn decode_pod_outputs<T: Pod>(outputs: Vec<Vec<u8>>) -> Result<Vec<Vec<T>>, BackendError> {
160    let width = mem::size_of::<T>();
161    if width == 0 {
162        return Err(BackendError::InvalidProgram {
163            fix: "Fix: typed dispatch does not support zero-sized POD outputs.".to_string(),
164        });
165    }
166    let mut typed_outputs = Vec::new();
167    crate::backend::resize_typed_output_slots(
168        &mut typed_outputs,
169        outputs.len(),
170        "typed POD output",
171    )?;
172    for (index, (bytes, slot)) in outputs
173        .into_iter()
174        .zip(typed_outputs.iter_mut())
175        .enumerate()
176    {
177        decode_pod_output_into(index, &bytes, width, slot)?;
178    }
179    Ok(typed_outputs)
180}
181
182fn decode_pod_outputs_into<T: Pod>(
183    raw_outputs: &[Vec<u8>],
184    typed_outputs: &mut Vec<Vec<T>>,
185) -> Result<(), BackendError> {
186    let width = mem::size_of::<T>();
187    if width == 0 {
188        return Err(BackendError::InvalidProgram {
189            fix: "Fix: typed dispatch does not support zero-sized POD outputs.".to_string(),
190        });
191    }
192    crate::backend::resize_typed_output_slots(
193        typed_outputs,
194        raw_outputs.len(),
195        "typed POD output",
196    )?;
197    for (index, (bytes, slot)) in raw_outputs.iter().zip(typed_outputs.iter_mut()).enumerate() {
198        decode_pod_output_into(index, bytes, width, slot)?;
199    }
200    Ok(())
201}
202
203fn decode_pod_output_into<T: Pod>(
204    index: usize,
205    bytes: &[u8],
206    width: usize,
207    output: &mut Vec<T>,
208) -> Result<(), BackendError> {
209    let remainder = bytes.len() % width;
210    if remainder != 0 {
211        return Err(BackendError::InvalidProgram {
212            fix: format!(
213                "Fix: output buffer {index} has {} bytes, which is not a whole number of {}-byte typed values.",
214                bytes.len(),
215                width
216            ),
217        });
218    }
219    output.clear();
220    let value_count = bytes.len() / width;
221    crate::allocation::try_reserve_vec_to_capacity(output, value_count).map_err(|error| {
222        BackendError::InvalidProgram {
223            fix: format!(
224                "Fix: typed dispatch could not reserve {value_count} decoded POD value(s) for output buffer {index}: {error}. Decode into caller-owned output storage or shard the dispatch output."
225            ),
226        }
227    })?;
228    output.extend(
229        bytes
230            .chunks_exact(width)
231            .map(bytemuck::pod_read_unaligned::<T>),
232    );
233    Ok(())
234}
235
236#[cfg(test)]
237mod tests {
238    use std::collections::HashSet;
239
240    use vyre_foundation::ir::{OpId, Program};
241
242    use super::*;
243    use crate::backend::private;
244
245    struct EchoBackend;
246
247    impl private::Sealed for EchoBackend {}
248
249    impl VyreBackend for EchoBackend {
250        fn id(&self) -> &'static str {
251            "typed-dispatch-test"
252        }
253
254        fn supported_ops(&self) -> &HashSet<OpId> {
255            static OPS: std::sync::OnceLock<HashSet<OpId>> = std::sync::OnceLock::new();
256            OPS.get_or_init(HashSet::new)
257        }
258
259        fn dispatch(
260            &self,
261            _program: &Program,
262            inputs: &[Vec<u8>],
263            _config: &DispatchConfig,
264        ) -> Result<Vec<Vec<u8>>, BackendError> {
265            Ok(inputs.to_vec())
266        }
267    }
268
269    #[test]
270    fn dispatch_u32_packs_inputs_and_decodes_outputs() {
271        let backend = EchoBackend;
272        let input = [1u32, 2, 0x0102_0304];
273        let outputs = backend
274            .dispatch_u32(&Program::empty(), &[&input], &DispatchConfig::default())
275            .unwrap_or_else(|error| panic!("typed u32 dispatch must succeed: {error}"));
276
277        assert_eq!(outputs, vec![input.to_vec()]);
278    }
279
280    #[test]
281    fn typed_decode_rejects_partial_words() {
282        let error = decode_pod_outputs::<u32>(vec![vec![1, 2, 3]])
283            .expect_err("partial u32 output must fail");
284
285        assert!(
286            error.to_string().contains("whole number of 4-byte"),
287            "malformed typed output must produce actionable width error: {error}"
288        );
289    }
290
291    #[test]
292    fn dispatch_u32_into_reuses_raw_and_typed_output_slots() {
293        let backend = EchoBackend;
294        let input = [1u32, 2, 0x0102_0304];
295        let mut raw_outputs = vec![Vec::with_capacity(16)];
296        let mut typed_outputs = vec![Vec::with_capacity(3)];
297        let raw_outer = raw_outputs.as_ptr();
298        let raw_slot = raw_outputs[0].as_ptr();
299        let typed_outer = typed_outputs.as_ptr();
300        let typed_slot = typed_outputs[0].as_ptr();
301
302        backend
303            .dispatch_u32_into(
304                &Program::empty(),
305                &[&input],
306                &DispatchConfig::default(),
307                &mut raw_outputs,
308                &mut typed_outputs,
309            )
310            .unwrap_or_else(|error| panic!("typed u32 into dispatch must succeed: {error}"));
311        assert_eq!(typed_outputs, vec![input.to_vec()]);
312        assert_eq!(raw_outputs.as_ptr(), raw_outer);
313        assert_eq!(raw_outputs[0].as_ptr(), raw_slot);
314        assert_eq!(typed_outputs.as_ptr(), typed_outer);
315        assert_eq!(typed_outputs[0].as_ptr(), typed_slot);
316
317        backend
318            .dispatch_u32_into(
319                &Program::empty(),
320                &[&input],
321                &DispatchConfig::default(),
322                &mut raw_outputs,
323                &mut typed_outputs,
324            )
325            .unwrap_or_else(|error| panic!("second typed u32 into dispatch must succeed: {error}"));
326        assert_eq!(typed_outputs, vec![input.to_vec()]);
327        assert_eq!(raw_outputs.as_ptr(), raw_outer);
328        assert_eq!(raw_outputs[0].as_ptr(), raw_slot);
329        assert_eq!(typed_outputs.as_ptr(), typed_outer);
330        assert_eq!(typed_outputs[0].as_ptr(), typed_slot);
331    }
332}