Skip to main content

tract_core/ops/
submodel.rs

1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9    pub model: Box<dyn InnerModel>,
10    label: String,
11    decluttered: bool,
12    codegen: bool,
13}
14
15impl SubmodelOp {
16    pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
17        Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
18    }
19
20    pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
21        None
22    }
23
24    pub fn model(&self) -> &TypedModel {
25        self.model.as_typed()
26    }
27
28    pub fn label(&self) -> &str {
29        self.label.as_str()
30    }
31}
32
33impl Op for SubmodelOp {
34    fn name(&self) -> StaticName {
35        "SubmodelOp".into()
36    }
37
38    op_as_typed_op!();
39}
40
41impl EvalOp for SubmodelOp {
42    fn is_stateless(&self) -> bool {
43        false
44    }
45
46    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
47        self.model.state(session, node_id)
48    }
49}
50
51impl TypedOp for SubmodelOp {
52    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
53        let facts = self.model.output_facts(inputs)?;
54        Ok(facts)
55    }
56
57    fn declutter(
58        &self,
59        model: &TypedModel,
60        node: &TypedNode,
61    ) -> TractResult<Option<TypedModelPatch>> {
62        if !self.decluttered {
63            let mut new = self.clone();
64            new.model.declutter()?;
65            new.decluttered = true;
66            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
67        } else {
68            Ok(None)
69        }
70    }
71
72    fn codegen(
73        &self,
74        model: &TypedModel,
75        node: &TypedNode,
76    ) -> TractResult<Option<TypedModelPatch>> {
77        if !self.codegen {
78            let mut new = self.clone();
79            new.model.codegen()?;
80            new.codegen = true;
81            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
82        } else {
83            Ok(None)
84        }
85    }
86
87    as_op!();
88}
89
90pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
91    #[allow(unused_variables)]
92    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
93    fn is_stateless(&self) -> bool;
94
95    #[allow(unused_variables)]
96    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
97        Ok(None)
98    }
99
100    #[allow(unused_variables)]
101    fn declutter(&mut self) -> TractResult<()>;
102
103    fn codegen(&mut self) -> TractResult<()>;
104
105    fn as_typed(&self) -> &TypedModel;
106}
107
108dyn_clone::clone_trait_object!(InnerModel);
109downcast_rs::impl_downcast!(InnerModel);
110
111impl InnerModel for TypedModel {
112    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
113        let facts = self
114            .output_outlets()?
115            .iter()
116            .map(|outlet| self.outlet_fact(*outlet).cloned())
117            .collect::<TractResult<TVec<_>>>()?;
118        Ok(facts)
119    }
120    fn is_stateless(&self) -> bool {
121        false
122    }
123
124    #[allow(unused_variables)]
125    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
126        let plan = self.clone().into_runnable()?;
127        let state = plan.spawn()?;
128        Ok(Some(Box::new(state)))
129    }
130
131    #[allow(unused_variables)]
132    fn declutter(&mut self) -> TractResult<()> {
133        self.declutter()
134    }
135
136    fn codegen(&mut self) -> TractResult<()> {
137        self.optimize()
138    }
139
140    fn as_typed(&self) -> &TypedModel {
141        self
142    }
143}
144
145pub type TypedModelOpState = TypedSimpleState;
146
147impl OpState for TypedModelOpState {
148    fn eval(
149        &mut self,
150        _session: &mut TurnState,
151        _op: &dyn Op,
152        inputs: TVec<TValue>,
153    ) -> TractResult<TVec<TValue>> {
154        let inference_out = self.run(inputs)?;
155        Ok(inference_out)
156    }
157}
158
159pub type FrozenSubmodelOpState = TypedFrozenSimpleState;
160
161impl FrozenOpState for FrozenSubmodelOpState {
162    fn unfreeze(&self) -> Box<dyn OpState> {
163        Box::new(self.unfreeze())
164    }
165}
166
167impl OpStateFreeze for TypedModelOpState {
168    fn freeze(&self) -> Box<dyn FrozenOpState> {
169        Box::new(self.freeze())
170    }
171}