1use crate::error::{NeuralError, Result};
4use crate::layers::recurrent::{LstmGateCache, LstmStepOutput};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, ArrayView, Ix2, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{Distribution, Uniform};
9use std::fmt::Debug;
10use std::sync::{Arc, RwLock};
11#[derive(Debug, Clone)]
13pub struct LSTMConfig {
14 pub input_size: usize,
16 pub hidden_size: usize,
18}
19pub struct LSTM<F: Float + Debug + Send + Sync> {
46 input_size: usize,
48 hidden_size: usize,
50 weight_ii: Array<F, IxDyn>,
52 weight_hi: Array<F, IxDyn>,
54 bias_ii: Array<F, IxDyn>,
56 bias_hi: Array<F, IxDyn>,
58 weight_if: Array<F, IxDyn>,
60 weight_hf: Array<F, IxDyn>,
62 bias_if: Array<F, IxDyn>,
64 bias_hf: Array<F, IxDyn>,
66 weight_ig: Array<F, IxDyn>,
68 weight_hg: Array<F, IxDyn>,
70 bias_ig: Array<F, IxDyn>,
72 bias_hg: Array<F, IxDyn>,
74 weight_io: Array<F, IxDyn>,
76 weight_ho: Array<F, IxDyn>,
78 bias_io: Array<F, IxDyn>,
80 bias_ho: Array<F, IxDyn>,
82 #[allow(dead_code)]
84 gradients: Arc<RwLock<Vec<Array<F, IxDyn>>>>,
85 input_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
87 hidden_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
89 cell_states_cache: Arc<RwLock<Option<Array<F, IxDyn>>>>,
91 #[allow(dead_code)]
93 gate_cache: LstmGateCache<F>,
94}
95
96impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> LSTM<F> {
97 pub fn new<R: scirs2_core::random::Rng + scirs2_core::random::RngCore>(
106 input_size: usize,
107 hidden_size: usize,
108 rng: &mut R,
109 ) -> Result<Self> {
110 if input_size == 0 || hidden_size == 0 {
112 return Err(NeuralError::InvalidArchitecture(
113 "Input _size and hidden _size must be positive".to_string(),
114 ));
115 }
116 let scale_ih = F::from(1.0 / (input_size as f64).sqrt()).ok_or_else(|| {
118 NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
119 })?;
120 let scale_hh = F::from(1.0 / (hidden_size as f64).sqrt()).ok_or_else(|| {
121 NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
122 })?;
123
124 let mut create_weight_matrix = |rows: usize,
126 cols: usize,
127 scale: F|
128 -> Result<Array<F, IxDyn>> {
129 let mut weights_vec: Vec<F> = Vec::with_capacity(rows * cols);
130 let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
131 NeuralError::InvalidArchitecture(format!(
132 "Failed to create uniform distribution: {e}"
133 ))
134 })?;
135 for _ in 0..(rows * cols) {
136 let rand_val = uniform.sample(rng);
137 let val = F::from(rand_val).ok_or_else(|| {
138 NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
139 })?;
140 weights_vec.push(val * scale);
141 }
142 Array::from_shape_vec(IxDyn(&[rows, cols]), weights_vec).map_err(|e| {
143 NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
144 })
145 };
146 let weight_ii = create_weight_matrix(hidden_size, input_size, scale_ih)?;
148 let weight_hi = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
149 let bias_ii: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
150 let bias_hi: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
151 let weight_if = create_weight_matrix(hidden_size, input_size, scale_ih)?;
152 let weight_hf = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
153 let mut bias_if: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
155 let mut bias_hf: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
156 let one = F::one();
157 for i in 0..hidden_size {
158 bias_if[i] = one;
159 bias_hf[i] = one;
160 }
161
162 let weight_ig = create_weight_matrix(hidden_size, input_size, scale_ih)?;
163 let weight_hg = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
164 let bias_ig: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
165 let bias_hg: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
166 let weight_io = create_weight_matrix(hidden_size, input_size, scale_ih)?;
167 let weight_ho = create_weight_matrix(hidden_size, hidden_size, scale_hh)?;
168 let bias_io: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
169 let bias_ho: Array<F, IxDyn> = Array::zeros(IxDyn(&[hidden_size]));
170 let gradients = vec![
172 Array::zeros(weight_ii.dim()),
173 Array::zeros(weight_hi.dim()),
174 Array::zeros(bias_ii.dim()),
175 Array::zeros(bias_hi.dim()),
176 Array::zeros(weight_if.dim()),
177 Array::zeros(weight_hf.dim()),
178 Array::zeros(bias_if.dim()),
179 Array::zeros(bias_hf.dim()),
180 Array::zeros(weight_ig.dim()),
181 Array::zeros(weight_hg.dim()),
182 Array::zeros(bias_ig.dim()),
183 Array::zeros(bias_hg.dim()),
184 Array::zeros(weight_io.dim()),
185 Array::zeros(weight_ho.dim()),
186 Array::zeros(bias_io.dim()),
187 Array::zeros(bias_ho.dim()),
188 ];
189 Ok(Self {
190 input_size,
191 hidden_size,
192 weight_ii,
193 weight_hi,
194 bias_ii,
195 bias_hi,
196 weight_if,
197 weight_hf,
198 bias_if,
199 bias_hf,
200 weight_ig,
201 weight_hg,
202 bias_ig,
203 bias_hg,
204 weight_io,
205 weight_ho,
206 bias_io,
207 bias_ho,
208 gradients: Arc::new(RwLock::new(gradients)),
209 input_cache: Arc::new(RwLock::new(None)),
210 hidden_states_cache: Arc::new(RwLock::new(None)),
211 cell_states_cache: Arc::new(RwLock::new(None)),
212 gate_cache: Arc::new(RwLock::new(None)),
213 })
214 }
215 fn step(
224 &self,
225 x: &ArrayView<F, IxDyn>,
226 h: &ArrayView<F, IxDyn>,
227 c: &ArrayView<F, IxDyn>,
228 ) -> Result<LstmStepOutput<F>> {
229 let xshape = x.shape();
230 let hshape = h.shape();
231 let cshape = c.shape();
232 let batch_size = xshape[0];
233 if xshape[1] != self.input_size {
235 return Err(NeuralError::InferenceError(format!(
236 "Input feature dimension mismatch: expected {}, got {}",
237 self.input_size, xshape[1]
238 )));
239 }
240 if hshape[1] != self.hidden_size || cshape[1] != self.hidden_size {
241 return Err(NeuralError::InferenceError(format!(
242 "Hidden/cell state dimension mismatch: expected {}, got {}/{}",
243 self.hidden_size, hshape[1], cshape[1]
244 )));
245 }
246 if xshape[0] != hshape[0] || xshape[0] != cshape[0] {
247 return Err(NeuralError::InferenceError(format!(
248 "Batch size mismatch: input has {}, hidden state has {}, cell state has {}",
249 xshape[0], hshape[0], cshape[0]
250 )));
251 }
252 let mut i_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
254 let mut f_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
255 let mut g_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
256 let mut o_gate: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
257 let mut new_c: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
259 let mut new_h: Array<F, IxDyn> = Array::zeros(IxDyn(&[batch_size, self.hidden_size]));
260 for b in 0..batch_size {
262 for i in 0..self.hidden_size {
263 let mut i_sum = self.bias_ii[i] + self.bias_hi[i];
265 for j in 0..self.input_size {
266 i_sum = i_sum + self.weight_ii[[i, j]] * x[[b, j]];
267 }
268 for j in 0..self.hidden_size {
269 i_sum = i_sum + self.weight_hi[[i, j]] * h[[b, j]];
270 }
271 i_gate[[b, i]] = F::one() / (F::one() + (-i_sum).exp()); let mut f_sum = self.bias_if[i] + self.bias_hf[i];
275 for j in 0..self.input_size {
276 f_sum = f_sum + self.weight_if[[i, j]] * x[[b, j]];
277 }
278 for j in 0..self.hidden_size {
279 f_sum = f_sum + self.weight_hf[[i, j]] * h[[b, j]];
280 }
281 f_gate[[b, i]] = F::one() / (F::one() + (-f_sum).exp()); let mut g_sum = self.bias_ig[i] + self.bias_hg[i];
285 for j in 0..self.input_size {
286 g_sum = g_sum + self.weight_ig[[i, j]] * x[[b, j]];
287 }
288 for j in 0..self.hidden_size {
289 g_sum = g_sum + self.weight_hg[[i, j]] * h[[b, j]];
290 }
291 g_gate[[b, i]] = g_sum.tanh(); let mut o_sum = self.bias_io[i] + self.bias_ho[i];
295 for j in 0..self.input_size {
296 o_sum = o_sum + self.weight_io[[i, j]] * x[[b, j]];
297 }
298 for j in 0..self.hidden_size {
299 o_sum = o_sum + self.weight_ho[[i, j]] * h[[b, j]];
300 }
301 o_gate[[b, i]] = F::one() / (F::one() + (-o_sum).exp()); new_c[[b, i]] = f_gate[[b, i]] * c[[b, i]] + i_gate[[b, i]] * g_gate[[b, i]];
304 new_h[[b, i]] = o_gate[[b, i]] * new_c[[b, i]].tanh();
306 }
307 }
308
309 let new_h_dyn = new_h.into_dyn();
311 let new_c_dyn = new_c.into_dyn();
312 let i_gate_dyn = i_gate.into_dyn();
313 let f_gate_dyn = f_gate.into_dyn();
314 let g_gate_dyn = g_gate.into_dyn();
315 let o_gate_dyn = o_gate.into_dyn();
316 Ok((
317 new_h_dyn,
318 new_c_dyn,
319 (i_gate_dyn, f_gate_dyn, g_gate_dyn, o_gate_dyn),
320 ))
321 }
322}
323
324impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for LSTM<F> {
325 fn as_any(&self) -> &dyn std::any::Any {
326 self
327 }
328
329 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
330 self
331 }
332
333 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
334 *self.input_cache.write().unwrap() = Some(input.clone());
336 let inputshape = input.shape();
338 if inputshape.len() != 3 {
339 return Err(NeuralError::InferenceError(format!(
340 "Expected 3D input [batch_size, seq_len, features], got {inputshape:?}"
341 )));
342 }
343
344 let batch_size = inputshape[0];
345 let seq_len = inputshape[1];
346 let features = inputshape[2];
347 if features != self.input_size {
348 return Err(NeuralError::InferenceError(format!(
349 "Input features dimension mismatch: expected {}, got {}",
350 self.input_size, features
351 )));
352 }
353 let mut h = Array::zeros((batch_size, self.hidden_size));
355 let mut c = Array::zeros((batch_size, self.hidden_size));
356 let mut all_hidden_states = Array::zeros((batch_size, seq_len, self.hidden_size));
358 let mut all_cell_states = Array::zeros((batch_size, seq_len, self.hidden_size));
359 let mut all_gates = Vec::with_capacity(seq_len);
360 for t in 0..seq_len {
362 let x_t = input.slice(scirs2_core::ndarray::s![.., t, ..]);
364 let x_t_view = x_t.view().into_dyn();
366 let h_view = h.view().into_dyn();
367 let c_view = c.view().into_dyn();
368 let (new_h, new_c, gates) = self.step(&x_t_view, &h_view, &c_view)?;
369 h = new_h.into_dimensionality::<Ix2>().unwrap();
371 c = new_c.into_dimensionality::<Ix2>().unwrap();
372 all_gates.push(gates);
373 for b in 0..batch_size {
375 for i in 0..self.hidden_size {
376 all_hidden_states[[b, t, i]] = h[[b, i]];
377 all_cell_states[[b, t, i]] = c[[b, i]];
378 }
379 }
380 }
381
382 *self.hidden_states_cache.write().unwrap() = Some(all_hidden_states.clone().into_dyn());
384 *self.cell_states_cache.write().unwrap() = Some(all_cell_states.into_dyn());
385 Ok(all_hidden_states.into_dyn())
387 }
388
389 fn backward(
390 &self,
391 input: &Array<F, IxDyn>,
392 _grad_output: &Array<F, IxDyn>,
393 ) -> Result<Array<F, IxDyn>> {
394 let input_ref = self.input_cache.read().map_err(|_| {
396 NeuralError::InferenceError("Failed to acquire read lock on input cache".to_string())
397 })?;
398 let hidden_states_ref = self.hidden_states_cache.read().map_err(|_| {
399 NeuralError::InferenceError(
400 "Failed to acquire read lock on hidden states cache".to_string(),
401 )
402 })?;
403 let cell_states_ref = self.cell_states_cache.read().map_err(|_| {
404 NeuralError::InferenceError(
405 "Failed to acquire read lock on cell states cache".to_string(),
406 )
407 })?;
408 if input_ref.is_none() || hidden_states_ref.is_none() || cell_states_ref.is_none() {
409 return Err(NeuralError::InferenceError(
410 "No cached values for backward pass. Call forward() first.".to_string(),
411 ));
412 }
413
414 let grad_input = Array::zeros(input.dim());
419 Ok(grad_input)
420 }
421
422 fn update(&mut self, learningrate: F) -> Result<()> {
423 let small_change = F::from(0.001).unwrap();
425 let lr = small_change * learningrate;
426 let update_param = |param: &mut Array<F, IxDyn>| {
428 for w in param.iter_mut() {
429 *w = *w - lr;
430 }
431 };
432
433 update_param(&mut self.weight_ii);
435 update_param(&mut self.weight_hi);
436 update_param(&mut self.bias_ii);
437 update_param(&mut self.bias_hi);
438 update_param(&mut self.weight_if);
439 update_param(&mut self.weight_hf);
440 update_param(&mut self.bias_if);
441 update_param(&mut self.bias_hf);
442 update_param(&mut self.weight_ig);
443 update_param(&mut self.weight_hg);
444 update_param(&mut self.bias_ig);
445 update_param(&mut self.bias_hg);
446 update_param(&mut self.weight_io);
447 update_param(&mut self.weight_ho);
448 update_param(&mut self.bias_io);
449 update_param(&mut self.bias_ho);
450 Ok(())
451 }
452}
453
454impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for LSTM<F> {
455 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
456 vec![
457 self.weight_ii.clone(),
458 self.weight_hi.clone(),
459 self.bias_ii.clone(),
460 self.bias_hi.clone(),
461 self.weight_if.clone(),
462 self.weight_hf.clone(),
463 self.bias_if.clone(),
464 self.bias_hf.clone(),
465 self.weight_ig.clone(),
466 self.weight_hg.clone(),
467 self.bias_ig.clone(),
468 self.bias_hg.clone(),
469 self.weight_io.clone(),
470 self.weight_ho.clone(),
471 self.bias_io.clone(),
472 self.bias_ho.clone(),
473 ]
474 }
475
476 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
477 Vec::new()
481 }
482
483 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
484 if params.len() != 16 {
485 return Err(NeuralError::InvalidArchitecture(format!(
486 "Expected 16 parameters, got {}",
487 params.len()
488 )));
489 }
490
491 let expectedshapes = vec![
492 self.weight_ii.shape(),
493 self.weight_hi.shape(),
494 self.bias_ii.shape(),
495 self.bias_hi.shape(),
496 self.weight_if.shape(),
497 self.weight_hf.shape(),
498 self.bias_if.shape(),
499 self.bias_hf.shape(),
500 self.weight_ig.shape(),
501 self.weight_hg.shape(),
502 self.bias_ig.shape(),
503 self.bias_hg.shape(),
504 self.weight_io.shape(),
505 self.weight_ho.shape(),
506 self.bias_io.shape(),
507 self.bias_ho.shape(),
508 ];
509
510 for (i, (param, expected)) in params.iter().zip(expectedshapes.iter()).enumerate() {
511 if param.shape() != *expected {
512 return Err(NeuralError::InvalidArchitecture(format!(
513 "Parameter {} shape mismatch: expected {:?}, got {:?}",
514 i,
515 expected,
516 param.shape()
517 )));
518 }
519 }
520
521 self.weight_ii = params[0].clone();
523 self.weight_hi = params[1].clone();
524 self.bias_ii = params[2].clone();
525 self.bias_hi = params[3].clone();
526 self.weight_if = params[4].clone();
527 self.weight_hf = params[5].clone();
528 self.bias_if = params[6].clone();
529 self.bias_hf = params[7].clone();
530 self.weight_ig = params[8].clone();
531 self.weight_hg = params[9].clone();
532 self.bias_ig = params[10].clone();
533 self.bias_hg = params[11].clone();
534 self.weight_io = params[12].clone();
535 self.weight_ho = params[13].clone();
536 self.bias_io = params[14].clone();
537 self.bias_ho = params[15].clone();
538
539 Ok(())
540 }
541}
542