1use crate::activations_minimal::Activation;
4use crate::error::{NeuralError, Result};
5use crate::layers::{Layer, ParamLayer};
6use scirs2_core::ndarray::{Array, IxDyn, ScalarOperand};
7use scirs2_core::numeric::Float;
8use scirs2_core::random::{Distribution, Uniform};
9use std::fmt::Debug;
10
11pub struct Dense<F: Float + Debug + Send + Sync> {
16 input_dim: usize,
18 output_dim: usize,
20 weights: Array<F, IxDyn>,
22 biases: Array<F, IxDyn>,
24 dweights: std::sync::RwLock<Array<F, IxDyn>>,
26 dbiases: std::sync::RwLock<Array<F, IxDyn>>,
28 activation: Option<Box<dyn Activation<F> + Send + Sync>>,
30 input: std::sync::RwLock<Option<Array<F, IxDyn>>>,
32 output_pre_activation: std::sync::RwLock<Option<Array<F, IxDyn>>>,
34}
35
36impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> std::fmt::Debug for Dense<F> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("Dense")
39 .field("input_dim", &self.input_dim)
40 .field("output_dim", &self.output_dim)
41 .field("weightsshape", &self.weights.shape())
42 .field("biasesshape", &self.biases.shape())
43 .field("has_activation", &self.activation.is_some())
44 .finish()
45 }
46}
47
48impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Clone for Dense<F> {
49 fn clone(&self) -> Self {
50 Self {
51 input_dim: self.input_dim,
52 output_dim: self.output_dim,
53 weights: self.weights.clone(),
54 biases: self.biases.clone(),
55 dweights: std::sync::RwLock::new(self.dweights.read().unwrap().clone()),
56 dbiases: std::sync::RwLock::new(self.dbiases.read().unwrap().clone()),
57 activation: None,
59 input: std::sync::RwLock::new(self.input.read().unwrap().clone()),
60 output_pre_activation: std::sync::RwLock::new(
61 self.output_pre_activation.read().unwrap().clone(),
62 ),
63 }
64 }
65}
66
67impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Dense<F> {
68 pub fn new<R: scirs2_core::random::Rng + scirs2_core::random::RngCore>(
76 input_dim: usize,
77 output_dim: usize,
78 activation_name: Option<&str>,
79 rng: &mut R,
80 ) -> Result<Self> {
81 let activation = if let Some(name) = activation_name {
83 match name.to_lowercase().as_str() {
84 "relu" => Some(Box::new(crate::activations_minimal::ReLU::new())
85 as Box<dyn Activation<F> + Send + Sync>),
86 "sigmoid" => Some(Box::new(crate::activations_minimal::Sigmoid::new())
87 as Box<dyn Activation<F> + Send + Sync>),
88 "tanh" => Some(Box::new(crate::activations_minimal::Tanh::new())
89 as Box<dyn Activation<F> + Send + Sync>),
90 "softmax" => Some(Box::new(crate::activations_minimal::Softmax::new(-1))
91 as Box<dyn Activation<F> + Send + Sync>),
92 "gelu" => Some(Box::new(crate::activations_minimal::GELU::new())
93 as Box<dyn Activation<F> + Send + Sync>),
94 _ => None,
95 }
96 } else {
97 None
98 };
99
100 let scale = F::from(1.0 / f64::sqrt(input_dim as f64)).ok_or_else(|| {
102 NeuralError::InvalidArchitecture("Failed to convert scale factor".to_string())
103 })?;
104
105 let uniform = Uniform::new(-1.0, 1.0).map_err(|e| {
107 NeuralError::InvalidArchitecture(format!("Failed to create uniform distribution: {e}"))
108 })?;
109 let weights_vec: Vec<F> = (0..(input_dim * output_dim))
110 .map(|_| {
111 let val = F::from(uniform.sample(rng)).ok_or_else(|| {
112 NeuralError::InvalidArchitecture("Failed to convert random value".to_string())
113 });
114 val.map(|v| v * scale).unwrap_or_else(|_| F::zero())
115 })
116 .collect();
117
118 let weights =
119 Array::from_shape_vec(IxDyn(&[input_dim, output_dim]), weights_vec).map_err(|e| {
120 NeuralError::InvalidArchitecture(format!("Failed to create weights array: {e}"))
121 })?;
122
123 let biases = Array::zeros(IxDyn(&[output_dim]));
125
126 let dweights = std::sync::RwLock::new(Array::zeros(weights.dim()));
128 let dbiases = std::sync::RwLock::new(Array::zeros(biases.dim()));
129
130 Ok(Self {
131 input_dim,
132 output_dim,
133 weights,
134 biases,
135 dweights,
136 dbiases,
137 activation,
138 input: std::sync::RwLock::new(None),
139 output_pre_activation: std::sync::RwLock::new(None),
140 })
141 }
142
143 pub fn input_dim(&self) -> usize {
145 self.input_dim
146 }
147
148 pub fn output_dim(&self) -> usize {
150 self.output_dim
151 }
152
153 fn compute_forward(&self, input: &Array<F, IxDyn>) -> Result<Array<F, IxDyn>> {
155 let batch_size = input.shape()[0];
156 let mut output = Array::zeros(IxDyn(&[batch_size, self.output_dim]));
157
158 for batch in 0..batch_size {
160 for out_idx in 0..self.output_dim {
161 let mut sum = F::zero();
162 for in_idx in 0..self.input_dim {
163 sum = sum + input[[batch, in_idx]] * self.weights[[in_idx, out_idx]];
164 }
165 output[[batch, out_idx]] = sum + self.biases[out_idx];
167 }
168 }
169
170 Ok(output)
171 }
172}
173
174impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> Layer<F> for Dense<F> {
175 fn forward(
176 &self,
177 input: &Array<F, scirs2_core::ndarray::IxDyn>,
178 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
179 {
181 let mut input_cache = self.input.write().unwrap();
182 *input_cache = Some(input.clone());
183 }
184
185 let input_2d = if input.ndim() == 1 {
187 input
188 .clone()
189 .into_shape_with_order(IxDyn(&[1, self.input_dim]))
190 .map_err(|e| NeuralError::InferenceError(format!("Failed to reshape input: {e}")))?
191 } else {
192 input.clone()
193 };
194
195 if input_2d.shape()[1] != self.input_dim {
197 return Err(NeuralError::InvalidArgument(format!(
198 "Input dimension mismatch: expected {}, got {}",
199 self.input_dim,
200 input_2d.shape()[1]
201 )));
202 }
203
204 let output = self.compute_forward(&input_2d)?;
206
207 {
209 let mut pre_activation_cache = self.output_pre_activation.write().unwrap();
210 *pre_activation_cache = Some(output.clone());
211 }
212
213 if let Some(ref activation) = self.activation {
215 activation.forward(&output)
216 } else {
217 Ok(output)
218 }
219 }
220
221 fn backward(
222 &self,
223 _input: &Array<F, scirs2_core::ndarray::IxDyn>,
224 grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
225 ) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
226 let cached_input = {
228 let cache = self.input.read().unwrap();
229 cache.clone().ok_or_else(|| {
230 NeuralError::InferenceError("No cached _input for backward pass".to_string())
231 })?
232 };
233
234 let pre_activation = {
235 let cache = self.output_pre_activation.read().unwrap();
236 cache.clone().ok_or_else(|| {
237 NeuralError::InferenceError(
238 "No cached pre-activation _output for backward pass".to_string(),
239 )
240 })?
241 };
242
243 let grad_pre_activation = if let Some(ref activation) = self.activation {
245 activation.backward(grad_output, &pre_activation)?
246 } else {
247 grad_output.clone()
248 };
249
250 let grad_2d = if grad_pre_activation.ndim() == 1 {
252 grad_pre_activation
253 .into_shape_with_order(IxDyn(&[1, self.output_dim]))
254 .map_err(|e| {
255 NeuralError::InferenceError(format!("Failed to reshape gradient: {e}"))
256 })?
257 } else {
258 grad_pre_activation
259 };
260
261 let input_2d = if cached_input.ndim() == 1 {
262 cached_input
263 .into_shape_with_order(IxDyn(&[1, self.input_dim]))
264 .map_err(|e| {
265 NeuralError::InferenceError(format!("Failed to reshape cached input: {e}"))
266 })?
267 } else {
268 cached_input
269 };
270
271 let batch_size = grad_2d.shape()[0];
272
273 let mut dweights = Array::zeros(IxDyn(&[self.input_dim, self.output_dim]));
275 for i in 0..self.input_dim {
276 for j in 0..self.output_dim {
277 let mut sum = F::zero();
278 for b in 0..batch_size {
279 sum = sum + input_2d[[b, i]] * grad_2d[[b, j]];
280 }
281 dweights[[i, j]] = sum;
282 }
283 }
284
285 let mut dbiases = Array::zeros(IxDyn(&[self.output_dim]));
287 for j in 0..self.output_dim {
288 let mut sum = F::zero();
289 for b in 0..batch_size {
290 sum = sum + grad_2d[[b, j]];
291 }
292 dbiases[j] = sum;
293 }
294
295 {
297 let mut dweights_guard = self.dweights.write().unwrap();
298 *dweights_guard = dweights;
299 }
300 {
301 let mut dbiases_guard = self.dbiases.write().unwrap();
302 *dbiases_guard = dbiases;
303 }
304
305 let mut grad_input = Array::zeros(IxDyn(&[batch_size, self.input_dim]));
307 for b in 0..batch_size {
308 for i in 0..self.input_dim {
309 let mut sum = F::zero();
310 for j in 0..self.output_dim {
311 sum = sum + grad_2d[[b, j]] * self.weights[[i, j]];
312 }
313 grad_input[[b, i]] = sum;
314 }
315 }
316
317 Ok(grad_input)
318 }
319
320 fn update(&mut self, learningrate: F) -> Result<()> {
321 let dweights = {
322 let dweights_guard = self.dweights.read().unwrap();
323 dweights_guard.clone()
324 };
325 let dbiases = {
326 let dbiases_guard = self.dbiases.read().unwrap();
327 dbiases_guard.clone()
328 };
329
330 for i in 0..self.input_dim {
332 for j in 0..self.output_dim {
333 self.weights[[i, j]] = self.weights[[i, j]] - learningrate * dweights[[i, j]];
334 }
335 }
336
337 for j in 0..self.output_dim {
338 self.biases[j] = self.biases[j] - learningrate * dbiases[j];
339 }
340
341 Ok(())
342 }
343
344 fn as_any(&self) -> &dyn std::any::Any {
345 self
346 }
347
348 fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
349 self
350 }
351
352 fn layer_type(&self) -> &str {
353 "Dense"
354 }
355
356 fn parameter_count(&self) -> usize {
357 self.weights.len() + self.biases.len()
358 }
359
360 fn layer_description(&self) -> String {
361 format!(
362 "type:Dense, input, _dim:{}, output, _dim:{}, params:{}",
363 self.input_dim,
364 self.output_dim,
365 self.parameter_count()
366 )
367 }
368}
369
370impl<F: Float + Debug + ScalarOperand + Send + Sync + 'static> ParamLayer<F> for Dense<F> {
371 fn get_parameters(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
372 vec![self.weights.clone(), self.biases.clone()]
373 }
374
375 fn get_gradients(&self) -> Vec<Array<F, scirs2_core::ndarray::IxDyn>> {
376 vec![]
378 }
379
380 fn set_parameters(&mut self, params: Vec<Array<F, scirs2_core::ndarray::IxDyn>>) -> Result<()> {
381 if params.len() != 2 {
382 return Err(NeuralError::InvalidArchitecture(format!(
383 "Expected 2 parameters (weights, biases), got {}",
384 params.len()
385 )));
386 }
387
388 let weights = ¶ms[0];
389 let biases = ¶ms[1];
390
391 if weights.shape() != self.weights.shape() {
392 return Err(NeuralError::InvalidArchitecture(format!(
393 "Weights shape mismatch: expected {:?}, got {:?}",
394 self.weights.shape(),
395 weights.shape()
396 )));
397 }
398
399 if biases.shape() != self.biases.shape() {
400 return Err(NeuralError::InvalidArchitecture(format!(
401 "Biases shape mismatch: expected {:?}, got {:?}",
402 self.biases.shape(),
403 biases.shape()
404 )));
405 }
406
407 self.weights = weights.clone();
408 self.biases = biases.clone();
409
410 Ok(())
411 }
412}
413
414unsafe impl<F: Float + Debug + Send + Sync> Send for Dense<F> {}
416unsafe impl<F: Float + Debug + Send + Sync> Sync for Dense<F> {}