tract_tensorflow/ops/
control_flow.rs1use tract_hir::internal::*;
2
3use crate::model::TfOpRegister;
4
5pub fn register_all_ops(reg: &mut TfOpRegister) {
6 reg.insert("Enter", |_, node| {
7 Ok(Box::new(LoopGate(LoopGateRole::Enter(node.get_attr_str("frame_name")?))))
8 });
9 reg.insert("Exit", |_, _| Ok(Box::new(LoopGate(LoopGateRole::Exit))));
10 reg.insert("LoopCond", |_, _| Ok(Box::new(LoopGate(LoopGateRole::LoopCond))));
11}
12
13#[derive(Debug, Clone, Hash)]
14pub enum LoopGateRole {
15 Enter(String),
16 Exit,
17 LoopCond,
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LoopGate(LoopGateRole);
22
23impl Op for LoopGate {
24 fn name(&self) -> StaticName {
25 format!("{:?}", self.0).into()
26 }
27
28 not_a_typed_op!();
29}
30
31impl EvalOp for LoopGate {
32 fn is_stateless(&self) -> bool {
33 true
34 }
35
36 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37 Ok(inputs)
38 }
39}
40
41impl InferenceRulesOp for LoopGate {
42 fn rules<'r, 'p: 'r, 's: 'r>(
43 &'s self,
44 s: &mut Solver<'r>,
45 inputs: &'p [TensorProxy],
46 outputs: &'p [TensorProxy],
47 ) -> InferenceResult {
48 check_input_arity(inputs, 1)?;
49 check_output_arity(outputs, 1)?;
50 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
51 s.equals(&inputs[0].shape, &outputs[0].shape)?;
52 Ok(())
53 }
54
55 as_op!();
56}
57
58#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
59pub enum NextIterationRole {
60 Source,
61 Sink,
62}
63
64#[derive(Debug, Clone, new, Hash)]
65pub struct NextIteration {
66 name: String,
67 role: NextIterationRole,
68}
69
70impl Op for NextIteration {
71 fn name(&self) -> StaticName {
72 format!("{:?}({})", self.role, self.name).into()
73 }
74
75 not_a_typed_op!();
76}
77
78impl EvalOp for NextIteration {
79 fn is_stateless(&self) -> bool {
80 false
81 }
82
83 fn state(&self, _state: &TurnState, _id: usize) -> TractResult<Option<Box<dyn OpState>>> {
84 unimplemented!();
85 }
86}
87
88impl InferenceRulesOp for NextIteration {
89 fn rules<'r, 'p: 'r, 's: 'r>(
90 &'s self,
91 _s: &mut Solver<'r>,
92 inputs: &'p [TensorProxy],
93 outputs: &'p [TensorProxy],
94 ) -> InferenceResult {
95 match self.role {
96 NextIterationRole::Source => {
97 check_input_arity(inputs, 0)?;
98 check_output_arity(outputs, 1)?;
99 }
100 NextIterationRole::Sink => {
101 check_input_arity(inputs, 1)?;
102 check_output_arity(outputs, 0)?;
103 }
104 }
105 Ok(())
106 }
107
108 as_op!();
109}