1use super::{Layer, TensorShape};
6use crate::error::{CnnError, CnnResult};
7use crate::Tensor;
8
9#[derive(Clone, Debug)]
13pub struct Linear {
14 in_features: usize,
16 out_features: usize,
18 weight: Vec<f32>,
20 bias: Option<Vec<f32>>,
22}
23
24impl Linear {
25 pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> CnnResult<Self> {
27 if in_features == 0 || out_features == 0 {
28 return Err(CnnError::InvalidParameter(
29 "Features must be > 0".to_string(),
30 ));
31 }
32
33 let weight = vec![0.0; out_features * in_features];
34 let bias = if use_bias {
35 Some(vec![0.0; out_features])
36 } else {
37 None
38 };
39
40 Ok(Self {
41 in_features,
42 out_features,
43 weight,
44 bias,
45 })
46 }
47
48 pub fn with_weights(
50 in_features: usize,
51 out_features: usize,
52 weight: Vec<f32>,
53 bias: Option<Vec<f32>>,
54 ) -> CnnResult<Self> {
55 if weight.len() != out_features * in_features {
56 return Err(CnnError::dim_mismatch(
57 out_features * in_features,
58 weight.len(),
59 ));
60 }
61
62 if let Some(ref b) = bias {
63 if b.len() != out_features {
64 return Err(CnnError::dim_mismatch(out_features, b.len()));
65 }
66 }
67
68 Ok(Self {
69 in_features,
70 out_features,
71 weight,
72 bias,
73 })
74 }
75
76 pub fn in_features(&self) -> usize {
78 self.in_features
79 }
80
81 pub fn out_features(&self) -> usize {
83 self.out_features
84 }
85
86 pub fn weight(&self) -> &[f32] {
88 &self.weight
89 }
90
91 pub fn bias(&self) -> Option<&[f32]> {
93 self.bias.as_deref()
94 }
95
96 pub fn set_weight(&mut self, weight: Vec<f32>) -> CnnResult<()> {
98 if weight.len() != self.out_features * self.in_features {
99 return Err(CnnError::dim_mismatch(
100 self.out_features * self.in_features,
101 weight.len(),
102 ));
103 }
104 self.weight = weight;
105 Ok(())
106 }
107
108 pub fn set_bias(&mut self, bias: Option<Vec<f32>>) -> CnnResult<()> {
110 if let Some(ref b) = bias {
111 if b.len() != self.out_features {
112 return Err(CnnError::dim_mismatch(self.out_features, b.len()));
113 }
114 }
115 self.bias = bias;
116 Ok(())
117 }
118
119 pub fn forward_vec(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
121 if input.len() != self.in_features {
122 return Err(CnnError::dim_mismatch(self.in_features, input.len()));
123 }
124
125 let mut output = vec![0.0; self.out_features];
126
127 for o in 0..self.out_features {
129 let mut sum = 0.0f32;
130 for i in 0..self.in_features {
131 sum += input[i] * self.weight[o * self.in_features + i];
132 }
133 if let Some(ref bias) = self.bias {
134 sum += bias[o];
135 }
136 output[o] = sum;
137 }
138
139 Ok(output)
140 }
141
142 pub fn forward_batch(&self, input: &[f32], batch_size: usize) -> CnnResult<Vec<f32>> {
144 if input.len() != batch_size * self.in_features {
145 return Err(CnnError::dim_mismatch(
146 batch_size * self.in_features,
147 input.len(),
148 ));
149 }
150
151 let mut output = vec![0.0; batch_size * self.out_features];
152
153 for n in 0..batch_size {
154 let input_offset = n * self.in_features;
155 let output_offset = n * self.out_features;
156
157 for o in 0..self.out_features {
158 let mut sum = 0.0f32;
159 for i in 0..self.in_features {
160 sum += input[input_offset + i] * self.weight[o * self.in_features + i];
161 }
162 if let Some(ref bias) = self.bias {
163 sum += bias[o];
164 }
165 output[output_offset + o] = sum;
166 }
167 }
168
169 Ok(output)
170 }
171}
172
173impl Layer for Linear {
174 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
175 let shape = input.shape();
176 let batch_size = if shape.is_empty() { 1 } else { shape[0] };
178 let features = input.numel() / batch_size;
179
180 if features != self.in_features {
181 return Err(CnnError::dim_mismatch(self.in_features, features));
182 }
183
184 let output_data = self.forward_batch(input.data(), batch_size)?;
185 let out_shape = vec![batch_size, self.out_features];
186 Tensor::from_data(output_data, &out_shape)
187 }
188
189 fn name(&self) -> &'static str {
190 "Linear"
191 }
192
193 fn num_params(&self) -> usize {
194 let weight_params = self.out_features * self.in_features;
195 let bias_params = if self.bias.is_some() {
196 self.out_features
197 } else {
198 0
199 };
200 weight_params + bias_params
201 }
202}
203
204impl Linear {
205 pub fn output_shape(&self, input_shape: &TensorShape) -> TensorShape {
207 TensorShape {
208 n: input_shape.n,
209 c: self.out_features,
210 h: 1,
211 w: 1,
212 }
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn test_linear_creation() {
222 let linear = Linear::new(512, 1000, true).unwrap();
223 assert_eq!(linear.in_features(), 512);
224 assert_eq!(linear.out_features(), 1000);
225 assert!(linear.bias().is_some());
226 }
227
228 #[test]
229 fn test_linear_no_bias() {
230 let linear = Linear::new(512, 1000, false).unwrap();
231 assert!(linear.bias().is_none());
232 }
233
234 #[test]
235 fn test_linear_forward_identity() {
236 let linear = Linear::with_weights(
237 2,
238 2,
239 vec![1.0, 0.0, 0.0, 1.0], Some(vec![0.0, 0.0]),
241 )
242 .unwrap();
243
244 let input = vec![1.0, 2.0];
245 let output = linear.forward_vec(&input).unwrap();
246
247 assert!((output[0] - 1.0).abs() < 1e-6);
248 assert!((output[1] - 2.0).abs() < 1e-6);
249 }
250
251 #[test]
252 fn test_linear_forward_with_bias() {
253 let linear = Linear::with_weights(
254 2,
255 2,
256 vec![1.0, 0.0, 0.0, 1.0],
257 Some(vec![5.0, 10.0]),
258 )
259 .unwrap();
260
261 let input = vec![1.0, 2.0];
262 let output = linear.forward_vec(&input).unwrap();
263
264 assert!((output[0] - 6.0).abs() < 1e-6);
265 assert!((output[1] - 12.0).abs() < 1e-6);
266 }
267
268 #[test]
269 fn test_linear_forward_batch() {
270 let linear = Linear::with_weights(
271 2,
272 2,
273 vec![1.0, 0.0, 0.0, 1.0],
274 None,
275 )
276 .unwrap();
277
278 let input = vec![1.0, 2.0, 3.0, 4.0]; let output = linear.forward_batch(&input, 2).unwrap();
280
281 assert!((output[0] - 1.0).abs() < 1e-6);
282 assert!((output[1] - 2.0).abs() < 1e-6);
283 assert!((output[2] - 3.0).abs() < 1e-6);
284 assert!((output[3] - 4.0).abs() < 1e-6);
285 }
286
287 #[test]
288 fn test_linear_num_params() {
289 let linear = Linear::new(512, 1000, true).unwrap();
290 assert_eq!(linear.num_params(), 512 * 1000 + 1000);
291 }
292
293 #[test]
294 fn test_linear_output_shape() {
295 let linear = Linear::new(576, 1024, true).unwrap();
296 let input_shape = TensorShape::new(2, 576, 1, 1);
297 let output_shape = linear.output_shape(&input_shape);
298
299 assert_eq!(output_shape.n, 2);
300 assert_eq!(output_shape.c, 1024);
301 assert_eq!(output_shape.h, 1);
302 assert_eq!(output_shape.w, 1);
303 }
304}