1use std::mem;
4
5use bytemuck::Pod;
6use smallvec::SmallVec;
7use vyre_foundation::ir::Program;
8
9use crate::backend::{BackendError, DispatchConfig, OutputBuffers, VyreBackend};
10
11pub trait TypedDispatchExt: VyreBackend {
14 fn dispatch_bytes(
24 &self,
25 program: &Program,
26 inputs: &[&[u8]],
27 config: &DispatchConfig,
28 ) -> Result<Vec<Vec<u8>>, BackendError> {
29 self.dispatch_borrowed(program, inputs, config)
30 }
31
32 fn dispatch_pod<T: Pod>(
39 &self,
40 program: &Program,
41 inputs: &[&[T]],
42 config: &DispatchConfig,
43 ) -> Result<Vec<Vec<T>>, BackendError> {
44 let byte_inputs = pod_input_byte_slices(inputs)?;
45 let outputs = self.dispatch_borrowed(program, &byte_inputs, config)?;
46 decode_pod_outputs(outputs)
47 }
48
49 fn dispatch_pod_into<T: Pod>(
62 &self,
63 program: &Program,
64 inputs: &[&[T]],
65 config: &DispatchConfig,
66 raw_outputs: &mut OutputBuffers,
67 typed_outputs: &mut Vec<Vec<T>>,
68 ) -> Result<(), BackendError> {
69 let byte_inputs = pod_input_byte_slices(inputs)?;
70 self.dispatch_borrowed_into(program, &byte_inputs, config, raw_outputs)?;
71 decode_pod_outputs_into(raw_outputs, typed_outputs)
72 }
73
74 fn dispatch_u32(
80 &self,
81 program: &Program,
82 inputs: &[&[u32]],
83 config: &DispatchConfig,
84 ) -> Result<Vec<Vec<u32>>, BackendError> {
85 self.dispatch_pod(program, inputs, config)
86 }
87
88 fn dispatch_u32_into(
95 &self,
96 program: &Program,
97 inputs: &[&[u32]],
98 config: &DispatchConfig,
99 raw_outputs: &mut OutputBuffers,
100 typed_outputs: &mut Vec<Vec<u32>>,
101 ) -> Result<(), BackendError> {
102 self.dispatch_pod_into(program, inputs, config, raw_outputs, typed_outputs)
103 }
104
105 fn dispatch_f32(
111 &self,
112 program: &Program,
113 inputs: &[&[f32]],
114 config: &DispatchConfig,
115 ) -> Result<Vec<Vec<f32>>, BackendError> {
116 self.dispatch_pod(program, inputs, config)
117 }
118
119 fn dispatch_f32_into(
126 &self,
127 program: &Program,
128 inputs: &[&[f32]],
129 config: &DispatchConfig,
130 raw_outputs: &mut OutputBuffers,
131 typed_outputs: &mut Vec<Vec<f32>>,
132 ) -> Result<(), BackendError> {
133 self.dispatch_pod_into(program, inputs, config, raw_outputs, typed_outputs)
134 }
135}
136
137impl<T: VyreBackend + ?Sized> TypedDispatchExt for T {}
138
139fn pod_input_byte_slices<'a, T: Pod>(
140 inputs: &'a [&'a [T]],
141) -> Result<SmallVec<[&'a [u8]; 8]>, BackendError> {
142 let mut byte_inputs = SmallVec::<[&[u8]; 8]>::new();
143 byte_inputs.try_reserve(inputs.len()).map_err(|error| {
144 BackendError::InvalidProgram {
145 fix: format!(
146 "Fix: typed dispatch could not reserve {} POD input byte slice(s): {error}. Reuse caller-owned byte slices or shard the typed dispatch.",
147 inputs.len()
148 ),
149 }
150 })?;
151 byte_inputs.extend(
152 inputs
153 .iter()
154 .map(|input| bytemuck::cast_slice::<T, u8>(input)),
155 );
156 Ok(byte_inputs)
157}
158
159fn decode_pod_outputs<T: Pod>(outputs: Vec<Vec<u8>>) -> Result<Vec<Vec<T>>, BackendError> {
160 let width = mem::size_of::<T>();
161 if width == 0 {
162 return Err(BackendError::InvalidProgram {
163 fix: "Fix: typed dispatch does not support zero-sized POD outputs.".to_string(),
164 });
165 }
166 let mut typed_outputs = Vec::new();
167 crate::backend::resize_typed_output_slots(
168 &mut typed_outputs,
169 outputs.len(),
170 "typed POD output",
171 )?;
172 for (index, (bytes, slot)) in outputs
173 .into_iter()
174 .zip(typed_outputs.iter_mut())
175 .enumerate()
176 {
177 decode_pod_output_into(index, &bytes, width, slot)?;
178 }
179 Ok(typed_outputs)
180}
181
182fn decode_pod_outputs_into<T: Pod>(
183 raw_outputs: &[Vec<u8>],
184 typed_outputs: &mut Vec<Vec<T>>,
185) -> Result<(), BackendError> {
186 let width = mem::size_of::<T>();
187 if width == 0 {
188 return Err(BackendError::InvalidProgram {
189 fix: "Fix: typed dispatch does not support zero-sized POD outputs.".to_string(),
190 });
191 }
192 crate::backend::resize_typed_output_slots(
193 typed_outputs,
194 raw_outputs.len(),
195 "typed POD output",
196 )?;
197 for (index, (bytes, slot)) in raw_outputs.iter().zip(typed_outputs.iter_mut()).enumerate() {
198 decode_pod_output_into(index, bytes, width, slot)?;
199 }
200 Ok(())
201}
202
203fn decode_pod_output_into<T: Pod>(
204 index: usize,
205 bytes: &[u8],
206 width: usize,
207 output: &mut Vec<T>,
208) -> Result<(), BackendError> {
209 let remainder = bytes.len() % width;
210 if remainder != 0 {
211 return Err(BackendError::InvalidProgram {
212 fix: format!(
213 "Fix: output buffer {index} has {} bytes, which is not a whole number of {}-byte typed values.",
214 bytes.len(),
215 width
216 ),
217 });
218 }
219 output.clear();
220 let value_count = bytes.len() / width;
221 crate::allocation::try_reserve_vec_to_capacity(output, value_count).map_err(|error| {
222 BackendError::InvalidProgram {
223 fix: format!(
224 "Fix: typed dispatch could not reserve {value_count} decoded POD value(s) for output buffer {index}: {error}. Decode into caller-owned output storage or shard the dispatch output."
225 ),
226 }
227 })?;
228 output.extend(
229 bytes
230 .chunks_exact(width)
231 .map(bytemuck::pod_read_unaligned::<T>),
232 );
233 Ok(())
234}
235
236#[cfg(test)]
237mod tests {
238 use std::collections::HashSet;
239
240 use vyre_foundation::ir::{OpId, Program};
241
242 use super::*;
243 use crate::backend::private;
244
245 struct EchoBackend;
246
247 impl private::Sealed for EchoBackend {}
248
249 impl VyreBackend for EchoBackend {
250 fn id(&self) -> &'static str {
251 "typed-dispatch-test"
252 }
253
254 fn supported_ops(&self) -> &HashSet<OpId> {
255 static OPS: std::sync::OnceLock<HashSet<OpId>> = std::sync::OnceLock::new();
256 OPS.get_or_init(HashSet::new)
257 }
258
259 fn dispatch(
260 &self,
261 _program: &Program,
262 inputs: &[Vec<u8>],
263 _config: &DispatchConfig,
264 ) -> Result<Vec<Vec<u8>>, BackendError> {
265 Ok(inputs.to_vec())
266 }
267 }
268
269 #[test]
270 fn dispatch_u32_packs_inputs_and_decodes_outputs() {
271 let backend = EchoBackend;
272 let input = [1u32, 2, 0x0102_0304];
273 let outputs = backend
274 .dispatch_u32(&Program::empty(), &[&input], &DispatchConfig::default())
275 .unwrap_or_else(|error| panic!("typed u32 dispatch must succeed: {error}"));
276
277 assert_eq!(outputs, vec![input.to_vec()]);
278 }
279
280 #[test]
281 fn typed_decode_rejects_partial_words() {
282 let error = decode_pod_outputs::<u32>(vec![vec![1, 2, 3]])
283 .expect_err("partial u32 output must fail");
284
285 assert!(
286 error.to_string().contains("whole number of 4-byte"),
287 "malformed typed output must produce actionable width error: {error}"
288 );
289 }
290
291 #[test]
292 fn dispatch_u32_into_reuses_raw_and_typed_output_slots() {
293 let backend = EchoBackend;
294 let input = [1u32, 2, 0x0102_0304];
295 let mut raw_outputs = vec![Vec::with_capacity(16)];
296 let mut typed_outputs = vec![Vec::with_capacity(3)];
297 let raw_outer = raw_outputs.as_ptr();
298 let raw_slot = raw_outputs[0].as_ptr();
299 let typed_outer = typed_outputs.as_ptr();
300 let typed_slot = typed_outputs[0].as_ptr();
301
302 backend
303 .dispatch_u32_into(
304 &Program::empty(),
305 &[&input],
306 &DispatchConfig::default(),
307 &mut raw_outputs,
308 &mut typed_outputs,
309 )
310 .unwrap_or_else(|error| panic!("typed u32 into dispatch must succeed: {error}"));
311 assert_eq!(typed_outputs, vec![input.to_vec()]);
312 assert_eq!(raw_outputs.as_ptr(), raw_outer);
313 assert_eq!(raw_outputs[0].as_ptr(), raw_slot);
314 assert_eq!(typed_outputs.as_ptr(), typed_outer);
315 assert_eq!(typed_outputs[0].as_ptr(), typed_slot);
316
317 backend
318 .dispatch_u32_into(
319 &Program::empty(),
320 &[&input],
321 &DispatchConfig::default(),
322 &mut raw_outputs,
323 &mut typed_outputs,
324 )
325 .unwrap_or_else(|error| panic!("second typed u32 into dispatch must succeed: {error}"));
326 assert_eq!(typed_outputs, vec![input.to_vec()]);
327 assert_eq!(raw_outputs.as_ptr(), raw_outer);
328 assert_eq!(raw_outputs[0].as_ptr(), raw_slot);
329 assert_eq!(typed_outputs.as_ptr(), typed_outer);
330 assert_eq!(typed_outputs[0].as_ptr(), typed_slot);
331 }
332}