Skip to main content

ruvector_cnn/layers/
conv.rs

1//! Convolutional Layers
2//!
3//! SIMD-optimized 2D convolution implementations:
4//! - Conv2d: Standard 2D convolution
5//! - DepthwiseSeparableConv: MobileNet-style efficient convolution
6
7use crate::{simd, CnnError, CnnResult, Tensor};
8
9use super::{Layer, TensorShape};
10
11/// 2D Convolution Layer
12///
13/// Performs 2D convolution on NHWC tensors with configurable:
14/// - Kernel size
15/// - Stride
16/// - Padding
17/// - Groups (for depthwise and grouped convolutions)
18///
19/// Kernel layout: [out_channels, kernel_h, kernel_w, in_channels] (OHWI)
20#[derive(Debug, Clone)]
21pub struct Conv2d {
22    /// Number of input channels
23    in_channels: usize,
24    /// Number of output channels
25    out_channels: usize,
26    /// Kernel size (height and width)
27    kernel_size: usize,
28    /// Stride
29    stride: usize,
30    /// Padding
31    padding: usize,
32    /// Groups for grouped/depthwise convolution
33    groups: usize,
34    /// Kernel weights: [out_c, kh, kw, in_c/groups]
35    weights: Vec<f32>,
36    /// Bias: [out_c]
37    bias: Option<Vec<f32>>,
38}
39
40/// Builder for Conv2d layer
41#[derive(Debug, Clone)]
42pub struct Conv2dBuilder {
43    in_channels: usize,
44    out_channels: usize,
45    kernel_size: usize,
46    stride: usize,
47    padding: usize,
48    groups: usize,
49    bias: bool,
50}
51
52impl Conv2dBuilder {
53    /// Create a new builder
54    pub fn new(in_channels: usize, out_channels: usize, kernel_size: usize) -> Self {
55        Self {
56            in_channels,
57            out_channels,
58            kernel_size,
59            stride: 1,
60            padding: 0,
61            groups: 1,
62            bias: true,
63        }
64    }
65
66    /// Set stride
67    pub fn stride(mut self, stride: usize) -> Self {
68        self.stride = stride;
69        self
70    }
71
72    /// Set padding
73    pub fn padding(mut self, padding: usize) -> Self {
74        self.padding = padding;
75        self
76    }
77
78    /// Set groups for grouped convolution
79    pub fn groups(mut self, groups: usize) -> Self {
80        self.groups = groups;
81        self
82    }
83
84    /// Set whether to use bias
85    pub fn bias(mut self, bias: bool) -> Self {
86        self.bias = bias;
87        self
88    }
89
90    /// Build the Conv2d layer
91    pub fn build(self) -> CnnResult<Conv2d> {
92        if self.in_channels % self.groups != 0 {
93            return Err(CnnError::InvalidParameter(
94                format!("in_channels {} must be divisible by groups {}", self.in_channels, self.groups)
95            ));
96        }
97        if self.out_channels % self.groups != 0 {
98            return Err(CnnError::InvalidParameter(
99                format!("out_channels {} must be divisible by groups {}", self.out_channels, self.groups)
100            ));
101        }
102
103        let in_channels_per_group = self.in_channels / self.groups;
104        let num_weights = self.out_channels * self.kernel_size * self.kernel_size * in_channels_per_group;
105
106        // Xavier/Glorot initialization
107        let fan_in = in_channels_per_group * self.kernel_size * self.kernel_size;
108        let fan_out = (self.out_channels / self.groups) * self.kernel_size * self.kernel_size;
109        let std_dev = (2.0 / (fan_in + fan_out) as f32).sqrt();
110
111        let weights: Vec<f32> = (0..num_weights)
112            .map(|i| {
113                let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
114                (x * 2.0 - 1.0) * std_dev
115            })
116            .collect();
117
118        let bias = if self.bias {
119            Some(vec![0.0; self.out_channels])
120        } else {
121            None
122        };
123
124        Ok(Conv2d {
125            in_channels: self.in_channels,
126            out_channels: self.out_channels,
127            kernel_size: self.kernel_size,
128            stride: self.stride,
129            padding: self.padding,
130            groups: self.groups,
131            weights,
132            bias,
133        })
134    }
135}
136
137impl Conv2d {
138    /// Create a new Conv2d layer with Xavier initialization
139    pub fn new(
140        in_channels: usize,
141        out_channels: usize,
142        kernel_size: usize,
143        stride: usize,
144        padding: usize,
145    ) -> Self {
146        let num_weights = out_channels * kernel_size * kernel_size * in_channels;
147
148        // Xavier/Glorot initialization
149        let fan_in = in_channels * kernel_size * kernel_size;
150        let fan_out = out_channels * kernel_size * kernel_size;
151        let std_dev = (2.0 / (fan_in + fan_out) as f32).sqrt();
152
153        // Simple pseudo-random initialization (for deterministic tests)
154        let weights: Vec<f32> = (0..num_weights)
155            .map(|i| {
156                let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
157                (x * 2.0 - 1.0) * std_dev
158            })
159            .collect();
160
161        Self {
162            in_channels,
163            out_channels,
164            kernel_size,
165            stride,
166            padding,
167            groups: 1,
168            weights,
169            bias: None,
170        }
171    }
172
173    /// Create a Conv2d builder
174    pub fn builder(in_channels: usize, out_channels: usize, kernel_size: usize) -> Conv2dBuilder {
175        Conv2dBuilder::new(in_channels, out_channels, kernel_size)
176    }
177
178    /// Create Conv2d with bias
179    pub fn with_bias(
180        in_channels: usize,
181        out_channels: usize,
182        kernel_size: usize,
183        stride: usize,
184        padding: usize,
185    ) -> Self {
186        let mut conv = Self::new(in_channels, out_channels, kernel_size, stride, padding);
187        conv.bias = Some(vec![0.0; out_channels]);
188        conv
189    }
190
191    /// Get the output shape for a TensorShape input (NCHW format)
192    pub fn output_shape_nchw(&self, input_shape: &TensorShape) -> TensorShape {
193        let out_h = (input_shape.h + 2 * self.padding - self.kernel_size) / self.stride + 1;
194        let out_w = (input_shape.w + 2 * self.padding - self.kernel_size) / self.stride + 1;
195        TensorShape::new(input_shape.n, self.out_channels, out_h, out_w)
196    }
197
198    /// Set the weights directly
199    pub fn set_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
200        let expected = self.out_channels * self.kernel_size * self.kernel_size * self.in_channels;
201        if weights.len() != expected {
202            return Err(CnnError::invalid_shape(
203                format!("{} weights", expected),
204                format!("{} weights", weights.len()),
205            ));
206        }
207        self.weights = weights;
208        Ok(())
209    }
210
211    /// Set the bias
212    pub fn set_bias(&mut self, bias: Vec<f32>) -> CnnResult<()> {
213        if bias.len() != self.out_channels {
214            return Err(CnnError::invalid_shape(
215                format!("{} bias values", self.out_channels),
216                format!("{} bias values", bias.len()),
217            ));
218        }
219        self.bias = Some(bias);
220        Ok(())
221    }
222
223    /// Get the output shape for a given input shape
224    pub fn output_shape(&self, input_shape: &[usize]) -> CnnResult<Vec<usize>> {
225        if input_shape.len() != 4 {
226            return Err(CnnError::invalid_shape(
227                "4D tensor (NHWC)",
228                format!("{}D tensor", input_shape.len()),
229            ));
230        }
231
232        let batch = input_shape[0];
233        let in_h = input_shape[1];
234        let in_w = input_shape[2];
235
236        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
237        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
238
239        Ok(vec![batch, out_h, out_w, self.out_channels])
240    }
241
242    /// Get weights reference
243    pub fn weights(&self) -> &[f32] {
244        &self.weights
245    }
246
247    /// Get bias reference
248    pub fn bias(&self) -> Option<&[f32]> {
249        self.bias.as_deref()
250    }
251
252    /// Get the kernel size
253    pub fn kernel_size(&self) -> usize {
254        self.kernel_size
255    }
256
257    /// Get the stride
258    pub fn stride(&self) -> usize {
259        self.stride
260    }
261
262    /// Get the padding
263    pub fn padding(&self) -> usize {
264        self.padding
265    }
266
267    /// Get the number of output channels
268    pub fn out_channels(&self) -> usize {
269        self.out_channels
270    }
271
272    /// Get the number of input channels
273    pub fn in_channels(&self) -> usize {
274        self.in_channels
275    }
276
277    /// Get the number of groups
278    pub fn groups(&self) -> usize {
279        self.groups
280    }
281}
282
283impl Layer for Conv2d {
284    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
285        let shape = input.shape();
286        if shape.len() != 4 {
287            return Err(CnnError::invalid_shape(
288                "4D tensor (NHWC)",
289                format!("{}D tensor", shape.len()),
290            ));
291        }
292
293        let in_channels = shape[3];
294        if in_channels != self.in_channels {
295            return Err(CnnError::invalid_shape(
296                format!("{} input channels", self.in_channels),
297                format!("{} input channels", in_channels),
298            ));
299        }
300
301        let batch = shape[0];
302        let in_h = shape[1];
303        let in_w = shape[2];
304
305        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
306        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
307
308        let out_shape = vec![batch, out_h, out_w, self.out_channels];
309        let mut output = Tensor::zeros(&out_shape);
310
311        // Process each batch
312        let batch_in_size = in_h * in_w * in_channels;
313        let batch_out_size = out_h * out_w * self.out_channels;
314
315        for b in 0..batch {
316            let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
317            let output_slice = &mut output.data_mut()[b * batch_out_size..(b + 1) * batch_out_size];
318
319            if self.kernel_size == 3 && self.groups == 1 {
320                // Standard 3x3 convolution (non-grouped)
321                simd::conv_3x3_simd(
322                    input_slice,
323                    &self.weights,
324                    output_slice,
325                    in_h,
326                    in_w,
327                    self.in_channels,
328                    self.out_channels,
329                    self.stride,
330                    self.padding,
331                );
332            } else if self.kernel_size == 3 && self.groups == self.in_channels && self.in_channels == self.out_channels {
333                // Depthwise 3x3 convolution (groups == in_channels == out_channels)
334                simd::depthwise_conv_3x3_simd(
335                    input_slice,
336                    &self.weights,
337                    output_slice,
338                    in_h,
339                    in_w,
340                    self.in_channels,
341                    self.stride,
342                    self.padding,
343                );
344            } else {
345                // Fallback to generic convolution for other cases
346                self.conv_generic(input_slice, output_slice, in_h, in_w, out_h, out_w);
347            }
348        }
349
350        // Add bias if present
351        if let Some(bias) = &self.bias {
352            for val in output.data_mut().chunks_mut(self.out_channels) {
353                for (i, v) in val.iter_mut().enumerate() {
354                    *v += bias[i];
355                }
356            }
357        }
358
359        Ok(output)
360    }
361
362    fn name(&self) -> &'static str {
363        "Conv2d"
364    }
365
366    fn num_params(&self) -> usize {
367        let weight_params =
368            self.out_channels * self.kernel_size * self.kernel_size * self.in_channels;
369        let bias_params = if self.bias.is_some() {
370            self.out_channels
371        } else {
372            0
373        };
374        weight_params + bias_params
375    }
376}
377
378impl Conv2d {
379    /// Generic convolution for arbitrary kernel sizes
380    fn conv_generic(
381        &self,
382        input: &[f32],
383        output: &mut [f32],
384        in_h: usize,
385        in_w: usize,
386        out_h: usize,
387        out_w: usize,
388    ) {
389        let ks = self.kernel_size;
390        let in_channels_per_group = self.in_channels / self.groups;
391        let out_channels_per_group = self.out_channels / self.groups;
392
393        for oh in 0..out_h {
394            for ow in 0..out_w {
395                for g in 0..self.groups {
396                    let in_c_start = g * in_channels_per_group;
397                    let out_c_start = g * out_channels_per_group;
398
399                    for oc_local in 0..out_channels_per_group {
400                        let oc = out_c_start + oc_local;
401                        let mut sum = 0.0f32;
402
403                        for kh in 0..ks {
404                            for kw in 0..ks {
405                                let ih = (oh * self.stride + kh) as isize - self.padding as isize;
406                                let iw = (ow * self.stride + kw) as isize - self.padding as isize;
407
408                                if ih >= 0
409                                    && ih < in_h as isize
410                                    && iw >= 0
411                                    && iw < in_w as isize
412                                {
413                                    let ih = ih as usize;
414                                    let iw = iw as usize;
415
416                                    for ic_local in 0..in_channels_per_group {
417                                        let ic = in_c_start + ic_local;
418                                        let input_idx =
419                                            ih * in_w * self.in_channels + iw * self.in_channels + ic;
420                                        // Kernel layout: [out_c, kh, kw, in_c_per_group]
421                                        let kernel_idx = oc * ks * ks * in_channels_per_group
422                                            + kh * ks * in_channels_per_group
423                                            + kw * in_channels_per_group
424                                            + ic_local;
425                                        sum += input[input_idx] * self.weights[kernel_idx];
426                                    }
427                                }
428                            }
429                        }
430
431                        output[oh * out_w * self.out_channels + ow * self.out_channels + oc] = sum;
432                    }
433                }
434            }
435        }
436    }
437}
438
439/// Depthwise Separable Convolution
440///
441/// Efficient convolution used in MobileNet architectures:
442/// 1. Depthwise convolution: one filter per input channel
443/// 2. Pointwise convolution: 1x1 conv to mix channels
444///
445/// Reduces parameters from O(K^2 * C_in * C_out) to O(K^2 * C_in + C_in * C_out)
446#[derive(Debug, Clone)]
447pub struct DepthwiseSeparableConv {
448    /// Number of input channels
449    in_channels: usize,
450    /// Number of output channels
451    out_channels: usize,
452    /// Depthwise kernel size
453    kernel_size: usize,
454    /// Stride for depthwise conv
455    stride: usize,
456    /// Padding for depthwise conv
457    padding: usize,
458    /// Depthwise weights: [in_channels, kernel_h, kernel_w]
459    depthwise_weights: Vec<f32>,
460    /// Pointwise weights: [out_channels, in_channels]
461    pointwise_weights: Vec<f32>,
462}
463
464impl DepthwiseSeparableConv {
465    /// Create a new depthwise separable convolution
466    pub fn new(
467        in_channels: usize,
468        out_channels: usize,
469        kernel_size: usize,
470        stride: usize,
471        padding: usize,
472    ) -> Self {
473        let dw_size = in_channels * kernel_size * kernel_size;
474        let pw_size = out_channels * in_channels;
475
476        // Initialize with small random values
477        let depthwise_weights: Vec<f32> = (0..dw_size)
478            .map(|i| {
479                let x = ((i * 1103515245 + 12345) % (1 << 31)) as f32 / (1u32 << 31) as f32;
480                (x * 2.0 - 1.0) * 0.1
481            })
482            .collect();
483
484        let pointwise_weights: Vec<f32> = (0..pw_size)
485            .map(|i| {
486                let x = ((i * 1103515245 + 54321) % (1 << 31)) as f32 / (1u32 << 31) as f32;
487                (x * 2.0 - 1.0) * 0.1
488            })
489            .collect();
490
491        Self {
492            in_channels,
493            out_channels,
494            kernel_size,
495            stride,
496            padding,
497            depthwise_weights,
498            pointwise_weights,
499        }
500    }
501
502    /// Set depthwise weights
503    pub fn set_depthwise_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
504        let expected = self.in_channels * self.kernel_size * self.kernel_size;
505        if weights.len() != expected {
506            return Err(CnnError::invalid_shape(
507                format!("{} depthwise weights", expected),
508                format!("{} weights", weights.len()),
509            ));
510        }
511        self.depthwise_weights = weights;
512        Ok(())
513    }
514
515    /// Set pointwise weights
516    pub fn set_pointwise_weights(&mut self, weights: Vec<f32>) -> CnnResult<()> {
517        let expected = self.out_channels * self.in_channels;
518        if weights.len() != expected {
519            return Err(CnnError::invalid_shape(
520                format!("{} pointwise weights", expected),
521                format!("{} weights", weights.len()),
522            ));
523        }
524        self.pointwise_weights = weights;
525        Ok(())
526    }
527}
528
529impl Layer for DepthwiseSeparableConv {
530    fn forward(&self, input: &Tensor) -> CnnResult<Tensor> {
531        let shape = input.shape();
532        if shape.len() != 4 {
533            return Err(CnnError::invalid_shape(
534                "4D tensor (NHWC)",
535                format!("{}D tensor", shape.len()),
536            ));
537        }
538
539        let in_channels = shape[3];
540        if in_channels != self.in_channels {
541            return Err(CnnError::invalid_shape(
542                format!("{} input channels", self.in_channels),
543                format!("{} input channels", in_channels),
544            ));
545        }
546
547        let batch = shape[0];
548        let in_h = shape[1];
549        let in_w = shape[2];
550
551        let out_h = (in_h + 2 * self.padding - self.kernel_size) / self.stride + 1;
552        let out_w = (in_w + 2 * self.padding - self.kernel_size) / self.stride + 1;
553
554        // Step 1: Depthwise convolution
555        let dw_shape = vec![batch, out_h, out_w, self.in_channels];
556        let mut dw_output = Tensor::zeros(&dw_shape);
557
558        let batch_in_size = in_h * in_w * self.in_channels;
559        let batch_dw_size = out_h * out_w * self.in_channels;
560
561        for b in 0..batch {
562            let input_slice = &input.data()[b * batch_in_size..(b + 1) * batch_in_size];
563            let output_slice = &mut dw_output.data_mut()[b * batch_dw_size..(b + 1) * batch_dw_size];
564
565            if self.kernel_size == 3 {
566                simd::depthwise_conv_3x3_simd(
567                    input_slice,
568                    &self.depthwise_weights,
569                    output_slice,
570                    in_h,
571                    in_w,
572                    self.in_channels,
573                    self.stride,
574                    self.padding,
575                );
576            } else {
577                self.depthwise_generic(input_slice, output_slice, in_h, in_w, out_h, out_w);
578            }
579        }
580
581        // Step 2: Pointwise (1x1) convolution
582        let pw_shape = vec![batch, out_h, out_w, self.out_channels];
583        let mut output = Tensor::zeros(&pw_shape);
584
585        let batch_pw_size = out_h * out_w * self.out_channels;
586
587        for b in 0..batch {
588            let dw_slice = &dw_output.data()[b * batch_dw_size..(b + 1) * batch_dw_size];
589            let output_slice = &mut output.data_mut()[b * batch_pw_size..(b + 1) * batch_pw_size];
590
591            simd::scalar::conv_1x1_scalar(
592                dw_slice,
593                &self.pointwise_weights,
594                output_slice,
595                out_h,
596                out_w,
597                self.in_channels,
598                self.out_channels,
599            );
600        }
601
602        Ok(output)
603    }
604
605    fn name(&self) -> &'static str {
606        "DepthwiseSeparableConv"
607    }
608
609    fn num_params(&self) -> usize {
610        let dw_params = self.in_channels * self.kernel_size * self.kernel_size;
611        let pw_params = self.out_channels * self.in_channels;
612        dw_params + pw_params
613    }
614}
615
616impl DepthwiseSeparableConv {
617    /// Generic depthwise convolution for arbitrary kernel sizes
618    fn depthwise_generic(
619        &self,
620        input: &[f32],
621        output: &mut [f32],
622        in_h: usize,
623        in_w: usize,
624        out_h: usize,
625        out_w: usize,
626    ) {
627        let ks = self.kernel_size;
628
629        for oh in 0..out_h {
630            for ow in 0..out_w {
631                for ch in 0..self.in_channels {
632                    let mut sum = 0.0f32;
633
634                    for kh in 0..ks {
635                        for kw in 0..ks {
636                            let ih = (oh * self.stride + kh) as isize - self.padding as isize;
637                            let iw = (ow * self.stride + kw) as isize - self.padding as isize;
638
639                            if ih >= 0
640                                && ih < in_h as isize
641                                && iw >= 0
642                                && iw < in_w as isize
643                            {
644                                let ih = ih as usize;
645                                let iw = iw as usize;
646
647                                let input_idx =
648                                    ih * in_w * self.in_channels + iw * self.in_channels + ch;
649                                let kernel_idx = ch * ks * ks + kh * ks + kw;
650                                sum += input[input_idx] * self.depthwise_weights[kernel_idx];
651                            }
652                        }
653                    }
654
655                    output[oh * out_w * self.in_channels + ow * self.in_channels + ch] = sum;
656                }
657            }
658        }
659    }
660}
661
662#[cfg(test)]
663mod tests {
664    use super::*;
665
666    #[test]
667    fn test_conv2d_creation() {
668        let conv = Conv2d::new(3, 64, 3, 1, 1);
669        assert_eq!(conv.num_params(), 3 * 64 * 3 * 3);
670    }
671
672    #[test]
673    fn test_conv2d_output_shape() {
674        let conv = Conv2d::new(3, 64, 3, 1, 1);
675        let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
676        assert_eq!(shape, vec![1, 224, 224, 64]);
677    }
678
679    #[test]
680    fn test_conv2d_output_shape_stride2() {
681        let conv = Conv2d::new(3, 64, 3, 2, 1);
682        let shape = conv.output_shape(&[1, 224, 224, 3]).unwrap();
683        assert_eq!(shape, vec![1, 112, 112, 64]);
684    }
685
686    #[test]
687    fn test_conv2d_forward() {
688        let conv = Conv2d::new(3, 16, 3, 1, 1);
689        let input = Tensor::ones(&[1, 8, 8, 3]);
690        let output = conv.forward(&input).unwrap();
691
692        assert_eq!(output.shape(), &[1, 8, 8, 16]);
693    }
694
695    #[test]
696    fn test_depthwise_separable_conv() {
697        let conv = DepthwiseSeparableConv::new(16, 32, 3, 1, 1);
698        let input = Tensor::ones(&[1, 8, 8, 16]);
699        let output = conv.forward(&input).unwrap();
700
701        assert_eq!(output.shape(), &[1, 8, 8, 32]);
702    }
703
704    #[test]
705    fn test_depthwise_separable_conv_params() {
706        let conv = DepthwiseSeparableConv::new(16, 32, 3, 1, 1);
707
708        // depthwise: 16 * 3 * 3 = 144
709        // pointwise: 32 * 16 = 512
710        // total: 656
711        assert_eq!(conv.num_params(), 144 + 512);
712
713        // Compare to regular conv: 32 * 3 * 3 * 16 = 4608
714        // Savings: 4608 / 656 = 7x fewer params
715    }
716}