1use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2, ArrayView3, Axis};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Untrained},
14 types::Float,
15};
16use std::collections::HashMap;
17
18use crate::activation::ActivationFunction;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum CellType {
23 RNN,
25 LSTM,
27 GRU,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum SequenceMode {
34 ManyToMany,
36 ManyToOne,
38 OneToMany,
40}
41
42#[derive(Debug, Clone)]
63pub struct RecurrentNeuralNetwork<S = Untrained> {
64 state: S,
65 cell_type: CellType,
66 hidden_size: usize,
67 num_layers: usize,
68 sequence_mode: SequenceMode,
69 bidirectional: bool,
70 dropout: Float,
71 learning_rate: Float,
72 max_iter: usize,
73 tolerance: Float,
74 random_state: Option<u64>,
75 alpha: Float, }
77
78#[derive(Debug, Clone)]
80pub struct RecurrentNeuralNetworkTrained {
81 input_weights: Vec<Array2<Float>>,
83 hidden_weights: Vec<Array2<Float>>,
85 biases: Vec<Array1<Float>>,
87 output_weights: Array2<Float>,
89 output_bias: Array1<Float>,
91 gate_weights: HashMap<String, Vec<Array2<Float>>>,
93 gate_biases: HashMap<String, Vec<Array1<Float>>>,
94 cell_type: CellType,
96 hidden_size: usize,
97 num_layers: usize,
98 sequence_mode: SequenceMode,
99 bidirectional: bool,
100 n_features: usize,
101 n_outputs: usize,
102 loss_curve: Vec<Float>,
104 n_iter: usize,
105}
106
107impl RecurrentNeuralNetwork<Untrained> {
108 pub fn new() -> Self {
110 Self {
111 state: Untrained,
112 cell_type: CellType::LSTM,
113 hidden_size: 50,
114 num_layers: 1,
115 sequence_mode: SequenceMode::ManyToMany,
116 bidirectional: false,
117 dropout: 0.0,
118 learning_rate: 0.001,
119 max_iter: 100,
120 tolerance: 1e-4,
121 random_state: None,
122 alpha: 0.0001,
123 }
124 }
125
126 pub fn cell_type(mut self, cell_type: CellType) -> Self {
128 self.cell_type = cell_type;
129 self
130 }
131
132 pub fn hidden_size(mut self, hidden_size: usize) -> Self {
134 self.hidden_size = hidden_size;
135 self
136 }
137
138 pub fn num_layers(mut self, num_layers: usize) -> Self {
140 self.num_layers = num_layers;
141 self
142 }
143
144 pub fn sequence_mode(mut self, sequence_mode: SequenceMode) -> Self {
146 self.sequence_mode = sequence_mode;
147 self
148 }
149
150 pub fn bidirectional(mut self, bidirectional: bool) -> Self {
152 self.bidirectional = bidirectional;
153 self
154 }
155
156 pub fn dropout(mut self, dropout: Float) -> Self {
158 self.dropout = dropout;
159 self
160 }
161
162 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
164 self.learning_rate = learning_rate;
165 self
166 }
167
168 pub fn max_iter(mut self, max_iter: usize) -> Self {
170 self.max_iter = max_iter;
171 self
172 }
173
174 pub fn tolerance(mut self, tolerance: Float) -> Self {
176 self.tolerance = tolerance;
177 self
178 }
179
180 pub fn random_state(mut self, random_state: Option<u64>) -> Self {
182 self.random_state = random_state;
183 self
184 }
185
186 pub fn alpha(mut self, alpha: Float) -> Self {
188 self.alpha = alpha;
189 self
190 }
191}
192
193impl Default for RecurrentNeuralNetwork<Untrained> {
194 fn default() -> Self {
195 Self::new()
196 }
197}
198
199impl Estimator for RecurrentNeuralNetwork<Untrained> {
200 type Config = ();
201 type Error = SklearsError;
202 type Float = Float;
203
204 fn config(&self) -> &Self::Config {
205 &()
206 }
207}
208
209impl Fit<ArrayView3<'_, Float>, Array3<Float>> for RecurrentNeuralNetwork<Untrained> {
210 type Fitted = RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained>;
211
212 #[allow(non_snake_case)]
213 fn fit(self, X: &ArrayView3<'_, Float>, y: &Array3<Float>) -> SklResult<Self::Fitted> {
214 let (n_samples, max_seq_len, n_features) = X.dim();
215 let (n_samples_y, max_seq_len_y, n_outputs) = y.dim();
216
217 if n_samples != n_samples_y {
218 return Err(SklearsError::InvalidInput(
219 "X and y must have the same number of samples".to_string(),
220 ));
221 }
222
223 if self.sequence_mode == SequenceMode::ManyToMany && max_seq_len != max_seq_len_y {
224 return Err(SklearsError::InvalidInput(
225 "For many-to-many mode, X and y must have the same sequence length".to_string(),
226 ));
227 }
228
229 if n_samples == 0 {
230 return Err(SklearsError::InvalidInput(
231 "Cannot fit with zero samples".to_string(),
232 ));
233 }
234
235 let mut rng = thread_rng();
237
238 let (
240 input_weights,
241 hidden_weights,
242 biases,
243 output_weights,
244 output_bias,
245 gate_weights,
246 gate_biases,
247 ) = self.initialize_parameters(n_features, n_outputs, &mut rng)?;
248
249 let mut input_weights = input_weights;
250 let mut hidden_weights = hidden_weights;
251 let mut biases = biases;
252 let mut output_weights = output_weights;
253 let mut output_bias = output_bias;
254 let mut gate_weights = gate_weights;
255 let mut gate_biases = gate_biases;
256
257 let mut loss_curve = Vec::new();
259 let X_owned = X.to_owned();
260 let y_owned = y.to_owned();
261
262 for epoch in 0..self.max_iter {
263 let mut total_loss = 0.0;
264
265 for sample_idx in 0..n_samples {
267 let x_seq = X_owned.slice(s![sample_idx, .., ..]);
268 let y_seq = y_owned.slice(s![sample_idx, .., ..]);
269
270 let (predictions, hidden_states) = self.forward_sequence(
272 &x_seq,
273 &input_weights,
274 &hidden_weights,
275 &biases,
276 &output_weights,
277 &output_bias,
278 &gate_weights,
279 &gate_biases,
280 )?;
281
282 let sample_loss = self.compute_sequence_loss(&predictions, &y_seq.to_owned());
284 total_loss += sample_loss;
285
286 self.backward_sequence(
288 &x_seq,
289 &y_seq.to_owned(),
290 &predictions,
291 &hidden_states,
292 &mut input_weights,
293 &mut hidden_weights,
294 &mut biases,
295 &mut output_weights,
296 &mut output_bias,
297 &mut gate_weights,
298 &mut gate_biases,
299 )?;
300 }
301
302 let avg_loss = total_loss / n_samples as Float;
303 loss_curve.push(avg_loss);
304
305 if epoch > 0 && (loss_curve[epoch - 1] - avg_loss).abs() < self.tolerance {
307 break;
308 }
309 }
310
311 let trained_state = RecurrentNeuralNetworkTrained {
312 input_weights,
313 hidden_weights,
314 biases,
315 output_weights,
316 output_bias,
317 gate_weights,
318 gate_biases,
319 cell_type: self.cell_type,
320 hidden_size: self.hidden_size,
321 num_layers: self.num_layers,
322 sequence_mode: self.sequence_mode,
323 bidirectional: self.bidirectional,
324 n_features,
325 n_outputs,
326 loss_curve,
327 n_iter: self.max_iter,
328 };
329
330 Ok(RecurrentNeuralNetwork {
331 state: trained_state,
332 cell_type: self.cell_type,
333 hidden_size: self.hidden_size,
334 num_layers: self.num_layers,
335 sequence_mode: self.sequence_mode,
336 bidirectional: self.bidirectional,
337 dropout: self.dropout,
338 learning_rate: self.learning_rate,
339 max_iter: self.max_iter,
340 tolerance: self.tolerance,
341 random_state: self.random_state,
342 alpha: self.alpha,
343 })
344 }
345}
346
347impl RecurrentNeuralNetwork<Untrained> {
348 fn initialize_parameters(
350 &self,
351 n_features: usize,
352 n_outputs: usize,
353 rng: &mut scirs2_core::random::CoreRandom,
354 ) -> SklResult<(
355 Vec<Array2<Float>>, Vec<Array2<Float>>, Vec<Array1<Float>>, Array2<Float>, Array1<Float>, HashMap<String, Vec<Array2<Float>>>, HashMap<String, Vec<Array1<Float>>>, )> {
363 let mut input_weights = Vec::new();
364 let mut hidden_weights = Vec::new();
365 let mut biases = Vec::new();
366 let mut gate_weights = HashMap::new();
367 let mut gate_biases = HashMap::new();
368
369 for layer in 0..self.num_layers {
371 let input_size = if layer == 0 {
372 n_features
373 } else {
374 self.hidden_size
375 };
376
377 let input_scale = (2.0 / (input_size + self.hidden_size) as Float).sqrt();
379 let hidden_scale = (2.0 / (self.hidden_size + self.hidden_size) as Float).sqrt();
380
381 let mut input_weight = Array2::<Float>::zeros((self.hidden_size, input_size));
382 let normal_dist = RandNormal::new(0.0, input_scale).unwrap();
383 for i in 0..self.hidden_size {
384 for j in 0..input_size {
385 input_weight[[i, j]] = rng.sample(normal_dist);
386 }
387 }
388 let mut hidden_weight = Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
389 let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
390 for i in 0..self.hidden_size {
391 for j in 0..self.hidden_size {
392 hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
393 }
394 }
395 let bias = Array1::<Float>::zeros(self.hidden_size);
396
397 input_weights.push(input_weight);
398 hidden_weights.push(hidden_weight);
399 biases.push(bias);
400
401 match self.cell_type {
403 CellType::LSTM => {
404 for gate_name in &["forget", "input", "output", "cell"] {
406 let input_key = format!("{}_input", gate_name);
408 if !gate_weights.contains_key(&input_key) {
409 gate_weights.insert(input_key.clone(), Vec::new());
410 }
411 let mut input_weight =
412 Array2::<Float>::zeros((self.hidden_size, input_size));
413 let input_normal_dist = RandNormal::new(0.0, input_scale).unwrap();
414 for i in 0..self.hidden_size {
415 for j in 0..input_size {
416 input_weight[[i, j]] = rng.sample(input_normal_dist);
417 }
418 }
419 gate_weights.get_mut(&input_key).unwrap().push(input_weight);
420
421 let hidden_key = format!("{}_hidden", gate_name);
423 if !gate_weights.contains_key(&hidden_key) {
424 gate_weights.insert(hidden_key.clone(), Vec::new());
425 }
426 let mut hidden_weight =
427 Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
428 let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
429 for i in 0..self.hidden_size {
430 for j in 0..self.hidden_size {
431 hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
432 }
433 }
434 gate_weights
435 .get_mut(&hidden_key)
436 .unwrap()
437 .push(hidden_weight);
438
439 let bias_key = gate_name.to_string();
441 if !gate_biases.contains_key(&bias_key) {
442 gate_biases.insert(bias_key.clone(), Vec::new());
443 }
444 gate_biases
445 .get_mut(&bias_key)
446 .unwrap()
447 .push(Array1::<Float>::zeros(self.hidden_size));
448 }
449 }
450 CellType::GRU => {
451 for gate_name in &["reset", "update", "new"] {
453 let input_key = format!("{}_input", gate_name);
455 if !gate_weights.contains_key(&input_key) {
456 gate_weights.insert(input_key.clone(), Vec::new());
457 }
458 let mut input_weight =
459 Array2::<Float>::zeros((self.hidden_size, input_size));
460 let input_normal_dist = RandNormal::new(0.0, input_scale).unwrap();
461 for i in 0..self.hidden_size {
462 for j in 0..input_size {
463 input_weight[[i, j]] = rng.sample(input_normal_dist);
464 }
465 }
466 gate_weights.get_mut(&input_key).unwrap().push(input_weight);
467
468 let hidden_key = format!("{}_hidden", gate_name);
470 if !gate_weights.contains_key(&hidden_key) {
471 gate_weights.insert(hidden_key.clone(), Vec::new());
472 }
473 let mut hidden_weight =
474 Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
475 let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
476 for i in 0..self.hidden_size {
477 for j in 0..self.hidden_size {
478 hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
479 }
480 }
481 gate_weights
482 .get_mut(&hidden_key)
483 .unwrap()
484 .push(hidden_weight);
485
486 let bias_key = gate_name.to_string();
488 if !gate_biases.contains_key(&bias_key) {
489 gate_biases.insert(bias_key.clone(), Vec::new());
490 }
491 gate_biases
492 .get_mut(&bias_key)
493 .unwrap()
494 .push(Array1::<Float>::zeros(self.hidden_size));
495 }
496 }
497 CellType::RNN => {
498 }
500 }
501 }
502
503 let output_input_size = if self.bidirectional {
505 2 * self.hidden_size
506 } else {
507 self.hidden_size
508 };
509 let output_scale = (2.0 / (output_input_size + n_outputs) as Float).sqrt();
510 let mut output_weights = Array2::<Float>::zeros((n_outputs, output_input_size));
511 let output_normal_dist = RandNormal::new(0.0, output_scale).unwrap();
512 for i in 0..n_outputs {
513 for j in 0..output_input_size {
514 output_weights[[i, j]] = rng.sample(output_normal_dist);
515 }
516 }
517 let output_bias = Array1::<Float>::zeros(n_outputs);
518
519 Ok((
520 input_weights,
521 hidden_weights,
522 biases,
523 output_weights,
524 output_bias,
525 gate_weights,
526 gate_biases,
527 ))
528 }
529
530 fn forward_sequence(
532 &self,
533 x_seq: &ArrayView2<'_, Float>,
534 input_weights: &[Array2<Float>],
535 hidden_weights: &[Array2<Float>],
536 biases: &[Array1<Float>],
537 output_weights: &Array2<Float>,
538 output_bias: &Array1<Float>,
539 gate_weights: &HashMap<String, Vec<Array2<Float>>>,
540 gate_biases: &HashMap<String, Vec<Array1<Float>>>,
541 ) -> SklResult<(Array2<Float>, Vec<Vec<Array1<Float>>>)> {
542 let (seq_len, _) = x_seq.dim();
543 let n_outputs = output_weights.nrows();
544
545 let mut hidden_states = Vec::new();
547 for _ in 0..self.num_layers {
548 hidden_states.push(vec![Array1::<Float>::zeros(self.hidden_size); seq_len + 1]);
549 }
550
551 let mut cell_states = Vec::new();
552 if self.cell_type == CellType::LSTM {
553 for _ in 0..self.num_layers {
554 cell_states.push(vec![Array1::<Float>::zeros(self.hidden_size); seq_len + 1]);
555 }
556 }
557
558 for t in 0..seq_len {
560 let x_t = x_seq.row(t);
561
562 for layer in 0..self.num_layers {
563 let input = if layer == 0 {
564 x_t.to_owned()
565 } else {
566 hidden_states[layer - 1][t].clone()
567 };
568
569 let prev_hidden = &hidden_states[layer][t];
570
571 match self.cell_type {
572 CellType::RNN => {
573 let linear = input_weights[layer].dot(&input)
575 + hidden_weights[layer].dot(prev_hidden)
576 + &biases[layer];
577 hidden_states[layer][t + 1] = linear.map(|x| x.tanh());
578 }
579 CellType::LSTM => {
580 let prev_cell = &cell_states[layer][t];
582
583 let f_t = self.compute_gate(
585 &input,
586 prev_hidden,
587 &gate_weights["forget_input"][layer],
588 &gate_weights["forget_hidden"][layer],
589 &gate_biases["forget"][layer],
590 ActivationFunction::Sigmoid,
591 );
592
593 let i_t = self.compute_gate(
595 &input,
596 prev_hidden,
597 &gate_weights["input_input"][layer],
598 &gate_weights["input_hidden"][layer],
599 &gate_biases["input"][layer],
600 ActivationFunction::Sigmoid,
601 );
602
603 let c_tilde = self.compute_gate(
605 &input,
606 prev_hidden,
607 &gate_weights["cell_input"][layer],
608 &gate_weights["cell_hidden"][layer],
609 &gate_biases["cell"][layer],
610 ActivationFunction::Tanh,
611 );
612
613 let new_cell = &f_t * prev_cell + &i_t * &c_tilde;
615 cell_states[layer][t + 1] = new_cell.clone();
616
617 let o_t = self.compute_gate(
619 &input,
620 prev_hidden,
621 &gate_weights["output_input"][layer],
622 &gate_weights["output_hidden"][layer],
623 &gate_biases["output"][layer],
624 ActivationFunction::Sigmoid,
625 );
626
627 hidden_states[layer][t + 1] = &o_t * &new_cell.map(|x| x.tanh());
629 }
630 CellType::GRU => {
631 let r_t = self.compute_gate(
633 &input,
634 prev_hidden,
635 &gate_weights["reset_input"][layer],
636 &gate_weights["reset_hidden"][layer],
637 &gate_biases["reset"][layer],
638 ActivationFunction::Sigmoid,
639 );
640
641 let z_t = self.compute_gate(
642 &input,
643 prev_hidden,
644 &gate_weights["update_input"][layer],
645 &gate_weights["update_hidden"][layer],
646 &gate_biases["update"][layer],
647 ActivationFunction::Sigmoid,
648 );
649
650 let reset_hidden = &r_t * prev_hidden;
651 let n_t = self.compute_gate(
652 &input,
653 &reset_hidden,
654 &gate_weights["new_input"][layer],
655 &gate_weights["new_hidden"][layer],
656 &gate_biases["new"][layer],
657 ActivationFunction::Tanh,
658 );
659
660 let one_minus_z = Array1::<Float>::ones(self.hidden_size) - &z_t;
661 hidden_states[layer][t + 1] = &z_t * prev_hidden + &one_minus_z * &n_t;
662 }
663 }
664 }
665 }
666
667 let predictions = match self.sequence_mode {
669 SequenceMode::ManyToMany => {
670 let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
671 for t in 0..seq_len {
672 let last_layer_hidden = &hidden_states[self.num_layers - 1][t + 1];
673 let output_t = output_weights.dot(last_layer_hidden) + output_bias;
674 outputs.row_mut(t).assign(&output_t);
675 }
676 outputs
677 }
678 SequenceMode::ManyToOne => {
679 let final_hidden = &hidden_states[self.num_layers - 1][seq_len];
680 let output = output_weights.dot(final_hidden) + output_bias;
681 Array2::from_shape_vec((1, n_outputs), output.to_vec()).unwrap()
682 }
683 SequenceMode::OneToMany => {
684 let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
686 for t in 0..seq_len {
687 let hidden_t = &hidden_states[self.num_layers - 1][t + 1];
688 let output_t = output_weights.dot(hidden_t) + output_bias;
689 outputs.row_mut(t).assign(&output_t);
690 }
691 outputs
692 }
693 };
694
695 Ok((predictions, hidden_states))
696 }
697
698 fn compute_gate(
700 &self,
701 input: &Array1<Float>,
702 hidden: &Array1<Float>,
703 input_weight: &Array2<Float>,
704 hidden_weight: &Array2<Float>,
705 bias: &Array1<Float>,
706 activation: ActivationFunction,
707 ) -> Array1<Float> {
708 let linear = input_weight.dot(input) + hidden_weight.dot(hidden) + bias;
709 activation.apply(&linear)
710 }
711
712 fn compute_sequence_loss(&self, predictions: &Array2<Float>, targets: &Array2<Float>) -> Float {
714 let diff = predictions - targets;
715 diff.map(|x| x * x).mean().unwrap()
716 }
717
718 fn backward_sequence(
720 &self,
721 x_seq: &ArrayView2<'_, Float>,
722 y_seq: &Array2<Float>,
723 predictions: &Array2<Float>,
724 hidden_states: &[Vec<Array1<Float>>],
725 input_weights: &mut [Array2<Float>],
726 hidden_weights: &mut [Array2<Float>],
727 biases: &mut [Array1<Float>],
728 output_weights: &mut Array2<Float>,
729 output_bias: &mut Array1<Float>,
730 gate_weights: &mut HashMap<String, Vec<Array2<Float>>>,
731 gate_biases: &mut HashMap<String, Vec<Array1<Float>>>,
732 ) -> SklResult<()> {
733 let (seq_len, _) = x_seq.dim();
735
736 let output_error = predictions - y_seq;
738
739 match self.sequence_mode {
740 SequenceMode::ManyToMany => {
741 for t in 0..seq_len {
742 let hidden_t = &hidden_states[self.num_layers - 1][t + 1];
743 let error_t = output_error.row(t).to_owned();
744
745 let weight_grad = error_t
747 .clone()
748 .insert_axis(Axis(1))
749 .dot(&hidden_t.clone().insert_axis(Axis(0)));
750 *output_weights = output_weights.clone() - self.learning_rate * weight_grad;
751 *output_bias = output_bias.clone() - self.learning_rate * &error_t;
752
753 for layer in (0..self.num_layers).rev() {
755 let x_t = if layer == 0 {
756 x_seq.row(t).to_owned()
757 } else {
758 hidden_states[layer - 1][t + 1].clone()
759 };
760
761 let hidden_error = output_weights.t().dot(&error_t);
762 let weight_grad = hidden_error
763 .clone()
764 .insert_axis(Axis(1))
765 .dot(&x_t.insert_axis(Axis(0)));
766
767 input_weights[layer] =
768 input_weights[layer].clone() - self.learning_rate * weight_grad;
769 biases[layer] = biases[layer].clone() - self.learning_rate * hidden_error;
770 }
771 }
772 }
773 _ => {
774 let hidden_final = &hidden_states[self.num_layers - 1][seq_len];
776 let error_final = output_error.row(0).to_owned();
777 let weight_grad = error_final
778 .clone()
779 .insert_axis(Axis(1))
780 .dot(&hidden_final.clone().insert_axis(Axis(0)));
781 *output_weights = output_weights.clone() - self.learning_rate * weight_grad;
782 *output_bias = output_bias.clone() - self.learning_rate * error_final;
783 }
784 }
785
786 Ok(())
787 }
788}
789
790impl Predict<ArrayView3<'_, Float>, Array3<Float>>
791 for RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained>
792{
793 fn predict(&self, X: &ArrayView3<'_, Float>) -> SklResult<Array3<Float>> {
794 let (n_samples, max_seq_len, n_features) = X.dim();
795
796 if n_features != self.state.n_features {
797 return Err(SklearsError::InvalidInput(
798 "X has different number of features than training data".to_string(),
799 ));
800 }
801
802 let mut predictions = match self.state.sequence_mode {
803 SequenceMode::ManyToMany => {
804 Array3::<Float>::zeros((n_samples, max_seq_len, self.state.n_outputs))
805 }
806 SequenceMode::ManyToOne => Array3::<Float>::zeros((n_samples, 1, self.state.n_outputs)),
807 SequenceMode::OneToMany => {
808 Array3::<Float>::zeros((n_samples, max_seq_len, self.state.n_outputs))
809 }
810 };
811
812 for sample_idx in 0..n_samples {
814 let x_seq = X.slice(s![sample_idx, .., ..]);
815
816 let (sample_predictions, _) = self.forward_sequence_trained(&x_seq)?;
817
818 match self.state.sequence_mode {
819 SequenceMode::ManyToMany | SequenceMode::OneToMany => {
820 for t in 0..sample_predictions.nrows() {
821 for j in 0..sample_predictions.ncols() {
822 predictions[[sample_idx, t, j]] = sample_predictions[[t, j]];
823 }
824 }
825 }
826 SequenceMode::ManyToOne => {
827 for j in 0..sample_predictions.ncols() {
828 predictions[[sample_idx, 0, j]] = sample_predictions[[0, j]];
829 }
830 }
831 }
832 }
833
834 Ok(predictions)
835 }
836}
837
838impl RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained> {
839 fn forward_sequence_trained(
841 &self,
842 x_seq: &ArrayView2<'_, Float>,
843 ) -> SklResult<(Array2<Float>, Vec<Vec<Array1<Float>>>)> {
844 let (seq_len, _) = x_seq.dim();
845 let n_outputs = self.state.output_weights.nrows();
846
847 let mut hidden_states = Vec::new();
849 for _ in 0..self.state.num_layers {
850 hidden_states.push(vec![
851 Array1::<Float>::zeros(self.state.hidden_size);
852 seq_len + 1
853 ]);
854 }
855
856 for t in 0..seq_len {
858 let x_t = x_seq.row(t);
859
860 for layer in 0..self.state.num_layers {
861 let input = if layer == 0 {
862 x_t.to_owned()
863 } else {
864 hidden_states[layer - 1][t].clone()
865 };
866
867 let prev_hidden = &hidden_states[layer][t];
868
869 let linear = self.state.input_weights[layer].dot(&input)
871 + self.state.hidden_weights[layer].dot(prev_hidden)
872 + &self.state.biases[layer];
873
874 hidden_states[layer][t + 1] = match self.state.cell_type {
875 CellType::RNN => linear.map(|x| x.tanh()),
876 CellType::LSTM | CellType::GRU => linear.map(|x| x.tanh()), };
878 }
879 }
880
881 let predictions = match self.state.sequence_mode {
883 SequenceMode::ManyToMany => {
884 let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
885 for t in 0..seq_len {
886 let last_layer_hidden = &hidden_states[self.state.num_layers - 1][t + 1];
887 let output_t =
888 self.state.output_weights.dot(last_layer_hidden) + &self.state.output_bias;
889 outputs.row_mut(t).assign(&output_t);
890 }
891 outputs
892 }
893 SequenceMode::ManyToOne => {
894 let final_hidden = &hidden_states[self.state.num_layers - 1][seq_len];
895 let output = self.state.output_weights.dot(final_hidden) + &self.state.output_bias;
896 Array2::from_shape_vec((1, n_outputs), output.to_vec()).unwrap()
897 }
898 SequenceMode::OneToMany => {
899 let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
900 for t in 0..seq_len {
901 let hidden_t = &hidden_states[self.state.num_layers - 1][t + 1];
902 let output_t =
903 self.state.output_weights.dot(hidden_t) + &self.state.output_bias;
904 outputs.row_mut(t).assign(&output_t);
905 }
906 outputs
907 }
908 };
909
910 Ok((predictions, hidden_states))
911 }
912
913 pub fn loss_curve(&self) -> &[Float] {
915 &self.state.loss_curve
916 }
917
918 pub fn n_iter(&self) -> usize {
920 self.state.n_iter
921 }
922
923 pub fn cell_type(&self) -> CellType {
925 self.state.cell_type
926 }
927
928 pub fn hidden_size(&self) -> usize {
930 self.state.hidden_size
931 }
932
933 pub fn sequence_mode(&self) -> SequenceMode {
935 self.state.sequence_mode
936 }
937}