sklears_linear/
coordinate_descent.rs

1//! Coordinate Descent solver for Lasso and ElasticNet regression
2
3#[cfg(feature = "early-stopping")]
4use crate::early_stopping::{train_validation_split, EarlyStopping, EarlyStoppingConfig};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use sklears_core::{
7    error::{Result, SklearsError},
8    types::Float,
9};
10
11/// Soft thresholding operator for L1 regularization
12#[inline]
13fn soft_threshold(x: Float, lambda: Float) -> Float {
14    if x > lambda {
15        x - lambda
16    } else if x < -lambda {
17        x + lambda
18    } else {
19        0.0
20    }
21}
22
23/// Coordinate Descent solver for Lasso and ElasticNet
24pub struct CoordinateDescentSolver {
25    /// Maximum number of iterations
26    pub max_iter: usize,
27    /// Convergence tolerance
28    pub tol: Float,
29    /// Whether to use cyclic or random selection
30    pub cyclic: bool,
31    /// Early stopping configuration (optional)
32    #[cfg(feature = "early-stopping")]
33    pub early_stopping_config: Option<EarlyStoppingConfig>,
34}
35
36impl Default for CoordinateDescentSolver {
37    fn default() -> Self {
38        Self {
39            max_iter: 1000,
40            tol: 1e-4,
41            cyclic: true,
42            #[cfg(feature = "early-stopping")]
43            early_stopping_config: None,
44        }
45    }
46}
47
48impl CoordinateDescentSolver {
49    /// Solve Lasso regression using coordinate descent
50    ///
51    /// Minimizes: (1/2n) ||y - Xβ||² + α||β||₁
52    pub fn solve_lasso(
53        &self,
54        x: &Array2<Float>,
55        y: &Array1<Float>,
56        alpha: Float,
57        fit_intercept: bool,
58    ) -> Result<(Array1<Float>, Option<Float>)> {
59        self.solve_lasso_with_warm_start(x, y, alpha, fit_intercept, None, None)
60    }
61
62    /// Solve Lasso regression using coordinate descent with warm start
63    ///
64    /// Minimizes: (1/2n) ||y - Xβ||² + α||β||₁
65    pub fn solve_lasso_with_warm_start(
66        &self,
67        x: &Array2<Float>,
68        y: &Array1<Float>,
69        alpha: Float,
70        fit_intercept: bool,
71        initial_coef: Option<&Array1<Float>>,
72        initial_intercept: Option<Float>,
73    ) -> Result<(Array1<Float>, Option<Float>)> {
74        let n_samples = x.nrows() as Float;
75        let n_features = x.ncols();
76
77        // Initialize coefficients (with warm start if provided)
78        let mut coef = match initial_coef {
79            Some(init_coef) => {
80                if init_coef.len() != n_features {
81                    return Err(SklearsError::FeatureMismatch {
82                        expected: n_features,
83                        actual: init_coef.len(),
84                    });
85                }
86                init_coef.clone()
87            }
88            None => Array1::zeros(n_features),
89        };
90
91        let mut intercept = if fit_intercept {
92            initial_intercept.unwrap_or_else(|| y.mean().unwrap_or(0.0))
93        } else {
94            0.0
95        };
96
97        // Precompute norms for each feature
98        let feature_norms: Array1<Float> = x
99            .axis_iter(Axis(1))
100            .map(|col| col.dot(&col) / n_samples)
101            .collect();
102
103        // Coordinate descent iterations
104        let mut converged = false;
105        for _iter in 0..self.max_iter {
106            let old_coef = coef.clone();
107
108            // Update intercept if needed
109            if fit_intercept {
110                let residuals = y - &x.dot(&coef) - intercept;
111                intercept = residuals.mean().unwrap_or(0.0);
112            }
113
114            // Update each coordinate
115            for j in 0..n_features {
116                // Skip if feature norm is zero
117                if feature_norms[j] == 0.0 {
118                    coef[j] = 0.0;
119                    continue;
120                }
121
122                // Compute partial residual (excluding j-th feature)
123                let mut residuals = y - &x.dot(&coef);
124                if fit_intercept {
125                    residuals -= intercept;
126                }
127                residuals = residuals + x.column(j).to_owned() * coef[j];
128
129                // Compute gradient for j-th feature
130                let gradient = x.column(j).dot(&residuals) / n_samples;
131
132                // Apply soft thresholding
133                coef[j] = soft_threshold(gradient, alpha) / feature_norms[j];
134            }
135
136            // Check convergence
137            let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
138            if coef_change < self.tol {
139                converged = true;
140                break;
141            }
142        }
143
144        if !converged {
145            eprintln!(
146                "Warning: Coordinate descent did not converge. Consider increasing max_iter."
147            );
148        }
149
150        let intercept_opt = if fit_intercept { Some(intercept) } else { None };
151        Ok((coef, intercept_opt))
152    }
153
154    /// Solve ElasticNet regression using coordinate descent
155    ///
156    /// Minimizes: (1/2n) ||y - Xβ||² + α * ρ * ||β||₁ + α * (1-ρ)/2 * ||β||²
157    /// where ρ is l1_ratio
158    pub fn solve_elastic_net(
159        &self,
160        x: &Array2<Float>,
161        y: &Array1<Float>,
162        alpha: Float,
163        l1_ratio: Float,
164        fit_intercept: bool,
165    ) -> Result<(Array1<Float>, Option<Float>)> {
166        self.solve_elastic_net_with_warm_start(x, y, alpha, l1_ratio, fit_intercept, None, None)
167    }
168
169    /// Solve ElasticNet regression using coordinate descent with warm start
170    ///
171    /// Minimizes: (1/2n) ||y - Xβ||² + α * ρ * ||β||₁ + α * (1-ρ)/2 * ||β||²
172    /// where ρ is l1_ratio
173    #[allow(clippy::too_many_arguments)]
174    pub fn solve_elastic_net_with_warm_start(
175        &self,
176        x: &Array2<Float>,
177        y: &Array1<Float>,
178        alpha: Float,
179        l1_ratio: Float,
180        fit_intercept: bool,
181        initial_coef: Option<&Array1<Float>>,
182        initial_intercept: Option<Float>,
183    ) -> Result<(Array1<Float>, Option<Float>)> {
184        if !(0.0..=1.0).contains(&l1_ratio) {
185            return Err(SklearsError::InvalidParameter {
186                name: "l1_ratio".to_string(),
187                reason: "must be between 0 and 1".to_string(),
188            });
189        }
190
191        let n_samples = x.nrows() as Float;
192        let n_features = x.ncols();
193
194        // Regularization parameters
195        let l1_reg = alpha * l1_ratio;
196        let l2_reg = alpha * (1.0 - l1_ratio);
197
198        // Initialize coefficients (with warm start if provided)
199        let mut coef = match initial_coef {
200            Some(init_coef) => {
201                if init_coef.len() != n_features {
202                    return Err(SklearsError::FeatureMismatch {
203                        expected: n_features,
204                        actual: init_coef.len(),
205                    });
206                }
207                init_coef.clone()
208            }
209            None => Array1::zeros(n_features),
210        };
211
212        let mut intercept = if fit_intercept {
213            initial_intercept.unwrap_or_else(|| y.mean().unwrap_or(0.0))
214        } else {
215            0.0
216        };
217
218        // Precompute norms for each feature (including L2 penalty)
219        let feature_norms: Array1<Float> = x
220            .axis_iter(Axis(1))
221            .map(|col| col.dot(&col) / n_samples + l2_reg)
222            .collect();
223
224        // Coordinate descent iterations
225        let mut converged = false;
226        for _iter in 0..self.max_iter {
227            let old_coef = coef.clone();
228
229            // Update intercept if needed
230            if fit_intercept {
231                let residuals = y - &x.dot(&coef) - intercept;
232                intercept = residuals.mean().unwrap_or(0.0);
233            }
234
235            // Update each coordinate
236            for j in 0..n_features {
237                // Skip if feature norm is zero
238                if feature_norms[j] == 0.0 {
239                    coef[j] = 0.0;
240                    continue;
241                }
242
243                // Compute partial residual (excluding j-th feature)
244                let mut residuals = y - &x.dot(&coef);
245                if fit_intercept {
246                    residuals -= intercept;
247                }
248                residuals = residuals + x.column(j).to_owned() * coef[j];
249
250                // Compute gradient for j-th feature
251                let gradient = x.column(j).dot(&residuals) / n_samples;
252
253                // Apply soft thresholding for L1 and account for L2
254                coef[j] = soft_threshold(gradient, l1_reg) / feature_norms[j];
255            }
256
257            // Check convergence
258            let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
259            if coef_change < self.tol {
260                converged = true;
261                break;
262            }
263        }
264
265        if !converged {
266            eprintln!(
267                "Warning: Coordinate descent did not converge. Consider increasing max_iter."
268            );
269        }
270
271        let intercept_opt = if fit_intercept { Some(intercept) } else { None };
272        Ok((coef, intercept_opt))
273    }
274
275    /// Configure early stopping for the solver
276    pub fn with_early_stopping(mut self, config: EarlyStoppingConfig) -> Self {
277        self.early_stopping_config = Some(config);
278        self
279    }
280
281    /// Solve Lasso regression with early stopping based on validation metrics
282    ///
283    /// This method automatically splits the data into training and validation sets
284    /// and uses validation performance to determine when to stop training.
285    pub fn solve_lasso_with_early_stopping(
286        &self,
287        x: &Array2<Float>,
288        y: &Array1<Float>,
289        alpha: Float,
290        fit_intercept: bool,
291    ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
292        let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
293            SklearsError::InvalidInput(
294                "Early stopping config not set. Use with_early_stopping() first.".to_string(),
295            )
296        })?;
297
298        // Split data into training and validation sets
299        let (x_train, y_train, x_val, y_val) = train_validation_split(
300            x,
301            y,
302            early_stopping_config.validation_split,
303            early_stopping_config.shuffle,
304            early_stopping_config.random_state,
305        )?;
306
307        self.solve_lasso_with_early_stopping_split(
308            &x_train,
309            &y_train,
310            &x_val,
311            &y_val,
312            alpha,
313            fit_intercept,
314        )
315    }
316
317    /// Solve Lasso regression with early stopping using pre-split data
318    pub fn solve_lasso_with_early_stopping_split(
319        &self,
320        x_train: &Array2<Float>,
321        y_train: &Array1<Float>,
322        x_val: &Array2<Float>,
323        y_val: &Array1<Float>,
324        alpha: Float,
325        fit_intercept: bool,
326    ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
327        let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
328            SklearsError::InvalidInput(
329                "Early stopping config not set. Use with_early_stopping() first.".to_string(),
330            )
331        })?;
332
333        let mut early_stopping = EarlyStopping::new(early_stopping_config.clone());
334
335        let n_samples = x_train.nrows() as Float;
336        let n_features = x_train.ncols();
337
338        // Initialize coefficients
339        let mut coef = Array1::zeros(n_features);
340        let mut intercept = if fit_intercept {
341            y_train.mean().unwrap_or(0.0)
342        } else {
343            0.0
344        };
345
346        // Store best coefficients if restore_best_weights is enabled
347        let mut best_coef = coef.clone();
348        let mut best_intercept = intercept;
349
350        // Precompute norms for each feature
351        let feature_norms: Array1<Float> = x_train
352            .axis_iter(Axis(1))
353            .map(|col| col.dot(&col) / n_samples)
354            .collect();
355
356        let mut validation_scores = Vec::new();
357        let mut converged = false;
358
359        // Coordinate descent iterations with early stopping
360        for iter in 0..self.max_iter {
361            let old_coef = coef.clone();
362
363            // Update intercept if needed
364            if fit_intercept {
365                let residuals = y_train - &x_train.dot(&coef) - intercept;
366                intercept = residuals.mean().unwrap_or(0.0);
367            }
368
369            // Update each coordinate
370            for j in 0..n_features {
371                if feature_norms[j] == 0.0 {
372                    coef[j] = 0.0;
373                    continue;
374                }
375
376                // Compute partial residual (excluding j-th feature)
377                let mut residuals = y_train - &x_train.dot(&coef);
378                if fit_intercept {
379                    residuals -= intercept;
380                }
381                residuals = residuals + x_train.column(j).to_owned() * coef[j];
382
383                // Compute gradient for j-th feature
384                let gradient = x_train.column(j).dot(&residuals) / n_samples;
385
386                // Apply soft thresholding
387                coef[j] = soft_threshold(gradient, alpha) / feature_norms[j];
388            }
389
390            // Check convergence based on coefficient change
391            let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
392            if coef_change < self.tol {
393                converged = true;
394            }
395
396            // Compute validation score (R² score)
397            let val_predictions = if fit_intercept {
398                x_val.dot(&coef) + intercept
399            } else {
400                x_val.dot(&coef)
401            };
402
403            let r2_score = compute_r2_score(&val_predictions, y_val);
404            validation_scores.push(r2_score);
405
406            // Check early stopping
407            let should_continue = early_stopping.update(r2_score);
408
409            // Store best weights if this is the best iteration so far
410            if early_stopping_config.restore_best_weights
411                && early_stopping.best_iteration() == iter + 1
412            {
413                best_coef = coef.clone();
414                best_intercept = intercept;
415            }
416
417            if !should_continue || converged {
418                break;
419            }
420        }
421
422        // Restore best weights if configured
423        let (final_coef, final_intercept) = if early_stopping_config.restore_best_weights {
424            (best_coef, best_intercept)
425        } else {
426            (coef, intercept)
427        };
428
429        let validation_info = ValidationInfo {
430            validation_scores,
431            best_score: early_stopping.best_score(),
432            best_iteration: early_stopping.best_iteration(),
433            stopped_early: early_stopping.should_stop(),
434            converged,
435        };
436
437        let intercept_opt = if fit_intercept {
438            Some(final_intercept)
439        } else {
440            None
441        };
442        Ok((final_coef, intercept_opt, validation_info))
443    }
444
445    /// Solve ElasticNet regression with early stopping based on validation metrics
446    pub fn solve_elastic_net_with_early_stopping(
447        &self,
448        x: &Array2<Float>,
449        y: &Array1<Float>,
450        alpha: Float,
451        l1_ratio: Float,
452        fit_intercept: bool,
453    ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
454        let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
455            SklearsError::InvalidInput(
456                "Early stopping config not set. Use with_early_stopping() first.".to_string(),
457            )
458        })?;
459
460        // Split data into training and validation sets
461        let (x_train, y_train, x_val, y_val) = train_validation_split(
462            x,
463            y,
464            early_stopping_config.validation_split,
465            early_stopping_config.shuffle,
466            early_stopping_config.random_state,
467        )?;
468
469        self.solve_elastic_net_with_early_stopping_split(
470            &x_train,
471            &y_train,
472            &x_val,
473            &y_val,
474            alpha,
475            l1_ratio,
476            fit_intercept,
477        )
478    }
479
480    /// Solve ElasticNet regression with early stopping using pre-split data
481    #[allow(clippy::too_many_arguments)]
482    pub fn solve_elastic_net_with_early_stopping_split(
483        &self,
484        x_train: &Array2<Float>,
485        y_train: &Array1<Float>,
486        x_val: &Array2<Float>,
487        y_val: &Array1<Float>,
488        alpha: Float,
489        l1_ratio: Float,
490        fit_intercept: bool,
491    ) -> Result<(Array1<Float>, Option<Float>, ValidationInfo)> {
492        if !(0.0..=1.0).contains(&l1_ratio) {
493            return Err(SklearsError::InvalidParameter {
494                name: "l1_ratio".to_string(),
495                reason: "must be between 0 and 1".to_string(),
496            });
497        }
498
499        let early_stopping_config = self.early_stopping_config.as_ref().ok_or_else(|| {
500            SklearsError::InvalidInput(
501                "Early stopping config not set. Use with_early_stopping() first.".to_string(),
502            )
503        })?;
504
505        let mut early_stopping = EarlyStopping::new(early_stopping_config.clone());
506
507        let n_samples = x_train.nrows() as Float;
508        let n_features = x_train.ncols();
509
510        // Regularization parameters
511        let l1_reg = alpha * l1_ratio;
512        let l2_reg = alpha * (1.0 - l1_ratio);
513
514        // Initialize coefficients
515        let mut coef = Array1::zeros(n_features);
516        let mut intercept = if fit_intercept {
517            y_train.mean().unwrap_or(0.0)
518        } else {
519            0.0
520        };
521
522        // Store best coefficients if restore_best_weights is enabled
523        let mut best_coef = coef.clone();
524        let mut best_intercept = intercept;
525
526        // Precompute norms for each feature (including L2 penalty)
527        let feature_norms: Array1<Float> = x_train
528            .axis_iter(Axis(1))
529            .map(|col| col.dot(&col) / n_samples + l2_reg)
530            .collect();
531
532        let mut validation_scores = Vec::new();
533        let mut converged = false;
534
535        // Coordinate descent iterations with early stopping
536        for iter in 0..self.max_iter {
537            let old_coef = coef.clone();
538
539            // Update intercept if needed
540            if fit_intercept {
541                let residuals = y_train - &x_train.dot(&coef) - intercept;
542                intercept = residuals.mean().unwrap_or(0.0);
543            }
544
545            // Update each coordinate
546            for j in 0..n_features {
547                if feature_norms[j] == 0.0 {
548                    coef[j] = 0.0;
549                    continue;
550                }
551
552                // Compute partial residual (excluding j-th feature)
553                let mut residuals = y_train - &x_train.dot(&coef);
554                if fit_intercept {
555                    residuals -= intercept;
556                }
557                residuals = residuals + x_train.column(j).to_owned() * coef[j];
558
559                // Compute gradient for j-th feature
560                let gradient = x_train.column(j).dot(&residuals) / n_samples;
561
562                // Apply soft thresholding for L1 and account for L2
563                coef[j] = soft_threshold(gradient, l1_reg) / feature_norms[j];
564            }
565
566            // Check convergence based on coefficient change
567            let coef_change = (&coef - &old_coef).mapv(Float::abs).sum();
568            if coef_change < self.tol {
569                converged = true;
570            }
571
572            // Compute validation score (R² score)
573            let val_predictions = if fit_intercept {
574                x_val.dot(&coef) + intercept
575            } else {
576                x_val.dot(&coef)
577            };
578
579            let r2_score = compute_r2_score(&val_predictions, y_val);
580            validation_scores.push(r2_score);
581
582            // Check early stopping
583            let should_continue = early_stopping.update(r2_score);
584
585            // Store best weights if this is the best iteration so far
586            if early_stopping_config.restore_best_weights
587                && early_stopping.best_iteration() == iter + 1
588            {
589                best_coef = coef.clone();
590                best_intercept = intercept;
591            }
592
593            if !should_continue || converged {
594                break;
595            }
596        }
597
598        // Restore best weights if configured
599        let (final_coef, final_intercept) = if early_stopping_config.restore_best_weights {
600            (best_coef, best_intercept)
601        } else {
602            (coef, intercept)
603        };
604
605        let validation_info = ValidationInfo {
606            validation_scores,
607            best_score: early_stopping.best_score(),
608            best_iteration: early_stopping.best_iteration(),
609            stopped_early: early_stopping.should_stop(),
610            converged,
611        };
612
613        let intercept_opt = if fit_intercept {
614            Some(final_intercept)
615        } else {
616            None
617        };
618        Ok((final_coef, intercept_opt, validation_info))
619    }
620}
621
622/// Information about the validation process during early stopping
623#[derive(Debug, Clone)]
624pub struct ValidationInfo {
625    /// Validation scores for each iteration
626    pub validation_scores: Vec<Float>,
627    /// Best validation score achieved
628    pub best_score: Option<Float>,
629    /// Iteration where best score was achieved
630    pub best_iteration: usize,
631    /// Whether training was stopped early
632    pub stopped_early: bool,
633    /// Whether convergence was achieved
634    pub converged: bool,
635}
636
637/// Compute R² score for regression validation
638pub fn compute_r2_score(y_pred: &Array1<Float>, y_true: &Array1<Float>) -> Float {
639    let y_mean = y_true.mean().unwrap_or(0.0);
640    let ss_res = (y_pred - y_true).mapv(|x| x * x).sum();
641    let ss_tot = y_true.mapv(|yi| (yi - y_mean).powi(2)).sum();
642
643    if ss_tot == 0.0 {
644        1.0
645    } else {
646        1.0 - (ss_res / ss_tot)
647    }
648}
649
650#[allow(non_snake_case)]
651#[cfg(test)]
652mod tests {
653    use super::*;
654    use approx::assert_abs_diff_eq;
655    use scirs2_core::ndarray::array;
656
657    #[test]
658    fn test_soft_threshold() {
659        assert_eq!(soft_threshold(2.0, 1.0), 1.0);
660        assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
661        assert_eq!(soft_threshold(0.5, 1.0), 0.0);
662        assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
663    }
664
665    #[test]
666    fn test_lasso_simple() {
667        // Simple test case: y = 2x (no noise)
668        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
669        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
670
671        let solver = CoordinateDescentSolver::default();
672
673        // With small alpha, should be close to OLS solution
674        let (coef, intercept) = solver.solve_lasso(&x, &y, 0.01, false).unwrap();
675        assert_abs_diff_eq!(coef[0], 2.0, epsilon = 0.1);
676        assert_eq!(intercept, None);
677
678        // With large alpha, coefficient should shrink
679        let (coef, _intercept) = solver.solve_lasso(&x, &y, 1.0, false).unwrap();
680        assert!(coef[0] < 2.0);
681        assert!(coef[0] > 0.0);
682    }
683
684    #[test]
685    fn test_elastic_net() {
686        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
687        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
688
689        let solver = CoordinateDescentSolver::default();
690
691        // ElasticNet with l1_ratio=0.5
692        let (coef, _) = solver.solve_elastic_net(&x, &y, 0.1, 0.5, false).unwrap();
693
694        // Should be between Lasso and Ridge solutions
695        let (lasso_coef, _) = solver.solve_lasso(&x, &y, 0.1, false).unwrap();
696
697        // ElasticNet coefficient should be different from pure Lasso
698        assert!(coef[0] != lasso_coef[0]);
699    }
700
701    #[test]
702    fn test_early_stopping_lasso() {
703        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
704
705        // Create synthetic data with more samples for meaningful validation split
706        let n_samples = 100;
707        let n_features = 5;
708        let mut x = Array2::zeros((n_samples, n_features));
709        let mut y = Array1::zeros(n_samples);
710
711        // Generate simple linear relationship
712        for i in 0..n_samples {
713            for j in 0..n_features {
714                x[[i, j]] = (i * j + 1) as Float / 10.0;
715            }
716            y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.5 * (i as Float % 3.0);
717        }
718
719        let early_stopping_config = EarlyStoppingConfig {
720            criterion: StoppingCriterion::Patience(5),
721            validation_split: 0.2,
722            shuffle: true,
723            random_state: Some(42),
724            higher_is_better: true,
725            min_iterations: 3,
726            restore_best_weights: true,
727        };
728
729        let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
730
731        let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.01, true);
732        assert!(result.is_ok());
733
734        let (coef, intercept, validation_info) = result.unwrap();
735
736        // Check that we have reasonable results
737        assert_eq!(coef.len(), n_features);
738        assert!(intercept.is_some());
739        assert!(!validation_info.validation_scores.is_empty());
740        assert!(validation_info.best_score.is_some());
741
742        // Early stopping should have triggered or converged
743        assert!(validation_info.stopped_early || validation_info.converged);
744    }
745
746    #[test]
747    fn test_early_stopping_elastic_net() {
748        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
749
750        // Create synthetic data
751        let n_samples = 80;
752        let n_features = 4;
753        let mut x = Array2::zeros((n_samples, n_features));
754        let mut y = Array1::zeros(n_samples);
755
756        for i in 0..n_samples {
757            for j in 0..n_features {
758                x[[i, j]] = (i + j) as Float / 10.0;
759            }
760            y[i] = 1.0 * x[[i, 0]] + 2.0 * x[[i, 1]] + 0.1 * (i as Float);
761        }
762
763        let early_stopping_config = EarlyStoppingConfig {
764            criterion: StoppingCriterion::TolerancePatience {
765                tolerance: 0.01,
766                patience: 3,
767            },
768            validation_split: 0.25,
769            shuffle: false,
770            random_state: None,
771            higher_is_better: true,
772            min_iterations: 2,
773            restore_best_weights: false,
774        };
775
776        let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
777
778        let result = solver.solve_elastic_net_with_early_stopping(&x, &y, 0.1, 0.5, true);
779        assert!(result.is_ok());
780
781        let (coef, intercept, validation_info) = result.unwrap();
782
783        // Check results
784        assert_eq!(coef.len(), n_features);
785        assert!(intercept.is_some());
786        assert!(!validation_info.validation_scores.is_empty());
787        assert!(validation_info.best_score.is_some());
788
789        // Validation info should be meaningful
790        assert!(validation_info.best_iteration > 0);
791    }
792
793    #[test]
794    fn test_early_stopping_with_presplit_data() {
795        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
796
797        // Create training and validation data
798        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
799        let y_train = array![3.0, 5.0, 7.0, 9.0]; // y = x1 + x2
800        let x_val = array![[5.0, 6.0], [6.0, 7.0]];
801        let y_val = array![11.0, 13.0];
802
803        let early_stopping_config = EarlyStoppingConfig {
804            criterion: StoppingCriterion::TargetScore(0.8),
805            validation_split: 0.2, // Not used since we provide pre-split data
806            shuffle: false,
807            random_state: None,
808            higher_is_better: true,
809            min_iterations: 1,
810            restore_best_weights: true,
811        };
812
813        let solver = CoordinateDescentSolver::default().with_early_stopping(early_stopping_config);
814
815        let result = solver
816            .solve_lasso_with_early_stopping_split(&x_train, &y_train, &x_val, &y_val, 0.001, true);
817        assert!(result.is_ok());
818
819        let (coef, intercept, validation_info) = result.unwrap();
820
821        assert_eq!(coef.len(), 2);
822        assert!(intercept.is_some());
823        assert!(!validation_info.validation_scores.is_empty());
824    }
825
826    #[test]
827    fn test_validation_info_structure() {
828        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
829
830        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0]];
831        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0];
832
833        let early_stopping_config = EarlyStoppingConfig {
834            criterion: StoppingCriterion::Patience(2),
835            validation_split: 0.25,
836            shuffle: false,
837            random_state: Some(123),
838            higher_is_better: true,
839            min_iterations: 1,
840            restore_best_weights: true,
841        };
842
843        let solver = CoordinateDescentSolver {
844            max_iter: 10,
845            tol: 1e-6,
846            cyclic: true,
847            early_stopping_config: Some(early_stopping_config),
848        };
849
850        let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.01, false);
851        assert!(result.is_ok());
852
853        let (_coef, _intercept, validation_info) = result.unwrap();
854
855        // Validation info should have meaningful structure
856        assert!(!validation_info.validation_scores.is_empty());
857        assert!(validation_info.best_score.is_some());
858        assert!(validation_info.best_iteration >= 1);
859
860        // Should have stopped early or converged
861        assert!(validation_info.stopped_early || validation_info.converged);
862
863        // Validation scores should be finite
864        for score in &validation_info.validation_scores {
865            assert!(score.is_finite());
866        }
867    }
868
869    #[test]
870    fn test_r2_score_computation() {
871        let y_true = array![1.0, 2.0, 3.0, 4.0, 5.0];
872        let y_pred = array![1.1, 1.9, 3.1, 3.9, 5.1];
873
874        let r2 = compute_r2_score(&y_pred, &y_true);
875        assert!(r2 > 0.9); // Should be high for good predictions
876        assert!(r2 <= 1.0);
877
878        // Perfect predictions should give R² = 1
879        let perfect_pred = y_true.clone();
880        let r2_perfect = compute_r2_score(&perfect_pred, &y_true);
881        assert!((r2_perfect - 1.0).abs() < 1e-10);
882    }
883
884    #[test]
885    fn test_early_stopping_without_config() {
886        let x = array![[1.0], [2.0], [3.0], [4.0]];
887        let y = array![2.0, 4.0, 6.0, 8.0];
888
889        let solver = CoordinateDescentSolver::default(); // No early stopping config
890
891        let result = solver.solve_lasso_with_early_stopping(&x, &y, 0.1, false);
892        assert!(result.is_err());
893
894        let error = result.unwrap_err();
895        match error {
896            SklearsError::InvalidInput(msg) => {
897                assert!(msg.contains("Early stopping config not set"));
898            }
899            _ => panic!("Expected InvalidInput error"),
900        }
901    }
902}