1use crate::{simd, CnnError, CnnResult, Tensor};
8
9use super::{Layer, TensorShape};
10
11#[derive(Debug, Clone)]
21pub struct Conv2d {
22 in_channels: usize,
24 out_channels: usize,
26 kernel_size: usize,
28 stride: usize,
30 padding: usize,
32 groups: usize,
34 weights: Vec<f32>,
36 bias: Option<Vec<f32>>,
38}
39
40#[derive(Debug, Clone)]
42pub struct Conv2dBuilder {
43 in_channels: usize,
44 out_channels: usize,
45 kernel_size: usize,
46 stride: usize,
47 padding: usize,
48 groups: usize,
49 bias: bool,
50}
51
52impl Conv2dBuilder {
53 pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
55 Self {
56 in_channels,
57 out_channels,
58 kernel_size,
59 stride: 1,
60 padding: 0,
61 groups: 1,
62 bias: true,
63 }
64 }
65
66 pub fn stride(mut self, stride: usize) -> Self {
68 self.stride = stride;
69 self
70 }
71
72 pub fn padding(mut self, padding: usize) -> Self {
74 self.padding = padding;
75 self
76 }
77
78 pub fn groups(mut self, groups: usize) -> Self {
80 self.groups = groups;
81 self
82 }
83
84 pub fn bias(mut self, bias: bool) -> Self {
86 self.bias = bias;
87 self
88 }
89
90 pub fn build(self) -> CnnResult<Conv2d> {
92 if self.in_channels % self.groups != 0 {
93 return Err(CnnError::InvalidParameter(
94 format!("in_channels {} must be divisible by groups {}", self.in_channels, self.groups)
95 ));
96 }
97 if self.out_channels % self.groups != 0 {
98 return Err(CnnError::InvalidParameter(
99 format!("out_channels {} must be divisible by groups {}", self.out_channels, self.groups)
100 ));
101 }
102
103 let in_channels_per_group = self.in_channels / self.groups;
104 let num_weights = self.out_channels * self.kernel_size * self.kernel_size * in_channels_per_group;
105
106 let fan_in = in_channels_per_group * self.kernel_size * self.kernel_size;
108 let fan_out = (self.out_channels / self.groups) * self.kernel_size * self.kernel_size;
109 let std_dev = (2.0 / (fan_in + fan_out) as f32).sqrt();
110
111 let weights: Vec<f32> = (0..num_weights)
112 .map(|i| {
113 let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
114 (x * 2.0 - 1.0) * std_dev
115 })
116 .collect();
117
118 let bias = if self.bias {
119 Some(vec![0.0; self.out_channels])
120 } else {
121 None
122 };
123
124 Ok(Conv2d {
125 in_channels: self.in_channels,
126 out_channels: self.out_channels,
127 kernel_size: self.kernel_size,
128 stride: self.stride,
129 padding: self.padding,
130 groups: self.groups,
131 weights,
132 bias,
133 })
134 }
135}
136
137impl Conv2d {
138 pub fn new(
140 in_channels: usize,
141 out_channels: usize,
142 kernel_size: usize,
143 stride: usize,
144 padding: usize,
145 ) -> Self {
146 let num_weights = out_channels * kernel_size * kernel_size * in_channels;
147
148 let fan_in = in_channels * kernel_size * kernel_size;
150 let fan_out = out_channels * kernel_size * kernel_size;
151 let std_dev = (2.0 / (fan_in + fan_out) as f32).sqrt();
152
153 let weights: Vec<f32> = (0..num_weights)
155 .map(|i| {
156 let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
157 (x * 2.0 - 1.0) * std_dev
158 })
159 .collect();
160
161 Self {
162 in_channels,
163 out_channels,
164 kernel_size,
165 stride,
166 padding,
167 groups: 1,
168 weights,
169 bias: None,
170 }
171 }
172
173 pub fn builder(in_channels: usize, out_channels: usize, kernel_size: usize) -> Conv2dBuilder {
175 Conv2dBuilder::new(in_channels, out_channels, kernel_size)
176 }
177
178 pub fn with_bias(
180 in_channels: usize,
181 out_channels: usize,
182 kernel_size: usize,
183 stride: usize,
184 padding: usize,
185 ) -> Self {
186 let mut conv = Self::new(in_channels, out_channels, kernel_size, stride, padding);
187 conv.bias = Some(vec![0.0; out_channels]);
188 conv
189 }
190
191 pub fn output_shape_nchw(&self, input_shape: &TensorShape) -> TensorShape {
193 let out_h = (input_shape.h + 2 * self.padding - self.kernel_size) / self.stride + 1;
194 let out_w = (input_shape.w + 2 * self.padding - self.kernel_size) / self.stride + 1;
195 TensorShape::new(input_shape.n, self.out_channels, out_h, out_w)
196 }
197
198 pub fn set_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
200 let expected = self.out_channels * self.kernel_size * self.kernel_size * self.in_channels;
201 if weights.len() != expected {
202 return Err(CnnError::invalid_shape(
203 format!("{} weights", expected),
204 format!("{} weights", weights.len()),
205 ));
206 }
207 self.weights = weights;
208 Ok(())
209 }
210
211 pub fn set_bias(&mut self, bias: Vec<f32>) -> CnnResult<()> {
213 if bias.len() != self.out_channels {
214 return Err(CnnError::invalid_shape(
215 format!("{} bias values", self.out_channels),
216 format!("{} bias values", bias.len()),
217 ));
218 }
219 self.bias = Some(bias);
220 Ok(())
221 }
222
223 pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
225 if input_shape.len() != 4 {
226 return Err(CnnError::invalid_shape(
227 "4D tensor (NHWC)",
228 format!("{}D tensor", input_shape.len()),
229 ));
230 }
231
232 let batch = input_shape[0];
233 let in_h = input_shape[1];
234 let in_w = input_shape[2];
235
236 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
237 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
238
239 Ok(vec![batch, out_h, out_w, self.out_channels])
240 }
241
242 pub fn weights(&self) -> &[f32] {
244 &self.weights
245 }
246
247 pub fn bias(&self) -> Option<&[f32]> {
249 self.bias.as_deref()
250 }
251
252 pub fn kernel_size(&self) -> usize {
254 self.kernel_size
255 }
256
257 pub fn stride(&self) -> usize {
259 self.stride
260 }
261
262 pub fn padding(&self) -> usize {
264 self.padding
265 }
266
267 pub fn out_channels(&self) -> usize {
269 self.out_channels
270 }
271
272 pub fn in_channels(&self) -> usize {
274 self.in_channels
275 }
276
277 pub fn groups(&self) -> usize {
279 self.groups
280 }
281}
282
283impl Layer for Conv2d {
284 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
285 let shape = input.shape();
286 if shape.len() != 4 {
287 return Err(CnnError::invalid_shape(
288 "4D tensor (NHWC)",
289 format!("{}D tensor", shape.len()),
290 ));
291 }
292
293 let in_channels = shape[3];
294 if in_channels != self.in_channels {
295 return Err(CnnError::invalid_shape(
296 format!("{} input channels", self.in_channels),
297 format!("{} input channels", in_channels),
298 ));
299 }
300
301 let batch = shape[0];
302 let in_h = shape[1];
303 let in_w = shape[2];
304
305 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
306 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
307
308 let out_shape = vec![batch, out_h, out_w, self.out_channels];
309 let mut output = Tensor::zeros(&out_shape);
310
311 let batch_in_size = in_h * in_w * in_channels;
313 let batch_out_size = out_h * out_w * self.out_channels;
314
315 for b in 0..batch {
316 let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
317 let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
318
319 if self.kernel_size == 3 && self.groups == 1 {
320 simd::conv_3x3_simd(
322 input_slice,
323 &self.weights,
324 output_slice,
325 in_h,
326 in_w,
327 self.in_channels,
328 self.out_channels,
329 self.stride,
330 self.padding,
331 );
332 } else if self.kernel_size == 3 && self.groups == self.in_channels && self.in_channels == self.out_channels {
333 simd::depthwise_conv_3x3_simd(
335 input_slice,
336 &self.weights,
337 output_slice,
338 in_h,
339 in_w,
340 self.in_channels,
341 self.stride,
342 self.padding,
343 );
344 } else {
345 self.conv_generic(input_slice, output_slice, in_h, in_w, out_h, out_w);
347 }
348 }
349
350 if let Some(bias) = &self.bias {
352 for val in output.data_mut().chunks_mut(self.out_channels) {
353 for (i, v) in val.iter_mut().enumerate() {
354 *v += bias[i];
355 }
356 }
357 }
358
359 Ok(output)
360 }
361
362 fn name(&self) -> &'static str {
363 "Conv2d"
364 }
365
366 fn num_params(&self) -> usize {
367 let weight_params =
368 self.out_channels * self.kernel_size * self.kernel_size * self.in_channels;
369 let bias_params = if self.bias.is_some() {
370 self.out_channels
371 } else {
372 0
373 };
374 weight_params + bias_params
375 }
376}
377
378impl Conv2d {
379 fn conv_generic(
381 &self,
382 input: &[f32],
383 output: &mut [f32],
384 in_h: usize,
385 in_w: usize,
386 out_h: usize,
387 out_w: usize,
388 ) {
389 let ks = self.kernel_size;
390 let in_channels_per_group = self.in_channels / self.groups;
391 let out_channels_per_group = self.out_channels / self.groups;
392
393 for oh in 0..out_h {
394 for ow in 0..out_w {
395 for g in 0..self.groups {
396 let in_c_start = g * in_channels_per_group;
397 let out_c_start = g * out_channels_per_group;
398
399 for oc_local in 0..out_channels_per_group {
400 let oc = out_c_start + oc_local;
401 let mut sum = 0.0f32;
402
403 for kh in 0..ks {
404 for kw in 0..ks {
405 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
406 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
407
408 if ih >= 0
409 && ih < in_h as isize
410 && iw >= 0
411 && iw < in_w as isize
412 {
413 let ih = ih as usize;
414 let iw = iw as usize;
415
416 for ic_local in 0..in_channels_per_group {
417 let ic = in_c_start + ic_local;
418 let input_idx =
419 ih * in_w * self.in_channels + iw * self.in_channels + ic;
420 let kernel_idx = oc * ks * ks * in_channels_per_group
422 + kh * ks * in_channels_per_group
423 + kw * in_channels_per_group
424 + ic_local;
425 sum += input[input_idx] * self.weights[kernel_idx];
426 }
427 }
428 }
429 }
430
431 output[oh * out_w * self.out_channels + ow * self.out_channels + oc] = sum;
432 }
433 }
434 }
435 }
436 }
437}
438
439#[derive(Debug, Clone)]
447pub struct DepthwiseSeparableConv {
448 in_channels: usize,
450 out_channels: usize,
452 kernel_size: usize,
454 stride: usize,
456 padding: usize,
458 depthwise_weights: Vec<f32>,
460 pointwise_weights: Vec<f32>,
462}
463
464impl DepthwiseSeparableConv {
465 pub fn new(
467 in_channels: usize,
468 out_channels: usize,
469 kernel_size: usize,
470 stride: usize,
471 padding: usize,
472 ) -> Self {
473 let dw_size = in_channels * kernel_size * kernel_size;
474 let pw_size = out_channels * in_channels;
475
476 let depthwise_weights: Vec<f32> = (0..dw_size)
478 .map(|i| {
479 let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
480 (x * 2.0 - 1.0) * 0.1
481 })
482 .collect();
483
484 let pointwise_weights: Vec<f32> = (0..pw_size)
485 .map(|i| {
486 let x = ((i * 1103515245 + 54321) % (1 << 31)) as f32 / (1u32 << 31) as f32;
487 (x * 2.0 - 1.0) * 0.1
488 })
489 .collect();
490
491 Self {
492 in_channels,
493 out_channels,
494 kernel_size,
495 stride,
496 padding,
497 depthwise_weights,
498 pointwise_weights,
499 }
500 }
501
502 pub fn set_depthwise_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
504 let expected = self.in_channels * self.kernel_size * self.kernel_size;
505 if weights.len() != expected {
506 return Err(CnnError::invalid_shape(
507 format!("{} depthwise weights", expected),
508 format!("{} weights", weights.len()),
509 ));
510 }
511 self.depthwise_weights = weights;
512 Ok(())
513 }
514
515 pub fn set_pointwise_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
517 let expected = self.out_channels * self.in_channels;
518 if weights.len() != expected {
519 return Err(CnnError::invalid_shape(
520 format!("{} pointwise weights", expected),
521 format!("{} weights", weights.len()),
522 ));
523 }
524 self.pointwise_weights = weights;
525 Ok(())
526 }
527}
528
529impl Layer for DepthwiseSeparableConv {
530 fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
531 let shape = input.shape();
532 if shape.len() != 4 {
533 return Err(CnnError::invalid_shape(
534 "4D tensor (NHWC)",
535 format!("{}D tensor", shape.len()),
536 ));
537 }
538
539 let in_channels = shape[3];
540 if in_channels != self.in_channels {
541 return Err(CnnError::invalid_shape(
542 format!("{} input channels", self.in_channels),
543 format!("{} input channels", in_channels),
544 ));
545 }
546
547 let batch = shape[0];
548 let in_h = shape[1];
549 let in_w = shape[2];
550
551 let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
552 let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
553
554 let dw_shape = vec![batch, out_h, out_w, self.in_channels];
556 let mut dw_output = Tensor::zeros(&dw_shape);
557
558 let batch_in_size = in_h * in_w * self.in_channels;
559 let batch_dw_size = out_h * out_w * self.in_channels;
560
561 for b in 0..batch {
562 let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
563 let output_slice = &mut dw_output.data_mut()[b * batch_dw_size..(b + 1) * batch_dw_size];
564
565 if self.kernel_size == 3 {
566 simd::depthwise_conv_3x3_simd(
567 input_slice,
568 &self.depthwise_weights,
569 output_slice,
570 in_h,
571 in_w,
572 self.in_channels,
573 self.stride,
574 self.padding,
575 );
576 } else {
577 self.depthwise_generic(input_slice, output_slice, in_h, in_w, out_h, out_w);
578 }
579 }
580
581 let pw_shape = vec![batch, out_h, out_w, self.out_channels];
583 let mut output = Tensor::zeros(&pw_shape);
584
585 let batch_pw_size = out_h * out_w * self.out_channels;
586
587 for b in 0..batch {
588 let dw_slice = &dw_output.data()[b * batch_dw_size..(b + 1) * batch_dw_size];
589 let output_slice = &mut output.data_mut()[b * batch_pw_size..(b + 1) * batch_pw_size];
590
591 simd::scalar::conv_1x1_scalar(
592 dw_slice,
593 &self.pointwise_weights,
594 output_slice,
595 out_h,
596 out_w,
597 self.in_channels,
598 self.out_channels,
599 );
600 }
601
602 Ok(output)
603 }
604
605 fn name(&self) -> &'static str {
606 "DepthwiseSeparableConv"
607 }
608
609 fn num_params(&self) -> usize {
610 let dw_params = self.in_channels * self.kernel_size * self.kernel_size;
611 let pw_params = self.out_channels * self.in_channels;
612 dw_params + pw_params
613 }
614}
615
616impl DepthwiseSeparableConv {
617 fn depthwise_generic(
619 &self,
620 input: &[f32],
621 output: &mut [f32],
622 in_h: usize,
623 in_w: usize,
624 out_h: usize,
625 out_w: usize,
626 ) {
627 let ks = self.kernel_size;
628
629 for oh in 0..out_h {
630 for ow in 0..out_w {
631 for ch in 0..self.in_channels {
632 let mut sum = 0.0f32;
633
634 for kh in 0..ks {
635 for kw in 0..ks {
636 let ih = (oh * self.stride + kh) as isize - self.padding as isize;
637 let iw = (ow * self.stride + kw) as isize - self.padding as isize;
638
639 if ih >= 0
640 && ih < in_h as isize
641 && iw >= 0
642 && iw < in_w as isize
643 {
644 let ih = ih as usize;
645 let iw = iw as usize;
646
647 let input_idx =
648 ih * in_w * self.in_channels + iw * self.in_channels + ch;
649 let kernel_idx = ch * ks * ks + kh * ks + kw;
650 sum += input[input_idx] * self.depthwise_weights[kernel_idx];
651 }
652 }
653 }
654
655 output[oh * out_w * self.in_channels + ow * self.in_channels + ch] = sum;
656 }
657 }
658 }
659 }
660}
661
662#[cfg(test)]
663mod tests {
664 use super::*;
665
666 #[test]
667 fn test_conv2d_creation() {
668 let conv = Conv2d::new(3, 64, 3, 1, 1);
669 assert_eq!(conv.num_params(), 3 * 64 * 3 * 3);
670 }
671
672 #[test]
673 fn test_conv2d_output_shape() {
674 let conv = Conv2d::new(3, 64, 3, 1, 1);
675 let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
676 assert_eq!(shape, vec![1, 224, 224, 64]);
677 }
678
679 #[test]
680 fn test_conv2d_output_shape_stride2() {
681 let conv = Conv2d::new(3, 64, 3, 2, 1);
682 let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
683 assert_eq!(shape, vec![1, 112, 112, 64]);
684 }
685
686 #[test]
687 fn test_conv2d_forward() {
688 let conv = Conv2d::new(3, 16, 3, 1, 1);
689 let input = Tensor::ones(&[1, 8, 8, 3]);
690 let output = conv.forward(&input).unwrap();
691
692 assert_eq!(output.shape(), &[1, 8, 8, 16]);
693 }
694
695 #[test]
696 fn test_depthwise_separable_conv() {
697 let conv = DepthwiseSeparableConv::new(16, 32, 3, 1, 1);
698 let input = Tensor::ones(&[1, 8, 8, 16]);
699 let output = conv.forward(&input).unwrap();
700
701 assert_eq!(output.shape(), &[1, 8, 8, 32]);
702 }
703
704 #[test]
705 fn test_depthwise_separable_conv_params() {
706 let conv = DepthwiseSeparableConv::new(16, 32, 3, 1, 1);
707
708 assert_eq!(conv.num_params(), 144 + 512);
712
713 }
716}