Skip to main content

yscv_model/
train.rs

1use yscv_autograd::{Graph, NodeId};
2use yscv_optim::{Adam, AdamW, LearningRate, LrScheduler, RmsProp, Sgd};
3use yscv_tensor::Tensor;
4
5use crate::{
6    BatchIterOptions, GradientAggregator, ModelError, SequentialModel, SupervisedDataset, bce_loss,
7    cross_entropy_loss, hinge_loss, huber_loss, mae_loss, mse_loss, nll_loss,
8};
9
10trait GraphOptimizer {
11    fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError>;
12}
13
14/// Configures supervised-loss function used by train-step and train-epoch helpers.
15#[derive(Debug, Clone, Copy, PartialEq, Default)]
16pub enum SupervisedLoss {
17    #[default]
18    Mse,
19    Mae,
20    Huber {
21        delta: f32,
22    },
23    Hinge {
24        margin: f32,
25    },
26    Bce,
27    Nll,
28    CrossEntropy,
29}
30
31fn build_loss_node(
32    graph: &mut Graph,
33    prediction: NodeId,
34    target: NodeId,
35    loss: SupervisedLoss,
36) -> Result<NodeId, ModelError> {
37    match loss {
38        SupervisedLoss::Mse => mse_loss(graph, prediction, target),
39        SupervisedLoss::Mae => mae_loss(graph, prediction, target),
40        SupervisedLoss::Huber { delta } => huber_loss(graph, prediction, target, delta),
41        SupervisedLoss::Hinge { margin } => hinge_loss(graph, prediction, target, margin),
42        SupervisedLoss::Bce => bce_loss(graph, prediction, target),
43        SupervisedLoss::Nll => nll_loss(graph, prediction, target),
44        SupervisedLoss::CrossEntropy => cross_entropy_loss(graph, prediction, target),
45    }
46}
47
48impl GraphOptimizer for Sgd {
49    fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
50        Sgd::step_graph_node(self, graph, node).map_err(Into::into)
51    }
52}
53
54impl GraphOptimizer for Adam {
55    fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
56        Adam::step_graph_node(self, graph, node).map_err(Into::into)
57    }
58}
59
60impl GraphOptimizer for AdamW {
61    fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
62        AdamW::step_graph_node(self, graph, node).map_err(Into::into)
63    }
64}
65
66impl GraphOptimizer for RmsProp {
67    fn step_graph_node(&mut self, graph: &mut Graph, node: NodeId) -> Result<(), ModelError> {
68        RmsProp::step_graph_node(self, graph, node).map_err(Into::into)
69    }
70}
71
72fn train_step_with_optimizer<O: GraphOptimizer>(
73    graph: &mut Graph,
74    optimizer: &mut O,
75    prediction: NodeId,
76    target: NodeId,
77    trainable_nodes: &[NodeId],
78    loss: SupervisedLoss,
79) -> Result<f32, ModelError> {
80    let loss_node = build_loss_node(graph, prediction, target, loss)?;
81    graph.backward(loss_node)?;
82
83    let loss_value = graph.value(loss_node)?.data()[0];
84    for node in trainable_nodes {
85        optimizer.step_graph_node(graph, *node)?;
86    }
87    Ok(loss_value)
88}
89
90/// Runs one full train step: loss forward, backward, and SGD updates.
91pub fn train_step_sgd(
92    graph: &mut Graph,
93    optimizer: &mut Sgd,
94    prediction: NodeId,
95    target: NodeId,
96    trainable_nodes: &[NodeId],
97) -> Result<f32, ModelError> {
98    train_step_sgd_with_loss(
99        graph,
100        optimizer,
101        prediction,
102        target,
103        trainable_nodes,
104        SupervisedLoss::Mse,
105    )
106}
107
108/// Runs one full train step: configured loss forward, backward, and SGD updates.
109pub fn train_step_sgd_with_loss(
110    graph: &mut Graph,
111    optimizer: &mut Sgd,
112    prediction: NodeId,
113    target: NodeId,
114    trainable_nodes: &[NodeId],
115    loss: SupervisedLoss,
116) -> Result<f32, ModelError> {
117    train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
118}
119
120/// Runs one full train step: loss forward, backward, and Adam updates.
121pub fn train_step_adam(
122    graph: &mut Graph,
123    optimizer: &mut Adam,
124    prediction: NodeId,
125    target: NodeId,
126    trainable_nodes: &[NodeId],
127) -> Result<f32, ModelError> {
128    train_step_adam_with_loss(
129        graph,
130        optimizer,
131        prediction,
132        target,
133        trainable_nodes,
134        SupervisedLoss::Mse,
135    )
136}
137
138/// Runs one full train step: configured loss forward, backward, and Adam updates.
139pub fn train_step_adam_with_loss(
140    graph: &mut Graph,
141    optimizer: &mut Adam,
142    prediction: NodeId,
143    target: NodeId,
144    trainable_nodes: &[NodeId],
145    loss: SupervisedLoss,
146) -> Result<f32, ModelError> {
147    train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
148}
149
150/// Runs one full train step: loss forward, backward, and AdamW updates.
151pub fn train_step_adamw(
152    graph: &mut Graph,
153    optimizer: &mut AdamW,
154    prediction: NodeId,
155    target: NodeId,
156    trainable_nodes: &[NodeId],
157) -> Result<f32, ModelError> {
158    train_step_adamw_with_loss(
159        graph,
160        optimizer,
161        prediction,
162        target,
163        trainable_nodes,
164        SupervisedLoss::Mse,
165    )
166}
167
168/// Runs one full train step: configured loss forward, backward, and AdamW updates.
169pub fn train_step_adamw_with_loss(
170    graph: &mut Graph,
171    optimizer: &mut AdamW,
172    prediction: NodeId,
173    target: NodeId,
174    trainable_nodes: &[NodeId],
175    loss: SupervisedLoss,
176) -> Result<f32, ModelError> {
177    train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
178}
179
180/// Runs one full train step: loss forward, backward, and RMSProp updates.
181pub fn train_step_rmsprop(
182    graph: &mut Graph,
183    optimizer: &mut RmsProp,
184    prediction: NodeId,
185    target: NodeId,
186    trainable_nodes: &[NodeId],
187) -> Result<f32, ModelError> {
188    train_step_rmsprop_with_loss(
189        graph,
190        optimizer,
191        prediction,
192        target,
193        trainable_nodes,
194        SupervisedLoss::Mse,
195    )
196}
197
198/// Runs one full train step: configured loss forward, backward, and RMSProp updates.
199pub fn train_step_rmsprop_with_loss(
200    graph: &mut Graph,
201    optimizer: &mut RmsProp,
202    prediction: NodeId,
203    target: NodeId,
204    trainable_nodes: &[NodeId],
205    loss: SupervisedLoss,
206) -> Result<f32, ModelError> {
207    train_step_with_optimizer(graph, optimizer, prediction, target, trainable_nodes, loss)
208}
209
210/// Metrics for one training epoch.
211#[derive(Debug, Clone, Copy, PartialEq)]
212pub struct EpochMetrics {
213    pub mean_loss: f32,
214    pub steps: usize,
215}
216
217/// Metrics for one scheduler-driven epoch.
218#[derive(Debug, Clone, Copy, PartialEq)]
219pub struct ScheduledEpochMetrics {
220    pub epoch: usize,
221    pub mean_loss: f32,
222    pub steps: usize,
223    pub learning_rate: f32,
224}
225
226/// Epoch-level training controls for batch order and preprocessing.
227#[derive(Debug, Clone, PartialEq)]
228pub struct EpochTrainOptions {
229    pub batch_size: usize,
230    pub batch_iter_options: BatchIterOptions,
231}
232
233impl Default for EpochTrainOptions {
234    fn default() -> Self {
235        Self {
236            batch_size: 1,
237            batch_iter_options: BatchIterOptions::default(),
238        }
239    }
240}
241
242/// Scheduler-driven epoch training controls.
243#[derive(Debug, Clone, PartialEq, Default)]
244pub struct SchedulerTrainOptions {
245    pub epoch_options: EpochTrainOptions,
246    pub loss: SupervisedLoss,
247}
248
249/// Deterministic one-epoch train loop over sequential mini-batches.
250pub fn train_epoch_sgd(
251    graph: &mut Graph,
252    model: &SequentialModel,
253    optimizer: &mut Sgd,
254    dataset: &SupervisedDataset,
255    batch_size: usize,
256) -> Result<EpochMetrics, ModelError> {
257    train_epoch_sgd_with_loss(
258        graph,
259        model,
260        optimizer,
261        dataset,
262        batch_size,
263        SupervisedLoss::Mse,
264    )
265}
266
267/// Deterministic one-epoch train loop with configurable supervised loss.
268pub fn train_epoch_sgd_with_loss(
269    graph: &mut Graph,
270    model: &SequentialModel,
271    optimizer: &mut Sgd,
272    dataset: &SupervisedDataset,
273    batch_size: usize,
274    loss: SupervisedLoss,
275) -> Result<EpochMetrics, ModelError> {
276    train_epoch_sgd_with_options_and_loss(
277        graph,
278        model,
279        optimizer,
280        dataset,
281        EpochTrainOptions {
282            batch_size,
283            batch_iter_options: BatchIterOptions::default(),
284        },
285        loss,
286    )
287}
288
289/// Deterministic one-epoch Adam train loop over sequential mini-batches.
290pub fn train_epoch_adam(
291    graph: &mut Graph,
292    model: &SequentialModel,
293    optimizer: &mut Adam,
294    dataset: &SupervisedDataset,
295    batch_size: usize,
296) -> Result<EpochMetrics, ModelError> {
297    train_epoch_adam_with_loss(
298        graph,
299        model,
300        optimizer,
301        dataset,
302        batch_size,
303        SupervisedLoss::Mse,
304    )
305}
306
307/// Deterministic one-epoch Adam train loop with configurable supervised loss.
308pub fn train_epoch_adam_with_loss(
309    graph: &mut Graph,
310    model: &SequentialModel,
311    optimizer: &mut Adam,
312    dataset: &SupervisedDataset,
313    batch_size: usize,
314    loss: SupervisedLoss,
315) -> Result<EpochMetrics, ModelError> {
316    train_epoch_adam_with_options_and_loss(
317        graph,
318        model,
319        optimizer,
320        dataset,
321        EpochTrainOptions {
322            batch_size,
323            batch_iter_options: BatchIterOptions::default(),
324        },
325        loss,
326    )
327}
328
329/// Deterministic one-epoch AdamW train loop over sequential mini-batches.
330pub fn train_epoch_adamw(
331    graph: &mut Graph,
332    model: &SequentialModel,
333    optimizer: &mut AdamW,
334    dataset: &SupervisedDataset,
335    batch_size: usize,
336) -> Result<EpochMetrics, ModelError> {
337    train_epoch_adamw_with_loss(
338        graph,
339        model,
340        optimizer,
341        dataset,
342        batch_size,
343        SupervisedLoss::Mse,
344    )
345}
346
347/// Deterministic one-epoch AdamW train loop with configurable supervised loss.
348pub fn train_epoch_adamw_with_loss(
349    graph: &mut Graph,
350    model: &SequentialModel,
351    optimizer: &mut AdamW,
352    dataset: &SupervisedDataset,
353    batch_size: usize,
354    loss: SupervisedLoss,
355) -> Result<EpochMetrics, ModelError> {
356    train_epoch_adamw_with_options_and_loss(
357        graph,
358        model,
359        optimizer,
360        dataset,
361        EpochTrainOptions {
362            batch_size,
363            batch_iter_options: BatchIterOptions::default(),
364        },
365        loss,
366    )
367}
368
369/// Deterministic one-epoch RMSProp train loop over sequential mini-batches.
370pub fn train_epoch_rmsprop(
371    graph: &mut Graph,
372    model: &SequentialModel,
373    optimizer: &mut RmsProp,
374    dataset: &SupervisedDataset,
375    batch_size: usize,
376) -> Result<EpochMetrics, ModelError> {
377    train_epoch_rmsprop_with_loss(
378        graph,
379        model,
380        optimizer,
381        dataset,
382        batch_size,
383        SupervisedLoss::Mse,
384    )
385}
386
387/// Deterministic one-epoch RMSProp train loop with configurable supervised loss.
388pub fn train_epoch_rmsprop_with_loss(
389    graph: &mut Graph,
390    model: &SequentialModel,
391    optimizer: &mut RmsProp,
392    dataset: &SupervisedDataset,
393    batch_size: usize,
394    loss: SupervisedLoss,
395) -> Result<EpochMetrics, ModelError> {
396    train_epoch_rmsprop_with_options_and_loss(
397        graph,
398        model,
399        optimizer,
400        dataset,
401        EpochTrainOptions {
402            batch_size,
403            batch_iter_options: BatchIterOptions::default(),
404        },
405        loss,
406    )
407}
408
409/// Deterministic one-epoch train loop with configurable batch iterator options.
410pub fn train_epoch_sgd_with_options(
411    graph: &mut Graph,
412    model: &SequentialModel,
413    optimizer: &mut Sgd,
414    dataset: &SupervisedDataset,
415    options: EpochTrainOptions,
416) -> Result<EpochMetrics, ModelError> {
417    train_epoch_sgd_with_options_and_loss(
418        graph,
419        model,
420        optimizer,
421        dataset,
422        options,
423        SupervisedLoss::Mse,
424    )
425}
426
427/// Deterministic one-epoch train loop with configurable batch iterator options and loss.
428pub fn train_epoch_sgd_with_options_and_loss(
429    graph: &mut Graph,
430    model: &SequentialModel,
431    optimizer: &mut Sgd,
432    dataset: &SupervisedDataset,
433    options: EpochTrainOptions,
434    loss: SupervisedLoss,
435) -> Result<EpochMetrics, ModelError> {
436    train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
437}
438
439/// Deterministic one-epoch Adam train loop with configurable batch iterator options.
440pub fn train_epoch_adam_with_options(
441    graph: &mut Graph,
442    model: &SequentialModel,
443    optimizer: &mut Adam,
444    dataset: &SupervisedDataset,
445    options: EpochTrainOptions,
446) -> Result<EpochMetrics, ModelError> {
447    train_epoch_adam_with_options_and_loss(
448        graph,
449        model,
450        optimizer,
451        dataset,
452        options,
453        SupervisedLoss::Mse,
454    )
455}
456
457/// Deterministic one-epoch Adam train loop with configurable batch iterator options and loss.
458pub fn train_epoch_adam_with_options_and_loss(
459    graph: &mut Graph,
460    model: &SequentialModel,
461    optimizer: &mut Adam,
462    dataset: &SupervisedDataset,
463    options: EpochTrainOptions,
464    loss: SupervisedLoss,
465) -> Result<EpochMetrics, ModelError> {
466    train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
467}
468
469/// Deterministic one-epoch AdamW train loop with configurable batch iterator options.
470pub fn train_epoch_adamw_with_options(
471    graph: &mut Graph,
472    model: &SequentialModel,
473    optimizer: &mut AdamW,
474    dataset: &SupervisedDataset,
475    options: EpochTrainOptions,
476) -> Result<EpochMetrics, ModelError> {
477    train_epoch_adamw_with_options_and_loss(
478        graph,
479        model,
480        optimizer,
481        dataset,
482        options,
483        SupervisedLoss::Mse,
484    )
485}
486
487/// Deterministic one-epoch AdamW train loop with configurable batch iterator options and loss.
488pub fn train_epoch_adamw_with_options_and_loss(
489    graph: &mut Graph,
490    model: &SequentialModel,
491    optimizer: &mut AdamW,
492    dataset: &SupervisedDataset,
493    options: EpochTrainOptions,
494    loss: SupervisedLoss,
495) -> Result<EpochMetrics, ModelError> {
496    train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
497}
498
499/// Deterministic one-epoch RMSProp train loop with configurable batch iterator options.
500pub fn train_epoch_rmsprop_with_options(
501    graph: &mut Graph,
502    model: &SequentialModel,
503    optimizer: &mut RmsProp,
504    dataset: &SupervisedDataset,
505    options: EpochTrainOptions,
506) -> Result<EpochMetrics, ModelError> {
507    train_epoch_rmsprop_with_options_and_loss(
508        graph,
509        model,
510        optimizer,
511        dataset,
512        options,
513        SupervisedLoss::Mse,
514    )
515}
516
517/// Deterministic one-epoch RMSProp train loop with configurable batch iterator options and loss.
518pub fn train_epoch_rmsprop_with_options_and_loss(
519    graph: &mut Graph,
520    model: &SequentialModel,
521    optimizer: &mut RmsProp,
522    dataset: &SupervisedDataset,
523    options: EpochTrainOptions,
524    loss: SupervisedLoss,
525) -> Result<EpochMetrics, ModelError> {
526    train_epoch_with_options(graph, model, optimizer, dataset, options, loss)
527}
528
529/// Runs multiple SGD epochs and advances scheduler after each epoch.
530pub fn train_epochs_sgd_with_scheduler<S: LrScheduler>(
531    graph: &mut Graph,
532    model: &SequentialModel,
533    optimizer: &mut Sgd,
534    scheduler: &mut S,
535    dataset: &SupervisedDataset,
536    epochs: usize,
537    options: EpochTrainOptions,
538) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
539    train_epochs_sgd_with_scheduler_and_loss(
540        graph,
541        model,
542        optimizer,
543        scheduler,
544        dataset,
545        epochs,
546        SchedulerTrainOptions {
547            epoch_options: options,
548            loss: SupervisedLoss::Mse,
549        },
550    )
551}
552
553/// Runs multiple SGD epochs with configurable supervised loss and advances scheduler after each epoch.
554pub fn train_epochs_sgd_with_scheduler_and_loss<S: LrScheduler>(
555    graph: &mut Graph,
556    model: &SequentialModel,
557    optimizer: &mut Sgd,
558    scheduler: &mut S,
559    dataset: &SupervisedDataset,
560    epochs: usize,
561    options: SchedulerTrainOptions,
562) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
563    train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
564}
565
566/// Runs multiple Adam epochs and advances scheduler after each epoch.
567pub fn train_epochs_adam_with_scheduler<S: LrScheduler>(
568    graph: &mut Graph,
569    model: &SequentialModel,
570    optimizer: &mut Adam,
571    scheduler: &mut S,
572    dataset: &SupervisedDataset,
573    epochs: usize,
574    options: EpochTrainOptions,
575) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
576    train_epochs_adam_with_scheduler_and_loss(
577        graph,
578        model,
579        optimizer,
580        scheduler,
581        dataset,
582        epochs,
583        SchedulerTrainOptions {
584            epoch_options: options,
585            loss: SupervisedLoss::Mse,
586        },
587    )
588}
589
590/// Runs multiple Adam epochs with configurable supervised loss and advances scheduler after each epoch.
591pub fn train_epochs_adam_with_scheduler_and_loss<S: LrScheduler>(
592    graph: &mut Graph,
593    model: &SequentialModel,
594    optimizer: &mut Adam,
595    scheduler: &mut S,
596    dataset: &SupervisedDataset,
597    epochs: usize,
598    options: SchedulerTrainOptions,
599) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
600    train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
601}
602
603/// Runs multiple AdamW epochs and advances scheduler after each epoch.
604pub fn train_epochs_adamw_with_scheduler<S: LrScheduler>(
605    graph: &mut Graph,
606    model: &SequentialModel,
607    optimizer: &mut AdamW,
608    scheduler: &mut S,
609    dataset: &SupervisedDataset,
610    epochs: usize,
611    options: EpochTrainOptions,
612) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
613    train_epochs_adamw_with_scheduler_and_loss(
614        graph,
615        model,
616        optimizer,
617        scheduler,
618        dataset,
619        epochs,
620        SchedulerTrainOptions {
621            epoch_options: options,
622            loss: SupervisedLoss::Mse,
623        },
624    )
625}
626
627/// Runs multiple AdamW epochs with configurable supervised loss and advances scheduler after each epoch.
628pub fn train_epochs_adamw_with_scheduler_and_loss<S: LrScheduler>(
629    graph: &mut Graph,
630    model: &SequentialModel,
631    optimizer: &mut AdamW,
632    scheduler: &mut S,
633    dataset: &SupervisedDataset,
634    epochs: usize,
635    options: SchedulerTrainOptions,
636) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
637    train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
638}
639
640/// Runs multiple RMSProp epochs and advances scheduler after each epoch.
641pub fn train_epochs_rmsprop_with_scheduler<S: LrScheduler>(
642    graph: &mut Graph,
643    model: &SequentialModel,
644    optimizer: &mut RmsProp,
645    scheduler: &mut S,
646    dataset: &SupervisedDataset,
647    epochs: usize,
648    options: EpochTrainOptions,
649) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
650    train_epochs_rmsprop_with_scheduler_and_loss(
651        graph,
652        model,
653        optimizer,
654        scheduler,
655        dataset,
656        epochs,
657        SchedulerTrainOptions {
658            epoch_options: options,
659            loss: SupervisedLoss::Mse,
660        },
661    )
662}
663
664/// Runs multiple RMSProp epochs with configurable supervised loss and advances scheduler after each epoch.
665pub fn train_epochs_rmsprop_with_scheduler_and_loss<S: LrScheduler>(
666    graph: &mut Graph,
667    model: &SequentialModel,
668    optimizer: &mut RmsProp,
669    scheduler: &mut S,
670    dataset: &SupervisedDataset,
671    epochs: usize,
672    options: SchedulerTrainOptions,
673) -> Result<Vec<ScheduledEpochMetrics>, ModelError> {
674    train_epochs_with_scheduler(graph, model, optimizer, scheduler, dataset, epochs, options)
675}
676
677fn train_epoch_with_options<O: GraphOptimizer>(
678    graph: &mut Graph,
679    model: &SequentialModel,
680    optimizer: &mut O,
681    dataset: &SupervisedDataset,
682    options: EpochTrainOptions,
683    loss: SupervisedLoss,
684) -> Result<EpochMetrics, ModelError> {
685    if dataset.is_empty() {
686        return Err(ModelError::EmptyDataset);
687    }
688    let batches = dataset.batches_with_options(options.batch_size, options.batch_iter_options)?;
689    let trainable_nodes = model.trainable_nodes();
690
691    let mut loss_sum = 0.0f32;
692    let mut steps = 0usize;
693    for batch in batches {
694        graph.truncate(model.persistent_node_count())?;
695
696        let input = graph.constant(batch.inputs);
697        let target = graph.constant(batch.targets);
698        let prediction = model.forward(graph, input)?;
699        let loss_value = train_step_with_optimizer(
700            graph,
701            optimizer,
702            prediction,
703            target,
704            &trainable_nodes,
705            loss,
706        )?;
707        loss_sum += loss_value;
708        steps += 1;
709    }
710    if steps == 0 {
711        return Err(ModelError::EmptyDataset);
712    }
713
714    Ok(EpochMetrics {
715        mean_loss: loss_sum / steps as f32,
716        steps,
717    })
718}
719
720fn train_epochs_with_scheduler<O, S>(
721    graph: &mut Graph,
722    model: &SequentialModel,
723    optimizer: &mut O,
724    scheduler: &mut S,
725    dataset: &SupervisedDataset,
726    epochs: usize,
727    options: SchedulerTrainOptions,
728) -> Result<Vec<ScheduledEpochMetrics>, ModelError>
729where
730    O: GraphOptimizer + LearningRate,
731    S: LrScheduler,
732{
733    if epochs == 0 {
734        return Err(ModelError::InvalidEpochCount { epochs });
735    }
736
737    let mut all_metrics = Vec::with_capacity(epochs);
738    for epoch_index in 0..epochs {
739        let epoch_metrics = train_epoch_with_options(
740            graph,
741            model,
742            optimizer,
743            dataset,
744            options.epoch_options.clone(),
745            options.loss,
746        )?;
747        let learning_rate = scheduler.step(optimizer)?;
748        all_metrics.push(ScheduledEpochMetrics {
749            epoch: epoch_index + 1,
750            mean_loss: epoch_metrics.mean_loss,
751            steps: epoch_metrics.steps,
752            learning_rate,
753        });
754    }
755    Ok(all_metrics)
756}
757
758// ── High-level CNN training and inference helpers ──────────────────
759
760/// Configuration for high-level CNN training.
761#[derive(Debug, Clone)]
762pub struct CnnTrainConfig {
763    pub lr: f32,
764    pub batch_size: usize,
765    pub loss: SupervisedLoss,
766    pub batch_iter_options: BatchIterOptions,
767}
768
769impl Default for CnnTrainConfig {
770    fn default() -> Self {
771        Self {
772            lr: 0.01,
773            batch_size: 16,
774            loss: SupervisedLoss::CrossEntropy,
775            batch_iter_options: BatchIterOptions::default(),
776        }
777    }
778}
779
780/// One-call CNN training epoch: register params, forward, loss, backward, update, sync.
781///
782/// Handles the full graph-mode CNN lifecycle for one epoch with SGD.
783pub fn train_cnn_epoch_sgd(
784    graph: &mut Graph,
785    model: &mut SequentialModel,
786    dataset: &SupervisedDataset,
787    config: &CnnTrainConfig,
788) -> Result<EpochMetrics, ModelError> {
789    let mut optimizer = yscv_optim::Sgd::new(config.lr)?;
790    train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
791}
792
793/// One-call CNN training epoch with Adam optimizer.
794pub fn train_cnn_epoch_adam(
795    graph: &mut Graph,
796    model: &mut SequentialModel,
797    dataset: &SupervisedDataset,
798    config: &CnnTrainConfig,
799) -> Result<EpochMetrics, ModelError> {
800    let mut optimizer = Adam::new(config.lr)?;
801    train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
802}
803
804/// One-call CNN training epoch with AdamW optimizer.
805pub fn train_cnn_epoch_adamw(
806    graph: &mut Graph,
807    model: &mut SequentialModel,
808    dataset: &SupervisedDataset,
809    config: &CnnTrainConfig,
810) -> Result<EpochMetrics, ModelError> {
811    let mut optimizer = AdamW::new(config.lr)?;
812    train_cnn_epoch_with_optimizer(graph, model, dataset, &mut optimizer, config)
813}
814
815fn train_cnn_epoch_with_optimizer<O: GraphOptimizer>(
816    graph: &mut Graph,
817    model: &mut SequentialModel,
818    dataset: &SupervisedDataset,
819    optimizer: &mut O,
820    config: &CnnTrainConfig,
821) -> Result<EpochMetrics, ModelError> {
822    model.register_cnn_params(graph);
823    let param_nodes = model.trainable_nodes();
824    let persistent = model.persistent_node_count();
825    let iter =
826        dataset.batches_with_options(config.batch_size, config.batch_iter_options.clone())?;
827
828    let mut total_loss = 0.0f32;
829    let mut steps = 0usize;
830
831    for batch in iter {
832        graph.truncate(persistent)?;
833        let input_node = graph.variable(batch.inputs);
834        let target_node = graph.variable(batch.targets);
835        let prediction = model.forward(graph, input_node)?;
836        let loss_val = train_step_with_optimizer(
837            graph,
838            optimizer,
839            prediction,
840            target_node,
841            &param_nodes,
842            config.loss,
843        )?;
844        model.sync_cnn_from_graph(graph)?;
845        total_loss += loss_val;
846        steps += 1;
847    }
848
849    Ok(EpochMetrics {
850        mean_loss: if steps > 0 {
851            total_loss / steps as f32
852        } else {
853            0.0
854        },
855        steps,
856    })
857}
858
859/// Multi-epoch CNN training with configurable optimizer type.
860#[derive(Debug, Clone, Copy, PartialEq, Eq)]
861pub enum OptimizerType {
862    Sgd,
863    Adam,
864    AdamW,
865}
866
867/// Runs multiple CNN training epochs, returning per-epoch metrics.
868pub fn train_cnn_epochs(
869    graph: &mut Graph,
870    model: &mut SequentialModel,
871    dataset: &SupervisedDataset,
872    epochs: usize,
873    config: &CnnTrainConfig,
874    optimizer_type: OptimizerType,
875) -> Result<Vec<EpochMetrics>, ModelError> {
876    if epochs == 0 {
877        return Err(ModelError::InvalidEpochCount { epochs });
878    }
879    let mut all = Vec::with_capacity(epochs);
880    for _ in 0..epochs {
881        let metrics = match optimizer_type {
882            OptimizerType::Sgd => train_cnn_epoch_sgd(graph, model, dataset, config)?,
883            OptimizerType::Adam => train_cnn_epoch_adam(graph, model, dataset, config)?,
884            OptimizerType::AdamW => train_cnn_epoch_adamw(graph, model, dataset, config)?,
885        };
886        all.push(metrics);
887    }
888    Ok(all)
889}
890
891// ── Gradient accumulation helpers ──────────────────────────────────
892
893/// Scales gradients of the given nodes by a scalar factor.
894///
895/// Nodes without computed gradients are skipped.
896pub fn scale_gradients(graph: &mut Graph, nodes: &[NodeId], scale: f32) -> Result<(), ModelError> {
897    for &node in nodes {
898        if let Some(grad) = graph.grad_mut(node)? {
899            let scaled = grad.scale(scale);
900            *grad = scaled;
901        }
902    }
903    Ok(())
904}
905
906/// Adds source gradients into the existing gradients of the given nodes.
907///
908/// For each node, if the node already has a gradient, the corresponding source
909/// gradient is added to it element-wise.  If the node has no gradient yet, the
910/// source gradient is cloned and set directly.  Source entries that are `None`
911/// are skipped.
912///
913/// `nodes` and `source_grads` must have the same length.
914pub fn accumulate_gradients(
915    graph: &mut Graph,
916    nodes: &[NodeId],
917    source_grads: &[Option<Tensor>],
918) -> Result<(), ModelError> {
919    assert_eq!(
920        nodes.len(),
921        source_grads.len(),
922        "nodes and source_grads must have the same length"
923    );
924    for (i, &node) in nodes.iter().enumerate() {
925        if let Some(src) = &source_grads[i] {
926            let existing = graph.grad(node)?;
927            let new_grad = match existing {
928                Some(current) => current.add(src)?,
929                None => src.clone(),
930            };
931            graph.set_grad(node, new_grad)?;
932        }
933    }
934    Ok(())
935}
936
937/// Collects the current gradients for a set of nodes as owned tensors.
938///
939/// Returns a `Vec` where each entry is `Some(grad.clone())` if a gradient
940/// exists for the corresponding node, or `None` otherwise.
941pub fn collect_gradients(
942    graph: &Graph,
943    nodes: &[NodeId],
944) -> Result<Vec<Option<Tensor>>, ModelError> {
945    let mut grads = Vec::with_capacity(nodes.len());
946    for &node in nodes {
947        grads.push(graph.grad(node)?.cloned());
948    }
949    Ok(grads)
950}
951
952/// Runs one training step with gradient accumulation across multiple
953/// micro-batches.
954///
955/// This is the SGD variant.  The caller supplies a closure that, given a
956/// mutable `Graph` reference, creates a fresh micro-batch forward pass and
957/// returns `(prediction_node, target_node)`.  The closure is called
958/// `accumulation_steps` times.
959///
960/// For each micro-batch the loss is scaled by `1 / accumulation_steps` so
961/// that the accumulated gradients approximate the gradient over the
962/// effective (large) batch.  The optimizer is stepped only once, after all
963/// micro-batches.
964///
965/// Returns the average loss across the micro-batches.
966pub fn train_step_sgd_with_accumulation<F>(
967    graph: &mut Graph,
968    optimizer: &mut Sgd,
969    trainable_nodes: &[NodeId],
970    accumulation_steps: usize,
971    loss_fn: SupervisedLoss,
972    mut micro_batch_fn: F,
973) -> Result<f32, ModelError>
974where
975    F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
976{
977    train_step_with_accumulation_impl(
978        graph,
979        optimizer,
980        trainable_nodes,
981        accumulation_steps,
982        loss_fn,
983        &mut micro_batch_fn,
984    )
985}
986
987/// Runs one training step with gradient accumulation across multiple
988/// micro-batches using the Adam optimizer.
989pub fn train_step_adam_with_accumulation<F>(
990    graph: &mut Graph,
991    optimizer: &mut Adam,
992    trainable_nodes: &[NodeId],
993    accumulation_steps: usize,
994    loss_fn: SupervisedLoss,
995    mut micro_batch_fn: F,
996) -> Result<f32, ModelError>
997where
998    F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
999{
1000    train_step_with_accumulation_impl(
1001        graph,
1002        optimizer,
1003        trainable_nodes,
1004        accumulation_steps,
1005        loss_fn,
1006        &mut micro_batch_fn,
1007    )
1008}
1009
1010/// Runs one training step with gradient accumulation across multiple
1011/// micro-batches using the AdamW optimizer.
1012pub fn train_step_adamw_with_accumulation<F>(
1013    graph: &mut Graph,
1014    optimizer: &mut AdamW,
1015    trainable_nodes: &[NodeId],
1016    accumulation_steps: usize,
1017    loss_fn: SupervisedLoss,
1018    mut micro_batch_fn: F,
1019) -> Result<f32, ModelError>
1020where
1021    F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1022{
1023    train_step_with_accumulation_impl(
1024        graph,
1025        optimizer,
1026        trainable_nodes,
1027        accumulation_steps,
1028        loss_fn,
1029        &mut micro_batch_fn,
1030    )
1031}
1032
1033/// Runs one training step with gradient accumulation across multiple
1034/// micro-batches using the RMSProp optimizer.
1035pub fn train_step_rmsprop_with_accumulation<F>(
1036    graph: &mut Graph,
1037    optimizer: &mut RmsProp,
1038    trainable_nodes: &[NodeId],
1039    accumulation_steps: usize,
1040    loss_fn: SupervisedLoss,
1041    mut micro_batch_fn: F,
1042) -> Result<f32, ModelError>
1043where
1044    F: FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1045{
1046    train_step_with_accumulation_impl(
1047        graph,
1048        optimizer,
1049        trainable_nodes,
1050        accumulation_steps,
1051        loss_fn,
1052        &mut micro_batch_fn,
1053    )
1054}
1055
1056#[allow(clippy::type_complexity)]
1057fn train_step_with_accumulation_impl<O: GraphOptimizer>(
1058    graph: &mut Graph,
1059    optimizer: &mut O,
1060    trainable_nodes: &[NodeId],
1061    accumulation_steps: usize,
1062    loss_fn: SupervisedLoss,
1063    micro_batch_fn: &mut dyn FnMut(&mut Graph) -> Result<(NodeId, NodeId), ModelError>,
1064) -> Result<f32, ModelError> {
1065    if accumulation_steps == 0 {
1066        return Err(ModelError::InvalidAccumulationSteps {
1067            steps: accumulation_steps,
1068        });
1069    }
1070
1071    let scale = 1.0 / accumulation_steps as f32;
1072    let mut accumulated: Vec<Option<Tensor>> = vec![None; trainable_nodes.len()];
1073    let mut total_loss = 0.0f32;
1074
1075    for _ in 0..accumulation_steps {
1076        // Zero grads before each micro-batch backward.
1077        graph.zero_grads();
1078
1079        let (prediction, target) = micro_batch_fn(graph)?;
1080        let loss_node = build_loss_node(graph, prediction, target, loss_fn)?;
1081        graph.backward(loss_node)?;
1082
1083        let loss_value = graph.value(loss_node)?.data()[0];
1084        total_loss += loss_value;
1085
1086        // Collect this micro-batch's gradients and accumulate with scaling.
1087        for (i, &node) in trainable_nodes.iter().enumerate() {
1088            if let Some(grad) = graph.grad(node)? {
1089                let scaled = grad.scale(scale);
1090                accumulated[i] = Some(match accumulated[i].take() {
1091                    Some(acc) => acc.add(&scaled)?,
1092                    None => scaled,
1093                });
1094            }
1095        }
1096    }
1097
1098    // Write the accumulated gradients back and step the optimizer.
1099    for (i, &node) in trainable_nodes.iter().enumerate() {
1100        if let Some(grad) = accumulated[i].take() {
1101            graph.set_grad(node, grad)?;
1102        }
1103    }
1104
1105    for &node in trainable_nodes {
1106        optimizer.step_graph_node(graph, node)?;
1107    }
1108
1109    Ok(total_loss / accumulation_steps as f32)
1110}
1111
1112/// Batch inference on a SequentialModel (tensor mode, no autograd graph).
1113pub fn infer_batch(
1114    model: &SequentialModel,
1115    input: &yscv_tensor::Tensor,
1116) -> Result<yscv_tensor::Tensor, ModelError> {
1117    model.forward_inference(input)
1118}
1119
1120/// Runs inference through the autograd graph and returns the output tensor value.
1121pub fn infer_batch_graph(
1122    graph: &mut Graph,
1123    model: &SequentialModel,
1124    input: yscv_tensor::Tensor,
1125) -> Result<yscv_tensor::Tensor, ModelError> {
1126    let persistent = model.persistent_node_count();
1127    graph.truncate(persistent)?;
1128    let input_node = graph.variable(input);
1129    let output_node = model.forward(graph, input_node)?;
1130    Ok(graph.value(output_node)?.clone())
1131}
1132
1133// ── Distributed training epoch ─────────────────────────────────────
1134
1135/// Train one epoch with distributed gradient synchronization.
1136///
1137/// After each batch's backward pass, gradients are collected from the
1138/// trainable parameter nodes, aggregated across all ranks using the
1139/// provided [`GradientAggregator`] (e.g. `AllReduceAggregator` or
1140/// `LocalAggregator` for single-rank), written back, and then the
1141/// optimizer is stepped.  This is the data-parallel training pattern.
1142///
1143/// The caller supplies a closure `train_batch_fn` that, given the graph
1144/// and a batch index, must:
1145///   1. Set up the forward pass for the batch (feed inputs, compute
1146///      prediction).
1147///   2. Compute the loss and call `graph.backward(loss_node)`.
1148///   3. Return `Ok(loss_scalar)`.
1149///
1150/// The function returns the mean loss across all batches.
1151#[allow(private_bounds)]
1152pub fn train_epoch_distributed<F, O: GraphOptimizer>(
1153    graph: &mut Graph,
1154    optimizer: &mut O,
1155    aggregator: &mut dyn GradientAggregator,
1156    trainable_nodes: &[NodeId],
1157    num_batches: usize,
1158    train_batch_fn: &mut F,
1159) -> Result<EpochMetrics, ModelError>
1160where
1161    F: FnMut(&mut Graph, usize) -> Result<f32, ModelError>,
1162{
1163    if num_batches == 0 {
1164        return Err(ModelError::EmptyDataset);
1165    }
1166
1167    let mut loss_sum = 0.0f32;
1168
1169    for batch_idx in 0..num_batches {
1170        // 1. Run the user's forward+backward pass for this batch.
1171        let loss_value = train_batch_fn(graph, batch_idx)?;
1172        loss_sum += loss_value;
1173
1174        // 2. Collect gradients from the trainable nodes into tensors.
1175        let mut local_grads = Vec::with_capacity(trainable_nodes.len());
1176        for &node in trainable_nodes {
1177            let grad = match graph.grad(node)?.cloned() {
1178                Some(g) => g,
1179                None => {
1180                    // If a node has no gradient, use a zero tensor with
1181                    // the same shape as the parameter value.
1182                    let val = graph.value(node)?;
1183                    Tensor::zeros(val.shape().to_vec())?
1184                }
1185            };
1186            local_grads.push(grad);
1187        }
1188
1189        // 3. Aggregate gradients across all ranks.
1190        let aggregated = aggregator.aggregate(&local_grads)?;
1191
1192        // 4. Write the aggregated gradients back and step the optimizer.
1193        for (i, &node) in trainable_nodes.iter().enumerate() {
1194            graph.set_grad(node, aggregated[i].clone())?;
1195        }
1196
1197        for &node in trainable_nodes {
1198            optimizer.step_graph_node(graph, node)?;
1199        }
1200    }
1201
1202    Ok(EpochMetrics {
1203        mean_loss: loss_sum / num_batches as f32,
1204        steps: num_batches,
1205    })
1206}
1207
1208/// Convenience wrapper: train one distributed epoch over a
1209/// [`SequentialModel`] and [`SupervisedDataset`] with SGD.
1210///
1211/// This mirrors [`train_epoch_sgd`] but inserts an aggregation step
1212/// between backward and optimizer update on every batch.
1213pub fn train_epoch_distributed_sgd(
1214    graph: &mut Graph,
1215    model: &SequentialModel,
1216    optimizer: &mut Sgd,
1217    aggregator: &mut dyn GradientAggregator,
1218    dataset: &SupervisedDataset,
1219    batch_size: usize,
1220    loss: SupervisedLoss,
1221) -> Result<EpochMetrics, ModelError> {
1222    if dataset.is_empty() {
1223        return Err(ModelError::EmptyDataset);
1224    }
1225    let batches: Vec<_> = dataset
1226        .batches_with_options(batch_size, BatchIterOptions::default())?
1227        .collect();
1228    let trainable_nodes = model.trainable_nodes();
1229    let persistent = model.persistent_node_count();
1230    let num_batches = batches.len();
1231
1232    let mut batch_iter = batches.into_iter();
1233
1234    train_epoch_distributed(
1235        graph,
1236        optimizer,
1237        aggregator,
1238        &trainable_nodes,
1239        num_batches,
1240        &mut |g, _batch_idx| {
1241            let batch = batch_iter.next().ok_or(ModelError::EmptyDataset)?;
1242            g.truncate(persistent)?;
1243            let input = g.constant(batch.inputs);
1244            let target = g.constant(batch.targets);
1245            let prediction = model.forward(g, input)?;
1246            let loss_node = build_loss_node(g, prediction, target, loss)?;
1247            g.backward(loss_node)?;
1248            let loss_value = g.value(loss_node)?.data()[0];
1249            Ok(loss_value)
1250        },
1251    )
1252}