sklears_multioutput/
recurrent.rs

1//! Neural Sequence Models for Structured Output Prediction
2//!
3//! This module implements RNN, LSTM, and GRU models for sequence-based tasks
4//! such as sequence labeling, sequence-to-sequence prediction, and other
5//! structured output problems.
6
7// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
8use scirs2_core::ndarray::{s, Array1, Array2, Array3, ArrayView2, ArrayView3, Axis};
9use scirs2_core::random::thread_rng;
10use scirs2_core::random::RandNormal;
11use sklears_core::{
12    error::{Result as SklResult, SklearsError},
13    traits::{Estimator, Fit, Predict, Untrained},
14    types::Float,
15};
16use std::collections::HashMap;
17
18use crate::activation::ActivationFunction;
19
20/// Cell types for recurrent neural networks
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub enum CellType {
23    /// Simple RNN cell
24    RNN,
25    /// Long Short-Term Memory cell
26    LSTM,
27    /// Gated Recurrent Unit cell
28    GRU,
29}
30
31/// Output modes for sequence models
32#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum SequenceMode {
34    /// Many-to-many: output at each timestep
35    ManyToMany,
36    /// Many-to-one: single output at the end
37    ManyToOne,
38    /// One-to-many: single input, sequence output
39    OneToMany,
40}
41
42/// Recurrent Neural Network for Sequence Prediction
43///
44/// This model can handle various sequence prediction tasks using different
45/// cell types (RNN, LSTM, GRU) and output modes.
46///
47/// # Examples
48///
49/// ```
50/// use sklears_multioutput::recurrent::{RecurrentNeuralNetwork, CellType, SequenceMode};
51/// use sklears_core::traits::{Predict, Fit};
52/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
53///
54/// // Example for sequence labeling (many-to-many)
55/// let rnn = RecurrentNeuralNetwork::new()
56///     .cell_type(CellType::LSTM)
57///     .hidden_size(50)
58///     .sequence_mode(SequenceMode::ManyToMany)
59///     .learning_rate(0.001)
60///     .max_iter(100);
61/// ```
62#[derive(Debug, Clone)]
63pub struct RecurrentNeuralNetwork<S = Untrained> {
64    state: S,
65    cell_type: CellType,
66    hidden_size: usize,
67    num_layers: usize,
68    sequence_mode: SequenceMode,
69    bidirectional: bool,
70    dropout: Float,
71    learning_rate: Float,
72    max_iter: usize,
73    tolerance: Float,
74    random_state: Option<u64>,
75    alpha: Float, // L2 regularization
76}
77
78/// Trained state for RecurrentNeuralNetwork
79#[derive(Debug, Clone)]
80pub struct RecurrentNeuralNetworkTrained {
81    /// Input-to-hidden weights for each layer
82    input_weights: Vec<Array2<Float>>,
83    /// Hidden-to-hidden weights for each layer
84    hidden_weights: Vec<Array2<Float>>,
85    /// Biases for each layer
86    biases: Vec<Array1<Float>>,
87    /// Output layer weights
88    output_weights: Array2<Float>,
89    /// Output layer bias
90    output_bias: Array1<Float>,
91    /// Additional parameters for LSTM/GRU gates
92    gate_weights: HashMap<String, Vec<Array2<Float>>>,
93    gate_biases: HashMap<String, Vec<Array1<Float>>>,
94    /// Network configuration
95    cell_type: CellType,
96    hidden_size: usize,
97    num_layers: usize,
98    sequence_mode: SequenceMode,
99    bidirectional: bool,
100    n_features: usize,
101    n_outputs: usize,
102    /// Training history
103    loss_curve: Vec<Float>,
104    n_iter: usize,
105}
106
107impl RecurrentNeuralNetwork<Untrained> {
108    /// Create a new RecurrentNeuralNetwork instance
109    pub fn new() -> Self {
110        Self {
111            state: Untrained,
112            cell_type: CellType::LSTM,
113            hidden_size: 50,
114            num_layers: 1,
115            sequence_mode: SequenceMode::ManyToMany,
116            bidirectional: false,
117            dropout: 0.0,
118            learning_rate: 0.001,
119            max_iter: 100,
120            tolerance: 1e-4,
121            random_state: None,
122            alpha: 0.0001,
123        }
124    }
125
126    /// Set the cell type (RNN, LSTM, GRU)
127    pub fn cell_type(mut self, cell_type: CellType) -> Self {
128        self.cell_type = cell_type;
129        self
130    }
131
132    /// Set the hidden layer size
133    pub fn hidden_size(mut self, hidden_size: usize) -> Self {
134        self.hidden_size = hidden_size;
135        self
136    }
137
138    /// Set the number of recurrent layers
139    pub fn num_layers(mut self, num_layers: usize) -> Self {
140        self.num_layers = num_layers;
141        self
142    }
143
144    /// Set the sequence prediction mode
145    pub fn sequence_mode(mut self, sequence_mode: SequenceMode) -> Self {
146        self.sequence_mode = sequence_mode;
147        self
148    }
149
150    /// Enable bidirectional processing
151    pub fn bidirectional(mut self, bidirectional: bool) -> Self {
152        self.bidirectional = bidirectional;
153        self
154    }
155
156    /// Set dropout rate
157    pub fn dropout(mut self, dropout: Float) -> Self {
158        self.dropout = dropout;
159        self
160    }
161
162    /// Set learning rate
163    pub fn learning_rate(mut self, learning_rate: Float) -> Self {
164        self.learning_rate = learning_rate;
165        self
166    }
167
168    /// Set maximum iterations
169    pub fn max_iter(mut self, max_iter: usize) -> Self {
170        self.max_iter = max_iter;
171        self
172    }
173
174    /// Set tolerance for convergence
175    pub fn tolerance(mut self, tolerance: Float) -> Self {
176        self.tolerance = tolerance;
177        self
178    }
179
180    /// Set random state for reproducibility
181    pub fn random_state(mut self, random_state: Option<u64>) -> Self {
182        self.random_state = random_state;
183        self
184    }
185
186    /// Set L2 regularization parameter
187    pub fn alpha(mut self, alpha: Float) -> Self {
188        self.alpha = alpha;
189        self
190    }
191}
192
193impl Default for RecurrentNeuralNetwork<Untrained> {
194    fn default() -> Self {
195        Self::new()
196    }
197}
198
199impl Estimator for RecurrentNeuralNetwork<Untrained> {
200    type Config = ();
201    type Error = SklearsError;
202    type Float = Float;
203
204    fn config(&self) -> &Self::Config {
205        &()
206    }
207}
208
209impl Fit<ArrayView3<'_, Float>, Array3<Float>> for RecurrentNeuralNetwork<Untrained> {
210    type Fitted = RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained>;
211
212    #[allow(non_snake_case)]
213    fn fit(self, X: &ArrayView3<'_, Float>, y: &Array3<Float>) -> SklResult<Self::Fitted> {
214        let (n_samples, max_seq_len, n_features) = X.dim();
215        let (n_samples_y, max_seq_len_y, n_outputs) = y.dim();
216
217        if n_samples != n_samples_y {
218            return Err(SklearsError::InvalidInput(
219                "X and y must have the same number of samples".to_string(),
220            ));
221        }
222
223        if self.sequence_mode == SequenceMode::ManyToMany && max_seq_len != max_seq_len_y {
224            return Err(SklearsError::InvalidInput(
225                "For many-to-many mode, X and y must have the same sequence length".to_string(),
226            ));
227        }
228
229        if n_samples == 0 {
230            return Err(SklearsError::InvalidInput(
231                "Cannot fit with zero samples".to_string(),
232            ));
233        }
234
235        // Initialize random number generator
236        let mut rng = thread_rng();
237
238        // Initialize weights and biases
239        let (
240            input_weights,
241            hidden_weights,
242            biases,
243            output_weights,
244            output_bias,
245            gate_weights,
246            gate_biases,
247        ) = self.initialize_parameters(n_features, n_outputs, &mut rng)?;
248
249        let mut input_weights = input_weights;
250        let mut hidden_weights = hidden_weights;
251        let mut biases = biases;
252        let mut output_weights = output_weights;
253        let mut output_bias = output_bias;
254        let mut gate_weights = gate_weights;
255        let mut gate_biases = gate_biases;
256
257        // Training loop
258        let mut loss_curve = Vec::new();
259        let X_owned = X.to_owned();
260        let y_owned = y.to_owned();
261
262        for epoch in 0..self.max_iter {
263            let mut total_loss = 0.0;
264
265            // Process each sequence in the batch
266            for sample_idx in 0..n_samples {
267                let x_seq = X_owned.slice(s![sample_idx, .., ..]);
268                let y_seq = y_owned.slice(s![sample_idx, .., ..]);
269
270                // Forward pass
271                let (predictions, hidden_states) = self.forward_sequence(
272                    &x_seq,
273                    &input_weights,
274                    &hidden_weights,
275                    &biases,
276                    &output_weights,
277                    &output_bias,
278                    &gate_weights,
279                    &gate_biases,
280                )?;
281
282                // Compute loss
283                let sample_loss = self.compute_sequence_loss(&predictions, &y_seq.to_owned());
284                total_loss += sample_loss;
285
286                // Backward pass (BPTT)
287                self.backward_sequence(
288                    &x_seq,
289                    &y_seq.to_owned(),
290                    &predictions,
291                    &hidden_states,
292                    &mut input_weights,
293                    &mut hidden_weights,
294                    &mut biases,
295                    &mut output_weights,
296                    &mut output_bias,
297                    &mut gate_weights,
298                    &mut gate_biases,
299                )?;
300            }
301
302            let avg_loss = total_loss / n_samples as Float;
303            loss_curve.push(avg_loss);
304
305            // Check convergence
306            if epoch > 0 && (loss_curve[epoch - 1] - avg_loss).abs() < self.tolerance {
307                break;
308            }
309        }
310
311        let trained_state = RecurrentNeuralNetworkTrained {
312            input_weights,
313            hidden_weights,
314            biases,
315            output_weights,
316            output_bias,
317            gate_weights,
318            gate_biases,
319            cell_type: self.cell_type,
320            hidden_size: self.hidden_size,
321            num_layers: self.num_layers,
322            sequence_mode: self.sequence_mode,
323            bidirectional: self.bidirectional,
324            n_features,
325            n_outputs,
326            loss_curve,
327            n_iter: self.max_iter,
328        };
329
330        Ok(RecurrentNeuralNetwork {
331            state: trained_state,
332            cell_type: self.cell_type,
333            hidden_size: self.hidden_size,
334            num_layers: self.num_layers,
335            sequence_mode: self.sequence_mode,
336            bidirectional: self.bidirectional,
337            dropout: self.dropout,
338            learning_rate: self.learning_rate,
339            max_iter: self.max_iter,
340            tolerance: self.tolerance,
341            random_state: self.random_state,
342            alpha: self.alpha,
343        })
344    }
345}
346
347impl RecurrentNeuralNetwork<Untrained> {
348    /// Initialize network parameters
349    fn initialize_parameters(
350        &self,
351        n_features: usize,
352        n_outputs: usize,
353        rng: &mut scirs2_core::random::CoreRandom,
354    ) -> SklResult<(
355        Vec<Array2<Float>>,                  // input_weights
356        Vec<Array2<Float>>,                  // hidden_weights
357        Vec<Array1<Float>>,                  // biases
358        Array2<Float>,                       // output_weights
359        Array1<Float>,                       // output_bias
360        HashMap<String, Vec<Array2<Float>>>, // gate_weights
361        HashMap<String, Vec<Array1<Float>>>, // gate_biases
362    )> {
363        let mut input_weights = Vec::new();
364        let mut hidden_weights = Vec::new();
365        let mut biases = Vec::new();
366        let mut gate_weights = HashMap::new();
367        let mut gate_biases = HashMap::new();
368
369        // Initialize parameters for each layer
370        for layer in 0..self.num_layers {
371            let input_size = if layer == 0 {
372                n_features
373            } else {
374                self.hidden_size
375            };
376
377            // Xavier initialization
378            let input_scale = (2.0 / (input_size + self.hidden_size) as Float).sqrt();
379            let hidden_scale = (2.0 / (self.hidden_size + self.hidden_size) as Float).sqrt();
380
381            let mut input_weight = Array2::<Float>::zeros((self.hidden_size, input_size));
382            let normal_dist = RandNormal::new(0.0, input_scale).unwrap();
383            for i in 0..self.hidden_size {
384                for j in 0..input_size {
385                    input_weight[[i, j]] = rng.sample(normal_dist);
386                }
387            }
388            let mut hidden_weight = Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
389            let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
390            for i in 0..self.hidden_size {
391                for j in 0..self.hidden_size {
392                    hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
393                }
394            }
395            let bias = Array1::<Float>::zeros(self.hidden_size);
396
397            input_weights.push(input_weight);
398            hidden_weights.push(hidden_weight);
399            biases.push(bias);
400
401            // Initialize gate parameters for LSTM/GRU
402            match self.cell_type {
403                CellType::LSTM => {
404                    // LSTM has forget, input, and output gates
405                    for gate_name in &["forget", "input", "output", "cell"] {
406                        // Initialize input weights
407                        let input_key = format!("{}_input", gate_name);
408                        if !gate_weights.contains_key(&input_key) {
409                            gate_weights.insert(input_key.clone(), Vec::new());
410                        }
411                        let mut input_weight =
412                            Array2::<Float>::zeros((self.hidden_size, input_size));
413                        let input_normal_dist = RandNormal::new(0.0, input_scale).unwrap();
414                        for i in 0..self.hidden_size {
415                            for j in 0..input_size {
416                                input_weight[[i, j]] = rng.sample(input_normal_dist);
417                            }
418                        }
419                        gate_weights.get_mut(&input_key).unwrap().push(input_weight);
420
421                        // Initialize hidden weights
422                        let hidden_key = format!("{}_hidden", gate_name);
423                        if !gate_weights.contains_key(&hidden_key) {
424                            gate_weights.insert(hidden_key.clone(), Vec::new());
425                        }
426                        let mut hidden_weight =
427                            Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
428                        let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
429                        for i in 0..self.hidden_size {
430                            for j in 0..self.hidden_size {
431                                hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
432                            }
433                        }
434                        gate_weights
435                            .get_mut(&hidden_key)
436                            .unwrap()
437                            .push(hidden_weight);
438
439                        // Initialize biases
440                        let bias_key = gate_name.to_string();
441                        if !gate_biases.contains_key(&bias_key) {
442                            gate_biases.insert(bias_key.clone(), Vec::new());
443                        }
444                        gate_biases
445                            .get_mut(&bias_key)
446                            .unwrap()
447                            .push(Array1::<Float>::zeros(self.hidden_size));
448                    }
449                }
450                CellType::GRU => {
451                    // GRU has reset and update gates
452                    for gate_name in &["reset", "update", "new"] {
453                        // Initialize input weights
454                        let input_key = format!("{}_input", gate_name);
455                        if !gate_weights.contains_key(&input_key) {
456                            gate_weights.insert(input_key.clone(), Vec::new());
457                        }
458                        let mut input_weight =
459                            Array2::<Float>::zeros((self.hidden_size, input_size));
460                        let input_normal_dist = RandNormal::new(0.0, input_scale).unwrap();
461                        for i in 0..self.hidden_size {
462                            for j in 0..input_size {
463                                input_weight[[i, j]] = rng.sample(input_normal_dist);
464                            }
465                        }
466                        gate_weights.get_mut(&input_key).unwrap().push(input_weight);
467
468                        // Initialize hidden weights
469                        let hidden_key = format!("{}_hidden", gate_name);
470                        if !gate_weights.contains_key(&hidden_key) {
471                            gate_weights.insert(hidden_key.clone(), Vec::new());
472                        }
473                        let mut hidden_weight =
474                            Array2::<Float>::zeros((self.hidden_size, self.hidden_size));
475                        let hidden_normal_dist = RandNormal::new(0.0, hidden_scale).unwrap();
476                        for i in 0..self.hidden_size {
477                            for j in 0..self.hidden_size {
478                                hidden_weight[[i, j]] = rng.sample(hidden_normal_dist);
479                            }
480                        }
481                        gate_weights
482                            .get_mut(&hidden_key)
483                            .unwrap()
484                            .push(hidden_weight);
485
486                        // Initialize biases
487                        let bias_key = gate_name.to_string();
488                        if !gate_biases.contains_key(&bias_key) {
489                            gate_biases.insert(bias_key.clone(), Vec::new());
490                        }
491                        gate_biases
492                            .get_mut(&bias_key)
493                            .unwrap()
494                            .push(Array1::<Float>::zeros(self.hidden_size));
495                    }
496                }
497                CellType::RNN => {
498                    // Simple RNN doesn't need additional gates
499                }
500            }
501        }
502
503        // Output layer
504        let output_input_size = if self.bidirectional {
505            2 * self.hidden_size
506        } else {
507            self.hidden_size
508        };
509        let output_scale = (2.0 / (output_input_size + n_outputs) as Float).sqrt();
510        let mut output_weights = Array2::<Float>::zeros((n_outputs, output_input_size));
511        let output_normal_dist = RandNormal::new(0.0, output_scale).unwrap();
512        for i in 0..n_outputs {
513            for j in 0..output_input_size {
514                output_weights[[i, j]] = rng.sample(output_normal_dist);
515            }
516        }
517        let output_bias = Array1::<Float>::zeros(n_outputs);
518
519        Ok((
520            input_weights,
521            hidden_weights,
522            biases,
523            output_weights,
524            output_bias,
525            gate_weights,
526            gate_biases,
527        ))
528    }
529
530    /// Forward pass through sequence
531    fn forward_sequence(
532        &self,
533        x_seq: &ArrayView2<'_, Float>,
534        input_weights: &[Array2<Float>],
535        hidden_weights: &[Array2<Float>],
536        biases: &[Array1<Float>],
537        output_weights: &Array2<Float>,
538        output_bias: &Array1<Float>,
539        gate_weights: &HashMap<String, Vec<Array2<Float>>>,
540        gate_biases: &HashMap<String, Vec<Array1<Float>>>,
541    ) -> SklResult<(Array2<Float>, Vec<Vec<Array1<Float>>>)> {
542        let (seq_len, _) = x_seq.dim();
543        let n_outputs = output_weights.nrows();
544
545        // Initialize hidden states for all layers
546        let mut hidden_states = Vec::new();
547        for _ in 0..self.num_layers {
548            hidden_states.push(vec![Array1::<Float>::zeros(self.hidden_size); seq_len + 1]);
549        }
550
551        let mut cell_states = Vec::new();
552        if self.cell_type == CellType::LSTM {
553            for _ in 0..self.num_layers {
554                cell_states.push(vec![Array1::<Float>::zeros(self.hidden_size); seq_len + 1]);
555            }
556        }
557
558        // Process sequence timestep by timestep
559        for t in 0..seq_len {
560            let x_t = x_seq.row(t);
561
562            for layer in 0..self.num_layers {
563                let input = if layer == 0 {
564                    x_t.to_owned()
565                } else {
566                    hidden_states[layer - 1][t].clone()
567                };
568
569                let prev_hidden = &hidden_states[layer][t];
570
571                match self.cell_type {
572                    CellType::RNN => {
573                        // Simple RNN: h_t = tanh(W_ih * x_t + W_hh * h_{t-1} + b)
574                        let linear = input_weights[layer].dot(&input)
575                            + hidden_weights[layer].dot(prev_hidden)
576                            + &biases[layer];
577                        hidden_states[layer][t + 1] = linear.map(|x| x.tanh());
578                    }
579                    CellType::LSTM => {
580                        // LSTM cell computation
581                        let prev_cell = &cell_states[layer][t];
582
583                        // Forget gate
584                        let f_t = self.compute_gate(
585                            &input,
586                            prev_hidden,
587                            &gate_weights["forget_input"][layer],
588                            &gate_weights["forget_hidden"][layer],
589                            &gate_biases["forget"][layer],
590                            ActivationFunction::Sigmoid,
591                        );
592
593                        // Input gate
594                        let i_t = self.compute_gate(
595                            &input,
596                            prev_hidden,
597                            &gate_weights["input_input"][layer],
598                            &gate_weights["input_hidden"][layer],
599                            &gate_biases["input"][layer],
600                            ActivationFunction::Sigmoid,
601                        );
602
603                        // Candidate values
604                        let c_tilde = self.compute_gate(
605                            &input,
606                            prev_hidden,
607                            &gate_weights["cell_input"][layer],
608                            &gate_weights["cell_hidden"][layer],
609                            &gate_biases["cell"][layer],
610                            ActivationFunction::Tanh,
611                        );
612
613                        // Update cell state
614                        let new_cell = &f_t * prev_cell + &i_t * &c_tilde;
615                        cell_states[layer][t + 1] = new_cell.clone();
616
617                        // Output gate
618                        let o_t = self.compute_gate(
619                            &input,
620                            prev_hidden,
621                            &gate_weights["output_input"][layer],
622                            &gate_weights["output_hidden"][layer],
623                            &gate_biases["output"][layer],
624                            ActivationFunction::Sigmoid,
625                        );
626
627                        // Update hidden state
628                        hidden_states[layer][t + 1] = &o_t * &new_cell.map(|x| x.tanh());
629                    }
630                    CellType::GRU => {
631                        // GRU cell computation
632                        let r_t = self.compute_gate(
633                            &input,
634                            prev_hidden,
635                            &gate_weights["reset_input"][layer],
636                            &gate_weights["reset_hidden"][layer],
637                            &gate_biases["reset"][layer],
638                            ActivationFunction::Sigmoid,
639                        );
640
641                        let z_t = self.compute_gate(
642                            &input,
643                            prev_hidden,
644                            &gate_weights["update_input"][layer],
645                            &gate_weights["update_hidden"][layer],
646                            &gate_biases["update"][layer],
647                            ActivationFunction::Sigmoid,
648                        );
649
650                        let reset_hidden = &r_t * prev_hidden;
651                        let n_t = self.compute_gate(
652                            &input,
653                            &reset_hidden,
654                            &gate_weights["new_input"][layer],
655                            &gate_weights["new_hidden"][layer],
656                            &gate_biases["new"][layer],
657                            ActivationFunction::Tanh,
658                        );
659
660                        let one_minus_z = Array1::<Float>::ones(self.hidden_size) - &z_t;
661                        hidden_states[layer][t + 1] = &z_t * prev_hidden + &one_minus_z * &n_t;
662                    }
663                }
664            }
665        }
666
667        // Generate outputs based on sequence mode
668        let predictions = match self.sequence_mode {
669            SequenceMode::ManyToMany => {
670                let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
671                for t in 0..seq_len {
672                    let last_layer_hidden = &hidden_states[self.num_layers - 1][t + 1];
673                    let output_t = output_weights.dot(last_layer_hidden) + output_bias;
674                    outputs.row_mut(t).assign(&output_t);
675                }
676                outputs
677            }
678            SequenceMode::ManyToOne => {
679                let final_hidden = &hidden_states[self.num_layers - 1][seq_len];
680                let output = output_weights.dot(final_hidden) + output_bias;
681                Array2::from_shape_vec((1, n_outputs), output.to_vec()).unwrap()
682            }
683            SequenceMode::OneToMany => {
684                // For one-to-many, we typically use the input at t=0 and generate sequence
685                let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
686                for t in 0..seq_len {
687                    let hidden_t = &hidden_states[self.num_layers - 1][t + 1];
688                    let output_t = output_weights.dot(hidden_t) + output_bias;
689                    outputs.row_mut(t).assign(&output_t);
690                }
691                outputs
692            }
693        };
694
695        Ok((predictions, hidden_states))
696    }
697
698    /// Compute gate activation
699    fn compute_gate(
700        &self,
701        input: &Array1<Float>,
702        hidden: &Array1<Float>,
703        input_weight: &Array2<Float>,
704        hidden_weight: &Array2<Float>,
705        bias: &Array1<Float>,
706        activation: ActivationFunction,
707    ) -> Array1<Float> {
708        let linear = input_weight.dot(input) + hidden_weight.dot(hidden) + bias;
709        activation.apply(&linear)
710    }
711
712    /// Compute loss for sequence
713    fn compute_sequence_loss(&self, predictions: &Array2<Float>, targets: &Array2<Float>) -> Float {
714        let diff = predictions - targets;
715        diff.map(|x| x * x).mean().unwrap()
716    }
717
718    /// Backward pass through sequence (simplified BPTT)
719    fn backward_sequence(
720        &self,
721        x_seq: &ArrayView2<'_, Float>,
722        y_seq: &Array2<Float>,
723        predictions: &Array2<Float>,
724        hidden_states: &[Vec<Array1<Float>>],
725        input_weights: &mut [Array2<Float>],
726        hidden_weights: &mut [Array2<Float>],
727        biases: &mut [Array1<Float>],
728        output_weights: &mut Array2<Float>,
729        output_bias: &mut Array1<Float>,
730        gate_weights: &mut HashMap<String, Vec<Array2<Float>>>,
731        gate_biases: &mut HashMap<String, Vec<Array1<Float>>>,
732    ) -> SklResult<()> {
733        // Simplified gradient computation
734        let (seq_len, _) = x_seq.dim();
735
736        // Compute output gradients
737        let output_error = predictions - y_seq;
738
739        match self.sequence_mode {
740            SequenceMode::ManyToMany => {
741                for t in 0..seq_len {
742                    let hidden_t = &hidden_states[self.num_layers - 1][t + 1];
743                    let error_t = output_error.row(t).to_owned();
744
745                    // Update output layer
746                    let weight_grad = error_t
747                        .clone()
748                        .insert_axis(Axis(1))
749                        .dot(&hidden_t.clone().insert_axis(Axis(0)));
750                    *output_weights = output_weights.clone() - self.learning_rate * weight_grad;
751                    *output_bias = output_bias.clone() - self.learning_rate * &error_t;
752
753                    // Simplified hidden layer updates
754                    for layer in (0..self.num_layers).rev() {
755                        let x_t = if layer == 0 {
756                            x_seq.row(t).to_owned()
757                        } else {
758                            hidden_states[layer - 1][t + 1].clone()
759                        };
760
761                        let hidden_error = output_weights.t().dot(&error_t);
762                        let weight_grad = hidden_error
763                            .clone()
764                            .insert_axis(Axis(1))
765                            .dot(&x_t.insert_axis(Axis(0)));
766
767                        input_weights[layer] =
768                            input_weights[layer].clone() - self.learning_rate * weight_grad;
769                        biases[layer] = biases[layer].clone() - self.learning_rate * hidden_error;
770                    }
771                }
772            }
773            _ => {
774                // Simplified update for other modes
775                let hidden_final = &hidden_states[self.num_layers - 1][seq_len];
776                let error_final = output_error.row(0).to_owned();
777                let weight_grad = error_final
778                    .clone()
779                    .insert_axis(Axis(1))
780                    .dot(&hidden_final.clone().insert_axis(Axis(0)));
781                *output_weights = output_weights.clone() - self.learning_rate * weight_grad;
782                *output_bias = output_bias.clone() - self.learning_rate * error_final;
783            }
784        }
785
786        Ok(())
787    }
788}
789
790impl Predict<ArrayView3<'_, Float>, Array3<Float>>
791    for RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained>
792{
793    fn predict(&self, X: &ArrayView3<'_, Float>) -> SklResult<Array3<Float>> {
794        let (n_samples, max_seq_len, n_features) = X.dim();
795
796        if n_features != self.state.n_features {
797            return Err(SklearsError::InvalidInput(
798                "X has different number of features than training data".to_string(),
799            ));
800        }
801
802        let mut predictions = match self.state.sequence_mode {
803            SequenceMode::ManyToMany => {
804                Array3::<Float>::zeros((n_samples, max_seq_len, self.state.n_outputs))
805            }
806            SequenceMode::ManyToOne => Array3::<Float>::zeros((n_samples, 1, self.state.n_outputs)),
807            SequenceMode::OneToMany => {
808                Array3::<Float>::zeros((n_samples, max_seq_len, self.state.n_outputs))
809            }
810        };
811
812        // Process each sequence
813        for sample_idx in 0..n_samples {
814            let x_seq = X.slice(s![sample_idx, .., ..]);
815
816            let (sample_predictions, _) = self.forward_sequence_trained(&x_seq)?;
817
818            match self.state.sequence_mode {
819                SequenceMode::ManyToMany | SequenceMode::OneToMany => {
820                    for t in 0..sample_predictions.nrows() {
821                        for j in 0..sample_predictions.ncols() {
822                            predictions[[sample_idx, t, j]] = sample_predictions[[t, j]];
823                        }
824                    }
825                }
826                SequenceMode::ManyToOne => {
827                    for j in 0..sample_predictions.ncols() {
828                        predictions[[sample_idx, 0, j]] = sample_predictions[[0, j]];
829                    }
830                }
831            }
832        }
833
834        Ok(predictions)
835    }
836}
837
838impl RecurrentNeuralNetwork<RecurrentNeuralNetworkTrained> {
839    /// Forward pass for trained model
840    fn forward_sequence_trained(
841        &self,
842        x_seq: &ArrayView2<'_, Float>,
843    ) -> SklResult<(Array2<Float>, Vec<Vec<Array1<Float>>>)> {
844        let (seq_len, _) = x_seq.dim();
845        let n_outputs = self.state.output_weights.nrows();
846
847        // Initialize hidden states
848        let mut hidden_states = Vec::new();
849        for _ in 0..self.state.num_layers {
850            hidden_states.push(vec![
851                Array1::<Float>::zeros(self.state.hidden_size);
852                seq_len + 1
853            ]);
854        }
855
856        // Forward pass (simplified for prediction)
857        for t in 0..seq_len {
858            let x_t = x_seq.row(t);
859
860            for layer in 0..self.state.num_layers {
861                let input = if layer == 0 {
862                    x_t.to_owned()
863                } else {
864                    hidden_states[layer - 1][t].clone()
865                };
866
867                let prev_hidden = &hidden_states[layer][t];
868
869                // Simplified cell computation for prediction
870                let linear = self.state.input_weights[layer].dot(&input)
871                    + self.state.hidden_weights[layer].dot(prev_hidden)
872                    + &self.state.biases[layer];
873
874                hidden_states[layer][t + 1] = match self.state.cell_type {
875                    CellType::RNN => linear.map(|x| x.tanh()),
876                    CellType::LSTM | CellType::GRU => linear.map(|x| x.tanh()), // Simplified
877                };
878            }
879        }
880
881        // Generate outputs
882        let predictions = match self.state.sequence_mode {
883            SequenceMode::ManyToMany => {
884                let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
885                for t in 0..seq_len {
886                    let last_layer_hidden = &hidden_states[self.state.num_layers - 1][t + 1];
887                    let output_t =
888                        self.state.output_weights.dot(last_layer_hidden) + &self.state.output_bias;
889                    outputs.row_mut(t).assign(&output_t);
890                }
891                outputs
892            }
893            SequenceMode::ManyToOne => {
894                let final_hidden = &hidden_states[self.state.num_layers - 1][seq_len];
895                let output = self.state.output_weights.dot(final_hidden) + &self.state.output_bias;
896                Array2::from_shape_vec((1, n_outputs), output.to_vec()).unwrap()
897            }
898            SequenceMode::OneToMany => {
899                let mut outputs = Array2::<Float>::zeros((seq_len, n_outputs));
900                for t in 0..seq_len {
901                    let hidden_t = &hidden_states[self.state.num_layers - 1][t + 1];
902                    let output_t =
903                        self.state.output_weights.dot(hidden_t) + &self.state.output_bias;
904                    outputs.row_mut(t).assign(&output_t);
905                }
906                outputs
907            }
908        };
909
910        Ok((predictions, hidden_states))
911    }
912
913    /// Get the loss curve from training
914    pub fn loss_curve(&self) -> &[Float] {
915        &self.state.loss_curve
916    }
917
918    /// Get training iterations
919    pub fn n_iter(&self) -> usize {
920        self.state.n_iter
921    }
922
923    /// Get network configuration
924    pub fn cell_type(&self) -> CellType {
925        self.state.cell_type
926    }
927
928    /// Get hidden size
929    pub fn hidden_size(&self) -> usize {
930        self.state.hidden_size
931    }
932
933    /// Get sequence mode
934    pub fn sequence_mode(&self) -> SequenceMode {
935        self.state.sequence_mode
936    }
937}