scirs2_neural/layers/conv/
conv2d.rs1use super::common::{validate_conv_params, PaddingMode};
4use crate::error::{NeuralError, Result};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10#[derive(Debug)]
12pub struct Conv2D<F: Float + Debug + Send + Sync> {
13 #[allow(dead_code)]
14 in_channels: usize,
15 #[allow(dead_code)]
16 out_channels: usize,
17 #[allow(dead_code)]
18 kernel_size: (usize, usize),
19 #[allow(dead_code)]
20 stride: (usize, usize),
21 #[allow(dead_code)]
22 padding_mode: PaddingMode,
23 weights: Array<F, IxDyn>,
24 bias: Option<Array<F, IxDyn>>,
25 use_bias: bool,
26 name: Option<String>,
27}
28
29impl<F: Float + Debug + Send + Sync + ScalarOperand + Default> Conv2D<F> {
30 pub fn new(
32 in_channels: usize,
33 out_channels: usize,
34 kernel_size: (usize, usize),
35 stride: (usize, usize),
36 name: Option<&str>,
37 ) -> Result<Self> {
38 validate_conv_params(in_channels, out_channels, kernel_size, stride)
39 .map_err(NeuralError::InvalidArchitecture)?;
40
41 let weightsshape = vec![out_channels, in_channels, kernel_size.0, kernel_size.1];
42 let weights = Array::zeros(IxDyn(&weightsshape));
43
44 let bias = Some(Array::zeros(IxDyn(&[out_channels])));
45
46 Ok(Self {
47 in_channels,
48 out_channels,
49 kernel_size,
50 stride,
51 padding_mode: PaddingMode::Valid,
52 weights,
53 bias,
54 use_bias: true,
55 name: name.map(String::from),
56 })
57 }
58}
59
60impl<F: Float + Debug + Send + Sync + ScalarOperand + Default> Layer<F> for Conv2D<F> {
61 fn forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
62 Ok(input.clone())
64 }
65
66 fn layer_type(&self) -> &str {
67 "Conv2D"
68 }
69
70 fn inputshape(&self) -> Option<Vec<usize>> {
71 None
72 }
73
74 fn outputshape(&self) -> Option<Vec<usize>> {
75 None
76 }
77
78 fn name(&self) -> Option<&str> {
79 self.name.as_deref()
80 }
81
82 fn backward(
83 &self,
84 _input: &Array<F, scirs2_core::ndarray::IxDyn>,
85 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
86 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
87 Ok(grad_output.clone())
89 }
90
91 fn update(&mut self, _learningrate: F) -> Result<()> {
92 Ok(())
94 }
95
96 fn as_any(&self) -> &dyn std::any::Any {
97 self
98 }
99
100 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
101 self
102 }
103
104 fn parameter_count(&self) -> usize {
105 let weights_count =
106 self.out_channels * self.in_channels * self.kernel_size.0 * self.kernel_size.1;
107 let bias_count = if self.use_bias { self.out_channels } else { 0 };
108 weights_count + bias_count
109 }
110}
111
112impl<F: Float + Debug + Send + Sync + ScalarOperand + Default> ParamLayer<F> for Conv2D<F> {
113 fn get_parameters(&self) -> Vec<Array<F, IxDyn>> {
114 let mut params = vec![self.weights.clone()];
115 if let Some(ref bias) = self.bias {
116 params.push(bias.clone());
117 }
118 params
119 }
120
121 fn set_parameters(&mut self, params: Vec<Array<F, IxDyn>>) -> Result<()> {
122 match (self.use_bias, params.len()) {
123 (true, 2) => {
124 self.weights = params[0].clone();
125 self.bias = Some(params[1].clone());
126 }
127 (false, 1) => {
128 self.weights = params[0].clone();
129 }
130 _ => {
131 let expected = if self.use_bias { 2 } else { 1 };
132 let got = params.len();
133 return Err(NeuralError::InvalidArchitecture(format!(
134 "Expected {expected} parameters, got {got}"
135 )));
136 }
137 }
138 Ok(())
139 }
140
141 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
142 Vec::new()
144 }
145}