1use ndarray::Array2;
2use ndarray_rand::RandomExt;
3use ndarray_rand::rand_distr::Uniform;
4use crate::optimizers::Optimizer;
5
6#[derive(Clone, Debug)]
8pub struct LinearGradients {
9 pub weight: Array2<f64>,
10 pub bias: Array2<f64>,
11}
12
13#[derive(Clone, Debug)]
18pub struct LinearLayer {
19 pub weight: Array2<f64>, pub bias: Array2<f64>, pub input_size: usize,
22 pub output_size: usize,
23 input_cache: Option<Array2<f64>>, }
25
26impl LinearLayer {
27 pub fn new(input_size: usize, output_size: usize) -> Self {
36 let scale = (2.0 / (input_size + output_size) as f64).sqrt();
38 let weight_range = scale;
39
40 let weight = Array2::random((output_size, input_size), Uniform::new(-weight_range, weight_range));
41 let bias = Array2::zeros((output_size, 1));
42
43 Self {
44 weight,
45 bias,
46 input_size,
47 output_size,
48 input_cache: None,
49 }
50 }
51
52 pub fn new_zeros(input_size: usize, output_size: usize) -> Self {
54 let weight = Array2::zeros((output_size, input_size));
55 let bias = Array2::zeros((output_size, 1));
56
57 Self {
58 weight,
59 bias,
60 input_size,
61 output_size,
62 input_cache: None,
63 }
64 }
65
66 pub fn from_weights(weight: Array2<f64>, bias: Array2<f64>) -> Self {
68 let (output_size, input_size) = weight.dim();
69 assert_eq!(bias.shape(), &[output_size, 1], "Bias shape must be (output_size, 1)");
70
71 Self {
72 weight,
73 bias,
74 input_size,
75 output_size,
76 input_cache: None,
77 }
78 }
79
80 pub fn forward(&mut self, input: &Array2<f64>) -> Array2<f64> {
88 let (input_features, _batch_size) = input.dim();
89 assert_eq!(input_features, self.input_size,
90 "Input size {} doesn't match layer input size {}",
91 input_features, self.input_size);
92
93 self.input_cache = Some(input.clone());
95
96 &self.weight.dot(input) + &self.bias
98 }
99
100 pub fn backward(&self, grad_output: &Array2<f64>) -> (LinearGradients, Array2<f64>) {
110 let input = self.input_cache.as_ref().expect("Input cache not found for backward pass");
111 let (output_features, batch_size) = grad_output.dim();
112 let (input_features, input_batch_size) = input.dim();
113
114 assert_eq!(output_features, self.output_size, "Gradient output size mismatch");
115 assert_eq!(input_features, self.input_size, "Input size mismatch");
116 assert_eq!(batch_size, input_batch_size, "Batch size mismatch");
117
118 let weight_grad = grad_output.dot(&input.t());
120
121 let bias_grad = grad_output.sum_axis(ndarray::Axis(1)).insert_axis(ndarray::Axis(1));
123
124 let input_grad = self.weight.t().dot(grad_output);
126
127 let gradients = LinearGradients {
128 weight: weight_grad,
129 bias: bias_grad,
130 };
131
132 (gradients, input_grad)
133 }
134
135 pub fn update_parameters<O: Optimizer>(&mut self, gradients: &LinearGradients, optimizer: &mut O, prefix: &str) {
137 optimizer.update(&format!("{}_weight", prefix), &mut self.weight, &gradients.weight);
138 optimizer.update(&format!("{}_bias", prefix), &mut self.bias, &gradients.bias);
139 }
140
141 pub fn zero_gradients(&self) -> LinearGradients {
143 LinearGradients {
144 weight: Array2::zeros(self.weight.raw_dim()),
145 bias: Array2::zeros(self.bias.raw_dim()),
146 }
147 }
148
149 pub fn num_parameters(&self) -> usize {
151 self.weight.len() + self.bias.len()
152 }
153
154 pub fn dimensions(&self) -> (usize, usize) {
156 (self.input_size, self.output_size)
157 }
158
159 pub fn train(&mut self) {
161 }
163
164 pub fn eval(&mut self) {
166 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use ndarray::arr2;
174 use crate::optimizers::SGD;
175
176 #[test]
177 fn test_linear_layer_creation() {
178 let layer = LinearLayer::new(10, 5);
179 assert_eq!(layer.input_size, 10);
180 assert_eq!(layer.output_size, 5);
181 assert_eq!(layer.weight.shape(), &[5, 10]);
182 assert_eq!(layer.bias.shape(), &[5, 1]);
183 }
184
185 #[test]
186 fn test_linear_layer_forward() {
187 let mut layer = LinearLayer::new_zeros(3, 2);
188 let input = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]); let output = layer.forward(&input);
191 assert_eq!(output.shape(), &[2, 2]); assert!(output.iter().all(|&x| x == 0.0));
195 }
196
197 #[test]
198 fn test_linear_layer_backward() {
199 let mut layer = LinearLayer::new(3, 2);
200 let input = arr2(&[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]); let grad_output = arr2(&[[1.0, 1.0], [1.0, 1.0]]); let _output = layer.forward(&input);
205
206 let (gradients, input_grad) = layer.backward(&grad_output);
207
208 assert_eq!(gradients.weight.shape(), &[2, 3]);
209 assert_eq!(gradients.bias.shape(), &[2, 1]);
210 assert_eq!(input_grad.shape(), &[3, 2]);
211 }
212
213 #[test]
214 fn test_linear_layer_with_optimizer() {
215 let mut layer = LinearLayer::new(2, 1);
216 let mut optimizer = SGD::new(0.1);
217
218 let input = arr2(&[[1.0], [2.0]]); let target = arr2(&[[3.0]]); let output = layer.forward(&input);
223
224 let grad_output = &output - ⌖
226
227 let (gradients, _) = layer.backward(&grad_output);
229
230 layer.update_parameters(&gradients, &mut optimizer, "linear");
232
233 assert!(layer.weight.iter().any(|&x| x != 0.0) || layer.bias.iter().any(|&x| x != 0.0));
235 }
236
237 #[test]
238 fn test_linear_layer_dimensions() {
239 let layer = LinearLayer::new(128, 10);
240 assert_eq!(layer.dimensions(), (128, 10));
241 assert_eq!(layer.num_parameters(), 128 * 10 + 10); }
243
244 #[test]
245 fn test_from_weights() {
246 let weight = arr2(&[[1.0, 2.0], [3.0, 4.0]]);
247 let bias = arr2(&[[0.5], [-0.5]]);
248
249 let layer = LinearLayer::from_weights(weight.clone(), bias.clone());
250 assert_eq!(layer.weight, weight);
251 assert_eq!(layer.bias, bias);
252 assert_eq!(layer.input_size, 2);
253 assert_eq!(layer.output_size, 2);
254 }
255}