1use ndarray::{Array, Ix1};
19use rand::Rng;
20
21use crate::array_protocol::ml_ops::ActivationFunc;
22use crate::array_protocol::operations::OperationError;
23use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
24
25pub trait Layer: Send + Sync {
27 fn layer_type(&self) -> &str;
30
31 fn forward(&self, inputs: &dyn ArrayProtocol)
32 -> Result<Box<dyn ArrayProtocol>, OperationError>;
33
34 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>>;
36
37 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>>;
39
40 fn update_parameter(
42 &mut self,
43 name: &str,
44 value: Box<dyn ArrayProtocol>,
45 ) -> Result<(), OperationError>;
46
47 fn parameter_names(&self) -> Vec<String>;
49
50 fn train(&mut self);
52
53 fn eval(&mut self);
55
56 fn is_training(&self) -> bool;
58
59 fn name(&self) -> &str;
61}
62
63pub struct Linear {
65 name: String,
67
68 weights: Box<dyn ArrayProtocol>,
70
71 bias: Option<Box<dyn ArrayProtocol>>,
73
74 activation: Option<ActivationFunc>,
76
77 training: bool,
79}
80
81impl Linear {
82 pub fn new(
84 name: &str,
85 weights: Box<dyn ArrayProtocol>,
86 bias: Option<Box<dyn ArrayProtocol>>,
87 activation: Option<ActivationFunc>,
88 ) -> Self {
89 Self {
90 name: name.to_string(),
91 weights,
92 bias,
93 activation,
94 training: true,
95 }
96 }
97
98 pub fn new_random(
100 name: &str,
101 in_features: usize,
102 out_features: usize,
103 withbias: bool,
104 activation: Option<ActivationFunc>,
105 ) -> Self {
106 let scale = (6.0 / (in_features + out_features) as f64).sqrt();
108 let mut rng = rand::rng();
109 let weights = Array::from_shape_fn((out_features, in_features), |_| {
110 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
111 });
112
113 let bias = if withbias {
115 let bias_array: Array<f64, Ix1> = Array::zeros(out_features);
116 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
117 } else {
118 None
119 };
120
121 Self {
122 name: name.to_string(),
123 weights: Box::new(NdarrayWrapper::new(weights)),
124 bias,
125 activation,
126 training: true,
127 }
128 }
129}
130
131impl Layer for Linear {
132 fn layer_type(&self) -> &str {
133 "Linear"
134 }
135
136 fn forward(
137 &self,
138 inputs: &dyn ArrayProtocol,
139 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
140 let mut result = crate::array_protocol::matmul(self.weights.as_ref(), inputs)?;
142
143 if let Some(bias) = &self.bias {
145 let intermediate = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
147 result = intermediate;
148 }
149
150 if let Some(act_fn) = self.activation {
152 let intermediate = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
154 result = intermediate;
155 }
156
157 Ok(result)
158 }
159
160 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
161 let mut params = vec![self.weights.clone()];
162 if let Some(bias) = &self.bias {
163 params.push(bias.clone());
164 }
165 params
166 }
167
168 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
169 let mut params = vec![&mut self.weights];
170 if let Some(bias) = &mut self.bias {
171 params.push(bias);
172 }
173 params
174 }
175
176 fn update_parameter(
177 &mut self,
178 name: &str,
179 value: Box<dyn ArrayProtocol>,
180 ) -> Result<(), OperationError> {
181 match name {
182 "weights" => {
183 self.weights = value;
184 Ok(())
185 }
186 "bias" => {
187 self.bias = Some(value);
188 Ok(())
189 }
190 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
191 }
192 }
193
194 fn parameter_names(&self) -> Vec<String> {
195 let mut names = vec!["weights".to_string()];
196 if self.bias.is_some() {
197 names.push("bias".to_string());
198 }
199 names
200 }
201
202 fn train(&mut self) {
203 self.training = true;
204 }
205
206 fn eval(&mut self) {
207 self.training = false;
208 }
209
210 fn is_training(&self) -> bool {
211 self.training
212 }
213
214 fn name(&self) -> &str {
215 &self.name
216 }
217}
218
219pub struct Conv2D {
221 name: String,
223
224 filters: Box<dyn ArrayProtocol>,
226
227 bias: Option<Box<dyn ArrayProtocol>>,
229
230 stride: (usize, usize),
232
233 padding: (usize, usize),
235
236 activation: Option<ActivationFunc>,
238
239 training: bool,
241}
242
243impl Conv2D {
244 pub fn new(
246 name: &str,
247 filters: Box<dyn ArrayProtocol>,
248 bias: Option<Box<dyn ArrayProtocol>>,
249 stride: (usize, usize),
250 padding: (usize, usize),
251 activation: Option<ActivationFunc>,
252 ) -> Self {
253 Self {
254 name: name.to_string(),
255 filters,
256 bias,
257 stride,
258 padding,
259 activation,
260 training: true,
261 }
262 }
263
264 #[allow(clippy::too_many_arguments)]
266 pub fn withshape(
267 name: &str,
268 filter_height: usize,
269 filter_width: usize,
270 in_channels: usize,
271 out_channels: usize,
272 stride: (usize, usize),
273 padding: (usize, usize),
274 withbias: bool,
275 activation: Option<ActivationFunc>,
276 ) -> Self {
277 let fan_in = filter_height * filter_width * in_channels;
279 let scale = (2.0 / fan_in as f64).sqrt();
280 let mut rng = rand::rng();
281 let filters = Array::from_shape_fn(
282 (filter_height, filter_width, in_channels, out_channels),
283 |_| (rng.random::<f64>() * 2.0_f64 - 1.0) * scale,
284 );
285
286 let bias = if withbias {
288 let bias_array: Array<f64, Ix1> = Array::zeros(out_channels);
289 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
290 } else {
291 None
292 };
293
294 Self {
295 name: name.to_string(),
296 filters: Box::new(NdarrayWrapper::new(filters)),
297 bias,
298 stride,
299 padding,
300 activation,
301 training: true,
302 }
303 }
304}
305
306impl Layer for Conv2D {
307 fn layer_type(&self) -> &str {
308 "Conv2D"
309 }
310
311 fn forward(
312 &self,
313 inputs: &dyn ArrayProtocol,
314 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
315 let mut result = crate::array_protocol::ml_ops::conv2d(
317 inputs,
318 self.filters.as_ref(),
319 self.stride,
320 self.padding,
321 )?;
322
323 if let Some(bias) = &self.bias {
325 result = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
326 }
327
328 if let Some(act_fn) = self.activation {
330 result = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
331 }
332
333 Ok(result)
334 }
335
336 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
337 let mut params = vec![self.filters.clone()];
338 if let Some(bias) = &self.bias {
339 params.push(bias.clone());
340 }
341 params
342 }
343
344 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
345 let mut params = vec![&mut self.filters];
346 if let Some(bias) = &mut self.bias {
347 params.push(bias);
348 }
349 params
350 }
351
352 fn update_parameter(
353 &mut self,
354 name: &str,
355 value: Box<dyn ArrayProtocol>,
356 ) -> Result<(), OperationError> {
357 match name {
358 "filters" => {
359 self.filters = value;
360 Ok(())
361 }
362 "bias" => {
363 self.bias = Some(value);
364 Ok(())
365 }
366 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
367 }
368 }
369
370 fn parameter_names(&self) -> Vec<String> {
371 let mut names = vec!["filters".to_string()];
372 if self.bias.is_some() {
373 names.push("bias".to_string());
374 }
375 names
376 }
377
378 fn train(&mut self) {
379 self.training = true;
380 }
381
382 fn eval(&mut self) {
383 self.training = false;
384 }
385
386 fn is_training(&self) -> bool {
387 self.training
388 }
389
390 fn name(&self) -> &str {
391 &self.name
392 }
393}
394
395pub struct Conv2DBuilder {
397 name: String,
398 filter_height: usize,
399 filter_width: usize,
400 in_channels: usize,
401 out_channels: usize,
402 stride: (usize, usize),
403 padding: (usize, usize),
404 withbias: bool,
405 activation: Option<ActivationFunc>,
406}
407
408impl Conv2DBuilder {
409 pub fn new(name: &str) -> Self {
411 Self {
412 name: name.to_string(),
413 filter_height: 3,
414 filter_width: 3,
415 in_channels: 1,
416 out_channels: 1,
417 stride: (1, 1),
418 padding: (0, 0),
419 withbias: true,
420 activation: None,
421 }
422 }
423
424 pub const fn filter_size(mut self, height: usize, width: usize) -> Self {
426 self.filter_height = height;
427 self.filter_width = width;
428 self
429 }
430
431 pub const fn channels(mut self, input: usize, output: usize) -> Self {
433 self.in_channels = input;
434 self.out_channels = output;
435 self
436 }
437
438 pub fn stride(mut self, stride: (usize, usize)) -> Self {
440 self.stride = stride;
441 self
442 }
443
444 pub fn padding(mut self, padding: (usize, usize)) -> Self {
446 self.padding = padding;
447 self
448 }
449
450 pub fn withbias(mut self, withbias: bool) -> Self {
452 self.withbias = withbias;
453 self
454 }
455
456 pub fn activation(mut self, activation: ActivationFunc) -> Self {
458 self.activation = Some(activation);
459 self
460 }
461
462 pub fn build(self) -> Conv2D {
464 Conv2D::withshape(
465 &self.name,
466 self.filter_height,
467 self.filter_width,
468 self.in_channels,
469 self.out_channels,
470 self.stride,
471 self.padding,
472 self.withbias,
473 self.activation,
474 )
475 }
476}
477
478#[allow(dead_code)]
480pub struct MaxPool2D {
481 name: String,
483
484 kernel_size: (usize, usize),
486
487 stride: (usize, usize),
489
490 padding: (usize, usize),
492
493 training: bool,
495}
496
497impl MaxPool2D {
498 pub fn new(
500 name: &str,
501 kernel_size: (usize, usize),
502 stride: Option<(usize, usize)>,
503 padding: (usize, usize),
504 ) -> Self {
505 let stride = stride.unwrap_or(kernel_size);
506
507 Self {
508 name: name.to_string(),
509 kernel_size,
510 stride,
511 padding,
512 training: true,
513 }
514 }
515}
516
517impl Layer for MaxPool2D {
518 fn layer_type(&self) -> &str {
519 "MaxPool2D"
520 }
521
522 fn forward(
523 &self,
524 inputs: &dyn ArrayProtocol,
525 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
526 crate::array_protocol::ml_ops::max_pool2d(
528 inputs,
529 self.kernel_size,
530 self.stride,
531 self.padding,
532 )
533 }
534
535 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
536 Vec::new()
538 }
539
540 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
541 Vec::new()
543 }
544
545 fn update_parameter(
546 &mut self,
547 name: &str,
548 _value: Box<dyn ArrayProtocol>,
549 ) -> Result<(), OperationError> {
550 Err(OperationError::Other(format!(
551 "MaxPool2D has no parameter: {name}"
552 )))
553 }
554
555 fn parameter_names(&self) -> Vec<String> {
556 Vec::new()
558 }
559
560 fn train(&mut self) {
561 self.training = true;
562 }
563
564 fn eval(&mut self) {
565 self.training = false;
566 }
567
568 fn is_training(&self) -> bool {
569 self.training
570 }
571
572 fn name(&self) -> &str {
573 &self.name
574 }
575}
576
577pub struct BatchNorm {
579 name: String,
581
582 scale: Box<dyn ArrayProtocol>,
584
585 offset: Box<dyn ArrayProtocol>,
587
588 running_mean: Box<dyn ArrayProtocol>,
590
591 running_var: Box<dyn ArrayProtocol>,
593
594 epsilon: f64,
596
597 training: bool,
599}
600
601impl BatchNorm {
602 pub fn new(
604 name: &str,
605 scale: Box<dyn ArrayProtocol>,
606 offset: Box<dyn ArrayProtocol>,
607 running_mean: Box<dyn ArrayProtocol>,
608 running_var: Box<dyn ArrayProtocol>,
609 epsilon: f64,
610 ) -> Self {
611 Self {
612 name: name.to_string(),
613 scale,
614 offset,
615 running_mean,
616 running_var,
617 epsilon,
618 training: true,
619 }
620 }
621
622 pub fn withshape(
624 name: &str,
625 num_features: usize,
626 epsilon: Option<f64>,
627 _momentum: Option<f64>,
628 ) -> Self {
629 let scale: Array<f64, Ix1> = Array::ones(num_features);
631 let offset: Array<f64, Ix1> = Array::zeros(num_features);
632 let running_mean: Array<f64, Ix1> = Array::zeros(num_features);
633 let running_var: Array<f64, Ix1> = Array::ones(num_features);
634
635 Self {
636 name: name.to_string(),
637 scale: Box::new(NdarrayWrapper::new(scale)),
638 offset: Box::new(NdarrayWrapper::new(offset)),
639 running_mean: Box::new(NdarrayWrapper::new(running_mean)),
640 running_var: Box::new(NdarrayWrapper::new(running_var)),
641 epsilon: epsilon.unwrap_or(1e-5),
642 training: true,
643 }
644 }
645}
646
647impl Layer for BatchNorm {
648 fn layer_type(&self) -> &str {
649 "BatchNorm"
650 }
651
652 fn forward(
653 &self,
654 inputs: &dyn ArrayProtocol,
655 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
656 crate::array_protocol::ml_ops::batch_norm(
657 inputs,
658 self.scale.as_ref(),
659 self.offset.as_ref(),
660 self.running_mean.as_ref(),
661 self.running_var.as_ref(),
662 self.epsilon,
663 )
664 }
665
666 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
667 vec![self.scale.clone(), self.offset.clone()]
668 }
669
670 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
671 vec![&mut self.scale, &mut self.offset]
672 }
673
674 fn update_parameter(
675 &mut self,
676 name: &str,
677 value: Box<dyn ArrayProtocol>,
678 ) -> Result<(), OperationError> {
679 match name {
680 "scale" => {
681 self.scale = value;
682 Ok(())
683 }
684 "offset" => {
685 self.offset = value;
686 Ok(())
687 }
688 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
689 }
690 }
691
692 fn parameter_names(&self) -> Vec<String> {
693 vec!["scale".to_string(), "offset".to_string()]
694 }
695
696 fn train(&mut self) {
697 self.training = true;
698 }
699
700 fn eval(&mut self) {
701 self.training = false;
702 }
703
704 fn is_training(&self) -> bool {
705 self.training
706 }
707
708 fn name(&self) -> &str {
709 &self.name
710 }
711}
712
713pub struct Dropout {
715 name: String,
717
718 rate: f64,
720
721 seed: Option<u64>,
723
724 training: bool,
726}
727
728impl Dropout {
729 pub fn new(name: &str, rate: f64, seed: Option<u64>) -> Self {
731 Self {
732 name: name.to_string(),
733 rate,
734 seed,
735 training: true,
736 }
737 }
738}
739
740impl Layer for Dropout {
741 fn layer_type(&self) -> &str {
742 "Dropout"
743 }
744
745 fn forward(
746 &self,
747 inputs: &dyn ArrayProtocol,
748 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
749 crate::array_protocol::ml_ops::dropout(inputs, self.rate, self.training, self.seed)
750 }
751
752 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
753 Vec::new()
755 }
756
757 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
758 Vec::new()
760 }
761
762 fn update_parameter(
763 &mut self,
764 name: &str,
765 _value: Box<dyn ArrayProtocol>,
766 ) -> Result<(), OperationError> {
767 Err(OperationError::Other(format!(
768 "Dropout has no parameter: {name}"
769 )))
770 }
771
772 fn parameter_names(&self) -> Vec<String> {
773 Vec::new()
775 }
776
777 fn train(&mut self) {
778 self.training = true;
779 }
780
781 fn eval(&mut self) {
782 self.training = false;
783 }
784
785 fn is_training(&self) -> bool {
786 self.training
787 }
788
789 fn name(&self) -> &str {
790 &self.name
791 }
792}
793
794pub struct MultiHeadAttention {
796 name: String,
798
799 wq: Box<dyn ArrayProtocol>,
801
802 wk: Box<dyn ArrayProtocol>,
804
805 wv: Box<dyn ArrayProtocol>,
807
808 wo: Box<dyn ArrayProtocol>,
810
811 num_heads: usize,
813
814 dmodel: usize,
816
817 training: bool,
819}
820
821impl MultiHeadAttention {
822 pub fn new(
824 name: &str,
825 wq: Box<dyn ArrayProtocol>,
826 wk: Box<dyn ArrayProtocol>,
827 wv: Box<dyn ArrayProtocol>,
828 wo: Box<dyn ArrayProtocol>,
829 num_heads: usize,
830 dmodel: usize,
831 ) -> Self {
832 Self {
833 name: name.to_string(),
834 wq,
835 wk,
836 wv,
837 wo,
838 num_heads,
839 dmodel,
840 training: true,
841 }
842 }
843
844 pub fn with_params(name: &str, num_heads: usize, dmodel: usize) -> Self {
846 assert!(
848 dmodel.is_multiple_of(num_heads),
849 "dmodel must be divisible by num_heads"
850 );
851
852 let scale = (1.0_f64 / dmodel as f64).sqrt();
854 let mut rng = rand::rng();
855
856 let wq = Array::from_shape_fn((dmodel, dmodel), |_| {
857 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
858 });
859
860 let wk = Array::from_shape_fn((dmodel, dmodel), |_| {
861 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
862 });
863
864 let wv = Array::from_shape_fn((dmodel, dmodel), |_| {
865 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
866 });
867
868 let wo = Array::from_shape_fn((dmodel, dmodel), |_| {
869 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
870 });
871
872 Self {
873 name: name.to_string(),
874 wq: Box::new(NdarrayWrapper::new(wq)),
875 wk: Box::new(NdarrayWrapper::new(wk)),
876 wv: Box::new(NdarrayWrapper::new(wv)),
877 wo: Box::new(NdarrayWrapper::new(wo)),
878 num_heads,
879 dmodel,
880 training: true,
881 }
882 }
883}
884
885impl Layer for MultiHeadAttention {
886 fn layer_type(&self) -> &str {
887 "MultiHeadAttention"
888 }
889
890 fn forward(
891 &self,
892 inputs: &dyn ArrayProtocol,
893 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
894 let queries = crate::array_protocol::matmul(self.wq.as_ref(), inputs)?;
902 let keys = crate::array_protocol::matmul(self.wk.as_ref(), inputs)?;
903 let values = crate::array_protocol::matmul(self.wv.as_ref(), inputs)?;
904
905 let attention = crate::array_protocol::ml_ops::self_attention(
907 queries.as_ref(),
908 keys.as_ref(),
909 values.as_ref(),
910 None,
911 Some((self.dmodel / self.num_heads) as f64),
912 )?;
913
914 let output = crate::array_protocol::matmul(self.wo.as_ref(), attention.as_ref())?;
916
917 Ok(output)
918 }
919
920 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
921 vec![
922 self.wq.clone(),
923 self.wk.clone(),
924 self.wv.clone(),
925 self.wo.clone(),
926 ]
927 }
928
929 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
930 vec![&mut self.wq, &mut self.wk, &mut self.wv, &mut self.wo]
931 }
932
933 fn update_parameter(
934 &mut self,
935 name: &str,
936 value: Box<dyn ArrayProtocol>,
937 ) -> Result<(), OperationError> {
938 match name {
939 "wq" => {
940 self.wq = value;
941 Ok(())
942 }
943 "wk" => {
944 self.wk = value;
945 Ok(())
946 }
947 "wv" => {
948 self.wv = value;
949 Ok(())
950 }
951 "wo" => {
952 self.wo = value;
953 Ok(())
954 }
955 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
956 }
957 }
958
959 fn parameter_names(&self) -> Vec<String> {
960 vec![
961 "wq".to_string(),
962 "wk".to_string(),
963 "wv".to_string(),
964 "wo".to_string(),
965 ]
966 }
967
968 fn train(&mut self) {
969 self.training = true;
970 }
971
972 fn eval(&mut self) {
973 self.training = false;
974 }
975
976 fn is_training(&self) -> bool {
977 self.training
978 }
979
980 fn name(&self) -> &str {
981 &self.name
982 }
983}
984
985pub struct Sequential {
987 name: String,
989
990 layers: Vec<Box<dyn Layer>>,
992
993 training: bool,
995}
996
997impl Sequential {
998 pub fn new(name: &str, layers: Vec<Box<dyn Layer>>) -> Self {
1000 Self {
1001 name: name.to_string(),
1002 layers,
1003 training: true,
1004 }
1005 }
1006
1007 pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
1009 self.layers.push(layer);
1010 }
1011
1012 pub fn forward(
1014 &self,
1015 inputs: &dyn ArrayProtocol,
1016 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1017 let mut x: Box<dyn ArrayProtocol> = inputs.box_clone();
1019
1020 for layer in &self.layers {
1021 let x_ref: &dyn ArrayProtocol = x.as_ref();
1023 x = layer.forward(x_ref)?;
1025 }
1026
1027 Ok(x)
1028 }
1029
1030 pub fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1032 let mut params = Vec::new();
1033
1034 for layer in &self.layers {
1035 params.extend(layer.parameters());
1036 }
1037
1038 params
1039 }
1040
1041 pub fn train(&mut self) {
1043 self.training = true;
1044
1045 for layer in &mut self.layers {
1046 layer.train();
1047 }
1048 }
1049
1050 pub fn eval(&mut self) {
1052 self.training = false;
1053
1054 for layer in &mut self.layers {
1055 layer.eval();
1056 }
1057 }
1058
1059 pub fn name(&self) -> &str {
1061 &self.name
1062 }
1063
1064 pub fn layers(&self) -> &[Box<dyn Layer>] {
1066 &self.layers
1067 }
1068
1069 pub fn backward(
1071 &self,
1072 _output: &dyn ArrayProtocol,
1073 _target: &dyn ArrayProtocol,
1074 ) -> Result<crate::array_protocol::grad::GradientDict, crate::error::CoreError> {
1075 Ok(crate::array_protocol::grad::GradientDict::new())
1078 }
1079
1080 pub fn update_parameter(
1082 &mut self,
1083 param_name: &str,
1084 gradient: &dyn ArrayProtocol,
1085 learningrate: f64,
1086 ) -> Result<(), crate::error::CoreError> {
1087 let parts: Vec<&str> = param_name.split('.').collect();
1089 if parts.len() != 2 {
1090 return Err(crate::error::CoreError::ValueError(
1091 crate::error::ErrorContext::new(format!(
1092 "Invalid parameter name format. Expected 'layer_index.param_name', got: {param_name}"
1093 )),
1094 ));
1095 }
1096
1097 let layer_index: usize = parts[0].parse().map_err(|_| {
1098 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1099 "Invalid layer index: {layer_idx}",
1100 layer_idx = parts[0]
1101 )))
1102 })?;
1103
1104 let param_name = parts[1];
1105
1106 if layer_index >= self.layers.len() {
1107 return Err(crate::error::CoreError::ValueError(
1108 crate::error::ErrorContext::new(format!(
1109 "Layer index {layer_index} out of bounds (model has {num_layers} layers)",
1110 num_layers = self.layers.len()
1111 )),
1112 ));
1113 }
1114
1115 let layer = &mut self.layers[layer_index];
1117 let current_params = layer.parameters();
1118 let param_names = layer.parameter_names();
1119
1120 let param_idx = param_names
1122 .iter()
1123 .position(|name| name == param_name)
1124 .ok_or_else(|| {
1125 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1126 "Parameter '{param_name}' not found in layer {layer_index}"
1127 )))
1128 })?;
1129
1130 let current_param = ¤t_params[param_idx];
1132
1133 let scaled_gradient =
1135 crate::array_protocol::operations::multiply_by_scalar_f64(gradient, learningrate)
1136 .map_err(|e| {
1137 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(
1138 format!("Failed to scale gradient: {e}"),
1139 ))
1140 })?;
1141
1142 let updated_param = crate::array_protocol::operations::subtract(
1144 current_param.as_ref(),
1145 scaled_gradient.as_ref(),
1146 )
1147 .map_err(|e| {
1148 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1149 "Failed to update parameter: {e}"
1150 )))
1151 })?;
1152
1153 layer
1155 .update_parameter(param_name, updated_param)
1156 .map_err(|e| {
1157 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1158 "Failed to set parameter in layer: {e}"
1159 )))
1160 })?;
1161
1162 Ok(())
1163 }
1164
1165 pub fn all_parameter_names(&self) -> Vec<String> {
1167 let mut all_names = Vec::new();
1168 for (layer_idx, layer) in self.layers.iter().enumerate() {
1169 let layer_param_names = layer.parameter_names();
1170 for param_name in layer_param_names {
1171 all_names.push(format!("{layer_idx}.{param_name}"));
1172 }
1173 }
1174 all_names
1175 }
1176
1177 pub fn all_parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1179 let mut all_params = Vec::new();
1180 for layer in &self.layers {
1181 all_params.extend(layer.parameters());
1182 }
1183 all_params
1184 }
1185}
1186
1187#[allow(dead_code)]
1189pub fn create_simple_cnn(inputshape: (usize, usize, usize), num_classes: usize) -> Sequential {
1190 let (height, width, channels) = inputshape;
1191
1192 let mut model = Sequential::new("SimpleCNN", Vec::new());
1193
1194 model.add_layer(Box::new(Conv2D::withshape(
1196 "conv1",
1197 3,
1198 3, channels,
1200 32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1205 )));
1206
1207 model.add_layer(Box::new(MaxPool2D::new(
1208 "pool1",
1209 (2, 2), None, (0, 0), )));
1213
1214 model.add_layer(Box::new(Conv2D::withshape(
1216 "conv2",
1217 3,
1218 3, 32,
1220 64, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1225 )));
1226
1227 model.add_layer(Box::new(MaxPool2D::new(
1228 "pool2",
1229 (2, 2), None, (0, 0), )));
1233
1234 model.add_layer(Box::new(Linear::new_random(
1238 "fc1",
1239 64 * (height / 4) * (width / 4), 128, true, Some(ActivationFunc::ReLU),
1243 )));
1244
1245 model.add_layer(Box::new(Dropout::new(
1246 "dropout", 0.5, None, )));
1249
1250 model.add_layer(Box::new(Linear::new_random(
1251 "fc2",
1252 128, num_classes, true, None, )));
1257
1258 model
1259}
1260
1261#[cfg(test)]
1262mod tests {
1263 use super::*;
1264 use crate::array_protocol::{self, NdarrayWrapper};
1265 use ndarray::{Array1, Array2};
1266
1267 #[test]
1268 fn test_linear_layer() {
1269 array_protocol::init();
1271
1272 let weights = Array2::<f64>::eye(3);
1274 let bias = Array1::<f64>::ones(3);
1275
1276 let layer = Linear::new(
1277 "linear",
1278 Box::new(NdarrayWrapper::new(weights)),
1279 Some(Box::new(NdarrayWrapper::new(bias))),
1280 Some(ActivationFunc::ReLU),
1281 );
1282
1283 assert_eq!(layer.name(), "linear");
1294 assert!(layer.is_training());
1295 }
1296
1297 #[test]
1298 fn test_sequential_model() {
1299 array_protocol::init();
1301
1302 let mut model = Sequential::new("test_model", Vec::new());
1304
1305 model.add_layer(Box::new(Linear::new_random(
1307 "fc1",
1308 3, 2, true, Some(ActivationFunc::ReLU),
1312 )));
1313
1314 model.add_layer(Box::new(Linear::new_random(
1315 "fc2",
1316 2, 1, true, Some(ActivationFunc::Sigmoid),
1320 )));
1321
1322 assert_eq!(model.name(), "test_model");
1324 assert_eq!(model.layers().len(), 2);
1325 assert!(model.training);
1326 }
1327
1328 #[test]
1329 fn test_simple_cnn_creation() {
1330 array_protocol::init();
1332
1333 let model = create_simple_cnn((28, 28, 1), 10);
1335
1336 assert_eq!(model.layers().len(), 7);
1338 assert_eq!(model.name(), "SimpleCNN");
1339
1340 let params = model.parameters();
1342 assert!(!params.is_empty());
1343 }
1344}