Skip to main content

tract_tensorflow/ops/
control_flow.rs

1use tract_hir::internal::*;
2
3use crate::model::TfOpRegister;
4
5pub fn register_all_ops(reg: &mut TfOpRegister) {
6    reg.insert("Enter", |_, node| {
7        Ok(Box::new(LoopGate(LoopGateRole::Enter(node.get_attr_str("frame_name")?))))
8    });
9    reg.insert("Exit", |_, _| Ok(Box::new(LoopGate(LoopGateRole::Exit))));
10    reg.insert("LoopCond", |_, _| Ok(Box::new(LoopGate(LoopGateRole::LoopCond))));
11}
12
13#[derive(Debug, Clone, Hash)]
14pub enum LoopGateRole {
15    Enter(String),
16    Exit,
17    LoopCond,
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LoopGate(LoopGateRole);
22
23impl Op for LoopGate {
24    fn name(&self) -> StaticName {
25        format!("{:?}", self.0).into()
26    }
27
28    not_a_typed_op!();
29}
30
31impl EvalOp for LoopGate {
32    fn is_stateless(&self) -> bool {
33        true
34    }
35
36    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
37        Ok(inputs)
38    }
39}
40
41impl InferenceRulesOp for LoopGate {
42    fn rules<'r, 'p: 'r, 's: 'r>(
43        &'s self,
44        s: &mut Solver<'r>,
45        inputs: &'p [TensorProxy],
46        outputs: &'p [TensorProxy],
47    ) -> InferenceResult {
48        check_input_arity(inputs, 1)?;
49        check_output_arity(outputs, 1)?;
50        s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
51        s.equals(&inputs[0].shape, &outputs[0].shape)?;
52        Ok(())
53    }
54
55    as_op!();
56}
57
58#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
59pub enum NextIterationRole {
60    Source,
61    Sink,
62}
63
64#[derive(Debug, Clone, new, Hash)]
65pub struct NextIteration {
66    name: String,
67    role: NextIterationRole,
68}
69
70impl Op for NextIteration {
71    fn name(&self) -> StaticName {
72        format!("{:?}({})", self.role, self.name).into()
73    }
74
75    not_a_typed_op!();
76}
77
78impl EvalOp for NextIteration {
79    fn is_stateless(&self) -> bool {
80        false
81    }
82
83    fn state(&self, _state: &TurnState, _id: usize) -> TractResult<Option<Box<dyn OpState>>> {
84        unimplemented!();
85    }
86}
87
88impl InferenceRulesOp for NextIteration {
89    fn rules<'r, 'p: 'r, 's: 'r>(
90        &'s self,
91        _s: &mut Solver<'r>,
92        inputs: &'p [TensorProxy],
93        outputs: &'p [TensorProxy],
94    ) -> InferenceResult {
95        match self.role {
96            NextIterationRole::Source => {
97                check_input_arity(inputs, 0)?;
98                check_output_arity(outputs, 1)?;
99            }
100            NextIterationRole::Sink => {
101                check_input_arity(inputs, 1)?;
102                check_output_arity(outputs, 0)?;
103            }
104        }
105        Ok(())
106    }
107
108    as_op!();
109}