tract_core/ops/
submodel.rs1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9 pub model: Box<dyn InnerModel>,
10 label: String,
11 decluttered: bool,
12 codegen: bool,
13}
14
15impl PartialEq for SubmodelOp {
16 fn eq(&self, _other: &Self) -> bool {
17 false
18 }
19}
20impl Eq for SubmodelOp {}
21
22impl SubmodelOp {
23 pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
24 Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
25 }
26
27 pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
28 None
29 }
30
31 pub fn model(&self) -> &TypedModel {
32 self.model.as_typed()
33 }
34
35 pub fn label(&self) -> &str {
36 self.label.as_str()
37 }
38}
39
40impl Op for SubmodelOp {
41 fn name(&self) -> StaticName {
42 "SubmodelOp".into()
43 }
44
45 op_as_typed_op!();
46}
47
48impl EvalOp for SubmodelOp {
49 fn is_stateless(&self) -> bool {
50 false
51 }
52
53 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
54 self.model.state(session, node_id)
55 }
56}
57
58impl TypedOp for SubmodelOp {
59 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
60 let facts = self.model.output_facts(inputs)?;
61 Ok(facts)
62 }
63
64 fn declutter(
65 &self,
66 model: &TypedModel,
67 node: &TypedNode,
68 ) -> TractResult<Option<TypedModelPatch>> {
69 if !self.decluttered {
70 let mut new = self.clone();
71 new.model.declutter()?;
72 new.decluttered = true;
73 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
74 } else {
75 Ok(None)
76 }
77 }
78
79 fn codegen(
80 &self,
81 model: &TypedModel,
82 node: &TypedNode,
83 ) -> TractResult<Option<TypedModelPatch>> {
84 if !self.codegen {
85 let mut new = self.clone();
86 new.model.codegen()?;
87 new.codegen = true;
88 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
89 } else {
90 Ok(None)
91 }
92 }
93
94 as_op!();
95}
96
97pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
98 #[allow(unused_variables)]
99 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
100 fn is_stateless(&self) -> bool;
101
102 #[allow(unused_variables)]
103 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
104 Ok(None)
105 }
106
107 #[allow(unused_variables)]
108 fn declutter(&mut self) -> TractResult<()>;
109
110 fn codegen(&mut self) -> TractResult<()>;
111
112 fn as_typed(&self) -> &TypedModel;
113}
114
115dyn_clone::clone_trait_object!(InnerModel);
116downcast_rs::impl_downcast!(InnerModel);
117
118impl InnerModel for TypedModel {
119 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
120 let facts = self
121 .output_outlets()?
122 .iter()
123 .map(|outlet| self.outlet_fact(*outlet).cloned())
124 .collect::<TractResult<TVec<_>>>()?;
125 Ok(facts)
126 }
127 fn is_stateless(&self) -> bool {
128 false
129 }
130
131 #[allow(unused_variables)]
132 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
133 let plan = self.clone().into_runnable()?;
134 let state = plan.spawn()?;
135 Ok(Some(Box::new(state)))
136 }
137
138 #[allow(unused_variables)]
139 fn declutter(&mut self) -> TractResult<()> {
140 self.declutter()
141 }
142
143 fn codegen(&mut self) -> TractResult<()> {
144 self.optimize()
145 }
146
147 fn as_typed(&self) -> &TypedModel {
148 self
149 }
150}
151
152pub type TypedModelOpState = TypedSimpleState;
153
154impl OpState for TypedModelOpState {
155 fn eval(
156 &mut self,
157 _session: &mut TurnState,
158 _op: &dyn Op,
159 inputs: TVec<TValue>,
160 ) -> TractResult<TVec<TValue>> {
161 let inference_out = self.run(inputs)?;
162 Ok(inference_out)
163 }
164}
165
166pub type FrozenSubmodelOpState = TypedFrozenSimpleState;
167
168impl FrozenOpState for FrozenSubmodelOpState {
169 fn unfreeze(&self) -> Box<dyn OpState> {
170 Box::new(self.unfreeze())
171 }
172}
173
174impl OpStateFreeze for TypedModelOpState {
175 fn freeze(&self) -> Box<dyn FrozenOpState> {
176 Box::new(self.freeze())
177 }
178}