1use ndarray::{Array, Ix1};
19use rand::Rng;
20
21use crate::array_protocol::ml_ops::ActivationFunc;
22use crate::array_protocol::operations::OperationError;
23use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
24
25pub trait Layer: Send + Sync {
27 fn layer_type(&self) -> &str;
30
31 fn forward(&self, inputs: &dyn ArrayProtocol)
32 -> Result<Box<dyn ArrayProtocol>, OperationError>;
33
34 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>>;
36
37 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>>;
39
40 fn update_parameter(
42 &mut self,
43 name: &str,
44 value: Box<dyn ArrayProtocol>,
45 ) -> Result<(), OperationError>;
46
47 fn parameter_names(&self) -> Vec<String>;
49
50 fn train(&mut self);
52
53 fn eval(&mut self);
55
56 fn is_training(&self) -> bool;
58
59 fn name(&self) -> &str;
61}
62
63pub struct Linear {
65 name: String,
67
68 weights: Box<dyn ArrayProtocol>,
70
71 bias: Option<Box<dyn ArrayProtocol>>,
73
74 activation: Option<ActivationFunc>,
76
77 training: bool,
79}
80
81impl Linear {
82 pub fn new(
84 name: &str,
85 weights: Box<dyn ArrayProtocol>,
86 bias: Option<Box<dyn ArrayProtocol>>,
87 activation: Option<ActivationFunc>,
88 ) -> Self {
89 Self {
90 name: name.to_string(),
91 weights,
92 bias,
93 activation,
94 training: true,
95 }
96 }
97
98 pub fn new_random(
100 name: &str,
101 in_features: usize,
102 out_features: usize,
103 withbias: bool,
104 activation: Option<ActivationFunc>,
105 ) -> Self {
106 let scale = (6.0 / (in_features + out_features) as f64).sqrt();
108 let mut rng = rand::rng();
109 let weights = Array::from_shape_fn((out_features, in_features), |_| {
110 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
111 });
112
113 let bias = if withbias {
115 let bias_array: Array<f64, Ix1> = Array::zeros(out_features);
116 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
117 } else {
118 None
119 };
120
121 Self {
122 name: name.to_string(),
123 weights: Box::new(NdarrayWrapper::new(weights)),
124 bias,
125 activation,
126 training: true,
127 }
128 }
129}
130
131impl Layer for Linear {
132 fn layer_type(&self) -> &str {
133 "Linear"
134 }
135
136 fn forward(
137 &self,
138 inputs: &dyn ArrayProtocol,
139 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
140 let mut result = crate::array_protocol::matmul(self.weights.as_ref(), inputs)?;
142
143 if let Some(bias) = &self.bias {
145 let intermediate = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
147 result = intermediate;
148 }
149
150 if let Some(act_fn) = self.activation {
152 let intermediate = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
154 result = intermediate;
155 }
156
157 Ok(result)
158 }
159
160 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
161 let mut params = vec![self.weights.clone()];
162 if let Some(bias) = &self.bias {
163 params.push(bias.clone());
164 }
165 params
166 }
167
168 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
169 let mut params = vec![&mut self.weights];
170 if let Some(bias) = &mut self.bias {
171 params.push(bias);
172 }
173 params
174 }
175
176 fn update_parameter(
177 &mut self,
178 name: &str,
179 value: Box<dyn ArrayProtocol>,
180 ) -> Result<(), OperationError> {
181 match name {
182 "weights" => {
183 self.weights = value;
184 Ok(())
185 }
186 "bias" => {
187 self.bias = Some(value);
188 Ok(())
189 }
190 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
191 }
192 }
193
194 fn parameter_names(&self) -> Vec<String> {
195 let mut names = vec!["weights".to_string()];
196 if self.bias.is_some() {
197 names.push("bias".to_string());
198 }
199 names
200 }
201
202 fn train(&mut self) {
203 self.training = true;
204 }
205
206 fn eval(&mut self) {
207 self.training = false;
208 }
209
210 fn is_training(&self) -> bool {
211 self.training
212 }
213
214 fn name(&self) -> &str {
215 &self.name
216 }
217}
218
219pub struct Conv2D {
221 name: String,
223
224 filters: Box<dyn ArrayProtocol>,
226
227 bias: Option<Box<dyn ArrayProtocol>>,
229
230 stride: (usize, usize),
232
233 padding: (usize, usize),
235
236 activation: Option<ActivationFunc>,
238
239 training: bool,
241}
242
243impl Conv2D {
244 pub fn new(
246 name: &str,
247 filters: Box<dyn ArrayProtocol>,
248 bias: Option<Box<dyn ArrayProtocol>>,
249 stride: (usize, usize),
250 padding: (usize, usize),
251 activation: Option<ActivationFunc>,
252 ) -> Self {
253 Self {
254 name: name.to_string(),
255 filters,
256 bias,
257 stride,
258 padding,
259 activation,
260 training: true,
261 }
262 }
263
264 #[allow(clippy::too_many_arguments)]
266 pub fn withshape(
267 name: &str,
268 filter_height: usize,
269 filter_width: usize,
270 in_channels: usize,
271 out_channels: usize,
272 stride: (usize, usize),
273 padding: (usize, usize),
274 withbias: bool,
275 activation: Option<ActivationFunc>,
276 ) -> Self {
277 let fan_in = filter_height * filter_width * in_channels;
279 let scale = (2.0 / fan_in as f64).sqrt();
280 let mut rng = rand::rng();
281 let filters = Array::from_shape_fn(
282 (filter_height, filter_width, in_channels, out_channels),
283 |_| (rng.random::<f64>() * 2.0_f64 - 1.0) * scale,
284 );
285
286 let bias = if withbias {
288 let bias_array: Array<f64, Ix1> = Array::zeros(out_channels);
289 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
290 } else {
291 None
292 };
293
294 Self {
295 name: name.to_string(),
296 filters: Box::new(NdarrayWrapper::new(filters)),
297 bias,
298 stride,
299 padding,
300 activation,
301 training: true,
302 }
303 }
304}
305
306impl Layer for Conv2D {
307 fn layer_type(&self) -> &str {
308 "Conv2D"
309 }
310
311 fn forward(
312 &self,
313 inputs: &dyn ArrayProtocol,
314 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
315 let mut result = crate::array_protocol::ml_ops::conv2d(
317 inputs,
318 self.filters.as_ref(),
319 self.stride,
320 self.padding,
321 )?;
322
323 if let Some(bias) = &self.bias {
325 result = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
326 }
327
328 if let Some(act_fn) = self.activation {
330 result = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
331 }
332
333 Ok(result)
334 }
335
336 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
337 let mut params = vec![self.filters.clone()];
338 if let Some(bias) = &self.bias {
339 params.push(bias.clone());
340 }
341 params
342 }
343
344 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
345 let mut params = vec![&mut self.filters];
346 if let Some(bias) = &mut self.bias {
347 params.push(bias);
348 }
349 params
350 }
351
352 fn update_parameter(
353 &mut self,
354 name: &str,
355 value: Box<dyn ArrayProtocol>,
356 ) -> Result<(), OperationError> {
357 match name {
358 "filters" => {
359 self.filters = value;
360 Ok(())
361 }
362 "bias" => {
363 self.bias = Some(value);
364 Ok(())
365 }
366 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
367 }
368 }
369
370 fn parameter_names(&self) -> Vec<String> {
371 let mut names = vec!["filters".to_string()];
372 if self.bias.is_some() {
373 names.push("bias".to_string());
374 }
375 names
376 }
377
378 fn train(&mut self) {
379 self.training = true;
380 }
381
382 fn eval(&mut self) {
383 self.training = false;
384 }
385
386 fn is_training(&self) -> bool {
387 self.training
388 }
389
390 fn name(&self) -> &str {
391 &self.name
392 }
393}
394
395pub struct Conv2DBuilder {
397 name: String,
398 filter_height: usize,
399 filter_width: usize,
400 in_channels: usize,
401 out_channels: usize,
402 stride: (usize, usize),
403 padding: (usize, usize),
404 withbias: bool,
405 activation: Option<ActivationFunc>,
406}
407
408impl Conv2DBuilder {
409 pub fn new(name: &str) -> Self {
411 Self {
412 name: name.to_string(),
413 filter_height: 3,
414 filter_width: 3,
415 in_channels: 1,
416 out_channels: 1,
417 stride: (1, 1),
418 padding: (0, 0),
419 withbias: true,
420 activation: None,
421 }
422 }
423
424 pub const fn filter_size(mut self, height: usize, width: usize) -> Self {
426 self.filter_height = height;
427 self.filter_width = width;
428 self
429 }
430
431 pub const fn channels(mut self, input: usize, output: usize) -> Self {
433 self.in_channels = input;
434 self.out_channels = output;
435 self
436 }
437
438 pub fn stride(mut self, stride: (usize, usize)) -> Self {
440 self.stride = stride;
441 self
442 }
443
444 pub fn padding(mut self, padding: (usize, usize)) -> Self {
446 self.padding = padding;
447 self
448 }
449
450 pub fn withbias(mut self, withbias: bool) -> Self {
452 self.withbias = withbias;
453 self
454 }
455
456 pub fn activation(mut self, activation: ActivationFunc) -> Self {
458 self.activation = Some(activation);
459 self
460 }
461
462 pub fn build(self) -> Conv2D {
464 Conv2D::withshape(
465 &self.name,
466 self.filter_height,
467 self.filter_width,
468 self.in_channels,
469 self.out_channels,
470 self.stride,
471 self.padding,
472 self.withbias,
473 self.activation,
474 )
475 }
476}
477
478#[allow(dead_code)]
480pub struct MaxPool2D {
481 name: String,
483
484 kernel_size: (usize, usize),
486
487 stride: (usize, usize),
489
490 padding: (usize, usize),
492
493 training: bool,
495}
496
497impl MaxPool2D {
498 pub fn new(
500 name: &str,
501 kernel_size: (usize, usize),
502 stride: Option<(usize, usize)>,
503 padding: (usize, usize),
504 ) -> Self {
505 let stride = stride.unwrap_or(kernel_size);
506
507 Self {
508 name: name.to_string(),
509 kernel_size,
510 stride,
511 padding,
512 training: true,
513 }
514 }
515}
516
517impl Layer for MaxPool2D {
518 fn layer_type(&self) -> &str {
519 "MaxPool2D"
520 }
521
522 fn forward(
523 &self,
524 _inputs: &dyn ArrayProtocol,
525 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
526 Err(OperationError::NotImplemented(
528 "max_pool2d not yet implemented".to_string(),
529 ))
530 }
531
532 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
533 Vec::new()
535 }
536
537 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
538 Vec::new()
540 }
541
542 fn update_parameter(
543 &mut self,
544 name: &str,
545 _value: Box<dyn ArrayProtocol>,
546 ) -> Result<(), OperationError> {
547 Err(OperationError::Other(format!(
548 "MaxPool2D has no parameter: {name}"
549 )))
550 }
551
552 fn parameter_names(&self) -> Vec<String> {
553 Vec::new()
555 }
556
557 fn train(&mut self) {
558 self.training = true;
559 }
560
561 fn eval(&mut self) {
562 self.training = false;
563 }
564
565 fn is_training(&self) -> bool {
566 self.training
567 }
568
569 fn name(&self) -> &str {
570 &self.name
571 }
572}
573
574pub struct BatchNorm {
576 name: String,
578
579 scale: Box<dyn ArrayProtocol>,
581
582 offset: Box<dyn ArrayProtocol>,
584
585 running_mean: Box<dyn ArrayProtocol>,
587
588 running_var: Box<dyn ArrayProtocol>,
590
591 epsilon: f64,
593
594 training: bool,
596}
597
598impl BatchNorm {
599 pub fn new(
601 name: &str,
602 scale: Box<dyn ArrayProtocol>,
603 offset: Box<dyn ArrayProtocol>,
604 running_mean: Box<dyn ArrayProtocol>,
605 running_var: Box<dyn ArrayProtocol>,
606 epsilon: f64,
607 ) -> Self {
608 Self {
609 name: name.to_string(),
610 scale,
611 offset,
612 running_mean,
613 running_var,
614 epsilon,
615 training: true,
616 }
617 }
618
619 pub fn withshape(
621 name: &str,
622 num_features: usize,
623 epsilon: Option<f64>,
624 _momentum: Option<f64>,
625 ) -> Self {
626 let scale: Array<f64, Ix1> = Array::ones(num_features);
628 let offset: Array<f64, Ix1> = Array::zeros(num_features);
629 let running_mean: Array<f64, Ix1> = Array::zeros(num_features);
630 let running_var: Array<f64, Ix1> = Array::ones(num_features);
631
632 Self {
633 name: name.to_string(),
634 scale: Box::new(NdarrayWrapper::new(scale)),
635 offset: Box::new(NdarrayWrapper::new(offset)),
636 running_mean: Box::new(NdarrayWrapper::new(running_mean)),
637 running_var: Box::new(NdarrayWrapper::new(running_var)),
638 epsilon: epsilon.unwrap_or(1e-5),
639 training: true,
640 }
641 }
642}
643
644impl Layer for BatchNorm {
645 fn layer_type(&self) -> &str {
646 "BatchNorm"
647 }
648
649 fn forward(
650 &self,
651 inputs: &dyn ArrayProtocol,
652 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
653 crate::array_protocol::ml_ops::batch_norm(
654 inputs,
655 self.scale.as_ref(),
656 self.offset.as_ref(),
657 self.running_mean.as_ref(),
658 self.running_var.as_ref(),
659 self.epsilon,
660 )
661 }
662
663 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
664 vec![self.scale.clone(), self.offset.clone()]
665 }
666
667 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
668 vec![&mut self.scale, &mut self.offset]
669 }
670
671 fn update_parameter(
672 &mut self,
673 name: &str,
674 value: Box<dyn ArrayProtocol>,
675 ) -> Result<(), OperationError> {
676 match name {
677 "scale" => {
678 self.scale = value;
679 Ok(())
680 }
681 "offset" => {
682 self.offset = value;
683 Ok(())
684 }
685 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
686 }
687 }
688
689 fn parameter_names(&self) -> Vec<String> {
690 vec!["scale".to_string(), "offset".to_string()]
691 }
692
693 fn train(&mut self) {
694 self.training = true;
695 }
696
697 fn eval(&mut self) {
698 self.training = false;
699 }
700
701 fn is_training(&self) -> bool {
702 self.training
703 }
704
705 fn name(&self) -> &str {
706 &self.name
707 }
708}
709
710pub struct Dropout {
712 name: String,
714
715 rate: f64,
717
718 seed: Option<u64>,
720
721 training: bool,
723}
724
725impl Dropout {
726 pub fn new(name: &str, rate: f64, seed: Option<u64>) -> Self {
728 Self {
729 name: name.to_string(),
730 rate,
731 seed,
732 training: true,
733 }
734 }
735}
736
737impl Layer for Dropout {
738 fn layer_type(&self) -> &str {
739 "Dropout"
740 }
741
742 fn forward(
743 &self,
744 inputs: &dyn ArrayProtocol,
745 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
746 crate::array_protocol::ml_ops::dropout(inputs, self.rate, self.training, self.seed)
747 }
748
749 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
750 Vec::new()
752 }
753
754 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
755 Vec::new()
757 }
758
759 fn update_parameter(
760 &mut self,
761 name: &str,
762 _value: Box<dyn ArrayProtocol>,
763 ) -> Result<(), OperationError> {
764 Err(OperationError::Other(format!(
765 "Dropout has no parameter: {name}"
766 )))
767 }
768
769 fn parameter_names(&self) -> Vec<String> {
770 Vec::new()
772 }
773
774 fn train(&mut self) {
775 self.training = true;
776 }
777
778 fn eval(&mut self) {
779 self.training = false;
780 }
781
782 fn is_training(&self) -> bool {
783 self.training
784 }
785
786 fn name(&self) -> &str {
787 &self.name
788 }
789}
790
791pub struct MultiHeadAttention {
793 name: String,
795
796 wq: Box<dyn ArrayProtocol>,
798
799 wk: Box<dyn ArrayProtocol>,
801
802 wv: Box<dyn ArrayProtocol>,
804
805 wo: Box<dyn ArrayProtocol>,
807
808 num_heads: usize,
810
811 dmodel: usize,
813
814 training: bool,
816}
817
818impl MultiHeadAttention {
819 pub fn new(
821 name: &str,
822 wq: Box<dyn ArrayProtocol>,
823 wk: Box<dyn ArrayProtocol>,
824 wv: Box<dyn ArrayProtocol>,
825 wo: Box<dyn ArrayProtocol>,
826 num_heads: usize,
827 dmodel: usize,
828 ) -> Self {
829 Self {
830 name: name.to_string(),
831 wq,
832 wk,
833 wv,
834 wo,
835 num_heads,
836 dmodel,
837 training: true,
838 }
839 }
840
841 pub fn with_params(name: &str, num_heads: usize, dmodel: usize) -> Self {
843 assert!(
845 dmodel % num_heads == 0,
846 "dmodel must be divisible by num_heads"
847 );
848
849 let scale = (1.0_f64 / dmodel as f64).sqrt();
851 let mut rng = rand::rng();
852
853 let wq = Array::from_shape_fn((dmodel, dmodel), |_| {
854 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
855 });
856
857 let wk = Array::from_shape_fn((dmodel, dmodel), |_| {
858 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
859 });
860
861 let wv = Array::from_shape_fn((dmodel, dmodel), |_| {
862 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
863 });
864
865 let wo = Array::from_shape_fn((dmodel, dmodel), |_| {
866 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
867 });
868
869 Self {
870 name: name.to_string(),
871 wq: Box::new(NdarrayWrapper::new(wq)),
872 wk: Box::new(NdarrayWrapper::new(wk)),
873 wv: Box::new(NdarrayWrapper::new(wv)),
874 wo: Box::new(NdarrayWrapper::new(wo)),
875 num_heads,
876 dmodel,
877 training: true,
878 }
879 }
880}
881
882impl Layer for MultiHeadAttention {
883 fn layer_type(&self) -> &str {
884 "MultiHeadAttention"
885 }
886
887 fn forward(
888 &self,
889 inputs: &dyn ArrayProtocol,
890 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
891 let queries = crate::array_protocol::matmul(self.wq.as_ref(), inputs)?;
899 let keys = crate::array_protocol::matmul(self.wk.as_ref(), inputs)?;
900 let values = crate::array_protocol::matmul(self.wv.as_ref(), inputs)?;
901
902 let attention = crate::array_protocol::ml_ops::self_attention(
904 queries.as_ref(),
905 keys.as_ref(),
906 values.as_ref(),
907 None,
908 Some((self.dmodel / self.num_heads) as f64),
909 )?;
910
911 let output = crate::array_protocol::matmul(self.wo.as_ref(), attention.as_ref())?;
913
914 Ok(output)
915 }
916
917 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
918 vec![
919 self.wq.clone(),
920 self.wk.clone(),
921 self.wv.clone(),
922 self.wo.clone(),
923 ]
924 }
925
926 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
927 vec![&mut self.wq, &mut self.wk, &mut self.wv, &mut self.wo]
928 }
929
930 fn update_parameter(
931 &mut self,
932 name: &str,
933 value: Box<dyn ArrayProtocol>,
934 ) -> Result<(), OperationError> {
935 match name {
936 "wq" => {
937 self.wq = value;
938 Ok(())
939 }
940 "wk" => {
941 self.wk = value;
942 Ok(())
943 }
944 "wv" => {
945 self.wv = value;
946 Ok(())
947 }
948 "wo" => {
949 self.wo = value;
950 Ok(())
951 }
952 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
953 }
954 }
955
956 fn parameter_names(&self) -> Vec<String> {
957 vec![
958 "wq".to_string(),
959 "wk".to_string(),
960 "wv".to_string(),
961 "wo".to_string(),
962 ]
963 }
964
965 fn train(&mut self) {
966 self.training = true;
967 }
968
969 fn eval(&mut self) {
970 self.training = false;
971 }
972
973 fn is_training(&self) -> bool {
974 self.training
975 }
976
977 fn name(&self) -> &str {
978 &self.name
979 }
980}
981
982pub struct Sequential {
984 name: String,
986
987 layers: Vec<Box<dyn Layer>>,
989
990 training: bool,
992}
993
994impl Sequential {
995 pub fn new(name: &str, layers: Vec<Box<dyn Layer>>) -> Self {
997 Self {
998 name: name.to_string(),
999 layers,
1000 training: true,
1001 }
1002 }
1003
1004 pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
1006 self.layers.push(layer);
1007 }
1008
1009 pub fn forward(
1011 &self,
1012 inputs: &dyn ArrayProtocol,
1013 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1014 let mut x: Box<dyn ArrayProtocol> = inputs.box_clone();
1016
1017 for layer in &self.layers {
1018 let x_ref: &dyn ArrayProtocol = x.as_ref();
1020 x = layer.forward(x_ref)?;
1022 }
1023
1024 Ok(x)
1025 }
1026
1027 pub fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1029 let mut params = Vec::new();
1030
1031 for layer in &self.layers {
1032 params.extend(layer.parameters());
1033 }
1034
1035 params
1036 }
1037
1038 pub fn train(&mut self) {
1040 self.training = true;
1041
1042 for layer in &mut self.layers {
1043 layer.train();
1044 }
1045 }
1046
1047 pub fn eval(&mut self) {
1049 self.training = false;
1050
1051 for layer in &mut self.layers {
1052 layer.eval();
1053 }
1054 }
1055
1056 pub fn name(&self) -> &str {
1058 &self.name
1059 }
1060
1061 pub fn layers(&self) -> &[Box<dyn Layer>] {
1063 &self.layers
1064 }
1065
1066 pub fn backward(
1068 &self,
1069 _output: &dyn ArrayProtocol,
1070 _target: &dyn ArrayProtocol,
1071 ) -> Result<crate::array_protocol::grad::GradientDict, crate::error::CoreError> {
1072 Ok(crate::array_protocol::grad::GradientDict::new())
1075 }
1076
1077 pub fn update_parameter(
1079 &mut self,
1080 param_name: &str,
1081 gradient: &dyn ArrayProtocol,
1082 learningrate: f64,
1083 ) -> Result<(), crate::error::CoreError> {
1084 let parts: Vec<&str> = param_name.split('.').collect();
1086 if parts.len() != 2 {
1087 return Err(crate::error::CoreError::ValueError(
1088 crate::error::ErrorContext::new(format!(
1089 "Invalid parameter name format. Expected 'layer_index.param_name', got: {param_name}"
1090 )),
1091 ));
1092 }
1093
1094 let layer_index: usize = parts[0].parse().map_err(|_| {
1095 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1096 "Invalid layer index: {layer_idx}",
1097 layer_idx = parts[0]
1098 )))
1099 })?;
1100
1101 let param_name = parts[1];
1102
1103 if layer_index >= self.layers.len() {
1104 return Err(crate::error::CoreError::ValueError(
1105 crate::error::ErrorContext::new(format!(
1106 "Layer index {layer_index} out of bounds (model has {num_layers} layers)",
1107 num_layers = self.layers.len()
1108 )),
1109 ));
1110 }
1111
1112 let layer = &mut self.layers[layer_index];
1114 let current_params = layer.parameters();
1115 let param_names = layer.parameter_names();
1116
1117 let param_idx = param_names
1119 .iter()
1120 .position(|name| name == param_name)
1121 .ok_or_else(|| {
1122 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1123 "Parameter '{param_name}' not found in layer {layer_index}"
1124 )))
1125 })?;
1126
1127 let current_param = ¤t_params[param_idx];
1129
1130 let scaled_gradient =
1132 crate::array_protocol::operations::multiply_by_scalar_f64(gradient, learningrate)
1133 .map_err(|e| {
1134 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(
1135 format!("Failed to scale gradient: {e}"),
1136 ))
1137 })?;
1138
1139 let updated_param = crate::array_protocol::operations::subtract(
1141 current_param.as_ref(),
1142 scaled_gradient.as_ref(),
1143 )
1144 .map_err(|e| {
1145 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1146 "Failed to update parameter: {e}"
1147 )))
1148 })?;
1149
1150 layer
1152 .update_parameter(param_name, updated_param)
1153 .map_err(|e| {
1154 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1155 "Failed to set parameter in layer: {e}"
1156 )))
1157 })?;
1158
1159 Ok(())
1160 }
1161
1162 pub fn all_parameter_names(&self) -> Vec<String> {
1164 let mut all_names = Vec::new();
1165 for (layer_idx, layer) in self.layers.iter().enumerate() {
1166 let layer_param_names = layer.parameter_names();
1167 for param_name in layer_param_names {
1168 all_names.push(format!("{layer_idx}.{param_name}"));
1169 }
1170 }
1171 all_names
1172 }
1173
1174 pub fn all_parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1176 let mut all_params = Vec::new();
1177 for layer in &self.layers {
1178 all_params.extend(layer.parameters());
1179 }
1180 all_params
1181 }
1182}
1183
1184#[allow(dead_code)]
1186pub fn create_simple_cnn(inputshape: (usize, usize, usize), num_classes: usize) -> Sequential {
1187 let (height, width, channels) = inputshape;
1188
1189 let mut model = Sequential::new("SimpleCNN", Vec::new());
1190
1191 model.add_layer(Box::new(Conv2D::withshape(
1193 "conv1",
1194 3,
1195 3, channels,
1197 32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1202 )));
1203
1204 model.add_layer(Box::new(MaxPool2D::new(
1205 "pool1",
1206 (2, 2), None, (0, 0), )));
1210
1211 model.add_layer(Box::new(Conv2D::withshape(
1213 "conv2",
1214 3,
1215 3, 32,
1217 64, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1222 )));
1223
1224 model.add_layer(Box::new(MaxPool2D::new(
1225 "pool2",
1226 (2, 2), None, (0, 0), )));
1230
1231 model.add_layer(Box::new(Linear::new_random(
1235 "fc1",
1236 64 * (height / 4) * (width / 4), 128, true, Some(ActivationFunc::ReLU),
1240 )));
1241
1242 model.add_layer(Box::new(Dropout::new(
1243 "dropout", 0.5, None, )));
1246
1247 model.add_layer(Box::new(Linear::new_random(
1248 "fc2",
1249 128, num_classes, true, None, )));
1254
1255 model
1256}
1257
1258#[cfg(test)]
1259mod tests {
1260 use super::*;
1261 use crate::array_protocol::{self, NdarrayWrapper};
1262 use ndarray::{Array1, Array2};
1263
1264 #[test]
1265 fn test_linear_layer() {
1266 array_protocol::init();
1268
1269 let weights = Array2::<f64>::eye(3);
1271 let bias = Array1::<f64>::ones(3);
1272
1273 let layer = Linear::new(
1274 "linear",
1275 Box::new(NdarrayWrapper::new(weights)),
1276 Some(Box::new(NdarrayWrapper::new(bias))),
1277 Some(ActivationFunc::ReLU),
1278 );
1279
1280 assert_eq!(layer.name(), "linear");
1291 assert!(layer.is_training());
1292 }
1293
1294 #[test]
1295 fn test_sequential_model() {
1296 array_protocol::init();
1298
1299 let mut model = Sequential::new("test_model", Vec::new());
1301
1302 model.add_layer(Box::new(Linear::new_random(
1304 "fc1",
1305 3, 2, true, Some(ActivationFunc::ReLU),
1309 )));
1310
1311 model.add_layer(Box::new(Linear::new_random(
1312 "fc2",
1313 2, 1, true, Some(ActivationFunc::Sigmoid),
1317 )));
1318
1319 assert_eq!(model.name(), "test_model");
1321 assert_eq!(model.layers().len(), 2);
1322 assert!(model.training);
1323 }
1324
1325 #[test]
1326 fn test_simple_cnn_creation() {
1327 array_protocol::init();
1329
1330 let model = create_simple_cnn((28, 28, 1), 10);
1332
1333 assert_eq!(model.layers().len(), 7);
1335 assert_eq!(model.name(), "SimpleCNN");
1336
1337 let params = model.parameters();
1339 assert!(!params.is_empty());
1340 }
1341}