scirs2_neural/layers/
mod.rs

1//! Neural network layers implementation
2//!
3//! This module provides implementations of various neural network layers
4//! such as dense (fully connected), attention, convolution, pooling, etc.
5//! Layers are the fundamental building blocks of neural networks.
6
7use crate::error::Result;
8use scirs2_core::ndarray::{Array, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::fmt::Debug;
11
12/// Base trait for neural network layers
13///
14/// This trait defines the core interface that all neural network layers must implement.
15/// It supports forward propagation, backpropagation, parameter management, and
16/// training/evaluation mode switching.
17pub trait Layer<F: Float + Debug + ScalarOperand>: Send + Sync {
18    /// Forward pass of the layer
19    ///
20    /// Computes the output of the layer given an input tensor.
21    fn forward(
22        &self,
23        input: &Array<F, scirs2_core::ndarray::IxDyn>,
24    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
25
26    /// Backward pass of the layer to compute gradients
27    ///
28    /// Computes gradients with respect to the layer's input, which is needed
29    /// for backpropagation.
30    fn backward(
31        &self,
32        input: &Array<F, scirs2_core::ndarray::IxDyn>,
33        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
34    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>>;
35
36    /// Update the layer parameters with the given learning rate
37    fn update(&mut self, learningrate: F) -> Result<()>;
38
39    /// Get the layer as a dyn Any for downcasting
40    fn as_any(&self) -> &dyn std::any::Any;
41
42    /// Get the layer as a mutable dyn Any for downcasting
43    fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
44
45    /// Get the parameters of the layer
46    fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
47        Vec::new()
48    }
49
50    /// Get the gradients of the layer parameters
51    fn gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
52        Vec::new()
53    }
54
55    /// Set the gradients of the layer parameters
56    fn set_gradients(
57        &mut self,
58        _gradients: &[Array<F, scirs2_core::ndarray::IxDyn>],
59    ) -> Result<()> {
60        Ok(())
61    }
62
63    /// Set the parameters of the layer
64    fn set_params(&mut self, _params: &[Array<F, scirs2_core::ndarray::IxDyn>]) -> Result<()> {
65        Ok(())
66    }
67
68    /// Set the layer to training mode (true) or evaluation mode (false)
69    fn set_training(&mut self, _training: bool) {
70        // Default implementation: do nothing
71    }
72
73    /// Get the current training mode
74    fn is_training(&self) -> bool {
75        true // Default implementation: always in training mode
76    }
77
78    /// Get the type of the layer (e.g., "Dense", "Conv2D")
79    fn layer_type(&self) -> &str {
80        "Unknown"
81    }
82
83    /// Get the number of trainable parameters in this layer
84    fn parameter_count(&self) -> usize {
85        0
86    }
87
88    /// Get a detailed description of this layer
89    fn layer_description(&self) -> String {
90        format!("type:{}", self.layer_type())
91    }
92
93    /// Get the input shape if known
94    fn inputshape(&self) -> Option<Vec<usize>> {
95        None
96    }
97
98    /// Get the output shape if known  
99    fn outputshape(&self) -> Option<Vec<usize>> {
100        None
101    }
102
103    /// Get the name of the layer if set
104    fn name(&self) -> Option<&str> {
105        None
106    }
107}
108
109/// Trait for layers with parameters (weights, biases)
110pub trait ParamLayer<F: Float + Debug + ScalarOperand>: Layer<F> {
111    /// Get the parameters of the layer as a vector of arrays
112    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>>;
113
114    /// Get the gradients of the parameters
115    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>>;
116
117    /// Set the parameters
118    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()>;
119}
120
121/// Information about a layer for visualization purposes
122#[derive(Debug, Clone)]
123pub struct LayerInfo {
124    /// Index of the layer in the sequence
125    pub index: usize,
126    /// Name of the layer
127    pub name: String,
128    /// Type of the layer
129    pub layer_type: String,
130    /// Number of parameters in the layer
131    pub parameter_count: usize,
132    /// Input shape of the layer
133    pub inputshape: Option<Vec<usize>>,
134    /// Output shape of the layer
135    pub outputshape: Option<Vec<usize>>,
136}
137
138/// Sequential container for neural network layers
139///
140/// A Sequential model is a linear stack of layers where data flows through
141/// each layer in order.
142pub struct Sequential<F: Float + Debug + ScalarOperand> {
143    layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
144    training: bool,
145}
146
147impl<F: Float + Debug + ScalarOperand> std::fmt::Debug for Sequential<F> {
148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
149        f.debug_struct("Sequential")
150            .field("num_layers", &self.layers.len())
151            .field("training", &self.training)
152            .finish()
153    }
154}
155
156impl<F: Float + Debug + ScalarOperand + 'static> Clone for Sequential<F> {
157    fn clone(&self) -> Self {
158        // We can't clone the layers, so we just create an empty Sequential
159        // with the same training flag
160        Self {
161            layers: Vec::new(),
162            training: self.training,
163        }
164    }
165}
166
167impl<F: Float + Debug + ScalarOperand> Default for Sequential<F> {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl<F: Float + Debug + ScalarOperand> Sequential<F> {
174    /// Create a new Sequential container
175    pub fn new() -> Self {
176        Self {
177            layers: Vec::new(),
178            training: true,
179        }
180    }
181
182    /// Add a layer to the container
183    pub fn add<L: Layer<F> + Send + Sync + 'static>(&mut self, layer: L) {
184        self.layers.push(Box::new(layer));
185    }
186
187    /// Get the number of layers
188    pub fn len(&self) -> usize {
189        self.layers.len()
190    }
191
192    /// Check if there are no layers
193    pub fn is_empty(&self) -> bool {
194        self.layers.is_empty()
195    }
196
197    /// Get total parameter count across all layers
198    pub fn total_parameters(&self) -> usize {
199        self.layers
200            .iter()
201            .map(|layer| layer.parameter_count())
202            .sum()
203    }
204
205    /// Get layer information for visualization purposes
206    pub fn layer_info(&self) -> Vec<LayerInfo> {
207        self.layers
208            .iter()
209            .enumerate()
210            .map(|(i, layer)| LayerInfo {
211                index: i,
212                name: layer.name().unwrap_or(&format!("Layer_{i}")).to_string(),
213                layer_type: layer.layer_type().to_string(),
214                parameter_count: layer.parameter_count(),
215                inputshape: layer.inputshape(),
216                outputshape: layer.outputshape(),
217            })
218            .collect()
219    }
220}
221
222impl<F: Float + Debug + ScalarOperand> Layer<F> for Sequential<F> {
223    fn forward(
224        &self,
225        input: &Array<F, scirs2_core::ndarray::IxDyn>,
226    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
227        let mut output = input.clone();
228        for layer in &self.layers {
229            output = layer.forward(&output)?;
230        }
231        Ok(output)
232    }
233
234    fn backward(
235        &self,
236        _input: &Array<F, scirs2_core::ndarray::IxDyn>,
237        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
238    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
239        // For simplicity, we'll just return the grad_output as-is
240        // A real implementation would propagate through the layers in reverse
241        Ok(grad_output.clone())
242    }
243
244    fn update(&mut self, learningrate: F) -> Result<()> {
245        for layer in &mut self.layers {
246            layer.update(learningrate)?;
247        }
248        Ok(())
249    }
250
251    fn params(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
252        let mut params = Vec::new();
253        for layer in &self.layers {
254            params.extend(layer.params());
255        }
256        params
257    }
258
259    fn set_training(&mut self, training: bool) {
260        self.training = training;
261        for layer in &mut self.layers {
262            layer.set_training(training);
263        }
264    }
265
266    fn is_training(&self) -> bool {
267        self.training
268    }
269
270    fn as_any(&self) -> &dyn std::any::Any {
271        self
272    }
273
274    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
275        self
276    }
277
278    fn layer_type(&self) -> &str {
279        "Sequential"
280    }
281
282    fn parameter_count(&self) -> usize {
283        self.layers
284            .iter()
285            .map(|layer| layer.parameter_count())
286            .sum()
287    }
288}
289
290/// Configuration enum for different types of layers
291#[derive(Debug, Clone)]
292pub enum LayerConfig {
293    /// Dense (fully connected) layer
294    Dense {
295        input_size: usize,
296        output_size: usize,
297        activation: Option<String>,
298    },
299    /// 2D Convolutional layer
300    Conv2D {
301        in_channels: usize,
302        out_channels: usize,
303        kernel_size: (usize, usize),
304    },
305    /// Dropout layer
306    Dropout { rate: f64 },
307}
308
309// Fixed modules
310pub mod conv;
311pub mod dense;
312pub mod dropout;
313pub mod normalization;
314pub mod recurrent;
315
316// Temporarily comment out layer modules that need fixing
317// mod attention;
318// mod embedding;
319// mod regularization;
320
321// Re-export fixed modules
322pub use conv::Conv2D;
323pub use dense::Dense;
324pub use dropout::Dropout;
325pub use normalization::{BatchNorm, LayerNorm};
326pub use recurrent::LSTM;
327
328// Re-export will be added as modules are fixed
329// pub use attention::{AttentionConfig, AttentionMask, MultiHeadAttention, SelfAttention};