1use std::fmt;
15
16pub trait OnlineLearner {
22 fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError>;
24
25 fn predict(&self, features: &[f64]) -> Result<f64, OnlineError>;
27
28 fn n_updates(&self) -> usize;
30
31 fn weights(&self) -> &[f64];
33}
34
35#[derive(Debug, Clone)]
37pub struct OnlineUpdateResult {
38 pub loss: f64,
40 pub weight_delta_norm: f64,
42 pub was_mistake: bool,
44}
45
46#[derive(Debug)]
48pub enum OnlineError {
49 DimensionMismatch { expected: usize, got: usize },
51 InvalidHyperparameter(String),
53 NotFitted,
55}
56
57impl fmt::Display for OnlineError {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 match self {
60 OnlineError::DimensionMismatch { expected, got } => write!(
61 f,
62 "dimension mismatch: expected {expected} features, got {got}"
63 ),
64 OnlineError::InvalidHyperparameter(msg) => {
65 write!(f, "invalid hyperparameter: {msg}")
66 }
67 OnlineError::NotFitted => write!(f, "model has not been fitted yet"),
68 }
69 }
70}
71
72impl std::error::Error for OnlineError {}
73
74#[derive(Debug, Clone, Default)]
80pub struct OnlineStats {
81 pub n_updates: usize,
83 pub n_mistakes: usize,
85 pub cumulative_loss: f64,
87 pub mean_loss: f64,
89 pub last_weight_norm: f64,
91}
92
93impl OnlineStats {
94 pub fn mistake_rate(&self) -> f64 {
98 if self.n_updates == 0 {
99 0.0
100 } else {
101 self.n_mistakes as f64 / self.n_updates as f64
102 }
103 }
104
105 pub fn update(&mut self, result: &OnlineUpdateResult) {
107 self.n_updates += 1;
108 if result.was_mistake {
109 self.n_mistakes += 1;
110 }
111 self.cumulative_loss += result.loss;
112 self.mean_loss = self.cumulative_loss / self.n_updates as f64;
113 }
114}
115
116#[inline]
122fn l2_norm_sq(v: &[f64]) -> f64 {
123 v.iter().map(|x| x * x).sum()
124}
125
126#[inline]
128fn l2_norm(v: &[f64]) -> f64 {
129 l2_norm_sq(v).sqrt()
130}
131
132#[inline]
134fn dot(a: &[f64], b: &[f64]) -> f64 {
135 a.iter().zip(b.iter()).map(|(ai, bi)| ai * bi).sum()
136}
137
138#[inline]
140fn sign(x: f64) -> f64 {
141 if x > 0.0 {
142 1.0
143 } else if x < 0.0 {
144 -1.0
145 } else {
146 0.0
147 }
148}
149
150#[derive(Debug, Clone)]
159pub struct Perceptron {
160 weights: Vec<f64>,
161 bias: f64,
162 n_updates: usize,
163 stats: OnlineStats,
164 learning_rate: f64,
165}
166
167impl Perceptron {
168 pub fn new(n_features: usize) -> Self {
170 Self {
171 weights: vec![0.0; n_features],
172 bias: 0.0,
173 n_updates: 0,
174 stats: OnlineStats::default(),
175 learning_rate: 1.0,
176 }
177 }
178
179 pub fn with_learning_rate(mut self, lr: f64) -> Self {
181 self.learning_rate = lr;
182 self
183 }
184
185 pub fn bias(&self) -> f64 {
187 self.bias
188 }
189
190 pub fn stats(&self) -> &OnlineStats {
192 &self.stats
193 }
194
195 fn score(&self, features: &[f64]) -> f64 {
197 dot(&self.weights, features) + self.bias
198 }
199}
200
201impl OnlineLearner for Perceptron {
202 fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
203 let n = self.weights.len();
204 if features.len() != n {
205 return Err(OnlineError::DimensionMismatch {
206 expected: n,
207 got: features.len(),
208 });
209 }
210
211 let score = self.score(features);
212 let predicted_sign = sign(score);
213 let true_sign = sign(label);
214
215 let margin = true_sign * score;
217 let loss = if margin <= 0.0 { -margin } else { 0.0 };
218 let was_mistake = predicted_sign != true_sign;
219
220 let mut delta_sq = 0.0_f64;
221
222 if was_mistake {
223 let eta_y = self.learning_rate * true_sign;
224 for (w, x) in self.weights.iter_mut().zip(features.iter()) {
225 let delta = eta_y * x;
226 delta_sq += delta * delta;
227 *w += delta;
228 }
229 let bias_delta = self.learning_rate * true_sign;
230 delta_sq += bias_delta * bias_delta;
231 self.bias += bias_delta;
232 }
233
234 self.n_updates += 1;
235
236 let weight_delta_norm = delta_sq.sqrt();
238 let result = OnlineUpdateResult {
239 loss,
240 weight_delta_norm,
241 was_mistake,
242 };
243 self.stats.update(&result);
244 self.stats.last_weight_norm = l2_norm(&self.weights);
245
246 Ok(result)
247 }
248
249 fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
250 let n = self.weights.len();
251 if features.len() != n {
252 return Err(OnlineError::DimensionMismatch {
253 expected: n,
254 got: features.len(),
255 });
256 }
257 Ok(sign(self.score(features)))
258 }
259
260 fn n_updates(&self) -> usize {
261 self.n_updates
262 }
263
264 fn weights(&self) -> &[f64] {
265 &self.weights
266 }
267}
268
269#[derive(Debug, Clone, Copy, PartialEq, Eq)]
275pub enum PAVariant {
276 PA,
278 PAI,
280 PAII,
282}
283
284#[derive(Debug, Clone)]
290pub struct PassiveAggressive {
291 weights: Vec<f64>,
292 bias: f64,
293 n_updates: usize,
294 stats: OnlineStats,
295 aggressiveness: f64,
296 variant: PAVariant,
297}
298
299impl PassiveAggressive {
300 pub fn new(n_features: usize, variant: PAVariant) -> Self {
302 Self {
303 weights: vec![0.0; n_features],
304 bias: 0.0,
305 n_updates: 0,
306 stats: OnlineStats::default(),
307 aggressiveness: 1.0,
308 variant,
309 }
310 }
311
312 pub fn with_aggressiveness(mut self, c: f64) -> Result<Self, OnlineError> {
314 if c <= 0.0 {
315 return Err(OnlineError::InvalidHyperparameter(format!(
316 "aggressiveness C must be > 0, got {c}"
317 )));
318 }
319 self.aggressiveness = c;
320 Ok(self)
321 }
322
323 pub fn stats(&self) -> &OnlineStats {
325 &self.stats
326 }
327
328 fn compute_tau(&self, loss: f64, x_norm_sq: f64) -> f64 {
330 match self.variant {
331 PAVariant::PA => {
332 if x_norm_sq == 0.0 {
333 0.0
334 } else {
335 loss / x_norm_sq
336 }
337 }
338 PAVariant::PAI => {
339 let tau_unconstrained = if x_norm_sq == 0.0 {
340 0.0
341 } else {
342 loss / x_norm_sq
343 };
344 tau_unconstrained.min(self.aggressiveness)
345 }
346 PAVariant::PAII => {
347 let denom = x_norm_sq + 1.0 / (2.0 * self.aggressiveness);
348 if denom == 0.0 {
349 0.0
350 } else {
351 loss / denom
352 }
353 }
354 }
355 }
356}
357
358impl OnlineLearner for PassiveAggressive {
359 fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
360 let n = self.weights.len();
361 if features.len() != n {
362 return Err(OnlineError::DimensionMismatch {
363 expected: n,
364 got: features.len(),
365 });
366 }
367
368 let score = dot(&self.weights, features) + self.bias;
369 let y = sign(label);
370
371 let margin = y * score;
373 let loss = (1.0 - margin).max(0.0);
374 let was_mistake = sign(score) != y;
375
376 let x_norm_sq = l2_norm_sq(features);
377 let tau = self.compute_tau(loss, x_norm_sq);
378
379 let mut delta_sq = 0.0_f64;
380 if tau > 0.0 {
381 let tau_y = tau * y;
382 for (w, x) in self.weights.iter_mut().zip(features.iter()) {
383 let delta = tau_y * x;
384 delta_sq += delta * delta;
385 *w += delta;
386 }
387 let bias_delta = tau * y;
388 delta_sq += bias_delta * bias_delta;
389 self.bias += bias_delta;
390 }
391
392 self.n_updates += 1;
393
394 let result = OnlineUpdateResult {
395 loss,
396 weight_delta_norm: delta_sq.sqrt(),
397 was_mistake,
398 };
399 self.stats.update(&result);
400 self.stats.last_weight_norm = l2_norm(&self.weights);
401
402 Ok(result)
403 }
404
405 fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
406 let n = self.weights.len();
407 if features.len() != n {
408 return Err(OnlineError::DimensionMismatch {
409 expected: n,
410 got: features.len(),
411 });
412 }
413 Ok(sign(dot(&self.weights, features) + self.bias))
414 }
415
416 fn n_updates(&self) -> usize {
417 self.n_updates
418 }
419
420 fn weights(&self) -> &[f64] {
421 &self.weights
422 }
423}
424
425#[derive(Debug, Clone, Copy, PartialEq, Eq)]
431pub enum OGDLoss {
432 Squared,
434 Hinge,
436 Logistic,
438}
439
440#[derive(Debug, Clone)]
446pub struct OnlineGradientDescent {
447 weights: Vec<f64>,
448 bias: f64,
449 n_updates: usize,
450 stats: OnlineStats,
451 initial_lr: f64,
452 lr_decay: f64,
453 l2_reg: f64,
454 loss: OGDLoss,
455}
456
457impl OnlineGradientDescent {
458 pub fn new(n_features: usize, loss: OGDLoss) -> Self {
460 Self {
461 weights: vec![0.0; n_features],
462 bias: 0.0,
463 n_updates: 0,
464 stats: OnlineStats::default(),
465 initial_lr: 0.1,
466 lr_decay: 0.0,
467 l2_reg: 0.0,
468 loss,
469 }
470 }
471
472 pub fn with_lr(mut self, lr: f64) -> Self {
474 self.initial_lr = lr;
475 self
476 }
477
478 pub fn with_l2(mut self, lambda: f64) -> Self {
480 self.l2_reg = lambda;
481 self
482 }
483
484 pub fn with_lr_decay(mut self, decay: f64) -> Self {
486 self.lr_decay = decay;
487 self
488 }
489
490 pub fn stats(&self) -> &OnlineStats {
492 &self.stats
493 }
494
495 fn current_lr(&self) -> f64 {
497 if self.lr_decay > 0.0 {
498 self.initial_lr / ((self.n_updates as f64 + 1.0).sqrt())
499 } else {
500 self.initial_lr
501 }
502 }
503
504 fn compute_loss_and_grad(&self, features: &[f64], label: f64) -> (f64, f64, f64) {
507 let score = dot(&self.weights, features) + self.bias;
508 match self.loss {
509 OGDLoss::Squared => {
510 let diff = score - label;
511 let loss = 0.5 * diff * diff;
512 (loss, diff, diff)
513 }
514 OGDLoss::Hinge => {
515 let y = sign(label);
516 let margin = y * score;
517 if margin < 1.0 {
518 let loss = 1.0 - margin;
519 (loss, -y, -y)
520 } else {
521 (0.0, 0.0, 0.0)
522 }
523 }
524 OGDLoss::Logistic => {
525 let y = sign(label);
526 let ys = y * score;
528 let sigma_neg = 1.0 / (1.0 + ys.exp()); let loss = (1.0 + (-ys).exp()).ln();
530 let grad_coeff = -y * sigma_neg;
531 (loss, grad_coeff, grad_coeff)
532 }
533 }
534 }
535}
536
537impl OnlineLearner for OnlineGradientDescent {
538 fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
539 let n = self.weights.len();
540 if features.len() != n {
541 return Err(OnlineError::DimensionMismatch {
542 expected: n,
543 got: features.len(),
544 });
545 }
546
547 let (loss, grad_coeff, bias_grad) = self.compute_loss_and_grad(features, label);
548 let eta = self.current_lr();
549
550 let was_mistake = match self.loss {
551 OGDLoss::Squared => false, OGDLoss::Hinge | OGDLoss::Logistic => {
553 let score = dot(&self.weights, features) + self.bias;
554 sign(score) != sign(label)
555 }
556 };
557
558 let mut delta_sq = 0.0_f64;
559
560 for (w, x) in self.weights.iter_mut().zip(features.iter()) {
562 let grad = grad_coeff * x + self.l2_reg * (*w);
563 let delta = -eta * grad;
564 delta_sq += delta * delta;
565 *w += delta;
566 }
567 let bias_delta = -eta * bias_grad;
569 delta_sq += bias_delta * bias_delta;
570 self.bias += bias_delta;
571
572 self.n_updates += 1;
573
574 let result = OnlineUpdateResult {
575 loss,
576 weight_delta_norm: delta_sq.sqrt(),
577 was_mistake,
578 };
579 self.stats.update(&result);
580 self.stats.last_weight_norm = l2_norm(&self.weights);
581
582 Ok(result)
583 }
584
585 fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
586 let n = self.weights.len();
587 if features.len() != n {
588 return Err(OnlineError::DimensionMismatch {
589 expected: n,
590 got: features.len(),
591 });
592 }
593 let score = dot(&self.weights, features) + self.bias;
594 let prediction = match self.loss {
595 OGDLoss::Squared => score,
596 OGDLoss::Hinge | OGDLoss::Logistic => sign(score),
597 };
598 Ok(prediction)
599 }
600
601 fn n_updates(&self) -> usize {
602 self.n_updates
603 }
604
605 fn weights(&self) -> &[f64] {
606 &self.weights
607 }
608}
609
610#[derive(Debug, Clone)]
631pub struct Ftrl {
632 weights: Vec<f64>,
633 z: Vec<f64>,
635 n_vec: Vec<f64>,
637 n_updates: usize,
638 stats: OnlineStats,
639 alpha: f64,
640 beta: f64,
641 l1: f64,
642 l2: f64,
643}
644
645impl Ftrl {
646 pub fn new(n_features: usize) -> Self {
650 Self {
651 weights: vec![0.0; n_features],
652 z: vec![0.0; n_features],
653 n_vec: vec![0.0; n_features],
654 n_updates: 0,
655 stats: OnlineStats::default(),
656 alpha: 0.1,
657 beta: 1.0,
658 l1: 0.0,
659 l2: 0.0,
660 }
661 }
662
663 pub fn with_alpha(mut self, alpha: f64) -> Self {
665 self.alpha = alpha;
666 self
667 }
668
669 pub fn with_l1_l2(mut self, l1: f64, l2: f64) -> Self {
671 self.l1 = l1;
672 self.l2 = l2;
673 self
674 }
675
676 pub fn stats(&self) -> &OnlineStats {
678 &self.stats
679 }
680
681 #[inline]
683 fn compute_weight(&self, i: usize) -> f64 {
684 let z_i = self.z[i];
685 let n_i = self.n_vec[i];
686 if z_i.abs() <= self.l1 {
687 0.0
688 } else {
689 let numerator = -(z_i - sign(z_i) * self.l1);
690 let denominator = (self.beta + n_i.sqrt()) / self.alpha + self.l2;
691 if denominator == 0.0 {
692 0.0
693 } else {
694 numerator / denominator
695 }
696 }
697 }
698
699 fn score(&self, features: &[f64]) -> f64 {
701 features
702 .iter()
703 .enumerate()
704 .map(|(i, x)| self.compute_weight(i) * x)
705 .sum::<f64>()
706 }
707
708 #[inline]
710 fn sigmoid(s: f64) -> f64 {
711 1.0 / (1.0 + (-s).exp())
712 }
713}
714
715impl OnlineLearner for Ftrl {
716 fn update(&mut self, features: &[f64], label: f64) -> Result<OnlineUpdateResult, OnlineError> {
717 let n = self.weights.len();
718 if features.len() != n {
719 return Err(OnlineError::DimensionMismatch {
720 expected: n,
721 got: features.len(),
722 });
723 }
724
725 for i in 0..n {
727 self.weights[i] = self.compute_weight(i);
728 }
729
730 let score = dot(&self.weights, features);
731 let p = Self::sigmoid(score);
732
733 let y_01 = if label > 0.0 { 1.0_f64 } else { 0.0_f64 };
736 let grad_scale = p - y_01; let loss = if y_01 > 0.0 {
740 -p.ln().max(-1e15)
741 } else {
742 -(1.0 - p).ln().max(-1e15)
743 };
744
745 let was_mistake = sign(score) != sign(label - 0.5); let old_weights = self.weights.clone();
748
749 for (i, &feat_i) in features.iter().enumerate().take(n) {
751 let g_i = grad_scale * feat_i;
752 let n_i_old = self.n_vec[i];
753 let n_i_new = n_i_old + g_i * g_i;
754
755 let sigma_i = (n_i_new.sqrt() - n_i_old.sqrt()) / self.alpha;
757
758 self.z[i] += g_i - sigma_i * self.weights[i];
759 self.n_vec[i] = n_i_new;
760 self.weights[i] = self.compute_weight(i);
761 }
762
763 let delta_norm = {
764 let sq: f64 = self
765 .weights
766 .iter()
767 .zip(old_weights.iter())
768 .map(|(w_new, w_old)| {
769 let d = w_new - w_old;
770 d * d
771 })
772 .sum();
773 sq.sqrt()
774 };
775
776 self.n_updates += 1;
777
778 let result = OnlineUpdateResult {
779 loss,
780 weight_delta_norm: delta_norm,
781 was_mistake,
782 };
783 self.stats.update(&result);
784 self.stats.last_weight_norm = l2_norm(&self.weights);
785
786 Ok(result)
787 }
788
789 fn predict(&self, features: &[f64]) -> Result<f64, OnlineError> {
790 let n = self.weights.len();
791 if features.len() != n {
792 return Err(OnlineError::DimensionMismatch {
793 expected: n,
794 got: features.len(),
795 });
796 }
797 let score = self.score(features);
798 Ok(sign(score))
799 }
800
801 fn n_updates(&self) -> usize {
802 self.n_updates
803 }
804
805 fn weights(&self) -> &[f64] {
806 &self.weights
807 }
808}
809
810pub fn online_evaluate(
823 learner: &mut dyn OnlineLearner,
824 data: &[(Vec<f64>, f64)],
825 train: bool,
826) -> Result<(Vec<f64>, OnlineStats), OnlineError> {
827 let mut predictions = Vec::with_capacity(data.len());
828 let mut stats = OnlineStats::default();
829
830 for (features, label) in data {
831 let pred = learner.predict(features)?;
832 predictions.push(pred);
833
834 if train {
835 let result = learner.update(features, *label)?;
836 stats.update(&result);
837 } else {
838 let was_mistake = sign(pred) != sign(*label);
840 let pseudo_result = OnlineUpdateResult {
841 loss: 0.0,
842 weight_delta_norm: 0.0,
843 was_mistake,
844 };
845 stats.update(&pseudo_result);
846 }
847 }
848
849 Ok((predictions, stats))
850}
851
852#[cfg(test)]
857mod tests {
858 use super::*;
859
860 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
865 (a - b).abs() < tol
866 }
867
868 #[test]
873 fn test_perceptron_zero_init() {
874 let p = Perceptron::new(4);
875 assert_eq!(p.weights(), &[0.0_f64; 4]);
876 assert_eq!(p.bias(), 0.0);
877 assert_eq!(p.n_updates(), 0);
878 }
879
880 #[test]
881 fn test_perceptron_update_on_mistake_positive() {
882 let mut p = Perceptron::new(2).with_learning_rate(1.0);
884 let x = vec![1.0, 0.5];
885 let result = p.update(&x, 1.0).expect("update failed");
886 assert!(result.was_mistake);
887 assert!(approx_eq(p.weights()[0], 1.0, 1e-10));
889 assert!(approx_eq(p.weights()[1], 0.5, 1e-10));
890 assert!(approx_eq(p.bias(), 1.0, 1e-10));
891 }
892
893 #[test]
894 fn test_perceptron_no_update_on_correct() {
895 let mut p = Perceptron::new(2);
897 let x = vec![1.0, 0.0];
899 p.update(&x, 1.0).expect("update");
901 let w_after_first = p.weights().to_vec();
902 p.update(&x, 1.0).expect("update");
904 assert_eq!(p.weights(), w_after_first.as_slice());
905 }
906
907 #[test]
908 fn test_perceptron_linearly_separable_2d() {
909 let data: Vec<(Vec<f64>, f64)> = vec![
911 (vec![1.0, 0.2], 1.0),
912 (vec![-1.0, 0.3], -1.0),
913 (vec![2.0, -0.5], 1.0),
914 (vec![-2.0, 0.1], -1.0),
915 (vec![0.5, 0.5], 1.0),
916 (vec![-0.5, -0.5], -1.0),
917 (vec![1.5, -0.1], 1.0),
918 (vec![-1.5, 0.4], -1.0),
919 (vec![0.8, 0.0], 1.0),
920 (vec![-0.8, 0.2], -1.0),
921 ];
922 let mut p = Perceptron::new(2);
923 for _ in 0..20 {
924 for (x, y) in &data {
925 p.update(x, *y).expect("update");
926 }
927 }
928 for (x, y) in &data {
930 let pred = p.predict(x).expect("predict");
931 assert_eq!(pred, *y, "misclassified {:?} (label {})", x, y);
932 }
933 }
934
935 #[test]
936 fn test_perceptron_n_updates_increments() {
937 let mut p = Perceptron::new(2);
938 for i in 0..5 {
939 p.update(&[1.0, -1.0], 1.0).expect("update");
940 assert_eq!(p.n_updates(), i + 1);
941 }
942 }
943
944 #[test]
945 fn test_perceptron_dimension_mismatch() {
946 let mut p = Perceptron::new(3);
947 let err = p.update(&[1.0, 2.0], 1.0);
948 assert!(matches!(
949 err,
950 Err(OnlineError::DimensionMismatch {
951 expected: 3,
952 got: 2
953 })
954 ));
955 }
956
957 #[test]
962 fn test_pa_tau_basic() {
963 let mut pa = PassiveAggressive::new(2, PAVariant::PA);
966 let result = pa.update(&[1.0, 0.0], 1.0).expect("update");
967 assert!(approx_eq(result.loss, 1.0, 1e-10));
968 assert!(approx_eq(pa.weights()[0], 1.0, 1e-10));
970 }
971
972 #[test]
973 fn test_pa1_tau_clamped() {
974 let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
976 .with_aggressiveness(0.3)
977 .expect("valid C");
978 let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
980 assert!(approx_eq(pa.weights()[0], 0.3, 1e-10));
981 }
982
983 #[test]
984 fn test_pa2_tau_formula() {
985 let mut pa = PassiveAggressive::new(2, PAVariant::PAII)
987 .with_aggressiveness(1.0)
988 .expect("valid C");
989 let _r = pa.update(&[1.0, 0.0], 1.0).expect("update");
990 let expected_tau = 1.0 / 1.5;
991 assert!(
992 approx_eq(pa.weights()[0], expected_tau, 1e-10),
993 "expected {expected_tau}, got {}",
994 pa.weights()[0]
995 );
996 }
997
998 #[test]
999 fn test_pa_negative_c_returns_err() {
1000 let res = PassiveAggressive::new(2, PAVariant::PA).with_aggressiveness(-1.0);
1001 assert!(res.is_err());
1002 }
1003
1004 #[test]
1005 fn test_pa_dimension_mismatch() {
1006 let mut pa = PassiveAggressive::new(3, PAVariant::PA);
1007 let err = pa.update(&[1.0], 1.0);
1008 assert!(matches!(
1009 err,
1010 Err(OnlineError::DimensionMismatch {
1011 expected: 3,
1012 got: 1
1013 })
1014 ));
1015 }
1016
1017 #[test]
1022 fn test_ogd_squared_loss_gradient() {
1023 let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Squared).with_lr(0.1);
1025 let result = ogd.update(&[2.0, 0.0], 3.0).expect("update");
1026 assert!(approx_eq(result.loss, 4.5, 1e-10));
1028 assert!(approx_eq(ogd.weights()[0], 0.6, 1e-10));
1030 }
1031
1032 #[test]
1033 fn test_ogd_hinge_no_update_when_margin_ok() {
1034 let mut ogd = OnlineGradientDescent::new(2, OGDLoss::Hinge).with_lr(1.0);
1036 for _ in 0..20 {
1039 ogd.update(&[10.0, 0.0], 1.0).expect("update");
1040 }
1041 let w_before = ogd.weights().to_vec();
1042 let result = ogd.update(&[10.0, 0.0], 1.0).expect("update");
1044 assert_eq!(result.loss, 0.0, "expected zero hinge loss");
1045 assert_eq!(result.weight_delta_norm, 0.0);
1046 assert_eq!(ogd.weights(), w_before.as_slice());
1047 }
1048
1049 #[test]
1050 fn test_ogd_lr_decay_reduces_lr() {
1051 let mut ogd_decay = OnlineGradientDescent::new(1, OGDLoss::Squared)
1062 .with_lr(1.0)
1063 .with_lr_decay(1.0);
1064
1065 let mut ogd_nodecay = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(1.0);
1066
1067 for _ in 0..5 {
1069 ogd_decay.update(&[0.0], 1.0).expect("update");
1070 ogd_nodecay.update(&[0.0], 1.0).expect("update");
1071 }
1072 assert!(
1077 ogd_decay.bias.abs() <= ogd_nodecay.bias.abs() + 1e-9,
1078 "decaying lr should not exceed constant lr convergence; decay_bias={}, nodecay_bias={}",
1079 ogd_decay.bias,
1080 ogd_nodecay.bias
1081 );
1082
1083 let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared)
1085 .with_lr(1.0)
1086 .with_lr_decay(1.0);
1087 for _ in 0..9 {
1089 ogd.update(&[0.0], 0.0).expect("update"); }
1091 let lr_at_t9 = ogd.current_lr();
1092 assert!(
1093 lr_at_t9 < 0.5,
1094 "lr at t=9 should be 1/√10 ≈ 0.316, got {lr_at_t9}"
1095 );
1096 assert!(
1097 approx_eq(lr_at_t9, 1.0 / 10_f64.sqrt(), 1e-10),
1098 "expected 1/√10, got {lr_at_t9}"
1099 );
1100 }
1101
1102 #[test]
1103 fn test_ogd_l2_penalises_large_weights() {
1104 let mut ogd_no_reg = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.5);
1106 let mut ogd_l2 = OnlineGradientDescent::new(1, OGDLoss::Squared)
1107 .with_lr(0.5)
1108 .with_l2(0.5);
1109
1110 for _ in 0..30 {
1111 ogd_no_reg.update(&[1.0], 1.0).expect("update");
1112 ogd_l2.update(&[1.0], 1.0).expect("update");
1113 }
1114 assert!(
1116 ogd_l2.weights()[0].abs() < ogd_no_reg.weights()[0].abs(),
1117 "l2 reg should shrink weights; no_reg={}, l2={}",
1118 ogd_no_reg.weights()[0],
1119 ogd_l2.weights()[0]
1120 );
1121 }
1122
1123 #[test]
1124 fn test_ogd_dimension_mismatch() {
1125 let mut ogd = OnlineGradientDescent::new(3, OGDLoss::Squared);
1126 let err = ogd.update(&[1.0, 2.0], 0.0);
1127 assert!(matches!(
1128 err,
1129 Err(OnlineError::DimensionMismatch {
1130 expected: 3,
1131 got: 2
1132 })
1133 ));
1134 }
1135
1136 #[test]
1141 fn test_ftrl_l1_sparsity() {
1142 let mut ftrl = Ftrl::new(2).with_alpha(0.1).with_l1_l2(10.0, 0.0);
1144 ftrl.update(&[1.0, 0.0], 1.0).expect("update");
1146 assert_eq!(ftrl.weights()[0], 0.0, "weight should be zero due to L1");
1148 }
1149
1150 #[test]
1151 fn test_ftrl_adaptive_per_feature() {
1152 let mut ftrl = Ftrl::new(2).with_alpha(0.1);
1155 for _ in 0..50 {
1156 ftrl.update(&[1.0, 0.0], 1.0).expect("update");
1157 }
1158 assert!(ftrl.n_vec[0] > ftrl.n_vec[1]);
1160 }
1161
1162 #[test]
1163 fn test_ftrl_l1_zero_l2_zero_adagrad_like() {
1164 let mut ftrl = Ftrl::new(1).with_alpha(1.0).with_l1_l2(0.0, 0.0);
1167 for _ in 0..10 {
1168 ftrl.update(&[1.0], 1.0).expect("update");
1169 }
1170 assert!(
1172 ftrl.weights()[0] > 0.0,
1173 "weight should be positive; got {}",
1174 ftrl.weights()[0]
1175 );
1176 }
1177
1178 #[test]
1179 fn test_ftrl_dimension_mismatch() {
1180 let mut ftrl = Ftrl::new(3);
1181 let err = ftrl.update(&[1.0, 2.0], 1.0);
1182 assert!(matches!(
1183 err,
1184 Err(OnlineError::DimensionMismatch {
1185 expected: 3,
1186 got: 2
1187 })
1188 ));
1189 }
1190
1191 #[test]
1192 fn test_ftrl_predict_dimension_mismatch() {
1193 let ftrl = Ftrl::new(3);
1194 let err = ftrl.predict(&[1.0]);
1195 assert!(matches!(
1196 err,
1197 Err(OnlineError::DimensionMismatch {
1198 expected: 3,
1199 got: 1
1200 })
1201 ));
1202 }
1203
1204 #[test]
1209 fn test_online_stats_mistake_rate_zero_updates() {
1210 let stats = OnlineStats::default();
1211 assert_eq!(stats.mistake_rate(), 0.0);
1212 }
1213
1214 #[test]
1215 fn test_online_stats_mistake_rate_computation() {
1216 let mut stats = OnlineStats::default();
1217 let mistake = OnlineUpdateResult {
1218 loss: 1.0,
1219 weight_delta_norm: 0.5,
1220 was_mistake: true,
1221 };
1222 let correct = OnlineUpdateResult {
1223 loss: 0.0,
1224 weight_delta_norm: 0.0,
1225 was_mistake: false,
1226 };
1227 stats.update(&mistake);
1228 stats.update(&correct);
1229 stats.update(&mistake);
1230 assert!(approx_eq(stats.mistake_rate(), 2.0 / 3.0, 1e-10));
1232 }
1233
1234 #[test]
1235 fn test_online_stats_cumulative_loss() {
1236 let mut stats = OnlineStats::default();
1237 for loss_val in [0.5, 1.0, 1.5] {
1238 let r = OnlineUpdateResult {
1239 loss: loss_val,
1240 weight_delta_norm: 0.0,
1241 was_mistake: false,
1242 };
1243 stats.update(&r);
1244 }
1245 assert!(approx_eq(stats.cumulative_loss, 3.0, 1e-10));
1246 assert!(approx_eq(stats.mean_loss, 1.0, 1e-10));
1247 }
1248
1249 #[test]
1254 fn test_online_evaluate_train_true_updates_model() {
1255 let mut p = Perceptron::new(2);
1256 let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
1257 let (preds, _stats) = online_evaluate(&mut p, &data, true).expect("evaluate");
1258 assert_eq!(preds.len(), 2);
1259 assert_eq!(p.n_updates(), 2);
1261 }
1262
1263 #[test]
1264 fn test_online_evaluate_train_false_no_update() {
1265 let mut p = Perceptron::new(2);
1266 let data = vec![(vec![1.0, 0.0], 1.0), (vec![-1.0, 0.0], -1.0)];
1267 let (preds, _stats) = online_evaluate(&mut p, &data, false).expect("evaluate");
1268 assert_eq!(preds.len(), 2);
1269 assert_eq!(p.n_updates(), 0);
1271 }
1272
1273 #[test]
1278 fn test_perceptron_converges_linearly_separable_10_samples() {
1279 let data: Vec<(Vec<f64>, f64)> = vec![
1280 (vec![2.0, 1.0], 1.0),
1281 (vec![1.5, 0.8], 1.0),
1282 (vec![1.0, 0.5], 1.0),
1283 (vec![0.5, 0.2], 1.0),
1284 (vec![0.2, 0.1], 1.0),
1285 (vec![-0.2, -0.1], -1.0),
1286 (vec![-0.5, -0.3], -1.0),
1287 (vec![-1.0, -0.5], -1.0),
1288 (vec![-1.5, -0.7], -1.0),
1289 (vec![-2.0, -1.0], -1.0),
1290 ];
1291 let mut p = Perceptron::new(2);
1292 for _ in 0..50 {
1294 for (x, y) in &data {
1295 p.update(x, *y).expect("update");
1296 }
1297 }
1298 let mut correct = 0;
1299 for (x, y) in &data {
1300 let pred = p.predict(x).expect("predict");
1301 if pred == *y {
1302 correct += 1;
1303 }
1304 }
1305 assert_eq!(
1306 correct, 10,
1307 "Perceptron should converge on linearly separable data"
1308 );
1309 }
1310
1311 #[test]
1312 fn test_pa_converges_linearly_separable() {
1313 let data: Vec<(Vec<f64>, f64)> = vec![
1314 (vec![1.0, 0.5], 1.0),
1315 (vec![-1.0, -0.5], -1.0),
1316 (vec![2.0, 1.0], 1.0),
1317 (vec![-2.0, -1.0], -1.0),
1318 ];
1319 let mut pa = PassiveAggressive::new(2, PAVariant::PAI)
1320 .with_aggressiveness(1.0)
1321 .expect("valid C");
1322 for _ in 0..30 {
1323 for (x, y) in &data {
1324 pa.update(x, *y).expect("update");
1325 }
1326 }
1327 for (x, y) in &data {
1328 let pred = pa.predict(x).expect("predict");
1329 assert_eq!(pred, *y);
1330 }
1331 }
1332
1333 #[test]
1334 fn test_ogd_squared_converges_to_constant() {
1335 let mut ogd = OnlineGradientDescent::new(1, OGDLoss::Squared).with_lr(0.3);
1337 let x = vec![1.0];
1338 for _ in 0..200 {
1339 ogd.update(&x, 2.0).expect("update");
1340 }
1341 let pred = ogd.predict(&x).expect("predict");
1342 assert!(
1343 approx_eq(pred, 2.0, 0.1),
1344 "OGD should converge near 2.0, got {pred}"
1345 );
1346 }
1347
1348 #[test]
1349 fn test_ftrl_n_updates_increments() {
1350 let mut ftrl = Ftrl::new(2);
1351 for i in 0..7 {
1352 ftrl.update(&[1.0, 0.5], 1.0).expect("update");
1353 assert_eq!(ftrl.n_updates(), i + 1);
1354 }
1355 }
1356
1357 #[test]
1358 fn test_online_error_display() {
1359 let e = OnlineError::DimensionMismatch {
1360 expected: 5,
1361 got: 3,
1362 };
1363 let s = e.to_string();
1364 assert!(s.contains("5") && s.contains("3"));
1365
1366 let e2 = OnlineError::InvalidHyperparameter("C must be positive".to_string());
1367 assert!(e2.to_string().contains("C must be positive"));
1368
1369 let e3 = OnlineError::NotFitted;
1370 assert!(e3.to_string().contains("fitted"));
1371 }
1372}