tract_hir/infer/
model.rs

1use std::collections::HashMap;
2
3use tract_core::ops::konst::Const;
4
5use super::factoid::Factoid;
6use super::{InferenceFact, InferenceModel, InferenceNode, InferenceOp};
7use crate::internal::*;
8use crate::prelude::TVec;
9
10pub trait InferenceModelExt {
11    /// Analyse all nodes of the graph.
12    ///
13    /// Will stop on first error unless `obstinate` is `true`.
14    fn analyse(&mut self, obstinate: bool) -> TractResult<bool>;
15
16    /// Perform early transformation before going typed.
17    fn incorporate(self) -> TractResult<InferenceModel>;
18
19    /// List OutletId with incomplete type information.
20    ///
21    /// Will stop on first error unless `obstinate` is `true`.
22    fn missing_type_shape(&self) -> TractResult<Vec<OutletId>>;
23
24    /// Eliminate seemingly dead branches of the graph.
25    ///
26    /// This may break stateful networks.
27    fn eliminate_dead_branches(self) -> TractResult<InferenceModel>;
28
29    /// Attempt full analyse and conversion to TypedModel.
30    fn into_typed(self) -> TractResult<TypedModel>;
31
32    /// Attempt full analyse, decluttering and mapping to optimized operations.
33    ///
34    /// This will work even if the network can not be normalized.
35    fn into_optimized(self) -> TractResult<TypedModel>;
36}
37
38impl InferenceModelExt for InferenceModel {
39    /// Analyse all nodes of the graph.
40    ///
41    /// Will stop on first error unless `obstinate` is `true`.
42    fn analyse(&mut self, obstinate: bool) -> TractResult<bool> {
43        super::analyser::Analyser::new(self).analyse_obstinate(obstinate)
44    }
45
46    /// Perform early transformation before going typed.
47    fn incorporate(self) -> TractResult<InferenceModel> {
48        let mut model = self;
49        loop {
50            let mut done_something = false;
51            for p in crate::infer::optim::incorporate() {
52                done_something = done_something || p.pass(&mut model)?;
53                if cfg!(debug_assertions) {
54                    model.check_edges()?;
55                }
56            }
57            if !done_something {
58                break;
59            }
60        }
61        model = model.into_compact()?;
62        model.analyse(false)?;
63        Ok(model)
64    }
65
66    /// List OutletId with incomplete type information.
67    ///
68    /// Will stop on first error unless `obstinate` is `true`.
69    fn missing_type_shape(&self) -> TractResult<Vec<OutletId>> {
70        Ok(self
71            .eval_order()?
72            .iter()
73            .flat_map(|&node| {
74                self.nodes()[node]
75                    .outputs
76                    .iter()
77                    .enumerate()
78                    .map(move |(ix, outlet)| (OutletId::new(node, ix), outlet))
79            })
80            .filter(|(_, o)| !o.fact.datum_type.is_concrete() || !o.fact.shape.is_concrete())
81            .map(|(id, _)| id)
82            .collect())
83    }
84
85    /// Eliminate seemingly dead branches of the graph.
86    ///
87    /// This may break stateful networks.
88    fn eliminate_dead_branches(self) -> TractResult<InferenceModel> {
89        self.into_compact()
90    }
91
92    /// Attempt full analyse and conversion to TypedModel.
93    fn into_typed(mut self) -> TractResult<TypedModel> {
94        use tract_core::internal::translator::Translate;
95
96        self.analyse(false)?;
97        let m = self.incorporate()?;
98
99        #[derive(Debug)]
100        struct ToTypedTranslator;
101        impl Translate<InferenceFact, Box<dyn InferenceOp>, TypedFact, Box<dyn TypedOp>>
102            for ToTypedTranslator
103        {
104            fn translate_node(
105                &self,
106                source: &InferenceModel,
107                node: &InferenceNode,
108                target: &mut TypedModel,
109                mapping: &HashMap<OutletId, OutletId>,
110            ) -> TractResult<TVec<OutletId>> {
111                if node.op.is_stateless()
112                    && source.node_output_facts(node.id)?.iter().all(|f| f.value.is_concrete())
113                {
114                    (0..node.outputs.len())
115                        .map(|ix| {
116                            target.add_const(
117                                format!("{}.{}", node.name, ix),
118                                node.outputs[ix].fact.value.concretize().unwrap(),
119                            )
120                        })
121                        .collect()
122                } else {
123                    let outputs = node.op.to_typed(source, node, target, mapping)?;
124                    for output in &outputs {
125                        let fact = target.outlet_fact(*output)?;
126                        fact.consistent().with_context(|| {
127                            format!(
128                                "Checking oulet fact consistency for {:?}: {:?} after translating {:?}",
129                                output,
130                                fact, node.op,
131                            )
132                        })?;
133                    }
134                    Ok(outputs)
135                }
136            }
137        }
138
139        ToTypedTranslator.translate_model(&m)
140    }
141
142    /// Attempt full analyse, decluttering and mapping to optimized operations.
143    ///
144    /// This is meant for "simple" networks, where no special model
145    /// transformation needs to happen. Aternaltively, use to_typed() and
146    /// manipulate the TypedModel for more control.
147    fn into_optimized(self) -> TractResult<TypedModel> {
148        self.into_typed()?.into_optimized()
149    }
150}
151
152impl SpecialOps<InferenceFact, Box<dyn InferenceOp>> for InferenceModel {
153    fn is_source(op: &Box<dyn InferenceOp>) -> bool {
154        op.as_op().downcast_ref::<crate::ops::source::Source>().is_some()
155    }
156
157    fn create_dummy(&self) -> Box<dyn InferenceOp> {
158        Box::new(tract_core::ops::dummy::Dummy::new())
159    }
160
161    fn create_source(&self, _fact: InferenceFact) -> Box<dyn InferenceOp> {
162        Box::new(crate::ops::source::Source::new())
163    }
164
165    fn wire_node(
166        &mut self,
167        name: impl Into<String>,
168        op: impl Into<Box<dyn InferenceOp>>,
169        inputs: &[OutletId],
170    ) -> TractResult<TVec<OutletId>> {
171        let op = op.into();
172        let output_facts: TVec<InferenceFact> =
173            (0..op.nboutputs()?).map(|_| InferenceFact::default()).collect();
174        let id = self.add_node(name, op, output_facts)?;
175        inputs
176            .iter()
177            .enumerate()
178            .try_for_each(|(ix, i)| self.add_edge(*i, InletId::new(id, ix)))?;
179        Ok(self.node(id).outputs.iter().enumerate().map(|(ix, _)| OutletId::new(id, ix)).collect())
180    }
181
182    fn add_const(
183        &mut self,
184        name: impl Into<String>,
185        v: impl IntoArcTensor,
186    ) -> TractResult<OutletId> {
187        let v = v.into_arc_tensor();
188        for node in &self.nodes {
189            if let Some(op) = node.op_as::<Const>() {
190                if op.val() == &v {
191                    return Ok(node.id.into());
192                }
193            }
194        }
195        let name = name.into();
196        let fact = TypedFact::from(v.clone());
197        self.add_node(name, crate::ops::konst::Const::new(v)?, tvec!(fact.into()))
198            .map(|id| id.into())
199    }
200}
201
202#[cfg(test)]
203mod test {
204    use super::*;
205
206    #[test]
207    fn test() {
208        fn is_sync<T: Sync>() {}
209        is_sync::<InferenceModel>();
210    }
211}