tract_onnx/ops/
logic.rs

1use crate::model::OnnxOpRegister;
2use crate::model::ParseResult;
3use crate::model::ParsingContext;
4use crate::pb::NodeProto;
5use tract_core::ops;
6use tract_hir::internal::*;
7use tract_hir::ops::logic::Comp;
8use tract_itertools::Itertools;
9
10pub fn register_all_ops(reg: &mut OnnxOpRegister) {
11    reg.insert("Not", |_, _| Ok((ops::logic::not().into_hir(), vec![])));
12    reg.insert("And", |_, _| Ok((ops::logic::And.into_hir(), vec![])));
13    reg.insert("Or", |_, _| Ok((ops::logic::Or.into_hir(), vec![])));
14    reg.insert("Xor", |_, _| Ok((ops::logic::Xor.into_hir(), vec![])));
15
16    reg.insert("Equal", |_, _| Ok((expand(Comp::Eq), vec![])));
17    reg.insert("Greater", |_, _| Ok((expand(Comp::GT), vec![])));
18    reg.insert("Less", |_, _| Ok((expand(Comp::LT), vec![])));
19    reg.insert("LessOrEqual", |_, _| Ok((expand(Comp::LTE), vec![])));
20    reg.insert("GreaterOrEqual", |_, _| Ok((expand(Comp::GTE), vec![])));
21
22    reg.insert("Where", |_, _| Ok((expand(tract_hir::ops::logic::Iff), vec![])));
23
24    reg.insert("If", _if)
25}
26
27pub fn _if(
28    ctx: &ParsingContext,
29    node: &NodeProto,
30) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
31    let graph_then = node.get_attr("then_branch")?;
32    let graph_else = node.get_attr("else_branch")?;
33    let ParseResult { model: then_body, unresolved_inputs: unresolved_inputs_then, .. } =
34        ctx.parse_graph(graph_then)?;
35    let ParseResult { model: else_body, unresolved_inputs: unresolved_inputs_else, .. } =
36        ctx.parse_graph(graph_else)?;
37    let unresolved_inputs: Vec<String> = unresolved_inputs_then
38        .iter()
39        .chain(unresolved_inputs_else.iter())
40        .sorted()
41        .unique()
42        .cloned()
43        .collect();
44    let then_input_mapping = unresolved_inputs_then
45        .iter()
46        .map(|i| unresolved_inputs.iter().position(|s| s == i).unwrap() + 1)
47        .collect();
48    let else_input_mapping = unresolved_inputs_else
49        .iter()
50        .map(|i| unresolved_inputs.iter().position(|s| s == i).unwrap() + 1)
51        .collect();
52    Ok((
53        Box::new(If { then_body, then_input_mapping, else_body, else_input_mapping }),
54        unresolved_inputs,
55    ))
56}
57
58#[derive(Debug, Clone, new)]
59pub struct If {
60    pub then_body: InferenceModel,
61    then_input_mapping: Vec<usize>,
62    pub else_body: InferenceModel,
63    else_input_mapping: Vec<usize>,
64}
65
66impl Op for If {
67    fn name(&self) -> StaticName {
68        "If".into()
69    }
70
71    not_a_typed_op!();
72}
73
74impl EvalOp for If {
75    fn is_stateless(&self) -> bool {
76        true
77    }
78
79    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
80        let cond = inputs[0].cast_to_scalar::<bool>()?;
81        let (input_mapping, body) = if cond {
82            (&self.then_input_mapping, &self.then_body)
83        } else {
84            (&self.else_input_mapping, &self.else_body)
85        };
86        let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
87        body.clone().into_runnable()?.run(inputs)
88    }
89}
90
91impl InferenceOp for If {
92    fn infer_facts(
93        &mut self,
94        inputs: TVec<&InferenceFact>,
95        outputs: TVec<&InferenceFact>,
96        observed: TVec<&InferenceFact>,
97    ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
98        let mut inputs: TVec<InferenceFact> = inputs.into_iter().cloned().collect();
99        let mut outputs: TVec<InferenceFact> = outputs.into_iter().cloned().collect();
100        loop {
101            let mut changed = false;
102            changed = changed || inputs[0].datum_type.unify_with(&bool::datum_type().into())?;
103            for (body_ix, outer_ix) in self.then_input_mapping.iter().enumerate() {
104                changed = changed
105                    || self
106                        .then_body
107                        .input_fact_mut(body_ix)?
108                        .unify_with_mut(&mut inputs[*outer_ix])?;
109            }
110            for (body_ix, outer_ix) in self.else_input_mapping.iter().enumerate() {
111                changed = changed
112                    || self
113                        .else_body
114                        .input_fact_mut(body_ix)?
115                        .unify_with_mut(&mut inputs[*outer_ix])?;
116            }
117            if let Some(a) = inputs[0].value.concretize() {
118                let a = a.cast_to_scalar()?;
119                let body = if a { &mut self.then_body } else { &mut self.else_body };
120                for oix in 0..body.output_outlets()?.len() {
121                    changed =
122                        changed || body.output_fact_mut(oix)?.unify_with_mut(&mut outputs[oix])?;
123                }
124            } else {
125                for ix in 0..self.nboutputs()? {
126                    changed = changed
127                        || self
128                            .then_body
129                            .output_fact_mut(ix)?
130                            .shape
131                            .unify_with_mut(&mut outputs[ix].shape)?
132                        || self
133                            .else_body
134                            .output_fact_mut(ix)?
135                            .shape
136                            .unify_with_mut(&mut outputs[ix].shape)?
137                        || self
138                            .then_body
139                            .output_fact_mut(ix)?
140                            .datum_type
141                            .unify_with_mut(&mut outputs[ix].datum_type)?
142                        || self
143                            .else_body
144                            .output_fact_mut(ix)?
145                            .datum_type
146                            .unify_with_mut(&mut outputs[ix].datum_type)?;
147                }
148            }
149            changed = changed || self.then_body.analyse(false)?;
150            changed = changed || self.else_body.analyse(false)?;
151            if !changed {
152                return Ok((inputs, outputs, observed.into_iter().cloned().collect()));
153            }
154        }
155    }
156
157    fn nboutputs(&self) -> TractResult<usize> {
158        let then_outputs = self.then_body.outputs.len();
159        let else_outputs = self.else_body.outputs.len();
160        ensure!(then_outputs == else_outputs, "If Operators expect the `then_branch` {} and `else_branch` {} to produce the same number of outputs", then_outputs, else_outputs);
161        Ok(then_outputs)
162    }
163
164    fn to_typed(
165        &self,
166        _source: &InferenceModel,
167        node: &InferenceNode,
168        target: &mut TypedModel,
169        mapping: &HashMap<OutletId, OutletId>,
170    ) -> TractResult<TVec<OutletId>> {
171        let then_body = self.then_body.clone().into_typed()?;
172        let else_body = self.else_body.clone().into_typed()?;
173        let inputs: TVec<_> = node.inputs.iter().map(|o| mapping[o]).collect();
174        let op = tract_core::ops::logic::IfThenElse {
175            then_body,
176            else_body,
177            then_input_mapping: self.then_input_mapping.clone(),
178            else_input_mapping: self.else_input_mapping.clone(),
179        };
180        target.wire_node(self.name(), op, &inputs)
181    }
182
183    as_op!();
184}