Skip to main content

tract_tensorflow/ops/
logic.rs

1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::binary::BinIntoHir;
4use tract_hir::ops::logic::{CompEq, CompGT, CompGTE, CompLT, CompLTE};
5
6use crate::model::ParsingContext;
7use crate::model::TfOpRegister;
8use crate::tfpb::tensorflow::NodeDef;
9use std::collections::HashSet;
10
11pub fn register_all_ops(reg: &mut TfOpRegister) {
12    reg.insert("Equal", |_, _| Ok(CompEq.into_hir()));
13    reg.insert("Greater", |_, _| Ok(CompGT.into_hir()));
14    reg.insert("GreaterEqual", |_, _| Ok(CompGTE.into_hir()));
15    reg.insert("Less", |_, _| Ok(CompLT.into_hir()));
16    reg.insert("LessEqual", |_, _| Ok(CompLTE.into_hir()));
17    reg.insert("LogicalAnd", |_, _| Ok(ops::logic::And.into_hir()));
18    reg.insert("LogicalOr", |_, _| Ok(ops::logic::Or.into_hir()));
19    reg.insert("Merge", merge);
20    reg.insert("Switch", |_, _| Ok(Box::new(Switch)));
21}
22
23#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
24pub struct Switch;
25
26impl Op for Switch {
27    fn name(&self) -> StaticName {
28        "Switch".into()
29    }
30
31    not_a_typed_op!();
32}
33
34impl EvalOp for Switch {
35    fn is_stateless(&self) -> bool {
36        true
37    }
38
39    fn state(
40        &self,
41        _session: &TurnState,
42        _node_id: usize,
43    ) -> TractResult<Option<Box<dyn OpState>>> {
44        Ok(None)
45    }
46}
47
48impl InferenceRulesOp for Switch {
49    fn rules<'r, 'p: 'r, 's: 'r>(
50        &'s self,
51        s: &mut Solver<'r>,
52        inputs: &'p [TensorProxy],
53        outputs: &'p [TensorProxy],
54    ) -> InferenceResult {
55        check_input_arity(inputs, 2)?;
56        check_output_arity(outputs, 2)?;
57        s.equals(&inputs[1].datum_type, DatumType::Bool)?;
58        s.equals(&inputs[1].shape, shapefactoid!())?;
59        for output in outputs {
60            s.equals(&inputs[0].datum_type, &output.datum_type)?;
61            s.equals(&inputs[0].shape, &output.shape)?;
62        }
63        Ok(())
64    }
65
66    fn incorporate(
67        &self,
68        model: &InferenceModel,
69        node: &InferenceNode,
70    ) -> TractResult<Option<InferenceModelPatch>> {
71        let pred = model.outlet_fact(node.inputs[1])?;
72        if let Some(pred) = pred.concretize() {
73            let pred = *pred.try_as_plain()?.to_scalar::<bool>()?;
74            let mut dead_to_visit = HashSet::new();
75            let mut dead_done = HashSet::new();
76            let mut patch = InferenceModelPatch::default();
77            dead_to_visit.insert(OutletId::new(node.id, !pred as usize));
78            while let Some(dead_outlet) = dead_to_visit.iter().cloned().next() {
79                dead_to_visit.remove(&dead_outlet);
80                dead_done.insert(dead_outlet);
81                for succ in model.outlet_successors(dead_outlet) {
82                    if model.node(succ.node).op_is::<Merge>() {
83                        let outlet = model.node(succ.node).inputs[(succ.slot == 0) as usize];
84                        let tap = patch.tap_model(model, outlet)?;
85                        patch.shunt_outside(model, succ.node.into(), tap)?;
86                    } else {
87                        for slot in 0..model.node(succ.node).outputs.len() {
88                            let new = OutletId::new(succ.node, slot);
89                            if !dead_done.contains(&new) {
90                                dead_to_visit.insert(new);
91                            }
92                        }
93                    }
94                }
95            }
96            let tap = patch.tap_model(model, node.inputs[0])?;
97            patch.shunt_outside(model, OutletId::new(node.id, 0), tap)?;
98            patch.shunt_outside(model, OutletId::new(node.id, 1), tap)?;
99            return Ok(Some(patch));
100        }
101        Ok(None)
102    }
103
104    fn nboutputs(&self) -> TractResult<usize> {
105        Ok(2)
106    }
107
108    as_op!();
109}
110
111fn merge(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
112    let inputs = pb.get_attr_int::<i32>("N")?;
113    Ok(Box::new(Merge::new(inputs as usize)))
114}
115
116#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
117pub struct Merge {
118    n: usize,
119}
120
121impl Op for Merge {
122    fn name(&self) -> StaticName {
123        "Merge".into()
124    }
125
126    op_as_typed_op!();
127}
128
129impl EvalOp for Merge {
130    fn is_stateless(&self) -> bool {
131        true
132    }
133
134    fn state(
135        &self,
136        _session: &TurnState,
137        _node_id: usize,
138    ) -> TractResult<Option<Box<dyn OpState>>> {
139        Ok(None)
140    }
141}
142
143impl InferenceRulesOp for Merge {
144    fn rules<'r, 'p: 'r, 's: 'r>(
145        &'s self,
146        s: &mut Solver<'r>,
147        inputs: &'p [TensorProxy],
148        outputs: &'p [TensorProxy],
149    ) -> InferenceResult {
150        check_input_arity(inputs, self.n)?;
151        check_output_arity(outputs, 1)?;
152        for i in 1..self.n {
153            s.equals(&inputs[0].datum_type, &inputs[i].datum_type)?;
154            s.equals(&inputs[0].shape, &inputs[i].shape)?;
155        }
156        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
157        s.equals(&inputs[0].shape, &outputs[0].shape)?;
158        Ok(())
159    }
160
161    as_op!();
162    to_typed!();
163}
164
165impl TypedOp for Merge {
166    as_op!();
167
168    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
169        Ok(tvec!(f32::fact(inputs[0].shape.iter()), i32::fact([0; 0])))
170    }
171}