Skip to main content

tract_core/
runtime.rs

1use std::any::Any;
2use std::fmt::Debug;
3
4use downcast_rs::Downcast;
5use dyn_clone::DynClone;
6use lazy_static::lazy_static;
7use tract_linalg::multithread::Executor;
8
9use crate::internal::*;
10
11#[derive(Clone, Debug, Default)]
12pub struct RunOptions {
13    /// Use the simple ordering instead of the newer memory friendly one
14    pub skip_order_opt_ram: bool,
15
16    /// Override default global executor
17    pub executor: Option<Executor>,
18
19    /// Memory sizing hints
20    pub memory_sizing_hints: Option<SymbolValues>,
21}
22
23pub trait Runtime: Debug + Send + Sync + 'static {
24    fn name(&self) -> StaticName;
25    fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
26        self.prepare_with_options(model, &Default::default())
27    }
28    fn check(&self) -> TractResult<()>;
29    fn prepare_with_options(
30        &self,
31        model: TypedModel,
32        options: &RunOptions,
33    ) -> TractResult<Box<dyn Runnable>>;
34}
35
36pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
37    fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
38        self.spawn()?.run(inputs)
39    }
40    fn spawn(&self) -> TractResult<Box<dyn State>>;
41    fn input_count(&self) -> usize {
42        self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
43    }
44    fn output_count(&self) -> usize {
45        self.typed_model()
46            .context("Fallback implementation on typed_model()")
47            .unwrap()
48            .outputs
49            .len()
50    }
51    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
52        self.typed_model()
53            .context("Fallback implementation on typed_model()")
54            .unwrap()
55            .input_fact(ix)
56    }
57    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
58        self.typed_model()
59            .context("Fallback implementation on typed_model()")
60            .unwrap()
61            .output_fact(ix)
62    }
63    fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
64        lazy_static! {
65            static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
66        };
67        self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
68    }
69
70    fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
71    fn typed_model(&self) -> Option<&Arc<TypedModel>>;
72}
73impl_downcast!(Runnable);
74
75pub trait State: Any + Downcast + Debug + 'static {
76    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
77
78    fn runnable(&self) -> &dyn Runnable;
79
80    fn input_count(&self) -> usize {
81        self.runnable().input_count()
82    }
83
84    fn output_count(&self) -> usize {
85        self.runnable().output_count()
86    }
87
88    fn freeze(&self) -> Box<dyn FrozenState>;
89}
90impl_downcast!(State);
91
92pub trait FrozenState: Any + Debug + DynClone + Send {
93    fn unfreeze(&self) -> Box<dyn State>;
94}
95dyn_clone::clone_trait_object!(FrozenState);
96
97#[derive(Debug)]
98pub struct DefaultRuntime;
99
100impl Runtime for DefaultRuntime {
101    fn name(&self) -> StaticName {
102        Cow::Borrowed("default")
103    }
104
105    fn prepare_with_options(
106        &self,
107        model: TypedModel,
108        options: &RunOptions,
109    ) -> TractResult<Box<dyn Runnable>> {
110        let model = model.into_optimized()?;
111        Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
112    }
113
114    fn check(&self) -> TractResult<()> {
115        Ok(())
116    }
117}
118
119impl Runnable for Arc<TypedRunnableModel> {
120    fn spawn(&self) -> TractResult<Box<dyn State>> {
121        Ok(Box::new(self.spawn()?))
122    }
123
124    fn typed_plan(&self) -> Option<&Self> {
125        Some(self)
126    }
127
128    fn typed_model(&self) -> Option<&Arc<TypedModel>> {
129        Some(&self.model)
130    }
131
132    fn input_count(&self) -> usize {
133        self.model.inputs.len()
134    }
135
136    fn output_count(&self) -> usize {
137        self.model.outputs.len()
138    }
139
140    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
141        self.model.input_fact(ix)
142    }
143    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
144        self.model.output_fact(ix)
145    }
146}
147
148impl State for TypedSimpleState {
149    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
150        self.run(inputs)
151    }
152
153    fn runnable(&self) -> &dyn Runnable {
154        &self.plan
155    }
156
157    fn freeze(&self) -> Box<dyn FrozenState> {
158        Box::new(TypedSimpleState::freeze(self))
159    }
160}
161
162impl FrozenState for TypedFrozenSimpleState {
163    fn unfreeze(&self) -> Box<dyn State> {
164        Box::new(TypedFrozenSimpleState::unfreeze(self))
165    }
166}
167
168pub struct InventorizedRuntime(pub &'static dyn Runtime);
169
170impl Runtime for InventorizedRuntime {
171    fn name(&self) -> StaticName {
172        self.0.name()
173    }
174
175    fn prepare_with_options(
176        &self,
177        model: TypedModel,
178        options: &RunOptions,
179    ) -> TractResult<Box<dyn Runnable>> {
180        self.0.prepare_with_options(model, options)
181    }
182
183    fn check(&self) -> TractResult<()> {
184        self.0.check()
185    }
186}
187
188impl Debug for InventorizedRuntime {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        self.0.fmt(f)
191    }
192}
193
194inventory::collect!(InventorizedRuntime);
195
196pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
197    inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
198}
199
200pub fn runtime_for_name(s: &str) -> Option<&'static dyn Runtime> {
201    runtimes().find(|rt| rt.name() == s)
202}
203
204#[macro_export]
205macro_rules! register_runtime {
206    ($type: ty= $val:expr) => {
207        static D: $type = $val;
208        inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
209    };
210}
211
212register_runtime!(DefaultRuntime = DefaultRuntime);