1use std::any::Any;
2use std::fmt::Debug;
3
4use downcast_rs::Downcast;
5use dyn_clone::DynClone;
6use lazy_static::lazy_static;
7use tract_linalg::multithread::Executor;
8
9use crate::internal::*;
10
11#[derive(Clone, Debug, Default)]
12pub struct RunOptions {
13 pub skip_order_opt_ram: bool,
15
16 pub executor: Option<Executor>,
18
19 pub memory_sizing_hints: Option<SymbolValues>,
21}
22
23pub trait Runtime: Debug + Send + Sync + 'static {
24 fn name(&self) -> StaticName;
25 fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
26 self.prepare_with_options(model, &Default::default())
27 }
28 fn check(&self) -> TractResult<()>;
29 fn prepare_with_options(
30 &self,
31 model: TypedModel,
32 options: &RunOptions,
33 ) -> TractResult<Box<dyn Runnable>>;
34}
35
36pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
37 fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
38 self.spawn()?.run(inputs)
39 }
40 fn spawn(&self) -> TractResult<Box<dyn State>>;
41 fn input_count(&self) -> usize {
42 self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
43 }
44 fn output_count(&self) -> usize {
45 self.typed_model()
46 .context("Fallback implementation on typed_model()")
47 .unwrap()
48 .outputs
49 .len()
50 }
51 fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
52 self.typed_model()
53 .context("Fallback implementation on typed_model()")
54 .unwrap()
55 .input_fact(ix)
56 }
57 fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
58 self.typed_model()
59 .context("Fallback implementation on typed_model()")
60 .unwrap()
61 .output_fact(ix)
62 }
63 fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
64 lazy_static! {
65 static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
66 };
67 self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
68 }
69
70 fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
71 fn typed_model(&self) -> Option<&Arc<TypedModel>>;
72}
73impl_downcast!(Runnable);
74
75pub trait State: Any + Downcast + Debug + 'static {
76 fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
77
78 fn runnable(&self) -> &dyn Runnable;
79
80 fn input_count(&self) -> usize {
81 self.runnable().input_count()
82 }
83
84 fn output_count(&self) -> usize {
85 self.runnable().output_count()
86 }
87
88 fn freeze(&self) -> Box<dyn FrozenState>;
89}
90impl_downcast!(State);
91
92pub trait FrozenState: Any + Debug + DynClone + Send {
93 fn unfreeze(&self) -> Box<dyn State>;
94}
95dyn_clone::clone_trait_object!(FrozenState);
96
97#[derive(Debug)]
98pub struct DefaultRuntime;
99
100impl Runtime for DefaultRuntime {
101 fn name(&self) -> StaticName {
102 Cow::Borrowed("default")
103 }
104
105 fn prepare_with_options(
106 &self,
107 model: TypedModel,
108 options: &RunOptions,
109 ) -> TractResult<Box<dyn Runnable>> {
110 let model = model.into_optimized()?;
111 Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
112 }
113
114 fn check(&self) -> TractResult<()> {
115 Ok(())
116 }
117}
118
119impl Runnable for Arc<TypedRunnableModel> {
120 fn spawn(&self) -> TractResult<Box<dyn State>> {
121 Ok(Box::new(self.spawn()?))
122 }
123
124 fn typed_plan(&self) -> Option<&Self> {
125 Some(self)
126 }
127
128 fn typed_model(&self) -> Option<&Arc<TypedModel>> {
129 Some(&self.model)
130 }
131
132 fn input_count(&self) -> usize {
133 self.model.inputs.len()
134 }
135
136 fn output_count(&self) -> usize {
137 self.model.outputs.len()
138 }
139
140 fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
141 self.model.input_fact(ix)
142 }
143 fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
144 self.model.output_fact(ix)
145 }
146}
147
148impl State for TypedSimpleState {
149 fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
150 self.run(inputs)
151 }
152
153 fn runnable(&self) -> &dyn Runnable {
154 &self.plan
155 }
156
157 fn freeze(&self) -> Box<dyn FrozenState> {
158 Box::new(TypedSimpleState::freeze(self))
159 }
160}
161
162impl FrozenState for TypedFrozenSimpleState {
163 fn unfreeze(&self) -> Box<dyn State> {
164 Box::new(TypedFrozenSimpleState::unfreeze(self))
165 }
166}
167
168pub struct InventorizedRuntime(pub &'static dyn Runtime);
169
170impl Runtime for InventorizedRuntime {
171 fn name(&self) -> StaticName {
172 self.0.name()
173 }
174
175 fn prepare_with_options(
176 &self,
177 model: TypedModel,
178 options: &RunOptions,
179 ) -> TractResult<Box<dyn Runnable>> {
180 self.0.prepare_with_options(model, options)
181 }
182
183 fn check(&self) -> TractResult<()> {
184 self.0.check()
185 }
186}
187
188impl Debug for InventorizedRuntime {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 self.0.fmt(f)
191 }
192}
193
194inventory::collect!(InventorizedRuntime);
195
196pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
197 inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
198}
199
200pub fn runtime_for_name(s: &str) -> Option<&'static dyn Runtime> {
201 runtimes().find(|rt| rt.name() == s)
202}
203
204#[macro_export]
205macro_rules! register_runtime {
206 ($type: ty= $val:expr) => {
207 static D: $type = $val;
208 inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
209 };
210}
211
212register_runtime!(DefaultRuntime = DefaultRuntime);