tract_tensorflow/ops/
logic.rs1use tract_hir::internal::*;
2use tract_hir::ops;
3use tract_hir::ops::logic::Comp;
4
5use crate::model::ParsingContext;
6use crate::model::TfOpRegister;
7use crate::tfpb::tensorflow::NodeDef;
8use std::collections::HashSet;
9
10pub fn register_all_ops(reg: &mut TfOpRegister) {
11 reg.insert("Equal", |_, _| Ok(expand(Comp::Eq)));
12 reg.insert("Greater", |_, _| Ok(expand(Comp::GT)));
13 reg.insert("GreaterEqual", |_, _| Ok(expand(Comp::GTE)));
14 reg.insert("Less", |_, _| Ok(expand(Comp::LT)));
15 reg.insert("LessEqual", |_, _| Ok(expand(Comp::LTE)));
16 reg.insert("LogicalAnd", |_, _| Ok(ops::logic::And.into_hir()));
17 reg.insert("LogicalOr", |_, _| Ok(ops::logic::Or.into_hir()));
18 reg.insert("Merge", merge);
19 reg.insert("Switch", |_, _| Ok(Box::new(Switch)));
20}
21
22#[derive(Debug, Clone, new, Hash)]
23pub struct Switch;
24
25impl Op for Switch {
26 fn name(&self) -> StaticName {
27 "Switch".into()
28 }
29
30 not_a_typed_op!();
31}
32
33impl EvalOp for Switch {
34 fn is_stateless(&self) -> bool {
35 true
36 }
37
38 fn state(
39 &self,
40 _session: &mut SessionState,
41 _node_id: usize,
42 ) -> TractResult<Option<Box<dyn OpState>>> {
43 Ok(None)
44 }
45}
46
47impl InferenceRulesOp for Switch {
48 fn rules<'r, 'p: 'r, 's: 'r>(
49 &'s self,
50 s: &mut Solver<'r>,
51 inputs: &'p [TensorProxy],
52 outputs: &'p [TensorProxy],
53 ) -> InferenceResult {
54 check_input_arity(inputs, 2)?;
55 check_output_arity(outputs, 2)?;
56 s.equals(&inputs[1].datum_type, DatumType::Bool)?;
57 s.equals(&inputs[1].shape, shapefactoid!())?;
58 for output in outputs {
59 s.equals(&inputs[0].datum_type, &output.datum_type)?;
60 s.equals(&inputs[0].shape, &output.shape)?;
61 }
62 Ok(())
63 }
64
65 fn incorporate(
66 &self,
67 model: &InferenceModel,
68 node: &InferenceNode,
69 ) -> TractResult<Option<InferenceModelPatch>> {
70 let pred = model.outlet_fact(node.inputs[1])?;
71 if let Some(pred) = pred.concretize() {
72 let pred = *pred.to_scalar::<bool>()?;
73 let mut dead_to_visit = HashSet::new();
74 let mut dead_done = HashSet::new();
75 let mut patch = InferenceModelPatch::default();
76 dead_to_visit.insert(OutletId::new(node.id, !pred as usize));
77 while let Some(dead_outlet) = dead_to_visit.iter().cloned().next() {
78 dead_to_visit.remove(&dead_outlet);
79 dead_done.insert(dead_outlet);
80 for succ in model.outlet_successors(dead_outlet) {
81 if model.node(succ.node).op_is::<Merge>() {
82 let outlet = model.node(succ.node).inputs[(succ.slot == 0) as usize];
83 let tap = patch.tap_model(model, outlet)?;
84 patch.shunt_outside(model, succ.node.into(), tap)?;
85 } else {
86 for slot in 0..model.node(succ.node).outputs.len() {
87 let new = OutletId::new(succ.node, slot);
88 if !dead_done.contains(&new) {
89 dead_to_visit.insert(new);
90 }
91 }
92 }
93 }
94 }
95 let tap = patch.tap_model(model, node.inputs[0])?;
96 patch.shunt_outside(model, OutletId::new(node.id, 0), tap)?;
97 patch.shunt_outside(model, OutletId::new(node.id, 1), tap)?;
98 return Ok(Some(patch));
99 }
100 Ok(None)
101 }
102
103 fn nboutputs(&self) -> TractResult<usize> {
104 Ok(2)
105 }
106
107 as_op!();
108}
109
110fn merge(_ctx: &ParsingContext, pb: &NodeDef) -> TractResult<Box<dyn InferenceOp>> {
111 let inputs = pb.get_attr_int::<i32>("N")?;
112 Ok(Box::new(Merge::new(inputs as usize)))
113}
114
115#[derive(Debug, Clone, new, Hash)]
116pub struct Merge {
117 n: usize,
118}
119
120impl Op for Merge {
121 fn name(&self) -> StaticName {
122 "Merge".into()
123 }
124
125 op_as_typed_op!();
126}
127
128impl EvalOp for Merge {
129 fn is_stateless(&self) -> bool {
130 true
131 }
132
133 fn state(
134 &self,
135 _session: &mut SessionState,
136 _node_id: usize,
137 ) -> TractResult<Option<Box<dyn OpState>>> {
138 Ok(None)
139 }
140}
141
142impl InferenceRulesOp for Merge {
143 fn rules<'r, 'p: 'r, 's: 'r>(
144 &'s self,
145 s: &mut Solver<'r>,
146 inputs: &'p [TensorProxy],
147 outputs: &'p [TensorProxy],
148 ) -> InferenceResult {
149 check_input_arity(inputs, self.n)?;
150 check_output_arity(outputs, 1)?;
151 for i in 1..self.n {
152 s.equals(&inputs[0].datum_type, &inputs[i].datum_type)?;
153 s.equals(&inputs[0].shape, &inputs[i].shape)?;
154 }
155 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
156 s.equals(&inputs[0].shape, &outputs[0].shape)?;
157 Ok(())
158 }
159
160 as_op!();
161 to_typed!();
162}
163
164impl TypedOp for Merge {
165 as_op!();
166
167 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
168 Ok(tvec!(f32::fact(inputs[0].shape.iter()), i32::fact([0; 0])))
169 }
170}