1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4use tract_hir::ops;
5use tract_hir::tract_core::ops::einsum::EinSum;
6
7use super::common::CommonRec;
8use super::common::WireBody;
9
10pub fn lstm(
11 _ctx: &ParsingContext,
12 pb: &NodeProto,
13) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
14 let lstm = LSTM {
15 f: Box::new(ops::nn::sigmoid()),
16 g: Box::new(ops::math::tanh()),
17 h: Box::new(ops::math::tanh()),
18 };
19 let common = CommonRec::from_node_and_options(pb, 3, 0, Box::new(lstm))?;
20 Ok((expand(common), vec![]))
21}
22
23#[derive(Debug, Clone)]
24pub struct LSTM {
25 pub f: Box<dyn TypedOp>,
26 pub g: Box<dyn TypedOp>,
27 pub h: Box<dyn TypedOp>,
28}
29
30impl WireBody for LSTM {
31 fn name(&self) -> &'static str {
32 "LSTM"
33 }
34
35 fn w_b_multipliers(&self) -> (usize, usize) {
36 (4, 8)
37 }
38
39 fn have_extra_c_state(&self) -> bool {
40 true
41 }
42
43 #[allow(non_snake_case)]
44 fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()> {
45 use tract_hir::ops::{array, math};
46 macro_rules! wire {
47 ($name: ident = $op: expr, $($param: expr),*) => {
48 let $name = body.wire_node(
49 format!("{}.{}", prefix, stringify!($name)),
50 $op, [$($param),*].as_ref())?[0];
51 }
52 }
53
54 let Xt: OutletId = body.node_by_name("Xt").unwrap().id.into();
55 let W: OutletId = body.node_by_name("W").unwrap().id.into();
56 let R: OutletId = body.node_by_name("R").unwrap().id.into();
57 let Ht_1: OutletId = body.node_by_name("Ht_1").unwrap().id.into();
58 let Ct_1: OutletId = body.node_by_name("Ct_1").unwrap().id.into();
59 let b: Option<OutletId> = body.node_by_name("b").ok().map(|n| n.id.into());
60 let peepholes: Option<OutletId> = body.node_by_name("peepholes").ok().map(|n| n.id.into());
61
62 let h_size = body.outlet_fact(R)?.shape[1].clone();
63
64 wire!(Wi = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), W);
65 wire!(Wo = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), W);
66 wire!(Wf = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), W);
67 wire!(Wc = array::Slice::new(0, 3.to_dim() * &h_size, 4.to_dim() * &h_size), W);
68
69 wire!(Ri = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), R);
70 wire!(Ro = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), R);
71 wire!(Rf = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), R);
72 wire!(Rc = array::Slice::new(0, 3.to_dim() * &h_size, 4.to_dim() * &h_size), R);
73
74 let biases = if let Some(b) = b {
75 wire!(Wbi = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), b);
76 wire!(Wbo = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), b);
77 wire!(Wbf = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), b);
78 wire!(Wbc = array::Slice::new(1, 3.to_dim() * &h_size, 4.to_dim() * &h_size), b);
79
80 wire!(Rbi = array::Slice::new(1, 4.to_dim() * &h_size, 5.to_dim() * &h_size), b);
81 wire!(Rbo = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
82 wire!(Rbf = array::Slice::new(1, 6.to_dim() * &h_size, 7.to_dim() * &h_size), b);
83 wire!(Rbc = array::Slice::new(1, 7.to_dim() * &h_size, 8.to_dim() * &h_size), b);
84
85 wire!(bi = math::add(), Wbi, Rbi);
86 wire!(bo = math::add(), Wbo, Rbo);
87 wire!(bf = math::add(), Wbf, Rbf);
88 wire!(bc = math::add(), Wbc, Rbc);
89
90 Some((bi, bo, bf, bc))
91 } else {
92 None
93 };
94
95 let peepholes = if let Some(p) = peepholes {
96 wire!(pi = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), p);
97 wire!(po = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), p);
98 wire!(pf = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), p);
99 Some((pi, po, pf))
100 } else {
101 None
102 };
103
104 let matmul_t = EinSum::new("mk,nk->mn".parse()?, f32::datum_type());
105
106 wire!(Xt_WiT = matmul_t.clone(), Xt, Wi);
108 wire!(Ht_1_RiT = matmul_t.clone(), Ht_1, Ri);
109 wire!(it0 = math::add(), Xt_WiT, Ht_1_RiT);
110 let mut it0 = it0;
111 if let Some(biases) = biases {
112 wire!(it_bias = math::add(), it0, biases.0);
113 it0 = it_bias;
114 };
115 if let Some(peephole) = peepholes {
116 wire!(Pi_Ct_1 = math::mul(), peephole.0, Ct_1);
117 wire!(it_peep = math::add(), Pi_Ct_1, it0);
118 it0 = it_peep;
119 }
120 wire!(it = self.f.clone(), it0);
121
122 wire!(Xt_WfT = matmul_t.clone(), Xt, Wf);
124 wire!(Ht_1_RfT = matmul_t.clone(), Ht_1, Rf);
125 wire!(ft0 = math::add(), Xt_WfT, Ht_1_RfT);
126 let mut ft0 = ft0;
127 if let Some(biases) = biases {
128 wire!(ft_bias = math::add(), ft0, biases.2);
129 ft0 = ft_bias;
130 };
131 if let Some(peephole) = peepholes {
132 wire!(Pf_Ct_1 = math::mul(), peephole.2, Ct_1);
133 wire!(ft_peep = math::add(), Pf_Ct_1, ft0);
134 ft0 = ft_peep;
135 }
136 wire!(ft = self.f.clone(), ft0);
137
138 wire!(Xt_WcT = matmul_t.clone(), Xt, Wc);
140 wire!(Ht_1_RcT = matmul_t.clone(), Ht_1, Rc);
141 wire!(ct0 = math::add(), Xt_WcT, Ht_1_RcT);
142 let mut ct0 = ct0;
143 if let Some(biases) = biases {
144 wire!(ct_bias = math::add(), ct0, biases.3);
145 ct0 = ct_bias
146 };
147 wire!(ct = self.g.clone(), ct0);
148
149 wire!(ft_Ct_1 = math::mul(), ft, Ct_1);
151 wire!(it_ct = math::mul(), it, ct);
152 wire!(Ct = math::add(), ft_Ct_1, it_ct);
153
154 wire!(Xt_WoT = matmul_t.clone(), Xt, Wo);
156 wire!(Ht_1_RoT = matmul_t, Ht_1, Ro);
157 wire!(ot0 = math::add(), Xt_WoT, Ht_1_RoT);
158 let mut ot0 = ot0;
159 if let Some(biases) = biases {
160 wire!(ot_bias = math::add(), ot0, biases.1);
161 ot0 = ot_bias
162 };
163 if let Some(peephole) = peepholes {
164 wire!(Po_Ct = math::mul(), peephole.1, Ct);
165 wire!(ot_peep = math::add(), Po_Ct, ot0);
166 ot0 = ot_peep;
167 }
168 wire!(ot = self.f.clone(), ot0);
169
170 wire!(h_Ct = self.h.clone(), Ct);
172 wire!(Ht = math::mul(), ot, h_Ct);
173
174 wire!(Ht_fixed = AxisOp::Add(1), Ht);
177 wire!(Ct_fixed = AxisOp::Add(1), Ct);
178 body.set_output_outlets(&[Ht_fixed, Ct_fixed])?;
179
180 Ok(())
181 }
182}