1use ::ndarray::{Array, Ix1};
13
14use rand::{Rng, RngExt};
15
16use crate::array_protocol::ml_ops::ActivationFunc;
17use crate::array_protocol::operations::OperationError;
18use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
19
20pub trait Layer: Send + Sync {
22 fn layer_type(&self) -> &str;
25
26 fn forward(&self, inputs: &dyn ArrayProtocol)
27 -> Result<Box<dyn ArrayProtocol>, OperationError>;
28
29 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>>;
31
32 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>>;
34
35 fn update_parameter(
37 &mut self,
38 name: &str,
39 value: Box<dyn ArrayProtocol>,
40 ) -> Result<(), OperationError>;
41
42 fn parameter_names(&self) -> Vec<String>;
44
45 fn train(&mut self);
47
48 fn eval(&mut self);
50
51 fn is_training(&self) -> bool;
53
54 fn name(&self) -> &str;
56}
57
58pub struct Linear {
60 name: String,
62
63 weights: Box<dyn ArrayProtocol>,
65
66 bias: Option<Box<dyn ArrayProtocol>>,
68
69 activation: Option<ActivationFunc>,
71
72 training: bool,
74}
75
76impl Linear {
77 pub fn new(
79 name: &str,
80 weights: Box<dyn ArrayProtocol>,
81 bias: Option<Box<dyn ArrayProtocol>>,
82 activation: Option<ActivationFunc>,
83 ) -> Self {
84 Self {
85 name: name.to_string(),
86 weights,
87 bias,
88 activation,
89 training: true,
90 }
91 }
92
93 pub fn new_random(
95 name: &str,
96 in_features: usize,
97 out_features: usize,
98 withbias: bool,
99 activation: Option<ActivationFunc>,
100 ) -> Self {
101 let scale = (6.0 / (in_features + out_features) as f64).sqrt();
103 let mut rng = rand::rng();
104 let weights = Array::from_shape_fn((out_features, in_features), |_| {
105 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
106 });
107
108 let bias = if withbias {
110 let bias_array: Array<f64, Ix1> = Array::zeros(out_features);
111 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
112 } else {
113 None
114 };
115
116 Self {
117 name: name.to_string(),
118 weights: Box::new(NdarrayWrapper::new(weights)),
119 bias,
120 activation,
121 training: true,
122 }
123 }
124}
125
126impl Layer for Linear {
127 fn layer_type(&self) -> &str {
128 "Linear"
129 }
130
131 fn forward(
132 &self,
133 inputs: &dyn ArrayProtocol,
134 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
135 let mut result = crate::array_protocol::matmul(self.weights.as_ref(), inputs)?;
137
138 if let Some(bias) = &self.bias {
140 let intermediate = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
142 result = intermediate;
143 }
144
145 if let Some(act_fn) = self.activation {
147 let intermediate = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
149 result = intermediate;
150 }
151
152 Ok(result)
153 }
154
155 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
156 let mut params = vec![self.weights.clone()];
157 if let Some(bias) = &self.bias {
158 params.push(bias.clone());
159 }
160 params
161 }
162
163 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
164 let mut params = vec![&mut self.weights];
165 if let Some(bias) = &mut self.bias {
166 params.push(bias);
167 }
168 params
169 }
170
171 fn update_parameter(
172 &mut self,
173 name: &str,
174 value: Box<dyn ArrayProtocol>,
175 ) -> Result<(), OperationError> {
176 match name {
177 "weights" => {
178 self.weights = value;
179 Ok(())
180 }
181 "bias" => {
182 self.bias = Some(value);
183 Ok(())
184 }
185 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
186 }
187 }
188
189 fn parameter_names(&self) -> Vec<String> {
190 let mut names = vec!["weights".to_string()];
191 if self.bias.is_some() {
192 names.push("bias".to_string());
193 }
194 names
195 }
196
197 fn train(&mut self) {
198 self.training = true;
199 }
200
201 fn eval(&mut self) {
202 self.training = false;
203 }
204
205 fn is_training(&self) -> bool {
206 self.training
207 }
208
209 fn name(&self) -> &str {
210 &self.name
211 }
212}
213
214pub struct Conv2D {
216 name: String,
218
219 filters: Box<dyn ArrayProtocol>,
221
222 bias: Option<Box<dyn ArrayProtocol>>,
224
225 stride: (usize, usize),
227
228 padding: (usize, usize),
230
231 activation: Option<ActivationFunc>,
233
234 training: bool,
236}
237
238impl Conv2D {
239 pub fn new(
241 name: &str,
242 filters: Box<dyn ArrayProtocol>,
243 bias: Option<Box<dyn ArrayProtocol>>,
244 stride: (usize, usize),
245 padding: (usize, usize),
246 activation: Option<ActivationFunc>,
247 ) -> Self {
248 Self {
249 name: name.to_string(),
250 filters,
251 bias,
252 stride,
253 padding,
254 activation,
255 training: true,
256 }
257 }
258
259 #[allow(clippy::too_many_arguments)]
261 pub fn withshape(
262 name: &str,
263 filter_height: usize,
264 filter_width: usize,
265 in_channels: usize,
266 out_channels: usize,
267 stride: (usize, usize),
268 padding: (usize, usize),
269 withbias: bool,
270 activation: Option<ActivationFunc>,
271 ) -> Self {
272 let fan_in = filter_height * filter_width * in_channels;
274 let scale = (2.0 / fan_in as f64).sqrt();
275 let mut rng = rand::rng();
276 let filters = Array::from_shape_fn(
277 (filter_height, filter_width, in_channels, out_channels),
278 |_| (rng.random::<f64>() * 2.0_f64 - 1.0) * scale,
279 );
280
281 let bias = if withbias {
283 let bias_array: Array<f64, Ix1> = Array::zeros(out_channels);
284 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
285 } else {
286 None
287 };
288
289 Self {
290 name: name.to_string(),
291 filters: Box::new(NdarrayWrapper::new(filters)),
292 bias,
293 stride,
294 padding,
295 activation,
296 training: true,
297 }
298 }
299}
300
301impl Layer for Conv2D {
302 fn layer_type(&self) -> &str {
303 "Conv2D"
304 }
305
306 fn forward(
307 &self,
308 inputs: &dyn ArrayProtocol,
309 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
310 let mut result = crate::array_protocol::ml_ops::conv2d(
312 inputs,
313 self.filters.as_ref(),
314 self.stride,
315 self.padding,
316 )?;
317
318 if let Some(bias) = &self.bias {
320 result = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
321 }
322
323 if let Some(act_fn) = self.activation {
325 result = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
326 }
327
328 Ok(result)
329 }
330
331 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
332 let mut params = vec![self.filters.clone()];
333 if let Some(bias) = &self.bias {
334 params.push(bias.clone());
335 }
336 params
337 }
338
339 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
340 let mut params = vec![&mut self.filters];
341 if let Some(bias) = &mut self.bias {
342 params.push(bias);
343 }
344 params
345 }
346
347 fn update_parameter(
348 &mut self,
349 name: &str,
350 value: Box<dyn ArrayProtocol>,
351 ) -> Result<(), OperationError> {
352 match name {
353 "filters" => {
354 self.filters = value;
355 Ok(())
356 }
357 "bias" => {
358 self.bias = Some(value);
359 Ok(())
360 }
361 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
362 }
363 }
364
365 fn parameter_names(&self) -> Vec<String> {
366 let mut names = vec!["filters".to_string()];
367 if self.bias.is_some() {
368 names.push("bias".to_string());
369 }
370 names
371 }
372
373 fn train(&mut self) {
374 self.training = true;
375 }
376
377 fn eval(&mut self) {
378 self.training = false;
379 }
380
381 fn is_training(&self) -> bool {
382 self.training
383 }
384
385 fn name(&self) -> &str {
386 &self.name
387 }
388}
389
390pub struct Conv2DBuilder {
392 name: String,
393 filter_height: usize,
394 filter_width: usize,
395 in_channels: usize,
396 out_channels: usize,
397 stride: (usize, usize),
398 padding: (usize, usize),
399 withbias: bool,
400 activation: Option<ActivationFunc>,
401}
402
403impl Conv2DBuilder {
404 pub fn new(name: &str) -> Self {
406 Self {
407 name: name.to_string(),
408 filter_height: 3,
409 filter_width: 3,
410 in_channels: 1,
411 out_channels: 1,
412 stride: (1, 1),
413 padding: (0, 0),
414 withbias: true,
415 activation: None,
416 }
417 }
418
419 pub const fn filter_size(mut self, height: usize, width: usize) -> Self {
421 self.filter_height = height;
422 self.filter_width = width;
423 self
424 }
425
426 pub const fn channels(mut self, input: usize, output: usize) -> Self {
428 self.in_channels = input;
429 self.out_channels = output;
430 self
431 }
432
433 pub fn stride(mut self, stride: (usize, usize)) -> Self {
435 self.stride = stride;
436 self
437 }
438
439 pub fn padding(mut self, padding: (usize, usize)) -> Self {
441 self.padding = padding;
442 self
443 }
444
445 pub fn withbias(mut self, withbias: bool) -> Self {
447 self.withbias = withbias;
448 self
449 }
450
451 pub fn activation(mut self, activation: ActivationFunc) -> Self {
453 self.activation = Some(activation);
454 self
455 }
456
457 pub fn build(self) -> Conv2D {
459 Conv2D::withshape(
460 &self.name,
461 self.filter_height,
462 self.filter_width,
463 self.in_channels,
464 self.out_channels,
465 self.stride,
466 self.padding,
467 self.withbias,
468 self.activation,
469 )
470 }
471}
472
473#[allow(dead_code)]
475pub struct MaxPool2D {
476 name: String,
478
479 kernel_size: (usize, usize),
481
482 stride: (usize, usize),
484
485 padding: (usize, usize),
487
488 training: bool,
490}
491
492impl MaxPool2D {
493 pub fn new(
495 name: &str,
496 kernel_size: (usize, usize),
497 stride: Option<(usize, usize)>,
498 padding: (usize, usize),
499 ) -> Self {
500 let stride = stride.unwrap_or(kernel_size);
501
502 Self {
503 name: name.to_string(),
504 kernel_size,
505 stride,
506 padding,
507 training: true,
508 }
509 }
510}
511
512impl Layer for MaxPool2D {
513 fn layer_type(&self) -> &str {
514 "MaxPool2D"
515 }
516
517 fn forward(
518 &self,
519 inputs: &dyn ArrayProtocol,
520 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
521 crate::array_protocol::ml_ops::max_pool2d(
523 inputs,
524 self.kernel_size,
525 self.stride,
526 self.padding,
527 )
528 }
529
530 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
531 Vec::new()
533 }
534
535 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
536 Vec::new()
538 }
539
540 fn update_parameter(
541 &mut self,
542 name: &str,
543 _value: Box<dyn ArrayProtocol>,
544 ) -> Result<(), OperationError> {
545 Err(OperationError::Other(format!(
546 "MaxPool2D has no parameter: {name}"
547 )))
548 }
549
550 fn parameter_names(&self) -> Vec<String> {
551 Vec::new()
553 }
554
555 fn train(&mut self) {
556 self.training = true;
557 }
558
559 fn eval(&mut self) {
560 self.training = false;
561 }
562
563 fn is_training(&self) -> bool {
564 self.training
565 }
566
567 fn name(&self) -> &str {
568 &self.name
569 }
570}
571
572pub struct BatchNorm {
574 name: String,
576
577 scale: Box<dyn ArrayProtocol>,
579
580 offset: Box<dyn ArrayProtocol>,
582
583 running_mean: Box<dyn ArrayProtocol>,
585
586 running_var: Box<dyn ArrayProtocol>,
588
589 epsilon: f64,
591
592 training: bool,
594}
595
596impl BatchNorm {
597 pub fn new(
599 name: &str,
600 scale: Box<dyn ArrayProtocol>,
601 offset: Box<dyn ArrayProtocol>,
602 running_mean: Box<dyn ArrayProtocol>,
603 running_var: Box<dyn ArrayProtocol>,
604 epsilon: f64,
605 ) -> Self {
606 Self {
607 name: name.to_string(),
608 scale,
609 offset,
610 running_mean,
611 running_var,
612 epsilon,
613 training: true,
614 }
615 }
616
617 pub fn withshape(
619 name: &str,
620 num_features: usize,
621 epsilon: Option<f64>,
622 _momentum: Option<f64>,
623 ) -> Self {
624 let scale: Array<f64, Ix1> = Array::ones(num_features);
626 let offset: Array<f64, Ix1> = Array::zeros(num_features);
627 let running_mean: Array<f64, Ix1> = Array::zeros(num_features);
628 let running_var: Array<f64, Ix1> = Array::ones(num_features);
629
630 Self {
631 name: name.to_string(),
632 scale: Box::new(NdarrayWrapper::new(scale)),
633 offset: Box::new(NdarrayWrapper::new(offset)),
634 running_mean: Box::new(NdarrayWrapper::new(running_mean)),
635 running_var: Box::new(NdarrayWrapper::new(running_var)),
636 epsilon: epsilon.unwrap_or(1e-5),
637 training: true,
638 }
639 }
640}
641
642impl Layer for BatchNorm {
643 fn layer_type(&self) -> &str {
644 "BatchNorm"
645 }
646
647 fn forward(
648 &self,
649 inputs: &dyn ArrayProtocol,
650 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
651 crate::array_protocol::ml_ops::batch_norm(
652 inputs,
653 self.scale.as_ref(),
654 self.offset.as_ref(),
655 self.running_mean.as_ref(),
656 self.running_var.as_ref(),
657 self.epsilon,
658 )
659 }
660
661 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
662 vec![self.scale.clone(), self.offset.clone()]
663 }
664
665 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
666 vec![&mut self.scale, &mut self.offset]
667 }
668
669 fn update_parameter(
670 &mut self,
671 name: &str,
672 value: Box<dyn ArrayProtocol>,
673 ) -> Result<(), OperationError> {
674 match name {
675 "scale" => {
676 self.scale = value;
677 Ok(())
678 }
679 "offset" => {
680 self.offset = value;
681 Ok(())
682 }
683 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
684 }
685 }
686
687 fn parameter_names(&self) -> Vec<String> {
688 vec!["scale".to_string(), "offset".to_string()]
689 }
690
691 fn train(&mut self) {
692 self.training = true;
693 }
694
695 fn eval(&mut self) {
696 self.training = false;
697 }
698
699 fn is_training(&self) -> bool {
700 self.training
701 }
702
703 fn name(&self) -> &str {
704 &self.name
705 }
706}
707
708pub struct Dropout {
710 name: String,
712
713 rate: f64,
715
716 seed: Option<u64>,
718
719 training: bool,
721}
722
723impl Dropout {
724 pub fn new(name: &str, rate: f64, seed: Option<u64>) -> Self {
726 Self {
727 name: name.to_string(),
728 rate,
729 seed,
730 training: true,
731 }
732 }
733}
734
735impl Layer for Dropout {
736 fn layer_type(&self) -> &str {
737 "Dropout"
738 }
739
740 fn forward(
741 &self,
742 inputs: &dyn ArrayProtocol,
743 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
744 crate::array_protocol::ml_ops::dropout(inputs, self.rate, self.training, self.seed)
745 }
746
747 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
748 Vec::new()
750 }
751
752 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
753 Vec::new()
755 }
756
757 fn update_parameter(
758 &mut self,
759 name: &str,
760 _value: Box<dyn ArrayProtocol>,
761 ) -> Result<(), OperationError> {
762 Err(OperationError::Other(format!(
763 "Dropout has no parameter: {name}"
764 )))
765 }
766
767 fn parameter_names(&self) -> Vec<String> {
768 Vec::new()
770 }
771
772 fn train(&mut self) {
773 self.training = true;
774 }
775
776 fn eval(&mut self) {
777 self.training = false;
778 }
779
780 fn is_training(&self) -> bool {
781 self.training
782 }
783
784 fn name(&self) -> &str {
785 &self.name
786 }
787}
788
789pub struct MultiHeadAttention {
791 name: String,
793
794 wq: Box<dyn ArrayProtocol>,
796
797 wk: Box<dyn ArrayProtocol>,
799
800 wv: Box<dyn ArrayProtocol>,
802
803 wo: Box<dyn ArrayProtocol>,
805
806 num_heads: usize,
808
809 dmodel: usize,
811
812 training: bool,
814}
815
816impl MultiHeadAttention {
817 pub fn new(
819 name: &str,
820 wq: Box<dyn ArrayProtocol>,
821 wk: Box<dyn ArrayProtocol>,
822 wv: Box<dyn ArrayProtocol>,
823 wo: Box<dyn ArrayProtocol>,
824 num_heads: usize,
825 dmodel: usize,
826 ) -> Self {
827 Self {
828 name: name.to_string(),
829 wq,
830 wk,
831 wv,
832 wo,
833 num_heads,
834 dmodel,
835 training: true,
836 }
837 }
838
839 pub fn with_params(name: &str, num_heads: usize, dmodel: usize) -> Self {
841 assert!(
843 dmodel % num_heads == 0,
844 "dmodel must be divisible by num_heads"
845 );
846
847 let scale = (1.0_f64 / dmodel as f64).sqrt();
849 let mut rng = rand::rng();
850
851 let wq = Array::from_shape_fn((dmodel, dmodel), |_| {
852 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
853 });
854
855 let wk = Array::from_shape_fn((dmodel, dmodel), |_| {
856 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
857 });
858
859 let wv = Array::from_shape_fn((dmodel, dmodel), |_| {
860 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
861 });
862
863 let wo = Array::from_shape_fn((dmodel, dmodel), |_| {
864 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
865 });
866
867 Self {
868 name: name.to_string(),
869 wq: Box::new(NdarrayWrapper::new(wq)),
870 wk: Box::new(NdarrayWrapper::new(wk)),
871 wv: Box::new(NdarrayWrapper::new(wv)),
872 wo: Box::new(NdarrayWrapper::new(wo)),
873 num_heads,
874 dmodel,
875 training: true,
876 }
877 }
878}
879
880impl Layer for MultiHeadAttention {
881 fn layer_type(&self) -> &str {
882 "MultiHeadAttention"
883 }
884
885 fn forward(
886 &self,
887 inputs: &dyn ArrayProtocol,
888 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
889 let queries = crate::array_protocol::matmul(self.wq.as_ref(), inputs)?;
897 let keys = crate::array_protocol::matmul(self.wk.as_ref(), inputs)?;
898 let values = crate::array_protocol::matmul(self.wv.as_ref(), inputs)?;
899
900 let attention = crate::array_protocol::ml_ops::self_attention(
902 queries.as_ref(),
903 keys.as_ref(),
904 values.as_ref(),
905 None,
906 Some((self.dmodel / self.num_heads) as f64),
907 )?;
908
909 let output = crate::array_protocol::matmul(self.wo.as_ref(), attention.as_ref())?;
911
912 Ok(output)
913 }
914
915 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
916 vec![
917 self.wq.clone(),
918 self.wk.clone(),
919 self.wv.clone(),
920 self.wo.clone(),
921 ]
922 }
923
924 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
925 vec![&mut self.wq, &mut self.wk, &mut self.wv, &mut self.wo]
926 }
927
928 fn update_parameter(
929 &mut self,
930 name: &str,
931 value: Box<dyn ArrayProtocol>,
932 ) -> Result<(), OperationError> {
933 match name {
934 "wq" => {
935 self.wq = value;
936 Ok(())
937 }
938 "wk" => {
939 self.wk = value;
940 Ok(())
941 }
942 "wv" => {
943 self.wv = value;
944 Ok(())
945 }
946 "wo" => {
947 self.wo = value;
948 Ok(())
949 }
950 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
951 }
952 }
953
954 fn parameter_names(&self) -> Vec<String> {
955 vec![
956 "wq".to_string(),
957 "wk".to_string(),
958 "wv".to_string(),
959 "wo".to_string(),
960 ]
961 }
962
963 fn train(&mut self) {
964 self.training = true;
965 }
966
967 fn eval(&mut self) {
968 self.training = false;
969 }
970
971 fn is_training(&self) -> bool {
972 self.training
973 }
974
975 fn name(&self) -> &str {
976 &self.name
977 }
978}
979
980pub struct Sequential {
982 name: String,
984
985 layers: Vec<Box<dyn Layer>>,
987
988 training: bool,
990}
991
992impl Sequential {
993 pub fn new(name: &str, layers: Vec<Box<dyn Layer>>) -> Self {
995 Self {
996 name: name.to_string(),
997 layers,
998 training: true,
999 }
1000 }
1001
1002 pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
1004 self.layers.push(layer);
1005 }
1006
1007 pub fn forward(
1009 &self,
1010 inputs: &dyn ArrayProtocol,
1011 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1012 let mut x: Box<dyn ArrayProtocol> = inputs.box_clone();
1014
1015 for layer in &self.layers {
1016 let x_ref: &dyn ArrayProtocol = x.as_ref();
1018 x = layer.forward(x_ref)?;
1020 }
1021
1022 Ok(x)
1023 }
1024
1025 pub fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1027 let mut params = Vec::new();
1028
1029 for layer in &self.layers {
1030 params.extend(layer.parameters());
1031 }
1032
1033 params
1034 }
1035
1036 pub fn train(&mut self) {
1038 self.training = true;
1039
1040 for layer in &mut self.layers {
1041 layer.train();
1042 }
1043 }
1044
1045 pub fn eval(&mut self) {
1047 self.training = false;
1048
1049 for layer in &mut self.layers {
1050 layer.eval();
1051 }
1052 }
1053
1054 pub fn name(&self) -> &str {
1056 &self.name
1057 }
1058
1059 pub fn layers(&self) -> &[Box<dyn Layer>] {
1061 &self.layers
1062 }
1063
1064 pub fn backward(
1066 &self,
1067 _output: &dyn ArrayProtocol,
1068 _target: &dyn ArrayProtocol,
1069 ) -> Result<crate::array_protocol::grad::GradientDict, crate::error::CoreError> {
1070 Ok(crate::array_protocol::grad::GradientDict::new())
1073 }
1074
1075 pub fn update_parameter(
1077 &mut self,
1078 param_name: &str,
1079 gradient: &dyn ArrayProtocol,
1080 learningrate: f64,
1081 ) -> Result<(), crate::error::CoreError> {
1082 let parts: Vec<&str> = param_name.split('.').collect();
1084 if parts.len() != 2 {
1085 return Err(crate::error::CoreError::ValueError(
1086 crate::error::ErrorContext::new(format!(
1087 "Invalid parameter name format. Expected 'layer_index.param_name', got: {param_name}"
1088 )),
1089 ));
1090 }
1091
1092 let layer_index: usize = parts[0].parse().map_err(|_| {
1093 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1094 "Invalid layer index: {layer_idx}",
1095 layer_idx = parts[0]
1096 )))
1097 })?;
1098
1099 let param_name = parts[1];
1100
1101 if layer_index >= self.layers.len() {
1102 return Err(crate::error::CoreError::ValueError(
1103 crate::error::ErrorContext::new(format!(
1104 "Layer index {layer_index} out of bounds (model has {num_layers} layers)",
1105 num_layers = self.layers.len()
1106 )),
1107 ));
1108 }
1109
1110 let layer = &mut self.layers[layer_index];
1112 let current_params = layer.parameters();
1113 let param_names = layer.parameter_names();
1114
1115 let param_idx = param_names
1117 .iter()
1118 .position(|name| name == param_name)
1119 .ok_or_else(|| {
1120 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1121 "Parameter '{param_name}' not found in layer {layer_index}"
1122 )))
1123 })?;
1124
1125 let current_param = ¤t_params[param_idx];
1127
1128 let scaled_gradient =
1130 crate::array_protocol::operations::multiply_by_scalar_f64(gradient, learningrate)
1131 .map_err(|e| {
1132 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(
1133 format!("Failed to scale gradient: {e}"),
1134 ))
1135 })?;
1136
1137 let updated_param = crate::array_protocol::operations::subtract(
1139 current_param.as_ref(),
1140 scaled_gradient.as_ref(),
1141 )
1142 .map_err(|e| {
1143 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1144 "Failed to update parameter: {e}"
1145 )))
1146 })?;
1147
1148 layer
1150 .update_parameter(param_name, updated_param)
1151 .map_err(|e| {
1152 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1153 "Failed to set parameter in layer: {e}"
1154 )))
1155 })?;
1156
1157 Ok(())
1158 }
1159
1160 pub fn all_parameter_names(&self) -> Vec<String> {
1162 let mut all_names = Vec::new();
1163 for (layer_idx, layer) in self.layers.iter().enumerate() {
1164 let layer_param_names = layer.parameter_names();
1165 for param_name in layer_param_names {
1166 all_names.push(format!("{layer_idx}.{param_name}"));
1167 }
1168 }
1169 all_names
1170 }
1171
1172 pub fn all_parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1174 let mut all_params = Vec::new();
1175 for layer in &self.layers {
1176 all_params.extend(layer.parameters());
1177 }
1178 all_params
1179 }
1180}
1181
1182#[allow(dead_code)]
1184pub fn create_simple_cnn(inputshape: (usize, usize, usize), num_classes: usize) -> Sequential {
1185 let (height, width, channels) = inputshape;
1186
1187 let mut model = Sequential::new("SimpleCNN", Vec::new());
1188
1189 model.add_layer(Box::new(Conv2D::withshape(
1191 "conv1",
1192 3,
1193 3, channels,
1195 32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1200 )));
1201
1202 model.add_layer(Box::new(MaxPool2D::new(
1203 "pool1",
1204 (2, 2), None, (0, 0), )));
1208
1209 model.add_layer(Box::new(Conv2D::withshape(
1211 "conv2",
1212 3,
1213 3, 32,
1215 64, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1220 )));
1221
1222 model.add_layer(Box::new(MaxPool2D::new(
1223 "pool2",
1224 (2, 2), None, (0, 0), )));
1228
1229 model.add_layer(Box::new(Linear::new_random(
1233 "fc1",
1234 64 * (height / 4) * (width / 4), 128, true, Some(ActivationFunc::ReLU),
1238 )));
1239
1240 model.add_layer(Box::new(Dropout::new(
1241 "dropout", 0.5, None, )));
1244
1245 model.add_layer(Box::new(Linear::new_random(
1246 "fc2",
1247 128, num_classes, true, None, )));
1252
1253 model
1254}
1255
1256#[cfg(test)]
1257mod tests {
1258 use super::*;
1259 use crate::array_protocol::{self, NdarrayWrapper};
1260 use ndarray::{Array1, Array2};
1261
1262 #[test]
1263 fn test_linear_layer() {
1264 array_protocol::init();
1266
1267 let weights = Array2::<f64>::eye(3);
1269 let bias = Array1::<f64>::ones(3);
1270
1271 let layer = Linear::new(
1272 "linear",
1273 Box::new(NdarrayWrapper::new(weights)),
1274 Some(Box::new(NdarrayWrapper::new(bias))),
1275 Some(ActivationFunc::ReLU),
1276 );
1277
1278 assert_eq!(layer.name(), "linear");
1289 assert!(layer.is_training());
1290 }
1291
1292 #[test]
1293 fn test_sequential_model() {
1294 array_protocol::init();
1296
1297 let mut model = Sequential::new("test_model", Vec::new());
1299
1300 model.add_layer(Box::new(Linear::new_random(
1302 "fc1",
1303 3, 2, true, Some(ActivationFunc::ReLU),
1307 )));
1308
1309 model.add_layer(Box::new(Linear::new_random(
1310 "fc2",
1311 2, 1, true, Some(ActivationFunc::Sigmoid),
1315 )));
1316
1317 assert_eq!(model.name(), "test_model");
1319 assert_eq!(model.layers().len(), 2);
1320 assert!(model.training);
1321 }
1322
1323 #[test]
1324 fn test_simple_cnn_creation() {
1325 array_protocol::init();
1327
1328 let model = create_simple_cnn((28, 28, 1), 10);
1330
1331 assert_eq!(model.layers().len(), 7);
1333 assert_eq!(model.name(), "SimpleCNN");
1334
1335 let params = model.parameters();
1337 assert!(!params.is_empty());
1338 }
1339}