1use crate::Layer;
10use crate::{Error, Result};
11
12#[derive(Debug, Clone)]
14pub struct Mlp {
15 layers: Vec<Layer>,
16}
17
18#[derive(Debug, Clone)]
22pub struct Scratch {
23 layer_outputs: Vec<Vec<f32>>,
24}
25
26#[derive(Debug, Clone)]
31pub struct BatchScratch {
32 batch_size: usize,
33 layer_outputs: Vec<Vec<f32>>,
34}
35
36#[derive(Debug, Clone)]
40pub struct BatchBackpropScratch {
41 batch_size: usize,
42 max_dim: usize,
43 buf0: Vec<f32>,
44 buf1: Vec<f32>,
45}
46
47impl BatchBackpropScratch {
48 pub fn new(mlp: &Mlp, batch_size: usize) -> Self {
50 assert!(batch_size > 0, "batch_size must be > 0");
51
52 let mut max_dim = mlp.input_dim();
53 for layer in &mlp.layers {
54 max_dim = max_dim.max(layer.in_dim());
55 max_dim = max_dim.max(layer.out_dim());
56 }
57
58 let len = batch_size * max_dim;
59 Self {
60 batch_size,
61 max_dim,
62 buf0: vec![0.0; len],
63 buf1: vec![0.0; len],
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
72pub struct Gradients {
73 d_weights: Vec<Vec<f32>>,
74 d_biases: Vec<Vec<f32>>,
75
76 d_layer_outputs: Vec<Vec<f32>>,
80
81 d_input: Vec<f32>,
82}
83
84impl Mlp {
85 pub(crate) fn from_layers(layers: Vec<Layer>) -> Self {
86 Self { layers }
87 }
88
89 #[inline]
91 pub fn input_dim(&self) -> usize {
92 self.layers
93 .first()
94 .expect("mlp must have at least one layer")
95 .in_dim()
96 }
97
98 #[inline]
100 pub fn output_dim(&self) -> usize {
101 self.layers
102 .last()
103 .expect("mlp must have at least one layer")
104 .out_dim()
105 }
106
107 #[inline]
109 pub fn num_layers(&self) -> usize {
110 self.layers.len()
111 }
112
113 #[inline]
117 pub fn layer(&self, idx: usize) -> Option<&Layer> {
118 self.layers.get(idx)
119 }
120
121 #[inline]
122 pub(crate) fn layer_mut(&mut self, idx: usize) -> Option<&mut Layer> {
123 self.layers.get_mut(idx)
124 }
125
126 pub fn scratch(&self) -> Scratch {
128 Scratch::new(self)
129 }
130
131 pub fn scratch_batch(&self, batch_size: usize) -> BatchScratch {
133 BatchScratch::new(self, batch_size)
134 }
135
136 pub fn backprop_scratch_batch(&self, batch_size: usize) -> BatchBackpropScratch {
138 BatchBackpropScratch::new(self, batch_size)
139 }
140
141 pub fn gradients(&self) -> Gradients {
143 Gradients::new(self)
144 }
145
146 #[inline]
148 pub fn trainer(&self) -> Trainer {
149 Trainer::new(self)
150 }
151
152 pub fn forward<'a>(&self, input: &[f32], scratch: &'a mut Scratch) -> &'a [f32] {
160 assert_eq!(
161 input.len(),
162 self.input_dim(),
163 "input len {} does not match model input_dim {}",
164 input.len(),
165 self.input_dim()
166 );
167 assert_eq!(
168 scratch.layer_outputs.len(),
169 self.layers.len(),
170 "scratch has {} layer outputs, model has {} layers",
171 scratch.layer_outputs.len(),
172 self.layers.len()
173 );
174
175 for (idx, layer) in self.layers.iter().enumerate() {
176 if idx == 0 {
177 let out = &mut scratch.layer_outputs[0];
178 assert_eq!(
179 out.len(),
180 layer.out_dim(),
181 "scratch layer 0 output len {} does not match layer out_dim {}",
182 out.len(),
183 layer.out_dim()
184 );
185 layer.forward(input, out);
186 } else {
187 let (left, right) = scratch.layer_outputs.split_at_mut(idx);
189 let prev = &left[idx - 1];
190 let out = &mut right[0];
191 assert_eq!(
192 out.len(),
193 layer.out_dim(),
194 "scratch layer {idx} output len {} does not match layer out_dim {}",
195 out.len(),
196 layer.out_dim()
197 );
198 layer.forward(prev, out);
199 }
200 }
201
202 scratch.output()
203 }
204
205 pub fn forward_batch<'a>(&self, inputs: &[f32], scratch: &'a mut BatchScratch) -> &'a [f32] {
214 let batch_size = scratch.batch_size;
215 assert!(batch_size > 0, "batch_size must be > 0");
216 assert_eq!(
217 inputs.len(),
218 batch_size * self.input_dim(),
219 "inputs len {} does not match batch_size * input_dim ({} * {})",
220 inputs.len(),
221 batch_size,
222 self.input_dim()
223 );
224 assert_eq!(
225 scratch.layer_outputs.len(),
226 self.layers.len(),
227 "batch scratch has {} layer outputs, model has {} layers",
228 scratch.layer_outputs.len(),
229 self.layers.len()
230 );
231
232 for (idx, layer) in self.layers.iter().enumerate() {
233 let out_dim = layer.out_dim();
234 let in_dim = layer.in_dim();
235
236 if idx == 0 {
237 let out = &mut scratch.layer_outputs[0];
238 assert_eq!(
239 out.len(),
240 batch_size * out_dim,
241 "batch scratch layer 0 output len {} does not match batch_size * out_dim ({} * {})",
242 out.len(),
243 batch_size,
244 out_dim
245 );
246
247 crate::matmul::gemm_f32(
251 batch_size,
252 out_dim,
253 in_dim,
254 1.0,
255 inputs,
256 in_dim,
257 1,
258 layer.weights(),
259 1,
260 in_dim,
261 0.0,
262 out,
263 out_dim,
264 1,
265 );
266
267 let activation = layer.activation();
268 let b = layer.biases();
269 debug_assert_eq!(b.len(), out_dim);
270 for row in 0..batch_size {
271 let o0 = row * out_dim;
272 for o in 0..out_dim {
273 let z = out[o0 + o] + b[o];
274 out[o0 + o] = activation.forward(z);
275 }
276 }
277 } else {
278 let (left, right) = scratch.layer_outputs.split_at_mut(idx);
280 let prev = &left[idx - 1];
281 let out = &mut right[0];
282
283 assert_eq!(
284 prev.len(),
285 batch_size * in_dim,
286 "batch scratch layer {} input len {} does not match batch_size * in_dim ({} * {})",
287 idx - 1,
288 prev.len(),
289 batch_size,
290 in_dim
291 );
292 assert_eq!(
293 out.len(),
294 batch_size * out_dim,
295 "batch scratch layer {idx} output len {} does not match batch_size * out_dim ({} * {})",
296 out.len(),
297 batch_size,
298 out_dim
299 );
300
301 crate::matmul::gemm_f32(
302 batch_size,
303 out_dim,
304 in_dim,
305 1.0,
306 prev,
307 in_dim,
308 1,
309 layer.weights(),
310 1,
311 in_dim,
312 0.0,
313 out,
314 out_dim,
315 1,
316 );
317
318 let activation = layer.activation();
319 let b = layer.biases();
320 debug_assert_eq!(b.len(), out_dim);
321 for row in 0..batch_size {
322 let o0 = row * out_dim;
323 for o in 0..out_dim {
324 let z = out[o0 + o] + b[o];
325 out[o0 + o] = activation.forward(z);
326 }
327 }
328 }
329 }
330
331 scratch.output()
332 }
333
334 pub fn backward<'a>(
346 &self,
347 input: &[f32],
348 scratch: &Scratch,
349 grads: &'a mut Gradients,
350 ) -> &'a [f32] {
351 assert_eq!(
352 input.len(),
353 self.input_dim(),
354 "input len {} does not match model input_dim {}",
355 input.len(),
356 self.input_dim()
357 );
358 assert_eq!(
359 scratch.layer_outputs.len(),
360 self.layers.len(),
361 "scratch has {} layer outputs, model has {} layers",
362 scratch.layer_outputs.len(),
363 self.layers.len()
364 );
365
366 assert_eq!(
367 grads.d_weights.len(),
368 self.layers.len(),
369 "grads has {} d_weights entries, model has {} layers",
370 grads.d_weights.len(),
371 self.layers.len()
372 );
373 assert_eq!(
374 grads.d_biases.len(),
375 self.layers.len(),
376 "grads has {} d_biases entries, model has {} layers",
377 grads.d_biases.len(),
378 self.layers.len()
379 );
380 assert_eq!(
381 grads.d_layer_outputs.len(),
382 self.layers.len(),
383 "grads has {} d_layer_outputs entries, model has {} layers",
384 grads.d_layer_outputs.len(),
385 self.layers.len()
386 );
387 assert_eq!(
388 grads.d_input.len(),
389 self.input_dim(),
390 "grads d_input len {} does not match model input_dim {}",
391 grads.d_input.len(),
392 self.input_dim()
393 );
394
395 let last = self.layers.len() - 1;
396 assert_eq!(
397 grads.d_layer_outputs[last].len(),
398 self.output_dim(),
399 "grads d_output len {} does not match model output_dim {}",
400 grads.d_layer_outputs[last].len(),
401 self.output_dim()
402 );
403
404 for idx in (0..self.layers.len()).rev() {
405 let layer = &self.layers[idx];
406
407 let layer_input: &[f32] = if idx == 0 {
408 input
409 } else {
410 &scratch.layer_outputs[idx - 1]
411 };
412
413 let layer_output: &[f32] = &scratch.layer_outputs[idx];
414 assert_eq!(
415 layer_output.len(),
416 layer.out_dim(),
417 "scratch layer {idx} output len {} does not match layer out_dim {}",
418 layer_output.len(),
419 layer.out_dim()
420 );
421
422 if idx == 0 {
423 let d_outputs = &grads.d_layer_outputs[0];
424 layer.backward(
425 layer_input,
426 layer_output,
427 d_outputs,
428 &mut grads.d_input,
429 &mut grads.d_weights[0],
430 &mut grads.d_biases[0],
431 );
432 } else {
433 let (left, right) = grads.d_layer_outputs.split_at_mut(idx);
437 let d_inputs_prev = &mut left[idx - 1];
438 let d_outputs = &right[0];
439 layer.backward(
440 layer_input,
441 layer_output,
442 d_outputs,
443 d_inputs_prev,
444 &mut grads.d_weights[idx],
445 &mut grads.d_biases[idx],
446 );
447 }
448 }
449
450 &grads.d_input
451 }
452
453 pub fn backward_accumulate<'a>(
465 &self,
466 input: &[f32],
467 scratch: &Scratch,
468 grads: &'a mut Gradients,
469 ) -> &'a [f32] {
470 assert_eq!(
471 input.len(),
472 self.input_dim(),
473 "input len {} does not match model input_dim {}",
474 input.len(),
475 self.input_dim()
476 );
477 assert_eq!(
478 scratch.layer_outputs.len(),
479 self.layers.len(),
480 "scratch has {} layer outputs, model has {} layers",
481 scratch.layer_outputs.len(),
482 self.layers.len()
483 );
484
485 assert_eq!(
486 grads.d_weights.len(),
487 self.layers.len(),
488 "grads has {} d_weights entries, model has {} layers",
489 grads.d_weights.len(),
490 self.layers.len()
491 );
492 assert_eq!(
493 grads.d_biases.len(),
494 self.layers.len(),
495 "grads has {} d_biases entries, model has {} layers",
496 grads.d_biases.len(),
497 self.layers.len()
498 );
499 assert_eq!(
500 grads.d_layer_outputs.len(),
501 self.layers.len(),
502 "grads has {} d_layer_outputs entries, model has {} layers",
503 grads.d_layer_outputs.len(),
504 self.layers.len()
505 );
506 assert_eq!(
507 grads.d_input.len(),
508 self.input_dim(),
509 "grads d_input len {} does not match model input_dim {}",
510 grads.d_input.len(),
511 self.input_dim()
512 );
513
514 let last = self.layers.len() - 1;
515 assert_eq!(
516 grads.d_layer_outputs[last].len(),
517 self.output_dim(),
518 "grads d_output len {} does not match model output_dim {}",
519 grads.d_layer_outputs[last].len(),
520 self.output_dim()
521 );
522
523 for idx in (0..self.layers.len()).rev() {
524 let layer = &self.layers[idx];
525
526 let layer_input: &[f32] = if idx == 0 {
527 input
528 } else {
529 &scratch.layer_outputs[idx - 1]
530 };
531
532 let layer_output: &[f32] = &scratch.layer_outputs[idx];
533 assert_eq!(
534 layer_output.len(),
535 layer.out_dim(),
536 "scratch layer {idx} output len {} does not match layer out_dim {}",
537 layer_output.len(),
538 layer.out_dim()
539 );
540
541 if idx == 0 {
542 let d_outputs = &grads.d_layer_outputs[0];
543 layer.backward_accumulate(
544 layer_input,
545 layer_output,
546 d_outputs,
547 &mut grads.d_input,
548 &mut grads.d_weights[0],
549 &mut grads.d_biases[0],
550 );
551 } else {
552 let (left, right) = grads.d_layer_outputs.split_at_mut(idx);
553 let d_inputs_prev = &mut left[idx - 1];
554 let d_outputs = &right[0];
555 layer.backward_accumulate(
556 layer_input,
557 layer_output,
558 d_outputs,
559 d_inputs_prev,
560 &mut grads.d_weights[idx],
561 &mut grads.d_biases[idx],
562 );
563 }
564 }
565
566 &grads.d_input
567 }
568
569 pub fn backward_batch(
578 &self,
579 inputs: &[f32],
580 scratch: &BatchScratch,
581 d_outputs: &[f32],
582 grads: &mut Gradients,
583 backprop_scratch: &mut BatchBackpropScratch,
584 ) {
585 let batch_size = scratch.batch_size;
586 assert!(batch_size > 0, "batch_size must be > 0");
587 assert_eq!(
588 backprop_scratch.batch_size, batch_size,
589 "backprop scratch batch_size {} does not match scratch batch_size {}",
590 backprop_scratch.batch_size, batch_size
591 );
592 assert_eq!(
593 inputs.len(),
594 batch_size * self.input_dim(),
595 "inputs len {} does not match batch_size * input_dim ({} * {})",
596 inputs.len(),
597 batch_size,
598 self.input_dim()
599 );
600 assert_eq!(
601 d_outputs.len(),
602 batch_size * self.output_dim(),
603 "d_outputs len {} does not match batch_size * output_dim ({} * {})",
604 d_outputs.len(),
605 batch_size,
606 self.output_dim()
607 );
608 assert_eq!(
609 scratch.layer_outputs.len(),
610 self.layers.len(),
611 "batch scratch has {} layer outputs, model has {} layers",
612 scratch.layer_outputs.len(),
613 self.layers.len()
614 );
615
616 for (idx, (buf, layer)) in scratch.layer_outputs.iter().zip(&self.layers).enumerate() {
617 assert_eq!(
618 buf.len(),
619 batch_size * layer.out_dim(),
620 "batch scratch layer {idx} output len {} does not match batch_size * out_dim ({} * {})",
621 buf.len(),
622 batch_size,
623 layer.out_dim()
624 );
625 }
626
627 for (idx, layer) in self.layers.iter().enumerate() {
629 let out_dim = layer.out_dim();
630 let in_dim = layer.in_dim();
631 grads.d_weights[idx].fill(0.0);
632 grads.d_biases[idx].fill(0.0);
633 debug_assert_eq!(grads.d_weights[idx].len(), out_dim * in_dim);
634 debug_assert_eq!(grads.d_biases[idx].len(), out_dim);
635 }
636
637 let needed = batch_size * backprop_scratch.max_dim;
639 assert!(
640 backprop_scratch.buf0.len() >= needed && backprop_scratch.buf1.len() >= needed,
641 "backprop scratch buffers are too small"
642 );
643
644 let inv_batch = 1.0 / batch_size as f32;
645
646 let mut cur_dim = self.output_dim();
648 let cur_len = batch_size * cur_dim;
649 backprop_scratch.buf0[..cur_len].copy_from_slice(d_outputs);
650 let mut cur_in_buf0 = true;
651
652 for idx in (0..self.layers.len()).rev() {
653 let layer = &self.layers[idx];
654 let out_dim = layer.out_dim();
655 let in_dim = layer.in_dim();
656
657 debug_assert_eq!(cur_dim, out_dim);
658
659 let (cur_buf, other_buf) = if cur_in_buf0 {
660 (&mut backprop_scratch.buf0, &mut backprop_scratch.buf1)
661 } else {
662 (&mut backprop_scratch.buf1, &mut backprop_scratch.buf0)
663 };
664
665 let d_cur: &mut [f32] = &mut cur_buf[..batch_size * out_dim];
666
667 let y = &scratch.layer_outputs[idx];
668 debug_assert_eq!(y.len(), batch_size * out_dim);
669
670 let activation = layer.activation();
672 for i in 0..d_cur.len() {
673 d_cur[i] *= activation.grad_from_output(y[i]);
674 }
675
676 let db = &mut grads.d_biases[idx];
678 assert_eq!(db.len(), out_dim);
679 db.fill(0.0);
680 for b in 0..batch_size {
681 let row0 = b * out_dim;
682 for o in 0..out_dim {
683 db[o] += d_cur[row0 + o];
684 }
685 }
686 for v in db.iter_mut() {
687 *v *= inv_batch;
688 }
689
690 let x: &[f32] = if idx == 0 {
692 inputs
693 } else {
694 &scratch.layer_outputs[idx - 1]
695 };
696 assert_eq!(x.len(), batch_size * in_dim);
697
698 let dw = &mut grads.d_weights[idx];
699 assert_eq!(dw.len(), out_dim * in_dim);
700 crate::matmul::gemm_f32(
701 out_dim, in_dim, batch_size, inv_batch, d_cur, 1, out_dim, x, in_dim, 1, 0.0, dw,
702 in_dim, 1,
703 );
704
705 if idx == 0 {
706 break;
707 }
708
709 let d_x: &mut [f32] = &mut other_buf[..batch_size * in_dim];
711 crate::matmul::gemm_f32(
712 batch_size,
713 in_dim,
714 out_dim,
715 1.0,
716 d_cur,
717 out_dim,
718 1,
719 layer.weights(),
720 in_dim,
721 1,
722 0.0,
723 d_x,
724 in_dim,
725 1,
726 );
727
728 cur_in_buf0 = !cur_in_buf0;
729 cur_dim = in_dim;
730 }
731 }
732
733 #[inline]
735 pub fn sgd_step(&mut self, grads: &Gradients, lr: f32) {
736 assert!(
737 lr.is_finite() && lr > 0.0,
738 "learning rate must be finite and > 0"
739 );
740 assert_eq!(
741 self.layers.len(),
742 grads.d_weights.len(),
743 "grads has {} d_weights entries, model has {} layers",
744 grads.d_weights.len(),
745 self.layers.len()
746 );
747 assert_eq!(
748 self.layers.len(),
749 grads.d_biases.len(),
750 "grads has {} d_biases entries, model has {} layers",
751 grads.d_biases.len(),
752 self.layers.len()
753 );
754
755 for i in 0..self.layers.len() {
756 self.layers[i].sgd_step(&grads.d_weights[i], &grads.d_biases[i], lr);
757 }
758 }
759
760 pub(crate) fn apply_weight_decay(&mut self, lr: f32, weight_decay: f32) {
764 assert!(
765 lr.is_finite() && lr > 0.0,
766 "learning rate must be finite and > 0"
767 );
768 assert!(
769 weight_decay.is_finite() && weight_decay >= 0.0,
770 "weight_decay must be finite and >= 0"
771 );
772
773 if weight_decay == 0.0 {
774 return;
775 }
776
777 for layer in &mut self.layers {
778 layer.apply_weight_decay(lr, weight_decay);
779 }
780 }
781
782 pub fn predict_into(
787 &self,
788 input: &[f32],
789 scratch: &mut Scratch,
790 out: &mut [f32],
791 ) -> Result<()> {
792 if input.len() != self.input_dim() {
793 return Err(Error::InvalidData(format!(
794 "input len {} does not match model input_dim {}",
795 input.len(),
796 self.input_dim()
797 )));
798 }
799 if out.len() != self.output_dim() {
800 return Err(Error::InvalidData(format!(
801 "out len {} does not match model output_dim {}",
802 out.len(),
803 self.output_dim()
804 )));
805 }
806 if scratch.layer_outputs.len() != self.layers.len() {
807 return Err(Error::InvalidData(format!(
808 "scratch has {} layer outputs, model has {} layers",
809 scratch.layer_outputs.len(),
810 self.layers.len()
811 )));
812 }
813 for (idx, (buf, layer)) in scratch.layer_outputs.iter().zip(&self.layers).enumerate() {
814 if buf.len() != layer.out_dim() {
815 return Err(Error::InvalidData(format!(
816 "scratch layer {idx} output len {} does not match layer out_dim {}",
817 buf.len(),
818 layer.out_dim()
819 )));
820 }
821 }
822
823 let y = self.forward(input, scratch);
824 out.copy_from_slice(y);
825 Ok(())
826 }
827
828 #[inline]
832 pub fn predict_one_into(
833 &self,
834 input: &[f32],
835 scratch: &mut Scratch,
836 out: &mut [f32],
837 ) -> Result<()> {
838 self.predict_into(input, scratch, out)
839 }
840}
841
842#[derive(Debug, Clone)]
846pub struct Trainer {
847 pub scratch: Scratch,
848 pub grads: Gradients,
849}
850
851impl Trainer {
852 pub fn new(mlp: &Mlp) -> Self {
854 Self {
855 scratch: Scratch::new(mlp),
856 grads: Gradients::new(mlp),
857 }
858 }
859}
860
861impl Scratch {
862 pub fn new(mlp: &Mlp) -> Self {
864 let mut layer_outputs = Vec::with_capacity(mlp.layers.len());
865 for layer in &mlp.layers {
866 layer_outputs.push(vec![0.0; layer.out_dim()]);
867 }
868 Self { layer_outputs }
869 }
870
871 #[inline]
872 pub fn output(&self) -> &[f32] {
874 self.layer_outputs
875 .last()
876 .expect("scratch must have at least one layer output")
877 .as_slice()
878 }
879}
880
881impl BatchScratch {
882 pub fn new(mlp: &Mlp, batch_size: usize) -> Self {
884 assert!(batch_size > 0, "batch_size must be > 0");
885
886 let mut layer_outputs = Vec::with_capacity(mlp.layers.len());
887 for layer in &mlp.layers {
888 layer_outputs.push(vec![0.0; batch_size * layer.out_dim()]);
889 }
890 Self {
891 batch_size,
892 layer_outputs,
893 }
894 }
895
896 #[inline]
897 pub fn batch_size(&self) -> usize {
899 self.batch_size
900 }
901
902 #[inline]
903 pub fn output(&self) -> &[f32] {
907 self.layer_outputs
908 .last()
909 .expect("batch scratch must have at least one layer output")
910 .as_slice()
911 }
912
913 #[inline]
914 pub fn output_row(&self, idx: usize) -> &[f32] {
918 assert!(idx < self.batch_size, "batch index out of bounds");
919
920 let out = self
921 .layer_outputs
922 .last()
923 .expect("batch scratch must have at least one layer output");
924 let out_dim = out.len() / self.batch_size;
925 let start = idx * out_dim;
926 &out[start..start + out_dim]
927 }
928}
929
930impl Gradients {
931 pub fn new(mlp: &Mlp) -> Self {
933 let mut d_weights = Vec::with_capacity(mlp.layers.len());
934 let mut d_biases = Vec::with_capacity(mlp.layers.len());
935 let mut d_layer_outputs = Vec::with_capacity(mlp.layers.len());
936
937 for layer in &mlp.layers {
938 d_weights.push(vec![0.0; layer.in_dim() * layer.out_dim()]);
939 d_biases.push(vec![0.0; layer.out_dim()]);
940 d_layer_outputs.push(vec![0.0; layer.out_dim()]);
941 }
942
943 let d_input = vec![0.0; mlp.input_dim()];
944
945 Self {
946 d_weights,
947 d_biases,
948 d_layer_outputs,
949 d_input,
950 }
951 }
952
953 #[inline]
960 pub fn d_output_mut(&mut self) -> &mut [f32] {
961 self.d_layer_outputs
962 .last_mut()
963 .expect("mlp must have at least one layer")
964 .as_mut_slice()
965 }
966
967 #[inline]
968 pub fn d_output(&self) -> &[f32] {
970 self.d_layer_outputs
971 .last()
972 .expect("mlp must have at least one layer")
973 .as_slice()
974 }
975
976 #[inline]
977 pub fn d_input(&self) -> &[f32] {
979 &self.d_input
980 }
981
982 #[inline]
984 pub fn d_weights(&self, layer_idx: usize) -> &[f32] {
985 &self.d_weights[layer_idx]
986 }
987
988 #[inline]
990 pub fn d_weights_mut(&mut self, layer_idx: usize) -> &mut [f32] {
991 &mut self.d_weights[layer_idx]
992 }
993
994 #[inline]
996 pub fn d_biases(&self, layer_idx: usize) -> &[f32] {
997 &self.d_biases[layer_idx]
998 }
999
1000 #[inline]
1002 pub fn d_biases_mut(&mut self, layer_idx: usize) -> &mut [f32] {
1003 &mut self.d_biases[layer_idx]
1004 }
1005
1006 #[inline]
1008 pub fn zero_params(&mut self) {
1009 for w in &mut self.d_weights {
1010 w.fill(0.0);
1011 }
1012 for b in &mut self.d_biases {
1013 b.fill(0.0);
1014 }
1015 }
1016
1017 #[inline]
1019 pub fn scale_params(&mut self, scale: f32) {
1020 assert!(scale.is_finite(), "scale must be finite");
1021
1022 for w in &mut self.d_weights {
1023 for v in w.iter_mut() {
1024 *v *= scale;
1025 }
1026 }
1027 for b in &mut self.d_biases {
1028 for v in b.iter_mut() {
1029 *v *= scale;
1030 }
1031 }
1032 }
1033
1034 pub fn global_l2_norm_params(&self) -> f32 {
1036 let mut sum_sq = 0.0_f32;
1037 for w in &self.d_weights {
1038 for &v in w {
1039 sum_sq = v.mul_add(v, sum_sq);
1040 }
1041 }
1042 for b in &self.d_biases {
1043 for &v in b {
1044 sum_sq = v.mul_add(v, sum_sq);
1045 }
1046 }
1047 sum_sq.sqrt()
1048 }
1049
1050 pub fn clip_global_norm_params(&mut self, max_norm: f32) -> f32 {
1054 assert!(
1055 max_norm.is_finite() && max_norm > 0.0,
1056 "max_norm must be finite and > 0"
1057 );
1058
1059 let norm = self.global_l2_norm_params();
1060 if norm > max_norm && norm > 0.0 {
1061 self.scale_params(max_norm / norm);
1062 }
1063 norm
1064 }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070 use rand::SeedableRng;
1071 use rand::rngs::StdRng;
1072
1073 use crate::{Activation, MlpBuilder};
1074
1075 fn loss_for_mlp(mlp: &Mlp, input: &[f32], target: &[f32], scratch: &mut Scratch) -> f32 {
1076 mlp.forward(input, scratch);
1077 crate::loss::mse(scratch.output(), target)
1078 }
1079
1080 fn assert_close(analytic: f32, numeric: f32, abs_tol: f32, rel_tol: f32) {
1081 let diff = (analytic - numeric).abs();
1082 let scale = analytic.abs().max(numeric.abs()).max(1.0);
1083 assert!(
1084 diff <= abs_tol || diff / scale <= rel_tol,
1085 "analytic={analytic} numeric={numeric} diff={diff}"
1086 );
1087 }
1088
1089 #[test]
1090 fn predict_into_validates_shapes() {
1091 let mlp = MlpBuilder::new(2)
1092 .unwrap()
1093 .add_layer(3, Activation::Tanh)
1094 .unwrap()
1095 .add_layer(1, Activation::Identity)
1096 .unwrap()
1097 .build_with_seed(0)
1098 .unwrap();
1099
1100 let mut scratch = mlp.scratch();
1101 let mut out = [0.0_f32; 1];
1102
1103 let ok = mlp.predict_into(&[0.1, 0.2], &mut scratch, &mut out);
1104 assert!(ok.is_ok());
1105
1106 let err = mlp.predict_into(&[0.1_f32], &mut scratch, &mut out);
1107 assert!(err.is_err());
1108 }
1109
1110 #[test]
1111 fn backward_matches_numeric_gradients_for_tanh() {
1112 let mut mlp = MlpBuilder::new(2)
1113 .unwrap()
1114 .add_layer(3, Activation::Tanh)
1115 .unwrap()
1116 .add_layer(1, Activation::Tanh)
1117 .unwrap()
1118 .build_with_seed(0)
1119 .unwrap();
1120
1121 let mut scratch = mlp.scratch();
1122 let mut grads = mlp.gradients();
1123
1124 let input = [0.3_f32, -0.7_f32];
1125 let target = [0.2_f32];
1126
1127 mlp.forward(&input, &mut scratch);
1128 let _loss = crate::loss::mse_backward(scratch.output(), &target, grads.d_output_mut());
1129 let d_input = mlp.backward(&input, &scratch, &mut grads).to_vec();
1130
1131 let eps = 1e-3_f32;
1132 let abs_tol = 1e-3_f32;
1133 let rel_tol = 1e-2_f32;
1134
1135 let mut scratch_tmp = mlp.scratch();
1136
1137 for layer_idx in 0..mlp.layers.len() {
1139 let w_len = mlp.layers[layer_idx].in_dim() * mlp.layers[layer_idx].out_dim();
1141 debug_assert_eq!(w_len, grads.d_weights(layer_idx).len());
1142
1143 for p in 0..w_len {
1144 let orig = {
1145 let w = mlp.layers[layer_idx].weights_mut();
1146 let orig = w[p];
1147 w[p] = orig + eps;
1148 orig
1149 };
1150 let loss_plus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1151
1152 {
1153 let w = mlp.layers[layer_idx].weights_mut();
1154 w[p] = orig - eps;
1155 }
1156 let loss_minus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1157
1158 {
1159 let w = mlp.layers[layer_idx].weights_mut();
1160 w[p] = orig;
1161 }
1162
1163 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1164 let analytic = grads.d_weights(layer_idx)[p];
1165 assert_close(analytic, numeric, abs_tol, rel_tol);
1166 }
1167
1168 let b_len = mlp.layers[layer_idx].out_dim();
1170 debug_assert_eq!(b_len, grads.d_biases(layer_idx).len());
1171
1172 for p in 0..b_len {
1173 let orig = {
1174 let b = mlp.layers[layer_idx].biases_mut();
1175 let orig = b[p];
1176 b[p] = orig + eps;
1177 orig
1178 };
1179 let loss_plus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1180
1181 {
1182 let b = mlp.layers[layer_idx].biases_mut();
1183 b[p] = orig - eps;
1184 }
1185 let loss_minus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1186
1187 {
1188 let b = mlp.layers[layer_idx].biases_mut();
1189 b[p] = orig;
1190 }
1191
1192 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1193 let analytic = grads.d_biases(layer_idx)[p];
1194 assert_close(analytic, numeric, abs_tol, rel_tol);
1195 }
1196 }
1197
1198 let mut input_var = input;
1200 for i in 0..input_var.len() {
1201 let orig = input_var[i];
1202
1203 input_var[i] = orig + eps;
1204 let loss_plus = loss_for_mlp(&mlp, &input_var, &target, &mut scratch_tmp);
1205
1206 input_var[i] = orig - eps;
1207 let loss_minus = loss_for_mlp(&mlp, &input_var, &target, &mut scratch_tmp);
1208
1209 input_var[i] = orig;
1210
1211 let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1212 let analytic = d_input[i];
1213 assert_close(analytic, numeric, abs_tol, rel_tol);
1214 }
1215 }
1216
1217 #[test]
1218 #[should_panic]
1219 fn forward_panics_on_input_shape_mismatch() {
1220 let mut rng = StdRng::seed_from_u64(0);
1221 let mlp = MlpBuilder::new(2)
1222 .unwrap()
1223 .add_layer(3, Activation::Tanh)
1224 .unwrap()
1225 .add_layer(1, Activation::Tanh)
1226 .unwrap()
1227 .build_with_rng(&mut rng)
1228 .unwrap();
1229 let mut scratch = mlp.scratch();
1230 let input = [0.0_f32; 3];
1231 mlp.forward(&input, &mut scratch);
1232 }
1233}