Skip to main content

tract_core/ops/logic/
ite.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct IfThenElse {
5    pub then_body: TypedModel,
6    pub then_input_mapping: Vec<usize>,
7    pub else_body: TypedModel,
8    pub else_input_mapping: Vec<usize>,
9}
10
11impl PartialEq for IfThenElse {
12    fn eq(&self, _other: &Self) -> bool {
13        false
14    }
15}
16impl Eq for IfThenElse {}
17
18impl Op for IfThenElse {
19    fn name(&self) -> StaticName {
20        "IfThenElse".into()
21    }
22
23    op_as_typed_op!();
24}
25
26impl TypedOp for IfThenElse {
27    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
28        ensure!(inputs[0].datum_type == bool::datum_type());
29        ensure!(inputs[0].shape.volume() == 1.to_dim());
30        ensure!(self.then_body.inputs.len() == self.then_input_mapping.len());
31        ensure!(self.else_body.inputs.len() == self.else_input_mapping.len());
32        let mut facts = tvec!();
33        for i in 0..self.then_body.outputs.len() {
34            ensure!(
35                self.then_body.output_fact(i)?.without_value()
36                    == self.else_body.output_fact(i)?.without_value()
37            );
38            facts.push(self.then_body.output_fact(i)?.clone());
39        }
40        Ok(facts)
41    }
42
43    fn declutter(
44        &self,
45        model: &TypedModel,
46        node: &TypedNode,
47    ) -> TractResult<Option<TypedModelPatch>> {
48        if let Some(cond) = &model.outlet_fact(node.inputs[0])?.konst {
49            let cond = cond.cast_to_scalar::<bool>()?;
50            let (body, input_mapping) = if cond {
51                (&self.then_body, &self.then_input_mapping)
52            } else {
53                (&self.else_body, &self.else_input_mapping)
54            };
55            let mut inner_mapping: HashMap<OutletId, OutletId> = HashMap::default();
56            let mut patch = TypedModelPatch::default();
57            for (input_ix, outlet) in tract_itertools::izip!(input_mapping, body.input_outlets()?) {
58                let tap = patch.tap_model(model, node.inputs[*input_ix])?;
59                inner_mapping.insert(*outlet, tap);
60            }
61            for node in body.eval_order()? {
62                if Graph::is_source(&body.node(node).op) {
63                    continue;
64                }
65                let node_inputs =
66                    body.node(node).inputs.iter().map(|o| inner_mapping[o]).collect::<TVec<_>>();
67                let node_outputs =
68                    patch.wire_node(&body.node(node).name, &body.node(node).op, &node_inputs)?;
69                for (slot_ix, outlet) in node_outputs.iter().enumerate() {
70                    inner_mapping.insert((node, slot_ix).into(), *outlet);
71                }
72            }
73            for (ix, output) in body.outputs.iter().enumerate() {
74                patch.shunt_outside(model, OutletId::new(node.id, ix), inner_mapping[output])?;
75            }
76            Ok(Some(patch))
77        } else {
78            Ok(None)
79        }
80    }
81
82    as_op!();
83}
84
85impl EvalOp for IfThenElse {
86    fn is_stateless(&self) -> bool {
87        true
88    }
89
90    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
91        let cond = inputs[0].cast_to_scalar::<bool>()?;
92        let (input_mapping, body) = if cond {
93            (&self.then_input_mapping, &self.then_body)
94        } else {
95            (&self.else_input_mapping, &self.else_body)
96        };
97        let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
98        body.clone().into_runnable()?.run(inputs)
99    }
100}