tract_core/ops/
lstm_cell.rs1use crate::internal::*;
2use tract_linalg::element_wise::ElementWise;
3
4#[derive(Debug, Clone, Hash, PartialEq, Eq)]
23pub struct LstmEpilogue {
24 pub hidden: usize,
25}
26
27impl Op for LstmEpilogue {
28 fn name(&self) -> StaticName {
29 "LstmEpilogue".into()
30 }
31
32 fn info(&self) -> TractResult<Vec<String>> {
33 Ok(vec![format!("hidden={}", self.hidden)])
34 }
35
36 op_as_typed_op!();
37}
38
39impl EvalOp for LstmEpilogue {
40 fn is_stateless(&self) -> bool {
41 true
42 }
43
44 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
45 let ops = tract_linalg::ops();
51 match inputs[0].datum_type().unquantized() {
52 DatumType::F32 => self.eval_t::<f32>(inputs, (ops.sigmoid_f32)(), (ops.tanh_f32)()),
53 DatumType::F16 => self.eval_t::<f16>(inputs, (ops.sigmoid_f16)(), (ops.tanh_f16)()),
54 dt => bail!("LstmEpilogue only supports f32 and f16 preactivations, got {dt:?}"),
55 }
56 }
57}
58
59impl LstmEpilogue {
60 fn eval_t<T>(
61 &self,
62 inputs: TVec<TValue>,
63 sigmoid: Box<dyn ElementWise<T>>,
64 tanh: Box<dyn ElementWise<T>>,
65 ) -> TractResult<TVec<TValue>>
66 where
67 T: Datum + Copy + std::ops::Mul<Output = T> + std::ops::Add<Output = T>,
68 {
69 let h = self.hidden;
70 let c_prev = &inputs[1]; let cp = unsafe { c_prev.as_slice_unchecked::<T>() };
72 let rows = inputs[0].len() / (4 * h); let mut pre_t = inputs[0].clone().into_tensor();
75 let pre = unsafe { pre_t.as_slice_mut_unchecked::<T>() };
76 let mut ht = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
77 let mut ct = unsafe { Tensor::uninitialized_dt(T::datum_type(), c_prev.shape())? };
78 {
79 let hs = unsafe { ht.as_slice_mut_unchecked::<T>() };
80 let cs = unsafe { ct.as_slice_mut_unchecked::<T>() };
81 for r in 0..rows {
82 let pb = r * 4 * h;
83 let cb = r * h;
84 let row = &mut pre[pb..pb + 4 * h];
85 sigmoid.run(&mut row[0..3 * h])?;
87 tanh.run(&mut row[3 * h..4 * h])?;
88 for j in 0..h {
90 cs[cb + j] = row[2 * h + j] * cp[cb + j] + row[j] * row[3 * h + j];
91 }
92 hs[cb..cb + h].copy_from_slice(&cs[cb..cb + h]);
94 tanh.run(&mut hs[cb..cb + h])?;
95 for j in 0..h {
96 hs[cb + j] = hs[cb + j] * row[h + j];
97 }
98 }
99 }
100 Ok(tvec!(ht.into_tvalue(), ct.into_tvalue()))
101 }
102}
103
104impl TypedOp for LstmEpilogue {
105 as_op!();
106
107 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
108 ensure!(inputs.len() == 2, "LstmEpilogue expects [preact, c_prev]");
109 let c_prev = inputs[1];
111 let fact = c_prev.datum_type.fact(c_prev.shape.clone());
112 Ok(tvec!(fact.clone(), fact))
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119
120 #[test]
124 fn epilogue_matches_scalar_reference() {
125 let h = 6usize;
126 let batch = 3usize;
127 let preact: Vec<f32> =
128 (0..batch * 4 * h).map(|i| ((i * 7 % 29) as f32 - 14.0) * 0.25).collect();
129 let cprev: Vec<f32> = (0..batch * h).map(|i| ((i * 5 % 17) as f32 - 8.0) * 0.2).collect();
130 let pre_t = Tensor::from_shape(&[batch, 4 * h], &preact).unwrap();
131 let cprev_t = Tensor::from_shape(&[batch, h], &cprev).unwrap();
132 let op = LstmEpilogue { hidden: h };
133 let out = op.eval(tvec!(pre_t.into_tvalue(), cprev_t.into_tvalue())).unwrap();
134 let ht = unsafe { out[0].as_slice_unchecked::<f32>() };
135 let ct = unsafe { out[1].as_slice_unchecked::<f32>() };
136
137 let sig = |x: f32| 1.0 / (1.0 + (-x).exp());
138 for r in 0..batch {
139 for j in 0..h {
140 let p = r * 4 * h; let it = sig(preact[p + j]);
142 let ot = sig(preact[p + h + j]);
143 let ft = sig(preact[p + 2 * h + j]);
144 let cc = preact[p + 3 * h + j].tanh();
145 let c_ref = ft * cprev[r * h + j] + it * cc;
146 let h_ref = ot * c_ref.tanh();
147 assert!((ct[r * h + j] - c_ref).abs() < 1e-3, "Ct mismatch at ({r},{j})");
148 assert!((ht[r * h + j] - h_ref).abs() < 1e-3, "Ht mismatch at ({r},{j})");
149 }
150 }
151 }
152}