Skip to main content

tract_core/ops/
lstm_cell.rs

1use crate::internal::*;
2use tract_linalg::element_wise::ElementWise;
3
4/// Fused LSTM cell epilogue.
5///
6/// Given the combined gate pre-activations
7/// `preact = Xt·Wᵀ + Ht-1·Rᵀ + bias` of shape `[batch, 4*hidden]` (ONNX gate
8/// order i, o, f, c) and the previous cell state `c_prev` `[batch, hidden]`,
9/// computes the new hidden `Ht` and cell `Ct` in a SINGLE fused pass.
10///
11/// This collapses the per-gate `Sigmoid`/`Tanh` + elementwise `Mul`/`Add`
12/// chain (≈ 15 separately-dispatched ops, each materialising an intermediate
13/// tensor) into one op — the dominant non-matmul cost for streaming LSTM
14/// inference. Standard activations only (`f = sigmoid`, `g = h = tanh`) and no
15/// peepholes; the importer falls back to the decomposed form otherwise.
16///
17/// Activations use tract's vectorised `sigmoid`/`tanh` linalg kernels
18/// (NEON on aarch64) applied to contiguous gate slices, so the output is
19/// numerically identical to the decomposed Sigmoid/Tanh path while collapsing
20/// the per-gate dispatch into one op. Runs in either `f32` or `f16`, matching
21/// the dtype the precision transform settled the surrounding graph on.
22#[derive(Debug, Clone, Hash, PartialEq, Eq)]
23pub struct LstmEpilogue {
24    pub hidden: usize,
25}
26
27impl Op for LstmEpilogue {
28    fn name(&self) -> StaticName {
29        "LstmEpilogue".into()
30    }
31
32    fn info(&self) -> TractResult<Vec<String>> {
33        Ok(vec![format!("hidden={}", self.hidden)])
34    }
35
36    op_as_typed_op!();
37}
38
39impl EvalOp for LstmEpilogue {
40    fn is_stateless(&self) -> bool {
41        true
42    }
43
44    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
45        // Dispatch on the dtype the precision transform left the graph in. The
46        // ONNX LSTM is f32-native, but `FloatPrecisionTranslator` rewrites the
47        // whole float graph (including this op's inputs) to f16, so we must run
48        // in whichever float type actually arrives — reading an f16 buffer as
49        // f32 would walk off the end of the allocation.
50        let ops = tract_linalg::ops();
51        match inputs[0].datum_type().unquantized() {
52            DatumType::F32 => self.eval_t::<f32>(inputs, (ops.sigmoid_f32)(), (ops.tanh_f32)()),
53            DatumType::F16 => self.eval_t::<f16>(inputs, (ops.sigmoid_f16)(), (ops.tanh_f16)()),
54            dt => bail!("LstmEpilogue only supports f32 and f16 preactivations, got {dt:?}"),
55        }
56    }
57}
58
59impl LstmEpilogue {
60    fn eval_t<T>(
61        &self,
62        inputs: TVec<TValue>,
63        sigmoid: Box<dyn ElementWise<T>>,
64        tanh: Box<dyn ElementWise<T>>,
65    ) -> TractResult<TVec<TValue>>
66    where
67        T: Datum + Copy + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
68    {
69        let h = self.hidden;
70        let c_prev = &inputs[1]; // [.., h]
71        let cp = unsafe { c_prev.as_slice_unchecked::<T>() };
72        let rows = inputs[0].len() / (4 * h); // any leading-dim layout, row-major
73        // Mutable copy of preact so the activation kernels run in place.
74        let mut pre_t = inputs[0].clone().into_tensor();
75        let pre = unsafe { pre_t.as_slice_mut_unchecked::<T>() };
76        let mut ht = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
77        let mut ct = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
78        {
79            let hs = unsafe { ht.as_slice_mut_unchecked::<T>() };
80            let cs = unsafe { ct.as_slice_mut_unchecked::<T>() };
81            for r in 0..rows {
82                let pb = r * 4 * h;
83                let cb = r * h;
84                let row = &mut pre[pb..pb + 4 * h];
85                // gate order i,o,f,c: sigmoid the i,o,f block, tanh the c block
86                sigmoid.run(&mut row[0..3 * h])?;
87                tanh.run(&mut row[3 * h..4 * h])?;
88                // Ct = ft*c_prev + it*cc  (it=row[j], ot=row[h+j], ft=row[2h+j], cc=row[3h+j])
89                for j in 0..h {
90                    cs[cb + j] = row[2 * h + j] * cp[cb + j] + row[j] * row[3 * h + j];
91                }
92                // Ht = ot * tanh(Ct): stage tanh(Ct) in hs, then scale by ot
93                hs[cb..cb + h].copy_from_slice(&cs[cb..cb + h]);
94                tanh.run(&mut hs[cb..cb + h])?;
95                for j in 0..h {
96                    hs[cb + j] = hs[cb + j] * row[h + j];
97                }
98            }
99        }
100        Ok(tvec!(ht.into_tvalue(), ct.into_tvalue()))
101    }
102}
103
104impl TypedOp for LstmEpilogue {
105    as_op!();
106
107    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108        ensure!(inputs.len() == 2, "LstmEpilogue expects [preact, c_prev]");
109        // Ht and Ct share c_prev's shape and dtype ([.., hidden]).
110        let c_prev = inputs[1];
111        let fact = c_prev.datum_type.fact(c_prev.shape.clone());
112        Ok(tvec!(fact.clone(), fact))
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119
120    // Fused epilogue must match a scalar reference LSTM cell (catches gate-order
121    // and cell/hidden-formula bugs). Tolerance covers the rational sigmoid/tanh
122    // approximation (~1e-7) vs the exact reference. Multi-row exercises batch.
123    #[test]
124    fn epilogue_matches_scalar_reference() {
125        let h = 6usize;
126        let batch = 3usize;
127        let preact: Vec<f32> =
128            (0..batch * 4 * h).map(|i| ((i * 7 % 29) as f32 - 14.0) * 0.25).collect();
129        let cprev: Vec<f32> = (0..batch * h).map(|i| ((i * 5 % 17) as f32 - 8.0) * 0.2).collect();
130        let pre_t = Tensor::from_shape(&[batch, 4 * h], &preact).unwrap();
131        let cprev_t = Tensor::from_shape(&[batch, h], &cprev).unwrap();
132        let op = LstmEpilogue { hidden: h };
133        let out = op.eval(tvec!(pre_t.into_tvalue(), cprev_t.into_tvalue())).unwrap();
134        let ht = unsafe { out[0].as_slice_unchecked::<f32>() };
135        let ct = unsafe { out[1].as_slice_unchecked::<f32>() };
136
137        let sig = |x: f32| 1.0 / (1.0 + (-x).exp());
138        for r in 0..batch {
139            for j in 0..h {
140                let p = r * 4 * h; // gate order on the 4*h axis: i, o, f, c
141                let it = sig(preact[p + j]);
142                let ot = sig(preact[p + h + j]);
143                let ft = sig(preact[p + 2 * h + j]);
144                let cc = preact[p + 3 * h + j].tanh();
145                let c_ref = ft * cprev[r * h + j] + it * cc;
146                let h_ref = ot * c_ref.tanh();
147                assert!((ct[r * h + j] - c_ref).abs() < 1e-3, "Ct mismatch at ({r},{j})");
148                assert!((ht[r * h + j] - h_ref).abs() < 1e-3, "Ht mismatch at ({r},{j})");
149            }
150        }
151    }
152}