scirs2_neural/models/
sequential.rs

1//! Sequential model implementation
2//!
3//! This module provides a sequential model implementation that chains
4//! layers together in a linear sequence.
5
6use ndarray::{Array, ScalarOperand};
7use num_traits::Float;
8use std::fmt::Debug;
9
10use crate::error::{NeuralError, Result};
11use crate::layers::{Layer, ParamLayer};
12use crate::losses::Loss;
13use crate::models::Model;
14use crate::optimizers::Optimizer;
15
16/// A sequential model that chains layers together in a linear sequence
17pub struct Sequential<F: Float + Debug + ScalarOperand + 'static> {
18    layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
19    layer_outputs: Vec<Array<F, ndarray::IxDyn>>,
20    input: Option<Array<F, ndarray::IxDyn>>,
21}
22
23impl<F: Float + Debug + ScalarOperand + 'static> Default for Sequential<F> {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<F: Float + Debug + ScalarOperand + 'static> Clone for Sequential<F> {
30    fn clone(&self) -> Self {
31        // Note: We can't clone the layer trait objects directly
32        // This creates a new empty model - the Clone trait is mainly for testing
33        Sequential {
34            layers: Vec::new(), // Cannot clone trait objects
35            layer_outputs: Vec::new(),
36            input: None,
37        }
38    }
39}
40
41impl<F: Float + Debug + ScalarOperand + 'static> Sequential<F> {
42    /// Create a new empty sequential model
43    pub fn new() -> Self {
44        Sequential {
45            layers: Vec::new(),
46            layer_outputs: Vec::new(),
47            input: None,
48        }
49    }
50
51    /// Create a new sequential model from existing layers
52    pub fn from_layers(layers: Vec<Box<dyn Layer<F> + Send + Sync>>) -> Self {
53        Sequential {
54            layers,
55            layer_outputs: Vec::new(),
56            input: None,
57        }
58    }
59
60    /// Add a layer to the model
61    pub fn add_layer<L: Layer<F> + 'static + Send + Sync>(&mut self, layer: L) -> &mut Self {
62        self.layers.push(Box::new(layer));
63        self
64    }
65
66    /// Get the number of layers in the model
67    pub fn num_layers(&self) -> usize {
68        self.layers.len()
69    }
70
71    /// Get the layers in the model
72    pub fn layers(&self) -> &[Box<dyn Layer<F> + Send + Sync>] {
73        &self.layers
74    }
75
76    /// Get a mutable reference to the layers in the model
77    pub fn layers_mut(&mut self) -> &mut Vec<Box<dyn Layer<F> + Send + Sync>> {
78        &mut self.layers
79    }
80}
81
82impl<F: Float + Debug + ScalarOperand + 'static> Model<F> for Sequential<F> {
83    fn forward(&self, input: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>> {
84        let mut current_output = input.clone();
85
86        for layer in &self.layers {
87            current_output = layer.forward(&current_output)?;
88        }
89
90        Ok(current_output)
91    }
92
93    fn backward(
94        &self,
95        input: &Array<F, ndarray::IxDyn>,
96        grad_output: &Array<F, ndarray::IxDyn>,
97    ) -> Result<Array<F, ndarray::IxDyn>> {
98        if self.layer_outputs.is_empty() {
99            return Err(NeuralError::InferenceError(
100                "No forward pass performed before backward pass".to_string(),
101            ));
102        }
103
104        let mut grad_input = grad_output.clone();
105
106        // Iterate through layers in reverse
107        for (i, layer) in self.layers.iter().enumerate().rev() {
108            // Get input for this layer (either from previous layer or the original input)
109            let layer_input = if i > 0 {
110                &self.layer_outputs[i - 1]
111            } else if let Some(saved_input) = &self.input {
112                saved_input
113            } else {
114                // Fallback to the provided input if nothing is saved
115                input
116            };
117
118            grad_input = layer.backward(layer_input, &grad_input)?;
119        }
120
121        Ok(grad_input)
122    }
123
124    fn update(&mut self, learning_rate: F) -> Result<()> {
125        for layer in &mut self.layers {
126            layer.update(learning_rate)?;
127        }
128
129        Ok(())
130    }
131
132    fn train_batch(
133        &mut self,
134        inputs: &Array<F, ndarray::IxDyn>,
135        targets: &Array<F, ndarray::IxDyn>,
136        loss_fn: &dyn Loss<F>,
137        optimizer: &mut dyn Optimizer<F>,
138    ) -> Result<F> {
139        // Forward pass
140        let mut layer_outputs = Vec::with_capacity(self.layers.len());
141        let mut current_output = inputs.clone();
142
143        for layer in &self.layers {
144            current_output = layer.forward(&current_output)?;
145            layer_outputs.push(current_output.clone());
146        }
147
148        // Save outputs for backward pass
149        self.input = Some(inputs.clone());
150        self.layer_outputs = layer_outputs;
151
152        // Compute loss
153        let predictions = self
154            .layer_outputs
155            .last()
156            .ok_or_else(|| NeuralError::InferenceError("No layers in model".to_string()))?;
157        let loss = loss_fn.forward(predictions, targets)?;
158
159        // Backward pass to compute gradients
160        let loss_grad = loss_fn.backward(predictions, targets)?;
161
162        let mut grad_input = loss_grad;
163
164        // Iterate through layers in reverse
165        for (i, layer) in self.layers.iter_mut().enumerate().rev() {
166            // Get input for this layer (either from previous layer or the original input)
167            let layer_input = if i > 0 {
168                &self.layer_outputs[i - 1]
169            } else {
170                inputs
171            };
172
173            grad_input = layer.backward(layer_input, &grad_input)?;
174        }
175
176        // Update parameters using optimizer
177        let mut all_params = Vec::new();
178        let mut all_grads = Vec::new();
179        let mut param_layers = Vec::new();
180
181        // First, collect all parameters and gradients
182        for (i, layer) in self.layers.iter().enumerate() {
183            // Need to use concrete type instead of dyn trait for downcasting
184            // Try to use concrete implementations
185            if let Some(param_layer) = layer
186                .as_any()
187                .downcast_ref::<Box<dyn ParamLayer<F> + Send + Sync>>()
188            {
189                param_layers.push(i);
190
191                for param in param_layer.get_parameters() {
192                    all_params.push(param.clone());
193                }
194
195                for grad in param_layer.get_gradients() {
196                    all_grads.push(grad.clone());
197                }
198            }
199        }
200
201        // Update parameters using optimizer
202        optimizer.update(&mut all_params, &all_grads)?;
203
204        // Update the layers with the optimized parameters
205        let mut param_idx = 0;
206        for i in param_layers {
207            // Need to use concrete type instead of dyn trait for downcasting
208            if let Some(param_layer) = self.layers[i]
209                .as_any_mut()
210                .downcast_mut::<Box<dyn ParamLayer<F> + Send + Sync>>()
211            {
212                let num_params = param_layer.get_parameters().len();
213                if param_idx + num_params <= all_params.len() {
214                    let layer_params = all_params[param_idx..param_idx + num_params].to_vec();
215                    param_layer.set_parameters(layer_params)?;
216                    param_idx += num_params;
217                }
218            }
219        }
220
221        Ok(loss)
222    }
223
224    fn predict(&self, inputs: &Array<F, ndarray::IxDyn>) -> Result<Array<F, ndarray::IxDyn>> {
225        self.forward(inputs)
226    }
227
228    fn evaluate(
229        &self,
230        inputs: &Array<F, ndarray::IxDyn>,
231        targets: &Array<F, ndarray::IxDyn>,
232        loss_fn: &dyn Loss<F>,
233    ) -> Result<F> {
234        let predictions = self.forward(inputs)?;
235        loss_fn.forward(&predictions, targets)
236    }
237}