Skip to main content

tract_core/ops/
submodel.rs

1use std::fmt::Debug;
2
3use tract_downcast_rs::Downcast;
4
5use crate::{internal::*, ops::OpStateFreeze};
6
7#[derive(Debug, Clone)]
8pub struct SubmodelOp {
9    pub model: Box<dyn InnerModel>,
10    label: String,
11    decluttered: bool,
12    codegen: bool,
13}
14
15impl PartialEq for SubmodelOp {
16    fn eq(&self, _other: &Self) -> bool {
17        false
18    }
19}
20impl Eq for SubmodelOp {}
21
22impl SubmodelOp {
23    pub fn new(model: Box<dyn InnerModel>, label: &str) -> TractResult<Self> {
24        Ok(Self { model, label: label.to_string(), decluttered: false, codegen: false })
25    }
26
27    pub fn iteration_count(&self, _inputs: &[&TypedFact]) -> Option<TDim> {
28        None
29    }
30
31    pub fn model(&self) -> &TypedModel {
32        self.model.as_typed()
33    }
34
35    pub fn label(&self) -> &str {
36        self.label.as_str()
37    }
38}
39
40impl Op for SubmodelOp {
41    fn name(&self) -> StaticName {
42        "SubmodelOp".into()
43    }
44
45    op_as_typed_op!();
46}
47
48impl EvalOp for SubmodelOp {
49    fn is_stateless(&self) -> bool {
50        false
51    }
52
53    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
54        self.model.state(session, node_id)
55    }
56}
57
58impl TypedOp for SubmodelOp {
59    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
60        let facts = self.model.output_facts(inputs)?;
61        Ok(facts)
62    }
63
64    fn declutter(
65        &self,
66        model: &TypedModel,
67        node: &TypedNode,
68    ) -> TractResult<Option<TypedModelPatch>> {
69        if !self.decluttered {
70            let mut new = self.clone();
71            new.model.declutter()?;
72            new.decluttered = true;
73            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
74        } else {
75            Ok(None)
76        }
77    }
78
79    fn codegen(
80        &self,
81        model: &TypedModel,
82        node: &TypedNode,
83    ) -> TractResult<Option<TypedModelPatch>> {
84        if !self.codegen {
85            let mut new = self.clone();
86            new.model.codegen()?;
87            new.codegen = true;
88            Ok(Some(TypedModelPatch::replace_single_op(model, node, &node.inputs, new)?))
89        } else {
90            Ok(None)
91        }
92    }
93
94    as_op!();
95}
96
97pub trait InnerModel: Debug + dyn_clone::DynClone + Downcast + Sync + Send + 'static {
98    #[allow(unused_variables)]
99    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
100    fn is_stateless(&self) -> bool;
101
102    #[allow(unused_variables)]
103    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
104        Ok(None)
105    }
106
107    #[allow(unused_variables)]
108    fn declutter(&mut self) -> TractResult<()>;
109
110    fn codegen(&mut self) -> TractResult<()>;
111
112    fn as_typed(&self) -> &TypedModel;
113}
114
115dyn_clone::clone_trait_object!(InnerModel);
116downcast_rs::impl_downcast!(InnerModel);
117
118impl InnerModel for TypedModel {
119    fn output_facts(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
120        let facts = self
121            .output_outlets()?
122            .iter()
123            .map(|outlet| self.outlet_fact(*outlet).cloned())
124            .collect::<TractResult<TVec<_>>>()?;
125        Ok(facts)
126    }
127    fn is_stateless(&self) -> bool {
128        false
129    }
130
131    #[allow(unused_variables)]
132    fn state(&self, session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
133        let plan = self.clone().into_runnable()?;
134        let state = plan.spawn()?;
135        Ok(Some(Box::new(state)))
136    }
137
138    #[allow(unused_variables)]
139    fn declutter(&mut self) -> TractResult<()> {
140        self.declutter()
141    }
142
143    fn codegen(&mut self) -> TractResult<()> {
144        self.optimize()
145    }
146
147    fn as_typed(&self) -> &TypedModel {
148        self
149    }
150}
151
152pub type TypedModelOpState = TypedSimpleState;
153
154impl OpState for TypedModelOpState {
155    fn eval(
156        &mut self,
157        _session: &mut TurnState,
158        _op: &dyn Op,
159        inputs: TVec<TValue>,
160    ) -> TractResult<TVec<TValue>> {
161        let inference_out = self.run(inputs)?;
162        Ok(inference_out)
163    }
164}
165
166pub type FrozenSubmodelOpState = TypedFrozenSimpleState;
167
168impl FrozenOpState for FrozenSubmodelOpState {
169    fn unfreeze(&self) -> Box<dyn OpState> {
170        Box::new(self.unfreeze())
171    }
172}
173
174impl OpStateFreeze for TypedModelOpState {
175    fn freeze(&self) -> Box<dyn FrozenOpState> {
176        Box::new(self.freeze())
177    }
178}