1use std::collections::HashMap;
2
3use tract_core::ops::konst::Const;
4
5use super::factoid::Factoid;
6use super::{InferenceFact, InferenceModel, InferenceNode, InferenceOp};
7use crate::internal::*;
8use crate::prelude::TVec;
9
10pub trait InferenceModelExt {
11 fn analyse(&mut self, obstinate: bool) -> TractResult<bool>;
15
16 fn incorporate(self) -> TractResult<InferenceModel>;
18
19 fn missing_type_shape(&self) -> TractResult<Vec<OutletId>>;
23
24 fn eliminate_dead_branches(self) -> TractResult<InferenceModel>;
28
29 fn into_typed(self) -> TractResult<TypedModel>;
31
32 fn into_optimized(self) -> TractResult<TypedModel>;
36}
37
38impl InferenceModelExt for InferenceModel {
39 fn analyse(&mut self, obstinate: bool) -> TractResult<bool> {
43 super::analyser::Analyser::new(self).analyse_obstinate(obstinate)
44 }
45
46 fn incorporate(self) -> TractResult<InferenceModel> {
48 let mut model = self;
49 loop {
50 let mut done_something = false;
51 for p in crate::infer::optim::incorporate() {
52 done_something = done_something || p.pass(&mut model)?;
53 if cfg!(debug_assertions) {
54 model.check_edges()?;
55 }
56 }
57 if !done_something {
58 break;
59 }
60 }
61 model = model.into_compact()?;
62 model.analyse(false)?;
63 Ok(model)
64 }
65
66 fn missing_type_shape(&self) -> TractResult<Vec<OutletId>> {
70 Ok(self
71 .eval_order()?
72 .iter()
73 .flat_map(|&node| {
74 self.nodes()[node]
75 .outputs
76 .iter()
77 .enumerate()
78 .map(move |(ix, outlet)| (OutletId::new(node, ix), outlet))
79 })
80 .filter(|(_, o)| !o.fact.datum_type.is_concrete() || !o.fact.shape.is_concrete())
81 .map(|(id, _)| id)
82 .collect())
83 }
84
85 fn eliminate_dead_branches(self) -> TractResult<InferenceModel> {
89 self.into_compact()
90 }
91
92 fn into_typed(mut self) -> TractResult<TypedModel> {
94 use tract_core::internal::translator::Translate;
95
96 self.analyse(false)?;
97 let m = self.incorporate()?;
98
99 #[derive(Debug)]
100 struct ToTypedTranslator;
101 impl Translate<InferenceFact, Box<dyn InferenceOp>, TypedFact, Box<dyn TypedOp>>
102 for ToTypedTranslator
103 {
104 fn translate_node(
105 &self,
106 source: &InferenceModel,
107 node: &InferenceNode,
108 target: &mut TypedModel,
109 mapping: &HashMap<OutletId, OutletId>,
110 ) -> TractResult<TVec<OutletId>> {
111 if node.op.is_stateless()
112 && source.node_output_facts(node.id)?.iter().all(|f| f.value.is_concrete())
113 {
114 (0..node.outputs.len())
115 .map(|ix| {
116 target.add_const(
117 format!("{}.{}", node.name, ix),
118 node.outputs[ix].fact.value.concretize().unwrap(),
119 )
120 })
121 .collect()
122 } else {
123 let outputs = node.op.to_typed(source, node, target, mapping)?;
124 for output in &outputs {
125 let fact = target.outlet_fact(*output)?;
126 fact.consistent().with_context(|| {
127 format!(
128 "Checking oulet fact consistency for {:?}: {:?} after translating {:?}",
129 output,
130 fact, node.op,
131 )
132 })?;
133 }
134 Ok(outputs)
135 }
136 }
137 }
138
139 ToTypedTranslator.translate_model(&m)
140 }
141
142 fn into_optimized(self) -> TractResult<TypedModel> {
148 self.into_typed()?.into_optimized()
149 }
150}
151
152impl SpecialOps<InferenceFact, Box<dyn InferenceOp>> for InferenceModel {
153 fn is_source(op: &Box<dyn InferenceOp>) -> bool {
154 op.as_op().downcast_ref::<crate::ops::source::Source>().is_some()
155 }
156
157 fn create_dummy(&self) -> Box<dyn InferenceOp> {
158 Box::new(tract_core::ops::dummy::Dummy::new())
159 }
160
161 fn create_source(&self, _fact: InferenceFact) -> Box<dyn InferenceOp> {
162 Box::new(crate::ops::source::Source::new())
163 }
164
165 fn wire_node(
166 &mut self,
167 name: impl Into<String>,
168 op: impl Into<Box<dyn InferenceOp>>,
169 inputs: &[OutletId],
170 ) -> TractResult<TVec<OutletId>> {
171 let op = op.into();
172 let output_facts: TVec<InferenceFact> =
173 (0..op.nboutputs()?).map(|_| InferenceFact::default()).collect();
174 let id = self.add_node(name, op, output_facts)?;
175 inputs
176 .iter()
177 .enumerate()
178 .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
179 Ok(self.node(id).outputs.iter().enumerate().map(|(ix, _)| OutletId::new(id, ix)).collect())
180 }
181
182 fn add_const(
183 &mut self,
184 name: impl Into<String>,
185 v: impl IntoArcTensor,
186 ) -> TractResult<OutletId> {
187 let v = v.into_arc_tensor();
188 for node in &self.nodes {
189 if let Some(op) = node.op_as::<Const>() {
190 if op.val() == &v {
191 return Ok(node.id.into());
192 }
193 }
194 }
195 let name = name.into();
196 let fact = TypedFact::from(v.clone());
197 self.add_node(name, crate::ops::konst::Const::new(v)?, tvec!(fact.into()))
198 .map(|id| id.into())
199 }
200}
201
202#[cfg(test)]
203mod test {
204 use super::*;
205
206 #[test]
207 fn test() {
208 fn is_sync<T: Sync>() {}
209 is_sync::<InferenceModel>();
210 }
211}