1#[cfg(feature = "early-stopping")]
4use crate::early_stopping::{train_validation_split, EarlyStopping, EarlyStoppingConfig};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use sklears_core::{
7 error::{Result, SklearsError},
8 types::Float,
9};
10
11#[inline]
13fn soft_threshold(x: Float, lambda: Float) -> Float {
14 if x > lambda {
15 x - lambda
16 } else if x < -lambda {
17 x + lambda
18 } else {
19 0.0
20 }
21}
22
23pub struct CoordinateDescentSolver {
25 pub max_iter: usize,
27 pub tol: Float,
29 pub cyclic: bool,
31 #[cfg(feature = "early-stopping")]
33 pub early_stopping_config: Option<EarlyStoppingConfig>,
34}
35
36impl Default for CoordinateDescentSolver {
37 fn default() -> Self {
38 Self {
39 max_iter: 1000,
40 tol: 1e-4,
41 cyclic: true,
42 #[cfg(feature = "early-stopping")]
43 early_stopping_config: None,
44 }
45 }
46}
47
48impl CoordinateDescentSolver {
49 pub fn solve_lasso(
53 &self,
54 x: &Array2<Float>,
55 y: &Array1<Float>,
56 alpha: Float,
57 fit_intercept: bool,
58 ) -> Result<(Array1<Float>, Option<Float>)> {
59 self.solve_lasso_with_warm_start(x, y, alpha, fit_intercept, None, None)
60 }
61
62 pub fn solve_lasso_with_warm_start(
66 &self,
67 x: &Array2<Float>,
68 y: &Array1<Float>,
69 alpha: Float,
70 fit_intercept: bool,
71 initial_coef: Option<&Array1<Float>>,
72 initial_intercept: Option<Float>,
73 ) -> Result<(Array1<Float>, Option<Float>)> {
74 let n_samples = x.nrows() as Float;
75 let n_features = x.ncols();
76
77 let mut coef = match initial_coef {
79 Some(init_coef) => {
80 if init_coef.len() != n_features {
81 return Err(SklearsError::FeatureMismatch {
82 expected: n_features,
83 actual: init_coef.len(),
84 });
85 }
86 init_coef.clone()
87 }
88 None => Array1::zeros(n_features),
89 };
90
91 let mut intercept = if fit_intercept {
92 initial_intercept.unwrap_or_else(|| y.mean().unwrap_or(0.0))
93 } else {
94 0.0
95 };
96
97 let feature_norms: Array1<Float> = x
99 .axis_iter(Axis(1))
100 .map(|col| col.dot(&col) / n_samples)
101 .collect();
102
103 let mut converged = false;
105 for _iter in 0..self.max_iter {
106 let old_coef = coef.clone();
107
108 if fit_intercept {
110 let residuals = y - &x.dot(&coef) - intercept;
111 intercept = residuals.mean().unwrap_or(0.0);
112 }
113
114 for j in 0..n_features {
116 if feature_norms[j] == 0.0 {
118 coef[j] = 0.0;
119 continue;
120 }
121
122 let mut residuals = y - &x.dot(&coef);
124 if fit_intercept {
125 residuals -= intercept;
126 }
127 residuals = residuals + x.column(j).to_owned() * coef[j];
128
129 let gradient = x.column(j).dot(&residuals) / n_samples;
131
132 coef[j] = soft_threshold(gradient, alpha) / feature_norms[j];
134 }
135
136 let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
138 if coef_change < self.tol {
139 converged = true;
140 break;
141 }
142 }
143
144 if !converged {
145 eprintln!(
146 "Warning: Coordinate descent did not converge. Consider increasing max_iter."
147 );
148 }
149
150 let intercept_opt = if fit_intercept { Some(intercept) } else { None };
151 Ok((coef, intercept_opt))
152 }
153
154 pub fn solve_elastic_net(
159 &self,
160 x: &Array2<Float>,
161 y: &Array1<Float>,
162 alpha: Float,
163 l1_ratio: Float,
164 fit_intercept: bool,
165 ) -> Result<(Array1<Float>, Option<Float>)> {
166 self.solve_elastic_net_with_warm_start(x, y, alpha, l1_ratio, fit_intercept, None, None)
167 }
168
169 #[allow(clippy::too_many_arguments)]
174 pub fn solve_elastic_net_with_warm_start(
175 &self,
176 x: &Array2<Float>,
177 y: &Array1<Float>,
178 alpha: Float,
179 l1_ratio: Float,
180 fit_intercept: bool,
181 initial_coef: Option<&Array1<Float>>,
182 initial_intercept: Option<Float>,
183 ) -> Result<(Array1<Float>, Option<Float>)> {
184 if !(0.0..=1.0).contains(&l1_ratio) {
185 return Err(SklearsError::InvalidParameter {
186 name: "l1_ratio".to_string(),
187 reason: "must be between 0 and 1".to_string(),
188 });
189 }
190
191 let n_samples = x.nrows() as Float;
192 let n_features = x.ncols();
193
194 let l1_reg = alpha * l1_ratio;
196 let l2_reg = alpha * (1.0 - l1_ratio);
197
198 let mut coef = match initial_coef {
200 Some(init_coef) => {
201 if init_coef.len() != n_features {
202 return Err(SklearsError::FeatureMismatch {
203 expected: n_features,
204 actual: init_coef.len(),
205 });
206 }
207 init_coef.clone()
208 }
209 None => Array1::zeros(n_features),
210 };
211
212 let mut intercept = if fit_intercept {
213 initial_intercept.unwrap_or_else(|| y.mean().unwrap_or(0.0))
214 } else {
215 0.0
216 };
217
218 let feature_norms: Array1<Float> = x
220 .axis_iter(Axis(1))
221 .map(|col| col.dot(&col) / n_samples + l2_reg)
222 .collect();
223
224 let mut converged = false;
226 for _iter in 0..self.max_iter {
227 let old_coef = coef.clone();
228
229 if fit_intercept {
231 let residuals = y - &x.dot(&coef) - intercept;
232 intercept = residuals.mean().unwrap_or(0.0);
233 }
234
235 for j in 0..n_features {
237 if feature_norms[j] == 0.0 {
239 coef[j] = 0.0;
240 continue;
241 }
242
243 let mut residuals = y - &x.dot(&coef);
245 if fit_intercept {
246 residuals -= intercept;
247 }
248 residuals = residuals + x.column(j).to_owned() * coef[j];
249
250 let gradient = x.column(j).dot(&residuals) / n_samples;
252
253 coef[j] = soft_threshold(gradient, l1_reg) / feature_norms[j];
255 }
256
257 let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
259 if coef_change < self.tol {
260 converged = true;
261 break;
262 }
263 }
264
265 if !converged {
266 eprintln!(
267 "Warning: Coordinate descent did not converge. Consider increasing max_iter."
268 );
269 }
270
271 let intercept_opt = if fit_intercept { Some(intercept) } else { None };
272 Ok((coef, intercept_opt))
273 }
274
275 pub fn with_early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
277 self.early_stopping_config = Some(config);
278 self
279 }
280
281 pub fn solve_lasso_with_early_stopping(
286 &self,
287 x: &Array2<Float>,
288 y: &Array1<Float>,
289 alpha: Float,
290 fit_intercept: bool,
291 ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
292 let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
293 SklearsError::InvalidInput(
294 "Early stopping config not set. Use with_early_stopping() first.".to_string(),
295 )
296 })?;
297
298 let (x_train, y_train, x_val, y_val) = train_validation_split(
300 x,
301 y,
302 early_stopping_config.validation_split,
303 early_stopping_config.shuffle,
304 early_stopping_config.random_state,
305 )?;
306
307 self.solve_lasso_with_early_stopping_split(
308 &x_train,
309 &y_train,
310 &x_val,
311 &y_val,
312 alpha,
313 fit_intercept,
314 )
315 }
316
317 pub fn solve_lasso_with_early_stopping_split(
319 &self,
320 x_train: &Array2<Float>,
321 y_train: &Array1<Float>,
322 x_val: &Array2<Float>,
323 y_val: &Array1<Float>,
324 alpha: Float,
325 fit_intercept: bool,
326 ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
327 let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
328 SklearsError::InvalidInput(
329 "Early stopping config not set. Use with_early_stopping() first.".to_string(),
330 )
331 })?;
332
333 let mut early_stopping = EarlyStopping::new(early_stopping_config.clone());
334
335 let n_samples = x_train.nrows() as Float;
336 let n_features = x_train.ncols();
337
338 let mut coef = Array1::zeros(n_features);
340 let mut intercept = if fit_intercept {
341 y_train.mean().unwrap_or(0.0)
342 } else {
343 0.0
344 };
345
346 let mut best_coef = coef.clone();
348 let mut best_intercept = intercept;
349
350 let feature_norms: Array1<Float> = x_train
352 .axis_iter(Axis(1))
353 .map(|col| col.dot(&col) / n_samples)
354 .collect();
355
356 let mut validation_scores = Vec::new();
357 let mut converged = false;
358
359 for iter in 0..self.max_iter {
361 let old_coef = coef.clone();
362
363 if fit_intercept {
365 let residuals = y_train - &x_train.dot(&coef) - intercept;
366 intercept = residuals.mean().unwrap_or(0.0);
367 }
368
369 for j in 0..n_features {
371 if feature_norms[j] == 0.0 {
372 coef[j] = 0.0;
373 continue;
374 }
375
376 let mut residuals = y_train - &x_train.dot(&coef);
378 if fit_intercept {
379 residuals -= intercept;
380 }
381 residuals = residuals + x_train.column(j).to_owned() * coef[j];
382
383 let gradient = x_train.column(j).dot(&residuals) / n_samples;
385
386 coef[j] = soft_threshold(gradient, alpha) / feature_norms[j];
388 }
389
390 let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
392 if coef_change < self.tol {
393 converged = true;
394 }
395
396 let val_predictions = if fit_intercept {
398 x_val.dot(&coef) + intercept
399 } else {
400 x_val.dot(&coef)
401 };
402
403 let r2_score = compute_r2_score(&val_predictions, y_val);
404 validation_scores.push(r2_score);
405
406 let should_continue = early_stopping.update(r2_score);
408
409 if early_stopping_config.restore_best_weights
411 && early_stopping.best_iteration() == iter + 1
412 {
413 best_coef = coef.clone();
414 best_intercept = intercept;
415 }
416
417 if !should_continue || converged {
418 break;
419 }
420 }
421
422 let (final_coef, final_intercept) = if early_stopping_config.restore_best_weights {
424 (best_coef, best_intercept)
425 } else {
426 (coef, intercept)
427 };
428
429 let validation_info = ValidationInfo {
430 validation_scores,
431 best_score: early_stopping.best_score(),
432 best_iteration: early_stopping.best_iteration(),
433 stopped_early: early_stopping.should_stop(),
434 converged,
435 };
436
437 let intercept_opt = if fit_intercept {
438 Some(final_intercept)
439 } else {
440 None
441 };
442 Ok((final_coef, intercept_opt, validation_info))
443 }
444
445 pub fn solve_elastic_net_with_early_stopping(
447 &self,
448 x: &Array2<Float>,
449 y: &Array1<Float>,
450 alpha: Float,
451 l1_ratio: Float,
452 fit_intercept: bool,
453 ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
454 let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
455 SklearsError::InvalidInput(
456 "Early stopping config not set. Use with_early_stopping() first.".to_string(),
457 )
458 })?;
459
460 let (x_train, y_train, x_val, y_val) = train_validation_split(
462 x,
463 y,
464 early_stopping_config.validation_split,
465 early_stopping_config.shuffle,
466 early_stopping_config.random_state,
467 )?;
468
469 self.solve_elastic_net_with_early_stopping_split(
470 &x_train,
471 &y_train,
472 &x_val,
473 &y_val,
474 alpha,
475 l1_ratio,
476 fit_intercept,
477 )
478 }
479
480 #[allow(clippy::too_many_arguments)]
482 pub fn solve_elastic_net_with_early_stopping_split(
483 &self,
484 x_train: &Array2<Float>,
485 y_train: &Array1<Float>,
486 x_val: &Array2<Float>,
487 y_val: &Array1<Float>,
488 alpha: Float,
489 l1_ratio: Float,
490 fit_intercept: bool,
491 ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
492 if !(0.0..=1.0).contains(&l1_ratio) {
493 return Err(SklearsError::InvalidParameter {
494 name: "l1_ratio".to_string(),
495 reason: "must be between 0 and 1".to_string(),
496 });
497 }
498
499 let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
500 SklearsError::InvalidInput(
501 "Early stopping config not set. Use with_early_stopping() first.".to_string(),
502 )
503 })?;
504
505 let mut early_stopping = EarlyStopping::new(early_stopping_config.clone());
506
507 let n_samples = x_train.nrows() as Float;
508 let n_features = x_train.ncols();
509
510 let l1_reg = alpha * l1_ratio;
512 let l2_reg = alpha * (1.0 - l1_ratio);
513
514 let mut coef = Array1::zeros(n_features);
516 let mut intercept = if fit_intercept {
517 y_train.mean().unwrap_or(0.0)
518 } else {
519 0.0
520 };
521
522 let mut best_coef = coef.clone();
524 let mut best_intercept = intercept;
525
526 let feature_norms: Array1<Float> = x_train
528 .axis_iter(Axis(1))
529 .map(|col| col.dot(&col) / n_samples + l2_reg)
530 .collect();
531
532 let mut validation_scores = Vec::new();
533 let mut converged = false;
534
535 for iter in 0..self.max_iter {
537 let old_coef = coef.clone();
538
539 if fit_intercept {
541 let residuals = y_train - &x_train.dot(&coef) - intercept;
542 intercept = residuals.mean().unwrap_or(0.0);
543 }
544
545 for j in 0..n_features {
547 if feature_norms[j] == 0.0 {
548 coef[j] = 0.0;
549 continue;
550 }
551
552 let mut residuals = y_train - &x_train.dot(&coef);
554 if fit_intercept {
555 residuals -= intercept;
556 }
557 residuals = residuals + x_train.column(j).to_owned() * coef[j];
558
559 let gradient = x_train.column(j).dot(&residuals) / n_samples;
561
562 coef[j] = soft_threshold(gradient, l1_reg) / feature_norms[j];
564 }
565
566 let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
568 if coef_change < self.tol {
569 converged = true;
570 }
571
572 let val_predictions = if fit_intercept {
574 x_val.dot(&coef) + intercept
575 } else {
576 x_val.dot(&coef)
577 };
578
579 let r2_score = compute_r2_score(&val_predictions, y_val);
580 validation_scores.push(r2_score);
581
582 let should_continue = early_stopping.update(r2_score);
584
585 if early_stopping_config.restore_best_weights
587 && early_stopping.best_iteration() == iter + 1
588 {
589 best_coef = coef.clone();
590 best_intercept = intercept;
591 }
592
593 if !should_continue || converged {
594 break;
595 }
596 }
597
598 let (final_coef, final_intercept) = if early_stopping_config.restore_best_weights {
600 (best_coef, best_intercept)
601 } else {
602 (coef, intercept)
603 };
604
605 let validation_info = ValidationInfo {
606 validation_scores,
607 best_score: early_stopping.best_score(),
608 best_iteration: early_stopping.best_iteration(),
609 stopped_early: early_stopping.should_stop(),
610 converged,
611 };
612
613 let intercept_opt = if fit_intercept {
614 Some(final_intercept)
615 } else {
616 None
617 };
618 Ok((final_coef, intercept_opt, validation_info))
619 }
620}
621
622#[derive(Debug, Clone)]
624pub struct ValidationInfo {
625 pub validation_scores: Vec<Float>,
627 pub best_score: Option<Float>,
629 pub best_iteration: usize,
631 pub stopped_early: bool,
633 pub converged: bool,
635}
636
637pub fn compute_r2_score(y_pred: &Array1<Float>, y_true: &Array1<Float>) -> Float {
639 let y_mean = y_true.mean().unwrap_or(0.0);
640 let ss_res = (y_pred - y_true).mapv(|x| x * x).sum();
641 let ss_tot = y_true.mapv(|yi| (yi - y_mean).powi(2)).sum();
642
643 if ss_tot == 0.0 {
644 1.0
645 } else {
646 1.0 - (ss_res / ss_tot)
647 }
648}
649
650#[allow(non_snake_case)]
651#[cfg(test)]
652mod tests {
653 use super::*;
654 use approx::assert_abs_diff_eq;
655 use scirs2_core::ndarray::array;
656
657 #[test]
658 fn test_soft_threshold() {
659 assert_eq!(soft_threshold(2.0, 1.0), 1.0);
660 assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
661 assert_eq!(soft_threshold(0.5, 1.0), 0.0);
662 assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
663 }
664
665 #[test]
666 fn test_lasso_simple() {
667 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
669 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
670
671 let solver = CoordinateDescentSolver::default();
672
673 let (coef, intercept) = solver.solve_lasso(&x, &y, 0.01, false).unwrap();
675 assert_abs_diff_eq!(coef[0], 2.0, epsilon = 0.1);
676 assert_eq!(intercept, None);
677
678 let (coef, _intercept) = solver.solve_lasso(&x, &y, 1.0, false).unwrap();
680 assert!(coef[0] < 2.0);
681 assert!(coef[0] > 0.0);
682 }
683
684 #[test]
685 fn test_elastic_net() {
686 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
687 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
688
689 let solver = CoordinateDescentSolver::default();
690
691 let (coef, _) = solver.solve_elastic_net(&x, &y, 0.1, 0.5, false).unwrap();
693
694 let (lasso_coef, _) = solver.solve_lasso(&x, &y, 0.1, false).unwrap();
696
697 assert!(coef[0] != lasso_coef[0]);
699 }
700
701 #[test]
702 fn test_early_stopping_lasso() {
703 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
704
705 let n_samples = 100;
707 let n_features = 5;
708 let mut x = Array2::zeros((n_samples, n_features));
709 let mut y = Array1::zeros(n_samples);
710
711 for i in 0..n_samples {
713 for j in 0..n_features {
714 x[[i, j]] = (i * j + 1) as Float / 10.0;
715 }
716 y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.5 * (i as Float % 3.0);
717 }
718
719 let early_stopping_config = EarlyStoppingConfig {
720 criterion: StoppingCriterion::Patience(5),
721 validation_split: 0.2,
722 shuffle: true,
723 random_state: Some(42),
724 higher_is_better: true,
725 min_iterations: 3,
726 restore_best_weights: true,
727 };
728
729 let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
730
731 let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.01, true);
732 assert!(result.is_ok());
733
734 let (coef, intercept, validation_info) = result.unwrap();
735
736 assert_eq!(coef.len(), n_features);
738 assert!(intercept.is_some());
739 assert!(!validation_info.validation_scores.is_empty());
740 assert!(validation_info.best_score.is_some());
741
742 assert!(validation_info.stopped_early || validation_info.converged);
744 }
745
746 #[test]
747 fn test_early_stopping_elastic_net() {
748 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
749
750 let n_samples = 80;
752 let n_features = 4;
753 let mut x = Array2::zeros((n_samples, n_features));
754 let mut y = Array1::zeros(n_samples);
755
756 for i in 0..n_samples {
757 for j in 0..n_features {
758 x[[i, j]] = (i + j) as Float / 10.0;
759 }
760 y[i] = 1.0 * x[[i, 0]] + 2.0 * x[[i, 1]] + 0.1 * (i as Float);
761 }
762
763 let early_stopping_config = EarlyStoppingConfig {
764 criterion: StoppingCriterion::TolerancePatience {
765 tolerance: 0.01,
766 patience: 3,
767 },
768 validation_split: 0.25,
769 shuffle: false,
770 random_state: None,
771 higher_is_better: true,
772 min_iterations: 2,
773 restore_best_weights: false,
774 };
775
776 let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
777
778 let result = solver.solve_elastic_net_with_early_stopping(&x, &y, 0.1, 0.5, true);
779 assert!(result.is_ok());
780
781 let (coef, intercept, validation_info) = result.unwrap();
782
783 assert_eq!(coef.len(), n_features);
785 assert!(intercept.is_some());
786 assert!(!validation_info.validation_scores.is_empty());
787 assert!(validation_info.best_score.is_some());
788
789 assert!(validation_info.best_iteration > 0);
791 }
792
793 #[test]
794 fn test_early_stopping_with_presplit_data() {
795 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
796
797 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
799 let y_train = array![3.0, 5.0, 7.0, 9.0]; let x_val = array![[5.0, 6.0], [6.0, 7.0]];
801 let y_val = array![11.0, 13.0];
802
803 let early_stopping_config = EarlyStoppingConfig {
804 criterion: StoppingCriterion::TargetScore(0.8),
805 validation_split: 0.2, shuffle: false,
807 random_state: None,
808 higher_is_better: true,
809 min_iterations: 1,
810 restore_best_weights: true,
811 };
812
813 let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
814
815 let result = solver
816 .solve_lasso_with_early_stopping_split(&x_train, &y_train, &x_val, &y_val, 0.001, true);
817 assert!(result.is_ok());
818
819 let (coef, intercept, validation_info) = result.unwrap();
820
821 assert_eq!(coef.len(), 2);
822 assert!(intercept.is_some());
823 assert!(!validation_info.validation_scores.is_empty());
824 }
825
826 #[test]
827 fn test_validation_info_structure() {
828 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
829
830 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]];
831 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
832
833 let early_stopping_config = EarlyStoppingConfig {
834 criterion: StoppingCriterion::Patience(2),
835 validation_split: 0.25,
836 shuffle: false,
837 random_state: Some(123),
838 higher_is_better: true,
839 min_iterations: 1,
840 restore_best_weights: true,
841 };
842
843 let solver = CoordinateDescentSolver {
844 max_iter: 10,
845 tol: 1e-6,
846 cyclic: true,
847 early_stopping_config: Some(early_stopping_config),
848 };
849
850 let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.01, false);
851 assert!(result.is_ok());
852
853 let (_coef, _intercept, validation_info) = result.unwrap();
854
855 assert!(!validation_info.validation_scores.is_empty());
857 assert!(validation_info.best_score.is_some());
858 assert!(validation_info.best_iteration >= 1);
859
860 assert!(validation_info.stopped_early || validation_info.converged);
862
863 for score in &validation_info.validation_scores {
865 assert!(score.is_finite());
866 }
867 }
868
869 #[test]
870 fn test_r2_score_computation() {
871 let y_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
872 let y_pred = array![1.1, 1.9, 3.1, 3.9, 5.1];
873
874 let r2 = compute_r2_score(&y_pred, &y_true);
875 assert!(r2 > 0.9); assert!(r2 <= 1.0);
877
878 let perfect_pred = y_true.clone();
880 let r2_perfect = compute_r2_score(&perfect_pred, &y_true);
881 assert!((r2_perfect - 1.0).abs() < 1e-10);
882 }
883
884 #[test]
885 fn test_early_stopping_without_config() {
886 let x = array![[1.0], [2.0], [3.0], [4.0]];
887 let y = array![2.0, 4.0, 6.0, 8.0];
888
889 let solver = CoordinateDescentSolver::default(); let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.1, false);
892 assert!(result.is_err());
893
894 let error = result.unwrap_err();
895 match error {
896 SklearsError::InvalidInput(msg) => {
897 assert!(msg.contains("Early stopping config not set"));
898 }
899 _ => panic!("Expected InvalidInput error"),
900 }
901 }
902}