Skip to main content

ruvector_cnn/layers/
mod.rs

1//! Neural Network Layers
2//!
3//! This module provides standard CNN layers with SIMD-optimized implementations:
4//! - **Conv2d**: 2D convolution with configurable kernel, stride, padding
5//! - **DepthwiseSeparableConv**: MobileNet-style efficient convolutions
6//! - **BatchNorm**: Batch normalization with learned parameters
7//! - **Activations**: ReLU, ReLU6, Swish, HardSwish, Sigmoid
8//! - **Pooling**: GlobalAvgPool, MaxPool2d, AvgPool2d
9//! - **Linear**: Fully connected layer
10
11pub mod activation;
12pub mod batchnorm;
13pub mod conv;
14pub mod linear;
15pub mod pooling;
16
17pub use activation::{Activation, ActivationType, HardSwish, ReLU, ReLU6, Sigmoid, Swish};
18pub use batchnorm::{BatchNorm, BatchNorm2d};
19pub use conv::{Conv2d, DepthwiseSeparableConv};
20pub use linear::Linear;
21pub use pooling::{AvgPool2d, GlobalAvgPool, GlobalAvgPool2d, MaxPool2d};
22
23use crate::{CnnResult, Tensor};
24
25/// Tensor shape for 4D tensors (N, C, H, W).
26#[derive(Clone, Copy, Debug, PartialEq, Eq)]
27pub struct TensorShape {
28    /// Batch size
29    pub n: usize,
30    /// Number of channels
31    pub c: usize,
32    /// Height
33    pub h: usize,
34    /// Width
35    pub w: usize,
36}
37
38impl TensorShape {
39    /// Creates a new tensor shape.
40    pub fn new(n: usize, c: usize, h: usize, w: usize) -> Self {
41        Self { n, c, h, w }
42    }
43
44    /// Returns the total number of elements.
45    pub fn numel(&self) -> usize {
46        self.n * self.c * self.h * self.w
47    }
48
49    /// Returns the spatial size (H * W).
50    pub fn spatial_size(&self) -> usize {
51        self.h * self.w
52    }
53
54    /// Returns the channel size (C * H * W).
55    pub fn channel_size(&self) -> usize {
56        self.c * self.h * self.w
57    }
58}
59
60impl std::fmt::Display for TensorShape {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        write!(f, "[{}, {}, {}, {}]", self.n, self.c, self.h, self.w)
63    }
64}
65
66/// Computes output size for convolution or pooling.
67pub fn conv_output_size(input: usize, kernel: usize, stride: usize, padding: usize, dilation: usize) -> usize {
68    let effective_kernel = dilation * (kernel - 1) + 1;
69    (input + 2 * padding - effective_kernel) / stride + 1
70}
71
72/// Trait for all neural network layers
73pub trait Layer {
74    /// Perform the forward pass
75    fn forward(&self, input: &Tensor) -> CnnResult<Tensor>;
76
77    /// Get the layer name
78    fn name(&self) -> &'static str;
79
80    /// Get the number of parameters
81    fn num_params(&self) -> usize {
82        0
83    }
84}
85
86// =============================================================================
87// Standalone layer functions (for backward compatibility with old backbone code)
88// =============================================================================
89
90/// 3x3 convolution function (standalone, for backward compatibility)
91///
92/// Input layout: NCHW flattened as [N * C * H * W]
93/// Output layout: NCHW flattened
94pub fn conv2d_3x3(
95    input: &[f32],
96    weights: &[f32],
97    in_channels: usize,
98    out_channels: usize,
99    height: usize,
100    width: usize,
101) -> Vec<f32> {
102    let out_h = height; // Same padding assumed
103    let out_w = width;
104    let mut output = vec![0.0f32; out_h * out_w * out_channels];
105
106    crate::simd::scalar::conv_3x3_scalar(
107        input,
108        weights,
109        &mut output,
110        height,
111        width,
112        in_channels,
113        out_channels,
114        1, // stride
115        1, // padding for same output size
116    );
117
118    output
119}
120
121/// Batch normalization function (standalone, for backward compatibility)
122///
123/// y = gamma * (x - mean) / sqrt(var + epsilon) + beta
124pub fn batch_norm(
125    input: &[f32],
126    gamma: &[f32],
127    beta: &[f32],
128    mean: &[f32],
129    var: &[f32],
130    epsilon: f32,
131) -> Vec<f32> {
132    let mut output = vec![0.0f32; input.len()];
133    let channels = gamma.len();
134
135    crate::simd::batch_norm_simd(
136        input,
137        &mut output,
138        gamma,
139        beta,
140        mean,
141        var,
142        epsilon,
143        channels,
144    );
145
146    output
147}
148
149/// HardSwish activation function (standalone, for backward compatibility)
150///
151/// hard_swish(x) = x * relu6(x + 3) / 6
152pub fn hard_swish(input: &[f32]) -> Vec<f32> {
153    let mut output = vec![0.0f32; input.len()];
154    crate::simd::scalar::hard_swish_scalar(input, &mut output);
155    output
156}
157
158/// ReLU activation function (standalone, for backward compatibility)
159pub fn relu(input: &[f32]) -> Vec<f32> {
160    let mut output = vec![0.0f32; input.len()];
161    crate::simd::relu_simd(input, &mut output);
162    output
163}
164
165/// ReLU6 activation function (standalone, for backward compatibility)
166pub fn relu6(input: &[f32]) -> Vec<f32> {
167    let mut output = vec![0.0f32; input.len()];
168    crate::simd::relu6_simd(input, &mut output);
169    output
170}
171
172/// Global average pooling function (standalone, for backward compatibility)
173///
174/// Assumes NCHW layout, pools over H*W dimensions
175pub fn global_avg_pool(input: &[f32], channels: usize) -> Vec<f32> {
176    let spatial_size = input.len() / channels;
177    let mut output = vec![0.0f32; channels];
178
179    // Sum over spatial dimensions
180    for i in 0..input.len() {
181        let c = i % channels;
182        output[c] += input[i];
183    }
184
185    // Average
186    let inv_spatial = 1.0 / spatial_size as f32;
187    for o in output.iter_mut() {
188        *o *= inv_spatial;
189    }
190
191    output
192}
193
194// Re-export Conv2dBuilder from conv module
195pub use conv::Conv2dBuilder;
196
197// =============================================================================
198// Activation helper methods
199// =============================================================================
200
201impl Activation {
202    /// Creates a ReLU activation.
203    pub fn relu() -> Self {
204        Self::new(ActivationType::ReLU)
205    }
206
207    /// Creates a ReLU6 activation.
208    pub fn relu6() -> Self {
209        Self::new(ActivationType::ReLU6)
210    }
211
212    /// Creates a HardSwish activation.
213    pub fn hard_swish() -> Self {
214        Self::new(ActivationType::HardSwish)
215    }
216
217    /// Creates a HardSigmoid activation (using Sigmoid as approximation).
218    pub fn hard_sigmoid() -> Self {
219        Self::new(ActivationType::Sigmoid)
220    }
221
222    /// Creates an identity (no-op) activation.
223    pub fn identity() -> Self {
224        Self::new(ActivationType::Identity)
225    }
226}