Skip to main content

tensorlogic_train/
online_learning.rs

1//! Online learning algorithms: Perceptron, Passive-Aggressive, OGD, and FTRL.
2//!
3//! All algorithms process one sample at a time with O(d) memory where d is
4//! the number of features, making them suitable for streaming and large-scale
5//! applications where the full dataset cannot be held in memory.
6//!
7//! # Algorithms
8//!
9//! - [`Perceptron`]: Classic binary classifier (Rosenblatt 1958)
10//! - [`PassiveAggressive`]: PA, PA-I, PA-II variants (Crammer et al. 2006)
11//! - [`OnlineGradientDescent`]: OGD with squared/hinge/logistic losses
12//! - [`Ftrl`]: Follow the Regularized Leader-Proximal (McMahan et al. 2013)
13
14use std::fmt;
15
16// ---------------------------------------------------------------------------
17// Core trait and result types
18// ---------------------------------------------------------------------------
19
20/// Trait for online learners that update one sample at a time.
21pub trait OnlineLearner {
22    /// Update model on a single (features, label) pair. Returns update stats.
23    fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError>;
24
25    /// Predict class label or regression value for features.
26    fn predict(&self, features: &[f64]) -> Result<f64, OnlineError>;
27
28    /// Number of updates seen so far.
29    fn n_updates(&self) -> usize;
30
31    /// Current weight vector (without bias).
32    fn weights(&self) -> &[f64];
33}
34
35/// Result of a single online update step.
36#[derive(Debug, Clone)]
37pub struct OnlineUpdateResult {
38    /// Loss on this sample computed *before* the update.
39    pub loss: f64,
40    /// L2 norm of the weight change vector (||Δw||).
41    pub weight_delta_norm: f64,
42    /// For classifiers: whether the prediction was incorrect before the update.
43    pub was_mistake: bool,
44}
45
46/// Errors that can arise in online learning routines.
47#[derive(Debug)]
48pub enum OnlineError {
49    /// Feature dimensionality does not match model dimensionality.
50    DimensionMismatch { expected: usize, got: usize },
51    /// A hyperparameter has an invalid value.
52    InvalidHyperparameter(String),
53    /// Prediction attempted before the model has received any data.
54    NotFitted,
55}
56
57impl fmt::Display for OnlineError {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        match self {
60            OnlineError::DimensionMismatch { expected, got } => write!(
61                f,
62                "dimension mismatch: expected {expected} features, got {got}"
63            ),
64            OnlineError::InvalidHyperparameter(msg) => {
65                write!(f, "invalid hyperparameter: {msg}")
66            }
67            OnlineError::NotFitted => write!(f, "model has not been fitted yet"),
68        }
69    }
70}
71
72impl std::error::Error for OnlineError {}
73
74// ---------------------------------------------------------------------------
75// Running statistics
76// ---------------------------------------------------------------------------
77
78/// Cumulative statistics for an online learning session.
79#[derive(Debug, Clone, Default)]
80pub struct OnlineStats {
81    /// Total number of update calls.
82    pub n_updates: usize,
83    /// Total number of incorrect predictions (classification only).
84    pub n_mistakes: usize,
85    /// Sum of per-sample losses.
86    pub cumulative_loss: f64,
87    /// Running mean loss: cumulative_loss / n_updates.
88    pub mean_loss: f64,
89    /// ||w|| after the most recent update.
90    pub last_weight_norm: f64,
91}
92
93impl OnlineStats {
94    /// Fraction of updates that resulted in a mistake (classification).
95    ///
96    /// Returns 0.0 when no updates have been performed.
97    pub fn mistake_rate(&self) -> f64 {
98        if self.n_updates == 0 {
99            0.0
100        } else {
101            self.n_mistakes as f64 / self.n_updates as f64
102        }
103    }
104
105    /// Incorporate the result of one update step into running statistics.
106    pub fn update(&mut self, result: &OnlineUpdateResult) {
107        self.n_updates += 1;
108        if result.was_mistake {
109            self.n_mistakes += 1;
110        }
111        self.cumulative_loss += result.loss;
112        self.mean_loss = self.cumulative_loss / self.n_updates as f64;
113    }
114}
115
116// ---------------------------------------------------------------------------
117// Internal helpers
118// ---------------------------------------------------------------------------
119
120/// Compute the squared L2 norm of a slice.
121#[inline]
122fn l2_norm_sq(v: &[f64]) -> f64 {
123    v.iter().map(|x| x * x).sum()
124}
125
126/// Compute the L2 norm of a slice.
127#[inline]
128fn l2_norm(v: &[f64]) -> f64 {
129    l2_norm_sq(v).sqrt()
130}
131
132/// Dot product of two equal-length slices.
133#[inline]
134fn dot(a: &[f64], b: &[f64]) -> f64 {
135    a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
136}
137
138/// Sign function returning -1.0, 0.0, or +1.0.
139#[inline]
140fn sign(x: f64) -> f64 {
141    if x > 0.0 {
142        1.0
143    } else if x < 0.0 {
144        -1.0
145    } else {
146        0.0
147    }
148}
149
150// ---------------------------------------------------------------------------
151// Perceptron
152// ---------------------------------------------------------------------------
153
154/// Binary Perceptron classifier (Rosenblatt 1958).
155///
156/// Labels must be in {−1, +1}. The update rule fires only when a prediction
157/// is wrong: `w ← w + η·y·x` and `bias ← bias + η·y`.
158#[derive(Debug, Clone)]
159pub struct Perceptron {
160    weights: Vec<f64>,
161    bias: f64,
162    n_updates: usize,
163    stats: OnlineStats,
164    learning_rate: f64,
165}
166
167impl Perceptron {
168    /// Create a new Perceptron with `n_features` dimensions and default `η = 1.0`.
169    pub fn new(n_features: usize) -> Self {
170        Self {
171            weights: vec![0.0; n_features],
172            bias: 0.0,
173            n_updates: 0,
174            stats: OnlineStats::default(),
175            learning_rate: 1.0,
176        }
177    }
178
179    /// Set the per-mistake learning rate (η).
180    pub fn with_learning_rate(mut self, lr: f64) -> Self {
181        self.learning_rate = lr;
182        self
183    }
184
185    /// Current bias term.
186    pub fn bias(&self) -> f64 {
187        self.bias
188    }
189
190    /// Reference to running statistics.
191    pub fn stats(&self) -> &OnlineStats {
192        &self.stats
193    }
194
195    /// Raw score w·x + bias.
196    fn score(&self, features: &[f64]) -> f64 {
197        dot(&self.weights, features) + self.bias
198    }
199}
200
201impl OnlineLearner for Perceptron {
202    fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
203        let n = self.weights.len();
204        if features.len() != n {
205            return Err(OnlineError::DimensionMismatch {
206                expected: n,
207                got: features.len(),
208            });
209        }
210
211        let score = self.score(features);
212        let predicted_sign = sign(score);
213        let true_sign = sign(label);
214
215        // Hinge-like loss: max(0, -y * score).
216        let margin = true_sign * score;
217        let loss = if margin <= 0.0 { -margin } else { 0.0 };
218        let was_mistake = predicted_sign != true_sign;
219
220        let mut delta_sq = 0.0_f64;
221
222        if was_mistake {
223            let eta_y = self.learning_rate * true_sign;
224            for (w, x) in self.weights.iter_mut().zip(features.iter()) {
225                let delta = eta_y * x;
226                delta_sq += delta * delta;
227                *w += delta;
228            }
229            let bias_delta = self.learning_rate * true_sign;
230            delta_sq += bias_delta * bias_delta;
231            self.bias += bias_delta;
232        }
233
234        self.n_updates += 1;
235
236        // Update last_weight_norm in stats.
237        let weight_delta_norm = delta_sq.sqrt();
238        let result = OnlineUpdateResult {
239            loss,
240            weight_delta_norm,
241            was_mistake,
242        };
243        self.stats.update(&result);
244        self.stats.last_weight_norm = l2_norm(&self.weights);
245
246        Ok(result)
247    }
248
249    fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
250        let n = self.weights.len();
251        if features.len() != n {
252            return Err(OnlineError::DimensionMismatch {
253                expected: n,
254                got: features.len(),
255            });
256        }
257        Ok(sign(self.score(features)))
258    }
259
260    fn n_updates(&self) -> usize {
261        self.n_updates
262    }
263
264    fn weights(&self) -> &[f64] {
265        &self.weights
266    }
267}
268
269// ---------------------------------------------------------------------------
270// Passive-Aggressive
271// ---------------------------------------------------------------------------
272
273/// Selects the PA update variant.
274#[derive(Debug, Clone, Copy, PartialEq, Eq)]
275pub enum PAVariant {
276    /// Unconstrained PA: τ = loss / ||x||².
277    PA,
278    /// PA-I: τ = min(C, loss / ||x||²).
279    PAI,
280    /// PA-II: τ = loss / (||x||² + 1 / (2C)).
281    PAII,
282}
283
284/// Passive-Aggressive classifier (Crammer et al. 2006).
285///
286/// Labels must be in {−1, +1}. The PA family uses the hinge loss
287/// `ℓ = max(0, 1 − y(w·x + b))` and then computes a closed-form
288/// minimal weight update.
289#[derive(Debug, Clone)]
290pub struct PassiveAggressive {
291    weights: Vec<f64>,
292    bias: f64,
293    n_updates: usize,
294    stats: OnlineStats,
295    aggressiveness: f64,
296    variant: PAVariant,
297}
298
299impl PassiveAggressive {
300    /// Create a new PA classifier. `variant` controls the update rule.
301    pub fn new(n_features: usize, variant: PAVariant) -> Self {
302        Self {
303            weights: vec![0.0; n_features],
304            bias: 0.0,
305            n_updates: 0,
306            stats: OnlineStats::default(),
307            aggressiveness: 1.0,
308            variant,
309        }
310    }
311
312    /// Set the aggressiveness parameter C (must be positive).
313    pub fn with_aggressiveness(mut self, c: f64) -> Result<Self, OnlineError> {
314        if c <= 0.0 {
315            return Err(OnlineError::InvalidHyperparameter(format!(
316                "aggressiveness C must be > 0, got {c}"
317            )));
318        }
319        self.aggressiveness = c;
320        Ok(self)
321    }
322
323    /// Reference to running statistics.
324    pub fn stats(&self) -> &OnlineStats {
325        &self.stats
326    }
327
328    /// Compute τ (step size) for the current sample.
329    fn compute_tau(&self, loss: f64, x_norm_sq: f64) -> f64 {
330        match self.variant {
331            PAVariant::PA => {
332                if x_norm_sq == 0.0 {
333                    0.0
334                } else {
335                    loss / x_norm_sq
336                }
337            }
338            PAVariant::PAI => {
339                let tau_unconstrained = if x_norm_sq == 0.0 {
340                    0.0
341                } else {
342                    loss / x_norm_sq
343                };
344                tau_unconstrained.min(self.aggressiveness)
345            }
346            PAVariant::PAII => {
347                let denom = x_norm_sq + 1.0 / (2.0 * self.aggressiveness);
348                if denom == 0.0 {
349                    0.0
350                } else {
351                    loss / denom
352                }
353            }
354        }
355    }
356}
357
358impl OnlineLearner for PassiveAggressive {
359    fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
360        let n = self.weights.len();
361        if features.len() != n {
362            return Err(OnlineError::DimensionMismatch {
363                expected: n,
364                got: features.len(),
365            });
366        }
367
368        let score = dot(&self.weights, features) + self.bias;
369        let y = sign(label);
370
371        // Hinge loss: max(0, 1 - y * score).
372        let margin = y * score;
373        let loss = (1.0 - margin).max(0.0);
374        let was_mistake = sign(score) != y;
375
376        let x_norm_sq = l2_norm_sq(features);
377        let tau = self.compute_tau(loss, x_norm_sq);
378
379        let mut delta_sq = 0.0_f64;
380        if tau > 0.0 {
381            let tau_y = tau * y;
382            for (w, x) in self.weights.iter_mut().zip(features.iter()) {
383                let delta = tau_y * x;
384                delta_sq += delta * delta;
385                *w += delta;
386            }
387            let bias_delta = tau * y;
388            delta_sq += bias_delta * bias_delta;
389            self.bias += bias_delta;
390        }
391
392        self.n_updates += 1;
393
394        let result = OnlineUpdateResult {
395            loss,
396            weight_delta_norm: delta_sq.sqrt(),
397            was_mistake,
398        };
399        self.stats.update(&result);
400        self.stats.last_weight_norm = l2_norm(&self.weights);
401
402        Ok(result)
403    }
404
405    fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
406        let n = self.weights.len();
407        if features.len() != n {
408            return Err(OnlineError::DimensionMismatch {
409                expected: n,
410                got: features.len(),
411            });
412        }
413        Ok(sign(dot(&self.weights, features) + self.bias))
414    }
415
416    fn n_updates(&self) -> usize {
417        self.n_updates
418    }
419
420    fn weights(&self) -> &[f64] {
421        &self.weights
422    }
423}
424
425// ---------------------------------------------------------------------------
426// Online Gradient Descent
427// ---------------------------------------------------------------------------
428
429/// Loss function for [`OnlineGradientDescent`].
430#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum OGDLoss {
432    /// Squared loss: ℓ = ½(pred − y)². Gradient: (pred − y)·x.
433    Squared,
434    /// Hinge loss: ℓ = max(0, 1 − y·score). Gradient: −y·x when margin < 1.
435    Hinge,
436    /// Logistic loss: ℓ = log(1 + exp(−y·score)). Gradient: −y·σ(−y·score)·x.
437    Logistic,
438}
439
440/// Online Gradient Descent for convex losses.
441///
442/// The learning rate schedule follows η_t = η_0 / √(t + 1) when `lr_decay > 0`,
443/// otherwise a constant η_0 is used. Optional L2 regularisation applies weight
444/// decay at each step.
445#[derive(Debug, Clone)]
446pub struct OnlineGradientDescent {
447    weights: Vec<f64>,
448    bias: f64,
449    n_updates: usize,
450    stats: OnlineStats,
451    initial_lr: f64,
452    lr_decay: f64,
453    l2_reg: f64,
454    loss: OGDLoss,
455}
456
457impl OnlineGradientDescent {
458    /// Create a new OGD learner for the given loss function.
459    pub fn new(n_features: usize, loss: OGDLoss) -> Self {
460        Self {
461            weights: vec![0.0; n_features],
462            bias: 0.0,
463            n_updates: 0,
464            stats: OnlineStats::default(),
465            initial_lr: 0.1,
466            lr_decay: 0.0,
467            l2_reg: 0.0,
468            loss,
469        }
470    }
471
472    /// Set the initial learning rate η_0.
473    pub fn with_lr(mut self, lr: f64) -> Self {
474        self.initial_lr = lr;
475        self
476    }
477
478    /// Set the L2 regularisation coefficient λ.
479    pub fn with_l2(mut self, lambda: f64) -> Self {
480        self.l2_reg = lambda;
481        self
482    }
483
484    /// Enable learning rate decay. When `decay > 0`, η_t = η_0 / √(t + 1).
485    pub fn with_lr_decay(mut self, decay: f64) -> Self {
486        self.lr_decay = decay;
487        self
488    }
489
490    /// Reference to running statistics.
491    pub fn stats(&self) -> &OnlineStats {
492        &self.stats
493    }
494
495    /// Effective learning rate at the current step.
496    fn current_lr(&self) -> f64 {
497        if self.lr_decay > 0.0 {
498            self.initial_lr / ((self.n_updates as f64 + 1.0).sqrt())
499        } else {
500            self.initial_lr
501        }
502    }
503
504    /// Compute loss and gradient coefficient `g` such that `∂ℓ/∂w = g·x`.
505    /// Returns `(loss, grad_coeff, bias_grad)`.
506    fn compute_loss_and_grad(&self, features: &[f64], label: f64) -> (f64, f64, f64) {
507        let score = dot(&self.weights, features) + self.bias;
508        match self.loss {
509            OGDLoss::Squared => {
510                let diff = score - label;
511                let loss = 0.5 * diff * diff;
512                (loss, diff, diff)
513            }
514            OGDLoss::Hinge => {
515                let y = sign(label);
516                let margin = y * score;
517                if margin < 1.0 {
518                    let loss = 1.0 - margin;
519                    (loss, -y, -y)
520                } else {
521                    (0.0, 0.0, 0.0)
522                }
523            }
524            OGDLoss::Logistic => {
525                let y = sign(label);
526                // σ(-y·s) = 1 / (1 + exp(y·s))
527                let ys = y * score;
528                let sigma_neg = 1.0 / (1.0 + ys.exp()); // σ(-y·s)
529                let loss = (1.0 + (-ys).exp()).ln();
530                let grad_coeff = -y * sigma_neg;
531                (loss, grad_coeff, grad_coeff)
532            }
533        }
534    }
535}
536
537impl OnlineLearner for OnlineGradientDescent {
538    fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
539        let n = self.weights.len();
540        if features.len() != n {
541            return Err(OnlineError::DimensionMismatch {
542                expected: n,
543                got: features.len(),
544            });
545        }
546
547        let (loss, grad_coeff, bias_grad) = self.compute_loss_and_grad(features, label);
548        let eta = self.current_lr();
549
550        let was_mistake = match self.loss {
551            OGDLoss::Squared => false, // regression — no concept of "mistake"
552            OGDLoss::Hinge | OGDLoss::Logistic => {
553                let score = dot(&self.weights, features) + self.bias;
554                sign(score) != sign(label)
555            }
556        };
557
558        let mut delta_sq = 0.0_f64;
559
560        // Gradient step + L2 regularisation (weight decay).
561        for (w, x) in self.weights.iter_mut().zip(features.iter()) {
562            let grad = grad_coeff * x + self.l2_reg * (*w);
563            let delta = -eta * grad;
564            delta_sq += delta * delta;
565            *w += delta;
566        }
567        // Bias is not regularised.
568        let bias_delta = -eta * bias_grad;
569        delta_sq += bias_delta * bias_delta;
570        self.bias += bias_delta;
571
572        self.n_updates += 1;
573
574        let result = OnlineUpdateResult {
575            loss,
576            weight_delta_norm: delta_sq.sqrt(),
577            was_mistake,
578        };
579        self.stats.update(&result);
580        self.stats.last_weight_norm = l2_norm(&self.weights);
581
582        Ok(result)
583    }
584
585    fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
586        let n = self.weights.len();
587        if features.len() != n {
588            return Err(OnlineError::DimensionMismatch {
589                expected: n,
590                got: features.len(),
591            });
592        }
593        let score = dot(&self.weights, features) + self.bias;
594        let prediction = match self.loss {
595            OGDLoss::Squared => score,
596            OGDLoss::Hinge | OGDLoss::Logistic => sign(score),
597        };
598        Ok(prediction)
599    }
600
601    fn n_updates(&self) -> usize {
602        self.n_updates
603    }
604
605    fn weights(&self) -> &[f64] {
606        &self.weights
607    }
608}
609
610// ---------------------------------------------------------------------------
611// FTRL-Proximal
612// ---------------------------------------------------------------------------
613
614/// Follow the Regularized Leader — Proximal (McMahan et al. 2013).
615///
616/// FTRL-Proximal maintains per-feature adaptive learning rates and supports
617/// L1 + L2 regularization. L1 induces sparsity: features with |z_i| ≤ l1
618/// are zeroed out, which makes FTRL popular for large-scale sparse models.
619///
620/// Update equations (per coordinate i):
621/// ```text
622/// g_i       = gradient of logistic loss on this sample
623/// z_i      += g_i − (√(n_i + g_i²) − √n_i) / α · w_i
624/// n_i      += g_i²
625/// if |z_i| ≤ l1:
626///     w_i = 0
627/// else:
628///     w_i = −(z_i − sign(z_i)·l1) / ((β + √n_i) / α + l2)
629/// ```
630#[derive(Debug, Clone)]
631pub struct Ftrl {
632    weights: Vec<f64>,
633    /// Accumulated gradient vector (z in the FTRL paper).
634    z: Vec<f64>,
635    /// Accumulated squared gradient per feature (n in the FTRL paper).
636    n_vec: Vec<f64>,
637    n_updates: usize,
638    stats: OnlineStats,
639    alpha: f64,
640    beta: f64,
641    l1: f64,
642    l2: f64,
643}
644
645impl Ftrl {
646    /// Create a new FTRL learner with `n_features` dimensions.
647    ///
648    /// Defaults: α = 0.1, β = 1.0, l1 = 0.0, l2 = 0.0.
649    pub fn new(n_features: usize) -> Self {
650        Self {
651            weights: vec![0.0; n_features],
652            z: vec![0.0; n_features],
653            n_vec: vec![0.0; n_features],
654            n_updates: 0,
655            stats: OnlineStats::default(),
656            alpha: 0.1,
657            beta: 1.0,
658            l1: 0.0,
659            l2: 0.0,
660        }
661    }
662
663    /// Set the learning rate α.
664    pub fn with_alpha(mut self, alpha: f64) -> Self {
665        self.alpha = alpha;
666        self
667    }
668
669    /// Set L1 and L2 regularization coefficients.
670    pub fn with_l1_l2(mut self, l1: f64, l2: f64) -> Self {
671        self.l1 = l1;
672        self.l2 = l2;
673        self
674    }
675
676    /// Reference to running statistics.
677    pub fn stats(&self) -> &OnlineStats {
678        &self.stats
679    }
680
681    /// Recompute weight from accumulated z and n for coordinate i.
682    #[inline]
683    fn compute_weight(&self, i: usize) -> f64 {
684        let z_i = self.z[i];
685        let n_i = self.n_vec[i];
686        if z_i.abs() <= self.l1 {
687            0.0
688        } else {
689            let numerator = -(z_i - sign(z_i) * self.l1);
690            let denominator = (self.beta + n_i.sqrt()) / self.alpha + self.l2;
691            if denominator == 0.0 {
692                0.0
693            } else {
694                numerator / denominator
695            }
696        }
697    }
698
699    /// Compute raw score w·x using on-the-fly weight computation.
700    fn score(&self, features: &[f64]) -> f64 {
701        features
702            .iter()
703            .enumerate()
704            .map(|(i, x)| self.compute_weight(i) * x)
705            .sum::<f64>()
706    }
707
708    /// Logistic probability σ(s) = 1 / (1 + e^{-s}).
709    #[inline]
710    fn sigmoid(s: f64) -> f64 {
711        1.0 / (1.0 + (-s).exp())
712    }
713}
714
715impl OnlineLearner for Ftrl {
716    fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
717        let n = self.weights.len();
718        if features.len() != n {
719            return Err(OnlineError::DimensionMismatch {
720                expected: n,
721                got: features.len(),
722            });
723        }
724
725        // Sync weights from z/n before computing score.
726        for i in 0..n {
727            self.weights[i] = self.compute_weight(i);
728        }
729
730        let score = dot(&self.weights, features);
731        let p = Self::sigmoid(score);
732
733        // FTRL uses logistic loss; label is mapped to {0, 1} for gradient.
734        // y_01 = 1 if label > 0 else 0.
735        let y_01 = if label > 0.0 { 1.0_f64 } else { 0.0_f64 };
736        let grad_scale = p - y_01; // ∂ℓ/∂score = p − y
737
738        // Logistic loss: -y log p - (1-y) log(1-p).
739        let loss = if y_01 > 0.0 {
740            -p.ln().max(-1e15)
741        } else {
742            -(1.0 - p).ln().max(-1e15)
743        };
744
745        let was_mistake = sign(score) != sign(label - 0.5); // compare against 0.5 threshold
746
747        let old_weights = self.weights.clone();
748
749        // FTRL update per coordinate.
750        for (i, &feat_i) in features.iter().enumerate().take(n) {
751            let g_i = grad_scale * feat_i;
752            let n_i_old = self.n_vec[i];
753            let n_i_new = n_i_old + g_i * g_i;
754
755            // σ_i = (√n_i_new − √n_i_old) / α
756            let sigma_i = (n_i_new.sqrt() - n_i_old.sqrt()) / self.alpha;
757
758            self.z[i] += g_i - sigma_i * self.weights[i];
759            self.n_vec[i] = n_i_new;
760            self.weights[i] = self.compute_weight(i);
761        }
762
763        let delta_norm = {
764            let sq: f64 = self
765                .weights
766                .iter()
767                .zip(old_weights.iter())
768                .map(|(w_new, w_old)| {
769                    let d = w_new - w_old;
770                    d * d
771                })
772                .sum();
773            sq.sqrt()
774        };
775
776        self.n_updates += 1;
777
778        let result = OnlineUpdateResult {
779            loss,
780            weight_delta_norm: delta_norm,
781            was_mistake,
782        };
783        self.stats.update(&result);
784        self.stats.last_weight_norm = l2_norm(&self.weights);
785
786        Ok(result)
787    }
788
789    fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
790        let n = self.weights.len();
791        if features.len() != n {
792            return Err(OnlineError::DimensionMismatch {
793                expected: n,
794                got: features.len(),
795            });
796        }
797        let score = self.score(features);
798        Ok(sign(score))
799    }
800
801    fn n_updates(&self) -> usize {
802        self.n_updates
803    }
804
805    fn weights(&self) -> &[f64] {
806        &self.weights
807    }
808}
809
810// ---------------------------------------------------------------------------
811// Batch evaluation helper
812// ---------------------------------------------------------------------------
813
814/// Evaluate an online learner sequentially on a dataset.
815///
816/// When `train = true`, each sample is used to update the model (prequential
817/// evaluation: predict first, then learn). When `train = false`, only
818/// predictions are made and the model is not updated.
819///
820/// Returns `(predictions, stats)` where `predictions[i]` is the prediction
821/// made *before* any update on sample `i`.
822pub fn online_evaluate(
823    learner: &mut dyn OnlineLearner,
824    data: &[(Vec<f64>, f64)],
825    train: bool,
826) -> Result<(Vec<f64>, OnlineStats), OnlineError> {
827    let mut predictions = Vec::with_capacity(data.len());
828    let mut stats = OnlineStats::default();
829
830    for (features, label) in data {
831        let pred = learner.predict(features)?;
832        predictions.push(pred);
833
834        if train {
835            let result = learner.update(features, *label)?;
836            stats.update(&result);
837        } else {
838            // Still track prediction quality without updating.
839            let was_mistake = sign(pred) != sign(*label);
840            let pseudo_result = OnlineUpdateResult {
841                loss: 0.0,
842                weight_delta_norm: 0.0,
843                was_mistake,
844            };
845            stats.update(&pseudo_result);
846        }
847    }
848
849    Ok((predictions, stats))
850}
851
852// ---------------------------------------------------------------------------
853// Tests
854// ---------------------------------------------------------------------------
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859
860    // -----------------------------------------------------------------------
861    // Helper utilities
862    // -----------------------------------------------------------------------
863
864    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
865        (a - b).abs() < tol
866    }
867
868    // -----------------------------------------------------------------------
869    // Perceptron tests
870    // -----------------------------------------------------------------------
871
872    #[test]
873    fn test_perceptron_zero_init() {
874        let p = Perceptron::new(4);
875        assert_eq!(p.weights(), &[0.0_f64; 4]);
876        assert_eq!(p.bias(), 0.0);
877        assert_eq!(p.n_updates(), 0);
878    }
879
880    #[test]
881    fn test_perceptron_update_on_mistake_positive() {
882        // y = +1, w·x = 0 → score=0 → sign=0 → mistake
883        let mut p = Perceptron::new(2).with_learning_rate(1.0);
884        let x = vec![1.0, 0.5];
885        let result = p.update(&x, 1.0).expect("update failed");
886        assert!(result.was_mistake);
887        // w should now be [1.0, 0.5]
888        assert!(approx_eq(p.weights()[0], 1.0, 1e-10));
889        assert!(approx_eq(p.weights()[1], 0.5, 1e-10));
890        assert!(approx_eq(p.bias(), 1.0, 1e-10));
891    }
892
893    #[test]
894    fn test_perceptron_no_update_on_correct() {
895        // Initialise with weights already correct for +1.
896        let mut p = Perceptron::new(2);
897        // Manually nudge weights so that x=[1,0] gets score > 0.
898        let x = vec![1.0, 0.0];
899        // First update creates a weight for y=+1.
900        p.update(&x, 1.0).expect("update");
901        let w_after_first = p.weights().to_vec();
902        // Second update: w·x = 1 > 0, so sign = +1 = y → no update.
903        p.update(&x, 1.0).expect("update");
904        assert_eq!(p.weights(), w_after_first.as_slice());
905    }
906
907    #[test]
908    fn test_perceptron_linearly_separable_2d() {
909        // 2D points: (+1 when x[0]>0, else -1).
910        let data: Vec<(Vec<f64>, f64)> = vec![
911            (vec![1.0, 0.2], 1.0),
912            (vec![-1.0, 0.3], -1.0),
913            (vec![2.0, -0.5], 1.0),
914            (vec![-2.0, 0.1], -1.0),
915            (vec![0.5, 0.5], 1.0),
916            (vec![-0.5, -0.5], -1.0),
917            (vec![1.5, -0.1], 1.0),
918            (vec![-1.5, 0.4], -1.0),
919            (vec![0.8, 0.0], 1.0),
920            (vec![-0.8, 0.2], -1.0),
921        ];
922        let mut p = Perceptron::new(2);
923        for _ in 0..20 {
924            for (x, y) in &data {
925                p.update(x, *y).expect("update");
926            }
927        }
928        // After convergence every point must be correct.
929        for (x, y) in &data {
930            let pred = p.predict(x).expect("predict");
931            assert_eq!(pred, *y, "misclassified {:?} (label {})", x, y);
932        }
933    }
934
935    #[test]
936    fn test_perceptron_n_updates_increments() {
937        let mut p = Perceptron::new(2);
938        for i in 0..5 {
939            p.update(&[1.0, -1.0], 1.0).expect("update");
940            assert_eq!(p.n_updates(), i + 1);
941        }
942    }
943
944    #[test]
945    fn test_perceptron_dimension_mismatch() {
946        let mut p = Perceptron::new(3);
947        let err = p.update(&[1.0, 2.0], 1.0);
948        assert!(matches!(
949            err,
950            Err(OnlineError::DimensionMismatch {
951                expected: 3,
952                got: 2
953            })
954        ));
955    }
956
957    // -----------------------------------------------------------------------
958    // Passive-Aggressive tests
959    // -----------------------------------------------------------------------
960
961    #[test]
962    fn test_pa_tau_basic() {
963        // PA variant: τ = loss / ||x||²
964        // Start from w=0, x=[1,0], y=+1 → score=0, loss = max(0,1-0)=1, ||x||²=1 → τ=1
965        let mut pa = PassiveAggressive::new(2, PAVariant::PA);
966        let result = pa.update(&[1.0, 0.0], 1.0).expect("update");
967        assert!(approx_eq(result.loss, 1.0, 1e-10));
968        // w = τ·y·x = 1·1·[1,0] = [1,0]
969        assert!(approx_eq(pa.weights()[0], 1.0, 1e-10));
970    }
971
972    #[test]
973    fn test_pa1_tau_clamped() {
974        // PA-I: τ = min(C, loss/||x||²).  Set C=0.3 so τ should be clamped.
975        let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
976            .with_aggressiveness(0.3)
977            .expect("valid C");
978        // w=0, x=[1,0], y=+1 → loss=1, ||x||²=1 → unclamped τ=1.0 > C=0.3 → τ=0.3
979        let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
980        assert!(approx_eq(pa.weights()[0], 0.3, 1e-10));
981    }
982
983    #[test]
984    fn test_pa2_tau_formula() {
985        // PA-II: τ = loss / (||x||² + 1/(2C)), C=1.0 → denom = 1 + 0.5 = 1.5 → τ=1/1.5
986        let mut pa = PassiveAggressive::new(2, PAVariant::PAII)
987            .with_aggressiveness(1.0)
988            .expect("valid C");
989        let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
990        let expected_tau = 1.0 / 1.5;
991        assert!(
992            approx_eq(pa.weights()[0], expected_tau, 1e-10),
993            "expected {expected_tau}, got {}",
994            pa.weights()[0]
995        );
996    }
997
998    #[test]
999    fn test_pa_negative_c_returns_err() {
1000        let res = PassiveAggressive::new(2, PAVariant::PA).with_aggressiveness(-1.0);
1001        assert!(res.is_err());
1002    }
1003
1004    #[test]
1005    fn test_pa_dimension_mismatch() {
1006        let mut pa = PassiveAggressive::new(3, PAVariant::PA);
1007        let err = pa.update(&[1.0], 1.0);
1008        assert!(matches!(
1009            err,
1010            Err(OnlineError::DimensionMismatch {
1011                expected: 3,
1012                got: 1
1013            })
1014        ));
1015    }
1016
1017    // -----------------------------------------------------------------------
1018    // OGD tests
1019    // -----------------------------------------------------------------------
1020
1021    #[test]
1022    fn test_ogd_squared_loss_gradient() {
1023        // w=0, b=0, x=[2.0,0.0], y=3.0 → pred=0, loss=½·9=4.5, grad=(-3)·x → delta = +η·3·x
1024        let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Squared).with_lr(0.1);
1025        let result = ogd.update(&[2.0, 0.0], 3.0).expect("update");
1026        // loss = ½(0-3)² = 4.5
1027        assert!(approx_eq(result.loss, 4.5, 1e-10));
1028        // weight[0] += -η * (0 - 3) * 2 = +0.6
1029        assert!(approx_eq(ogd.weights()[0], 0.6, 1e-10));
1030    }
1031
1032    #[test]
1033    fn test_ogd_hinge_no_update_when_margin_ok() {
1034        // Manually set up a learner where y*score ≥ 1 → no gradient.
1035        let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Hinge).with_lr(1.0);
1036        // After one y=+1 update on x=[1,0], w=[0.1, 0], score for x=[1,0]=0.1 < 1 → updates happen.
1037        // Set up with large weights directly: use multiple updates.
1038        for _ in 0..20 {
1039            ogd.update(&[10.0, 0.0], 1.0).expect("update");
1040        }
1041        let w_before = ogd.weights().to_vec();
1042        // Now y·score >> 1, no gradient.
1043        let result = ogd.update(&[10.0, 0.0], 1.0).expect("update");
1044        assert_eq!(result.loss, 0.0, "expected zero hinge loss");
1045        assert_eq!(result.weight_delta_norm, 0.0);
1046        assert_eq!(ogd.weights(), w_before.as_slice());
1047    }
1048
1049    #[test]
1050    fn test_ogd_lr_decay_reduces_lr() {
1051        // Verify that the effective learning rate decreases over time.
1052        // current_lr() = initial_lr / sqrt(n_updates + 1) when lr_decay > 0.
1053        // At t=0: lr = 1.0/√1 = 1.0
1054        // At t=5: lr = 1.0/√6 ≈ 0.408
1055        // We measure this by observing the step on a constant gradient.
1056        // Use a large single-feature example where the gradient is always 1.
1057        // Step size for the bias (not regularised) = eta * bias_grad.
1058        // For squared loss on x=[], label=1, score=0: grad=0-1=-1, delta_bias=eta*1.
1059        // We compare delta_bias at t=0 vs t=5.
1060
1061        let mut ogd_decay = OnlineGradientDescent::new(1, OGDLoss::Squared)
1062            .with_lr(1.0)
1063            .with_lr_decay(1.0);
1064
1065        let mut ogd_nodecay = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(1.0);
1066
1067        // Both start at w=0, b=0; use x=[0] so weight gets no gradient, only bias does.
1068        for _ in 0..5 {
1069            ogd_decay.update(&[0.0], 1.0).expect("update");
1070            ogd_nodecay.update(&[0.0], 1.0).expect("update");
1071        }
1072        // After 5 steps with same gradient, lr_decay model should have made smaller total progress.
1073        // bias_nodecay converges faster (constant lr=1.0 vs decaying).
1074        // Actually both converge; check that decayed model has lower bias after same # of steps.
1075        // For no-decay: bias → 1.0 quickly. For decay: slower convergence.
1076        assert!(
1077            ogd_decay.bias.abs() <= ogd_nodecay.bias.abs() + 1e-9,
1078            "decaying lr should not exceed constant lr convergence; decay_bias={}, nodecay_bias={}",
1079            ogd_decay.bias,
1080            ogd_nodecay.bias
1081        );
1082
1083        // Verify n_updates is used to compute lr: at t=10 the lr should be < lr at t=0.
1084        let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared)
1085            .with_lr(1.0)
1086            .with_lr_decay(1.0);
1087        // Drive n_updates to 9 (lr at t=9: 1/√10 ≈ 0.316).
1088        for _ in 0..9 {
1089            ogd.update(&[0.0], 0.0).expect("update"); // zero gradient, just increments counter
1090        }
1091        let lr_at_t9 = ogd.current_lr();
1092        assert!(
1093            lr_at_t9 < 0.5,
1094            "lr at t=9 should be 1/√10 ≈ 0.316, got {lr_at_t9}"
1095        );
1096        assert!(
1097            approx_eq(lr_at_t9, 1.0 / 10_f64.sqrt(), 1e-10),
1098            "expected 1/√10, got {lr_at_t9}"
1099        );
1100    }
1101
1102    #[test]
1103    fn test_ogd_l2_penalises_large_weights() {
1104        // With l2_reg, the weight should be smaller after many updates than without.
1105        let mut ogd_no_reg = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.5);
1106        let mut ogd_l2 = OnlineGradientDescent::new(1, OGDLoss::Squared)
1107            .with_lr(0.5)
1108            .with_l2(0.5);
1109
1110        for _ in 0..30 {
1111            ogd_no_reg.update(&[1.0], 1.0).expect("update");
1112            ogd_l2.update(&[1.0], 1.0).expect("update");
1113        }
1114        // The regularised model should have smaller weights.
1115        assert!(
1116            ogd_l2.weights()[0].abs() < ogd_no_reg.weights()[0].abs(),
1117            "l2 reg should shrink weights; no_reg={}, l2={}",
1118            ogd_no_reg.weights()[0],
1119            ogd_l2.weights()[0]
1120        );
1121    }
1122
1123    #[test]
1124    fn test_ogd_dimension_mismatch() {
1125        let mut ogd = OnlineGradientDescent::new(3, OGDLoss::Squared);
1126        let err = ogd.update(&[1.0, 2.0], 0.0);
1127        assert!(matches!(
1128            err,
1129            Err(OnlineError::DimensionMismatch {
1130                expected: 3,
1131                got: 2
1132            })
1133        ));
1134    }
1135
1136    // -----------------------------------------------------------------------
1137    // FTRL tests
1138    // -----------------------------------------------------------------------
1139
1140    #[test]
1141    fn test_ftrl_l1_sparsity() {
1142        // With l1 > 0, features whose accumulated gradient |z_i| ≤ l1 are zeroed.
1143        let mut ftrl = Ftrl::new(2).with_alpha(0.1).with_l1_l2(10.0, 0.0);
1144        // After one small update, z[i] should be well within l1 → w should be 0.
1145        ftrl.update(&[1.0, 0.0], 1.0).expect("update");
1146        // After just a few steps against a large l1, weights should remain zero.
1147        assert_eq!(ftrl.weights()[0], 0.0, "weight should be zero due to L1");
1148    }
1149
1150    #[test]
1151    fn test_ftrl_adaptive_per_feature() {
1152        // Feature 0 appears frequently, feature 1 rarely → n_vec[0] >> n_vec[1]
1153        // meaning feature 0's effective lr should be smaller.
1154        let mut ftrl = Ftrl::new(2).with_alpha(0.1);
1155        for _ in 0..50 {
1156            ftrl.update(&[1.0, 0.0], 1.0).expect("update");
1157        }
1158        // n_vec[0] should be much larger than n_vec[1] (which stays at 0).
1159        assert!(ftrl.n_vec[0] > ftrl.n_vec[1]);
1160    }
1161
1162    #[test]
1163    fn test_ftrl_l1_zero_l2_zero_adagrad_like() {
1164        // With l1=0, l2=0, FTRL reduces to a form of AdaGrad.
1165        // The weight update should be non-zero after enough iterations.
1166        let mut ftrl = Ftrl::new(1).with_alpha(1.0).with_l1_l2(0.0, 0.0);
1167        for _ in 0..10 {
1168            ftrl.update(&[1.0], 1.0).expect("update");
1169        }
1170        // With consistent positive label signals, weight should be positive.
1171        assert!(
1172            ftrl.weights()[0] > 0.0,
1173            "weight should be positive; got {}",
1174            ftrl.weights()[0]
1175        );
1176    }
1177
1178    #[test]
1179    fn test_ftrl_dimension_mismatch() {
1180        let mut ftrl = Ftrl::new(3);
1181        let err = ftrl.update(&[1.0, 2.0], 1.0);
1182        assert!(matches!(
1183            err,
1184            Err(OnlineError::DimensionMismatch {
1185                expected: 3,
1186                got: 2
1187            })
1188        ));
1189    }
1190
1191    #[test]
1192    fn test_ftrl_predict_dimension_mismatch() {
1193        let ftrl = Ftrl::new(3);
1194        let err = ftrl.predict(&[1.0]);
1195        assert!(matches!(
1196            err,
1197            Err(OnlineError::DimensionMismatch {
1198                expected: 3,
1199                got: 1
1200            })
1201        ));
1202    }
1203
1204    // -----------------------------------------------------------------------
1205    // OnlineStats tests
1206    // -----------------------------------------------------------------------
1207
1208    #[test]
1209    fn test_online_stats_mistake_rate_zero_updates() {
1210        let stats = OnlineStats::default();
1211        assert_eq!(stats.mistake_rate(), 0.0);
1212    }
1213
1214    #[test]
1215    fn test_online_stats_mistake_rate_computation() {
1216        let mut stats = OnlineStats::default();
1217        let mistake = OnlineUpdateResult {
1218            loss: 1.0,
1219            weight_delta_norm: 0.5,
1220            was_mistake: true,
1221        };
1222        let correct = OnlineUpdateResult {
1223            loss: 0.0,
1224            weight_delta_norm: 0.0,
1225            was_mistake: false,
1226        };
1227        stats.update(&mistake);
1228        stats.update(&correct);
1229        stats.update(&mistake);
1230        // 2 mistakes out of 3.
1231        assert!(approx_eq(stats.mistake_rate(), 2.0 / 3.0, 1e-10));
1232    }
1233
1234    #[test]
1235    fn test_online_stats_cumulative_loss() {
1236        let mut stats = OnlineStats::default();
1237        for loss_val in [0.5, 1.0, 1.5] {
1238            let r = OnlineUpdateResult {
1239                loss: loss_val,
1240                weight_delta_norm: 0.0,
1241                was_mistake: false,
1242            };
1243            stats.update(&r);
1244        }
1245        assert!(approx_eq(stats.cumulative_loss, 3.0, 1e-10));
1246        assert!(approx_eq(stats.mean_loss, 1.0, 1e-10));
1247    }
1248
1249    // -----------------------------------------------------------------------
1250    // online_evaluate tests
1251    // -----------------------------------------------------------------------
1252
1253    #[test]
1254    fn test_online_evaluate_train_true_updates_model() {
1255        let mut p = Perceptron::new(2);
1256        let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
1257        let (preds, _stats) = online_evaluate(&mut p, &data, true).expect("evaluate");
1258        assert_eq!(preds.len(), 2);
1259        // After processing both samples, n_updates should be 2.
1260        assert_eq!(p.n_updates(), 2);
1261    }
1262
1263    #[test]
1264    fn test_online_evaluate_train_false_no_update() {
1265        let mut p = Perceptron::new(2);
1266        let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
1267        let (preds, _stats) = online_evaluate(&mut p, &data, false).expect("evaluate");
1268        assert_eq!(preds.len(), 2);
1269        // No updates when train=false.
1270        assert_eq!(p.n_updates(), 0);
1271    }
1272
1273    // -----------------------------------------------------------------------
1274    // Convergence test
1275    // -----------------------------------------------------------------------
1276
1277    #[test]
1278    fn test_perceptron_converges_linearly_separable_10_samples() {
1279        let data: Vec<(Vec<f64>, f64)> = vec![
1280            (vec![2.0, 1.0], 1.0),
1281            (vec![1.5, 0.8], 1.0),
1282            (vec![1.0, 0.5], 1.0),
1283            (vec![0.5, 0.2], 1.0),
1284            (vec![0.2, 0.1], 1.0),
1285            (vec![-0.2, -0.1], -1.0),
1286            (vec![-0.5, -0.3], -1.0),
1287            (vec![-1.0, -0.5], -1.0),
1288            (vec![-1.5, -0.7], -1.0),
1289            (vec![-2.0, -1.0], -1.0),
1290        ];
1291        let mut p = Perceptron::new(2);
1292        // Run multiple passes.
1293        for _ in 0..50 {
1294            for (x, y) in &data {
1295                p.update(x, *y).expect("update");
1296            }
1297        }
1298        let mut correct = 0;
1299        for (x, y) in &data {
1300            let pred = p.predict(x).expect("predict");
1301            if pred == *y {
1302                correct += 1;
1303            }
1304        }
1305        assert_eq!(
1306            correct, 10,
1307            "Perceptron should converge on linearly separable data"
1308        );
1309    }
1310
1311    #[test]
1312    fn test_pa_converges_linearly_separable() {
1313        let data: Vec<(Vec<f64>, f64)> = vec![
1314            (vec![1.0, 0.5], 1.0),
1315            (vec![-1.0, -0.5], -1.0),
1316            (vec![2.0, 1.0], 1.0),
1317            (vec![-2.0, -1.0], -1.0),
1318        ];
1319        let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
1320            .with_aggressiveness(1.0)
1321            .expect("valid C");
1322        for _ in 0..30 {
1323            for (x, y) in &data {
1324                pa.update(x, *y).expect("update");
1325            }
1326        }
1327        for (x, y) in &data {
1328            let pred = pa.predict(x).expect("predict");
1329            assert_eq!(pred, *y);
1330        }
1331    }
1332
1333    #[test]
1334    fn test_ogd_squared_converges_to_constant() {
1335        // All labels are 2.0 — OGD squared loss should drive w·x toward 2.
1336        let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.3);
1337        let x = vec![1.0];
1338        for _ in 0..200 {
1339            ogd.update(&x, 2.0).expect("update");
1340        }
1341        let pred = ogd.predict(&x).expect("predict");
1342        assert!(
1343            approx_eq(pred, 2.0, 0.1),
1344            "OGD should converge near 2.0, got {pred}"
1345        );
1346    }
1347
1348    #[test]
1349    fn test_ftrl_n_updates_increments() {
1350        let mut ftrl = Ftrl::new(2);
1351        for i in 0..7 {
1352            ftrl.update(&[1.0, 0.5], 1.0).expect("update");
1353            assert_eq!(ftrl.n_updates(), i + 1);
1354        }
1355    }
1356
1357    #[test]
1358    fn test_online_error_display() {
1359        let e = OnlineError::DimensionMismatch {
1360            expected: 5,
1361            got: 3,
1362        };
1363        let s = e.to_string();
1364        assert!(s.contains("5") && s.contains("3"));
1365
1366        let e2 = OnlineError::InvalidHyperparameter("C must be positive".to_string());
1367        assert!(e2.to_string().contains("C must be positive"));
1368
1369        let e3 = OnlineError::NotFitted;
1370        assert!(e3.to_string().contains("fitted"));
1371    }
1372}