1use ::ndarray::{Array, Ix1};
19
20use rand::Rng;
21
22use crate::array_protocol::ml_ops::ActivationFunc;
23use crate::array_protocol::operations::OperationError;
24use crate::array_protocol::{ArrayProtocol, NdarrayWrapper};
25
26pub trait Layer: Send + Sync {
28 fn layer_type(&self) -> &str;
31
32 fn forward(&self, inputs: &dyn ArrayProtocol)
33 -> Result<Box<dyn ArrayProtocol>, OperationError>;
34
35 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>>;
37
38 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>>;
40
41 fn update_parameter(
43 &mut self,
44 name: &str,
45 value: Box<dyn ArrayProtocol>,
46 ) -> Result<(), OperationError>;
47
48 fn parameter_names(&self) -> Vec<String>;
50
51 fn train(&mut self);
53
54 fn eval(&mut self);
56
57 fn is_training(&self) -> bool;
59
60 fn name(&self) -> &str;
62}
63
64pub struct Linear {
66 name: String,
68
69 weights: Box<dyn ArrayProtocol>,
71
72 bias: Option<Box<dyn ArrayProtocol>>,
74
75 activation: Option<ActivationFunc>,
77
78 training: bool,
80}
81
82impl Linear {
83 pub fn new(
85 name: &str,
86 weights: Box<dyn ArrayProtocol>,
87 bias: Option<Box<dyn ArrayProtocol>>,
88 activation: Option<ActivationFunc>,
89 ) -> Self {
90 Self {
91 name: name.to_string(),
92 weights,
93 bias,
94 activation,
95 training: true,
96 }
97 }
98
99 pub fn new_random(
101 name: &str,
102 in_features: usize,
103 out_features: usize,
104 withbias: bool,
105 activation: Option<ActivationFunc>,
106 ) -> Self {
107 let scale = (6.0 / (in_features + out_features) as f64).sqrt();
109 let mut rng = rand::rng();
110 let weights = Array::from_shape_fn((out_features, in_features), |_| {
111 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
112 });
113
114 let bias = if withbias {
116 let bias_array: Array<f64, Ix1> = Array::zeros(out_features);
117 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
118 } else {
119 None
120 };
121
122 Self {
123 name: name.to_string(),
124 weights: Box::new(NdarrayWrapper::new(weights)),
125 bias,
126 activation,
127 training: true,
128 }
129 }
130}
131
132impl Layer for Linear {
133 fn layer_type(&self) -> &str {
134 "Linear"
135 }
136
137 fn forward(
138 &self,
139 inputs: &dyn ArrayProtocol,
140 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
141 let mut result = crate::array_protocol::matmul(self.weights.as_ref(), inputs)?;
143
144 if let Some(bias) = &self.bias {
146 let intermediate = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
148 result = intermediate;
149 }
150
151 if let Some(act_fn) = self.activation {
153 let intermediate = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
155 result = intermediate;
156 }
157
158 Ok(result)
159 }
160
161 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
162 let mut params = vec![self.weights.clone()];
163 if let Some(bias) = &self.bias {
164 params.push(bias.clone());
165 }
166 params
167 }
168
169 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
170 let mut params = vec![&mut self.weights];
171 if let Some(bias) = &mut self.bias {
172 params.push(bias);
173 }
174 params
175 }
176
177 fn update_parameter(
178 &mut self,
179 name: &str,
180 value: Box<dyn ArrayProtocol>,
181 ) -> Result<(), OperationError> {
182 match name {
183 "weights" => {
184 self.weights = value;
185 Ok(())
186 }
187 "bias" => {
188 self.bias = Some(value);
189 Ok(())
190 }
191 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
192 }
193 }
194
195 fn parameter_names(&self) -> Vec<String> {
196 let mut names = vec!["weights".to_string()];
197 if self.bias.is_some() {
198 names.push("bias".to_string());
199 }
200 names
201 }
202
203 fn train(&mut self) {
204 self.training = true;
205 }
206
207 fn eval(&mut self) {
208 self.training = false;
209 }
210
211 fn is_training(&self) -> bool {
212 self.training
213 }
214
215 fn name(&self) -> &str {
216 &self.name
217 }
218}
219
220pub struct Conv2D {
222 name: String,
224
225 filters: Box<dyn ArrayProtocol>,
227
228 bias: Option<Box<dyn ArrayProtocol>>,
230
231 stride: (usize, usize),
233
234 padding: (usize, usize),
236
237 activation: Option<ActivationFunc>,
239
240 training: bool,
242}
243
244impl Conv2D {
245 pub fn new(
247 name: &str,
248 filters: Box<dyn ArrayProtocol>,
249 bias: Option<Box<dyn ArrayProtocol>>,
250 stride: (usize, usize),
251 padding: (usize, usize),
252 activation: Option<ActivationFunc>,
253 ) -> Self {
254 Self {
255 name: name.to_string(),
256 filters,
257 bias,
258 stride,
259 padding,
260 activation,
261 training: true,
262 }
263 }
264
265 #[allow(clippy::too_many_arguments)]
267 pub fn withshape(
268 name: &str,
269 filter_height: usize,
270 filter_width: usize,
271 in_channels: usize,
272 out_channels: usize,
273 stride: (usize, usize),
274 padding: (usize, usize),
275 withbias: bool,
276 activation: Option<ActivationFunc>,
277 ) -> Self {
278 let fan_in = filter_height * filter_width * in_channels;
280 let scale = (2.0 / fan_in as f64).sqrt();
281 let mut rng = rand::rng();
282 let filters = Array::from_shape_fn(
283 (filter_height, filter_width, in_channels, out_channels),
284 |_| (rng.random::<f64>() * 2.0_f64 - 1.0) * scale,
285 );
286
287 let bias = if withbias {
289 let bias_array: Array<f64, Ix1> = Array::zeros(out_channels);
290 Some(Box::new(NdarrayWrapper::new(bias_array)) as Box<dyn ArrayProtocol>)
291 } else {
292 None
293 };
294
295 Self {
296 name: name.to_string(),
297 filters: Box::new(NdarrayWrapper::new(filters)),
298 bias,
299 stride,
300 padding,
301 activation,
302 training: true,
303 }
304 }
305}
306
307impl Layer for Conv2D {
308 fn layer_type(&self) -> &str {
309 "Conv2D"
310 }
311
312 fn forward(
313 &self,
314 inputs: &dyn ArrayProtocol,
315 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
316 let mut result = crate::array_protocol::ml_ops::conv2d(
318 inputs,
319 self.filters.as_ref(),
320 self.stride,
321 self.padding,
322 )?;
323
324 if let Some(bias) = &self.bias {
326 result = crate::array_protocol::add(result.as_ref(), bias.as_ref())?;
327 }
328
329 if let Some(act_fn) = self.activation {
331 result = crate::array_protocol::ml_ops::activation(result.as_ref(), act_fn)?;
332 }
333
334 Ok(result)
335 }
336
337 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
338 let mut params = vec![self.filters.clone()];
339 if let Some(bias) = &self.bias {
340 params.push(bias.clone());
341 }
342 params
343 }
344
345 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
346 let mut params = vec![&mut self.filters];
347 if let Some(bias) = &mut self.bias {
348 params.push(bias);
349 }
350 params
351 }
352
353 fn update_parameter(
354 &mut self,
355 name: &str,
356 value: Box<dyn ArrayProtocol>,
357 ) -> Result<(), OperationError> {
358 match name {
359 "filters" => {
360 self.filters = value;
361 Ok(())
362 }
363 "bias" => {
364 self.bias = Some(value);
365 Ok(())
366 }
367 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
368 }
369 }
370
371 fn parameter_names(&self) -> Vec<String> {
372 let mut names = vec!["filters".to_string()];
373 if self.bias.is_some() {
374 names.push("bias".to_string());
375 }
376 names
377 }
378
379 fn train(&mut self) {
380 self.training = true;
381 }
382
383 fn eval(&mut self) {
384 self.training = false;
385 }
386
387 fn is_training(&self) -> bool {
388 self.training
389 }
390
391 fn name(&self) -> &str {
392 &self.name
393 }
394}
395
396pub struct Conv2DBuilder {
398 name: String,
399 filter_height: usize,
400 filter_width: usize,
401 in_channels: usize,
402 out_channels: usize,
403 stride: (usize, usize),
404 padding: (usize, usize),
405 withbias: bool,
406 activation: Option<ActivationFunc>,
407}
408
409impl Conv2DBuilder {
410 pub fn new(name: &str) -> Self {
412 Self {
413 name: name.to_string(),
414 filter_height: 3,
415 filter_width: 3,
416 in_channels: 1,
417 out_channels: 1,
418 stride: (1, 1),
419 padding: (0, 0),
420 withbias: true,
421 activation: None,
422 }
423 }
424
425 pub const fn filter_size(mut self, height: usize, width: usize) -> Self {
427 self.filter_height = height;
428 self.filter_width = width;
429 self
430 }
431
432 pub const fn channels(mut self, input: usize, output: usize) -> Self {
434 self.in_channels = input;
435 self.out_channels = output;
436 self
437 }
438
439 pub fn stride(mut self, stride: (usize, usize)) -> Self {
441 self.stride = stride;
442 self
443 }
444
445 pub fn padding(mut self, padding: (usize, usize)) -> Self {
447 self.padding = padding;
448 self
449 }
450
451 pub fn withbias(mut self, withbias: bool) -> Self {
453 self.withbias = withbias;
454 self
455 }
456
457 pub fn activation(mut self, activation: ActivationFunc) -> Self {
459 self.activation = Some(activation);
460 self
461 }
462
463 pub fn build(self) -> Conv2D {
465 Conv2D::withshape(
466 &self.name,
467 self.filter_height,
468 self.filter_width,
469 self.in_channels,
470 self.out_channels,
471 self.stride,
472 self.padding,
473 self.withbias,
474 self.activation,
475 )
476 }
477}
478
479#[allow(dead_code)]
481pub struct MaxPool2D {
482 name: String,
484
485 kernel_size: (usize, usize),
487
488 stride: (usize, usize),
490
491 padding: (usize, usize),
493
494 training: bool,
496}
497
498impl MaxPool2D {
499 pub fn new(
501 name: &str,
502 kernel_size: (usize, usize),
503 stride: Option<(usize, usize)>,
504 padding: (usize, usize),
505 ) -> Self {
506 let stride = stride.unwrap_or(kernel_size);
507
508 Self {
509 name: name.to_string(),
510 kernel_size,
511 stride,
512 padding,
513 training: true,
514 }
515 }
516}
517
518impl Layer for MaxPool2D {
519 fn layer_type(&self) -> &str {
520 "MaxPool2D"
521 }
522
523 fn forward(
524 &self,
525 inputs: &dyn ArrayProtocol,
526 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
527 crate::array_protocol::ml_ops::max_pool2d(
529 inputs,
530 self.kernel_size,
531 self.stride,
532 self.padding,
533 )
534 }
535
536 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
537 Vec::new()
539 }
540
541 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
542 Vec::new()
544 }
545
546 fn update_parameter(
547 &mut self,
548 name: &str,
549 _value: Box<dyn ArrayProtocol>,
550 ) -> Result<(), OperationError> {
551 Err(OperationError::Other(format!(
552 "MaxPool2D has no parameter: {name}"
553 )))
554 }
555
556 fn parameter_names(&self) -> Vec<String> {
557 Vec::new()
559 }
560
561 fn train(&mut self) {
562 self.training = true;
563 }
564
565 fn eval(&mut self) {
566 self.training = false;
567 }
568
569 fn is_training(&self) -> bool {
570 self.training
571 }
572
573 fn name(&self) -> &str {
574 &self.name
575 }
576}
577
578pub struct BatchNorm {
580 name: String,
582
583 scale: Box<dyn ArrayProtocol>,
585
586 offset: Box<dyn ArrayProtocol>,
588
589 running_mean: Box<dyn ArrayProtocol>,
591
592 running_var: Box<dyn ArrayProtocol>,
594
595 epsilon: f64,
597
598 training: bool,
600}
601
602impl BatchNorm {
603 pub fn new(
605 name: &str,
606 scale: Box<dyn ArrayProtocol>,
607 offset: Box<dyn ArrayProtocol>,
608 running_mean: Box<dyn ArrayProtocol>,
609 running_var: Box<dyn ArrayProtocol>,
610 epsilon: f64,
611 ) -> Self {
612 Self {
613 name: name.to_string(),
614 scale,
615 offset,
616 running_mean,
617 running_var,
618 epsilon,
619 training: true,
620 }
621 }
622
623 pub fn withshape(
625 name: &str,
626 num_features: usize,
627 epsilon: Option<f64>,
628 _momentum: Option<f64>,
629 ) -> Self {
630 let scale: Array<f64, Ix1> = Array::ones(num_features);
632 let offset: Array<f64, Ix1> = Array::zeros(num_features);
633 let running_mean: Array<f64, Ix1> = Array::zeros(num_features);
634 let running_var: Array<f64, Ix1> = Array::ones(num_features);
635
636 Self {
637 name: name.to_string(),
638 scale: Box::new(NdarrayWrapper::new(scale)),
639 offset: Box::new(NdarrayWrapper::new(offset)),
640 running_mean: Box::new(NdarrayWrapper::new(running_mean)),
641 running_var: Box::new(NdarrayWrapper::new(running_var)),
642 epsilon: epsilon.unwrap_or(1e-5),
643 training: true,
644 }
645 }
646}
647
648impl Layer for BatchNorm {
649 fn layer_type(&self) -> &str {
650 "BatchNorm"
651 }
652
653 fn forward(
654 &self,
655 inputs: &dyn ArrayProtocol,
656 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
657 crate::array_protocol::ml_ops::batch_norm(
658 inputs,
659 self.scale.as_ref(),
660 self.offset.as_ref(),
661 self.running_mean.as_ref(),
662 self.running_var.as_ref(),
663 self.epsilon,
664 )
665 }
666
667 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
668 vec![self.scale.clone(), self.offset.clone()]
669 }
670
671 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
672 vec![&mut self.scale, &mut self.offset]
673 }
674
675 fn update_parameter(
676 &mut self,
677 name: &str,
678 value: Box<dyn ArrayProtocol>,
679 ) -> Result<(), OperationError> {
680 match name {
681 "scale" => {
682 self.scale = value;
683 Ok(())
684 }
685 "offset" => {
686 self.offset = value;
687 Ok(())
688 }
689 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
690 }
691 }
692
693 fn parameter_names(&self) -> Vec<String> {
694 vec!["scale".to_string(), "offset".to_string()]
695 }
696
697 fn train(&mut self) {
698 self.training = true;
699 }
700
701 fn eval(&mut self) {
702 self.training = false;
703 }
704
705 fn is_training(&self) -> bool {
706 self.training
707 }
708
709 fn name(&self) -> &str {
710 &self.name
711 }
712}
713
714pub struct Dropout {
716 name: String,
718
719 rate: f64,
721
722 seed: Option<u64>,
724
725 training: bool,
727}
728
729impl Dropout {
730 pub fn new(name: &str, rate: f64, seed: Option<u64>) -> Self {
732 Self {
733 name: name.to_string(),
734 rate,
735 seed,
736 training: true,
737 }
738 }
739}
740
741impl Layer for Dropout {
742 fn layer_type(&self) -> &str {
743 "Dropout"
744 }
745
746 fn forward(
747 &self,
748 inputs: &dyn ArrayProtocol,
749 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
750 crate::array_protocol::ml_ops::dropout(inputs, self.rate, self.training, self.seed)
751 }
752
753 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
754 Vec::new()
756 }
757
758 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
759 Vec::new()
761 }
762
763 fn update_parameter(
764 &mut self,
765 name: &str,
766 _value: Box<dyn ArrayProtocol>,
767 ) -> Result<(), OperationError> {
768 Err(OperationError::Other(format!(
769 "Dropout has no parameter: {name}"
770 )))
771 }
772
773 fn parameter_names(&self) -> Vec<String> {
774 Vec::new()
776 }
777
778 fn train(&mut self) {
779 self.training = true;
780 }
781
782 fn eval(&mut self) {
783 self.training = false;
784 }
785
786 fn is_training(&self) -> bool {
787 self.training
788 }
789
790 fn name(&self) -> &str {
791 &self.name
792 }
793}
794
795pub struct MultiHeadAttention {
797 name: String,
799
800 wq: Box<dyn ArrayProtocol>,
802
803 wk: Box<dyn ArrayProtocol>,
805
806 wv: Box<dyn ArrayProtocol>,
808
809 wo: Box<dyn ArrayProtocol>,
811
812 num_heads: usize,
814
815 dmodel: usize,
817
818 training: bool,
820}
821
822impl MultiHeadAttention {
823 pub fn new(
825 name: &str,
826 wq: Box<dyn ArrayProtocol>,
827 wk: Box<dyn ArrayProtocol>,
828 wv: Box<dyn ArrayProtocol>,
829 wo: Box<dyn ArrayProtocol>,
830 num_heads: usize,
831 dmodel: usize,
832 ) -> Self {
833 Self {
834 name: name.to_string(),
835 wq,
836 wk,
837 wv,
838 wo,
839 num_heads,
840 dmodel,
841 training: true,
842 }
843 }
844
845 pub fn with_params(name: &str, num_heads: usize, dmodel: usize) -> Self {
847 assert!(
849 dmodel % num_heads == 0,
850 "dmodel must be divisible by num_heads"
851 );
852
853 let scale = (1.0_f64 / dmodel as f64).sqrt();
855 let mut rng = rand::rng();
856
857 let wq = Array::from_shape_fn((dmodel, dmodel), |_| {
858 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
859 });
860
861 let wk = Array::from_shape_fn((dmodel, dmodel), |_| {
862 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
863 });
864
865 let wv = Array::from_shape_fn((dmodel, dmodel), |_| {
866 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
867 });
868
869 let wo = Array::from_shape_fn((dmodel, dmodel), |_| {
870 (rng.random::<f64>() * 2.0_f64 - 1.0) * scale
871 });
872
873 Self {
874 name: name.to_string(),
875 wq: Box::new(NdarrayWrapper::new(wq)),
876 wk: Box::new(NdarrayWrapper::new(wk)),
877 wv: Box::new(NdarrayWrapper::new(wv)),
878 wo: Box::new(NdarrayWrapper::new(wo)),
879 num_heads,
880 dmodel,
881 training: true,
882 }
883 }
884}
885
886impl Layer for MultiHeadAttention {
887 fn layer_type(&self) -> &str {
888 "MultiHeadAttention"
889 }
890
891 fn forward(
892 &self,
893 inputs: &dyn ArrayProtocol,
894 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
895 let queries = crate::array_protocol::matmul(self.wq.as_ref(), inputs)?;
903 let keys = crate::array_protocol::matmul(self.wk.as_ref(), inputs)?;
904 let values = crate::array_protocol::matmul(self.wv.as_ref(), inputs)?;
905
906 let attention = crate::array_protocol::ml_ops::self_attention(
908 queries.as_ref(),
909 keys.as_ref(),
910 values.as_ref(),
911 None,
912 Some((self.dmodel / self.num_heads) as f64),
913 )?;
914
915 let output = crate::array_protocol::matmul(self.wo.as_ref(), attention.as_ref())?;
917
918 Ok(output)
919 }
920
921 fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
922 vec![
923 self.wq.clone(),
924 self.wk.clone(),
925 self.wv.clone(),
926 self.wo.clone(),
927 ]
928 }
929
930 fn parameters_mut(&mut self) -> Vec<&mut Box<dyn ArrayProtocol>> {
931 vec![&mut self.wq, &mut self.wk, &mut self.wv, &mut self.wo]
932 }
933
934 fn update_parameter(
935 &mut self,
936 name: &str,
937 value: Box<dyn ArrayProtocol>,
938 ) -> Result<(), OperationError> {
939 match name {
940 "wq" => {
941 self.wq = value;
942 Ok(())
943 }
944 "wk" => {
945 self.wk = value;
946 Ok(())
947 }
948 "wv" => {
949 self.wv = value;
950 Ok(())
951 }
952 "wo" => {
953 self.wo = value;
954 Ok(())
955 }
956 _ => Err(OperationError::Other(format!("Unknown parameter: {name}"))),
957 }
958 }
959
960 fn parameter_names(&self) -> Vec<String> {
961 vec![
962 "wq".to_string(),
963 "wk".to_string(),
964 "wv".to_string(),
965 "wo".to_string(),
966 ]
967 }
968
969 fn train(&mut self) {
970 self.training = true;
971 }
972
973 fn eval(&mut self) {
974 self.training = false;
975 }
976
977 fn is_training(&self) -> bool {
978 self.training
979 }
980
981 fn name(&self) -> &str {
982 &self.name
983 }
984}
985
986pub struct Sequential {
988 name: String,
990
991 layers: Vec<Box<dyn Layer>>,
993
994 training: bool,
996}
997
998impl Sequential {
999 pub fn new(name: &str, layers: Vec<Box<dyn Layer>>) -> Self {
1001 Self {
1002 name: name.to_string(),
1003 layers,
1004 training: true,
1005 }
1006 }
1007
1008 pub fn add_layer(&mut self, layer: Box<dyn Layer>) {
1010 self.layers.push(layer);
1011 }
1012
1013 pub fn forward(
1015 &self,
1016 inputs: &dyn ArrayProtocol,
1017 ) -> Result<Box<dyn ArrayProtocol>, OperationError> {
1018 let mut x: Box<dyn ArrayProtocol> = inputs.box_clone();
1020
1021 for layer in &self.layers {
1022 let x_ref: &dyn ArrayProtocol = x.as_ref();
1024 x = layer.forward(x_ref)?;
1026 }
1027
1028 Ok(x)
1029 }
1030
1031 pub fn parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1033 let mut params = Vec::new();
1034
1035 for layer in &self.layers {
1036 params.extend(layer.parameters());
1037 }
1038
1039 params
1040 }
1041
1042 pub fn train(&mut self) {
1044 self.training = true;
1045
1046 for layer in &mut self.layers {
1047 layer.train();
1048 }
1049 }
1050
1051 pub fn eval(&mut self) {
1053 self.training = false;
1054
1055 for layer in &mut self.layers {
1056 layer.eval();
1057 }
1058 }
1059
1060 pub fn name(&self) -> &str {
1062 &self.name
1063 }
1064
1065 pub fn layers(&self) -> &[Box<dyn Layer>] {
1067 &self.layers
1068 }
1069
1070 pub fn backward(
1072 &self,
1073 _output: &dyn ArrayProtocol,
1074 _target: &dyn ArrayProtocol,
1075 ) -> Result<crate::array_protocol::grad::GradientDict, crate::error::CoreError> {
1076 Ok(crate::array_protocol::grad::GradientDict::new())
1079 }
1080
1081 pub fn update_parameter(
1083 &mut self,
1084 param_name: &str,
1085 gradient: &dyn ArrayProtocol,
1086 learningrate: f64,
1087 ) -> Result<(), crate::error::CoreError> {
1088 let parts: Vec<&str> = param_name.split('.').collect();
1090 if parts.len() != 2 {
1091 return Err(crate::error::CoreError::ValueError(
1092 crate::error::ErrorContext::new(format!(
1093 "Invalid parameter name format. Expected 'layer_index.param_name', got: {param_name}"
1094 )),
1095 ));
1096 }
1097
1098 let layer_index: usize = parts[0].parse().map_err(|_| {
1099 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1100 "Invalid layer index: {layer_idx}",
1101 layer_idx = parts[0]
1102 )))
1103 })?;
1104
1105 let param_name = parts[1];
1106
1107 if layer_index >= self.layers.len() {
1108 return Err(crate::error::CoreError::ValueError(
1109 crate::error::ErrorContext::new(format!(
1110 "Layer index {layer_index} out of bounds (model has {num_layers} layers)",
1111 num_layers = self.layers.len()
1112 )),
1113 ));
1114 }
1115
1116 let layer = &mut self.layers[layer_index];
1118 let current_params = layer.parameters();
1119 let param_names = layer.parameter_names();
1120
1121 let param_idx = param_names
1123 .iter()
1124 .position(|name| name == param_name)
1125 .ok_or_else(|| {
1126 crate::error::CoreError::ValueError(crate::error::ErrorContext::new(format!(
1127 "Parameter '{param_name}' not found in layer {layer_index}"
1128 )))
1129 })?;
1130
1131 let current_param = ¤t_params[param_idx];
1133
1134 let scaled_gradient =
1136 crate::array_protocol::operations::multiply_by_scalar_f64(gradient, learningrate)
1137 .map_err(|e| {
1138 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(
1139 format!("Failed to scale gradient: {e}"),
1140 ))
1141 })?;
1142
1143 let updated_param = crate::array_protocol::operations::subtract(
1145 current_param.as_ref(),
1146 scaled_gradient.as_ref(),
1147 )
1148 .map_err(|e| {
1149 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1150 "Failed to update parameter: {e}"
1151 )))
1152 })?;
1153
1154 layer
1156 .update_parameter(param_name, updated_param)
1157 .map_err(|e| {
1158 crate::error::CoreError::ComputationError(crate::error::ErrorContext::new(format!(
1159 "Failed to set parameter in layer: {e}"
1160 )))
1161 })?;
1162
1163 Ok(())
1164 }
1165
1166 pub fn all_parameter_names(&self) -> Vec<String> {
1168 let mut all_names = Vec::new();
1169 for (layer_idx, layer) in self.layers.iter().enumerate() {
1170 let layer_param_names = layer.parameter_names();
1171 for param_name in layer_param_names {
1172 all_names.push(format!("{layer_idx}.{param_name}"));
1173 }
1174 }
1175 all_names
1176 }
1177
1178 pub fn all_parameters(&self) -> Vec<Box<dyn ArrayProtocol>> {
1180 let mut all_params = Vec::new();
1181 for layer in &self.layers {
1182 all_params.extend(layer.parameters());
1183 }
1184 all_params
1185 }
1186}
1187
1188#[allow(dead_code)]
1190pub fn create_simple_cnn(inputshape: (usize, usize, usize), num_classes: usize) -> Sequential {
1191 let (height, width, channels) = inputshape;
1192
1193 let mut model = Sequential::new("SimpleCNN", Vec::new());
1194
1195 model.add_layer(Box::new(Conv2D::withshape(
1197 "conv1",
1198 3,
1199 3, channels,
1201 32, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1206 )));
1207
1208 model.add_layer(Box::new(MaxPool2D::new(
1209 "pool1",
1210 (2, 2), None, (0, 0), )));
1214
1215 model.add_layer(Box::new(Conv2D::withshape(
1217 "conv2",
1218 3,
1219 3, 32,
1221 64, (1, 1), (1, 1), true, Some(ActivationFunc::ReLU),
1226 )));
1227
1228 model.add_layer(Box::new(MaxPool2D::new(
1229 "pool2",
1230 (2, 2), None, (0, 0), )));
1234
1235 model.add_layer(Box::new(Linear::new_random(
1239 "fc1",
1240 64 * (height / 4) * (width / 4), 128, true, Some(ActivationFunc::ReLU),
1244 )));
1245
1246 model.add_layer(Box::new(Dropout::new(
1247 "dropout", 0.5, None, )));
1250
1251 model.add_layer(Box::new(Linear::new_random(
1252 "fc2",
1253 128, num_classes, true, None, )));
1258
1259 model
1260}
1261
1262#[cfg(test)]
1263mod tests {
1264 use super::*;
1265 use crate::array_protocol::{self, NdarrayWrapper};
1266 use ndarray::{Array1, Array2};
1267
1268 #[test]
1269 fn test_linear_layer() {
1270 array_protocol::init();
1272
1273 let weights = Array2::<f64>::eye(3);
1275 let bias = Array1::<f64>::ones(3);
1276
1277 let layer = Linear::new(
1278 "linear",
1279 Box::new(NdarrayWrapper::new(weights)),
1280 Some(Box::new(NdarrayWrapper::new(bias))),
1281 Some(ActivationFunc::ReLU),
1282 );
1283
1284 assert_eq!(layer.name(), "linear");
1295 assert!(layer.is_training());
1296 }
1297
1298 #[test]
1299 fn test_sequential_model() {
1300 array_protocol::init();
1302
1303 let mut model = Sequential::new("test_model", Vec::new());
1305
1306 model.add_layer(Box::new(Linear::new_random(
1308 "fc1",
1309 3, 2, true, Some(ActivationFunc::ReLU),
1313 )));
1314
1315 model.add_layer(Box::new(Linear::new_random(
1316 "fc2",
1317 2, 1, true, Some(ActivationFunc::Sigmoid),
1321 )));
1322
1323 assert_eq!(model.name(), "test_model");
1325 assert_eq!(model.layers().len(), 2);
1326 assert!(model.training);
1327 }
1328
1329 #[test]
1330 fn test_simple_cnn_creation() {
1331 array_protocol::init();
1333
1334 let model = create_simple_cnn((28, 28, 1), 10);
1336
1337 assert_eq!(model.layers().len(), 7);
1339 assert_eq!(model.name(), "SimpleCNN");
1340
1341 let params = model.parameters();
1343 assert!(!params.is_empty());
1344 }
1345}