tract_tensorflow/ops/
control_flow.rs1use tract_hir::internal::*;
2
3use crate::model::TfOpRegister;
4
5pub fn register_all_ops(reg: &mut TfOpRegister) {
6 reg.insert("Enter", |_, node| {
7 Ok(Box::new(LoopGate(LoopGateRole::Enter(node.get_attr_str("frame_name")?))))
8 });
9 reg.insert("Exit", |_, _| Ok(Box::new(LoopGate(LoopGateRole::Exit))));
10 reg.insert("LoopCond", |_, _| Ok(Box::new(LoopGate(LoopGateRole::LoopCond))));
11}
12
13#[derive(Debug, Clone, Hash)]
14pub enum LoopGateRole {
15 Enter(String),
16 Exit,
17 LoopCond,
18}
19
20#[derive(Debug, Clone, Hash)]
21pub struct LoopGate(LoopGateRole);
22
23
24
25impl Op for LoopGate {
26 fn name(&self) -> StaticName {
27 format!("{:?}", self.0).into()
28 }
29
30 not_a_typed_op!();
31}
32
33impl EvalOp for LoopGate {
34 fn is_stateless(&self) -> bool {
35 true
36 }
37
38 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
39 Ok(inputs)
40 }
41}
42
43impl InferenceRulesOp for LoopGate {
44 fn rules<'r, 'p: 'r, 's: 'r>(
45 &'s self,
46 s: &mut Solver<'r>,
47 inputs: &'p [TensorProxy],
48 outputs: &'p [TensorProxy],
49 ) -> InferenceResult {
50 check_input_arity(inputs, 1)?;
51 check_output_arity(outputs, 1)?;
52 s.equals(&inputs[0].datum_type, &outputs[0].datum_type)?;
53 s.equals(&inputs[0].shape, &outputs[0].shape)?;
54 Ok(())
55 }
56
57 as_op!();
58}
59
60#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
61pub enum NextIterationRole {
62 Source,
63 Sink,
64}
65
66#[derive(Debug, Clone, new, Hash)]
67pub struct NextIteration {
68 name: String,
69 role: NextIterationRole,
70}
71
72
73
74impl Op for NextIteration {
75 fn name(&self) -> StaticName {
76 format!("{:?}({})", self.role, self.name).into()
77 }
78
79 not_a_typed_op!();
80}
81
82impl EvalOp for NextIteration {
83 fn is_stateless(&self) -> bool {
84 false
85 }
86
87 fn state(
88 &self,
89 _state: &mut SessionState,
90 _id: usize,
91 ) -> TractResult<Option<Box<dyn OpState>>> {
92 unimplemented!();
93 }
94}
95
96impl InferenceRulesOp for NextIteration {
97 fn rules<'r, 'p: 'r, 's: 'r>(
98 &'s self,
99 _s: &mut Solver<'r>,
100 inputs: &'p [TensorProxy],
101 outputs: &'p [TensorProxy],
102 ) -> InferenceResult {
103 match self.role {
104 NextIterationRole::Source => {
105 check_input_arity(inputs, 0)?;
106 check_output_arity(outputs, 1)?;
107 }
108 NextIterationRole::Sink => {
109 check_input_arity(inputs, 1)?;
110 check_output_arity(outputs, 0)?;
111 }
112 }
113 Ok(())
114 }
115
116 as_op!();
117}