1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
8use sklears_core::{
9 error::{Result as SklResult, SklearsError},
10 traits::{Estimator, Fit, Predict, Untrained},
11 types::Float,
12};
13use std::collections::HashMap;
14
15#[derive(Debug, Clone)]
21pub struct EarlyStoppingConfig {
22 pub min_delta: Float,
24 pub patience: usize,
26 pub monitor: String,
28 pub mode_max: bool,
30 pub restore_best_weights: bool,
32}
33
34impl Default for EarlyStoppingConfig {
35 fn default() -> Self {
36 Self {
37 min_delta: 1e-4,
38 patience: 10,
39 monitor: "loss".to_string(),
40 mode_max: false,
41 restore_best_weights: true,
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct EarlyStopping {
49 config: EarlyStoppingConfig,
50 best_value: Option<Float>,
51 best_iteration: usize,
52 wait_count: usize,
53 should_stop: bool,
54}
55
56impl EarlyStopping {
57 pub fn new(config: EarlyStoppingConfig) -> Self {
59 Self {
60 config,
61 best_value: None,
62 best_iteration: 0,
63 wait_count: 0,
64 should_stop: false,
65 }
66 }
67
68 pub fn update(&mut self, value: Float, iteration: usize) -> bool {
70 match self.best_value {
71 None => {
72 self.best_value = Some(value);
73 self.best_iteration = iteration;
74 false
75 }
76 Some(best) => {
77 let is_improvement = if self.config.mode_max {
78 value > best + self.config.min_delta
79 } else {
80 value < best - self.config.min_delta
81 };
82
83 if is_improvement {
84 self.best_value = Some(value);
85 self.best_iteration = iteration;
86 self.wait_count = 0;
87 false
88 } else {
89 self.wait_count += 1;
90 if self.wait_count >= self.config.patience {
91 self.should_stop = true;
92 true
93 } else {
94 false
95 }
96 }
97 }
98 }
99 }
100
101 pub fn should_stop(&self) -> bool {
103 self.should_stop
104 }
105
106 pub fn best_value(&self) -> Option<Float> {
108 self.best_value
109 }
110
111 pub fn best_iteration(&self) -> usize {
113 self.best_iteration
114 }
115}
116
117#[derive(Debug, Clone)]
123pub struct WarmStartRegressorConfig {
124 pub max_iter: usize,
126 pub learning_rate: Float,
128 pub alpha: Float,
130 pub tol: Float,
132 pub early_stopping: Option<EarlyStoppingConfig>,
134 pub verbose: bool,
136}
137
138impl Default for WarmStartRegressorConfig {
139 fn default() -> Self {
140 Self {
141 max_iter: 1000,
142 learning_rate: 0.01,
143 alpha: 0.0001,
144 tol: 1e-4,
145 early_stopping: Some(EarlyStoppingConfig::default()),
146 verbose: false,
147 }
148 }
149}
150
151#[derive(Debug, Clone)]
180pub struct WarmStartRegressor<S = Untrained> {
181 state: S,
182 config: WarmStartRegressorConfig,
183}
184
185#[derive(Debug, Clone)]
187pub struct WarmStartRegressorTrained {
188 pub coef: Array2<Float>,
190 pub intercept: Array1<Float>,
192 pub n_features: usize,
194 pub n_outputs: usize,
196 pub n_iter: usize,
198 pub loss_history: Vec<Float>,
200 pub best_loss: Float,
202 pub best_iter: usize,
204 pub best_coef: Option<Array2<Float>>,
206 pub best_intercept: Option<Array1<Float>>,
208 pub converged: bool,
210 pub config: WarmStartRegressorConfig,
212}
213
214impl WarmStartRegressor<Untrained> {
215 pub fn new() -> Self {
217 Self {
218 state: Untrained,
219 config: WarmStartRegressorConfig::default(),
220 }
221 }
222
223 pub fn config(mut self, config: WarmStartRegressorConfig) -> Self {
225 self.config = config;
226 self
227 }
228
229 pub fn max_iter(mut self, max_iter: usize) -> Self {
231 self.config.max_iter = max_iter;
232 self
233 }
234
235 pub fn learning_rate(mut self, lr: Float) -> Self {
237 self.config.learning_rate = lr;
238 self
239 }
240
241 pub fn early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
243 self.config.early_stopping = Some(config);
244 self
245 }
246}
247
248impl Default for WarmStartRegressor<Untrained> {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for WarmStartRegressor<Untrained> {
255 type Fitted = WarmStartRegressor<WarmStartRegressorTrained>;
256
257 fn fit(self, X: &ArrayView2<Float>, y: &ArrayView2<Float>) -> SklResult<Self::Fitted> {
258 if X.nrows() != y.nrows() {
259 return Err(SklearsError::InvalidInput(
260 "Number of samples in X and y must match".to_string(),
261 ));
262 }
263
264 let n_samples = X.nrows();
265 let n_features = X.ncols();
266 let n_outputs = y.ncols();
267
268 let mut coef = Array2::zeros((n_features, n_outputs));
270 let mut intercept = Array1::zeros(n_outputs);
271
272 let mut loss_history = Vec::new();
273 let mut best_loss = Float::INFINITY;
274 let mut best_iter = 0;
275 let mut best_coef = None;
276 let mut best_intercept = None;
277
278 let mut early_stopping = self
279 .config
280 .early_stopping
281 .as_ref()
282 .map(|cfg| EarlyStopping::new(cfg.clone()));
283
284 let mut converged = false;
285
286 for iter in 0..self.config.max_iter {
288 let mut total_loss = 0.0;
289
290 for i in 0..n_samples {
292 let x_i = X.row(i);
293 let y_i = y.row(i);
294
295 let pred = coef.t().dot(&x_i) + &intercept;
297
298 let error = &y_i - &pred;
300 total_loss += error.mapv(|x| x.powi(2)).sum();
301
302 for j in 0..n_features {
304 for k in 0..n_outputs {
305 let gradient = -error[k] * x_i[j] + self.config.alpha * coef[[j, k]];
306 coef[[j, k]] -= self.config.learning_rate * gradient;
307 }
308 }
309
310 for k in 0..n_outputs {
312 intercept[k] += self.config.learning_rate * error[k];
313 }
314 }
315
316 let avg_loss = total_loss / (n_samples as Float * n_outputs as Float);
318 loss_history.push(avg_loss);
319
320 if avg_loss < best_loss {
322 best_loss = avg_loss;
323 best_iter = iter;
324 if self.config.early_stopping.is_some() {
325 best_coef = Some(coef.clone());
326 best_intercept = Some(intercept.clone());
327 }
328 }
329
330 if iter > 0 && (loss_history[iter - 1] - avg_loss).abs() < self.config.tol {
332 converged = true;
333 if self.config.verbose {
334 println!("Converged at iteration {}", iter);
335 }
336 break;
337 }
338
339 if let Some(ref mut es) = early_stopping {
341 if es.update(avg_loss, iter) {
342 if self.config.verbose {
343 println!("Early stopping at iteration {}", iter);
344 }
345 break;
346 }
347 }
348
349 if self.config.verbose && iter % 100 == 0 {
350 println!("Iteration {}: loss = {:.6}", iter, avg_loss);
351 }
352 }
353
354 if let Some(cfg) = &self.config.early_stopping {
356 if cfg.restore_best_weights {
357 if let Some(ref best_c) = best_coef {
358 coef = best_c.clone();
359 }
360 if let Some(ref best_i) = best_intercept {
361 intercept = best_i.clone();
362 }
363 }
364 }
365
366 Ok(WarmStartRegressor {
367 state: WarmStartRegressorTrained {
368 coef,
369 intercept,
370 n_features,
371 n_outputs,
372 n_iter: loss_history.len(),
373 loss_history,
374 best_loss,
375 best_iter,
376 best_coef,
377 best_intercept,
378 converged,
379 config: self.config,
380 },
381 config: WarmStartRegressorConfig::default(),
382 })
383 }
384}
385
386impl WarmStartRegressor<WarmStartRegressorTrained> {
387 pub fn continue_training(
389 mut self,
390 X: &ArrayView2<Float>,
391 y: &ArrayView2<Float>,
392 additional_iterations: usize,
393 ) -> SklResult<Self> {
394 if X.nrows() != y.nrows() {
395 return Err(SklearsError::InvalidInput(
396 "Number of samples in X and y must match".to_string(),
397 ));
398 }
399
400 if X.ncols() != self.state.n_features || y.ncols() != self.state.n_outputs {
401 return Err(SklearsError::InvalidInput(
402 "Feature or output dimensions do not match".to_string(),
403 ));
404 }
405
406 let n_samples = X.nrows();
407
408 let mut early_stopping = self
409 .state
410 .config
411 .early_stopping
412 .as_ref()
413 .map(|cfg| EarlyStopping::new(cfg.clone()));
414
415 for iter in 0..additional_iterations {
417 let mut total_loss = 0.0;
418
419 for i in 0..n_samples {
421 let x_i = X.row(i);
422 let y_i = y.row(i);
423
424 let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
425 let error = &y_i - &pred;
426 total_loss += error.mapv(|x| x.powi(2)).sum();
427
428 for j in 0..self.state.n_features {
430 for k in 0..self.state.n_outputs {
431 let gradient =
432 -error[k] * x_i[j] + self.state.config.alpha * self.state.coef[[j, k]];
433 self.state.coef[[j, k]] -= self.state.config.learning_rate * gradient;
434 }
435 }
436
437 for k in 0..self.state.n_outputs {
439 self.state.intercept[k] += self.state.config.learning_rate * error[k];
440 }
441 }
442
443 let avg_loss = total_loss / (n_samples as Float * self.state.n_outputs as Float);
444 self.state.loss_history.push(avg_loss);
445
446 if avg_loss < self.state.best_loss {
448 self.state.best_loss = avg_loss;
449 self.state.best_iter = self.state.n_iter + iter;
450 if self.state.config.early_stopping.is_some() {
451 self.state.best_coef = Some(self.state.coef.clone());
452 self.state.best_intercept = Some(self.state.intercept.clone());
453 }
454 }
455
456 let loss_len = self.state.loss_history.len();
458 if loss_len > 1 {
459 let prev_loss = self.state.loss_history[loss_len - 2];
460 if (prev_loss - avg_loss).abs() < self.state.config.tol {
461 self.state.converged = true;
462 break;
463 }
464 }
465
466 if let Some(ref mut es) = early_stopping {
468 if es.update(avg_loss, self.state.n_iter + iter) {
469 break;
470 }
471 }
472 }
473
474 self.state.n_iter += additional_iterations;
475 Ok(self)
476 }
477
478 pub fn loss_history(&self) -> &[Float] {
480 &self.state.loss_history
481 }
482
483 pub fn best_loss(&self) -> Float {
485 self.state.best_loss
486 }
487
488 pub fn converged(&self) -> bool {
490 self.state.converged
491 }
492
493 pub fn coef(&self) -> &Array2<Float> {
495 &self.state.coef
496 }
497
498 pub fn n_iter(&self) -> usize {
500 self.state.n_iter
501 }
502}
503
504impl Predict<ArrayView2<'_, Float>, Array2<Float>>
505 for WarmStartRegressor<WarmStartRegressorTrained>
506{
507 fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
508 if X.ncols() != self.state.n_features {
509 return Err(SklearsError::InvalidInput(format!(
510 "Expected {} features, got {}",
511 self.state.n_features,
512 X.ncols()
513 )));
514 }
515
516 let n_samples = X.nrows();
517 let mut predictions = Array2::zeros((n_samples, self.state.n_outputs));
518
519 for i in 0..n_samples {
520 let x_i = X.row(i);
521 let pred = self.state.coef.t().dot(&x_i) + &self.state.intercept;
522 predictions.row_mut(i).assign(&pred);
523 }
524
525 Ok(predictions)
526 }
527}
528
529impl Estimator for WarmStartRegressor<Untrained> {
530 type Config = WarmStartRegressorConfig;
531 type Error = SklearsError;
532 type Float = Float;
533
534 fn config(&self) -> &Self::Config {
535 &self.config
536 }
537}
538
539impl Estimator for WarmStartRegressor<WarmStartRegressorTrained> {
540 type Config = WarmStartRegressorConfig;
541 type Error = SklearsError;
542 type Float = Float;
543
544 fn config(&self) -> &Self::Config {
545 &self.state.config
546 }
547}
548
549#[derive(Debug, Clone)]
555pub struct PredictionCache {
556 cache: HashMap<u64, Array2<Float>>,
558 max_size: usize,
560 hits: usize,
562 misses: usize,
564}
565
566impl PredictionCache {
567 pub fn new(max_size: usize) -> Self {
569 Self {
570 cache: HashMap::new(),
571 max_size,
572 hits: 0,
573 misses: 0,
574 }
575 }
576
577 pub fn get(&mut self, X: &ArrayView2<Float>) -> Option<Array2<Float>> {
579 let hash = self.hash_input(X);
580 if let Some(pred) = self.cache.get(&hash) {
581 self.hits += 1;
582 Some(pred.clone())
583 } else {
584 self.misses += 1;
585 None
586 }
587 }
588
589 pub fn put(&mut self, X: &ArrayView2<Float>, prediction: Array2<Float>) {
591 if self.cache.len() >= self.max_size {
592 if let Some(first_key) = self.cache.keys().next().copied() {
594 self.cache.remove(&first_key);
595 }
596 }
597 let hash = self.hash_input(X);
598 self.cache.insert(hash, prediction);
599 }
600
601 pub fn clear(&mut self) {
603 self.cache.clear();
604 }
605
606 pub fn stats(&self) -> (usize, usize, Float) {
608 let total = self.hits + self.misses;
609 let hit_rate = if total > 0 {
610 self.hits as Float / total as Float
611 } else {
612 0.0
613 };
614 (self.hits, self.misses, hit_rate)
615 }
616
617 fn hash_input(&self, X: &ArrayView2<Float>) -> u64 {
619 use std::collections::hash_map::DefaultHasher;
620 use std::hash::{Hash, Hasher};
621
622 let mut hasher = DefaultHasher::new();
623 for &val in X.iter() {
624 val.to_bits().hash(&mut hasher);
625 }
626 hasher.finish()
627 }
628}
629
630#[cfg(test)]
635mod tests {
636 use super::*;
637 use approx::assert_abs_diff_eq;
638 use scirs2_core::ndarray::array;
640
641 #[test]
642 fn test_early_stopping_basic() {
643 let config = EarlyStoppingConfig {
644 min_delta: 0.1,
645 patience: 3,
646 mode_max: false,
647 ..Default::default()
648 };
649
650 let mut es = EarlyStopping::new(config);
651
652 assert!(!es.update(1.0, 0));
653 assert!(!es.update(0.8, 1)); assert!(!es.update(0.79, 2)); assert!(!es.update(0.78, 3)); assert!(es.update(0.77, 4)); }
658
659 #[test]
660 fn test_early_stopping_mode_max() {
661 let config = EarlyStoppingConfig {
662 min_delta: 0.01,
663 patience: 2,
664 mode_max: true,
665 ..Default::default()
666 };
667
668 let mut es = EarlyStopping::new(config);
669
670 assert!(!es.update(0.5, 0));
671 assert!(!es.update(0.6, 1)); assert!(!es.update(0.59, 2)); assert!(es.update(0.58, 3)); }
675
676 #[test]
677 #[allow(non_snake_case)]
678 fn test_warm_start_regressor_basic() {
679 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
680 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
681
682 let model = WarmStartRegressor::new().max_iter(100).learning_rate(0.1);
683
684 let trained = model.fit(&X.view(), &y.view()).unwrap();
685 let predictions = trained.predict(&X.view()).unwrap();
686
687 assert_eq!(predictions.dim(), (3, 2));
688 assert!(trained.n_iter() > 0);
689 }
690
691 #[test]
692 #[allow(non_snake_case)]
693 fn test_warm_start_continue_training() {
694 let X = array![[1.0, 2.0], [2.0, 3.0]];
695 let y = array![[1.0, 2.0], [2.0, 3.0]];
696
697 let model = WarmStartRegressor::new().max_iter(10).learning_rate(0.1);
698
699 let trained = model.fit(&X.view(), &y.view()).unwrap();
700 let initial_iter = trained.n_iter();
701 let initial_loss = trained.loss_history().last().copied().unwrap();
702
703 let continued = trained.continue_training(&X.view(), &y.view(), 20).unwrap();
705 let final_loss = continued.loss_history().last().copied().unwrap();
706
707 assert!(continued.n_iter() > initial_iter);
708 assert!(final_loss <= initial_loss + 1.0); }
711
712 #[test]
713 #[allow(non_snake_case)]
714 fn test_warm_start_with_early_stopping() {
715 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
716 let y = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]];
717
718 let es_config = EarlyStoppingConfig {
719 patience: 5,
720 min_delta: 1e-6,
721 ..Default::default()
722 };
723
724 let model = WarmStartRegressor::new()
725 .max_iter(1000)
726 .early_stopping(es_config)
727 .learning_rate(0.1);
728
729 let trained = model.fit(&X.view(), &y.view()).unwrap();
730
731 assert!(trained.n_iter() < 1000);
733 assert!(trained.best_loss() < Float::INFINITY);
734 }
735
736 #[test]
737 fn test_prediction_cache_basic() {
738 let mut cache = PredictionCache::new(10);
739
740 let X = array![[1.0, 2.0], [2.0, 3.0]];
741 let pred = array![[1.0, 2.0], [2.0, 3.0]];
742
743 assert!(cache.get(&X.view()).is_none());
745
746 cache.put(&X.view(), pred.clone());
748 let cached = cache.get(&X.view()).unwrap();
749
750 assert_eq!(cached.dim(), pred.dim());
751 assert_eq!(cache.stats().0, 1); assert_eq!(cache.stats().1, 1); }
754
755 #[test]
756 fn test_prediction_cache_eviction() {
757 let mut cache = PredictionCache::new(2);
758
759 let X1 = array![[1.0, 2.0]];
760 let X2 = array![[2.0, 3.0]];
761 let X3 = array![[3.0, 4.0]];
762 let pred = array![[1.0, 2.0]];
763
764 cache.put(&X1.view(), pred.clone());
765 cache.put(&X2.view(), pred.clone());
766 cache.put(&X3.view(), pred.clone()); assert_eq!(cache.cache.len(), 2);
769 }
770
771 #[test]
772 fn test_cache_stats() {
773 let mut cache = PredictionCache::new(10);
774
775 let X = array![[1.0, 2.0]];
776 let pred = array![[1.0, 2.0]];
777
778 cache.get(&X.view()); cache.put(&X.view(), pred);
780 cache.get(&X.view()); cache.get(&X.view()); let (hits, misses, hit_rate) = cache.stats();
784 assert_eq!(hits, 2);
785 assert_eq!(misses, 1);
786 assert_abs_diff_eq!(hit_rate, 2.0 / 3.0, epsilon = 1e-6);
787 }
788}