1use crate::conv::{Conv, conv_out_dim};
2use crate::{ConvGeometryIsValid, Float, Sample};
3use rand::{Rng, SeedableRng, rngs::StdRng, seq::SliceRandom};
4use std::fmt;
5
6pub trait Initializer {
7 fn fill<R: Rng + ?Sized>(
8 &self,
9 values: &mut [Float],
10 fan_in: usize,
11 fan_out: usize,
12 rng: &mut R,
13 );
14}
15
16#[derive(Debug, Clone, Copy)]
17pub struct Uniform {
18 pub low: Float,
19 pub high: Float,
20}
21
22impl Uniform {
23 pub const fn new(low: Float, high: Float) -> Self {
24 Self { low, high }
25 }
26}
27
28impl Initializer for Uniform {
29 fn fill<R: Rng + ?Sized>(
30 &self,
31 values: &mut [Float],
32 _fan_in: usize,
33 _fan_out: usize,
34 rng: &mut R,
35 ) {
36 for value in values {
37 *value = rng.random_range(self.low..self.high);
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, Default)]
43pub struct XavierUniform;
44
45impl Initializer for XavierUniform {
46 fn fill<R: Rng + ?Sized>(
47 &self,
48 values: &mut [Float],
49 fan_in: usize,
50 fan_out: usize,
51 rng: &mut R,
52 ) {
53 let denom = (fan_in + fan_out).max(1) as Float;
54 let bound = (6.0 / denom).sqrt();
55 Uniform::new(-bound, bound).fill(values, fan_in, fan_out, rng);
56 }
57}
58
59#[derive(Debug, Clone, Copy, Default)]
60pub struct KaimingUniform;
61
62impl Initializer for KaimingUniform {
63 fn fill<R: Rng + ?Sized>(
64 &self,
65 values: &mut [Float],
66 fan_in: usize,
67 fan_out: usize,
68 rng: &mut R,
69 ) {
70 let denom = fan_in.max(1) as Float;
71 let bound = (6.0 / denom).sqrt();
72 Uniform::new(-bound, bound).fill(values, fan_in, fan_out, rng);
73 }
74}
75
76pub trait Optimizer: fmt::Debug {
77 fn begin_step(&mut self) {}
78 fn update_parameter(
79 &mut self,
80 slot: usize,
81 params: &mut [Float],
82 grads: &[Float],
83 scale: Float,
84 );
85}
86
87#[derive(Debug, Clone, Copy)]
88pub struct Sgd {
89 pub lr: Float,
90 pub weight_decay: Float,
91}
92
93impl Sgd {
94 pub const fn new(lr: Float) -> Self {
95 Self {
96 lr,
97 weight_decay: 0.0,
98 }
99 }
100
101 pub const fn with_weight_decay(mut self, weight_decay: Float) -> Self {
102 self.weight_decay = weight_decay;
103 self
104 }
105}
106
107impl Optimizer for Sgd {
108 fn update_parameter(
109 &mut self,
110 _slot: usize,
111 params: &mut [Float],
112 grads: &[Float],
113 scale: Float,
114 ) {
115 for (param, grad) in params.iter_mut().zip(grads.iter()) {
116 let update = *grad * scale + self.weight_decay * *param;
117 *param -= self.lr * update;
118 }
119 }
120}
121
122#[derive(Debug, Clone)]
123pub struct Adam {
124 pub lr: Float,
125 pub beta1: Float,
126 pub beta2: Float,
127 pub epsilon: Float,
128 pub weight_decay: Float,
129 step: usize,
130 first_moment: Vec<Box<[Float]>>,
131 second_moment: Vec<Box<[Float]>>,
132}
133
134impl Adam {
135 pub fn new(lr: Float) -> Self {
136 Self {
137 lr,
138 beta1: 0.9,
139 beta2: 0.999,
140 epsilon: 1e-8,
141 weight_decay: 0.0,
142 step: 0,
143 first_moment: Vec::new(),
144 second_moment: Vec::new(),
145 }
146 }
147
148 pub const fn with_weight_decay(mut self, weight_decay: Float) -> Self {
149 self.weight_decay = weight_decay;
150 self
151 }
152
153 fn ensure_slot(&mut self, slot: usize, len: usize) {
154 while self.first_moment.len() <= slot {
155 self.first_moment.push(Vec::new().into_boxed_slice());
156 self.second_moment.push(Vec::new().into_boxed_slice());
157 }
158 if self.first_moment[slot].len() != len {
159 self.first_moment[slot] = vec![0.0; len].into_boxed_slice();
160 self.second_moment[slot] = vec![0.0; len].into_boxed_slice();
161 }
162 }
163}
164
165impl Optimizer for Adam {
166 fn begin_step(&mut self) {
167 self.step += 1;
168 }
169
170 fn update_parameter(
171 &mut self,
172 slot: usize,
173 params: &mut [Float],
174 grads: &[Float],
175 scale: Float,
176 ) {
177 self.ensure_slot(slot, params.len());
178 let bias_correction1 = 1.0 - self.beta1.powi(self.step as i32);
179 let bias_correction2 = 1.0 - self.beta2.powi(self.step as i32);
180
181 let first = &mut self.first_moment[slot];
182 let second = &mut self.second_moment[slot];
183 for i in 0..params.len() {
184 let grad = grads[i] * scale + self.weight_decay * params[i];
185 first[i] = self.beta1 * first[i] + (1.0 - self.beta1) * grad;
186 second[i] = self.beta2 * second[i] + (1.0 - self.beta2) * grad * grad;
187 let m_hat = first[i] / bias_correction1.max(f64::EPSILON);
188 let v_hat = second[i] / bias_correction2.max(f64::EPSILON);
189 params[i] -= self.lr * m_hat / (v_hat.sqrt() + self.epsilon);
190 }
191 }
192}
193
194pub struct TrainConfig {
195 optimizer: Box<dyn Optimizer>,
196 pub epochs: usize,
197 pub batch_size: usize,
198 pub shuffle_seed: Option<u64>,
199}
200
201impl fmt::Debug for TrainConfig {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("TrainConfig")
204 .field("optimizer", &self.optimizer)
205 .field("epochs", &self.epochs)
206 .field("batch_size", &self.batch_size)
207 .field("shuffle_seed", &self.shuffle_seed)
208 .finish()
209 }
210}
211
212impl TrainConfig {
213 pub fn new<O: Optimizer + 'static>(optimizer: O) -> Self {
214 Self {
215 optimizer: Box::new(optimizer),
216 epochs: 1,
217 batch_size: 1,
218 shuffle_seed: None,
219 }
220 }
221
222 pub fn sgd(lr: Float) -> Self {
223 Self::new(Sgd::new(lr))
224 }
225
226 pub fn adam(lr: Float) -> Self {
227 Self::new(Adam::new(lr))
228 }
229
230 pub fn epochs(mut self, epochs: usize) -> Self {
231 self.epochs = epochs;
232 self
233 }
234
235 pub fn batch_size(mut self, batch_size: usize) -> Self {
236 self.batch_size = batch_size.max(1);
237 self
238 }
239
240 pub fn shuffle_seed(mut self, shuffle_seed: u64) -> Self {
241 self.shuffle_seed = Some(shuffle_seed);
242 self
243 }
244
245 fn optimizer_mut(&mut self) -> &mut dyn Optimizer {
246 self.optimizer.as_mut()
247 }
248}
249
250impl Default for TrainConfig {
251 fn default() -> Self {
252 Self::adam(1e-3)
253 }
254}
255
256pub trait LossFunction<const N: usize>: fmt::Debug {
257 fn loss_and_grad(
258 &self,
259 output: &[Float; N],
260 target: &[Float; N],
261 grad: &mut [Float; N],
262 ) -> Float;
263}
264
265#[derive(Debug, Clone, Copy, Default)]
266pub struct MeanSquaredError;
267
268pub fn mse_loss<const N: usize>(
269 output: &[Float; N],
270 target: &[Float; N],
271 grad: &mut [Float; N],
272) -> Float {
273 let scale = 2.0 / N as Float;
274 let loss = output
275 .iter()
276 .zip(target.iter())
277 .zip(grad.iter_mut())
278 .map(|((&o, &t), g)| {
279 let diff = o - t;
280 *g = diff * scale;
281 diff * diff
282 })
283 .sum::<Float>();
284 loss / N as Float
285}
286
287impl<const N: usize> LossFunction<N> for MeanSquaredError {
288 fn loss_and_grad(
289 &self,
290 output: &[Float; N],
291 target: &[Float; N],
292 grad: &mut [Float; N],
293 ) -> Float {
294 mse_loss(output, target, grad)
295 }
296}
297
298pub trait Layer<const IN: usize, const OUT: usize> {
299 fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]);
300 fn backward(
301 &mut self,
302 input: &[Float; IN],
303 output: &[Float; OUT],
304 output_grad: &[Float; OUT],
305 input_grad: &mut [Float; IN],
306 );
307
308 fn zero_grad(&mut self) {}
309
310 fn apply_gradients(
311 &mut self,
312 _optimizer: &mut dyn Optimizer,
313 _slot: &mut usize,
314 _scale: Float,
315 ) {
316 }
317}
318
319pub trait LayerDims {
320 const INPUT: usize;
321 const OUTPUT: usize;
322}
323
324#[derive(Debug)]
325pub struct DenseLayer<const IN: usize, const OUT: usize> {
326 weights: Box<[Float]>,
327 biases: Box<[Float; OUT]>,
328 weight_grads: Box<[Float]>,
329 bias_grads: Box<[Float; OUT]>,
330}
331
332#[derive(Debug)]
333pub struct ReLU<const N: usize>;
334
335#[derive(Debug)]
336pub struct Sigmoid<const N: usize>;
337
338#[derive(Debug)]
339pub struct Flatten<const N: usize>;
340
341impl<const IN: usize, const OUT: usize> DenseLayer<IN, OUT> {
342 pub fn init() -> Self {
343 Self::with_initializer(XavierUniform)
344 }
345
346 pub fn seeded(seed: u64) -> Self {
347 Self::with_initializer_and_seed(XavierUniform, seed)
348 }
349
350 pub fn with_initializer<I: Initializer>(initializer: I) -> Self {
351 let mut rng = rand::rng();
352 Self::with_initializer_and_rng(initializer, &mut rng)
353 }
354
355 pub fn with_initializer_and_seed<I: Initializer>(initializer: I, seed: u64) -> Self {
356 let mut rng = StdRng::seed_from_u64(seed);
357 Self::with_initializer_and_rng(initializer, &mut rng)
358 }
359
360 pub fn with_initializer_and_rng<I: Initializer, R: Rng + ?Sized>(
361 initializer: I,
362 rng: &mut R,
363 ) -> Self {
364 let mut weights = vec![0.0; IN * OUT].into_boxed_slice();
365 initializer.fill(&mut weights, IN, OUT, rng);
366 Self {
367 weights,
368 biases: Box::new([0.0; OUT]),
369 weight_grads: vec![0.0; IN * OUT].into_boxed_slice(),
370 bias_grads: Box::new([0.0; OUT]),
371 }
372 }
373
374 pub fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]) {
375 for (o, out) in output.iter_mut().enumerate() {
376 let row = &self.weights[o * IN..(o + 1) * IN];
377 let mut sum = self.biases[o];
378 for (weight, inp) in row.iter().zip(input.iter()) {
379 sum += *weight * *inp;
380 }
381 *out = sum;
382 }
383 }
384
385 pub fn backward(
386 &mut self,
387 input: &[Float; IN],
388 _output: &[Float; OUT],
389 output_grad: &[Float; OUT],
390 input_grad: &mut [Float; IN],
391 ) {
392 input_grad.fill(0.0);
393
394 for (o, &grad) in output_grad.iter().enumerate() {
395 let row = &self.weights[o * IN..(o + 1) * IN];
396 for (in_grad, weight) in input_grad.iter_mut().zip(row.iter()) {
397 *in_grad += *weight * grad;
398 }
399 }
400
401 for (o, &grad) in output_grad.iter().enumerate() {
402 self.bias_grads[o] += grad;
403 let row_grads = &mut self.weight_grads[o * IN..(o + 1) * IN];
404 for (weight_grad, inp) in row_grads.iter_mut().zip(input.iter()) {
405 *weight_grad += grad * *inp;
406 }
407 }
408 }
409}
410
411impl<const IN: usize, const OUT: usize> LayerDims for DenseLayer<IN, OUT> {
412 const INPUT: usize = IN;
413 const OUTPUT: usize = OUT;
414}
415
416impl<const IN: usize, const OUT: usize> Layer<IN, OUT> for DenseLayer<IN, OUT> {
417 fn forward(&self, input: &[Float; IN], output: &mut [Float; OUT]) {
418 DenseLayer::forward(self, input, output);
419 }
420
421 fn backward(
422 &mut self,
423 input: &[Float; IN],
424 output: &[Float; OUT],
425 output_grad: &[Float; OUT],
426 input_grad: &mut [Float; IN],
427 ) {
428 DenseLayer::backward(self, input, output, output_grad, input_grad);
429 }
430
431 fn zero_grad(&mut self) {
432 self.weight_grads.fill(0.0);
433 self.bias_grads.fill(0.0);
434 }
435
436 fn apply_gradients(&mut self, optimizer: &mut dyn Optimizer, slot: &mut usize, scale: Float) {
437 optimizer.update_parameter(*slot, &mut self.weights, &self.weight_grads, scale);
438 *slot += 1;
439 optimizer.update_parameter(
440 *slot,
441 self.biases.as_mut_slice(),
442 self.bias_grads.as_slice(),
443 scale,
444 );
445 *slot += 1;
446 self.zero_grad();
447 }
448}
449
450impl<const N: usize> ReLU<N> {
451 pub fn init() -> Self {
452 ReLU
453 }
454
455 pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
456 for i in 0..N {
457 output[i] = input[i].max(0.0);
458 }
459 }
460
461 pub fn backward(
462 &self,
463 input: &[Float; N],
464 _output: &[Float; N],
465 output_grad: &[Float; N],
466 input_grad: &mut [Float; N],
467 ) {
468 for i in 0..N {
469 input_grad[i] = if input[i] > 0.0 { output_grad[i] } else { 0.0 };
470 }
471 }
472}
473
474impl<const N: usize> LayerDims for ReLU<N> {
475 const INPUT: usize = N;
476 const OUTPUT: usize = N;
477}
478
479impl<const N: usize> Layer<N, N> for ReLU<N> {
480 fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
481 ReLU::forward(self, input, output);
482 }
483
484 fn backward(
485 &mut self,
486 input: &[Float; N],
487 output: &[Float; N],
488 output_grad: &[Float; N],
489 input_grad: &mut [Float; N],
490 ) {
491 ReLU::backward(self, input, output, output_grad, input_grad);
492 }
493}
494
495impl<const N: usize> Sigmoid<N> {
496 pub fn init() -> Self {
497 Sigmoid
498 }
499
500 pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
501 for i in 0..N {
502 output[i] = 1.0 / (1.0 + (-input[i]).exp());
503 }
504 }
505
506 pub fn backward(
507 &self,
508 _input: &[Float; N],
509 output: &[Float; N],
510 output_grad: &[Float; N],
511 input_grad: &mut [Float; N],
512 ) {
513 for i in 0..N {
514 let y = output[i];
515 input_grad[i] = output_grad[i] * y * (1.0 - y);
516 }
517 }
518}
519
520impl<const N: usize> LayerDims for Sigmoid<N> {
521 const INPUT: usize = N;
522 const OUTPUT: usize = N;
523}
524
525impl<const N: usize> Layer<N, N> for Sigmoid<N> {
526 fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
527 Sigmoid::forward(self, input, output);
528 }
529
530 fn backward(
531 &mut self,
532 input: &[Float; N],
533 output: &[Float; N],
534 output_grad: &[Float; N],
535 input_grad: &mut [Float; N],
536 ) {
537 Sigmoid::backward(self, input, output, output_grad, input_grad);
538 }
539}
540
541impl<const N: usize> Flatten<N> {
542 pub fn init() -> Self {
543 Flatten
544 }
545
546 pub fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
547 output.copy_from_slice(input);
548 }
549
550 pub fn backward(
551 &self,
552 _input: &[Float; N],
553 _output: &[Float; N],
554 output_grad: &[Float; N],
555 input_grad: &mut [Float; N],
556 ) {
557 input_grad.copy_from_slice(output_grad);
558 }
559}
560
561impl<const N: usize> LayerDims for Flatten<N> {
562 const INPUT: usize = N;
563 const OUTPUT: usize = N;
564}
565
566impl<const N: usize> Layer<N, N> for Flatten<N> {
567 fn forward(&self, input: &[Float; N], output: &mut [Float; N]) {
568 Flatten::forward(self, input, output);
569 }
570
571 fn backward(
572 &mut self,
573 input: &[Float; N],
574 output: &[Float; N],
575 output_grad: &[Float; N],
576 input_grad: &mut [Float; N],
577 ) {
578 Flatten::backward(self, input, output, output_grad, input_grad);
579 }
580}
581
582mod private {
583 use super::*;
584
585 #[derive(Debug, Clone, Copy, Default)]
586 pub struct End;
587
588 #[derive(Debug)]
589 pub struct Chain<Head, Tail, const MID: usize> {
590 pub(super) head: Head,
591 pub(super) tail: Tail,
592 }
593
594 impl<Head, Tail, const MID: usize> Chain<Head, Tail, MID> {
595 pub const fn new(head: Head, tail: Tail) -> Self {
596 Self { head, tail }
597 }
598 }
599
600 pub trait AppendLayer<Next, const NEXT_OUTPUT: usize>: Sized {
601 type Output;
602 fn then(self, next: Next) -> Self::Output;
603 }
604
605 impl<Next, const NEXT_OUTPUT: usize> AppendLayer<Next, NEXT_OUTPUT> for End {
606 type Output = Chain<Next, End, NEXT_OUTPUT>;
607
608 fn then(self, next: Next) -> Self::Output {
609 Chain::new(next, End)
610 }
611 }
612
613 impl<Head, Tail, const MID: usize, Next, const NEXT_OUTPUT: usize>
614 AppendLayer<Next, NEXT_OUTPUT> for Chain<Head, Tail, MID>
615 where
616 Tail: AppendLayer<Next, NEXT_OUTPUT>,
617 {
618 type Output = Chain<Head, <Tail as AppendLayer<Next, NEXT_OUTPUT>>::Output, MID>;
619
620 fn then(self, next: Next) -> Self::Output {
621 Chain::new(self.head, self.tail.then(next))
622 }
623 }
624
625 #[derive(Debug)]
626 pub struct TerminalWorkspace<const OUT: usize> {
627 activation: Box<[Float; OUT]>,
628 gradient: Box<[Float; OUT]>,
629 }
630
631 #[derive(Debug)]
632 pub struct ChainWorkspace<const MID: usize, TailWorkspace> {
633 activation: Box<[Float; MID]>,
634 gradient: Box<[Float; MID]>,
635 tail: TailWorkspace,
636 }
637
638 #[derive(Debug)]
639 pub struct StackWorkspace<BodyWorkspace, const INPUT: usize> {
640 body: BodyWorkspace,
641 input_grad: Box<[Float; INPUT]>,
642 }
643
644 pub trait ModuleChain<const INPUT: usize, const OUTPUT: usize> {
645 type Workspace;
646
647 fn workspace(&self) -> Self::Workspace;
648 fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace);
649 fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT];
650 fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]);
651 fn backward_with_workspace(
652 &mut self,
653 input: &[Float; INPUT],
654 input_grad: &mut [Float; INPUT],
655 workspace: &mut Self::Workspace,
656 );
657 fn zero_grad(&mut self);
658 fn apply_gradients(
659 &mut self,
660 optimizer: &mut dyn Optimizer,
661 slot: &mut usize,
662 scale: Float,
663 );
664 }
665
666 impl<Head, const INPUT: usize, const OUTPUT: usize> ModuleChain<INPUT, OUTPUT>
667 for Chain<Head, End, OUTPUT>
668 where
669 Head: Layer<INPUT, OUTPUT>,
670 {
671 type Workspace = TerminalWorkspace<OUTPUT>;
672
673 fn workspace(&self) -> Self::Workspace {
674 TerminalWorkspace {
675 activation: Box::new([0.0; OUTPUT]),
676 gradient: Box::new([0.0; OUTPUT]),
677 }
678 }
679
680 fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace) {
681 self.head.forward(input, workspace.activation.as_mut());
682 }
683
684 fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT] {
685 workspace.activation.as_ref()
686 }
687
688 fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]) {
689 workspace.gradient.copy_from_slice(grad);
690 }
691
692 fn backward_with_workspace(
693 &mut self,
694 input: &[Float; INPUT],
695 input_grad: &mut [Float; INPUT],
696 workspace: &mut Self::Workspace,
697 ) {
698 self.head.backward(
699 input,
700 workspace.activation.as_ref(),
701 workspace.gradient.as_ref(),
702 input_grad,
703 );
704 }
705
706 fn zero_grad(&mut self) {
707 self.head.zero_grad();
708 }
709
710 fn apply_gradients(
711 &mut self,
712 optimizer: &mut dyn Optimizer,
713 slot: &mut usize,
714 scale: Float,
715 ) {
716 self.head.apply_gradients(optimizer, slot, scale);
717 }
718 }
719
720 impl<Head, Tail, const INPUT: usize, const MID: usize, const OUTPUT: usize>
721 ModuleChain<INPUT, OUTPUT> for Chain<Head, Tail, MID>
722 where
723 Head: Layer<INPUT, MID>,
724 Tail: ModuleChain<MID, OUTPUT>,
725 {
726 type Workspace = ChainWorkspace<MID, Tail::Workspace>;
727
728 fn workspace(&self) -> Self::Workspace {
729 ChainWorkspace {
730 activation: Box::new([0.0; MID]),
731 gradient: Box::new([0.0; MID]),
732 tail: self.tail.workspace(),
733 }
734 }
735
736 fn forward_with_workspace(&self, input: &[Float; INPUT], workspace: &mut Self::Workspace) {
737 self.head.forward(input, workspace.activation.as_mut());
738 self.tail
739 .forward_with_workspace(workspace.activation.as_ref(), &mut workspace.tail);
740 }
741
742 fn output(workspace: &Self::Workspace) -> &[Float; OUTPUT] {
743 Tail::output(&workspace.tail)
744 }
745
746 fn set_output_grad(workspace: &mut Self::Workspace, grad: &[Float; OUTPUT]) {
747 Tail::set_output_grad(&mut workspace.tail, grad);
748 }
749
750 fn backward_with_workspace(
751 &mut self,
752 input: &[Float; INPUT],
753 input_grad: &mut [Float; INPUT],
754 workspace: &mut Self::Workspace,
755 ) {
756 self.tail.backward_with_workspace(
757 workspace.activation.as_ref(),
758 workspace.gradient.as_mut(),
759 &mut workspace.tail,
760 );
761 self.head.backward(
762 input,
763 workspace.activation.as_ref(),
764 workspace.gradient.as_ref(),
765 input_grad,
766 );
767 }
768
769 fn zero_grad(&mut self) {
770 self.head.zero_grad();
771 self.tail.zero_grad();
772 }
773
774 fn apply_gradients(
775 &mut self,
776 optimizer: &mut dyn Optimizer,
777 slot: &mut usize,
778 scale: Float,
779 ) {
780 self.head.apply_gradients(optimizer, slot, scale);
781 self.tail.apply_gradients(optimizer, slot, scale);
782 }
783 }
784
785 #[derive(Debug)]
786 pub struct Stack<Layers, const INPUT: usize, const OUTPUT: usize>
787 where
788 Layers: ModuleChain<INPUT, OUTPUT>,
789 {
790 layers: Layers,
791 }
792
793 impl<Layers, const INPUT: usize, const OUTPUT: usize> Stack<Layers, INPUT, OUTPUT>
794 where
795 Layers: ModuleChain<INPUT, OUTPUT>,
796 {
797 pub const fn new(layers: Layers) -> Self {
798 Self { layers }
799 }
800
801 pub fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
802 let mut workspace = StackWorkspace {
803 body: self.layers.workspace(),
804 input_grad: Box::new([0.0; INPUT]),
805 };
806 self.layers
807 .forward_with_workspace(input, &mut workspace.body);
808 let mut result = [0.0; OUTPUT];
809 result.copy_from_slice(Layers::output(&workspace.body));
810 result
811 }
812
813 pub fn fit_with_loss(
814 &mut self,
815 samples: &[Sample<INPUT, OUTPUT>],
816 loss_fn: &dyn LossFunction<OUTPUT>,
817 mut config: TrainConfig,
818 ) -> Float {
819 if samples.is_empty() || config.epochs == 0 {
820 return 0.0;
821 }
822
823 let batch_size = config.batch_size.max(1);
824 let mut workspace = StackWorkspace {
825 body: self.layers.workspace(),
826 input_grad: Box::new([0.0; INPUT]),
827 };
828 let mut order = (0..samples.len()).collect::<Vec<_>>();
829 let mut shuffler = config.shuffle_seed.map(StdRng::seed_from_u64);
830 let mut total_loss = 0.0;
831 let mut steps = 0usize;
832
833 for _ in 0..config.epochs {
834 if let Some(rng) = shuffler.as_mut() {
835 order.shuffle(rng);
836 }
837
838 for batch in order.chunks(batch_size) {
839 self.layers.zero_grad();
840 let mut batch_loss = 0.0;
841
842 for &sample_idx in batch {
843 let sample = &samples[sample_idx];
844 self.layers
845 .forward_with_workspace(&sample.input, &mut workspace.body);
846 let mut grad = [0.0; OUTPUT];
847 let loss = loss_fn.loss_and_grad(
848 Layers::output(&workspace.body),
849 &sample.target,
850 &mut grad,
851 );
852 Layers::set_output_grad(&mut workspace.body, &grad);
853 self.layers.backward_with_workspace(
854 &sample.input,
855 workspace.input_grad.as_mut(),
856 &mut workspace.body,
857 );
858 batch_loss += loss;
859 }
860
861 config.optimizer_mut().begin_step();
862 let mut slot = 0usize;
863 self.layers.apply_gradients(
864 config.optimizer_mut(),
865 &mut slot,
866 1.0 / batch.len() as Float,
867 );
868 total_loss += batch_loss / batch.len() as Float;
869 steps += 1;
870 }
871 }
872
873 total_loss / steps as Float
874 }
875 }
876
877 pub trait ModelRuntime<const INPUT: usize, const OUTPUT: usize>: fmt::Debug {
878 fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT];
879 fn fit_with_loss(
880 &mut self,
881 samples: &[Sample<INPUT, OUTPUT>],
882 loss_fn: &dyn LossFunction<OUTPUT>,
883 config: TrainConfig,
884 ) -> Float;
885 }
886
887 impl<Layers, const INPUT: usize, const OUTPUT: usize> ModelRuntime<INPUT, OUTPUT>
888 for Stack<Layers, INPUT, OUTPUT>
889 where
890 Layers: ModuleChain<INPUT, OUTPUT> + fmt::Debug + 'static,
891 {
892 fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
893 Stack::predict(self, input)
894 }
895
896 fn fit_with_loss(
897 &mut self,
898 samples: &[Sample<INPUT, OUTPUT>],
899 loss_fn: &dyn LossFunction<OUTPUT>,
900 config: TrainConfig,
901 ) -> Float {
902 Stack::fit_with_loss(self, samples, loss_fn, config)
903 }
904 }
905}
906
907pub struct Sequential<const INPUT: usize, const OUTPUT: usize> {
908 inner: Box<dyn private::ModelRuntime<INPUT, OUTPUT>>,
909}
910
911impl<const INPUT: usize, const OUTPUT: usize> fmt::Debug for Sequential<INPUT, OUTPUT> {
912 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
913 f.debug_struct("Sequential")
914 .field("input", &INPUT)
915 .field("output", &OUTPUT)
916 .finish()
917 }
918}
919
920impl<const INPUT: usize, const OUTPUT: usize> Sequential<INPUT, OUTPUT> {
921 fn from_runtime<R>(runtime: R) -> Self
922 where
923 R: private::ModelRuntime<INPUT, OUTPUT> + 'static,
924 {
925 Self {
926 inner: Box::new(runtime),
927 }
928 }
929
930 pub fn predict(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
931 self.inner.predict(input)
932 }
933
934 pub fn predict_in_place(&self, input: &[Float; INPUT]) -> [Float; OUTPUT] {
935 self.predict(input)
936 }
937
938 pub fn fit(&mut self, samples: &[Sample<INPUT, OUTPUT>], config: TrainConfig) -> Float {
939 self.fit_with_loss(samples, &MeanSquaredError, config)
940 }
941
942 pub fn fit_with_loss(
943 &mut self,
944 samples: &[Sample<INPUT, OUTPUT>],
945 loss_fn: &dyn LossFunction<OUTPUT>,
946 config: TrainConfig,
947 ) -> Float {
948 self.inner.fit_with_loss(samples, loss_fn, config)
949 }
950}
951
952#[derive(Debug, Clone, Copy, Default)]
953pub struct ModelBuilder;
954
955impl ModelBuilder {
956 pub const fn new() -> Self {
957 Self
958 }
959
960 pub fn input<const N: usize>(self) -> VectorBuilder<private::End, N, N> {
961 VectorBuilder {
962 layers: private::End,
963 }
964 }
965
966 pub fn image_input<const C: usize, const H: usize, const W: usize>(
967 self,
968 ) -> ImageBuilder<private::End, { C * H * W }, C, H, W>
969 where
970 [(); C * H * W]:,
971 {
972 ImageBuilder {
973 layers: private::End,
974 }
975 }
976}
977
978pub struct VectorBuilder<Layers, const INPUT: usize, const CURRENT: usize> {
979 layers: Layers,
980}
981
982impl<Layers, const INPUT: usize, const CURRENT: usize> fmt::Debug
983 for VectorBuilder<Layers, INPUT, CURRENT>
984where
985 Layers: fmt::Debug,
986{
987 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
988 f.debug_struct("VectorBuilder")
989 .field("input", &INPUT)
990 .field("current", &CURRENT)
991 .finish()
992 }
993}
994
995impl<Layers, const INPUT: usize, const CURRENT: usize> VectorBuilder<Layers, INPUT, CURRENT> {
996 pub const fn flatten(self) -> Self {
997 self
998 }
999
1000 pub fn dense<const NEXT: usize>(
1001 self,
1002 ) -> VectorBuilder<
1003 <Layers as private::AppendLayer<DenseLayer<CURRENT, NEXT>, NEXT>>::Output,
1004 INPUT,
1005 NEXT,
1006 >
1007 where
1008 Layers: private::AppendLayer<DenseLayer<CURRENT, NEXT>, NEXT>,
1009 {
1010 VectorBuilder {
1011 layers: self.layers.then(DenseLayer::<CURRENT, NEXT>::init()),
1012 }
1013 }
1014
1015 pub fn relu(
1016 self,
1017 ) -> VectorBuilder<
1018 <Layers as private::AppendLayer<ReLU<CURRENT>, CURRENT>>::Output,
1019 INPUT,
1020 CURRENT,
1021 >
1022 where
1023 Layers: private::AppendLayer<ReLU<CURRENT>, CURRENT>,
1024 {
1025 VectorBuilder {
1026 layers: self.layers.then(ReLU::<CURRENT>::init()),
1027 }
1028 }
1029
1030 pub fn sigmoid(
1031 self,
1032 ) -> VectorBuilder<
1033 <Layers as private::AppendLayer<Sigmoid<CURRENT>, CURRENT>>::Output,
1034 INPUT,
1035 CURRENT,
1036 >
1037 where
1038 Layers: private::AppendLayer<Sigmoid<CURRENT>, CURRENT>,
1039 {
1040 VectorBuilder {
1041 layers: self.layers.then(Sigmoid::<CURRENT>::init()),
1042 }
1043 }
1044
1045 pub fn build(self) -> Sequential<INPUT, CURRENT>
1046 where
1047 Layers: private::ModuleChain<INPUT, CURRENT> + fmt::Debug + 'static,
1048 {
1049 Sequential::from_runtime(private::Stack::new(self.layers))
1050 }
1051}
1052
1053pub struct ImageBuilder<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize>
1054{
1055 layers: Layers,
1056}
1057
1058impl<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize> fmt::Debug
1059 for ImageBuilder<Layers, INPUT, C, H, W>
1060where
1061 Layers: fmt::Debug,
1062{
1063 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1064 f.debug_struct("ImageBuilder")
1065 .field("input", &INPUT)
1066 .field("channels", &C)
1067 .field("height", &H)
1068 .field("width", &W)
1069 .finish()
1070 }
1071}
1072
1073impl<Layers, const INPUT: usize, const C: usize, const H: usize, const W: usize>
1074 ImageBuilder<Layers, INPUT, C, H, W>
1075where
1076 [(); INPUT]:,
1077{
1078 pub fn relu(
1079 self,
1080 ) -> ImageBuilder<
1081 <Layers as private::AppendLayer<ReLU<{ C * H * W }>, { C * H * W }>>::Output,
1082 INPUT,
1083 C,
1084 H,
1085 W,
1086 >
1087 where
1088 [(); C * H * W]:,
1089 Layers: private::AppendLayer<ReLU<{ C * H * W }>, { C * H * W }>,
1090 {
1091 ImageBuilder {
1092 layers: self.layers.then(ReLU::<{ C * H * W }>::init()),
1093 }
1094 }
1095
1096 pub fn sigmoid(
1097 self,
1098 ) -> ImageBuilder<
1099 <Layers as private::AppendLayer<Sigmoid<{ C * H * W }>, { C * H * W }>>::Output,
1100 INPUT,
1101 C,
1102 H,
1103 W,
1104 >
1105 where
1106 [(); C * H * W]:,
1107 Layers: private::AppendLayer<Sigmoid<{ C * H * W }>, { C * H * W }>,
1108 {
1109 ImageBuilder {
1110 layers: self.layers.then(Sigmoid::<{ C * H * W }>::init()),
1111 }
1112 }
1113
1114 pub fn conv<const OC: usize, const FH: usize, const FW: usize, const S: usize, const P: usize>(
1115 self,
1116 ) -> ImageBuilder<
1117 <Layers as private::AppendLayer<
1118 Conv<W, H, C, FH, FW, OC, S, P>,
1119 { OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S) },
1120 >>::Output,
1121 INPUT,
1122 OC,
1123 { conv_out_dim(H, P, FH, S) },
1124 { conv_out_dim(W, P, FW, S) },
1125 >
1126 where
1127 [(); C * H * W]:,
1128 [(); OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S)]:,
1129 (): ConvGeometryIsValid<H, W, FH, FW, S, P>,
1130 Layers: private::AppendLayer<
1131 Conv<W, H, C, FH, FW, OC, S, P>,
1132 { OC * conv_out_dim(H, P, FH, S) * conv_out_dim(W, P, FW, S) },
1133 >,
1134 {
1135 ImageBuilder {
1136 layers: self.layers.then(Conv::<W, H, C, FH, FW, OC, S, P>::init()),
1137 }
1138 }
1139
1140 pub fn flatten(self) -> VectorBuilder<Layers, INPUT, { C * H * W }>
1141 where
1142 [(); C * H * W]:,
1143 {
1144 VectorBuilder {
1145 layers: self.layers,
1146 }
1147 }
1148
1149 pub fn build(self) -> Sequential<INPUT, { C * H * W }>
1150 where
1151 [(); C * H * W]:,
1152 Layers: private::ModuleChain<INPUT, { C * H * W }> + fmt::Debug + 'static,
1153 {
1154 Sequential::from_runtime(private::Stack::new(self.layers))
1155 }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160 use super::*;
1161
1162 fn approx_eq(a: Float, b: Float, eps: Float) {
1163 let diff = (a - b).abs();
1164 assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
1165 }
1166
1167 #[test]
1168 fn mse_loss_matches_manual_computation() {
1169 let output = [2.0, -1.0];
1170 let target = [1.0, 1.0];
1171 let mut grad = [0.0; 2];
1172 let loss = mse_loss(&output, &target, &mut grad);
1173 approx_eq(loss, 2.5, 1e-12);
1174 assert_eq!(grad, [1.0, -2.0]);
1175 }
1176
1177 #[test]
1178 fn dense_input_gradient_matches_finite_difference() {
1179 let mut layer = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 7);
1180 layer.weights.copy_from_slice(&[0.4, -0.2, 0.1, 0.3]);
1181 *layer.biases = [0.05, -0.1];
1182
1183 let input = [0.7, -1.2];
1184 let output_grad = [0.8, -0.4];
1185 let mut output = [0.0; 2];
1186 let mut input_grad = [0.0; 2];
1187
1188 layer.zero_grad();
1189 layer.forward(&input, &mut output);
1190 layer.backward(&input, &output, &output_grad, &mut input_grad);
1191
1192 let eps = 1e-7;
1193 for i in 0..2 {
1194 let mut plus = input;
1195 let mut minus = input;
1196 plus[i] += eps;
1197 minus[i] -= eps;
1198
1199 let mut plus_out = [0.0; 2];
1200 let mut minus_out = [0.0; 2];
1201 layer.forward(&plus, &mut plus_out);
1202 layer.forward(&minus, &mut minus_out);
1203 let objective_plus = plus_out
1204 .iter()
1205 .zip(output_grad.iter())
1206 .map(|(o, g)| o * g)
1207 .sum::<Float>();
1208 let objective_minus = minus_out
1209 .iter()
1210 .zip(output_grad.iter())
1211 .map(|(o, g)| o * g)
1212 .sum::<Float>();
1213 let numeric = (objective_plus - objective_minus) / (2.0 * eps);
1214 approx_eq(input_grad[i], numeric, 1e-6);
1215 }
1216 }
1217
1218 #[test]
1219 fn dense_weight_gradient_matches_finite_difference() {
1220 let mut layer = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 11);
1221 layer.weights.copy_from_slice(&[0.4, -0.2, 0.1, 0.3]);
1222 *layer.biases = [0.05, -0.1];
1223
1224 let input = [0.7, -1.2];
1225 let output_grad = [0.8, -0.4];
1226 let mut output = [0.0; 2];
1227 let mut input_grad = [0.0; 2];
1228
1229 layer.zero_grad();
1230 layer.forward(&input, &mut output);
1231 layer.backward(&input, &output, &output_grad, &mut input_grad);
1232
1233 let weight_idx = 1;
1234 let eps = 1e-7;
1235 let mut plus = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 0);
1236 plus.weights.copy_from_slice(&layer.weights);
1237 plus.biases.copy_from_slice(layer.biases.as_ref());
1238 plus.weights[weight_idx] += eps;
1239 let mut minus = DenseLayer::<2, 2>::with_initializer_and_seed(Uniform::new(-0.3, 0.3), 0);
1240 minus.weights.copy_from_slice(&layer.weights);
1241 minus.biases.copy_from_slice(layer.biases.as_ref());
1242 minus.weights[weight_idx] -= eps;
1243
1244 let mut plus_out = [0.0; 2];
1245 let mut minus_out = [0.0; 2];
1246 plus.forward(&input, &mut plus_out);
1247 minus.forward(&input, &mut minus_out);
1248 let objective_plus = plus_out
1249 .iter()
1250 .zip(output_grad.iter())
1251 .map(|(o, g)| o * g)
1252 .sum::<Float>();
1253 let objective_minus = minus_out
1254 .iter()
1255 .zip(output_grad.iter())
1256 .map(|(o, g)| o * g)
1257 .sum::<Float>();
1258 let numeric = (objective_plus - objective_minus) / (2.0 * eps);
1259
1260 approx_eq(layer.weight_grads[weight_idx], numeric, 1e-6);
1261 }
1262
1263 #[test]
1264 fn seeded_initialization_is_reproducible() {
1265 let a = DenseLayer::<3, 2>::seeded(42);
1266 let b = DenseLayer::<3, 2>::seeded(42);
1267 assert_eq!(&*a.weights, &*b.weights);
1268 assert_eq!(&*a.biases, &*b.biases);
1269 }
1270
1271 #[test]
1272 fn builder_training_decreases_loss_with_seeded_shuffle() {
1273 let mut model = ModelBuilder::new()
1274 .input::<1>()
1275 .dense::<8>()
1276 .relu()
1277 .dense::<1>()
1278 .build();
1279 let samples = (-20..=20)
1280 .map(|i| {
1281 let x = i as Float / 10.0;
1282 Sample::new([x], [2.0 * x - 0.5])
1283 })
1284 .collect::<Vec<_>>();
1285 let config = TrainConfig::adam(0.03)
1286 .epochs(250)
1287 .batch_size(8)
1288 .shuffle_seed(9);
1289
1290 let before = samples
1291 .iter()
1292 .map(|sample| {
1293 let output = model.predict(&sample.input);
1294 let mut grad = [0.0; 1];
1295 MeanSquaredError.loss_and_grad(&output, &sample.target, &mut grad)
1296 })
1297 .sum::<Float>()
1298 / samples.len() as Float;
1299 let during = model.fit(&samples, config);
1300 let after = samples
1301 .iter()
1302 .map(|sample| {
1303 let output = model.predict(&sample.input);
1304 let mut grad = [0.0; 1];
1305 MeanSquaredError.loss_and_grad(&output, &sample.target, &mut grad)
1306 })
1307 .sum::<Float>()
1308 / samples.len() as Float;
1309
1310 assert!(during < before, "training step average should improve");
1311 assert!(after < before * 0.2, "expected loss to fall sharply");
1312 }
1313}