scirs2_neural/layers/conv/
common.rs1use scirs2_core::ndarray::{Array, IxDyn};
8use std::sync::{Arc, RwLock};
9
10pub type MaxIndicesCache = Arc<RwLock<Option<Array<(usize, usize), IxDyn>>>>;
12
13pub type MaxIndicesCache3D = Arc<RwLock<Option<Array<(usize, usize, usize), IxDyn>>>>;
15
16#[derive(Debug, Clone, Copy, PartialEq, Default)]
18pub enum PaddingMode {
19 #[default]
21 Valid,
22 Same,
24 Custom(usize),
26}
27
28impl PaddingMode {
29 pub fn calculate_padding(
31 &self,
32 kernel_size: (usize, usize),
33 dilation: (usize, usize),
34 ) -> (usize, usize) {
35 match self {
36 PaddingMode::Valid => (0, 0),
37 PaddingMode::Same => (
38 (kernel_size.0 - 1) * dilation.0 / 2,
39 (kernel_size.1 - 1) * dilation.1 / 2,
40 ),
41 PaddingMode::Custom(pad) => (*pad, *pad),
42 }
43 }
44
45 pub fn as_str(&self) -> String {
47 match self {
48 PaddingMode::Valid => "valid".to_string(),
49 PaddingMode::Same => "same".to_string(),
50 PaddingMode::Custom(p) => p.to_string(),
51 }
52 }
53}
54
55#[allow(dead_code)]
57pub fn validate_conv_params(
58 in_channels: usize,
59 out_channels: usize,
60 kernel_size: (usize, usize),
61 stride: (usize, usize),
62) -> Result<(), String> {
63 if in_channels == 0 {
64 return Err("Input _channels must be greater than 0".to_string());
65 }
66 if out_channels == 0 {
67 return Err("Output _channels must be greater than 0".to_string());
68 }
69 if kernel_size.0 == 0 || kernel_size.1 == 0 {
70 return Err("Kernel _size must be greater than 0".to_string());
71 }
72 if stride.0 == 0 || stride.1 == 0 {
73 return Err("Stride must be greater than 0".to_string());
74 }
75 Ok(())
76}
77
78#[allow(dead_code)]
80pub fn calculate_outputshape(
81 input_height: usize,
82 input_width: usize,
83 kernel_size: (usize, usize),
84 stride: (usize, usize),
85 padding: (usize, usize),
86 dilation: (usize, usize),
87) -> (usize, usize) {
88 let effective_kernel_h = (kernel_size.0 - 1) * dilation.0 + 1;
89 let effective_kernel_w = (kernel_size.1 - 1) * dilation.1 + 1;
90
91 let output_height = (input_height + 2 * padding.0 - effective_kernel_h) / stride.0 + 1;
92 let output_width = (input_width + 2 * padding.1 - effective_kernel_w) / stride.1 + 1;
93
94 (output_height, output_width)
95}
96
97#[allow(dead_code)]
99pub fn calculate_adaptive_pooling_params(
100 input_size: usize,
101 output_size: usize,
102) -> (usize, usize, usize) {
103 let stride = input_size / output_size;
105 let kernel_size = input_size - (output_size - 1) * stride;
107 let padding = 0; (kernel_size, stride, padding)
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115
116 #[test]
117 fn test_padding_mode_calculation() {
118 let kernel_size = (3, 3);
119 let dilation = (1, 1);
120
121 assert_eq!(
122 PaddingMode::Valid.calculate_padding(kernel_size, dilation),
123 (0, 0)
124 );
125
126 assert_eq!(
127 PaddingMode::Same.calculate_padding(kernel_size, dilation),
128 (1, 1)
129 );
130
131 assert_eq!(
132 PaddingMode::Custom(2).calculate_padding(kernel_size, dilation),
133 (2, 2)
134 );
135 }
136
137 #[test]
138 fn test_outputshape_calculation() {
139 assert_eq!(
141 calculate_outputshape(32, 32, (3, 3), (1, 1), (0, 0), (1, 1)),
142 (30, 30)
143 );
144
145 assert_eq!(
147 calculate_outputshape(32, 32, (3, 3), (1, 1), (1, 1), (1, 1)),
148 (32, 32)
149 );
150
151 assert_eq!(
153 calculate_outputshape(32, 32, (3, 3), (2, 2), (1, 1), (1, 1)),
154 (16, 16)
155 );
156 }
157
158 #[test]
159 fn test_calculate_adaptive_pooling_params() {
160 let (kernel_size, stride, padding) = calculate_adaptive_pooling_params(8, 4);
161 assert_eq!(stride, 2);
162 assert_eq!(kernel_size, 2);
163 assert_eq!(padding, 0);
164 }
165}