1use crate::model::OnnxOpRegister;
2use crate::model::ParseResult;
3use crate::model::ParsingContext;
4use crate::pb::NodeProto;
5use tract_core::ops;
6use tract_hir::internal::*;
7use tract_hir::ops::logic::Comp;
8use tract_itertools::Itertools;
9
10pub fn register_all_ops(reg: &mut OnnxOpRegister) {
11 reg.insert("Not", |_, _| Ok((ops::logic::not().into_hir(), vec![])));
12 reg.insert("And", |_, _| Ok((ops::logic::And.into_hir(), vec![])));
13 reg.insert("Or", |_, _| Ok((ops::logic::Or.into_hir(), vec![])));
14 reg.insert("Xor", |_, _| Ok((ops::logic::Xor.into_hir(), vec![])));
15
16 reg.insert("Equal", |_, _| Ok((expand(Comp::Eq), vec![])));
17 reg.insert("Greater", |_, _| Ok((expand(Comp::GT), vec![])));
18 reg.insert("Less", |_, _| Ok((expand(Comp::LT), vec![])));
19 reg.insert("LessOrEqual", |_, _| Ok((expand(Comp::LTE), vec![])));
20 reg.insert("GreaterOrEqual", |_, _| Ok((expand(Comp::GTE), vec![])));
21
22 reg.insert("Where", |_, _| Ok((expand(tract_hir::ops::logic::Iff), vec![])));
23
24 reg.insert("If", _if)
25}
26
27pub fn _if(
28 ctx: &ParsingContext,
29 node: &NodeProto,
30) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
31 let graph_then = node.get_attr("then_branch")?;
32 let graph_else = node.get_attr("else_branch")?;
33 let ParseResult { model: then_body, unresolved_inputs: unresolved_inputs_then, .. } =
34 ctx.parse_graph(graph_then)?;
35 let ParseResult { model: else_body, unresolved_inputs: unresolved_inputs_else, .. } =
36 ctx.parse_graph(graph_else)?;
37 let unresolved_inputs: Vec<String> = unresolved_inputs_then
38 .iter()
39 .chain(unresolved_inputs_else.iter())
40 .sorted()
41 .unique()
42 .cloned()
43 .collect();
44 let then_input_mapping = unresolved_inputs_then
45 .iter()
46 .map(|i| unresolved_inputs.iter().position(|s| s == i).unwrap() + 1)
47 .collect();
48 let else_input_mapping = unresolved_inputs_else
49 .iter()
50 .map(|i| unresolved_inputs.iter().position(|s| s == i).unwrap() + 1)
51 .collect();
52 Ok((
53 Box::new(If { then_body, then_input_mapping, else_body, else_input_mapping }),
54 unresolved_inputs,
55 ))
56}
57
58#[derive(Debug, Clone, new)]
59pub struct If {
60 pub then_body: InferenceModel,
61 then_input_mapping: Vec<usize>,
62 pub else_body: InferenceModel,
63 else_input_mapping: Vec<usize>,
64}
65
66impl Op for If {
67 fn name(&self) -> StaticName {
68 "If".into()
69 }
70
71 not_a_typed_op!();
72}
73
74impl EvalOp for If {
75 fn is_stateless(&self) -> bool {
76 true
77 }
78
79 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
80 let cond = inputs[0].cast_to_scalar::<bool>()?;
81 let (input_mapping, body) = if cond {
82 (&self.then_input_mapping, &self.then_body)
83 } else {
84 (&self.else_input_mapping, &self.else_body)
85 };
86 let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
87 body.clone().into_runnable()?.run(inputs)
88 }
89}
90
91impl InferenceOp for If {
92 fn infer_facts(
93 &mut self,
94 inputs: TVec<&InferenceFact>,
95 outputs: TVec<&InferenceFact>,
96 observed: TVec<&InferenceFact>,
97 ) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
98 let mut inputs: TVec<InferenceFact> = inputs.into_iter().cloned().collect();
99 let mut outputs: TVec<InferenceFact> = outputs.into_iter().cloned().collect();
100 loop {
101 let mut changed = false;
102 changed = changed || inputs[0].datum_type.unify_with(&bool::datum_type().into())?;
103 for (body_ix, outer_ix) in self.then_input_mapping.iter().enumerate() {
104 changed = changed
105 || self
106 .then_body
107 .input_fact_mut(body_ix)?
108 .unify_with_mut(&mut inputs[*outer_ix])?;
109 }
110 for (body_ix, outer_ix) in self.else_input_mapping.iter().enumerate() {
111 changed = changed
112 || self
113 .else_body
114 .input_fact_mut(body_ix)?
115 .unify_with_mut(&mut inputs[*outer_ix])?;
116 }
117 if let Some(a) = inputs[0].value.concretize() {
118 let a = a.cast_to_scalar()?;
119 let body = if a { &mut self.then_body } else { &mut self.else_body };
120 for oix in 0..body.output_outlets()?.len() {
121 changed =
122 changed || body.output_fact_mut(oix)?.unify_with_mut(&mut outputs[oix])?;
123 }
124 } else {
125 for ix in 0..self.nboutputs()? {
126 changed = changed
127 || self
128 .then_body
129 .output_fact_mut(ix)?
130 .shape
131 .unify_with_mut(&mut outputs[ix].shape)?
132 || self
133 .else_body
134 .output_fact_mut(ix)?
135 .shape
136 .unify_with_mut(&mut outputs[ix].shape)?
137 || self
138 .then_body
139 .output_fact_mut(ix)?
140 .datum_type
141 .unify_with_mut(&mut outputs[ix].datum_type)?
142 || self
143 .else_body
144 .output_fact_mut(ix)?
145 .datum_type
146 .unify_with_mut(&mut outputs[ix].datum_type)?;
147 }
148 }
149 changed = changed || self.then_body.analyse(false)?;
150 changed = changed || self.else_body.analyse(false)?;
151 if !changed {
152 return Ok((inputs, outputs, observed.into_iter().cloned().collect()));
153 }
154 }
155 }
156
157 fn nboutputs(&self) -> TractResult<usize> {
158 let then_outputs = self.then_body.outputs.len();
159 let else_outputs = self.else_body.outputs.len();
160 ensure!(then_outputs == else_outputs, "If Operators expect the `then_branch` {} and `else_branch` {} to produce the same number of outputs", then_outputs, else_outputs);
161 Ok(then_outputs)
162 }
163
164 fn to_typed(
165 &self,
166 _source: &InferenceModel,
167 node: &InferenceNode,
168 target: &mut TypedModel,
169 mapping: &HashMap<OutletId, OutletId>,
170 ) -> TractResult<TVec<OutletId>> {
171 let then_body = self.then_body.clone().into_typed()?;
172 let else_body = self.else_body.clone().into_typed()?;
173 let inputs: TVec<_> = node.inputs.iter().map(|o| mapping[o]).collect();
174 let op = tract_core::ops::logic::IfThenElse {
175 then_body,
176 else_body,
177 then_input_mapping: self.then_input_mapping.clone(),
178 else_input_mapping: self.else_input_mapping.clone(),
179 };
180 target.wire_node(self.name(), op, &inputs)
181 }
182
183 as_op!();
184}