tract_core/ops/
submodel.rs1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9 pub model: Box<dyn InnerModel>,
10 label: String,
11 decluttered: bool,
12 codegen: bool,
13}
14
15impl SubmodelOp {
16 pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
17 Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
18 }
19
20 pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
21 None
22 }
23
24 pub fn model(&self) -> &TypedModel {
25 self.model.as_typed()
26 }
27
28 pub fn label(&self) -> &str {
29 self.label.as_str()
30 }
31}
32
33impl Op for SubmodelOp {
34 fn name(&self) -> Cow<str> {
35 "SubmodelOp".into()
36 }
37
38 op_as_typed_op!();
39}
40
41impl EvalOp for SubmodelOp {
42 fn is_stateless(&self) -> bool {
43 false
44 }
45
46 fn state(
47 &self,
48 session: &mut SessionState,
49 node_id: usize,
50 ) -> TractResult<Option<Box<dyn OpState>>> {
51 self.model.state(session, node_id)
52 }
53}
54
55impl TypedOp for SubmodelOp {
56 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
57 let facts = self.model.output_facts(inputs)?;
58 Ok(facts)
59 }
60
61 fn declutter(
62 &self,
63 model: &TypedModel,
64 node: &TypedNode,
65 ) -> TractResult<Option<TypedModelPatch>> {
66 if !self.decluttered {
67 let mut new = self.clone();
68 new.model.declutter()?;
69 new.decluttered = true;
70 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
71 } else {
72 Ok(None)
73 }
74 }
75
76 fn codegen(
77 &self,
78 model: &TypedModel,
79 node: &TypedNode,
80 ) -> TractResult<Option<TypedModelPatch>> {
81 if !self.codegen {
82 let mut new = self.clone();
83 new.model.codegen()?;
84 new.codegen = true;
85 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
86 } else {
87 Ok(None)
88 }
89 }
90
91 as_op!();
92}
93
94pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
95 #[allow(unused_variables)]
96 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
97 fn is_stateless(&self) -> bool;
98
99 #[allow(unused_variables)]
100 fn state(
101 &self,
102 session: &mut SessionState,
103 node_id: usize,
104 ) -> TractResult<Option<Box<dyn OpState>>> {
105 Ok(None)
106 }
107
108 #[allow(unused_variables)]
109 fn declutter(&mut self) -> TractResult<()>;
110
111 fn codegen(&mut self) -> TractResult<()>;
112
113 fn as_typed(&self) -> &TypedModel;
114}
115
116dyn_clone::clone_trait_object!(InnerModel);
117downcast_rs::impl_downcast!(InnerModel);
118
119impl InnerModel for TypedModel {
120 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
121 let facts = self
122 .output_outlets()?
123 .iter()
124 .map(|outlet| self.outlet_fact(*outlet).cloned())
125 .collect::<TractResult<TVec<_>>>()?;
126 Ok(facts)
127 }
128 fn is_stateless(&self) -> bool {
129 false
130 }
131
132 #[allow(unused_variables)]
133 fn state(
134 &self,
135 session: &mut SessionState,
136 node_id: usize,
137 ) -> TractResult<Option<Box<dyn OpState>>> {
138 let plan = SimplePlan::new(self.clone())?;
139 let state = SimpleState::new(Arc::new(plan))?;
140 Ok(Some(Box::new(state)))
141 }
142
143 #[allow(unused_variables)]
144 fn declutter(&mut self) -> TractResult<()> {
145 self.declutter()
146 }
147
148 fn codegen(&mut self) -> TractResult<()> {
149 self.optimize()
150 }
151
152 fn as_typed(&self) -> &TypedModel {
153 self
154 }
155}
156
157pub type TypedModelOpState = TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>;
158
159impl OpState for TypedModelOpState {
160 fn eval(
161 &mut self,
162 _session: &mut SessionState,
163 _op: &dyn Op,
164 inputs: TVec<TValue>,
165 ) -> TractResult<TVec<TValue>> {
166 let inference_out = self.run(inputs)?;
167 Ok(inference_out)
168 }
169}
170
171pub type FrozenSubmodelOpState =
172 TypedFrozenSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>;
173
174impl FrozenOpState for FrozenSubmodelOpState {
175 fn unfreeze(&self) -> Box<dyn OpState> {
176 Box::new(self.unfreeze())
177 }
178}
179
180impl OpStateFreeze for TypedModelOpState {
181 fn freeze(&self) -> Box<dyn FrozenOpState> {
182 Box::new(self.freeze())
183 }
184}