tract_core/ops/logic/
ite.rs1use crate::internal::*;
2
3#[derive(Debug, Clone, Default)]
4pub struct IfThenElse {
5 pub then_body: TypedModel,
6 pub then_input_mapping: Vec<usize>,
7 pub else_body: TypedModel,
8 pub else_input_mapping: Vec<usize>,
9}
10
11impl Op for IfThenElse {
12 fn name(&self) -> Cow<str> {
13 "IfThenElse".into()
14 }
15
16 op_as_typed_op!();
17}
18
19impl TypedOp for IfThenElse {
20 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
21 ensure!(inputs[0].datum_type == bool::datum_type());
22 ensure!(inputs[0].shape.volume() == 1.to_dim());
23 ensure!(self.then_body.inputs.len() == self.then_input_mapping.len());
24 ensure!(self.else_body.inputs.len() == self.else_input_mapping.len());
25 let mut facts = tvec!();
26 for i in 0..self.then_body.outputs.len() {
27 ensure!(
28 self.then_body.output_fact(i)?.without_value()
29 == self.else_body.output_fact(i)?.without_value()
30 );
31 facts.push(self.then_body.output_fact(i)?.clone());
32 }
33 Ok(facts)
34 }
35
36 fn declutter(
37 &self,
38 model: &TypedModel,
39 node: &TypedNode,
40 ) -> TractResult<Option<TypedModelPatch>> {
41 if let Some(cond) = &model.outlet_fact(node.inputs[0])?.konst {
42 let cond = cond.cast_to_scalar::<bool>()?;
43 let (body, input_mapping) = if cond {
44 (&self.then_body, &self.then_input_mapping)
45 } else {
46 (&self.else_body, &self.else_input_mapping)
47 };
48 let mut inner_mapping: HashMap<OutletId, OutletId> = HashMap::default();
49 let mut patch = TypedModelPatch::default();
50 for (input_ix, outlet) in tract_itertools::izip!(input_mapping, body.input_outlets()?) {
51 let tap = patch.tap_model(model, node.inputs[*input_ix])?;
52 inner_mapping.insert(*outlet, tap);
53 }
54 for node in body.eval_order()? {
55 if Graph::is_source(&body.node(node).op) {
56 continue;
57 }
58 let node_inputs =
59 body.node(node).inputs.iter().map(|o| inner_mapping[o]).collect::<TVec<_>>();
60 let node_outputs =
61 patch.wire_node(&body.node(node).name, &body.node(node).op, &node_inputs)?;
62 for (slot_ix, outlet) in node_outputs.iter().enumerate() {
63 inner_mapping.insert((node, slot_ix).into(), *outlet);
64 }
65 }
66 for (ix, output) in body.outputs.iter().enumerate() {
67 patch.shunt_outside(model, OutletId::new(node.id, ix), inner_mapping[output])?;
68 }
69 Ok(Some(patch))
70 } else {
71 Ok(None)
72 }
73 }
74
75 as_op!();
76}
77
78impl EvalOp for IfThenElse {
79 fn is_stateless(&self) -> bool {
80 true
81 }
82
83 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
84 let cond = inputs[0].cast_to_scalar::<bool>()?;
85 let (input_mapping, body) = if cond {
86 (&self.then_input_mapping, &self.then_body)
87 } else {
88 (&self.else_input_mapping, &self.else_body)
89 };
90 let inputs: TVec<TValue> = input_mapping.iter().map(|&ix| inputs[ix].clone()).collect();
91 body.clone().into_runnable()?.run(inputs)
92 }
93}