Skip to main content

ruvector_cnn/layers/
linear.rs

1//! Linear (Fully Connected) layer implementation.
2//!
3//! Implements a standard linear transformation: y = xW^T + b
4
5use super::{Layer, TensorShape};
6use crate::error::{CnnError, CnnResult};
7use crate::Tensor;
8
9/// Linear (Fully Connected) layer.
10///
11/// Performs the operation: output = input @ weight^T + bias
12#[derive(Clone, Debug)]
13pub struct Linear {
14    /// Input features
15    in_features: usize,
16    /// Output features
17    out_features: usize,
18    /// Weight matrix [out_features, in_features]
19    weight: Vec<f32>,
20    /// Bias vector [out_features], None for no bias
21    bias: Option<Vec<f32>>,
22}
23
24impl Linear {
25    /// Creates a new Linear layer with zero-initialized weights.
26    pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> CnnResult<Self> {
27        if in_features == 0 || out_features == 0 {
28            return Err(CnnError::InvalidParameter(
29                "Features must be > 0".to_string(),
30            ));
31        }
32
33        let weight = vec![0.0; out_features * in_features];
34        let bias = if use_bias {
35            Some(vec![0.0; out_features])
36        } else {
37            None
38        };
39
40        Ok(Self {
41            in_features,
42            out_features,
43            weight,
44            bias,
45        })
46    }
47
48    /// Creates a Linear layer with provided weights.
49    pub fn with_weights(
50        in_features: usize,
51        out_features: usize,
52        weight: Vec<f32>,
53        bias: Option<Vec<f32>>,
54    ) -> CnnResult<Self> {
55        if weight.len() != out_features * in_features {
56            return Err(CnnError::dim_mismatch(
57                out_features * in_features,
58                weight.len(),
59            ));
60        }
61
62        if let Some(ref b) = bias {
63            if b.len() != out_features {
64                return Err(CnnError::dim_mismatch(out_features, b.len()));
65            }
66        }
67
68        Ok(Self {
69            in_features,
70            out_features,
71            weight,
72            bias,
73        })
74    }
75
76    /// Returns the input features.
77    pub fn in_features(&self) -> usize {
78        self.in_features
79    }
80
81    /// Returns the output features.
82    pub fn out_features(&self) -> usize {
83        self.out_features
84    }
85
86    /// Returns a reference to the weight matrix.
87    pub fn weight(&self) -> &[f32] {
88        &self.weight
89    }
90
91    /// Returns a reference to the bias vector.
92    pub fn bias(&self) -> Option<&[f32]> {
93        self.bias.as_deref()
94    }
95
96    /// Sets the weight matrix.
97    pub fn set_weight(&mut self, weight: Vec<f32>) -> CnnResult<()> {
98        if weight.len() != self.out_features * self.in_features {
99            return Err(CnnError::dim_mismatch(
100                self.out_features * self.in_features,
101                weight.len(),
102            ));
103        }
104        self.weight = weight;
105        Ok(())
106    }
107
108    /// Sets the bias vector.
109    pub fn set_bias(&mut self, bias: Option<Vec<f32>>) -> CnnResult<()> {
110        if let Some(ref b) = bias {
111            if b.len() != self.out_features {
112                return Err(CnnError::dim_mismatch(self.out_features, b.len()));
113            }
114        }
115        self.bias = bias;
116        Ok(())
117    }
118
119    /// Forward pass for a single input vector.
120    pub fn forward_vec(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
121        if input.len() != self.in_features {
122            return Err(CnnError::dim_mismatch(self.in_features, input.len()));
123        }
124
125        let mut output = vec![0.0; self.out_features];
126
127        // output = input @ weight^T
128        for o in 0..self.out_features {
129            let mut sum = 0.0f32;
130            for i in 0..self.in_features {
131                sum += input[i] * self.weight[o * self.in_features + i];
132            }
133            if let Some(ref bias) = self.bias {
134                sum += bias[o];
135            }
136            output[o] = sum;
137        }
138
139        Ok(output)
140    }
141
142    /// Forward pass for a batch of input vectors.
143    pub fn forward_batch(&self, input: &[f32], batch_size: usize) -> CnnResult<Vec<f32>> {
144        if input.len() != batch_size * self.in_features {
145            return Err(CnnError::dim_mismatch(
146                batch_size * self.in_features,
147                input.len(),
148            ));
149        }
150
151        let mut output = vec![0.0; batch_size * self.out_features];
152
153        for n in 0..batch_size {
154            let input_offset = n * self.in_features;
155            let output_offset = n * self.out_features;
156
157            for o in 0..self.out_features {
158                let mut sum = 0.0f32;
159                for i in 0..self.in_features {
160                    sum += input[input_offset + i] * self.weight[o * self.in_features + i];
161                }
162                if let Some(ref bias) = self.bias {
163                    sum += bias[o];
164                }
165                output[output_offset + o] = sum;
166            }
167        }
168
169        Ok(output)
170    }
171}
172
173impl Layer for Linear {
174    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
175        let shape = input.shape();
176        // For linear, flatten all dimensions except batch
177        let batch_size = if shape.is_empty() { 1 } else { shape[0] };
178        let features = input.numel() / batch_size;
179
180        if features != self.in_features {
181            return Err(CnnError::dim_mismatch(self.in_features, features));
182        }
183
184        let output_data = self.forward_batch(input.data(), batch_size)?;
185        let out_shape = vec![batch_size, self.out_features];
186        Tensor::from_data(output_data, &out_shape)
187    }
188
189    fn name(&self) -> &'static str {
190        "Linear"
191    }
192
193    fn num_params(&self) -> usize {
194        let weight_params = self.out_features * self.in_features;
195        let bias_params = if self.bias.is_some() {
196            self.out_features
197        } else {
198            0
199        };
200        weight_params + bias_params
201    }
202}
203
204impl Linear {
205    /// Returns the output TensorShape for a given input TensorShape
206    pub fn output_shape(&self, input_shape: &TensorShape) -> TensorShape {
207        TensorShape {
208            n: input_shape.n,
209            c: self.out_features,
210            h: 1,
211            w: 1,
212        }
213    }
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn test_linear_creation() {
222        let linear = Linear::new(512, 1000, true).unwrap();
223        assert_eq!(linear.in_features(), 512);
224        assert_eq!(linear.out_features(), 1000);
225        assert!(linear.bias().is_some());
226    }
227
228    #[test]
229    fn test_linear_no_bias() {
230        let linear = Linear::new(512, 1000, false).unwrap();
231        assert!(linear.bias().is_none());
232    }
233
234    #[test]
235    fn test_linear_forward_identity() {
236        let linear = Linear::with_weights(
237            2,
238            2,
239            vec![1.0, 0.0, 0.0, 1.0], // Identity matrix
240            Some(vec![0.0, 0.0]),
241        )
242        .unwrap();
243
244        let input = vec![1.0, 2.0];
245        let output = linear.forward_vec(&input).unwrap();
246
247        assert!((output[0] - 1.0).abs() < 1e-6);
248        assert!((output[1] - 2.0).abs() < 1e-6);
249    }
250
251    #[test]
252    fn test_linear_forward_with_bias() {
253        let linear = Linear::with_weights(
254            2,
255            2,
256            vec![1.0, 0.0, 0.0, 1.0],
257            Some(vec![5.0, 10.0]),
258        )
259        .unwrap();
260
261        let input = vec![1.0, 2.0];
262        let output = linear.forward_vec(&input).unwrap();
263
264        assert!((output[0] - 6.0).abs() < 1e-6);
265        assert!((output[1] - 12.0).abs() < 1e-6);
266    }
267
268    #[test]
269    fn test_linear_forward_batch() {
270        let linear = Linear::with_weights(
271            2,
272            2,
273            vec![1.0, 0.0, 0.0, 1.0],
274            None,
275        )
276        .unwrap();
277
278        let input = vec![1.0, 2.0, 3.0, 4.0]; // batch of 2
279        let output = linear.forward_batch(&input, 2).unwrap();
280
281        assert!((output[0] - 1.0).abs() < 1e-6);
282        assert!((output[1] - 2.0).abs() < 1e-6);
283        assert!((output[2] - 3.0).abs() < 1e-6);
284        assert!((output[3] - 4.0).abs() < 1e-6);
285    }
286
287    #[test]
288    fn test_linear_num_params() {
289        let linear = Linear::new(512, 1000, true).unwrap();
290        assert_eq!(linear.num_params(), 512 * 1000 + 1000);
291    }
292
293    #[test]
294    fn test_linear_output_shape() {
295        let linear = Linear::new(576, 1024, true).unwrap();
296        let input_shape = TensorShape::new(2, 576, 1, 1);
297        let output_shape = linear.output_shape(&input_shape);
298
299        assert_eq!(output_shape.n, 2);
300        assert_eq!(output_shape.c, 1024);
301        assert_eq!(output_shape.h, 1);
302        assert_eq!(output_shape.w, 1);
303    }
304}