tract_onnx/ops/rec/
lstm.rs

1use crate::model::ParsingContext;
2use crate::pb::*;
3use tract_hir::internal::*;
4use tract_hir::ops;
5use tract_hir::tract_core::ops::einsum::EinSum;
6
7use super::common::CommonRec;
8use super::common::WireBody;
9
10pub fn lstm(
11    _ctx: &ParsingContext,
12    pb: &NodeProto,
13) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
14    let lstm = LSTM {
15        f: Box::new(ops::nn::sigmoid()),
16        g: Box::new(ops::math::tanh()),
17        h: Box::new(ops::math::tanh()),
18    };
19    let common = CommonRec::from_node_and_options(pb, 3, 0, Box::new(lstm))?;
20    Ok((expand(common), vec![]))
21}
22
23#[derive(Debug, Clone)]
24pub struct LSTM {
25    pub f: Box<dyn TypedOp>,
26    pub g: Box<dyn TypedOp>,
27    pub h: Box<dyn TypedOp>,
28}
29
30impl WireBody for LSTM {
31    fn name(&self) -> &'static str {
32        "LSTM"
33    }
34
35    fn w_b_multipliers(&self) -> (usize, usize) {
36        (4, 8)
37    }
38
39    fn have_extra_c_state(&self) -> bool {
40        true
41    }
42
43    #[allow(non_snake_case)]
44    fn wire_body(&self, prefix: &str, body: &mut TypedModel) -> TractResult<()> {
45        use tract_hir::ops::{array, math};
46        macro_rules! wire {
47            ($name: ident = $op: expr, $($param: expr),*) => {
48                let $name = body.wire_node(
49                    format!("{}.{}", prefix, stringify!($name)),
50                    $op, [$($param),*].as_ref())?[0];
51            }
52        }
53
54        let Xt: OutletId = body.node_by_name("Xt").unwrap().id.into();
55        let W: OutletId = body.node_by_name("W").unwrap().id.into();
56        let R: OutletId = body.node_by_name("R").unwrap().id.into();
57        let Ht_1: OutletId = body.node_by_name("Ht_1").unwrap().id.into();
58        let Ct_1: OutletId = body.node_by_name("Ct_1").unwrap().id.into();
59        let b: Option<OutletId> = body.node_by_name("b").ok().map(|n| n.id.into());
60        let peepholes: Option<OutletId> = body.node_by_name("peepholes").ok().map(|n| n.id.into());
61
62        let h_size = body.outlet_fact(R)?.shape[1].clone();
63
64        wire!(Wi = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), W);
65        wire!(Wo = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), W);
66        wire!(Wf = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), W);
67        wire!(Wc = array::Slice::new(0, 3.to_dim() * &h_size, 4.to_dim() * &h_size), W);
68
69        wire!(Ri = array::Slice::new(0, 0.to_dim() * &h_size, 1.to_dim() * &h_size), R);
70        wire!(Ro = array::Slice::new(0, 1.to_dim() * &h_size, 2.to_dim() * &h_size), R);
71        wire!(Rf = array::Slice::new(0, 2.to_dim() * &h_size, 3.to_dim() * &h_size), R);
72        wire!(Rc = array::Slice::new(0, 3.to_dim() * &h_size, 4.to_dim() * &h_size), R);
73
74        let biases = if let Some(b) = b {
75            wire!(Wbi = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), b);
76            wire!(Wbo = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), b);
77            wire!(Wbf = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), b);
78            wire!(Wbc = array::Slice::new(1, 3.to_dim() * &h_size, 4.to_dim() * &h_size), b);
79
80            wire!(Rbi = array::Slice::new(1, 4.to_dim() * &h_size, 5.to_dim() * &h_size), b);
81            wire!(Rbo = array::Slice::new(1, 5.to_dim() * &h_size, 6.to_dim() * &h_size), b);
82            wire!(Rbf = array::Slice::new(1, 6.to_dim() * &h_size, 7.to_dim() * &h_size), b);
83            wire!(Rbc = array::Slice::new(1, 7.to_dim() * &h_size, 8.to_dim() * &h_size), b);
84
85            wire!(bi = math::add(), Wbi, Rbi);
86            wire!(bo = math::add(), Wbo, Rbo);
87            wire!(bf = math::add(), Wbf, Rbf);
88            wire!(bc = math::add(), Wbc, Rbc);
89
90            Some((bi, bo, bf, bc))
91        } else {
92            None
93        };
94
95        let peepholes = if let Some(p) = peepholes {
96            wire!(pi = array::Slice::new(1, 0.to_dim() * &h_size, 1.to_dim() * &h_size), p);
97            wire!(po = array::Slice::new(1, 1.to_dim() * &h_size, 2.to_dim() * &h_size), p);
98            wire!(pf = array::Slice::new(1, 2.to_dim() * &h_size, 3.to_dim() * &h_size), p);
99            Some((pi, po, pf))
100        } else {
101            None
102        };
103
104        let matmul_t = EinSum::new("mk,nk->mn".parse()?, f32::datum_type());
105
106        // it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
107        wire!(Xt_WiT = matmul_t.clone(), Xt, Wi);
108        wire!(Ht_1_RiT = matmul_t.clone(), Ht_1, Ri);
109        wire!(it0 = math::add(), Xt_WiT, Ht_1_RiT);
110        let mut it0 = it0;
111        if let Some(biases) = biases {
112            wire!(it_bias = math::add(), it0, biases.0);
113            it0 = it_bias;
114        };
115        if let Some(peephole) = peepholes {
116            wire!(Pi_Ct_1 = math::mul(), peephole.0, Ct_1);
117            wire!(it_peep = math::add(), Pi_Ct_1, it0);
118            it0 = it_peep;
119        }
120        wire!(it = self.f.clone(), it0);
121
122        // ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
123        wire!(Xt_WfT = matmul_t.clone(), Xt, Wf);
124        wire!(Ht_1_RfT = matmul_t.clone(), Ht_1, Rf);
125        wire!(ft0 = math::add(), Xt_WfT, Ht_1_RfT);
126        let mut ft0 = ft0;
127        if let Some(biases) = biases {
128            wire!(ft_bias = math::add(), ft0, biases.2);
129            ft0 = ft_bias;
130        };
131        if let Some(peephole) = peepholes {
132            wire!(Pf_Ct_1 = math::mul(), peephole.2, Ct_1);
133            wire!(ft_peep = math::add(), Pf_Ct_1, ft0);
134            ft0 = ft_peep;
135        }
136        wire!(ft = self.f.clone(), ft0);
137
138        // ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
139        wire!(Xt_WcT = matmul_t.clone(), Xt, Wc);
140        wire!(Ht_1_RcT = matmul_t.clone(), Ht_1, Rc);
141        wire!(ct0 = math::add(), Xt_WcT, Ht_1_RcT);
142        let mut ct0 = ct0;
143        if let Some(biases) = biases {
144            wire!(ct_bias = math::add(), ct0, biases.3);
145            ct0 = ct_bias
146        };
147        wire!(ct = self.g.clone(), ct0);
148
149        // Ct = ft (.) Ct-1 + it (.) ct
150        wire!(ft_Ct_1 = math::mul(), ft, Ct_1);
151        wire!(it_ct = math::mul(), it, ct);
152        wire!(Ct = math::add(), ft_Ct_1, it_ct);
153
154        // ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
155        wire!(Xt_WoT = matmul_t.clone(), Xt, Wo);
156        wire!(Ht_1_RoT = matmul_t, Ht_1, Ro);
157        wire!(ot0 = math::add(), Xt_WoT, Ht_1_RoT);
158        let mut ot0 = ot0;
159        if let Some(biases) = biases {
160            wire!(ot_bias = math::add(), ot0, biases.1);
161            ot0 = ot_bias
162        };
163        if let Some(peephole) = peepholes {
164            wire!(Po_Ct = math::mul(), peephole.1, Ct);
165            wire!(ot_peep = math::add(), Po_Ct, ot0);
166            ot0 = ot_peep;
167        }
168        wire!(ot = self.f.clone(), ot0);
169
170        // Ht = ot (.) h(Ct)
171        wire!(h_Ct = self.h.clone(), Ct);
172        wire!(Ht = math::mul(), ot, h_Ct);
173
174        // onnx inner interface: [batch_size, input_size]
175        // add sequence axis (chunk == 1)
176        wire!(Ht_fixed = AxisOp::Add(1), Ht);
177        wire!(Ct_fixed = AxisOp::Add(1), Ct);
178        body.set_output_outlets(&[Ht_fixed, Ct_fixed])?;
179
180        Ok(())
181    }
182}