tract_core/ops/logic/
ite.rs

1use crate::internal::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct IfThenElse {
5    pub then_body: TypedModel,
6    pub then_input_mapping: Vec<usize>,
7    pub else_body: TypedModel,
8    pub else_input_mapping: Vec<usize>,
9}
10
11impl Op for IfThenElse {
12    fn name(&self) -> Cow<str> {
13        "IfThenElse".into()
14    }
15
16    op_as_typed_op!();
17}
18
19impl TypedOp for IfThenElse {
20    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
21        ensure!(inputs[0].datum_type == bool::datum_type());
22        ensure!(inputs[0].shape.volume() == 1.to_dim());
23        ensure!(self.then_body.inputs.len() == self.then_input_mapping.len());
24        ensure!(self.else_body.inputs.len() == self.else_input_mapping.len());
25        let mut facts = tvec!();
26        for i in 0..self.then_body.outputs.len() {
27            ensure!(
28                self.then_body.output_fact(i)?.without_value()
29                    == self.else_body.output_fact(i)?.without_value()
30            );
31            facts.push(self.then_body.output_fact(i)?.clone());
32        }
33        Ok(facts)
34    }
35
36    fn declutter(
37        &self,
38        model: &TypedModel,
39        node: &TypedNode,
40    ) -> TractResult<Option<TypedModelPatch>> {
41        if let Some(cond) = &model.outlet_fact(node.inputs[0])?.konst {
42            let cond = cond.cast_to_scalar::<bool>()?;
43            let (body, input_mapping) = if cond {
44                (&self.then_body, &self.then_input_mapping)
45            } else {
46                (&self.else_body, &self.else_input_mapping)
47            };
48            let mut inner_mapping: HashMap<OutletId, OutletId> = HashMap::default();
49            let mut patch = TypedModelPatch::default();
50            for (input_ix, outlet) in tract_itertools::izip!(input_mapping, body.input_outlets()?) {
51                let tap = patch.tap_model(model, node.inputs[*input_ix])?;
52                inner_mapping.insert(*outlet, tap);
53            }
54            for node in body.eval_order()? {
55                if Graph::is_source(&body.node(node).op) {
56                    continue;
57                }
58                let node_inputs =
59                    body.node(node).inputs.iter().map(|o| inner_mapping[o]).collect::<TVec<_>>();
60                let node_outputs =
61                    patch.wire_node(&body.node(node).name, &body.node(node).op, &node_inputs)?;
62                for (slot_ix, outlet) in node_outputs.iter().enumerate() {
63                    inner_mapping.insert((node, slot_ix).into(), *outlet);
64                }
65            }
66            for (ix, output) in body.outputs.iter().enumerate() {
67                patch.shunt_outside(model, OutletId::new(node.id, ix), inner_mapping[output])?;
68            }
69            Ok(Some(patch))
70        } else {
71            Ok(None)
72        }
73    }
74
75    as_op!();
76}
77
78impl EvalOp for IfThenElse {
79    fn is_stateless(&self) -> bool {
80        true
81    }
82
83    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
84        let cond = inputs[0].cast_to_scalar::<bool>()?;
85        let (input_mapping, body) = if cond {
86            (&self.then_input_mapping, &self.then_body)
87        } else {
88            (&self.else_input_mapping, &self.else_body)
89        };
90        let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
91        body.clone().into_runnable()?.run(inputs)
92    }
93}