tract_onnx/ops/rec/
gru.rs

1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4use tract_hir::ops;
5use tract_hir::tract_core::ops::einsum::EinSum;
6
7use super::common::CommonRec;
8use super::common::WireBody;
9
10pub fn gru(
11    _ctx: &ParsingContext,
12    pb: &NodeProto,
13) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
14    let gru = GRU {
15        f: Box::new(ops::nn::sigmoid()),
16        g: Box::new(ops::math::tanh()),
17        linear_before_reset: pb.get_attr("linear_before_reset").unwrap_or(false),
18    };
19    let common = CommonRec::from_node_and_options(pb, 3, 0, Box::new(gru))?;
20
21    Ok((expand(common), vec![]))
22}
23
24#[derive(Debug, Clone)]
25pub struct GRU {
26    pub f: Box<dyn TypedOp>,
27    pub g: Box<dyn TypedOp>,
28    pub linear_before_reset: bool,
29}
30
31impl WireBody for GRU {
32    fn name(&self) -> &'static str {
33        "GRU"
34    }
35
36    fn w_b_multipliers(&self) -> (usize, usize) {
37        (3, 6)
38    }
39
40    fn have_extra_c_state(&self) -> bool {
41        false
42    }
43
44    #[allow(non_snake_case)]
45    fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()> {
46        use tract_hir::ops::{array, math};
47        macro_rules! wire {
48            ($name: ident = $op: expr, $($param: expr),*) => {
49                let $name = body.wire_node(
50                    format!("{}.{}", prefix, stringify!($name)),
51                    $op, [$($param),*].as_ref())?[0];
52            }
53        }
54
55        let Xt: OutletId = body.node_by_name("Xt").unwrap().id.into();
56        let W: OutletId = body.node_by_name("W").unwrap().id.into();
57        let R: OutletId = body.node_by_name("R").unwrap().id.into();
58        let Ht_1: OutletId = body.node_by_name("Ht_1").unwrap().id.into();
59        let b: Option<OutletId> = body.node_by_name("b").ok().map(|n| n.id.into());
60
61        let h_size = body.outlet_fact(R)?.shape[1].clone();
62
63        wire!(Rz = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), R);
64        wire!(Rr = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), R);
65        wire!(Rh = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), R);
66
67        wire!(Wz = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), W);
68        wire!(Wr = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), W);
69        wire!(Wh = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), W);
70
71        let matmul_t = EinSum::new("mk,nk->mn".parse()?, f32::datum_type());
72
73        // zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
74        wire!(Xt_WzT = matmul_t.clone(), Xt, Wz);
75        wire!(Ht_1_RzT = matmul_t.clone(), Ht_1, Rz);
76        wire!(zt0 = math::add(), Xt_WzT, Ht_1_RzT);
77        let mut zt0 = zt0;
78        if let Some(b) = b {
79            wire!(Wbz = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), b);
80            wire!(Rbz = array::Slice::new(1, 3.to_dim() * &h_size, 4.to_dim() * &h_size), b);
81            wire!(Wbz_Rbz = math::add(), Wbz, Rbz);
82            wire!(zt0_biased = math::add(), zt0, Wbz_Rbz);
83            zt0 = zt0_biased
84        };
85        wire!(zt = self.f.clone(), zt0);
86
87        // rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
88        wire!(Xt_WrT = matmul_t.clone(), Xt, Wr);
89        wire!(Ht_1_RrT = matmul_t.clone(), Ht_1, Rr);
90        wire!(rt0 = math::add(), Xt_WrT, Ht_1_RrT);
91        let mut rt0 = rt0;
92        if let Some(b) = b {
93            wire!(Wbr = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), b);
94            wire!(Rbr = array::Slice::new(1, 4.to_dim() * &h_size, 5.to_dim() * &h_size), b);
95            wire!(Wbr_Rbr = math::add(), Wbr, Rbr);
96            wire!(rt0_biased = math::add(), rt0, Wbr_Rbr);
97            rt0 = rt0_biased
98        };
99        wire!(rt = self.f.clone(), rt0);
100
101        // ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0
102        // ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0
103        wire!(Xt_WhT = matmul_t.clone(), Xt, Wh);
104        let rt_Ht_1_RhT_Rbh = if self.linear_before_reset {
105            // rt (.) (Ht-1*(Rh^T) + Rbh)
106            wire!(Ht_1_RhT = matmul_t, Ht_1, Rh);
107            let Ht_1_RhT_Rbh = if let Some(b) = b {
108                wire!(Rbh = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
109                wire!(Ht_1_RhT_Rbh = math::add(), Ht_1_RhT, Rbh);
110                Ht_1_RhT_Rbh
111            } else {
112                Ht_1_RhT
113            };
114            wire!(rt_Ht_1_RhT_Rbh = math::mul(), rt, Ht_1_RhT_Rbh);
115            rt_Ht_1_RhT_Rbh
116        } else {
117            // (rt (.) Ht-1)*(Rh^T) + Rbh
118            wire!(rt_Ht_1 = math::mul(), rt, Ht_1);
119            wire!(rt_Ht_1_RhT = matmul_t, rt_Ht_1, Rh);
120            if let Some(b) = b {
121                wire!(Rbh = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
122                wire!(rt_Ht_1_RhT_Rbh = math::add(), rt_Ht_1_RhT, Rbh);
123                rt_Ht_1_RhT_Rbh
124            } else {
125                rt_Ht_1_RhT
126            }
127        };
128        wire!(ht0 = math::add(), Xt_WhT, rt_Ht_1_RhT_Rbh);
129        let mut ht0 = ht0;
130        if let Some(b) = b {
131            wire!(Wbh = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), b);
132            wire!(ht0_biased = math::add(), ht0, Wbh);
133            ht0 = ht0_biased
134        }
135        wire!(ht = self.g.clone(), ht0);
136
137        // Ht = (1 - zt) (.) ht + zt (.) Ht-1
138        let one: OutletId = body.add_const("one", tensor2(&[[1f32]]))?;
139        wire!(one_sub_zt = math::sub(), one, zt);
140        wire!(one_sub_zt_ht = math::mul(), one_sub_zt, ht);
141        wire!(zt_Ht_1 = math::mul(), zt, Ht_1);
142        wire!(Ht = math::add(), one_sub_zt_ht, zt_Ht_1);
143
144        /*
145        // Ht = ht + (- (zt (.) ht) + zt (.) Ht-1)
146        wire!(zt_ht = math::mul(), zt, ht);
147        wire!(zt_Ht_1 = math::mul(), zt, Ht_1);
148        wire!(zt_Ht_1_sub_zt_ht = math::sub(), zt_Ht_1, zt_ht);
149        wire!(Ht = math::add(), ht, zt_Ht_1_sub_zt_ht);
150        */
151
152        // Ht = ht - (zt (.) ht) + zt (.) Ht-1)
153        /*
154        wire!(zt_ht = math::mul(), zt, ht);
155        wire!(zt_Ht_1 = math::mul(), zt, Ht_1);
156        wire!(ht_zt_ht = math::sub(), ht, zt_ht);
157        wire!(Ht = math::add(), ht_zt_ht, zt_Ht_1);
158        */
159
160        wire!(y_h = AxisOp::Add(1), Ht);
161        body.set_output_outlets(&[y_h])?;
162        Ok(())
163    }
164}