1use std::any::Any;
2use std::fmt::Debug;
3
4use downcast_rs::Downcast;
5use dyn_clone::DynClone;
6use lazy_static::lazy_static;
7use tract_linalg::multithread::Executor;
8
9use crate::internal::*;
10
11#[derive(Clone, Debug, Default)]
12pub struct RunOptions {
13 pub skip_order_opt_ram: bool,
15
16 pub executor: Option<Executor>,
18
19 pub memory_sizing_hints: Option<SymbolValues>,
21}
22
23pub trait Runtime: Debug + Send + Sync + 'static {
24 fn name(&self) -> StaticName;
25 fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
26 self.prepare_with_options(model, &Default::default())
27 }
28 fn prepare_with_options(
29 &self,
30 model: TypedModel,
31 options: &RunOptions,
32 ) -> TractResult<Box<dyn Runnable>>;
33}
34
35pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
36 fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37 self.spawn()?.run(inputs)
38 }
39 fn spawn(&self) -> TractResult<Box<dyn State>>;
40 fn input_count(&self) -> usize {
41 self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
42 }
43 fn output_count(&self) -> usize {
44 self.typed_model()
45 .context("Fallback implementation on typed_model()")
46 .unwrap()
47 .outputs
48 .len()
49 }
50 fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
51 self.typed_model()
52 .context("Fallback implementation on typed_model()")
53 .unwrap()
54 .input_fact(ix)
55 }
56 fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
57 self.typed_model()
58 .context("Fallback implementation on typed_model()")
59 .unwrap()
60 .output_fact(ix)
61 }
62 fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
63 lazy_static! {
64 static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
65 };
66 self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
67 }
68
69 fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
70 fn typed_model(&self) -> Option<&Arc<TypedModel>>;
71}
72impl_downcast!(Runnable);
73
74pub trait State: Any + Downcast + Debug + 'static {
75 fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
76
77 fn runnable(&self) -> &dyn Runnable;
78
79 fn initializable_states_count(&self) -> usize;
80 fn get_states_facts(&self) -> Vec<TypedFact>;
81 fn init_state(&mut self, states: &[TValue]) -> TractResult<()>;
82 fn get_states(&self) -> TractResult<Vec<TValue>>;
83 fn input_count(&self) -> usize {
84 self.runnable().input_count()
85 }
86
87 fn output_count(&self) -> usize {
88 self.runnable().input_count()
89 }
90
91 fn freeze(&self) -> Box<dyn FrozenState>;
92}
93impl_downcast!(State);
94
95pub trait FrozenState: Any + Debug + DynClone + Send {
96 fn unfreeze(&self) -> Box<dyn State>;
97}
98dyn_clone::clone_trait_object!(FrozenState);
99
100#[derive(Debug)]
101pub struct DefaultRuntime;
102
103impl Runtime for DefaultRuntime {
104 fn name(&self) -> StaticName {
105 Cow::Borrowed("default")
106 }
107
108 fn prepare_with_options(
109 &self,
110 model: TypedModel,
111 options: &RunOptions,
112 ) -> TractResult<Box<dyn Runnable>> {
113 let model = model.into_optimized()?;
114 Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
115 }
116}
117
118impl Runnable for Arc<TypedRunnableModel> {
119 fn spawn(&self) -> TractResult<Box<dyn State>> {
120 Ok(Box::new(self.spawn()?))
121 }
122
123 fn typed_plan(&self) -> Option<&Self> {
124 Some(self)
125 }
126
127 fn typed_model(&self) -> Option<&Arc<TypedModel>> {
128 Some(&self.model)
129 }
130
131 fn input_count(&self) -> usize {
132 self.model.inputs.len()
133 }
134
135 fn output_count(&self) -> usize {
136 self.model.outputs.len()
137 }
138
139 fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
140 self.model.input_fact(ix)
141 }
142 fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
143 self.model.output_fact(ix)
144 }
145}
146
147impl State for TypedSimpleState {
148 fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
149 self.run(inputs)
150 }
151
152 fn runnable(&self) -> &dyn Runnable {
153 &self.plan
154 }
155
156 fn initializable_states_count(&self) -> usize {
157 self.op_states
158 .iter()
159 .filter_map(Option::as_ref)
160 .filter(|s| s.init_tensor_fact().is_some())
161 .count()
162 }
163
164 fn get_states_facts(&self) -> Vec<TypedFact> {
165 self.op_states
166 .iter()
167 .filter_map(|s| s.as_ref().and_then(|s| s.init_tensor_fact().map(|(_, fact)| fact)))
168 .collect()
169 }
170
171 fn init_state(&mut self, states: &[TValue]) -> TractResult<()> {
172 self.init_states(states)
173 }
174
175 fn get_states(&self) -> TractResult<Vec<TValue>> {
176 let mut states = vec![];
177 for op_state in self.op_states.iter().flatten() {
178 if op_state.init_tensor_fact().is_some() {
179 op_state.save_to(&mut states)?;
180 }
181 }
182 Ok(states)
183 }
184
185 fn freeze(&self) -> Box<dyn FrozenState> {
186 Box::new(TypedSimpleState::freeze(self))
187 }
188}
189
190impl FrozenState for TypedFrozenSimpleState {
191 fn unfreeze(&self) -> Box<dyn State> {
192 Box::new(TypedFrozenSimpleState::unfreeze(self))
193 }
194}
195
196pub struct InventorizedRuntime(pub &'static dyn Runtime);
197
198impl Runtime for InventorizedRuntime {
199 fn name(&self) -> StaticName {
200 self.0.name()
201 }
202
203 fn prepare_with_options(
204 &self,
205 model: TypedModel,
206 options: &RunOptions,
207 ) -> TractResult<Box<dyn Runnable>> {
208 self.0.prepare_with_options(model, options)
209 }
210}
211
212impl Debug for InventorizedRuntime {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 self.0.fmt(f)
215 }
216}
217
218inventory::collect!(InventorizedRuntime);
219
220pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
221 inventory::iter::<InventorizedRuntime>().map(|ir| ir.0)
222}
223
224pub fn runtime_for_name(s: &str) -> Option<&'static dyn Runtime> {
225 runtimes().find(|rt| rt.name() == s)
226}
227
228#[macro_export]
229macro_rules! register_runtime {
230 ($type: ty= $val:expr) => {
231 static D: $type = $val;
232 inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
233 };
234}
235
236register_runtime!(DefaultRuntime = DefaultRuntime);