Skip to main content

scry_learn/neural/
classifier.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Multi-layer perceptron classifier.
3//!
4//! Sklearn-compatible API with builder pattern.
5//!
6//! ```ignore
7//! let mut clf = MLPClassifier::new()
8//!     .hidden_layers(&[100, 50])
9//!     .activation(Activation::Relu)
10//!     .optimizer(OptimizerKind::Adam)
11//!     .learning_rate(0.001)
12//!     .max_iter(200)
13//!     .batch_size(32)
14//!     .early_stopping(true)
15//!     .seed(42);
16//! clf.fit(&train_data)?;
17//! let preds = clf.predict(&test_features)?;
18//! ```
19
20use crate::dataset::Dataset;
21use crate::error::{Result, ScryLearnError};
22use crate::neural::activation::Activation;
23use crate::neural::callback::{
24    self, CallbackAction, EpochMetrics, TrainingCallback, TrainingHistory,
25};
26use crate::neural::layer::FastRng;
27use crate::neural::network::{self, Network};
28use crate::neural::optimizer::{LearningRateSchedule, OptimizerKind, OptimizerState};
29use crate::partial_fit::PartialFit;
30
31/// Multi-layer perceptron classifier.
32///
33/// Trains a feedforward neural network for classification using
34/// backpropagation with configurable optimizers and activations.
35///
36/// Defaults match sklearn: `hidden_layers=[100]`, Adam, lr=0.001,
37/// `max_iter=200`, `batch_size=200`, `alpha=0.0001`.
38#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
39#[non_exhaustive]
40pub struct MLPClassifier {
41    hidden_layers: Vec<usize>,
42    activation: Activation,
43    optimizer_kind: OptimizerKind,
44    learning_rate: f64,
45    max_iter: usize,
46    batch_size: usize,
47    alpha: f64,
48    tolerance: f64,
49    early_stopping: bool,
50    validation_fraction: f64,
51    n_iter_no_change: usize,
52    seed: u64,
53    /// Dropout probability applied between hidden layers (0.0 = no dropout).
54    dropout_rate: f64,
55    /// Learning rate schedule.
56    lr_schedule: LearningRateSchedule,
57    // ── Fitted state ──
58    fitted: bool,
59    n_features: usize,
60    n_classes: usize,
61    class_labels: Vec<f64>,
62    network_weights: Vec<(Vec<f64>, Vec<f64>)>,
63    network_dims: Vec<(usize, usize)>,
64    /// Training loss curve (one entry per epoch).
65    pub loss_curve: Vec<f64>,
66    /// Structured training history with per-epoch metrics.
67    training_history: TrainingHistory,
68    /// User-supplied training callbacks (not cloned — session-specific).
69    #[cfg_attr(feature = "serde", serde(skip))]
70    callbacks: Vec<Box<dyn TrainingCallback>>,
71    #[cfg_attr(feature = "serde", serde(default))]
72    _schema_version: u32,
73}
74
75impl Clone for MLPClassifier {
76    fn clone(&self) -> Self {
77        Self {
78            hidden_layers: self.hidden_layers.clone(),
79            activation: self.activation,
80            optimizer_kind: self.optimizer_kind,
81            learning_rate: self.learning_rate,
82            max_iter: self.max_iter,
83            batch_size: self.batch_size,
84            alpha: self.alpha,
85            tolerance: self.tolerance,
86            early_stopping: self.early_stopping,
87            validation_fraction: self.validation_fraction,
88            n_iter_no_change: self.n_iter_no_change,
89            seed: self.seed,
90            dropout_rate: self.dropout_rate,
91            lr_schedule: self.lr_schedule,
92            fitted: self.fitted,
93            n_features: self.n_features,
94            n_classes: self.n_classes,
95            class_labels: self.class_labels.clone(),
96            network_weights: self.network_weights.clone(),
97            network_dims: self.network_dims.clone(),
98            loss_curve: self.loss_curve.clone(),
99            training_history: self.training_history.clone(),
100            // Callbacks are session-specific and not cloned.
101            callbacks: Vec::new(),
102            _schema_version: 0,
103        }
104    }
105}
106
107impl MLPClassifier {
108    /// Create a new MLP classifier with sklearn defaults.
109    pub fn new() -> Self {
110        Self {
111            hidden_layers: vec![100],
112            activation: Activation::Relu,
113            optimizer_kind: OptimizerKind::default(),
114            learning_rate: 0.001,
115            max_iter: 200,
116            batch_size: 200,
117            alpha: 0.0001,
118            tolerance: 1e-4,
119            early_stopping: false,
120            validation_fraction: 0.1,
121            n_iter_no_change: 10,
122            seed: 42,
123            dropout_rate: 0.0,
124            lr_schedule: LearningRateSchedule::Constant,
125            fitted: false,
126            n_features: 0,
127            n_classes: 0,
128            class_labels: Vec::new(),
129            network_weights: Vec::new(),
130            network_dims: Vec::new(),
131            loss_curve: Vec::new(),
132            training_history: TrainingHistory::new(),
133            callbacks: Vec::new(),
134            _schema_version: 0,
135        }
136    }
137
138    /// Set hidden layer sizes. Default: `&[100]`.
139    pub fn hidden_layers(mut self, sizes: &[usize]) -> Self {
140        self.hidden_layers = sizes.to_vec();
141        self
142    }
143
144    /// Set activation function for hidden layers. Default: ReLU.
145    pub fn activation(mut self, activation: Activation) -> Self {
146        self.activation = activation;
147        self
148    }
149
150    /// Set optimizer algorithm. Default: Adam.
151    pub fn optimizer(mut self, kind: OptimizerKind) -> Self {
152        self.optimizer_kind = kind;
153        self
154    }
155
156    /// Set learning rate. Default: 0.001.
157    pub fn learning_rate(mut self, lr: f64) -> Self {
158        self.learning_rate = lr;
159        self
160    }
161
162    /// Set maximum training iterations (epochs). Default: 200.
163    pub fn max_iter(mut self, n: usize) -> Self {
164        self.max_iter = n;
165        self
166    }
167
168    /// Set mini-batch size. Default: 200.
169    pub fn batch_size(mut self, n: usize) -> Self {
170        self.batch_size = n;
171        self
172    }
173
174    /// Set L2 regularization strength. Default: 0.0001.
175    pub fn alpha(mut self, a: f64) -> Self {
176        self.alpha = a;
177        self
178    }
179
180    /// Set convergence tolerance. Default: 1e-4.
181    pub fn tolerance(mut self, tol: f64) -> Self {
182        self.tolerance = tol;
183        self
184    }
185
186    /// Alias for [`tolerance`](Self::tolerance) (sklearn convention).
187    pub fn tol(self, t: f64) -> Self {
188        self.tolerance(t)
189    }
190
191    /// Enable early stopping with validation split. Default: false.
192    pub fn early_stopping(mut self, enable: bool) -> Self {
193        self.early_stopping = enable;
194        self
195    }
196
197    /// Set validation fraction for early stopping. Default: 0.1.
198    pub fn validation_fraction(mut self, frac: f64) -> Self {
199        self.validation_fraction = frac;
200        self
201    }
202
203    /// Set patience for early stopping. Default: 10.
204    pub fn n_iter_no_change(mut self, n: usize) -> Self {
205        self.n_iter_no_change = n;
206        self
207    }
208
209    /// Set random seed. Default: 42.
210    pub fn seed(mut self, s: u64) -> Self {
211        self.seed = s;
212        self
213    }
214
215    /// Set learning rate schedule. Default: [`LearningRateSchedule::Constant`].
216    ///
217    /// Use [`LearningRateSchedule::adaptive()`] for reduce-on-plateau behavior.
218    pub fn learning_rate_schedule(mut self, schedule: LearningRateSchedule) -> Self {
219        self.lr_schedule = schedule;
220        self
221    }
222
223    /// Set dropout probability applied between hidden layers.
224    ///
225    /// `p` is the fraction of activations to zero out (e.g. 0.5 for 50%).
226    /// Applied only during training; inference is unaffected.
227    /// Default: 0.0 (no dropout).
228    pub fn dropout(mut self, p: f64) -> Self {
229        self.dropout_rate = p;
230        self
231    }
232
233    /// Add a training callback (invoked after each epoch).
234    pub fn callback(mut self, cb: Box<dyn TrainingCallback>) -> Self {
235        self.callbacks.push(cb);
236        self
237    }
238
239    /// Train the classifier on a dataset.
240    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
241        data.validate_finite()?;
242        let n_samples = data.n_samples();
243        let n_features = data.n_features();
244
245        if n_samples == 0 {
246            return Err(ScryLearnError::EmptyDataset);
247        }
248
249        // Discover classes
250        let mut class_labels: Vec<f64> = data.target.clone();
251        class_labels.sort_by(|a, b| a.total_cmp(b));
252        class_labels.dedup();
253        let n_classes = class_labels.len();
254
255        if n_classes < 2 {
256            return Err(ScryLearnError::InvalidParameter(
257                "need at least 2 classes".into(),
258            ));
259        }
260
261        // Build row-major feature matrix
262        let x = build_row_major(&data.features, n_samples, n_features);
263
264        // Map targets to class indices
265        let y: Vec<f64> = data
266            .target
267            .iter()
268            .map(|&t| {
269                class_labels
270                    .iter()
271                    .position(|&c| (c - t).abs() < f64::EPSILON)
272                    .expect("target value must appear in class_labels") as f64
273            })
274            .collect();
275
276        // Split train/val if early stopping
277        let (train_x, train_y, val_x, val_y) = if self.early_stopping {
278            let mut rng = FastRng::new(self.seed);
279            let val_size = (n_samples as f64 * self.validation_fraction).max(1.0) as usize;
280            let train_size = n_samples - val_size;
281            let mut indices: Vec<usize> = (0..n_samples).collect();
282            rng.shuffle(&mut indices);
283
284            let mut tx = Vec::with_capacity(train_size * n_features);
285            let mut ty = Vec::with_capacity(train_size);
286            let mut vx = Vec::with_capacity(val_size * n_features);
287            let mut vy = Vec::with_capacity(val_size);
288
289            for &i in &indices[..train_size] {
290                tx.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
291                ty.push(y[i]);
292            }
293            for &i in &indices[train_size..] {
294                vx.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
295                vy.push(y[i]);
296            }
297            (tx, ty, Some(vx), Some(vy))
298        } else {
299            (x, y, None, None)
300        };
301
302        let train_n = train_y.len();
303
304        // Build network
305        let mut sizes = Vec::with_capacity(self.hidden_layers.len() + 2);
306        sizes.push(n_features);
307        sizes.extend_from_slice(&self.hidden_layers);
308        sizes.push(n_classes);
309
310        let mut net =
311            Network::new_with_dropout(&sizes, self.activation, self.seed, self.dropout_rate);
312        let param_sizes = net.param_group_sizes();
313        let mut optimizer = OptimizerState::new_with_schedule(
314            self.optimizer_kind,
315            self.learning_rate,
316            &param_sizes,
317            self.lr_schedule,
318        );
319
320        let batch_size = self.batch_size.min(train_n);
321        let mut rng = FastRng::new(self.seed.wrapping_add(1));
322        let mut indices: Vec<usize> = (0..train_n).collect();
323
324        self.loss_curve.clear();
325        self.training_history = TrainingHistory::new();
326        let mut best_val_loss = f64::INFINITY;
327        let mut best_weights: Option<Vec<(Vec<f64>, Vec<f64>)>> = None;
328        let mut no_improve = 0;
329
330        // Take callbacks out so we can mutably borrow them during training.
331        let mut callbacks = std::mem::take(&mut self.callbacks);
332
333        for epoch_idx in 0..self.max_iter {
334            let epoch_start = std::time::Instant::now();
335            rng.shuffle(&mut indices);
336
337            let mut epoch_loss = 0.0;
338            let mut n_batches = 0;
339            let mut last_grad_norm = 0.0;
340            let mut epoch_correct = 0usize;
341            let mut epoch_total = 0usize;
342
343            for chunk in indices.chunks(batch_size) {
344                let b = chunk.len();
345                let mut batch_x = Vec::with_capacity(b * n_features);
346                let mut batch_y = Vec::with_capacity(b);
347                for &i in chunk {
348                    batch_x.extend_from_slice(&train_x[i * n_features..(i + 1) * n_features]);
349                    batch_y.push(train_y[i]);
350                }
351
352                let logits = net.forward(&batch_x, b, true);
353                let (loss, grad) = network::cross_entropy_loss(&logits, &batch_y, b, n_classes);
354                epoch_loss += loss;
355                n_batches += 1;
356
357                // Compute training accuracy for this mini-batch
358                let preds = network::argmax_predictions(&logits, b, n_classes);
359                for (p, t) in preds.iter().zip(batch_y.iter()) {
360                    if (*p - *t).abs() < f64::EPSILON {
361                        epoch_correct += 1;
362                    }
363                    epoch_total += 1;
364                }
365
366                let layer_grads = net.backward(&grad, self.alpha);
367                last_grad_norm = callback::compute_grad_norm(&layer_grads);
368                optimizer.tick();
369                net.apply_gradients(&layer_grads, &mut optimizer);
370            }
371
372            let avg_loss = epoch_loss / n_batches as f64;
373            self.loss_curve.push(avg_loss);
374
375            // Adjust learning rate based on schedule.
376            optimizer.adjust_lr(avg_loss);
377
378            let train_accuracy = if epoch_total > 0 {
379                Some(epoch_correct as f64 / epoch_total as f64)
380            } else {
381                None
382            };
383
384            // Early stopping check + validation metrics
385            let mut val_loss_epoch = None;
386            let mut val_metric_epoch = None;
387
388            if self.early_stopping {
389                if let (Some(ref vx), Some(ref vy)) = (&val_x, &val_y) {
390                    let val_n = vy.len();
391                    let val_logits = net.forward(vx, val_n, false);
392                    let (val_loss, _) =
393                        network::cross_entropy_loss(&val_logits, vy, val_n, n_classes);
394                    val_loss_epoch = Some(val_loss);
395
396                    // Validation accuracy
397                    let val_preds = network::argmax_predictions(&val_logits, val_n, n_classes);
398                    let val_correct = val_preds
399                        .iter()
400                        .zip(vy.iter())
401                        .filter(|(p, t)| (**p - **t).abs() < f64::EPSILON)
402                        .count();
403                    val_metric_epoch = Some(val_correct as f64 / val_n as f64);
404
405                    if val_loss < best_val_loss - self.tolerance {
406                        best_val_loss = val_loss;
407                        best_weights = Some(net.save_weights());
408                        no_improve = 0;
409                    } else {
410                        no_improve += 1;
411                    }
412                }
413            } else {
414                // Check training loss convergence
415                let n = self.loss_curve.len();
416                if n >= 2 {
417                    let improvement = self.loss_curve[n - 2] - self.loss_curve[n - 1];
418                    if improvement.abs() < self.tolerance {
419                        no_improve += 1;
420                    } else {
421                        no_improve = 0;
422                    }
423                }
424            }
425
426            let elapsed = epoch_start.elapsed();
427            let metrics = EpochMetrics {
428                epoch: epoch_idx,
429                train_loss: avg_loss,
430                val_loss: val_loss_epoch,
431                train_metric: train_accuracy,
432                val_metric: val_metric_epoch,
433                learning_rate: optimizer.current_lr(),
434                grad_norm: last_grad_norm,
435                elapsed_ms: elapsed.as_millis() as u64,
436            };
437
438            // Invoke user callbacks.
439            let mut cb_stop = false;
440            for cb in &mut callbacks {
441                if cb.on_epoch_end(&metrics) == CallbackAction::Stop {
442                    cb_stop = true;
443                }
444            }
445
446            self.training_history.push(metrics);
447
448            if cb_stop {
449                break;
450            }
451
452            if no_improve >= self.n_iter_no_change
453                && (self.early_stopping || self.loss_curve.len() >= 2)
454            {
455                break;
456            }
457        }
458
459        // Notify callbacks that training is done, then put them back.
460        for cb in &mut callbacks {
461            cb.on_training_end();
462        }
463        self.callbacks = callbacks;
464
465        // Restore best weights if early stopping found an improvement
466        if let Some(ref best) = best_weights {
467            net.restore_weights(best);
468        }
469
470        // Save fitted state
471        self.network_weights = net.save_weights();
472        self.network_dims = net.layer_dims();
473        self.n_features = n_features;
474        self.n_classes = n_classes;
475        self.class_labels = class_labels;
476        self.fitted = true;
477
478        Ok(())
479    }
480
481    /// Predict class labels for input samples.
482    ///
483    /// `features` is `&[Vec<f64>]` where each inner vec is one sample (row-major).
484    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
485        let proba = self.predict_proba(features)?;
486        let batch = features.len();
487        let preds = network::argmax_predictions(&proba, batch, self.n_classes);
488        // Map indices back to original class labels
489        Ok(preds
490            .iter()
491            .map(|&i| self.class_labels[i as usize])
492            .collect())
493    }
494
495    /// Predict class probabilities (softmax output).
496    ///
497    /// Returns a flat `[batch * n_classes]` row-major probability matrix.
498    pub fn predict_proba(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
499        if !self.fitted {
500            return Err(ScryLearnError::NotFitted);
501        }
502
503        let batch = features.len();
504        if batch == 0 {
505            return Ok(Vec::new());
506        }
507
508        let n_feat = features[0].len();
509        if n_feat != self.n_features {
510            return Err(ScryLearnError::ShapeMismatch {
511                expected: self.n_features,
512                got: n_feat,
513            });
514        }
515
516        let mut net = self.rebuild_network();
517        let x: Vec<f64> = features
518            .iter()
519            .flat_map(|row| row.iter().copied())
520            .collect();
521        let logits = net.forward(&x, batch, false);
522        Ok(network::softmax(&logits, batch, self.n_classes))
523    }
524
525    /// Number of classes discovered during fit.
526    pub fn n_classes(&self) -> usize {
527        self.n_classes
528    }
529
530    /// Number of features the model was trained on.
531    pub fn n_features(&self) -> usize {
532        self.n_features
533    }
534
535    /// Training loss per epoch.
536    pub fn loss_curve(&self) -> &[f64] {
537        &self.loss_curve
538    }
539
540    /// Structured training history with per-epoch metrics.
541    ///
542    /// Returns `None` if the model has not been fitted yet.
543    pub fn history(&self) -> Option<&TrainingHistory> {
544        if self.training_history.is_empty() {
545            None
546        } else {
547            Some(&self.training_history)
548        }
549    }
550
551    /// Saved network weights (for visualization).
552    pub fn weights(&self) -> &[(Vec<f64>, Vec<f64>)] {
553        &self.network_weights
554    }
555
556    /// Layer dimensions (for visualization).
557    pub fn layer_dims(&self) -> &[(usize, usize)] {
558        &self.network_dims
559    }
560
561    /// Hidden-layer activation function.
562    pub fn activation_fn(&self) -> Activation {
563        self.activation
564    }
565
566    /// Rebuild a Network from saved weights.
567    fn rebuild_network(&self) -> Network {
568        let mut sizes = Vec::with_capacity(self.network_dims.len() + 1);
569        sizes.push(self.network_dims[0].0);
570        for &(_, out) in &self.network_dims {
571            sizes.push(out);
572        }
573        let mut net = Network::new_with_dropout(&sizes, self.activation, 0, self.dropout_rate);
574        net.restore_weights(&self.network_weights);
575        net
576    }
577}
578
579impl PartialFit for MLPClassifier {
580    /// Run one epoch of mini-batch SGD on the given data.
581    ///
582    /// On the first call, initializes the network architecture from the data
583    /// dimensions. Subsequent calls preserve network weights and continue
584    /// training.
585    fn partial_fit(&mut self, data: &Dataset) -> Result<()> {
586        let n_samples = data.n_samples();
587        let n_features = data.n_features();
588        if n_samples == 0 {
589            if self.is_initialized() {
590                return Ok(());
591            }
592            return Err(ScryLearnError::EmptyDataset);
593        }
594
595        // Discover classes from this batch.
596        let mut batch_labels: Vec<f64> = data.target.clone();
597        batch_labels.sort_by(|a, b| a.total_cmp(b));
598        batch_labels.dedup();
599
600        if self.is_initialized() {
601            if n_features != self.n_features {
602                return Err(ScryLearnError::ShapeMismatch {
603                    expected: self.n_features,
604                    got: n_features,
605                });
606            }
607            // Check for new classes not seen during initialization.
608            for &label in &batch_labels {
609                if !self
610                    .class_labels
611                    .iter()
612                    .any(|&c| (c - label).abs() < f64::EPSILON)
613                {
614                    return Err(ScryLearnError::InvalidParameter(format!(
615                        "partial_fit encountered new class {label} not seen during \
616                         initialization (known classes: {:?}). MLPClassifier cannot add \
617                         classes after network initialization — pass all possible classes \
618                         in the first batch.",
619                        self.class_labels
620                    )));
621                }
622            }
623        } else {
624            let n_classes = batch_labels.len();
625            if n_classes < 2 {
626                return Err(ScryLearnError::InvalidParameter(
627                    "need at least 2 classes".into(),
628                ));
629            }
630
631            // Build and initialize network.
632            let mut sizes = Vec::with_capacity(self.hidden_layers.len() + 2);
633            sizes.push(n_features);
634            sizes.extend_from_slice(&self.hidden_layers);
635            sizes.push(n_classes);
636
637            let net = Network::new(&sizes, self.activation, self.seed);
638            self.network_weights = net.save_weights();
639            self.network_dims = net.layer_dims();
640            self.n_features = n_features;
641            self.n_classes = n_classes;
642            self.class_labels = batch_labels;
643            self.loss_curve.clear();
644        }
645
646        // Build row-major data.
647        let x = build_row_major(&data.features, n_samples, n_features);
648        let y: Vec<f64> = data
649            .target
650            .iter()
651            .map(|&t| {
652                self.class_labels
653                    .iter()
654                    .position(|&c| (c - t).abs() < f64::EPSILON)
655                    .unwrap_or(0) as f64
656            })
657            .collect();
658
659        // Rebuild network from saved weights.
660        let mut net = self.rebuild_network();
661        let param_sizes = net.param_group_sizes();
662        let mut optimizer =
663            OptimizerState::new(self.optimizer_kind, self.learning_rate, &param_sizes);
664
665        let batch_size = self.batch_size.min(n_samples);
666        let mut rng = FastRng::new(self.seed.wrapping_add(self.loss_curve.len() as u64));
667        let mut indices: Vec<usize> = (0..n_samples).collect();
668
669        // One epoch.
670        rng.shuffle(&mut indices);
671        let mut epoch_loss = 0.0;
672        let mut n_batches = 0;
673
674        for chunk in indices.chunks(batch_size) {
675            let b = chunk.len();
676            let mut batch_x = Vec::with_capacity(b * n_features);
677            let mut batch_y = Vec::with_capacity(b);
678            for &i in chunk {
679                batch_x.extend_from_slice(&x[i * n_features..(i + 1) * n_features]);
680                batch_y.push(y[i]);
681            }
682
683            let logits = net.forward(&batch_x, b, true);
684            let (loss, grad) = network::cross_entropy_loss(&logits, &batch_y, b, self.n_classes);
685            epoch_loss += loss;
686            n_batches += 1;
687
688            let layer_grads = net.backward(&grad, self.alpha);
689            optimizer.tick();
690            net.apply_gradients(&layer_grads, &mut optimizer);
691        }
692
693        self.loss_curve.push(epoch_loss / n_batches as f64);
694        self.network_weights = net.save_weights();
695        self.fitted = true;
696        Ok(())
697    }
698
699    fn is_initialized(&self) -> bool {
700        !self.network_weights.is_empty()
701    }
702}
703
704impl Default for MLPClassifier {
705    fn default() -> Self {
706        Self::new()
707    }
708}
709
710impl std::fmt::Debug for MLPClassifier {
711    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
712        f.debug_struct("MLPClassifier")
713            .field("hidden_layers", &self.hidden_layers)
714            .field("activation", &self.activation)
715            .field("fitted", &self.fitted)
716            .field("n_classes", &self.n_classes)
717            .finish()
718    }
719}
720
721/// Build row-major feature matrix from column-major Dataset.
722fn build_row_major(features: &[Vec<f64>], n_samples: usize, n_features: usize) -> Vec<f64> {
723    let mut x = vec![0.0; n_samples * n_features];
724    for j in 0..n_features {
725        for i in 0..n_samples {
726            x[i * n_features + j] = features[j][i];
727        }
728    }
729    x
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735
736    fn xor_dataset() -> Dataset {
737        Dataset::new(
738            vec![vec![0.0, 0.0, 1.0, 1.0], vec![0.0, 1.0, 0.0, 1.0]],
739            vec![0.0, 1.0, 1.0, 0.0],
740            vec!["x1".into(), "x2".into()],
741            "xor",
742        )
743    }
744
745    fn linearly_separable() -> Dataset {
746        let mut f1 = Vec::new();
747        let mut f2 = Vec::new();
748        let mut target = Vec::new();
749        for i in 0..50 {
750            let v = i as f64 * 0.1;
751            f1.push(v);
752            f2.push(v + 0.5);
753            target.push(0.0);
754            f1.push(v + 5.0);
755            f2.push(v + 5.5);
756            target.push(1.0);
757        }
758        Dataset::new(
759            vec![f1, f2],
760            target,
761            vec!["f1".into(), "f2".into()],
762            "class",
763        )
764    }
765
766    #[test]
767    fn not_fitted_error() {
768        let clf = MLPClassifier::new();
769        let result = clf.predict(&[vec![1.0, 2.0]]);
770        assert!(matches!(result, Err(ScryLearnError::NotFitted)));
771    }
772
773    #[test]
774    fn xor_problem() {
775        // XOR requires non-linear separation — proves the network works
776        let data = xor_dataset();
777        let mut clf = MLPClassifier::new()
778            .hidden_layers(&[10, 10])
779            .learning_rate(0.01)
780            .max_iter(1000)
781            .batch_size(4)
782            .seed(42);
783        clf.fit(&data).unwrap();
784
785        let preds = clf
786            .predict(&[
787                vec![0.0, 0.0],
788                vec![0.0, 1.0],
789                vec![1.0, 0.0],
790                vec![1.0, 1.0],
791            ])
792            .unwrap();
793
794        let correct = preds
795            .iter()
796            .zip([0.0, 1.0, 1.0, 0.0].iter())
797            .filter(|(p, t)| (**p - **t).abs() < f64::EPSILON)
798            .count();
799
800        assert!(
801            correct >= 3,
802            "XOR: got {correct}/4 correct, preds={preds:?}"
803        );
804    }
805
806    #[test]
807    fn linearly_separable_data() {
808        let data = linearly_separable();
809        let mut clf = MLPClassifier::new()
810            .hidden_layers(&[20])
811            .max_iter(200)
812            .seed(42);
813        clf.fit(&data).unwrap();
814
815        let test_x = vec![vec![0.5, 1.0], vec![5.5, 6.0]];
816        let preds = clf.predict(&test_x).unwrap();
817        assert!((preds[0] - 0.0).abs() < f64::EPSILON);
818        assert!((preds[1] - 1.0).abs() < f64::EPSILON);
819    }
820
821    #[test]
822    fn early_stopping_halts() {
823        let data = linearly_separable();
824        let mut clf = MLPClassifier::new()
825            .hidden_layers(&[20])
826            .max_iter(500)
827            .early_stopping(true)
828            .n_iter_no_change(5)
829            .seed(42);
830        clf.fit(&data).unwrap();
831
832        // Should have stopped well before 500 epochs
833        assert!(
834            clf.loss_curve.len() < 500,
835            "expected early stop, got {} epochs",
836            clf.loss_curve.len()
837        );
838    }
839
840    #[test]
841    fn predict_proba_sums_to_one() {
842        let data = linearly_separable();
843        let mut clf = MLPClassifier::new()
844            .hidden_layers(&[10])
845            .max_iter(50)
846            .seed(42);
847        clf.fit(&data).unwrap();
848
849        let proba = clf.predict_proba(&[vec![1.0, 1.5]]).unwrap();
850        let sum: f64 = proba.iter().sum();
851        assert!((sum - 1.0).abs() < 1e-6);
852    }
853
854    #[test]
855    fn shape_mismatch_error() {
856        let data = linearly_separable();
857        let mut clf = MLPClassifier::new()
858            .hidden_layers(&[10])
859            .max_iter(10)
860            .seed(42);
861        clf.fit(&data).unwrap();
862
863        let result = clf.predict(&[vec![1.0, 2.0, 3.0]]); // 3 features, expected 2
864        assert!(matches!(result, Err(ScryLearnError::ShapeMismatch { .. })));
865    }
866
867    #[test]
868    fn loss_decreases() {
869        let data = linearly_separable();
870        let mut clf = MLPClassifier::new()
871            .hidden_layers(&[20])
872            .max_iter(50)
873            .seed(42);
874        clf.fit(&data).unwrap();
875
876        let curve = clf.loss_curve();
877        assert!(curve.len() >= 2);
878        // First loss should be higher than last
879        assert!(curve.first().unwrap() > curve.last().unwrap());
880    }
881
882    #[test]
883    fn partial_fit_is_initialized() {
884        let mut clf = MLPClassifier::new();
885        assert!(!clf.is_initialized());
886
887        let data = linearly_separable();
888        clf.partial_fit(&data).unwrap();
889        assert!(clf.is_initialized());
890    }
891
892    #[test]
893    fn partial_fit_loss_decreases() {
894        let data = linearly_separable();
895        let mut clf = MLPClassifier::new()
896            .hidden_layers(&[20])
897            .learning_rate(0.01)
898            .batch_size(32)
899            .seed(42);
900
901        // Run 10 partial_fit calls on the same data.
902        for _ in 0..10 {
903            clf.partial_fit(&data).unwrap();
904        }
905
906        let curve = clf.loss_curve();
907        assert!(curve.len() == 10);
908        // Overall trend: first loss > last loss
909        assert!(
910            curve.first().unwrap() > curve.last().unwrap(),
911            "loss should decrease: first={} last={}",
912            curve.first().unwrap(),
913            curve.last().unwrap()
914        );
915    }
916
917    #[test]
918    fn partial_fit_classifies_after_batches() {
919        let mut clf = MLPClassifier::new()
920            .hidden_layers(&[20])
921            .learning_rate(0.01)
922            .batch_size(32)
923            .seed(42);
924
925        let data = linearly_separable();
926        for _ in 0..50 {
927            clf.partial_fit(&data).unwrap();
928        }
929
930        let preds = clf.predict(&[vec![0.5, 1.0], vec![5.5, 6.0]]).unwrap();
931        assert!(
932            (preds[0] - 0.0).abs() < f64::EPSILON,
933            "x=0.5 should be class 0"
934        );
935        assert!(
936            (preds[1] - 1.0).abs() < f64::EPSILON,
937            "x=5.5 should be class 1"
938        );
939    }
940}