Skip to main content

tract_core/
runtime.rs

1use std::any::Any;
2use std::fmt::Debug;
3
4use downcast_rs::Downcast;
5use dyn_clone::DynClone;
6use lazy_static::lazy_static;
7use tract_linalg::multithread::Executor;
8
9use crate::internal::*;
10
11#[derive(Clone, Debug, Default)]
12pub struct RunOptions {
13    /// Use the simple ordering instead of the newer memory friendly one
14    pub skip_order_opt_ram: bool,
15
16    /// Override default global executor
17    pub executor: Option<Executor>,
18
19    /// Memory sizing hints
20    pub memory_sizing_hints: Option<SymbolValues>,
21}
22
23pub trait Runtime: Debug + Send + Sync + 'static {
24    fn name(&self) -> StaticName;
25    fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
26        self.prepare_with_options(model, &Default::default())
27    }
28    fn prepare_with_options(
29        &self,
30        model: TypedModel,
31        options: &RunOptions,
32    ) -> TractResult<Box<dyn Runnable>>;
33}
34
35pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
36    fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37        self.spawn()?.run(inputs)
38    }
39    fn spawn(&self) -> TractResult<Box<dyn State>>;
40    fn input_count(&self) -> usize {
41        self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
42    }
43    fn output_count(&self) -> usize {
44        self.typed_model()
45            .context("Fallback implementation on typed_model()")
46            .unwrap()
47            .outputs
48            .len()
49    }
50    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
51        self.typed_model()
52            .context("Fallback implementation on typed_model()")
53            .unwrap()
54            .input_fact(ix)
55    }
56    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
57        self.typed_model()
58            .context("Fallback implementation on typed_model()")
59            .unwrap()
60            .output_fact(ix)
61    }
62    fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
63        lazy_static! {
64            static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
65        };
66        self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
67    }
68
69    fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
70    fn typed_model(&self) -> Option<&Arc<TypedModel>>;
71}
72impl_downcast!(Runnable);
73
74pub trait State: Any + Downcast + Debug + 'static {
75    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
76
77    fn runnable(&self) -> &dyn Runnable;
78
79    fn initializable_states_count(&self) -> usize;
80    fn get_states_facts(&self) -> Vec<TypedFact>;
81    fn init_state(&mut self, states: &[TValue]) -> TractResult<()>;
82    fn get_states(&self) -> TractResult<Vec<TValue>>;
83    fn input_count(&self) -> usize {
84        self.runnable().input_count()
85    }
86
87    fn output_count(&self) -> usize {
88        self.runnable().input_count()
89    }
90
91    fn freeze(&self) -> Box<dyn FrozenState>;
92}
93impl_downcast!(State);
94
95pub trait FrozenState: Any + Debug + DynClone + Send {
96    fn unfreeze(&self) -> Box<dyn State>;
97}
98dyn_clone::clone_trait_object!(FrozenState);
99
100#[derive(Debug)]
101pub struct DefaultRuntime;
102
103impl Runtime for DefaultRuntime {
104    fn name(&self) -> StaticName {
105        Cow::Borrowed("default")
106    }
107
108    fn prepare_with_options(
109        &self,
110        model: TypedModel,
111        options: &RunOptions,
112    ) -> TractResult<Box<dyn Runnable>> {
113        let model = model.into_optimized()?;
114        Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
115    }
116}
117
118impl Runnable for Arc<TypedRunnableModel> {
119    fn spawn(&self) -> TractResult<Box<dyn State>> {
120        Ok(Box::new(self.spawn()?))
121    }
122
123    fn typed_plan(&self) -> Option<&Self> {
124        Some(self)
125    }
126
127    fn typed_model(&self) -> Option<&Arc<TypedModel>> {
128        Some(&self.model)
129    }
130
131    fn input_count(&self) -> usize {
132        self.model.inputs.len()
133    }
134
135    fn output_count(&self) -> usize {
136        self.model.outputs.len()
137    }
138
139    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
140        self.model.input_fact(ix)
141    }
142    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
143        self.model.output_fact(ix)
144    }
145}
146
147impl State for TypedSimpleState {
148    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
149        self.run(inputs)
150    }
151
152    fn runnable(&self) -> &dyn Runnable {
153        &self.plan
154    }
155
156    fn initializable_states_count(&self) -> usize {
157        self.op_states
158            .iter()
159            .filter_map(Option::as_ref)
160            .filter(|s| s.init_tensor_fact().is_some())
161            .count()
162    }
163
164    fn get_states_facts(&self) -> Vec<TypedFact> {
165        self.op_states
166            .iter()
167            .filter_map(|s| s.as_ref().and_then(|s| s.init_tensor_fact().map(|(_, fact)| fact)))
168            .collect()
169    }
170
171    fn init_state(&mut self, states: &[TValue]) -> TractResult<()> {
172        self.init_states(states)
173    }
174
175    fn get_states(&self) -> TractResult<Vec<TValue>> {
176        let mut states = vec![];
177        for op_state in self.op_states.iter().flatten() {
178            if op_state.init_tensor_fact().is_some() {
179                op_state.save_to(&mut states)?;
180            }
181        }
182        Ok(states)
183    }
184
185    fn freeze(&self) -> Box<dyn FrozenState> {
186        Box::new(TypedSimpleState::freeze(self))
187    }
188}
189
190impl FrozenState for TypedFrozenSimpleState {
191    fn unfreeze(&self) -> Box<dyn State> {
192        Box::new(TypedFrozenSimpleState::unfreeze(self))
193    }
194}
195
196pub struct InventorizedRuntime(pub &'static dyn Runtime);
197
198impl Runtime for InventorizedRuntime {
199    fn name(&self) -> StaticName {
200        self.0.name()
201    }
202
203    fn prepare_with_options(
204        &self,
205        model: TypedModel,
206        options: &RunOptions,
207    ) -> TractResult<Box<dyn Runnable>> {
208        self.0.prepare_with_options(model, options)
209    }
210}
211
212impl Debug for InventorizedRuntime {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        self.0.fmt(f)
215    }
216}
217
218inventory::collect!(InventorizedRuntime);
219
220pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
221    inventory::iter::<InventorizedRuntime>().map(|ir| ir.0)
222}
223
224pub fn runtime_for_name(s: &str) -> Option<&'static dyn Runtime> {
225    runtimes().find(|rt| rt.name() == s)
226}
227
228#[macro_export]
229macro_rules! register_runtime {
230    ($type: ty= $val:expr) => {
231        static D: $type = $val;
232        inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
233    };
234}
235
236register_runtime!(DefaultRuntime = DefaultRuntime);