Skip to main content

tract_core/
runtime.rs

1use std::any::Any;
2use std::fmt::Debug;
3
4use downcast_rs::Downcast;
5use dyn_clone::DynClone;
6use lazy_static::lazy_static;
7use tract_linalg::multithread::Executor;
8
9use crate::internal::*;
10
11#[derive(Clone, Debug, Default)]
12pub struct RunOptions {
13    /// Use the simple ordering instead of the newer memory friendly one
14    pub skip_order_opt_ram: bool,
15
16    /// Override default global executor
17    pub executor: Option<Executor>,
18
19    /// Memory sizing hints
20    pub memory_sizing_hints: Option<SymbolValues>,
21}
22
23pub trait Runtime: Debug + Send + Sync + 'static {
24    fn name(&self) -> StaticName;
25    fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
26        self.prepare_with_options(model, &Default::default())
27    }
28    fn check(&self) -> TractResult<()>;
29    fn prepare_with_options(
30        &self,
31        model: TypedModel,
32        options: &RunOptions,
33    ) -> TractResult<Box<dyn Runnable>>;
34}
35
36pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
37    fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
38        self.spawn()?.run(inputs)
39    }
40    fn spawn(&self) -> TractResult<Box<dyn State>>;
41    fn input_count(&self) -> usize {
42        self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
43    }
44    fn output_count(&self) -> usize {
45        self.typed_model()
46            .context("Fallback implementation on typed_model()")
47            .unwrap()
48            .outputs
49            .len()
50    }
51    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
52        self.typed_model()
53            .context("Fallback implementation on typed_model()")
54            .unwrap()
55            .input_fact(ix)
56    }
57    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
58        self.typed_model()
59            .context("Fallback implementation on typed_model()")
60            .unwrap()
61            .output_fact(ix)
62    }
63    fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
64        lazy_static! {
65            static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
66        };
67        self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
68    }
69
70    fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
71    fn typed_model(&self) -> Option<&Arc<TypedModel>>;
72}
73impl_downcast!(Runnable);
74
75pub trait State: Any + Downcast + Debug + 'static {
76    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
77
78    fn runnable(&self) -> &dyn Runnable;
79
80    fn input_count(&self) -> usize {
81        self.runnable().input_count()
82    }
83
84    fn output_count(&self) -> usize {
85        self.runnable().output_count()
86    }
87
88    fn freeze(&self) -> Box<dyn FrozenState>;
89    /// Consuming freeze: moves data instead of cloning.
90    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
91        self.freeze()
92    }
93}
94impl_downcast!(State);
95
96pub trait FrozenState: Any + Debug + DynClone + Send {
97    fn unfreeze(&self) -> Box<dyn State>;
98    fn input_count(&self) -> usize;
99    fn output_count(&self) -> usize;
100}
101dyn_clone::clone_trait_object!(FrozenState);
102
103#[derive(Debug)]
104pub struct DefaultRuntime;
105
106impl Runtime for DefaultRuntime {
107    fn name(&self) -> StaticName {
108        Cow::Borrowed("default")
109    }
110
111    fn prepare_with_options(
112        &self,
113        model: TypedModel,
114        options: &RunOptions,
115    ) -> TractResult<Box<dyn Runnable>> {
116        let model = model.into_optimized()?;
117        Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
118    }
119
120    fn check(&self) -> TractResult<()> {
121        Ok(())
122    }
123}
124
125impl Runnable for Arc<TypedRunnableModel> {
126    fn spawn(&self) -> TractResult<Box<dyn State>> {
127        Ok(Box::new(self.spawn()?))
128    }
129
130    fn typed_plan(&self) -> Option<&Self> {
131        Some(self)
132    }
133
134    fn typed_model(&self) -> Option<&Arc<TypedModel>> {
135        Some(&self.model)
136    }
137
138    fn input_count(&self) -> usize {
139        self.model.inputs.len()
140    }
141
142    fn output_count(&self) -> usize {
143        self.model.outputs.len()
144    }
145
146    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
147        self.model.input_fact(ix)
148    }
149    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
150        self.model.output_fact(ix)
151    }
152}
153
154impl State for TypedSimpleState {
155    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
156        self.run(inputs)
157    }
158
159    fn runnable(&self) -> &dyn Runnable {
160        &self.plan
161    }
162
163    fn freeze(&self) -> Box<dyn FrozenState> {
164        Box::new(TypedSimpleState::freeze(self))
165    }
166
167    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
168        Box::new(TypedSimpleState::freeze_into(*self))
169    }
170}
171
172impl FrozenState for TypedFrozenSimpleState {
173    fn unfreeze(&self) -> Box<dyn State> {
174        Box::new(TypedFrozenSimpleState::unfreeze(self))
175    }
176
177    fn input_count(&self) -> usize {
178        self.plan().model().input_outlets().unwrap().len()
179    }
180
181    fn output_count(&self) -> usize {
182        self.plan().model().output_outlets().unwrap().len()
183    }
184}
185
186pub struct InventorizedRuntime(pub &'static dyn Runtime);
187
188impl Runtime for InventorizedRuntime {
189    fn name(&self) -> StaticName {
190        self.0.name()
191    }
192
193    fn prepare_with_options(
194        &self,
195        model: TypedModel,
196        options: &RunOptions,
197    ) -> TractResult<Box<dyn Runnable>> {
198        self.0.prepare_with_options(model, options)
199    }
200
201    fn check(&self) -> TractResult<()> {
202        self.0.check()
203    }
204}
205
206impl Debug for InventorizedRuntime {
207    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
208        self.0.fmt(f)
209    }
210}
211
212inventory::collect!(InventorizedRuntime);
213
214pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
215    inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
216}
217
218pub fn runtime_for_name(s: &str) -> TractResult<Option<&'static dyn Runtime>> {
219    let Some(rt) = inventory::iter::<InventorizedRuntime>().find(|rt| rt.name() == s) else {
220        return Ok(None);
221    };
222    rt.check()?;
223    Ok(Some(rt.0))
224}
225
226#[macro_export]
227macro_rules! register_runtime {
228    ($type: ty= $val:expr) => {
229        static D: $type = $val;
230        inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
231    };
232}
233
234register_runtime!(DefaultRuntime = DefaultRuntime);