scirs2_neural/layers/conv/
common.rs

1//! Common types and utilities for convolutional and pooling layers
2//!
3//! This module provides shared types, enums, and utility functions used across
4//! all convolutional and pooling layer implementations including padding modes,
5//! type aliases for caching, and common validation functions.
6
7use scirs2_core::ndarray::{Array, IxDyn};
8use std::sync::{Arc, RwLock};
9
10/// Type alias for caching max indices in 2D pooling operations
11pub type MaxIndicesCache = Arc<RwLock<Option<Array<(usize, usize), IxDyn>>>>;
12
13/// Type alias for caching max indices in 3D pooling operations  
14pub type MaxIndicesCache3D = Arc<RwLock<Option<Array<(usize, usize, usize), IxDyn>>>>;
15
16/// Padding mode for convolutional layers
17#[derive(Debug, Clone, Copy, PartialEq, Default)]
18pub enum PaddingMode {
19    /// No padding (will reduce spatial dimensions)
20    #[default]
21    Valid,
22    /// Padding to preserve spatial dimensions
23    Same,
24    /// Custom padding values
25    Custom(usize),
26}
27
28impl PaddingMode {
29    /// Calculate padding values for a given kernel size and dilation
30    pub fn calculate_padding(
31        &self,
32        kernel_size: (usize, usize),
33        dilation: (usize, usize),
34    ) -> (usize, usize) {
35        match self {
36            PaddingMode::Valid => (0, 0),
37            PaddingMode::Same => (
38                (kernel_size.0 - 1) * dilation.0 / 2,
39                (kernel_size.1 - 1) * dilation.1 / 2,
40            ),
41            PaddingMode::Custom(pad) => (*pad, *pad),
42        }
43    }
44
45    /// Get a string representation of the padding mode
46    pub fn as_str(&self) -> String {
47        match self {
48            PaddingMode::Valid => "valid".to_string(),
49            PaddingMode::Same => "same".to_string(),
50            PaddingMode::Custom(p) => p.to_string(),
51        }
52    }
53}
54
55/// Validate convolution parameters
56#[allow(dead_code)]
57pub fn validate_conv_params(
58    in_channels: usize,
59    out_channels: usize,
60    kernel_size: (usize, usize),
61    stride: (usize, usize),
62) -> Result<(), String> {
63    if in_channels == 0 {
64        return Err("Input _channels must be greater than 0".to_string());
65    }
66    if out_channels == 0 {
67        return Err("Output _channels must be greater than 0".to_string());
68    }
69    if kernel_size.0 == 0 || kernel_size.1 == 0 {
70        return Err("Kernel _size must be greater than 0".to_string());
71    }
72    if stride.0 == 0 || stride.1 == 0 {
73        return Err("Stride must be greater than 0".to_string());
74    }
75    Ok(())
76}
77
78/// Calculate output shape for convolution operations
79#[allow(dead_code)]
80pub fn calculate_outputshape(
81    input_height: usize,
82    input_width: usize,
83    kernel_size: (usize, usize),
84    stride: (usize, usize),
85    padding: (usize, usize),
86    dilation: (usize, usize),
87) -> (usize, usize) {
88    let effective_kernel_h = (kernel_size.0 - 1) * dilation.0 + 1;
89    let effective_kernel_w = (kernel_size.1 - 1) * dilation.1 + 1;
90
91    let output_height = (input_height + 2 * padding.0 - effective_kernel_h) / stride.0 + 1;
92    let output_width = (input_width + 2 * padding.1 - effective_kernel_w) / stride.1 + 1;
93
94    (output_height, output_width)
95}
96
97/// Calculate adaptive pooling parameters
98#[allow(dead_code)]
99pub fn calculate_adaptive_pooling_params(
100    input_size: usize,
101    output_size: usize,
102) -> (usize, usize, usize) {
103    // Calculate stride as floor division
104    let stride = input_size / output_size;
105    // Calculate kernel _size to ensure complete coverage
106    let kernel_size = input_size - (output_size - 1) * stride;
107    // Calculate padding to center the pooling
108    let padding = 0; // No padding for adaptive pooling
109    (kernel_size, stride, padding)
110}
111
112#[cfg(test)]
113mod tests {
114    use super::*;
115
116    #[test]
117    fn test_padding_mode_calculation() {
118        let kernel_size = (3, 3);
119        let dilation = (1, 1);
120
121        assert_eq!(
122            PaddingMode::Valid.calculate_padding(kernel_size, dilation),
123            (0, 0)
124        );
125
126        assert_eq!(
127            PaddingMode::Same.calculate_padding(kernel_size, dilation),
128            (1, 1)
129        );
130
131        assert_eq!(
132            PaddingMode::Custom(2).calculate_padding(kernel_size, dilation),
133            (2, 2)
134        );
135    }
136
137    #[test]
138    fn test_outputshape_calculation() {
139        // Valid padding, stride 1
140        assert_eq!(
141            calculate_outputshape(32, 32, (3, 3), (1, 1), (0, 0), (1, 1)),
142            (30, 30)
143        );
144
145        // Same padding, stride 1
146        assert_eq!(
147            calculate_outputshape(32, 32, (3, 3), (1, 1), (1, 1), (1, 1)),
148            (32, 32)
149        );
150
151        // Stride 2
152        assert_eq!(
153            calculate_outputshape(32, 32, (3, 3), (2, 2), (1, 1), (1, 1)),
154            (16, 16)
155        );
156    }
157
158    #[test]
159    fn test_calculate_adaptive_pooling_params() {
160        let (kernel_size, stride, padding) = calculate_adaptive_pooling_params(8, 4);
161        assert_eq!(stride, 2);
162        assert_eq!(kernel_size, 2);
163        assert_eq!(padding, 0);
164    }
165}