tract_core/ops/
submodel.rs1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9 pub model: Box<dyn InnerModel>,
10 label: String,
11 decluttered: bool,
12 codegen: bool,
13}
14
15impl SubmodelOp {
16 pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
17 Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
18 }
19
20 pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
21 None
22 }
23
24 pub fn model(&self) -> &TypedModel {
25 self.model.as_typed()
26 }
27
28 pub fn label(&self) -> &str {
29 self.label.as_str()
30 }
31}
32
33impl Op for SubmodelOp {
34 fn name(&self) -> StaticName {
35 "SubmodelOp".into()
36 }
37
38 op_as_typed_op!();
39}
40
41impl EvalOp for SubmodelOp {
42 fn is_stateless(&self) -> bool {
43 false
44 }
45
46 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
47 self.model.state(session, node_id)
48 }
49}
50
51impl TypedOp for SubmodelOp {
52 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
53 let facts = self.model.output_facts(inputs)?;
54 Ok(facts)
55 }
56
57 fn declutter(
58 &self,
59 model: &TypedModel,
60 node: &TypedNode,
61 ) -> TractResult<Option<TypedModelPatch>> {
62 if !self.decluttered {
63 let mut new = self.clone();
64 new.model.declutter()?;
65 new.decluttered = true;
66 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
67 } else {
68 Ok(None)
69 }
70 }
71
72 fn codegen(
73 &self,
74 model: &TypedModel,
75 node: &TypedNode,
76 ) -> TractResult<Option<TypedModelPatch>> {
77 if !self.codegen {
78 let mut new = self.clone();
79 new.model.codegen()?;
80 new.codegen = true;
81 Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
82 } else {
83 Ok(None)
84 }
85 }
86
87 as_op!();
88}
89
90pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
91 #[allow(unused_variables)]
92 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
93 fn is_stateless(&self) -> bool;
94
95 #[allow(unused_variables)]
96 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
97 Ok(None)
98 }
99
100 #[allow(unused_variables)]
101 fn declutter(&mut self) -> TractResult<()>;
102
103 fn codegen(&mut self) -> TractResult<()>;
104
105 fn as_typed(&self) -> &TypedModel;
106}
107
108dyn_clone::clone_trait_object!(InnerModel);
109downcast_rs::impl_downcast!(InnerModel);
110
111impl InnerModel for TypedModel {
112 fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
113 let facts = self
114 .output_outlets()?
115 .iter()
116 .map(|outlet| self.outlet_fact(*outlet).cloned())
117 .collect::<TractResult<TVec<_>>>()?;
118 Ok(facts)
119 }
120 fn is_stateless(&self) -> bool {
121 false
122 }
123
124 #[allow(unused_variables)]
125 fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
126 let plan = self.clone().into_runnable()?;
127 let state = plan.spawn()?;
128 Ok(Some(Box::new(state)))
129 }
130
131 #[allow(unused_variables)]
132 fn declutter(&mut self) -> TractResult<()> {
133 self.declutter()
134 }
135
136 fn codegen(&mut self) -> TractResult<()> {
137 self.optimize()
138 }
139
140 fn as_typed(&self) -> &TypedModel {
141 self
142 }
143}
144
145pub type TypedModelOpState = TypedSimpleState;
146
147impl OpState for TypedModelOpState {
148 fn eval(
149 &mut self,
150 _session: &mut TurnState,
151 _op: &dyn Op,
152 inputs: TVec<TValue>,
153 ) -> TractResult<TVec<TValue>> {
154 let inference_out = self.run(inputs)?;
155 Ok(inference_out)
156 }
157}
158
159pub type FrozenSubmodelOpState = TypedFrozenSimpleState;
160
161impl FrozenOpState for FrozenSubmodelOpState {
162 fn unfreeze(&self) -> Box<dyn OpState> {
163 Box::new(self.unfreeze())
164 }
165}
166
167impl OpStateFreeze for TypedModelOpState {
168 fn freeze(&self) -> Box<dyn FrozenOpState> {
169 Box::new(self.freeze())
170 }
171}