tract_core/ops/
submodel.rs

1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9    pub model: Box<dyn InnerModel>,
10    label: String,
11    decluttered: bool,
12    codegen: bool,
13}
14
15impl SubmodelOp {
16    pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
17        Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
18    }
19
20    pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
21        None
22    }
23
24    pub fn model(&self) -> &TypedModel {
25        self.model.as_typed()
26    }
27
28    pub fn label(&self) -> &str {
29        self.label.as_str()
30    }
31}
32
33impl Op for SubmodelOp {
34    fn name(&self) -> Cow<str> {
35        "SubmodelOp".into()
36    }
37
38    op_as_typed_op!();
39}
40
41impl EvalOp for SubmodelOp {
42    fn is_stateless(&self) -> bool {
43        false
44    }
45
46    fn state(
47        &self,
48        session: &mut SessionState,
49        node_id: usize,
50    ) -> TractResult<Option<Box<dyn OpState>>> {
51        self.model.state(session, node_id)
52    }
53}
54
55impl TypedOp for SubmodelOp {
56    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
57        let facts = self.model.output_facts(inputs)?;
58        Ok(facts)
59    }
60
61    fn declutter(
62        &self,
63        model: &TypedModel,
64        node: &TypedNode,
65    ) -> TractResult<Option<TypedModelPatch>> {
66        if !self.decluttered {
67            let mut new = self.clone();
68            new.model.declutter()?;
69            new.decluttered = true;
70            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
71        } else {
72            Ok(None)
73        }
74    }
75
76    fn codegen(
77        &self,
78        model: &TypedModel,
79        node: &TypedNode,
80    ) -> TractResult<Option<TypedModelPatch>> {
81        if !self.codegen {
82            let mut new = self.clone();
83            new.model.codegen()?;
84            new.codegen = true;
85            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
86        } else {
87            Ok(None)
88        }
89    }
90
91    as_op!();
92}
93
94pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
95    #[allow(unused_variables)]
96    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
97    fn is_stateless(&self) -> bool;
98
99    #[allow(unused_variables)]
100    fn state(
101        &self,
102        session: &mut SessionState,
103        node_id: usize,
104    ) -> TractResult<Option<Box<dyn OpState>>> {
105        Ok(None)
106    }
107
108    #[allow(unused_variables)]
109    fn declutter(&mut self) -> TractResult<()>;
110
111    fn codegen(&mut self) -> TractResult<()>;
112
113    fn as_typed(&self) -> &TypedModel;
114}
115
116dyn_clone::clone_trait_object!(InnerModel);
117downcast_rs::impl_downcast!(InnerModel);
118
119impl InnerModel for TypedModel {
120    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
121        let facts = self
122            .output_outlets()?
123            .iter()
124            .map(|outlet| self.outlet_fact(*outlet).cloned())
125            .collect::<TractResult<TVec<_>>>()?;
126        Ok(facts)
127    }
128    fn is_stateless(&self) -> bool {
129        false
130    }
131
132    #[allow(unused_variables)]
133    fn state(
134        &self,
135        session: &mut SessionState,
136        node_id: usize,
137    ) -> TractResult<Option<Box<dyn OpState>>> {
138        let plan = SimplePlan::new(self.clone())?;
139        let state = SimpleState::new(Arc::new(plan))?;
140        Ok(Some(Box::new(state)))
141    }
142
143    #[allow(unused_variables)]
144    fn declutter(&mut self) -> TractResult<()> {
145        self.declutter()
146    }
147
148    fn codegen(&mut self) -> TractResult<()> {
149        self.optimize()
150    }
151
152    fn as_typed(&self) -> &TypedModel {
153        self
154    }
155}
156
157pub type TypedModelOpState = TypedSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>;
158
159impl OpState for TypedModelOpState {
160    fn eval(
161        &mut self,
162        _session: &mut SessionState,
163        _op: &dyn Op,
164        inputs: TVec<TValue>,
165    ) -> TractResult<TVec<TValue>> {
166        let inference_out = self.run(inputs)?;
167        Ok(inference_out)
168    }
169}
170
171pub type FrozenSubmodelOpState =
172    TypedFrozenSimpleState<TypedModel, Arc<TypedSimplePlan<TypedModel>>>;
173
174impl FrozenOpState for FrozenSubmodelOpState {
175    fn unfreeze(&self) -> Box<dyn OpState> {
176        Box::new(self.unfreeze())
177    }
178}
179
180impl OpStateFreeze for TypedModelOpState {
181    fn freeze(&self) -> Box<dyn FrozenOpState> {
182        Box::new(self.freeze())
183    }
184}