1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4use tract_hir::ops;
5use tract_hir::tract_core::ops::einsum::EinSum;
6
7use super::common::CommonRec;
8use super::common::WireBody;
9
10pub fn gru(
11 _ctx: &ParsingContext,
12 pb: &NodeProto,
13) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
14 let gru = GRU {
15 f: Box::new(ops::nn::sigmoid()),
16 g: Box::new(ops::math::tanh()),
17 linear_before_reset: pb.get_attr("linear_before_reset").unwrap_or(false),
18 };
19 let common = CommonRec::from_node_and_options(pb, 3, 0, Box::new(gru))?;
20
21 Ok((expand(common), vec![]))
22}
23
24#[derive(Debug, Clone)]
25pub struct GRU {
26 pub f: Box<dyn TypedOp>,
27 pub g: Box<dyn TypedOp>,
28 pub linear_before_reset: bool,
29}
30
31impl WireBody for GRU {
32 fn name(&self) -> &'static str {
33 "GRU"
34 }
35
36 fn w_b_multipliers(&self) -> (usize, usize) {
37 (3, 6)
38 }
39
40 fn have_extra_c_state(&self) -> bool {
41 false
42 }
43
44 #[allow(non_snake_case)]
45 fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()> {
46 use tract_hir::ops::{array, math};
47 macro_rules! wire {
48 ($name: ident = $op: expr, $($param: expr),*) => {
49 let $name = body.wire_node(
50 format!("{}.{}", prefix, stringify!($name)),
51 $op, [$($param),*].as_ref())?[0];
52 }
53 }
54
55 let Xt: OutletId = body.node_by_name("Xt").unwrap().id.into();
56 let W: OutletId = body.node_by_name("W").unwrap().id.into();
57 let R: OutletId = body.node_by_name("R").unwrap().id.into();
58 let Ht_1: OutletId = body.node_by_name("Ht_1").unwrap().id.into();
59 let b: Option<OutletId> = body.node_by_name("b").ok().map(|n| n.id.into());
60
61 let h_size = body.outlet_fact(R)?.shape[1].clone();
62
63 wire!(Rz = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), R);
64 wire!(Rr = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), R);
65 wire!(Rh = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), R);
66
67 wire!(Wz = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), W);
68 wire!(Wr = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), W);
69 wire!(Wh = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), W);
70
71 let matmul_t = EinSum::new("mk,nk->mn".parse()?, f32::datum_type());
72
73 wire!(Xt_WzT = matmul_t.clone(), Xt, Wz);
75 wire!(Ht_1_RzT = matmul_t.clone(), Ht_1, Rz);
76 wire!(zt0 = math::add(), Xt_WzT, Ht_1_RzT);
77 let mut zt0 = zt0;
78 if let Some(b) = b {
79 wire!(Wbz = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), b);
80 wire!(Rbz = array::Slice::new(1, 3.to_dim() * &h_size, 4.to_dim() * &h_size), b);
81 wire!(Wbz_Rbz = math::add(), Wbz, Rbz);
82 wire!(zt0_biased = math::add(), zt0, Wbz_Rbz);
83 zt0 = zt0_biased
84 };
85 wire!(zt = self.f.clone(), zt0);
86
87 wire!(Xt_WrT = matmul_t.clone(), Xt, Wr);
89 wire!(Ht_1_RrT = matmul_t.clone(), Ht_1, Rr);
90 wire!(rt0 = math::add(), Xt_WrT, Ht_1_RrT);
91 let mut rt0 = rt0;
92 if let Some(b) = b {
93 wire!(Wbr = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), b);
94 wire!(Rbr = array::Slice::new(1, 4.to_dim() * &h_size, 5.to_dim() * &h_size), b);
95 wire!(Wbr_Rbr = math::add(), Wbr, Rbr);
96 wire!(rt0_biased = math::add(), rt0, Wbr_Rbr);
97 rt0 = rt0_biased
98 };
99 wire!(rt = self.f.clone(), rt0);
100
101 wire!(Xt_WhT = matmul_t.clone(), Xt, Wh);
104 let rt_Ht_1_RhT_Rbh = if self.linear_before_reset {
105 wire!(Ht_1_RhT = matmul_t, Ht_1, Rh);
107 let Ht_1_RhT_Rbh = if let Some(b) = b {
108 wire!(Rbh = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
109 wire!(Ht_1_RhT_Rbh = math::add(), Ht_1_RhT, Rbh);
110 Ht_1_RhT_Rbh
111 } else {
112 Ht_1_RhT
113 };
114 wire!(rt_Ht_1_RhT_Rbh = math::mul(), rt, Ht_1_RhT_Rbh);
115 rt_Ht_1_RhT_Rbh
116 } else {
117 wire!(rt_Ht_1 = math::mul(), rt, Ht_1);
119 wire!(rt_Ht_1_RhT = matmul_t, rt_Ht_1, Rh);
120 if let Some(b) = b {
121 wire!(Rbh = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
122 wire!(rt_Ht_1_RhT_Rbh = math::add(), rt_Ht_1_RhT, Rbh);
123 rt_Ht_1_RhT_Rbh
124 } else {
125 rt_Ht_1_RhT
126 }
127 };
128 wire!(ht0 = math::add(), Xt_WhT, rt_Ht_1_RhT_Rbh);
129 let mut ht0 = ht0;
130 if let Some(b) = b {
131 wire!(Wbh = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), b);
132 wire!(ht0_biased = math::add(), ht0, Wbh);
133 ht0 = ht0_biased
134 }
135 wire!(ht = self.g.clone(), ht0);
136
137 let one: OutletId = body.add_const("one", tensor2(&[[1f32]]))?;
139 wire!(one_sub_zt = math::sub(), one, zt);
140 wire!(one_sub_zt_ht = math::mul(), one_sub_zt, ht);
141 wire!(zt_Ht_1 = math::mul(), zt, Ht_1);
142 wire!(Ht = math::add(), one_sub_zt_ht, zt_Ht_1);
143
144 wire!(y_h = AxisOp::Add(1), Ht);
161 body.set_output_outlets(&[y_h])?;
162 Ok(())
163 }
164}