tract_tensorflow/ops/
logic.rs1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::binary::BinIntoHir;
4use tract_hir::ops::logic::{CompEq, CompGT, CompGTE, CompLT, CompLTE};
5
6use crate::model::ParsingContext;
7use crate::model::TfOpRegister;
8use crate::tfpb::tensorflow::NodeDef;
9use std::collections::HashSet;
10
11pub fn register_all_ops(reg: &mut TfOpRegister) {
12 reg.insert("Equal", |_, _| Ok(CompEq.into_hir()));
13 reg.insert("Greater", |_, _| Ok(CompGT.into_hir()));
14 reg.insert("GreaterEqual", |_, _| Ok(CompGTE.into_hir()));
15 reg.insert("Less", |_, _| Ok(CompLT.into_hir()));
16 reg.insert("LessEqual", |_, _| Ok(CompLTE.into_hir()));
17 reg.insert("LogicalAnd", |_, _| Ok(ops::logic::And.into_hir()));
18 reg.insert("LogicalOr", |_, _| Ok(ops::logic::Or.into_hir()));
19 reg.insert("Merge", merge);
20 reg.insert("Switch", |_, _| Ok(Box::new(Switch)));
21}
22
23#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
24pub struct Switch;
25
26impl Op for Switch {
27 fn name(&self) -> StaticName {
28 "Switch".into()
29 }
30
31 not_a_typed_op!();
32}
33
34impl EvalOp for Switch {
35 fn is_stateless(&self) -> bool {
36 true
37 }
38
39 fn state(
40 &self,
41 _session: &TurnState,
42 _node_id: usize,
43 ) -> TractResult<Option<Box<dyn OpState>>> {
44 Ok(None)
45 }
46}
47
48impl InferenceRulesOp for Switch {
49 fn rules<'r, 'p: 'r, 's: 'r>(
50 &'s self,
51 s: &mut Solver<'r>,
52 inputs: &'p [TensorProxy],
53 outputs: &'p [TensorProxy],
54 ) -> InferenceResult {
55 check_input_arity(inputs, 2)?;
56 check_output_arity(outputs, 2)?;
57 s.equals(&inputs[1].datum_type, DatumType::Bool)?;
58 s.equals(&inputs[1].shape, shapefactoid!())?;
59 for output in outputs {
60 s.equals(&inputs[0].datum_type, &output.datum_type)?;
61 s.equals(&inputs[0].shape, &output.shape)?;
62 }
63 Ok(())
64 }
65
66 fn incorporate(
67 &self,
68 model: &InferenceModel,
69 node: &InferenceNode,
70 ) -> TractResult<Option<InferenceModelPatch>> {
71 let pred = model.outlet_fact(node.inputs[1])?;
72 if let Some(pred) = pred.concretize() {
73 let pred = *pred.try_as_plain()?.to_scalar::<bool>()?;
74 let mut dead_to_visit = HashSet::new();
75 let mut dead_done = HashSet::new();
76 let mut patch = InferenceModelPatch::default();
77 dead_to_visit.insert(OutletId::new(node.id, !pred as usize));
78 while let Some(dead_outlet) = dead_to_visit.iter().cloned().next() {
79 dead_to_visit.remove(&dead_outlet);
80 dead_done.insert(dead_outlet);
81 for succ in model.outlet_successors(dead_outlet) {
82 if model.node(succ.node).op_is::<Merge>() {
83 let outlet = model.node(succ.node).inputs[(succ.slot == 0) as usize];
84 let tap = patch.tap_model(model, outlet)?;
85 patch.shunt_outside(model, succ.node.into(), tap)?;
86 } else {
87 for slot in 0..model.node(succ.node).outputs.len() {
88 let new = OutletId::new(succ.node, slot);
89 if !dead_done.contains(&new) {
90 dead_to_visit.insert(new);
91 }
92 }
93 }
94 }
95 }
96 let tap = patch.tap_model(model, node.inputs[0])?;
97 patch.shunt_outside(model, OutletId::new(node.id, 0), tap)?;
98 patch.shunt_outside(model, OutletId::new(node.id, 1), tap)?;
99 return Ok(Some(patch));
100 }
101 Ok(None)
102 }
103
104 fn nboutputs(&self) -> TractResult<usize> {
105 Ok(2)
106 }
107
108 as_op!();
109}
110
111fn merge(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
112 let inputs = pb.get_attr_int::<i32>("N")?;
113 Ok(Box::new(Merge::new(inputs as usize)))
114}
115
116#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
117pub struct Merge {
118 n: usize,
119}
120
121impl Op for Merge {
122 fn name(&self) -> StaticName {
123 "Merge".into()
124 }
125
126 op_as_typed_op!();
127}
128
129impl EvalOp for Merge {
130 fn is_stateless(&self) -> bool {
131 true
132 }
133
134 fn state(
135 &self,
136 _session: &TurnState,
137 _node_id: usize,
138 ) -> TractResult<Option<Box<dyn OpState>>> {
139 Ok(None)
140 }
141}
142
143impl InferenceRulesOp for Merge {
144 fn rules<'r, 'p: 'r, 's: 'r>(
145 &'s self,
146 s: &mut Solver<'r>,
147 inputs: &'p [TensorProxy],
148 outputs: &'p [TensorProxy],
149 ) -> InferenceResult {
150 check_input_arity(inputs, self.n)?;
151 check_output_arity(outputs, 1)?;
152 for i in 1..self.n {
153 s.equals(&inputs[0].datum_type, &inputs[i].datum_type)?;
154 s.equals(&inputs[0].shape, &inputs[i].shape)?;
155 }
156 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
157 s.equals(&inputs[0].shape, &outputs[0].shape)?;
158 Ok(())
159 }
160
161 as_op!();
162 to_typed!();
163}
164
165impl TypedOp for Merge {
166 as_op!();
167
168 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
169 Ok(tvec!(f32::fact(inputs[0].shape.iter()), i32::fact([0; 0])))
170 }
171}