vyre_driver/backend/
capability.rs1use super::{BackendError, DispatchConfig, VyreBackend};
4use std::collections::HashSet;
5use vyre_foundation::ir::OpId;
6use vyre_foundation::ir::Program;
7
8pub type MemoryRef<'a> = &'a [u8];
10
11pub type Memory = Vec<u8>;
13
14pub trait Backend: Send + Sync {
16 fn id(&self) -> &'static str;
18 fn version(&self) -> &'static str;
20 fn supported_ops(&self) -> &HashSet<OpId>;
22}
23
24impl<T: VyreBackend + ?Sized> Backend for T {
25 fn id(&self) -> &'static str {
26 VyreBackend::id(self)
27 }
28
29 fn version(&self) -> &'static str {
30 VyreBackend::version(self)
31 }
32
33 fn supported_ops(&self) -> &HashSet<OpId> {
34 VyreBackend::supported_ops(self)
35 }
36}
37
38pub trait Executable: Backend {
40 fn dispatch(
42 &self,
43 program: &Program,
44 inputs: &[MemoryRef<'_>],
45 config: &DispatchConfig,
46 ) -> Result<Vec<Memory>, BackendError>;
47}
48
49pub trait Streamable: Backend {
51 fn stream(
53 &self,
54 program: &Program,
55 chunks: &mut dyn Iterator<Item = MemoryRef<'_>>,
56 config: &DispatchConfig,
57 ) -> Result<Box<dyn Iterator<Item = Result<Memory, BackendError>>>, BackendError>;
58}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63 use std::collections::HashSet;
64 use vyre_foundation::ir::Program;
65
66 struct EchoStreamable {
67 ops: HashSet<OpId>,
68 }
69
70 impl EchoStreamable {
71 fn new() -> Self {
72 Self {
73 ops: HashSet::new(),
74 }
75 }
76 }
77
78 impl Backend for EchoStreamable {
79 fn id(&self) -> &'static str {
80 "echo-streamable"
81 }
82
83 fn version(&self) -> &'static str {
84 "0.4.1-test"
85 }
86
87 fn supported_ops(&self) -> &HashSet<OpId> {
88 &self.ops
89 }
90 }
91
92 impl Streamable for EchoStreamable {
93 fn stream(
94 &self,
95 _program: &Program,
96 chunks: &mut dyn Iterator<Item = MemoryRef<'_>>,
97 _config: &DispatchConfig,
98 ) -> Result<Box<dyn Iterator<Item = Result<Memory, BackendError>>>, BackendError> {
99 let outputs = chunks
100 .map(|chunk| Ok(chunk.to_vec()))
101 .collect::<Vec<Result<Memory, BackendError>>>();
102 Ok(Box::new(outputs.into_iter()))
103 }
104 }
105
106 #[test]
107 fn streamable_is_object_safe() {
108 let backend: Box<dyn Streamable> = Box::new(EchoStreamable::new());
109 let program = Program::empty();
110 let chunks = [b"ab".as_slice(), b"cd".as_slice()];
111 let mut iter = chunks.into_iter();
112 let outputs = backend
113 .stream(&program, &mut iter, &DispatchConfig::default())
114 .expect("Fix: object-safe Streamable dispatch must succeed")
115 .collect::<Result<Vec<_>, _>>()
116 .expect("Fix: object-safe Streamable iterator must yield owned buffers");
117 assert_eq!(outputs, vec![b"ab".to_vec(), b"cd".to_vec()]);
118 }
119}