scirs2_neural/layers/
dense.rs

1//! Dense (fully connected) layer implementation
2
3use crate::activations_minimal::Activation;
4use crate::error::{NeuralError, Result};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{Distribution, Uniform};
9use std::fmt::Debug;
10
11/// Dense (fully connected) layer for neural networks.
12///
13/// A dense layer performs the operation: y = activation(W * x + b), where W is the weight matrix,
14/// x is the input vector, b is the bias vector, and activation is the activation function.
15pub struct Dense<F: Float + Debug + Send + Sync> {
16    /// Number of input features
17    input_dim: usize,
18    /// Number of output features
19    output_dim: usize,
20    /// Weight matrix
21    weights: Array<F, IxDyn>,
22    /// Bias vector
23    biases: Array<F, IxDyn>,
24    /// Gradient of the weights
25    dweights: std::sync::RwLock<Array<F, IxDyn>>,
26    /// Gradient of the biases
27    dbiases: std::sync::RwLock<Array<F, IxDyn>>,
28    /// Activation function, if any
29    activation: Option<Box<dyn Activation<F> + Send + Sync>>,
30    /// Input from the forward pass, needed in backward pass
31    input: std::sync::RwLock<Option<Array<F, IxDyn>>>,
32    /// Output before activation, needed in backward pass
33    output_pre_activation: std::sync::RwLock<Option<Array<F, IxDyn>>>,
34}
35
36impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> std::fmt::Debug for Dense<F> {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("Dense")
39            .field("input_dim", &self.input_dim)
40            .field("output_dim", &self.output_dim)
41            .field("weightsshape", &self.weights.shape())
42            .field("biasesshape", &self.biases.shape())
43            .field("has_activation", &self.activation.is_some())
44            .finish()
45    }
46}
47
48impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Clone for Dense<F> {
49    fn clone(&self) -> Self {
50        Self {
51            input_dim: self.input_dim,
52            output_dim: self.output_dim,
53            weights: self.weights.clone(),
54            biases: self.biases.clone(),
55            dweights: std::sync::RwLock::new(self.dweights.read().unwrap().clone()),
56            dbiases: std::sync::RwLock::new(self.dbiases.read().unwrap().clone()),
57            // We can't clone trait objects, so we skip the activation
58            activation: None,
59            input: std::sync::RwLock::new(self.input.read().unwrap().clone()),
60            output_pre_activation: std::sync::RwLock::new(
61                self.output_pre_activation.read().unwrap().clone(),
62            ),
63        }
64    }
65}
66
67impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Dense<F> {
68    /// Create a new dense layer.
69    ///
70    /// # Arguments
71    /// * `input_dim` - Number of input features
72    /// * `output_dim` - Number of output features
73    /// * `activation_name` - Optional activation function name
74    /// * `rng` - Random number generator for weight initialization
75    pub fn new<R: scirs2_core::random::Rng + scirs2_core::random::RngCore>(
76        input_dim: usize,
77        output_dim: usize,
78        activation_name: Option<&str>,
79        rng: &mut R,
80    ) -> Result<Self> {
81        // Create activation function from _name
82        let activation = if let Some(name) = activation_name {
83            match name.to_lowercase().as_str() {
84                "relu" => Some(Box::new(crate::activations_minimal::ReLU::new())
85                    as Box<dyn Activation<F> + Send + Sync>),
86                "sigmoid" => Some(Box::new(crate::activations_minimal::Sigmoid::new())
87                    as Box<dyn Activation<F> + Send + Sync>),
88                "tanh" => Some(Box::new(crate::activations_minimal::Tanh::new())
89                    as Box<dyn Activation<F> + Send + Sync>),
90                "softmax" => Some(Box::new(crate::activations_minimal::Softmax::new(-1))
91                    as Box<dyn Activation<F> + Send + Sync>),
92                "gelu" => Some(Box::new(crate::activations_minimal::GELU::new())
93                    as Box<dyn Activation<F> + Send + Sync>),
94                _ => None,
95            }
96        } else {
97            None
98        };
99
100        // Initialize weights with Xavier/Glorot initialization
101        let scale = F::from(1.0 / f64::sqrt(input_dim as f64)).ok_or_else(|| {
102            NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
103        })?;
104
105        // Create a 2D weights array
106        let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
107            NeuralError::InvalidArchitecture(format!("Failed to create uniform distribution: {e}"))
108        })?;
109        let weights_vec: Vec<F> = (0..(input_dim * output_dim))
110            .map(|_| {
111                let val = F::from(uniform.sample(rng)).ok_or_else(|| {
112                    NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
113                });
114                val.map(|v| v * scale).unwrap_or_else(|_| F::zero())
115            })
116            .collect();
117
118        let weights =
119            Array::from_shape_vec(IxDyn(&[input_dim, output_dim]), weights_vec).map_err(|e| {
120                NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
121            })?;
122
123        // Initialize biases with zeros
124        let biases = Array::zeros(IxDyn(&[output_dim]));
125
126        // Initialize gradient arrays with zeros
127        let dweights = std::sync::RwLock::new(Array::zeros(weights.dim()));
128        let dbiases = std::sync::RwLock::new(Array::zeros(biases.dim()));
129
130        Ok(Self {
131            input_dim,
132            output_dim,
133            weights,
134            biases,
135            dweights,
136            dbiases,
137            activation,
138            input: std::sync::RwLock::new(None),
139            output_pre_activation: std::sync::RwLock::new(None),
140        })
141    }
142
143    /// Get the input dimension
144    pub fn input_dim(&self) -> usize {
145        self.input_dim
146    }
147
148    /// Get the output dimension
149    pub fn output_dim(&self) -> usize {
150        self.output_dim
151    }
152
153    /// Simple matrix multiplication for forward pass
154    fn compute_forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
155        let batch_size = input.shape()[0];
156        let mut output = Array::zeros(IxDyn(&[batch_size, self.output_dim]));
157
158        // Matrix multiplication: output = input @ weights
159        for batch in 0..batch_size {
160            for out_idx in 0..self.output_dim {
161                let mut sum = F::zero();
162                for in_idx in 0..self.input_dim {
163                    sum = sum + input[[batch, in_idx]] * self.weights[[in_idx, out_idx]];
164                }
165                // Add bias
166                output[[batch, out_idx]] = sum + self.biases[out_idx];
167            }
168        }
169
170        Ok(output)
171    }
172}
173
174impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dense<F> {
175    fn forward(
176        &self,
177        input: &Array<F, scirs2_core::ndarray::IxDyn>,
178    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
179        // Cache input for backward pass
180        {
181            let mut input_cache = self.input.write().unwrap();
182            *input_cache = Some(input.clone());
183        }
184
185        // Ensure input is 2D
186        let input_2d = if input.ndim() == 1 {
187            input
188                .clone()
189                .into_shape_with_order(IxDyn(&[1, self.input_dim]))
190                .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {e}")))?
191        } else {
192            input.clone()
193        };
194
195        // Validate input dimensions
196        if input_2d.shape()[1] != self.input_dim {
197            return Err(NeuralError::InvalidArgument(format!(
198                "Input dimension mismatch: expected {}, got {}",
199                self.input_dim,
200                input_2d.shape()[1]
201            )));
202        }
203
204        // Compute linear transformation
205        let output = self.compute_forward(&input_2d)?;
206
207        // Cache pre-activation output
208        {
209            let mut pre_activation_cache = self.output_pre_activation.write().unwrap();
210            *pre_activation_cache = Some(output.clone());
211        }
212
213        // Apply activation function if present
214        if let Some(ref activation) = self.activation {
215            activation.forward(&output)
216        } else {
217            Ok(output)
218        }
219    }
220
221    fn backward(
222        &self,
223        _input: &Array<F, scirs2_core::ndarray::IxDyn>,
224        grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
225    ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
226        // Get cached data
227        let cached_input = {
228            let cache = self.input.read().unwrap();
229            cache.clone().ok_or_else(|| {
230                NeuralError::InferenceError("No cached _input for backward pass".to_string())
231            })?
232        };
233
234        let pre_activation = {
235            let cache = self.output_pre_activation.read().unwrap();
236            cache.clone().ok_or_else(|| {
237                NeuralError::InferenceError(
238                    "No cached pre-activation _output for backward pass".to_string(),
239                )
240            })?
241        };
242
243        // Apply activation gradient if present
244        let grad_pre_activation = if let Some(ref activation) = self.activation {
245            activation.backward(grad_output, &pre_activation)?
246        } else {
247            grad_output.clone()
248        };
249
250        // Ensure gradients are 2D
251        let grad_2d = if grad_pre_activation.ndim() == 1 {
252            grad_pre_activation
253                .into_shape_with_order(IxDyn(&[1, self.output_dim]))
254                .map_err(|e| {
255                    NeuralError::InferenceError(format!("Failed to reshape gradient: {e}"))
256                })?
257        } else {
258            grad_pre_activation
259        };
260
261        let input_2d = if cached_input.ndim() == 1 {
262            cached_input
263                .into_shape_with_order(IxDyn(&[1, self.input_dim]))
264                .map_err(|e| {
265                    NeuralError::InferenceError(format!("Failed to reshape cached input: {e}"))
266                })?
267        } else {
268            cached_input
269        };
270
271        let batch_size = grad_2d.shape()[0];
272
273        // Compute weight gradients: dW = input.T @ grad_output
274        let mut dweights = Array::zeros(IxDyn(&[self.input_dim, self.output_dim]));
275        for i in 0..self.input_dim {
276            for j in 0..self.output_dim {
277                let mut sum = F::zero();
278                for b in 0..batch_size {
279                    sum = sum + input_2d[[b, i]] * grad_2d[[b, j]];
280                }
281                dweights[[i, j]] = sum;
282            }
283        }
284
285        // Compute bias gradients: db = sum(grad_output, axis=0)
286        let mut dbiases = Array::zeros(IxDyn(&[self.output_dim]));
287        for j in 0..self.output_dim {
288            let mut sum = F::zero();
289            for b in 0..batch_size {
290                sum = sum + grad_2d[[b, j]];
291            }
292            dbiases[j] = sum;
293        }
294
295        // Update internal gradients
296        {
297            let mut dweights_guard = self.dweights.write().unwrap();
298            *dweights_guard = dweights;
299        }
300        {
301            let mut dbiases_guard = self.dbiases.write().unwrap();
302            *dbiases_guard = dbiases;
303        }
304
305        // Compute gradient w.r.t. _input: grad_input = grad_output @ weights.T
306        let mut grad_input = Array::zeros(IxDyn(&[batch_size, self.input_dim]));
307        for b in 0..batch_size {
308            for i in 0..self.input_dim {
309                let mut sum = F::zero();
310                for j in 0..self.output_dim {
311                    sum = sum + grad_2d[[b, j]] * self.weights[[i, j]];
312                }
313                grad_input[[b, i]] = sum;
314            }
315        }
316
317        Ok(grad_input)
318    }
319
320    fn update(&mut self, learningrate: F) -> Result<()> {
321        let dweights = {
322            let dweights_guard = self.dweights.read().unwrap();
323            dweights_guard.clone()
324        };
325        let dbiases = {
326            let dbiases_guard = self.dbiases.read().unwrap();
327            dbiases_guard.clone()
328        };
329
330        // Update weights and biases using gradient descent
331        for i in 0..self.input_dim {
332            for j in 0..self.output_dim {
333                self.weights[[i, j]] = self.weights[[i, j]] - learningrate * dweights[[i, j]];
334            }
335        }
336
337        for j in 0..self.output_dim {
338            self.biases[j] = self.biases[j] - learningrate * dbiases[j];
339        }
340
341        Ok(())
342    }
343
344    fn as_any(&self) -> &dyn std::any::Any {
345        self
346    }
347
348    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
349        self
350    }
351
352    fn layer_type(&self) -> &str {
353        "Dense"
354    }
355
356    fn parameter_count(&self) -> usize {
357        self.weights.len() + self.biases.len()
358    }
359
360    fn layer_description(&self) -> String {
361        format!(
362            "type:Dense, input, _dim:{}, output, _dim:{}, params:{}",
363            self.input_dim,
364            self.output_dim,
365            self.parameter_count()
366        )
367    }
368}
369
370impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for Dense<F> {
371    fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
372        vec![self.weights.clone(), self.biases.clone()]
373    }
374
375    fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
376        // This method has limitations with RwLock - in practice this would need redesign
377        vec![]
378    }
379
380    fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
381        if params.len() != 2 {
382            return Err(NeuralError::InvalidArchitecture(format!(
383                "Expected 2 parameters (weights, biases), got {}",
384                params.len()
385            )));
386        }
387
388        let weights = &params[0];
389        let biases = &params[1];
390
391        if weights.shape() != self.weights.shape() {
392            return Err(NeuralError::InvalidArchitecture(format!(
393                "Weights shape mismatch: expected {:?}, got {:?}",
394                self.weights.shape(),
395                weights.shape()
396            )));
397        }
398
399        if biases.shape() != self.biases.shape() {
400            return Err(NeuralError::InvalidArchitecture(format!(
401                "Biases shape mismatch: expected {:?}, got {:?}",
402                self.biases.shape(),
403                biases.shape()
404            )));
405        }
406
407        self.weights = weights.clone();
408        self.biases = biases.clone();
409
410        Ok(())
411    }
412}
413
414// Explicit Send + Sync implementations for Dense layer
415unsafe impl<F: Float + Debug + Send + Sync> Send for Dense<F> {}
416unsafe impl<F: Float + Debug + Send + Sync> Sync for Dense<F> {}