1use yscv_kernels::matmul_2d;
2use yscv_tensor::Tensor;
3
4use crate::ModelError;
5
6#[derive(Debug, Clone)]
8pub struct RnnCell {
9 pub w_ih: Tensor, pub w_hh: Tensor, pub bias: Tensor, pub hidden_size: usize,
13}
14
15impl RnnCell {
16 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
17 Ok(Self {
18 w_ih: Tensor::from_vec(
19 vec![input_size, hidden_size],
20 vec![0.0; input_size * hidden_size],
21 )?,
22 w_hh: Tensor::from_vec(
23 vec![hidden_size, hidden_size],
24 vec![0.0; hidden_size * hidden_size],
25 )?,
26 bias: Tensor::from_vec(vec![hidden_size], vec![0.0; hidden_size])?,
27 hidden_size,
28 })
29 }
30
31 pub fn forward(&self, x: &Tensor, h: &Tensor) -> Result<Tensor, ModelError> {
33 let xw = matmul_2d(x, &self.w_ih)?;
34 let hw = matmul_2d(h, &self.w_hh)?;
35 let sum = xw.add(&hw)?;
36 let sum = sum.add(&self.bias.unsqueeze(0)?)?;
37 let data: Vec<f32> = sum.data().iter().map(|&v| v.tanh()).collect();
38 Tensor::from_vec(sum.shape().to_vec(), data).map_err(Into::into)
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct LstmCell {
45 pub w_ih: Tensor, pub w_hh: Tensor, pub bias: Tensor, pub hidden_size: usize,
49}
50
51impl LstmCell {
52 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
53 let h4 = 4 * hidden_size;
54 Ok(Self {
55 w_ih: Tensor::from_vec(vec![input_size, h4], vec![0.0; input_size * h4])?,
56 w_hh: Tensor::from_vec(vec![hidden_size, h4], vec![0.0; hidden_size * h4])?,
57 bias: Tensor::from_vec(vec![h4], vec![0.0; h4])?,
58 hidden_size,
59 })
60 }
61
62 pub fn forward(
66 &self,
67 x: &Tensor,
68 h: &Tensor,
69 c: &Tensor,
70 ) -> Result<(Tensor, Tensor), ModelError> {
71 let batch = x.shape()[0];
72 let hs = self.hidden_size;
73
74 let gates = {
75 let xw = matmul_2d(x, &self.w_ih)?;
76 let hw = matmul_2d(h, &self.w_hh)?;
77 let g = xw.add(&hw)?;
78 g.add(&self.bias.unsqueeze(0)?)?
79 };
80
81 let gd = gates.data();
82 let cd = c.data();
83 let mut h_new = Vec::with_capacity(batch * hs);
84 let mut c_new = Vec::with_capacity(batch * hs);
85
86 for b in 0..batch {
87 let base = b * 4 * hs;
88 for j in 0..hs {
89 let i_gate = sigmoid_f32(gd[base + j]);
90 let f_gate = sigmoid_f32(gd[base + hs + j]);
91 let g_gate = gd[base + 2 * hs + j].tanh();
92 let o_gate = sigmoid_f32(gd[base + 3 * hs + j]);
93 let c_val = f_gate * cd[b * hs + j] + i_gate * g_gate;
94 let h_val = o_gate * c_val.tanh();
95 c_new.push(c_val);
96 h_new.push(h_val);
97 }
98 }
99
100 let h_out = Tensor::from_vec(vec![batch, hs], h_new)?;
101 let c_out = Tensor::from_vec(vec![batch, hs], c_new)?;
102 Ok((h_out, c_out))
103 }
104}
105
106#[derive(Debug, Clone)]
108pub struct GruCell {
109 pub w_ih: Tensor, pub w_hh: Tensor, pub bias_ih: Tensor, pub bias_hh: Tensor, pub hidden_size: usize,
114}
115
116impl GruCell {
117 pub fn new(input_size: usize, hidden_size: usize) -> Result<Self, ModelError> {
118 let h3 = 3 * hidden_size;
119 Ok(Self {
120 w_ih: Tensor::from_vec(vec![input_size, h3], vec![0.0; input_size * h3])?,
121 w_hh: Tensor::from_vec(vec![hidden_size, h3], vec![0.0; hidden_size * h3])?,
122 bias_ih: Tensor::from_vec(vec![h3], vec![0.0; h3])?,
123 bias_hh: Tensor::from_vec(vec![h3], vec![0.0; h3])?,
124 hidden_size,
125 })
126 }
127
128 pub fn forward(&self, x: &Tensor, h: &Tensor) -> Result<Tensor, ModelError> {
130 let batch = x.shape()[0];
131 let hs = self.hidden_size;
132
133 let xw = matmul_2d(x, &self.w_ih)?;
134 let xw = xw.add(&self.bias_ih.unsqueeze(0)?)?;
135 let hw = matmul_2d(h, &self.w_hh)?;
136 let hw = hw.add(&self.bias_hh.unsqueeze(0)?)?;
137
138 let xd = xw.data();
139 let hd = hw.data();
140 let h_prev = h.data();
141 let mut h_new = Vec::with_capacity(batch * hs);
142
143 for b in 0..batch {
144 let xb = b * 3 * hs;
145 let hb = b * 3 * hs;
146 for j in 0..hs {
147 let r = sigmoid_f32(xd[xb + j] + hd[hb + j]);
148 let z = sigmoid_f32(xd[xb + hs + j] + hd[hb + hs + j]);
149 let n = (xd[xb + 2 * hs + j] + r * hd[hb + 2 * hs + j]).tanh();
150 let h_val = (1.0 - z) * n + z * h_prev[b * hs + j];
151 h_new.push(h_val);
152 }
153 }
154
155 Tensor::from_vec(vec![batch, hs], h_new).map_err(Into::into)
156 }
157}
158
159pub fn rnn_forward_sequence(
165 cell: &RnnCell,
166 input: &Tensor,
167 h0: Option<&Tensor>,
168) -> Result<(Tensor, Tensor), ModelError> {
169 let shape = input.shape();
170 let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
171 let hs = cell.hidden_size;
172
173 let mut h = match h0 {
174 Some(h) => h.clone(),
175 None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
176 };
177
178 let mut all_h = Vec::with_capacity(batch * seq_len * hs);
179
180 for t in 0..seq_len {
181 let xt = input.narrow(1, t, 1)?;
182 let xt = xt.reshape(vec![batch, input.shape()[2]])?;
183 h = cell.forward(&xt, &h)?;
184 all_h.extend_from_slice(h.data());
185 }
186
187 let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
188 Ok((output, h))
189}
190
191pub fn lstm_forward_sequence(
195 cell: &LstmCell,
196 input: &Tensor,
197 h0: Option<&Tensor>,
198 c0: Option<&Tensor>,
199) -> Result<(Tensor, Tensor, Tensor), ModelError> {
200 let shape = input.shape();
201 let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
202 let hs = cell.hidden_size;
203
204 let mut h = match h0 {
205 Some(h) => h.clone(),
206 None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
207 };
208 let mut c = match c0 {
209 Some(c) => c.clone(),
210 None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
211 };
212
213 let mut all_h = Vec::with_capacity(batch * seq_len * hs);
214
215 for t in 0..seq_len {
216 let xt = input.narrow(1, t, 1)?;
217 let xt = xt.reshape(vec![batch, input.shape()[2]])?;
218 let (h_new, c_new) = cell.forward(&xt, &h, &c)?;
219 all_h.extend_from_slice(h_new.data());
220 h = h_new;
221 c = c_new;
222 }
223
224 let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
225 Ok((output, h, c))
226}
227
228pub fn gru_forward_sequence(
232 cell: &GruCell,
233 input: &Tensor,
234 h0: Option<&Tensor>,
235) -> Result<(Tensor, Tensor), ModelError> {
236 let shape = input.shape();
237 let (batch, seq_len, _input_size) = (shape[0], shape[1], shape[2]);
238 let hs = cell.hidden_size;
239
240 let mut h = match h0 {
241 Some(h) => h.clone(),
242 None => Tensor::from_vec(vec![batch, hs], vec![0.0; batch * hs])?,
243 };
244
245 let mut all_h = Vec::with_capacity(batch * seq_len * hs);
246
247 for t in 0..seq_len {
248 let xt = input.narrow(1, t, 1)?;
249 let xt = xt.reshape(vec![batch, input.shape()[2]])?;
250 h = cell.forward(&xt, &h)?;
251 all_h.extend_from_slice(h.data());
252 }
253
254 let output = Tensor::from_vec(vec![batch, seq_len, hs], all_h)?;
255 Ok((output, h))
256}
257
258pub fn bilstm_forward_sequence(
262 fwd_cell: &LstmCell,
263 bwd_cell: &LstmCell,
264 input: &Tensor,
265) -> Result<Tensor, ModelError> {
266 let shape = input.shape();
267 let (batch, seq_len, input_size) = (shape[0], shape[1], shape[2]);
268 let hs = fwd_cell.hidden_size;
269
270 let (fwd_out, _, _) = lstm_forward_sequence(fwd_cell, input, None, None)?;
272
273 let mut rev_data = Vec::with_capacity(batch * seq_len * input_size);
275 let in_data = input.data();
276 for b in 0..batch {
277 for t in (0..seq_len).rev() {
278 let start = (b * seq_len + t) * input_size;
279 rev_data.extend_from_slice(&in_data[start..start + input_size]);
280 }
281 }
282 let rev_input = Tensor::from_vec(vec![batch, seq_len, input_size], rev_data)?;
283
284 let (bwd_out_rev, _, _) = lstm_forward_sequence(bwd_cell, &rev_input, None, None)?;
286
287 let fwd_d = fwd_out.data();
289 let bwd_d = bwd_out_rev.data();
290 let mut out = Vec::with_capacity(batch * seq_len * 2 * hs);
291 for b in 0..batch {
292 for t in 0..seq_len {
293 let fwd_start = (b * seq_len + t) * hs;
294 out.extend_from_slice(&fwd_d[fwd_start..fwd_start + hs]);
295 let bwd_t = seq_len - 1 - t;
296 let bwd_start = (b * seq_len + bwd_t) * hs;
297 out.extend_from_slice(&bwd_d[bwd_start..bwd_start + hs]);
298 }
299 }
300
301 Tensor::from_vec(vec![batch, seq_len, 2 * hs], out).map_err(Into::into)
302}
303
304fn sigmoid_f32(x: f32) -> f32 {
305 1.0 / (1.0 + (-x).exp())
306}