tract_core/ops/logic/
ite.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct IfThenElse {
5 pub then_body: TypedModel,
6 pub then_input_mapping: Vec<usize>,
7 pub else_body: TypedModel,
8 pub else_input_mapping: Vec<usize>,
9}
10
11impl PartialEq for IfThenElse {
12 fn eq(&self, _other: &Self) -> bool {
13 false
14 }
15}
16impl Eq for IfThenElse {}
17
18impl Op for IfThenElse {
19 fn name(&self) -> StaticName {
20 "IfThenElse".into()
21 }
22
23 op_as_typed_op!();
24}
25
26impl TypedOp for IfThenElse {
27 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
28 ensure!(inputs[0].datum_type == bool::datum_type());
29 ensure!(inputs[0].shape.volume() == 1.to_dim());
30 ensure!(self.then_body.inputs.len() == self.then_input_mapping.len());
31 ensure!(self.else_body.inputs.len() == self.else_input_mapping.len());
32 let mut facts = tvec!();
33 for i in 0..self.then_body.outputs.len() {
34 ensure!(
35 self.then_body.output_fact(i)?.without_value()
36 == self.else_body.output_fact(i)?.without_value()
37 );
38 facts.push(self.then_body.output_fact(i)?.clone());
39 }
40 Ok(facts)
41 }
42
43 fn declutter(
44 &self,
45 model: &TypedModel,
46 node: &TypedNode,
47 ) -> TractResult<Option<TypedModelPatch>> {
48 if let Some(cond) = &model.outlet_fact(node.inputs[0])?.konst {
49 let cond = cond.cast_to_scalar::<bool>()?;
50 let (body, input_mapping) = if cond {
51 (&self.then_body, &self.then_input_mapping)
52 } else {
53 (&self.else_body, &self.else_input_mapping)
54 };
55 let mut inner_mapping: HashMap<OutletId, OutletId> = HashMap::default();
56 let mut patch = TypedModelPatch::default();
57 for (input_ix, outlet) in tract_itertools::izip!(input_mapping, body.input_outlets()?) {
58 let tap = patch.tap_model(model, node.inputs[*input_ix])?;
59 inner_mapping.insert(*outlet, tap);
60 }
61 for node in body.eval_order()? {
62 if Graph::is_source(&body.node(node).op) {
63 continue;
64 }
65 let node_inputs =
66 body.node(node).inputs.iter().map(|o| inner_mapping[o]).collect::<TVec<_>>();
67 let node_outputs =
68 patch.wire_node(&body.node(node).name, &body.node(node).op, &node_inputs)?;
69 for (slot_ix, outlet) in node_outputs.iter().enumerate() {
70 inner_mapping.insert((node, slot_ix).into(), *outlet);
71 }
72 }
73 for (ix, output) in body.outputs.iter().enumerate() {
74 patch.shunt_outside(model, OutletId::new(node.id, ix), inner_mapping[output])?;
75 }
76 Ok(Some(patch))
77 } else {
78 Ok(None)
79 }
80 }
81
82 as_op!();
83}
84
85impl EvalOp for IfThenElse {
86 fn is_stateless(&self) -> bool {
87 true
88 }
89
90 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
91 let cond = inputs[0].cast_to_scalar::<bool>()?;
92 let (input_mapping, body) = if cond {
93 (&self.then_input_mapping, &self.then_body)
94 } else {
95 (&self.else_input_mapping, &self.else_body)
96 };
97 let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
98 body.clone().into_runnable()?.run(inputs)
99 }
100}