rust_lstm/layers/
linear.rs

1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::optimizers::Optimizer;
5
6/// Holds gradients for linear layer parameters during backpropagation
7#[derive(Clone, Debug)]
8pub struct LinearGradients {
9    pub weight: Array2<f64>,
10    pub bias: Array2<f64>,
11}
12
13/// A fully connected (linear/dense) layer for neural networks
14/// 
15/// Performs the transformation: output = input * weight^T + bias
16/// where weight has shape (output_size, input_size) and bias has shape (output_size, 1)
17#[derive(Clone, Debug)]
18pub struct LinearLayer {
19    pub weight: Array2<f64>,     // (output_size, input_size)
20    pub bias: Array2<f64>,       // (output_size, 1)
21    pub input_size: usize,
22    pub output_size: usize,
23    input_cache: Option<Array2<f64>>, // Cache input for backward pass
24}
25
26impl LinearLayer {
27    /// Create a new linear layer with random initialization
28    /// 
29    /// # Arguments
30    /// * `input_size` - Size of input features
31    /// * `output_size` - Size of output features
32    /// 
33    /// # Returns
34    /// * New LinearLayer with Xavier/Glorot initialization
35    pub fn new(input_size: usize, output_size: usize) -> Self {
36        // Xavier/Glorot initialization: scale by sqrt(2 / (input_size + output_size))
37        let scale = (2.0 / (input_size + output_size) as f64).sqrt();
38        let weight_range = scale;
39        
40        let weight = Array2::random((output_size, input_size), Uniform::new(-weight_range, weight_range));
41        let bias = Array2::zeros((output_size, 1));
42        
43        Self {
44            weight,
45            bias,
46            input_size,
47            output_size,
48            input_cache: None,
49        }
50    }
51    
52    /// Create a new linear layer with zero initialization
53    pub fn new_zeros(input_size: usize, output_size: usize) -> Self {
54        let weight = Array2::zeros((output_size, input_size));
55        let bias = Array2::zeros((output_size, 1));
56        
57        Self {
58            weight,
59            bias,
60            input_size,
61            output_size,
62            input_cache: None,
63        }
64    }
65    
66    /// Create a new linear layer with custom initialization
67    pub fn from_weights(weight: Array2<f64>, bias: Array2<f64>) -> Self {
68        let (output_size, input_size) = weight.dim();
69        assert_eq!(bias.shape(), &[output_size, 1], "Bias shape must be (output_size, 1)");
70        
71        Self {
72            weight,
73            bias,
74            input_size,
75            output_size,
76            input_cache: None,
77        }
78    }
79    
80    /// Forward pass through the linear layer
81    /// 
82    /// # Arguments
83    /// * `input` - Input tensor of shape (input_size, batch_size)
84    /// 
85    /// # Returns
86    /// * Output tensor of shape (output_size, batch_size)
87    pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
88        let (input_features, _batch_size) = input.dim();
89        assert_eq!(input_features, self.input_size, 
90                  "Input size {} doesn't match layer input size {}", 
91                  input_features, self.input_size);
92        
93        // Cache input for backward pass
94        self.input_cache = Some(input.clone());
95        
96        // output = weight @ input + bias (bias broadcasts automatically)
97        &self.weight.dot(input) + &self.bias
98    }
99    
100    /// Backward pass through the linear layer
101    /// 
102    /// # Arguments
103    /// * `grad_output` - Gradient w.r.t. output of shape (output_size, batch_size)
104    /// 
105    /// # Returns
106    /// * Tuple of (gradients, input_gradient)
107    ///   - gradients: LinearGradients containing weight and bias gradients
108    ///   - input_gradient: Gradient w.r.t. input of shape (input_size, batch_size)
109    pub fn backward(&self, grad_output: &Array2<f64>) -> (LinearGradients, Array2<f64>) {
110        let input = self.input_cache.as_ref().expect("Input cache not found for backward pass");
111        let (output_features, batch_size) = grad_output.dim();
112        let (input_features, input_batch_size) = input.dim();
113        
114        assert_eq!(output_features, self.output_size, "Gradient output size mismatch");
115        assert_eq!(input_features, self.input_size, "Input size mismatch");
116        assert_eq!(batch_size, input_batch_size, "Batch size mismatch");
117        
118        // Gradient w.r.t. weight: grad_output @ input^T
119        let weight_grad = grad_output.dot(&input.t());
120        
121        // Gradient w.r.t. bias: sum over batch dimension, keep as column vector
122        let bias_grad = grad_output.sum_axis(ndarray::Axis(1)).insert_axis(ndarray::Axis(1));
123        
124        // Gradient w.r.t. input: weight^T @ grad_output
125        let input_grad = self.weight.t().dot(grad_output);
126        
127        let gradients = LinearGradients {
128            weight: weight_grad,
129            bias: bias_grad,
130        };
131        
132        (gradients, input_grad)
133    }
134    
135    /// Update parameters using the provided optimizer
136    pub fn update_parameters<O: Optimizer>(&mut self, gradients: &LinearGradients, optimizer: &mut O, prefix: &str) {
137        optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight);
138        optimizer.update(&format!("{}_bias", prefix), &mut self.bias, &gradients.bias);
139    }
140    
141    /// Initialize zero gradients for accumulation
142    pub fn zero_gradients(&self) -> LinearGradients {
143        LinearGradients {
144            weight: Array2::zeros(self.weight.raw_dim()),
145            bias: Array2::zeros(self.bias.raw_dim()),
146        }
147    }
148    
149    /// Get the number of parameters in this layer
150    pub fn num_parameters(&self) -> usize {
151        self.weight.len() + self.bias.len()
152    }
153    
154    /// Get layer dimensions
155    pub fn dimensions(&self) -> (usize, usize) {
156        (self.input_size, self.output_size)
157    }
158    
159    /// Set the layer to training mode
160    pub fn train(&mut self) {
161        // Linear layer has no specific training mode behavior like dropout
162    }
163    
164    /// Set the layer to evaluation mode
165    pub fn eval(&mut self) {
166        // Linear layer has no specific evaluation mode behavior
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use ndarray::arr2;
174    use crate::optimizers::SGD;
175
176    #[test]
177    fn test_linear_layer_creation() {
178        let layer = LinearLayer::new(10, 5);
179        assert_eq!(layer.input_size, 10);
180        assert_eq!(layer.output_size, 5);
181        assert_eq!(layer.weight.shape(), &[5, 10]);
182        assert_eq!(layer.bias.shape(), &[5, 1]);
183    }
184
185    #[test]
186    fn test_linear_layer_forward() {
187        let mut layer = LinearLayer::new_zeros(3, 2);
188        let input = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]); // (3, 2)
189        
190        let output = layer.forward(&input);
191        assert_eq!(output.shape(), &[2, 2]); // (output_size, batch_size)
192        
193        // With zero weights and bias, output should be zero
194        assert!(output.iter().all(|&x| x == 0.0));
195    }
196
197    #[test]
198    fn test_linear_layer_backward() {
199        let mut layer = LinearLayer::new(3, 2);
200        let input = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]); // (3, 2)
201        let grad_output = arr2(&[[1.0, 1.0], [1.0, 1.0]]); // (2, 2)
202        
203        // Forward pass first to cache input
204        let _output = layer.forward(&input);
205        
206        let (gradients, input_grad) = layer.backward(&grad_output);
207        
208        assert_eq!(gradients.weight.shape(), &[2, 3]);
209        assert_eq!(gradients.bias.shape(), &[2, 1]);
210        assert_eq!(input_grad.shape(), &[3, 2]);
211    }
212
213    #[test]
214    fn test_linear_layer_with_optimizer() {
215        let mut layer = LinearLayer::new(2, 1);
216        let mut optimizer = SGD::new(0.1);
217        
218        let input = arr2(&[[1.0], [2.0]]); // (2, 1)
219        let target = arr2(&[[3.0]]); // (1, 1)
220        
221        // Forward pass
222        let output = layer.forward(&input);
223        
224        // Simple loss gradient (output - target)
225        let grad_output = &output - &target;
226        
227        // Backward pass
228        let (gradients, _) = layer.backward(&grad_output);
229        
230        // Update parameters
231        layer.update_parameters(&gradients, &mut optimizer, "linear");
232        
233        // Parameters should have changed
234        assert!(layer.weight.iter().any(|&x| x != 0.0) || layer.bias.iter().any(|&x| x != 0.0));
235    }
236
237    #[test]
238    fn test_linear_layer_dimensions() {
239        let layer = LinearLayer::new(128, 10);
240        assert_eq!(layer.dimensions(), (128, 10));
241        assert_eq!(layer.num_parameters(), 128 * 10 + 10); // weights + bias
242    }
243
244    #[test]
245    fn test_from_weights() {
246        let weight = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
247        let bias = arr2(&[[0.5], [-0.5]]);
248        
249        let layer = LinearLayer::from_weights(weight.clone(), bias.clone());
250        assert_eq!(layer.weight, weight);
251        assert_eq!(layer.bias, bias);
252        assert_eq!(layer.input_size, 2);
253        assert_eq!(layer.output_size, 2);
254    }
255}