tract_tensorflow/ops/
logic.rs

1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::logic::Comp;
4
5use crate::model::ParsingContext;
6use crate::model::TfOpRegister;
7use crate::tfpb::tensorflow::NodeDef;
8use std::collections::HashSet;
9
10pub fn register_all_ops(reg: &mut TfOpRegister) {
11    reg.insert("Equal", |_, _| Ok(expand(Comp::Eq)));
12    reg.insert("Greater", |_, _| Ok(expand(Comp::GT)));
13    reg.insert("GreaterEqual", |_, _| Ok(expand(Comp::GTE)));
14    reg.insert("Less", |_, _| Ok(expand(Comp::LT)));
15    reg.insert("LessEqual", |_, _| Ok(expand(Comp::LTE)));
16    reg.insert("LogicalAnd", |_, _| Ok(ops::logic::And.into_hir()));
17    reg.insert("LogicalOr", |_, _| Ok(ops::logic::Or.into_hir()));
18    reg.insert("Merge", merge);
19    reg.insert("Switch", |_, _| Ok(Box::new(Switch)));
20}
21
22#[derive(Debug, Clone, new, Hash)]
23pub struct Switch;
24
25impl Op for Switch {
26    fn name(&self) -> StaticName {
27        "Switch".into()
28    }
29
30    not_a_typed_op!();
31}
32
33impl EvalOp for Switch {
34    fn is_stateless(&self) -> bool {
35        true
36    }
37
38    fn state(
39        &self,
40        _session: &mut SessionState,
41        _node_id: usize,
42    ) -> TractResult<Option<Box<dyn OpState>>> {
43        Ok(None)
44    }
45}
46
47impl InferenceRulesOp for Switch {
48    fn rules<'r, 'p: 'r, 's: 'r>(
49        &'s self,
50        s: &mut Solver<'r>,
51        inputs: &'p [TensorProxy],
52        outputs: &'p [TensorProxy],
53    ) -> InferenceResult {
54        check_input_arity(inputs, 2)?;
55        check_output_arity(outputs, 2)?;
56        s.equals(&inputs[1].datum_type, DatumType::Bool)?;
57        s.equals(&inputs[1].shape, shapefactoid!())?;
58        for output in outputs {
59            s.equals(&inputs[0].datum_type, &output.datum_type)?;
60            s.equals(&inputs[0].shape, &output.shape)?;
61        }
62        Ok(())
63    }
64
65    fn incorporate(
66        &self,
67        model: &InferenceModel,
68        node: &InferenceNode,
69    ) -> TractResult<Option<InferenceModelPatch>> {
70        let pred = model.outlet_fact(node.inputs[1])?;
71        if let Some(pred) = pred.concretize() {
72            let pred = *pred.to_scalar::<bool>()?;
73            let mut dead_to_visit = HashSet::new();
74            let mut dead_done = HashSet::new();
75            let mut patch = InferenceModelPatch::default();
76            dead_to_visit.insert(OutletId::new(node.id, !pred as usize));
77            while let Some(dead_outlet) = dead_to_visit.iter().cloned().next() {
78                dead_to_visit.remove(&dead_outlet);
79                dead_done.insert(dead_outlet);
80                for succ in model.outlet_successors(dead_outlet) {
81                    if model.node(succ.node).op_is::<Merge>() {
82                        let outlet = model.node(succ.node).inputs[(succ.slot == 0) as usize];
83                        let tap = patch.tap_model(model, outlet)?;
84                        patch.shunt_outside(model, succ.node.into(), tap)?;
85                    } else {
86                        for slot in 0..model.node(succ.node).outputs.len() {
87                            let new = OutletId::new(succ.node, slot);
88                            if !dead_done.contains(&new) {
89                                dead_to_visit.insert(new);
90                            }
91                        }
92                    }
93                }
94            }
95            let tap = patch.tap_model(model, node.inputs[0])?;
96            patch.shunt_outside(model, OutletId::new(node.id, 0), tap)?;
97            patch.shunt_outside(model, OutletId::new(node.id, 1), tap)?;
98            return Ok(Some(patch));
99        }
100        Ok(None)
101    }
102
103    fn nboutputs(&self) -> TractResult<usize> {
104        Ok(2)
105    }
106
107    as_op!();
108}
109
110fn merge(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
111    let inputs = pb.get_attr_int::<i32>("N")?;
112    Ok(Box::new(Merge::new(inputs as usize)))
113}
114
115#[derive(Debug, Clone, new, Hash)]
116pub struct Merge {
117    n: usize,
118}
119
120impl Op for Merge {
121    fn name(&self) -> StaticName {
122        "Merge".into()
123    }
124
125    op_as_typed_op!();
126}
127
128impl EvalOp for Merge {
129    fn is_stateless(&self) -> bool {
130        true
131    }
132
133    fn state(
134        &self,
135        _session: &mut SessionState,
136        _node_id: usize,
137    ) -> TractResult<Option<Box<dyn OpState>>> {
138        Ok(None)
139    }
140}
141
142impl InferenceRulesOp for Merge {
143    fn rules<'r, 'p: 'r, 's: 'r>(
144        &'s self,
145        s: &mut Solver<'r>,
146        inputs: &'p [TensorProxy],
147        outputs: &'p [TensorProxy],
148    ) -> InferenceResult {
149        check_input_arity(inputs, self.n)?;
150        check_output_arity(outputs, 1)?;
151        for i in 1..self.n {
152            s.equals(&inputs[0].datum_type, &inputs[i].datum_type)?;
153            s.equals(&inputs[0].shape, &inputs[i].shape)?;
154        }
155        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
156        s.equals(&inputs[0].shape, &outputs[0].shape)?;
157        Ok(())
158    }
159
160    as_op!();
161    to_typed!();
162}
163
164impl TypedOp for Merge {
165    as_op!();
166
167    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
168        Ok(tvec!(f32::fact(inputs[0].shape.iter()), i32::fact([0; 0])))
169    }
170}