Skip to main content

yscv_model/
recurrent.rs

1use yscv_kernels::matmul_2d;
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6/// Vanilla RNN cell: h_t = tanh(x_t @ W_ih + h_{t-1} @ W_hh + b).
7#[derive(Debug, Clone)]
8pub struct RnnCell {
9    pub w_ih: Tensor, // [input_size, hidden_size]
10    pub w_hh: Tensor, // [hidden_size, hidden_size]
11    pub bias: Tensor, // [hidden_size]
12    pub hidden_size: usize,
13}
14
15impl RnnCell {
16    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
17        Ok(Self {
18            w_ih: Tensor::from_vec(
19                vec![input_size, hidden_size],
20                vec![0.0; input_size * hidden_size],
21            )?,
22            w_hh: Tensor::from_vec(
23                vec![hidden_size, hidden_size],
24                vec![0.0; hidden_size * hidden_size],
25            )?,
26            bias: Tensor::from_vec(vec![hidden_size], vec![0.0; hidden_size])?,
27            hidden_size,
28        })
29    }
30
31    /// Forward one timestep: x `[batch, input_size]`, h `[batch, hidden_size]` -> h' `[batch, hidden_size]`.
32    pub fn forward(&self, x: &Tensor, h: &Tensor) -> Result<Tensor, ModelError> {
33        let xw = matmul_2d(x, &self.w_ih)?;
34        let hw = matmul_2d(h, &self.w_hh)?;
35        let sum = xw.add(&hw)?;
36        let sum = sum.add(&self.bias.unsqueeze(0)?)?;
37        let data: Vec<f32> = sum.data().iter().map(|&v| v.tanh()).collect();
38        Tensor::from_vec(sum.shape().to_vec(), data).map_err(Into::into)
39    }
40}
41
42/// LSTM cell: standard gates (input, forget, cell, output).
43#[derive(Debug, Clone)]
44pub struct LstmCell {
45    pub w_ih: Tensor, // [input_size, 4 * hidden_size]
46    pub w_hh: Tensor, // [hidden_size, 4 * hidden_size]
47    pub bias: Tensor, // [4 * hidden_size]
48    pub hidden_size: usize,
49}
50
51impl LstmCell {
52    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
53        let h4 = 4 * hidden_size;
54        Ok(Self {
55            w_ih: Tensor::from_vec(vec![input_size, h4], vec![0.0; input_size * h4])?,
56            w_hh: Tensor::from_vec(vec![hidden_size, h4], vec![0.0; hidden_size * h4])?,
57            bias: Tensor::from_vec(vec![h4], vec![0.0; h4])?,
58            hidden_size,
59        })
60    }
61
62    /// Forward one timestep. Returns `(h_new, c_new)`.
63    ///
64    /// x: `[batch, input_size]`, h: `[batch, hidden_size]`, c: `[batch, hidden_size]`.
65    pub fn forward(
66        &self,
67        x: &Tensor,
68        h: &Tensor,
69        c: &Tensor,
70    ) -> Result<(Tensor, Tensor), ModelError> {
71        let batch = x.shape()[0];
72        let hs = self.hidden_size;
73
74        let gates = {
75            let xw = matmul_2d(x, &self.w_ih)?;
76            let hw = matmul_2d(h, &self.w_hh)?;
77            let g = xw.add(&hw)?;
78            g.add(&self.bias.unsqueeze(0)?)?
79        };
80
81        let gd = gates.data();
82        let cd = c.data();
83        let mut h_new = Vec::with_capacity(batch * hs);
84        let mut c_new = Vec::with_capacity(batch * hs);
85
86        for b in 0..batch {
87            let base = b * 4 * hs;
88            for j in 0..hs {
89                let i_gate = sigmoid_f32(gd[base + j]);
90                let f_gate = sigmoid_f32(gd[base + hs + j]);
91                let g_gate = gd[base + 2 * hs + j].tanh();
92                let o_gate = sigmoid_f32(gd[base + 3 * hs + j]);
93                let c_val = f_gate * cd[b * hs + j] + i_gate * g_gate;
94                let h_val = o_gate * c_val.tanh();
95                c_new.push(c_val);
96                h_new.push(h_val);
97            }
98        }
99
100        let h_out = Tensor::from_vec(vec![batch, hs], h_new)?;
101        let c_out = Tensor::from_vec(vec![batch, hs], c_new)?;
102        Ok((h_out, c_out))
103    }
104}
105
106/// GRU cell: update and reset gates.
107#[derive(Debug, Clone)]
108pub struct GruCell {
109    pub w_ih: Tensor,    // [input_size, 3 * hidden_size]
110    pub w_hh: Tensor,    // [hidden_size, 3 * hidden_size]
111    pub bias_ih: Tensor, // [3 * hidden_size]
112    pub bias_hh: Tensor, // [3 * hidden_size]
113    pub hidden_size: usize,
114}
115
116impl GruCell {
117    pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
118        let h3 = 3 * hidden_size;
119        Ok(Self {
120            w_ih: Tensor::from_vec(vec![input_size, h3], vec![0.0; input_size * h3])?,
121            w_hh: Tensor::from_vec(vec![hidden_size, h3], vec![0.0; hidden_size * h3])?,
122            bias_ih: Tensor::from_vec(vec![h3], vec![0.0; h3])?,
123            bias_hh: Tensor::from_vec(vec![h3], vec![0.0; h3])?,
124            hidden_size,
125        })
126    }
127
128    /// Forward one timestep: x `[batch, input_size]`, h `[batch, hidden_size]` -> h' `[batch, hidden_size]`.
129    pub fn forward(&self, x: &Tensor, h: &Tensor) -> Result<Tensor, ModelError> {
130        let batch = x.shape()[0];
131        let hs = self.hidden_size;
132
133        let xw = matmul_2d(x, &self.w_ih)?;
134        let xw = xw.add(&self.bias_ih.unsqueeze(0)?)?;
135        let hw = matmul_2d(h, &self.w_hh)?;
136        let hw = hw.add(&self.bias_hh.unsqueeze(0)?)?;
137
138        let xd = xw.data();
139        let hd = hw.data();
140        let h_prev = h.data();
141        let mut h_new = Vec::with_capacity(batch * hs);
142
143        for b in 0..batch {
144            let xb = b * 3 * hs;
145            let hb = b * 3 * hs;
146            for j in 0..hs {
147                let r = sigmoid_f32(xd[xb + j] + hd[hb + j]);
148                let z = sigmoid_f32(xd[xb + hs + j] + hd[hb + hs + j]);
149                let n = (xd[xb + 2 * hs + j] + r * hd[hb + 2 * hs + j]).tanh();
150                let h_val = (1.0 - z) * n + z * h_prev[b * hs + j];
151                h_new.push(h_val);
152            }
153        }
154
155        Tensor::from_vec(vec![batch, hs], h_new).map_err(Into::into)
156    }
157}
158
159// ── Multi-step sequence wrappers ───────────────────────────────────
160
161/// Runs an RNN cell over a sequence `[batch, seq_len, input_size]`.
162///
163/// Returns all hidden states `[batch, seq_len, hidden_size]` and final hidden `[batch, hidden_size]`.
164pub fn rnn_forward_sequence(
165    cell: &RnnCell,
166    input: &Tensor,
167    h0: Option<&Tensor>,
168) -> Result<(Tensor, Tensor), ModelError> {
169    let shape = input.shape();
170    let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
171    let hs = cell.hidden_size;
172
173    let mut h = match h0 {
174        Some(h) => h.clone(),
175        None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
176    };
177
178    let mut all_h = Vec::with_capacity(batch * seq_len * hs);
179
180    for t in 0..seq_len {
181        let xt = input.narrow(1, t, 1)?;
182        let xt = xt.reshape(vec![batch, input.shape()[2]])?;
183        h = cell.forward(&xt, &h)?;
184        all_h.extend_from_slice(h.data());
185    }
186
187    let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
188    Ok((output, h))
189}
190
191/// Runs an LSTM cell over a sequence `[batch, seq_len, input_size]`.
192///
193/// Returns all hidden states `[batch, seq_len, hidden_size]`, final `(h, c)`.
194pub fn lstm_forward_sequence(
195    cell: &LstmCell,
196    input: &Tensor,
197    h0: Option<&Tensor>,
198    c0: Option<&Tensor>,
199) -> Result<(Tensor, Tensor, Tensor), ModelError> {
200    let shape = input.shape();
201    let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
202    let hs = cell.hidden_size;
203
204    let mut h = match h0 {
205        Some(h) => h.clone(),
206        None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
207    };
208    let mut c = match c0 {
209        Some(c) => c.clone(),
210        None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
211    };
212
213    let mut all_h = Vec::with_capacity(batch * seq_len * hs);
214
215    for t in 0..seq_len {
216        let xt = input.narrow(1, t, 1)?;
217        let xt = xt.reshape(vec![batch, input.shape()[2]])?;
218        let (h_new, c_new) = cell.forward(&xt, &h, &c)?;
219        all_h.extend_from_slice(h_new.data());
220        h = h_new;
221        c = c_new;
222    }
223
224    let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
225    Ok((output, h, c))
226}
227
228/// Runs a GRU cell over a sequence `[batch, seq_len, input_size]`.
229///
230/// Returns all hidden states `[batch, seq_len, hidden_size]` and final hidden `[batch, hidden_size]`.
231pub fn gru_forward_sequence(
232    cell: &GruCell,
233    input: &Tensor,
234    h0: Option<&Tensor>,
235) -> Result<(Tensor, Tensor), ModelError> {
236    let shape = input.shape();
237    let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
238    let hs = cell.hidden_size;
239
240    let mut h = match h0 {
241        Some(h) => h.clone(),
242        None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
243    };
244
245    let mut all_h = Vec::with_capacity(batch * seq_len * hs);
246
247    for t in 0..seq_len {
248        let xt = input.narrow(1, t, 1)?;
249        let xt = xt.reshape(vec![batch, input.shape()[2]])?;
250        h = cell.forward(&xt, &h)?;
251        all_h.extend_from_slice(h.data());
252    }
253
254    let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
255    Ok((output, h))
256}
257
258/// Bidirectional LSTM: runs forward and backward LSTMs, concatenates outputs.
259///
260/// Returns `[batch, seq_len, 2 * hidden_size]`.
261pub fn bilstm_forward_sequence(
262    fwd_cell: &LstmCell,
263    bwd_cell: &LstmCell,
264    input: &Tensor,
265) -> Result<Tensor, ModelError> {
266    let shape = input.shape();
267    let (batch, seq_len, input_size) = (shape[0], shape[1], shape[2]);
268    let hs = fwd_cell.hidden_size;
269
270    // Forward pass
271    let (fwd_out, _, _) = lstm_forward_sequence(fwd_cell, input, None, None)?;
272
273    // Reverse input along time axis
274    let mut rev_data = Vec::with_capacity(batch * seq_len * input_size);
275    let in_data = input.data();
276    for b in 0..batch {
277        for t in (0..seq_len).rev() {
278            let start = (b * seq_len + t) * input_size;
279            rev_data.extend_from_slice(&in_data[start..start + input_size]);
280        }
281    }
282    let rev_input = Tensor::from_vec(vec![batch, seq_len, input_size], rev_data)?;
283
284    // Backward pass
285    let (bwd_out_rev, _, _) = lstm_forward_sequence(bwd_cell, &rev_input, None, None)?;
286
287    // Reverse backward output and concatenate
288    let fwd_d = fwd_out.data();
289    let bwd_d = bwd_out_rev.data();
290    let mut out = Vec::with_capacity(batch * seq_len * 2 * hs);
291    for b in 0..batch {
292        for t in 0..seq_len {
293            let fwd_start = (b * seq_len + t) * hs;
294            out.extend_from_slice(&fwd_d[fwd_start..fwd_start + hs]);
295            let bwd_t = seq_len - 1 - t;
296            let bwd_start = (b * seq_len + bwd_t) * hs;
297            out.extend_from_slice(&bwd_d[bwd_start..bwd_start + hs]);
298        }
299    }
300
301    Tensor::from_vec(vec![batch, seq_len, 2 * hs], out).map_err(Into::into)
302}
303
304fn sigmoid_f32(x: f32) -> f32 {
305    1.0 / (1.0 + (-x).exp())
306}