tract_tensorflow/ops/
control_flow.rs

1use tract_hir::internal::*;
2
3use crate::model::TfOpRegister;
4
5pub fn register_all_ops(reg: &mut TfOpRegister) {
6    reg.insert("Enter", |_, node| {
7        Ok(Box::new(LoopGate(LoopGateRole::Enter(node.get_attr_str("frame_name")?))))
8    });
9    reg.insert("Exit", |_, _| Ok(Box::new(LoopGate(LoopGateRole::Exit))));
10    reg.insert("LoopCond", |_, _| Ok(Box::new(LoopGate(LoopGateRole::LoopCond))));
11}
12
13#[derive(Debug, Clone, Hash)]
14pub enum LoopGateRole {
15    Enter(String),
16    Exit,
17    LoopCond,
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LoopGate(LoopGateRole);
22
23
24
25impl Op for LoopGate {
26    fn name(&self) -> StaticName {
27        format!("{:?}", self.0).into()
28    }
29
30    not_a_typed_op!();
31}
32
33impl EvalOp for LoopGate {
34    fn is_stateless(&self) -> bool {
35        true
36    }
37
38    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
39        Ok(inputs)
40    }
41}
42
43impl InferenceRulesOp for LoopGate {
44    fn rules<'r, 'p: 'r, 's: 'r>(
45        &'s self,
46        s: &mut Solver<'r>,
47        inputs: &'p [TensorProxy],
48        outputs: &'p [TensorProxy],
49    ) -> InferenceResult {
50        check_input_arity(inputs, 1)?;
51        check_output_arity(outputs, 1)?;
52        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
53        s.equals(&inputs[0].shape, &outputs[0].shape)?;
54        Ok(())
55    }
56
57    as_op!();
58}
59
60#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
61pub enum NextIterationRole {
62    Source,
63    Sink,
64}
65
66#[derive(Debug, Clone, new, Hash)]
67pub struct NextIteration {
68    name: String,
69    role: NextIterationRole,
70}
71
72
73
74impl Op for NextIteration {
75    fn name(&self) -> StaticName {
76        format!("{:?}({})", self.role, self.name).into()
77    }
78
79    not_a_typed_op!();
80}
81
82impl EvalOp for NextIteration {
83    fn is_stateless(&self) -> bool {
84        false
85    }
86
87    fn state(
88        &self,
89        _state: &mut SessionState,
90        _id: usize,
91    ) -> TractResult<Option<Box<dyn OpState>>> {
92        unimplemented!();
93    }
94}
95
96impl InferenceRulesOp for NextIteration {
97    fn rules<'r, 'p: 'r, 's: 'r>(
98        &'s self,
99        _s: &mut Solver<'r>,
100        inputs: &'p [TensorProxy],
101        outputs: &'p [TensorProxy],
102    ) -> InferenceResult {
103        match self.role {
104            NextIterationRole::Source => {
105                check_input_arity(inputs, 0)?;
106                check_output_arity(outputs, 1)?;
107            }
108            NextIterationRole::Sink => {
109                check_input_arity(inputs, 1)?;
110                check_output_arity(outputs, 0)?;
111            }
112        }
113        Ok(())
114    }
115
116    as_op!();
117}