Skip to main content

rust_mlp/
mlp.rs

1//! Multi-layer perceptron (MLP) core.
2//!
3//! The low-level API is intentionally allocation-free:
4//! - `Mlp::forward` writes activations into a reusable `Scratch` and returns a slice.
5//! - `Mlp::backward` writes gradients into a reusable `Gradients`.
6//!
7//! Shape mismatches are treated as programmer error and will panic via `assert!`.
8
9use crate::Layer;
10use crate::{Error, Result};
11
12/// A feed-forward multi-layer perceptron composed of dense layers.
13#[derive(Debug, Clone)]
14pub struct Mlp {
15    layers: Vec<Layer>,
16}
17
18/// Reusable buffers for `Mlp::forward`.
19///
20/// The output of the most recent forward pass lives inside `Scratch`.
21#[derive(Debug, Clone)]
22pub struct Scratch {
23    layer_outputs: Vec<Vec<f32>>,
24}
25
26/// Reusable buffers for `Mlp::forward_batch`.
27///
28/// This stores per-layer outputs for all samples in the batch in flat row-major layout:
29/// - each layer buffer has shape `(batch_size, out_dim)`.
30#[derive(Debug, Clone)]
31pub struct BatchScratch {
32    batch_size: usize,
33    layer_outputs: Vec<Vec<f32>>,
34}
35
36/// Reusable buffers for `Mlp::backward_batch`.
37///
38/// This stores two work buffers sized for the maximum layer dimension.
39#[derive(Debug, Clone)]
40pub struct BatchBackpropScratch {
41    batch_size: usize,
42    max_dim: usize,
43    buf0: Vec<f32>,
44    buf1: Vec<f32>,
45}
46
47impl BatchBackpropScratch {
48    /// Allocate a backprop scratch buffer suitable for `mlp` and `batch_size`.
49    pub fn new(mlp: &Mlp, batch_size: usize) -> Self {
50        assert!(batch_size > 0, "batch_size must be > 0");
51
52        let mut max_dim = mlp.input_dim();
53        for layer in &mlp.layers {
54            max_dim = max_dim.max(layer.in_dim());
55            max_dim = max_dim.max(layer.out_dim());
56        }
57
58        let len = batch_size * max_dim;
59        Self {
60            batch_size,
61            max_dim,
62            buf0: vec![0.0; len],
63            buf1: vec![0.0; len],
64        }
65    }
66}
67
68/// Parameter gradients for an `Mlp` (overwrite semantics).
69///
70/// Allocate once via `Mlp::gradients()` and reuse across training steps.
71#[derive(Debug, Clone)]
72pub struct Gradients {
73    d_weights: Vec<Vec<f32>>,
74    d_biases: Vec<Vec<f32>>,
75
76    // Backprop intermediate: gradient w.r.t each layer output.
77    // This includes the final layer output; `Mlp::backward` copies the provided
78    // `d_output` into this buffer so it can uniformly backprop layer-by-layer.
79    d_layer_outputs: Vec<Vec<f32>>,
80
81    d_input: Vec<f32>,
82}
83
84impl Mlp {
85    pub(crate) fn from_layers(layers: Vec<Layer>) -> Self {
86        Self { layers }
87    }
88
89    /// Returns the expected input dimension.
90    #[inline]
91    pub fn input_dim(&self) -> usize {
92        self.layers
93            .first()
94            .expect("mlp must have at least one layer")
95            .in_dim()
96    }
97
98    /// Returns the produced output dimension.
99    #[inline]
100    pub fn output_dim(&self) -> usize {
101        self.layers
102            .last()
103            .expect("mlp must have at least one layer")
104            .out_dim()
105    }
106
107    /// Returns the number of layers.
108    #[inline]
109    pub fn num_layers(&self) -> usize {
110        self.layers.len()
111    }
112
113    /// Returns a reference to a layer by index.
114    ///
115    /// This is primarily useful for inspection and debugging.
116    #[inline]
117    pub fn layer(&self, idx: usize) -> Option<&Layer> {
118        self.layers.get(idx)
119    }
120
121    #[inline]
122    pub(crate) fn layer_mut(&mut self, idx: usize) -> Option<&mut Layer> {
123        self.layers.get_mut(idx)
124    }
125
126    /// Allocate a `Scratch` buffer suitable for this model.
127    pub fn scratch(&self) -> Scratch {
128        Scratch::new(self)
129    }
130
131    /// Allocate a `BatchScratch` buffer suitable for this model and a fixed batch size.
132    pub fn scratch_batch(&self, batch_size: usize) -> BatchScratch {
133        BatchScratch::new(self, batch_size)
134    }
135
136    /// Allocate a `BatchBackpropScratch` buffer suitable for this model and a fixed batch size.
137    pub fn backprop_scratch_batch(&self, batch_size: usize) -> BatchBackpropScratch {
138        BatchBackpropScratch::new(self, batch_size)
139    }
140
141    /// Allocate a `Gradients` buffer suitable for this model.
142    pub fn gradients(&self) -> Gradients {
143        Gradients::new(self)
144    }
145
146    /// Convenience constructor: allocate all training buffers.
147    #[inline]
148    pub fn trainer(&self) -> Trainer {
149        Trainer::new(self)
150    }
151
152    /// Forward pass for a single sample.
153    ///
154    /// Writes intermediate activations into `scratch` and returns the final output slice.
155    ///
156    /// Shape contract:
157    /// - `input.len() == self.input_dim()`
158    /// - `scratch` must be built for this `Mlp` (same layer count and output sizes)
159    pub fn forward<'a>(&self, input: &[f32], scratch: &'a mut Scratch) -> &'a [f32] {
160        assert_eq!(
161            input.len(),
162            self.input_dim(),
163            "input len {} does not match model input_dim {}",
164            input.len(),
165            self.input_dim()
166        );
167        assert_eq!(
168            scratch.layer_outputs.len(),
169            self.layers.len(),
170            "scratch has {} layer outputs, model has {} layers",
171            scratch.layer_outputs.len(),
172            self.layers.len()
173        );
174
175        for (idx, layer) in self.layers.iter().enumerate() {
176            if idx == 0 {
177                let out = &mut scratch.layer_outputs[0];
178                assert_eq!(
179                    out.len(),
180                    layer.out_dim(),
181                    "scratch layer 0 output len {} does not match layer out_dim {}",
182                    out.len(),
183                    layer.out_dim()
184                );
185                layer.forward(input, out);
186            } else {
187                // Borrow the previous output immutably and the current output mutably.
188                let (left, right) = scratch.layer_outputs.split_at_mut(idx);
189                let prev = &left[idx - 1];
190                let out = &mut right[0];
191                assert_eq!(
192                    out.len(),
193                    layer.out_dim(),
194                    "scratch layer {idx} output len {} does not match layer out_dim {}",
195                    out.len(),
196                    layer.out_dim()
197                );
198                layer.forward(prev, out);
199            }
200        }
201
202        scratch.output()
203    }
204
205    /// Forward pass for a contiguous batch.
206    ///
207    /// Writes intermediate activations into `scratch` and returns the final output buffer
208    /// for the whole batch (flat row-major).
209    ///
210    /// Shape contract:
211    /// - `inputs.len() == batch_size * self.input_dim()`
212    /// - `scratch` must be built for this `Mlp` and the same `batch_size`
213    pub fn forward_batch<'a>(&self, inputs: &[f32], scratch: &'a mut BatchScratch) -> &'a [f32] {
214        let batch_size = scratch.batch_size;
215        assert!(batch_size > 0, "batch_size must be > 0");
216        assert_eq!(
217            inputs.len(),
218            batch_size * self.input_dim(),
219            "inputs len {} does not match batch_size * input_dim ({} * {})",
220            inputs.len(),
221            batch_size,
222            self.input_dim()
223        );
224        assert_eq!(
225            scratch.layer_outputs.len(),
226            self.layers.len(),
227            "batch scratch has {} layer outputs, model has {} layers",
228            scratch.layer_outputs.len(),
229            self.layers.len()
230        );
231
232        for (idx, layer) in self.layers.iter().enumerate() {
233            let out_dim = layer.out_dim();
234            let in_dim = layer.in_dim();
235
236            if idx == 0 {
237                let out = &mut scratch.layer_outputs[0];
238                assert_eq!(
239                    out.len(),
240                    batch_size * out_dim,
241                    "batch scratch layer 0 output len {} does not match batch_size * out_dim ({} * {})",
242                    out.len(),
243                    batch_size,
244                    out_dim
245                );
246
247                // out = inputs * weights^T
248                // inputs: (batch_size, in_dim) row-major
249                // weights: (out_dim, in_dim) row-major, so weights^T is represented by strides.
250                crate::matmul::gemm_f32(
251                    batch_size,
252                    out_dim,
253                    in_dim,
254                    1.0,
255                    inputs,
256                    in_dim,
257                    1,
258                    layer.weights(),
259                    1,
260                    in_dim,
261                    0.0,
262                    out,
263                    out_dim,
264                    1,
265                );
266
267                let activation = layer.activation();
268                let b = layer.biases();
269                debug_assert_eq!(b.len(), out_dim);
270                for row in 0..batch_size {
271                    let o0 = row * out_dim;
272                    for o in 0..out_dim {
273                        let z = out[o0 + o] + b[o];
274                        out[o0 + o] = activation.forward(z);
275                    }
276                }
277            } else {
278                // Borrow the previous output immutably and the current output mutably.
279                let (left, right) = scratch.layer_outputs.split_at_mut(idx);
280                let prev = &left[idx - 1];
281                let out = &mut right[0];
282
283                assert_eq!(
284                    prev.len(),
285                    batch_size * in_dim,
286                    "batch scratch layer {} input len {} does not match batch_size * in_dim ({} * {})",
287                    idx - 1,
288                    prev.len(),
289                    batch_size,
290                    in_dim
291                );
292                assert_eq!(
293                    out.len(),
294                    batch_size * out_dim,
295                    "batch scratch layer {idx} output len {} does not match batch_size * out_dim ({} * {})",
296                    out.len(),
297                    batch_size,
298                    out_dim
299                );
300
301                crate::matmul::gemm_f32(
302                    batch_size,
303                    out_dim,
304                    in_dim,
305                    1.0,
306                    prev,
307                    in_dim,
308                    1,
309                    layer.weights(),
310                    1,
311                    in_dim,
312                    0.0,
313                    out,
314                    out_dim,
315                    1,
316                );
317
318                let activation = layer.activation();
319                let b = layer.biases();
320                debug_assert_eq!(b.len(), out_dim);
321                for row in 0..batch_size {
322                    let o0 = row * out_dim;
323                    for o in 0..out_dim {
324                        let z = out[o0 + o] + b[o];
325                        out[o0 + o] = activation.forward(z);
326                    }
327                }
328            }
329        }
330
331        scratch.output()
332    }
333
334    /// Backward pass for a single sample, using the internal `d_output` buffer.
335    ///
336    /// You must call `forward` first using the same `input` and `scratch`.
337    ///
338    /// Before calling this, write the upstream gradient `dL/d(output)` into
339    /// `grads.d_output_mut()`.
340    ///
341    /// Overwrite semantics:
342    /// - `grads` is overwritten with gradients for this sample.
343    ///
344    /// Returns dL/d(input).
345    pub fn backward<'a>(
346        &self,
347        input: &[f32],
348        scratch: &Scratch,
349        grads: &'a mut Gradients,
350    ) -> &'a [f32] {
351        assert_eq!(
352            input.len(),
353            self.input_dim(),
354            "input len {} does not match model input_dim {}",
355            input.len(),
356            self.input_dim()
357        );
358        assert_eq!(
359            scratch.layer_outputs.len(),
360            self.layers.len(),
361            "scratch has {} layer outputs, model has {} layers",
362            scratch.layer_outputs.len(),
363            self.layers.len()
364        );
365
366        assert_eq!(
367            grads.d_weights.len(),
368            self.layers.len(),
369            "grads has {} d_weights entries, model has {} layers",
370            grads.d_weights.len(),
371            self.layers.len()
372        );
373        assert_eq!(
374            grads.d_biases.len(),
375            self.layers.len(),
376            "grads has {} d_biases entries, model has {} layers",
377            grads.d_biases.len(),
378            self.layers.len()
379        );
380        assert_eq!(
381            grads.d_layer_outputs.len(),
382            self.layers.len(),
383            "grads has {} d_layer_outputs entries, model has {} layers",
384            grads.d_layer_outputs.len(),
385            self.layers.len()
386        );
387        assert_eq!(
388            grads.d_input.len(),
389            self.input_dim(),
390            "grads d_input len {} does not match model input_dim {}",
391            grads.d_input.len(),
392            self.input_dim()
393        );
394
395        let last = self.layers.len() - 1;
396        assert_eq!(
397            grads.d_layer_outputs[last].len(),
398            self.output_dim(),
399            "grads d_output len {} does not match model output_dim {}",
400            grads.d_layer_outputs[last].len(),
401            self.output_dim()
402        );
403
404        for idx in (0..self.layers.len()).rev() {
405            let layer = &self.layers[idx];
406
407            let layer_input: &[f32] = if idx == 0 {
408                input
409            } else {
410                &scratch.layer_outputs[idx - 1]
411            };
412
413            let layer_output: &[f32] = &scratch.layer_outputs[idx];
414            assert_eq!(
415                layer_output.len(),
416                layer.out_dim(),
417                "scratch layer {idx} output len {} does not match layer out_dim {}",
418                layer_output.len(),
419                layer.out_dim()
420            );
421
422            if idx == 0 {
423                let d_outputs = &grads.d_layer_outputs[0];
424                layer.backward(
425                    layer_input,
426                    layer_output,
427                    d_outputs,
428                    &mut grads.d_input,
429                    &mut grads.d_weights[0],
430                    &mut grads.d_biases[0],
431                );
432            } else {
433                // We need two different layer gradient buffers:
434                // - `d_outputs` for the current layer (read-only)
435                // - `d_inputs` for the current layer, which becomes `d_outputs` of the previous
436                let (left, right) = grads.d_layer_outputs.split_at_mut(idx);
437                let d_inputs_prev = &mut left[idx - 1];
438                let d_outputs = &right[0];
439                layer.backward(
440                    layer_input,
441                    layer_output,
442                    d_outputs,
443                    d_inputs_prev,
444                    &mut grads.d_weights[idx],
445                    &mut grads.d_biases[idx],
446                );
447            }
448        }
449
450        &grads.d_input
451    }
452
453    /// Backward pass for a single sample (parameter accumulation semantics).
454    ///
455    /// This is identical to `backward` except that parameter gradients are *accumulated*:
456    /// - `grads.d_weights` and `grads.d_biases` are accumulated into (`+=`)
457    /// - `grads.d_layer_outputs` and `grads.d_input` are overwritten
458    ///
459    /// This is useful for mini-batch training.
460    ///
461    /// You must call `forward` first using the same `input` and `scratch`.
462    /// Before calling this, write the upstream gradient `dL/d(output)` into
463    /// `grads.d_output_mut()`.
464    pub fn backward_accumulate<'a>(
465        &self,
466        input: &[f32],
467        scratch: &Scratch,
468        grads: &'a mut Gradients,
469    ) -> &'a [f32] {
470        assert_eq!(
471            input.len(),
472            self.input_dim(),
473            "input len {} does not match model input_dim {}",
474            input.len(),
475            self.input_dim()
476        );
477        assert_eq!(
478            scratch.layer_outputs.len(),
479            self.layers.len(),
480            "scratch has {} layer outputs, model has {} layers",
481            scratch.layer_outputs.len(),
482            self.layers.len()
483        );
484
485        assert_eq!(
486            grads.d_weights.len(),
487            self.layers.len(),
488            "grads has {} d_weights entries, model has {} layers",
489            grads.d_weights.len(),
490            self.layers.len()
491        );
492        assert_eq!(
493            grads.d_biases.len(),
494            self.layers.len(),
495            "grads has {} d_biases entries, model has {} layers",
496            grads.d_biases.len(),
497            self.layers.len()
498        );
499        assert_eq!(
500            grads.d_layer_outputs.len(),
501            self.layers.len(),
502            "grads has {} d_layer_outputs entries, model has {} layers",
503            grads.d_layer_outputs.len(),
504            self.layers.len()
505        );
506        assert_eq!(
507            grads.d_input.len(),
508            self.input_dim(),
509            "grads d_input len {} does not match model input_dim {}",
510            grads.d_input.len(),
511            self.input_dim()
512        );
513
514        let last = self.layers.len() - 1;
515        assert_eq!(
516            grads.d_layer_outputs[last].len(),
517            self.output_dim(),
518            "grads d_output len {} does not match model output_dim {}",
519            grads.d_layer_outputs[last].len(),
520            self.output_dim()
521        );
522
523        for idx in (0..self.layers.len()).rev() {
524            let layer = &self.layers[idx];
525
526            let layer_input: &[f32] = if idx == 0 {
527                input
528            } else {
529                &scratch.layer_outputs[idx - 1]
530            };
531
532            let layer_output: &[f32] = &scratch.layer_outputs[idx];
533            assert_eq!(
534                layer_output.len(),
535                layer.out_dim(),
536                "scratch layer {idx} output len {} does not match layer out_dim {}",
537                layer_output.len(),
538                layer.out_dim()
539            );
540
541            if idx == 0 {
542                let d_outputs = &grads.d_layer_outputs[0];
543                layer.backward_accumulate(
544                    layer_input,
545                    layer_output,
546                    d_outputs,
547                    &mut grads.d_input,
548                    &mut grads.d_weights[0],
549                    &mut grads.d_biases[0],
550                );
551            } else {
552                let (left, right) = grads.d_layer_outputs.split_at_mut(idx);
553                let d_inputs_prev = &mut left[idx - 1];
554                let d_outputs = &right[0];
555                layer.backward_accumulate(
556                    layer_input,
557                    layer_output,
558                    d_outputs,
559                    d_inputs_prev,
560                    &mut grads.d_weights[idx],
561                    &mut grads.d_biases[idx],
562                );
563            }
564        }
565
566        &grads.d_input
567    }
568
569    /// Backward pass for a contiguous batch.
570    ///
571    /// This overwrites `grads` with the *mean* parameter gradients over the batch.
572    ///
573    /// Inputs:
574    /// - `inputs`: flat row-major with shape `(batch_size, input_dim)`
575    /// - `scratch`: activations from `forward_batch`
576    /// - `d_outputs`: flat row-major upstream gradients with shape `(batch_size, output_dim)`
577    pub fn backward_batch(
578        &self,
579        inputs: &[f32],
580        scratch: &BatchScratch,
581        d_outputs: &[f32],
582        grads: &mut Gradients,
583        backprop_scratch: &mut BatchBackpropScratch,
584    ) {
585        let batch_size = scratch.batch_size;
586        assert!(batch_size > 0, "batch_size must be > 0");
587        assert_eq!(
588            backprop_scratch.batch_size, batch_size,
589            "backprop scratch batch_size {} does not match scratch batch_size {}",
590            backprop_scratch.batch_size, batch_size
591        );
592        assert_eq!(
593            inputs.len(),
594            batch_size * self.input_dim(),
595            "inputs len {} does not match batch_size * input_dim ({} * {})",
596            inputs.len(),
597            batch_size,
598            self.input_dim()
599        );
600        assert_eq!(
601            d_outputs.len(),
602            batch_size * self.output_dim(),
603            "d_outputs len {} does not match batch_size * output_dim ({} * {})",
604            d_outputs.len(),
605            batch_size,
606            self.output_dim()
607        );
608        assert_eq!(
609            scratch.layer_outputs.len(),
610            self.layers.len(),
611            "batch scratch has {} layer outputs, model has {} layers",
612            scratch.layer_outputs.len(),
613            self.layers.len()
614        );
615
616        for (idx, (buf, layer)) in scratch.layer_outputs.iter().zip(&self.layers).enumerate() {
617            assert_eq!(
618                buf.len(),
619                batch_size * layer.out_dim(),
620                "batch scratch layer {idx} output len {} does not match batch_size * out_dim ({} * {})",
621                buf.len(),
622                batch_size,
623                layer.out_dim()
624            );
625        }
626
627        // Overwrite semantics.
628        for (idx, layer) in self.layers.iter().enumerate() {
629            let out_dim = layer.out_dim();
630            let in_dim = layer.in_dim();
631            grads.d_weights[idx].fill(0.0);
632            grads.d_biases[idx].fill(0.0);
633            debug_assert_eq!(grads.d_weights[idx].len(), out_dim * in_dim);
634            debug_assert_eq!(grads.d_biases[idx].len(), out_dim);
635        }
636
637        // Ensure work buffers are large enough.
638        let needed = batch_size * backprop_scratch.max_dim;
639        assert!(
640            backprop_scratch.buf0.len() >= needed && backprop_scratch.buf1.len() >= needed,
641            "backprop scratch buffers are too small"
642        );
643
644        let inv_batch = 1.0 / batch_size as f32;
645
646        // d_cur holds dL/dy for the current layer output, then overwritten to dL/dz.
647        let mut cur_dim = self.output_dim();
648        let cur_len = batch_size * cur_dim;
649        backprop_scratch.buf0[..cur_len].copy_from_slice(d_outputs);
650        let mut cur_in_buf0 = true;
651
652        for idx in (0..self.layers.len()).rev() {
653            let layer = &self.layers[idx];
654            let out_dim = layer.out_dim();
655            let in_dim = layer.in_dim();
656
657            debug_assert_eq!(cur_dim, out_dim);
658
659            let (cur_buf, other_buf) = if cur_in_buf0 {
660                (&mut backprop_scratch.buf0, &mut backprop_scratch.buf1)
661            } else {
662                (&mut backprop_scratch.buf1, &mut backprop_scratch.buf0)
663            };
664
665            let d_cur: &mut [f32] = &mut cur_buf[..batch_size * out_dim];
666
667            let y = &scratch.layer_outputs[idx];
668            debug_assert_eq!(y.len(), batch_size * out_dim);
669
670            // dZ = dY * activation'(y)
671            let activation = layer.activation();
672            for i in 0..d_cur.len() {
673                d_cur[i] *= activation.grad_from_output(y[i]);
674            }
675
676            // db = mean over batch of dZ
677            let db = &mut grads.d_biases[idx];
678            assert_eq!(db.len(), out_dim);
679            db.fill(0.0);
680            for b in 0..batch_size {
681                let row0 = b * out_dim;
682                for o in 0..out_dim {
683                    db[o] += d_cur[row0 + o];
684                }
685            }
686            for v in db.iter_mut() {
687                *v *= inv_batch;
688            }
689
690            // dW = mean over batch of dZ^T * X
691            let x: &[f32] = if idx == 0 {
692                inputs
693            } else {
694                &scratch.layer_outputs[idx - 1]
695            };
696            assert_eq!(x.len(), batch_size * in_dim);
697
698            let dw = &mut grads.d_weights[idx];
699            assert_eq!(dw.len(), out_dim * in_dim);
700            crate::matmul::gemm_f32(
701                out_dim, in_dim, batch_size, inv_batch, d_cur, 1, out_dim, x, in_dim, 1, 0.0, dw,
702                in_dim, 1,
703            );
704
705            if idx == 0 {
706                break;
707            }
708
709            // dX = dZ * W into the other buffer.
710            let d_x: &mut [f32] = &mut other_buf[..batch_size * in_dim];
711            crate::matmul::gemm_f32(
712                batch_size,
713                in_dim,
714                out_dim,
715                1.0,
716                d_cur,
717                out_dim,
718                1,
719                layer.weights(),
720                in_dim,
721                1,
722                0.0,
723                d_x,
724                in_dim,
725                1,
726            );
727
728            cur_in_buf0 = !cur_in_buf0;
729            cur_dim = in_dim;
730        }
731    }
732
733    /// Applies an SGD update to all layers.
734    #[inline]
735    pub fn sgd_step(&mut self, grads: &Gradients, lr: f32) {
736        assert!(
737            lr.is_finite() && lr > 0.0,
738            "learning rate must be finite and > 0"
739        );
740        assert_eq!(
741            self.layers.len(),
742            grads.d_weights.len(),
743            "grads has {} d_weights entries, model has {} layers",
744            grads.d_weights.len(),
745            self.layers.len()
746        );
747        assert_eq!(
748            self.layers.len(),
749            grads.d_biases.len(),
750            "grads has {} d_biases entries, model has {} layers",
751            grads.d_biases.len(),
752            self.layers.len()
753        );
754
755        for i in 0..self.layers.len() {
756            self.layers[i].sgd_step(&grads.d_weights[i], &grads.d_biases[i], lr);
757        }
758    }
759
760    /// Apply decoupled weight decay to all layer weights.
761    ///
762    /// This updates weights only (biases are not decayed): `w -= lr * weight_decay * w`.
763    pub(crate) fn apply_weight_decay(&mut self, lr: f32, weight_decay: f32) {
764        assert!(
765            lr.is_finite() && lr > 0.0,
766            "learning rate must be finite and > 0"
767        );
768        assert!(
769            weight_decay.is_finite() && weight_decay >= 0.0,
770            "weight_decay must be finite and >= 0"
771        );
772
773        if weight_decay == 0.0 {
774            return;
775        }
776
777        for layer in &mut self.layers {
778            layer.apply_weight_decay(lr, weight_decay);
779        }
780    }
781
782    /// Shape-safe, non-allocating inference.
783    ///
784    /// This validates shapes and returns `Result` instead of panicking.
785    /// Internally it uses the low-level `forward` hot path.
786    pub fn predict_into(
787        &self,
788        input: &[f32],
789        scratch: &mut Scratch,
790        out: &mut [f32],
791    ) -> Result<()> {
792        if input.len() != self.input_dim() {
793            return Err(Error::InvalidData(format!(
794                "input len {} does not match model input_dim {}",
795                input.len(),
796                self.input_dim()
797            )));
798        }
799        if out.len() != self.output_dim() {
800            return Err(Error::InvalidData(format!(
801                "out len {} does not match model output_dim {}",
802                out.len(),
803                self.output_dim()
804            )));
805        }
806        if scratch.layer_outputs.len() != self.layers.len() {
807            return Err(Error::InvalidData(format!(
808                "scratch has {} layer outputs, model has {} layers",
809                scratch.layer_outputs.len(),
810                self.layers.len()
811            )));
812        }
813        for (idx, (buf, layer)) in scratch.layer_outputs.iter().zip(&self.layers).enumerate() {
814            if buf.len() != layer.out_dim() {
815                return Err(Error::InvalidData(format!(
816                    "scratch layer {idx} output len {} does not match layer out_dim {}",
817                    buf.len(),
818                    layer.out_dim()
819                )));
820            }
821        }
822
823        let y = self.forward(input, scratch);
824        out.copy_from_slice(y);
825        Ok(())
826    }
827
828    /// Shape-safe, non-allocating inference for a single input.
829    ///
830    /// Alias of [`Mlp::predict_into`].
831    #[inline]
832    pub fn predict_one_into(
833        &self,
834        input: &[f32],
835        scratch: &mut Scratch,
836        out: &mut [f32],
837    ) -> Result<()> {
838        self.predict_into(input, scratch, out)
839    }
840}
841
842/// Reusable buffers for training a specific `Mlp`.
843///
844/// This is the ergonomic wrapper around `Scratch` + `Gradients`.
845#[derive(Debug, Clone)]
846pub struct Trainer {
847    pub scratch: Scratch,
848    pub grads: Gradients,
849}
850
851impl Trainer {
852    /// Allocate a `Trainer` (scratch + gradients) for `mlp`.
853    pub fn new(mlp: &Mlp) -> Self {
854        Self {
855            scratch: Scratch::new(mlp),
856            grads: Gradients::new(mlp),
857        }
858    }
859}
860
861impl Scratch {
862    /// Allocate a scratch buffer suitable for `mlp`.
863    pub fn new(mlp: &Mlp) -> Self {
864        let mut layer_outputs = Vec::with_capacity(mlp.layers.len());
865        for layer in &mlp.layers {
866            layer_outputs.push(vec![0.0; layer.out_dim()]);
867        }
868        Self { layer_outputs }
869    }
870
871    #[inline]
872    /// Returns the final model output slice from the last `forward` call.
873    pub fn output(&self) -> &[f32] {
874        self.layer_outputs
875            .last()
876            .expect("scratch must have at least one layer output")
877            .as_slice()
878    }
879}
880
881impl BatchScratch {
882    /// Allocate a batch scratch buffer suitable for `mlp` and `batch_size`.
883    pub fn new(mlp: &Mlp, batch_size: usize) -> Self {
884        assert!(batch_size > 0, "batch_size must be > 0");
885
886        let mut layer_outputs = Vec::with_capacity(mlp.layers.len());
887        for layer in &mlp.layers {
888            layer_outputs.push(vec![0.0; batch_size * layer.out_dim()]);
889        }
890        Self {
891            batch_size,
892            layer_outputs,
893        }
894    }
895
896    #[inline]
897    /// Returns the fixed batch size.
898    pub fn batch_size(&self) -> usize {
899        self.batch_size
900    }
901
902    #[inline]
903    /// Returns the final model output buffer from the last `forward_batch` call.
904    ///
905    /// Shape: `(batch_size * output_dim,)`.
906    pub fn output(&self) -> &[f32] {
907        self.layer_outputs
908            .last()
909            .expect("batch scratch must have at least one layer output")
910            .as_slice()
911    }
912
913    #[inline]
914    /// Returns the `idx`-th output row (shape: `(output_dim,)`).
915    ///
916    /// Panics if `idx >= batch_size`.
917    pub fn output_row(&self, idx: usize) -> &[f32] {
918        assert!(idx < self.batch_size, "batch index out of bounds");
919
920        let out = self
921            .layer_outputs
922            .last()
923            .expect("batch scratch must have at least one layer output");
924        let out_dim = out.len() / self.batch_size;
925        let start = idx * out_dim;
926        &out[start..start + out_dim]
927    }
928}
929
930impl Gradients {
931    /// Allocate gradient buffers suitable for `mlp`.
932    pub fn new(mlp: &Mlp) -> Self {
933        let mut d_weights = Vec::with_capacity(mlp.layers.len());
934        let mut d_biases = Vec::with_capacity(mlp.layers.len());
935        let mut d_layer_outputs = Vec::with_capacity(mlp.layers.len());
936
937        for layer in &mlp.layers {
938            d_weights.push(vec![0.0; layer.in_dim() * layer.out_dim()]);
939            d_biases.push(vec![0.0; layer.out_dim()]);
940            d_layer_outputs.push(vec![0.0; layer.out_dim()]);
941        }
942
943        let d_input = vec![0.0; mlp.input_dim()];
944
945        Self {
946            d_weights,
947            d_biases,
948            d_layer_outputs,
949            d_input,
950        }
951    }
952
953    /// Mutable view of the upstream gradient buffer for the final model output.
954    ///
955    /// Typical training flow:
956    /// - `mlp.forward(input, &mut scratch)`
957    /// - loss writes `dL/d(output)` into `grads.d_output_mut()`
958    /// - `mlp.backward(input, &scratch, &mut grads)`
959    #[inline]
960    pub fn d_output_mut(&mut self) -> &mut [f32] {
961        self.d_layer_outputs
962            .last_mut()
963            .expect("mlp must have at least one layer")
964            .as_mut_slice()
965    }
966
967    #[inline]
968    /// Immutable view of the final upstream gradient buffer.
969    pub fn d_output(&self) -> &[f32] {
970        self.d_layer_outputs
971            .last()
972            .expect("mlp must have at least one layer")
973            .as_slice()
974    }
975
976    #[inline]
977    /// Returns dL/d(input) computed by the most recent `backward` call.
978    pub fn d_input(&self) -> &[f32] {
979        &self.d_input
980    }
981
982    /// Returns the weight gradient for the given layer (row-major `(out_dim, in_dim)`).
983    #[inline]
984    pub fn d_weights(&self, layer_idx: usize) -> &[f32] {
985        &self.d_weights[layer_idx]
986    }
987
988    /// Mutable weight gradient for the given layer.
989    #[inline]
990    pub fn d_weights_mut(&mut self, layer_idx: usize) -> &mut [f32] {
991        &mut self.d_weights[layer_idx]
992    }
993
994    /// Returns the bias gradient for the given layer (length `out_dim`).
995    #[inline]
996    pub fn d_biases(&self, layer_idx: usize) -> &[f32] {
997        &self.d_biases[layer_idx]
998    }
999
1000    /// Mutable bias gradient for the given layer.
1001    #[inline]
1002    pub fn d_biases_mut(&mut self, layer_idx: usize) -> &mut [f32] {
1003        &mut self.d_biases[layer_idx]
1004    }
1005
1006    /// Zero the parameter gradient buffers (`d_weights` and `d_biases`).
1007    #[inline]
1008    pub fn zero_params(&mut self) {
1009        for w in &mut self.d_weights {
1010            w.fill(0.0);
1011        }
1012        for b in &mut self.d_biases {
1013            b.fill(0.0);
1014        }
1015    }
1016
1017    /// Scale parameter gradients (`d_weights` and `d_biases`) in place.
1018    #[inline]
1019    pub fn scale_params(&mut self, scale: f32) {
1020        assert!(scale.is_finite(), "scale must be finite");
1021
1022        for w in &mut self.d_weights {
1023            for v in w.iter_mut() {
1024                *v *= scale;
1025            }
1026        }
1027        for b in &mut self.d_biases {
1028            for v in b.iter_mut() {
1029                *v *= scale;
1030            }
1031        }
1032    }
1033
1034    /// Compute the global L2 norm of parameter gradients.
1035    pub fn global_l2_norm_params(&self) -> f32 {
1036        let mut sum_sq = 0.0_f32;
1037        for w in &self.d_weights {
1038            for &v in w {
1039                sum_sq = v.mul_add(v, sum_sq);
1040            }
1041        }
1042        for b in &self.d_biases {
1043            for &v in b {
1044                sum_sq = v.mul_add(v, sum_sq);
1045            }
1046        }
1047        sum_sq.sqrt()
1048    }
1049
1050    /// Clip parameter gradients by global norm.
1051    ///
1052    /// Returns the pre-clip norm.
1053    pub fn clip_global_norm_params(&mut self, max_norm: f32) -> f32 {
1054        assert!(
1055            max_norm.is_finite() && max_norm > 0.0,
1056            "max_norm must be finite and > 0"
1057        );
1058
1059        let norm = self.global_l2_norm_params();
1060        if norm > max_norm && norm > 0.0 {
1061            self.scale_params(max_norm / norm);
1062        }
1063        norm
1064    }
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069    use super::*;
1070    use rand::SeedableRng;
1071    use rand::rngs::StdRng;
1072
1073    use crate::{Activation, MlpBuilder};
1074
1075    fn loss_for_mlp(mlp: &Mlp, input: &[f32], target: &[f32], scratch: &mut Scratch) -> f32 {
1076        mlp.forward(input, scratch);
1077        crate::loss::mse(scratch.output(), target)
1078    }
1079
1080    fn assert_close(analytic: f32, numeric: f32, abs_tol: f32, rel_tol: f32) {
1081        let diff = (analytic - numeric).abs();
1082        let scale = analytic.abs().max(numeric.abs()).max(1.0);
1083        assert!(
1084            diff <= abs_tol || diff / scale <= rel_tol,
1085            "analytic={analytic} numeric={numeric} diff={diff}"
1086        );
1087    }
1088
1089    #[test]
1090    fn predict_into_validates_shapes() {
1091        let mlp = MlpBuilder::new(2)
1092            .unwrap()
1093            .add_layer(3, Activation::Tanh)
1094            .unwrap()
1095            .add_layer(1, Activation::Identity)
1096            .unwrap()
1097            .build_with_seed(0)
1098            .unwrap();
1099
1100        let mut scratch = mlp.scratch();
1101        let mut out = [0.0_f32; 1];
1102
1103        let ok = mlp.predict_into(&[0.1, 0.2], &mut scratch, &mut out);
1104        assert!(ok.is_ok());
1105
1106        let err = mlp.predict_into(&[0.1_f32], &mut scratch, &mut out);
1107        assert!(err.is_err());
1108    }
1109
1110    #[test]
1111    fn backward_matches_numeric_gradients_for_tanh() {
1112        let mut mlp = MlpBuilder::new(2)
1113            .unwrap()
1114            .add_layer(3, Activation::Tanh)
1115            .unwrap()
1116            .add_layer(1, Activation::Tanh)
1117            .unwrap()
1118            .build_with_seed(0)
1119            .unwrap();
1120
1121        let mut scratch = mlp.scratch();
1122        let mut grads = mlp.gradients();
1123
1124        let input = [0.3_f32, -0.7_f32];
1125        let target = [0.2_f32];
1126
1127        mlp.forward(&input, &mut scratch);
1128        let _loss = crate::loss::mse_backward(scratch.output(), &target, grads.d_output_mut());
1129        let d_input = mlp.backward(&input, &scratch, &mut grads).to_vec();
1130
1131        let eps = 1e-3_f32;
1132        let abs_tol = 1e-3_f32;
1133        let rel_tol = 1e-2_f32;
1134
1135        let mut scratch_tmp = mlp.scratch();
1136
1137        // Parameters.
1138        for layer_idx in 0..mlp.layers.len() {
1139            // Weights.
1140            let w_len = mlp.layers[layer_idx].in_dim() * mlp.layers[layer_idx].out_dim();
1141            debug_assert_eq!(w_len, grads.d_weights(layer_idx).len());
1142
1143            for p in 0..w_len {
1144                let orig = {
1145                    let w = mlp.layers[layer_idx].weights_mut();
1146                    let orig = w[p];
1147                    w[p] = orig + eps;
1148                    orig
1149                };
1150                let loss_plus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1151
1152                {
1153                    let w = mlp.layers[layer_idx].weights_mut();
1154                    w[p] = orig - eps;
1155                }
1156                let loss_minus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1157
1158                {
1159                    let w = mlp.layers[layer_idx].weights_mut();
1160                    w[p] = orig;
1161                }
1162
1163                let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1164                let analytic = grads.d_weights(layer_idx)[p];
1165                assert_close(analytic, numeric, abs_tol, rel_tol);
1166            }
1167
1168            // Biases.
1169            let b_len = mlp.layers[layer_idx].out_dim();
1170            debug_assert_eq!(b_len, grads.d_biases(layer_idx).len());
1171
1172            for p in 0..b_len {
1173                let orig = {
1174                    let b = mlp.layers[layer_idx].biases_mut();
1175                    let orig = b[p];
1176                    b[p] = orig + eps;
1177                    orig
1178                };
1179                let loss_plus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1180
1181                {
1182                    let b = mlp.layers[layer_idx].biases_mut();
1183                    b[p] = orig - eps;
1184                }
1185                let loss_minus = loss_for_mlp(&mlp, &input, &target, &mut scratch_tmp);
1186
1187                {
1188                    let b = mlp.layers[layer_idx].biases_mut();
1189                    b[p] = orig;
1190                }
1191
1192                let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1193                let analytic = grads.d_biases(layer_idx)[p];
1194                assert_close(analytic, numeric, abs_tol, rel_tol);
1195            }
1196        }
1197
1198        // Inputs.
1199        let mut input_var = input;
1200        for i in 0..input_var.len() {
1201            let orig = input_var[i];
1202
1203            input_var[i] = orig + eps;
1204            let loss_plus = loss_for_mlp(&mlp, &input_var, &target, &mut scratch_tmp);
1205
1206            input_var[i] = orig - eps;
1207            let loss_minus = loss_for_mlp(&mlp, &input_var, &target, &mut scratch_tmp);
1208
1209            input_var[i] = orig;
1210
1211            let numeric = (loss_plus - loss_minus) / (2.0 * eps);
1212            let analytic = d_input[i];
1213            assert_close(analytic, numeric, abs_tol, rel_tol);
1214        }
1215    }
1216
1217    #[test]
1218    #[should_panic]
1219    fn forward_panics_on_input_shape_mismatch() {
1220        let mut rng = StdRng::seed_from_u64(0);
1221        let mlp = MlpBuilder::new(2)
1222            .unwrap()
1223            .add_layer(3, Activation::Tanh)
1224            .unwrap()
1225            .add_layer(1, Activation::Tanh)
1226            .unwrap()
1227            .build_with_rng(&mut rng)
1228            .unwrap();
1229        let mut scratch = mlp.scratch();
1230        let input = [0.0_f32; 3];
1231        mlp.forward(&input, &mut scratch);
1232    }
1233}