Skip to main content

sklears_linear/
linear_regression.rs

1//! Linear Regression implementation
2
3use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{s, Array};
6use scirs2_linalg::compat::ArrayLinalgExt;
7// Removed SVD import - using ArrayLinalgExt for both solve and svd methods
8use sklears_core::{
9    error::{validate, Result, SklearsError},
10    traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11    types::{Array1, Array2, Float},
12};
13
14use crate::{Penalty, Solver};
15
16#[cfg(feature = "coordinate-descent")]
17use crate::coordinate_descent::CoordinateDescentSolver;
18
19#[cfg(feature = "coordinate-descent")]
20use crate::coordinate_descent::ValidationInfo;
21
22#[cfg(feature = "early-stopping")]
23use crate::early_stopping::EarlyStoppingConfig;
24
25/// Configuration for Linear Regression
26#[derive(Debug, Clone)]
27pub struct LinearRegressionConfig {
28    /// Whether to fit the intercept
29    pub fit_intercept: bool,
30    /// Regularization penalty
31    pub penalty: Penalty,
32    /// Solver to use
33    pub solver: Solver,
34    /// Maximum iterations for iterative solvers
35    pub max_iter: usize,
36    /// Tolerance for convergence
37    pub tol: f64,
38    /// Whether to use warm start (reuse previous solution as initialization)
39    pub warm_start: bool,
40    /// Enable GPU acceleration if available
41    #[cfg(feature = "gpu")]
42    pub use_gpu: bool,
43    /// Minimum problem size to use GPU acceleration
44    #[cfg(feature = "gpu")]
45    pub gpu_min_size: usize,
46}
47
48impl Default for LinearRegressionConfig {
49    fn default() -> Self {
50        Self {
51            fit_intercept: true,
52            penalty: Penalty::None,
53            solver: Solver::Auto,
54            max_iter: 1000,
55            tol: 1e-4,
56            warm_start: false,
57            #[cfg(feature = "gpu")]
58            use_gpu: true,
59            #[cfg(feature = "gpu")]
60            gpu_min_size: 1000,
61        }
62    }
63}
64
65/// Linear Regression model
66#[derive(Debug, Clone)]
67pub struct LinearRegression<State = Untrained> {
68    config: LinearRegressionConfig,
69    state: PhantomData<State>,
70    // Trained state fields
71    coef_: Option<Array1<Float>>,
72    intercept_: Option<Float>,
73    n_features_: Option<usize>,
74}
75
76impl LinearRegression<Untrained> {
77    /// Create a new Linear Regression model
78    pub fn new() -> Self {
79        Self {
80            config: LinearRegressionConfig::default(),
81            state: PhantomData,
82            coef_: None,
83            intercept_: None,
84            n_features_: None,
85        }
86    }
87
88    /// Set whether to fit intercept
89    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
90        self.config.fit_intercept = fit_intercept;
91        self
92    }
93
94    /// Set regularization (Ridge/L2)
95    pub fn regularization(mut self, alpha: f64) -> Self {
96        self.config.penalty = Penalty::L2(alpha);
97        self
98    }
99
100    /// Create a Lasso regression model (L1 penalty)
101    pub fn lasso(alpha: f64) -> Self {
102        Self::new()
103            .penalty(Penalty::L1(alpha))
104            .solver(Solver::CoordinateDescent)
105    }
106
107    /// Create an ElasticNet regression model (L1 + L2 penalty)
108    pub fn elastic_net(alpha: f64, l1_ratio: f64) -> Self {
109        Self::new()
110            .penalty(Penalty::ElasticNet { l1_ratio, alpha })
111            .solver(Solver::CoordinateDescent)
112    }
113
114    /// Set penalty
115    pub fn penalty(mut self, penalty: Penalty) -> Self {
116        self.config.penalty = penalty;
117        self
118    }
119
120    /// Set solver
121    pub fn solver(mut self, solver: Solver) -> Self {
122        self.config.solver = solver;
123        self
124    }
125
126    /// Set maximum iterations
127    pub fn max_iter(mut self, max_iter: usize) -> Self {
128        self.config.max_iter = max_iter;
129        self
130    }
131
132    /// Set whether to use warm start
133    pub fn warm_start(mut self, warm_start: bool) -> Self {
134        self.config.warm_start = warm_start;
135        self
136    }
137
138    /// Enable or disable GPU acceleration
139    #[cfg(feature = "gpu")]
140    pub fn use_gpu(mut self, use_gpu: bool) -> Self {
141        self.config.use_gpu = use_gpu;
142        self
143    }
144
145    /// Set minimum problem size for GPU acceleration
146    #[cfg(feature = "gpu")]
147    pub fn gpu_min_size(mut self, min_size: usize) -> Self {
148        self.config.gpu_min_size = min_size;
149        self
150    }
151}
152
153impl Default for LinearRegression<Untrained> {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl Estimator for LinearRegression<Untrained> {
160    type Config = LinearRegressionConfig;
161    type Error = SklearsError;
162    type Float = Float;
163
164    fn config(&self) -> &Self::Config {
165        &self.config
166    }
167}
168
169impl Fit<Array2<Float>, Array1<Float>> for LinearRegression<Untrained> {
170    type Fitted = LinearRegression<Trained>;
171
172    fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
173        // Validate inputs
174        validate::check_consistent_length(x, y)?;
175
176        let n_samples = x.nrows();
177        let n_features = x.ncols();
178
179        // Add intercept column if needed
180        let (x_with_intercept, n_params) = if self.config.fit_intercept {
181            let mut x_new = Array::ones((n_samples, n_features + 1));
182            x_new.slice_mut(s![.., 1..]).assign(x);
183            (x_new, n_features + 1)
184        } else {
185            (x.clone(), n_features)
186        };
187
188        // Solve based on penalty type
189        let params = match self.config.penalty {
190            Penalty::None => {
191                // Check if we should use GPU acceleration
192                #[cfg(feature = "gpu")]
193                if self.config.use_gpu && n_samples * n_features >= self.config.gpu_min_size {
194                    // Try GPU-accelerated OLS
195                    match self.solve_ols_gpu(&x_with_intercept, y) {
196                        Ok(params) => params,
197                        Err(_) => {
198                            // Fallback to CPU if GPU fails
199                            self.solve_ols_cpu(&x_with_intercept, y)?
200                        }
201                    }
202                } else {
203                    self.solve_ols_cpu(&x_with_intercept, y)?
204                }
205
206                #[cfg(not(feature = "gpu"))]
207                self.solve_ols_cpu(&x_with_intercept, y)?
208            }
209            Penalty::L2(alpha) => {
210                // Ridge regression
211                // (X^T X + αI) β = X^T y
212                let xtx = x_with_intercept.t().dot(&x_with_intercept);
213                let xty = x_with_intercept.t().dot(y);
214
215                // Add regularization to diagonal (except intercept if present)
216                let mut regularized = xtx.clone();
217                let start_idx = if self.config.fit_intercept { 1 } else { 0 };
218                for i in start_idx..n_params {
219                    regularized[[i, i]] += alpha;
220                }
221
222                regularized.solve(&xty).map_err(|e| {
223                    SklearsError::NumericalError(format!("Failed to solve ridge regression: {}", e))
224                })?
225            }
226            Penalty::L1(alpha) => {
227                // Lasso regression using coordinate descent
228                #[cfg(feature = "coordinate-descent")]
229                {
230                    let cd_solver = CoordinateDescentSolver {
231                        max_iter: self.config.max_iter,
232                        tol: self.config.tol,
233                        cyclic: true,
234                        #[cfg(feature = "early-stopping")]
235                        early_stopping_config: None,
236                    };
237
238                    let (coef, intercept) = cd_solver
239                        .solve_lasso(x, y, alpha, self.config.fit_intercept)
240                        .map_err(|e| {
241                            SklearsError::NumericalError(format!(
242                                "Coordinate descent failed: {}",
243                                e
244                            ))
245                        })?;
246
247                    if self.config.fit_intercept {
248                        // Need to add intercept to beginning of params for consistency
249                        let mut params = Array::zeros(coef.len() + 1);
250                        params[0] = intercept.unwrap_or(0.0);
251                        params.slice_mut(s![1..]).assign(&coef);
252                        params
253                    } else {
254                        coef
255                    }
256                }
257                #[cfg(not(feature = "coordinate-descent"))]
258                {
259                    return Err(SklearsError::InvalidParameter {
260                        name: "penalty".to_string(),
261                        reason:
262                            "L1 regularization (Lasso) requires the 'coordinate-descent' feature"
263                                .to_string(),
264                    });
265                }
266            }
267            Penalty::ElasticNet { l1_ratio, alpha } => {
268                // ElasticNet regression using coordinate descent
269                #[cfg(feature = "coordinate-descent")]
270                {
271                    let cd_solver = CoordinateDescentSolver {
272                        max_iter: self.config.max_iter,
273                        tol: self.config.tol,
274                        cyclic: true,
275                        #[cfg(feature = "early-stopping")]
276                        early_stopping_config: None,
277                    };
278
279                    let (coef, intercept) = cd_solver
280                        .solve_elastic_net(x, y, alpha, l1_ratio, self.config.fit_intercept)
281                        .map_err(|e| {
282                            SklearsError::NumericalError(format!(
283                                "Coordinate descent failed: {}",
284                                e
285                            ))
286                        })?;
287
288                    if self.config.fit_intercept {
289                        // Need to add intercept to beginning of params for consistency
290                        let mut params = Array::zeros(coef.len() + 1);
291                        params[0] = intercept.unwrap_or(0.0);
292                        params.slice_mut(s![1..]).assign(&coef);
293                        params
294                    } else {
295                        coef
296                    }
297                }
298                #[cfg(not(feature = "coordinate-descent"))]
299                {
300                    return Err(SklearsError::InvalidParameter {
301                        name: "penalty".to_string(),
302                        reason:
303                            "ElasticNet regularization requires the 'coordinate-descent' feature"
304                                .to_string(),
305                    });
306                }
307            }
308        };
309
310        // Extract coefficients and intercept
311        let (coef_, intercept_) = if self.config.fit_intercept {
312            let intercept = params[0];
313            let coef = params.slice(s![1..]).to_owned();
314            (coef, Some(intercept))
315        } else {
316            (params, None)
317        };
318
319        Ok(LinearRegression {
320            config: self.config,
321            state: PhantomData,
322            coef_: Some(coef_),
323            intercept_,
324            n_features_: Some(n_features),
325        })
326    }
327}
328
329impl LinearRegression<Untrained> {
330    /// CPU-based OLS solver
331    fn solve_ols_cpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
332        // Ordinary Least Squares using scirs2
333        // X^T X β = X^T y
334        let xtx = x.t().dot(x);
335        let xty = x.t().dot(y);
336
337        // Use scirs2's linear solver
338        xtx.solve(&xty).map_err(|e| {
339            SklearsError::NumericalError(format!("Failed to solve linear system: {}", e))
340        })
341    }
342
343    /// GPU-based OLS solver
344    #[cfg(feature = "gpu")]
345    fn solve_ols_gpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
346        use crate::gpu_acceleration::{GpuConfig, GpuLinearOps};
347
348        // Initialize GPU operations
349        let gpu_config = GpuConfig {
350            device_id: 0,
351            use_pinned_memory: true,
352            min_problem_size: self.config.gpu_min_size,
353            ..Default::default()
354        };
355
356        let gpu_ops = GpuLinearOps::new(gpu_config).map_err(|e| {
357            SklearsError::NumericalError(format!("Failed to initialize GPU operations: {}", e))
358        })?;
359
360        // Check if GPU is available
361        if !gpu_ops.is_gpu_available() {
362            return Err(SklearsError::NumericalError(
363                "GPU not available, falling back to CPU".to_string(),
364            ));
365        }
366
367        // Compute X^T X using GPU
368        let xt = gpu_ops.matrix_transpose(x)?;
369        let xtx = gpu_ops.matrix_multiply(&xt, x)?;
370
371        // Compute X^T y using GPU
372        let xty = gpu_ops.matrix_vector_multiply(&xt, y)?;
373
374        // Solve linear system using GPU
375        gpu_ops.solve_linear_system(&xtx, &xty)
376    }
377
378    /// Fit the linear regression model with warm start
379    ///
380    /// Uses the provided coefficients and intercept as initialization for iterative solvers
381    pub fn fit_with_warm_start(
382        self,
383        x: &Array2<Float>,
384        y: &Array1<Float>,
385        initial_coef: Option<&Array1<Float>>,
386        initial_intercept: Option<Float>,
387    ) -> Result<LinearRegression<Trained>> {
388        // Validate inputs
389        validate::check_consistent_length(x, y)?;
390
391        let n_features = x.ncols();
392
393        // For warm start, we only support ElasticNet/Lasso methods (coordinate descent)
394        let params: Array1<Float> = match self.config.penalty {
395            Penalty::L1(_)
396            | Penalty::L2(_)
397            | Penalty::ElasticNet {
398                alpha: _,
399                l1_ratio: _,
400            } => {
401                #[cfg(feature = "coordinate-descent")]
402                {
403                    let (alpha_val, l1_ratio) = match self.config.penalty {
404                        Penalty::L1(alpha) => (alpha, 1.0),
405                        Penalty::L2(alpha) => (alpha, 0.0),
406                        Penalty::ElasticNet { alpha, l1_ratio } => (alpha, l1_ratio),
407                        _ => unreachable!(),
408                    };
409
410                    let cd_solver = CoordinateDescentSolver {
411                        max_iter: self.config.max_iter,
412                        tol: self.config.tol,
413                        cyclic: true,
414                        #[cfg(feature = "early-stopping")]
415                        early_stopping_config: None,
416                    };
417
418                    let (coef, intercept) = cd_solver
419                        .solve_elastic_net_with_warm_start(
420                            x,
421                            y,
422                            alpha_val,
423                            l1_ratio,
424                            self.config.fit_intercept,
425                            initial_coef,
426                            initial_intercept,
427                        )
428                        .map_err(|e| {
429                            SklearsError::NumericalError(format!(
430                                "Coordinate descent failed: {}",
431                                e
432                            ))
433                        })?;
434
435                    if self.config.fit_intercept {
436                        // Need to add intercept to beginning of params for consistency
437                        let mut params = Array::zeros(coef.len() + 1);
438                        params[0] = intercept.unwrap_or(0.0);
439                        params.slice_mut(s![1..]).assign(&coef);
440                        params
441                    } else {
442                        coef
443                    }
444                }
445                #[cfg(not(feature = "coordinate-descent"))]
446                {
447                    return Err(SklearsError::InvalidParameter {
448                        name: "penalty".to_string(),
449                        reason: "Warm start requires the 'coordinate-descent' feature".to_string(),
450                    });
451                }
452            }
453            Penalty::None => {
454                return Err(SklearsError::InvalidParameter {
455                    name: "penalty".to_string(),
456                    reason:
457                        "Warm start only supported for regularized methods (L1, L2, ElasticNet)"
458                            .to_string(),
459                });
460            }
461        };
462
463        // Extract coefficients and intercept
464        let (coef_, intercept_) = if self.config.fit_intercept {
465            let intercept = params[0];
466            let coef = params.slice(s![1..]).to_owned();
467            (coef, Some(intercept))
468        } else {
469            (params, None)
470        };
471
472        Ok(LinearRegression {
473            config: self.config,
474            state: PhantomData,
475            coef_: Some(coef_),
476            intercept_,
477            n_features_: Some(n_features),
478        })
479    }
480}
481
482impl LinearRegression<Trained> {
483    /// Get the coefficients
484    pub fn coef(&self) -> &Array1<Float> {
485        self.coef_.as_ref().expect("Model is trained")
486    }
487
488    /// Get the intercept
489    pub fn intercept(&self) -> Option<Float> {
490        self.intercept_
491    }
492}
493
494impl Predict<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
495    fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
496        let n_features = self.n_features_.expect("Model is trained");
497        validate::check_n_features(x, n_features)?;
498
499        let coef = self.coef_.as_ref().expect("Model is trained");
500        let mut predictions = x.dot(coef);
501
502        if let Some(intercept) = self.intercept_ {
503            predictions += intercept;
504        }
505
506        Ok(predictions)
507    }
508}
509
510impl Score<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
511    type Float = Float;
512
513    fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
514        let predictions = self.predict(x)?;
515
516        // Calculate R² score using scirs2 metrics
517        let ss_res = (&predictions - y).mapv(|x| x * x).sum();
518        let y_mean = y.mean().unwrap_or(0.0);
519        let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
520
521        if ss_tot == 0.0 {
522            return Ok(1.0);
523        }
524
525        Ok(1.0 - (ss_res / ss_tot))
526    }
527}
528
529impl LinearRegression<Untrained> {
530    /// Fit the linear regression model with early stopping based on validation metrics
531    ///
532    /// This method is particularly useful for regularized methods (Lasso, ElasticNet)
533    /// where early stopping can prevent overfitting.
534    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
535    pub fn fit_with_early_stopping(
536        self,
537        x: &Array2<Float>,
538        y: &Array1<Float>,
539        early_stopping_config: EarlyStoppingConfig,
540    ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
541        // Validate inputs
542        validate::check_consistent_length(x, y)?;
543
544        let n_features = x.ncols();
545
546        // Early stopping is most beneficial for regularized methods
547        match self.config.penalty {
548            Penalty::L1(alpha) => {
549                let cd_solver = CoordinateDescentSolver {
550                    max_iter: self.config.max_iter,
551                    tol: self.config.tol,
552                    cyclic: true,
553                    early_stopping_config: Some(early_stopping_config),
554                };
555
556                let (coef, intercept, validation_info) = cd_solver
557                    .solve_lasso_with_early_stopping(x, y, alpha, self.config.fit_intercept)?;
558
559                let intercept_ = if self.config.fit_intercept {
560                    intercept
561                } else {
562                    None
563                };
564
565                let fitted_model = LinearRegression {
566                    config: self.config,
567                    state: PhantomData,
568                    coef_: Some(coef),
569                    intercept_,
570                    n_features_: Some(n_features),
571                };
572
573                Ok((fitted_model, validation_info))
574            }
575            Penalty::ElasticNet { l1_ratio, alpha } => {
576                let cd_solver = CoordinateDescentSolver {
577                    max_iter: self.config.max_iter,
578                    tol: self.config.tol,
579                    cyclic: true,
580                    early_stopping_config: Some(early_stopping_config),
581                };
582
583                let (coef, intercept, validation_info) = cd_solver
584                    .solve_elastic_net_with_early_stopping(
585                        x,
586                        y,
587                        alpha,
588                        l1_ratio,
589                        self.config.fit_intercept,
590                    )?;
591
592                let intercept_ = if self.config.fit_intercept {
593                    intercept
594                } else {
595                    None
596                };
597
598                let fitted_model = LinearRegression {
599                    config: self.config,
600                    state: PhantomData,
601                    coef_: Some(coef),
602                    intercept_,
603                    n_features_: Some(n_features),
604                };
605
606                Ok((fitted_model, validation_info))
607            }
608            Penalty::L2(_alpha) => {
609                // For Ridge regression, we can use iterative solver with early stopping
610                // For now, fall back to regular fit and provide minimal validation info
611                let fitted_model = self.fit(x, y)?;
612                let validation_info = ValidationInfo {
613                    validation_scores: vec![1.0], // Dummy score
614                    best_score: Some(1.0),
615                    best_iteration: 1,
616                    stopped_early: false,
617                    converged: true,
618                };
619                Ok((fitted_model, validation_info))
620            }
621            Penalty::None => {
622                // For OLS, early stopping doesn't make much sense since it's a direct solution
623                let fitted_model = self.fit(x, y)?;
624                let validation_info = ValidationInfo {
625                    validation_scores: vec![1.0], // Dummy score
626                    best_score: Some(1.0),
627                    best_iteration: 1,
628                    stopped_early: false,
629                    converged: true,
630                };
631                Ok((fitted_model, validation_info))
632            }
633        }
634    }
635
636    /// Fit the linear regression model with early stopping using pre-split validation data
637    ///
638    /// This gives you more control over the train/validation split compared to
639    /// `fit_with_early_stopping` which automatically splits the data.
640    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
641    pub fn fit_with_early_stopping_split(
642        self,
643        x_train: &Array2<Float>,
644        y_train: &Array1<Float>,
645        x_val: &Array2<Float>,
646        y_val: &Array1<Float>,
647        early_stopping_config: EarlyStoppingConfig,
648    ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
649        // Validate inputs
650        validate::check_consistent_length(x_train, y_train)?;
651        validate::check_consistent_length(x_val, y_val)?;
652
653        let n_features = x_train.ncols();
654        if x_val.ncols() != n_features {
655            return Err(SklearsError::FeatureMismatch {
656                expected: n_features,
657                actual: x_val.ncols(),
658            });
659        }
660
661        // Early stopping is most beneficial for regularized methods
662        match self.config.penalty {
663            Penalty::L1(alpha) => {
664                let cd_solver = CoordinateDescentSolver {
665                    max_iter: self.config.max_iter,
666                    tol: self.config.tol,
667                    cyclic: true,
668                    early_stopping_config: Some(early_stopping_config),
669                };
670
671                let (coef, intercept, validation_info) = cd_solver
672                    .solve_lasso_with_early_stopping_split(
673                        x_train,
674                        y_train,
675                        x_val,
676                        y_val,
677                        alpha,
678                        self.config.fit_intercept,
679                    )?;
680
681                let intercept_ = if self.config.fit_intercept {
682                    intercept
683                } else {
684                    None
685                };
686
687                let fitted_model = LinearRegression {
688                    config: self.config,
689                    state: PhantomData,
690                    coef_: Some(coef),
691                    intercept_,
692                    n_features_: Some(n_features),
693                };
694
695                Ok((fitted_model, validation_info))
696            }
697            Penalty::ElasticNet { l1_ratio, alpha } => {
698                let cd_solver = CoordinateDescentSolver {
699                    max_iter: self.config.max_iter,
700                    tol: self.config.tol,
701                    cyclic: true,
702                    early_stopping_config: Some(early_stopping_config),
703                };
704
705                let (coef, intercept, validation_info) = cd_solver
706                    .solve_elastic_net_with_early_stopping_split(
707                        x_train,
708                        y_train,
709                        x_val,
710                        y_val,
711                        alpha,
712                        l1_ratio,
713                        self.config.fit_intercept,
714                    )?;
715
716                let intercept_ = if self.config.fit_intercept {
717                    intercept
718                } else {
719                    None
720                };
721
722                let fitted_model = LinearRegression {
723                    config: self.config,
724                    state: PhantomData,
725                    coef_: Some(coef),
726                    intercept_,
727                    n_features_: Some(n_features),
728                };
729
730                Ok((fitted_model, validation_info))
731            }
732            Penalty::L2(_alpha) => {
733                // For Ridge regression, compute validation score manually
734                let fitted_model = LinearRegression::new()
735                    .penalty(self.config.penalty)
736                    .fit_intercept(self.config.fit_intercept)
737                    .fit(x_train, y_train)?;
738
739                // Compute validation R² score
740                let val_predictions = fitted_model.predict(x_val)?;
741                let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
742
743                let validation_info = ValidationInfo {
744                    validation_scores: vec![r2_score],
745                    best_score: Some(r2_score),
746                    best_iteration: 1,
747                    stopped_early: false,
748                    converged: true,
749                };
750
751                Ok((fitted_model, validation_info))
752            }
753            Penalty::None => {
754                // For OLS, compute validation score manually
755                let fitted_model = LinearRegression::new()
756                    .fit_intercept(self.config.fit_intercept)
757                    .fit(x_train, y_train)?;
758
759                // Compute validation R² score
760                let val_predictions = fitted_model.predict(x_val)?;
761                let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
762
763                let validation_info = ValidationInfo {
764                    validation_scores: vec![r2_score],
765                    best_score: Some(r2_score),
766                    best_iteration: 1,
767                    stopped_early: false,
768                    converged: true,
769                };
770
771                Ok((fitted_model, validation_info))
772            }
773        }
774    }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use approx::assert_abs_diff_eq;
782    use scirs2_core::ndarray::array;
783
784    #[test]
785    fn test_linear_regression_simple() {
786        let x = array![[1.0], [2.0], [3.0], [4.0]];
787        let y = array![2.0, 4.0, 6.0, 8.0];
788
789        let model = LinearRegression::new()
790            .fit_intercept(false)
791            .fit(&x, &y)
792            .expect("operation should succeed");
793
794        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
795
796        let predictions = model
797            .predict(&array![[5.0]])
798            .expect("prediction should succeed");
799        assert_abs_diff_eq!(predictions[0], 10.0, epsilon = 1e-10);
800    }
801
802    #[test]
803    fn test_linear_regression_with_intercept() {
804        let x = array![[1.0], [2.0], [3.0], [4.0]];
805        let y = array![3.0, 5.0, 7.0, 9.0]; // y = 2x + 1
806
807        let model = LinearRegression::new()
808            .fit_intercept(true)
809            .fit(&x, &y)
810            .expect("operation should succeed");
811
812        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
813        assert_abs_diff_eq!(
814            model.intercept().expect("intercept should be available"),
815            1.0,
816            epsilon = 1e-10
817        );
818    }
819
820    #[test]
821    fn test_ridge_regression() {
822        let x = array![[1.0], [2.0], [3.0], [4.0]];
823        let y = array![2.0, 4.0, 6.0, 8.0];
824
825        let model = LinearRegression::new()
826            .fit_intercept(false)
827            .regularization(0.1)
828            .fit(&x, &y)
829            .expect("operation should succeed");
830
831        // With regularization, coefficient should be slightly less than 2.0
832        assert!(model.coef()[0] < 2.0);
833        assert!(model.coef()[0] > 1.9);
834    }
835
836    #[test]
837    fn test_lasso_regression() {
838        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
839        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
840
841        // Test with small alpha
842        let model = LinearRegression::lasso(0.01)
843            .fit_intercept(false)
844            .fit(&x, &y)
845            .expect("operation should succeed");
846
847        // Should be close to OLS solution (coef = 2.0)
848        assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 0.1);
849
850        // Test with larger alpha
851        let model = LinearRegression::lasso(0.5)
852            .fit_intercept(false)
853            .fit(&x, &y)
854            .expect("operation should succeed");
855
856        // Coefficient should be shrunk
857        assert!(model.coef()[0] < 2.0);
858        assert!(model.coef()[0] > 1.0);
859    }
860
861    #[test]
862    fn test_elastic_net_regression() {
863        let x = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
864        let y = array![3.0, 6.0, 9.0, 12.0]; // y = 2*x1 + 2*x2
865
866        let model = LinearRegression::elastic_net(0.1, 0.5)
867            .fit_intercept(false)
868            .fit(&x, &y)
869            .expect("operation should succeed");
870
871        // Both coefficients should be shrunk but non-zero
872        println!(
873            "ElasticNet coef[0] = {}, coef[1] = {}",
874            model.coef()[0],
875            model.coef()[1]
876        );
877        assert!(model.coef()[0] > 0.0);
878        assert!(model.coef()[0] < 3.0); // More lenient bound for weak regularization
879        assert!(model.coef()[1] > 0.0);
880        assert!(model.coef()[1] < 3.0); // More lenient bound for weak regularization
881    }
882
883    #[test]
884    fn test_lasso_sparsity() {
885        // Create data where only first feature is relevant
886        let n_samples = 20;
887        let mut x = Array2::zeros((n_samples, 5));
888        let mut y = Array1::zeros(n_samples);
889
890        for i in 0..n_samples {
891            x[[i, 0]] = i as f64;
892            x[[i, 1]] = (i as f64) * 0.1; // weak feature
893                                          // Add deterministic noise instead of random
894            x[[i, 2]] = ((i * 7) % 10) as f64 / 10.0; // pseudo-random noise
895            x[[i, 3]] = ((i * 13) % 10) as f64 / 10.0; // pseudo-random noise
896            x[[i, 4]] = ((i * 17) % 10) as f64 / 10.0; // pseudo-random noise
897            y[i] = 2.0 * x[[i, 0]] + 0.05 * (i % 3) as f64;
898        }
899
900        // With strong L1 penalty, should select only the first feature
901        let model = LinearRegression::lasso(1.0)
902            .fit_intercept(false)
903            .fit(&x, &y)
904            .expect("operation should succeed");
905
906        let coef = model.coef();
907
908        // First coefficient should be non-zero
909        assert!(coef[0] > 0.5);
910
911        // Other coefficients should be zero or very small
912        for i in 2..5 {
913            assert_abs_diff_eq!(coef[i], 0.0, epsilon = 0.01);
914        }
915    }
916
917    #[test]
918    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
919    fn test_linear_regression_early_stopping_lasso() {
920        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
921
922        // Create larger dataset for meaningful validation split
923        let n_samples = 100;
924        let n_features = 8;
925        let mut x = Array2::zeros((n_samples, n_features));
926        let mut y = Array1::zeros(n_samples);
927
928        // Generate synthetic data with linear relationship
929        for i in 0..n_samples {
930            for j in 0..n_features {
931                x[[i, j]] = (i * j + 1) as f64 / 20.0;
932            }
933            // Only first few features are relevant
934            y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.8 * x[[i, 2]] + 0.1 * (i as f64 % 5.0);
935        }
936
937        let early_stopping_config = EarlyStoppingConfig {
938            criterion: StoppingCriterion::Patience(10),
939            validation_split: 0.25,
940            shuffle: true,
941            random_state: Some(42),
942            higher_is_better: true,
943            min_iterations: 5,
944            restore_best_weights: true,
945        };
946
947        let model = LinearRegression::lasso(0.1);
948        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
949
950        assert!(result.is_ok());
951        let (fitted_model, validation_info) = result.expect("operation should succeed");
952
953        // Check model properties
954        assert_eq!(fitted_model.coef().len(), n_features);
955        assert!(fitted_model.intercept().is_some());
956
957        // Check validation info
958        assert!(!validation_info.validation_scores.is_empty());
959        assert!(validation_info.best_score.is_some());
960        assert!(validation_info.best_iteration >= 1);
961
962        // Predictions should work
963        let predictions = fitted_model.predict(&x);
964        assert!(predictions.is_ok());
965        assert_eq!(
966            predictions.expect("operation should succeed").len(),
967            n_samples
968        );
969    }
970
971    #[test]
972    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
973    fn test_linear_regression_early_stopping_elastic_net() {
974        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
975
976        let x = array![
977            [1.0, 2.0, 0.5],
978            [2.0, 3.0, 1.0],
979            [3.0, 4.0, 1.5],
980            [4.0, 5.0, 2.0],
981            [5.0, 6.0, 2.5],
982            [6.0, 7.0, 3.0],
983            [7.0, 8.0, 3.5],
984            [8.0, 9.0, 4.0]
985        ];
986        let y = array![4.5, 7.0, 9.5, 12.0, 14.5, 17.0, 19.5, 22.0]; // y ≈ 1.5*x1 + x2 + x3
987
988        let early_stopping_config = EarlyStoppingConfig {
989            criterion: StoppingCriterion::TolerancePatience {
990                tolerance: 0.005,
991                patience: 3,
992            },
993            validation_split: 0.25,
994            shuffle: false,
995            random_state: Some(123),
996            higher_is_better: true,
997            min_iterations: 2,
998            restore_best_weights: true,
999        };
1000
1001        let model = LinearRegression::elastic_net(0.1, 0.7);
1002        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1003
1004        assert!(result.is_ok());
1005        let (fitted_model, validation_info) = result.expect("operation should succeed");
1006
1007        assert_eq!(fitted_model.coef().len(), 3);
1008        assert!(fitted_model.intercept().is_some());
1009        assert!(!validation_info.validation_scores.is_empty());
1010        assert!(validation_info.best_score.is_some());
1011    }
1012
1013    #[test]
1014    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1015    fn test_linear_regression_early_stopping_with_split() {
1016        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1017
1018        // Training data
1019        let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1020        let y_train = array![5.0, 8.0, 11.0, 14.0, 17.0]; // y = 2*x1 + x2
1021
1022        // Validation data
1023        let x_val = array![[6.0, 7.0], [7.0, 8.0]];
1024        let y_val = array![20.0, 23.0];
1025
1026        let early_stopping_config = EarlyStoppingConfig {
1027            criterion: StoppingCriterion::TargetScore(0.9),
1028            validation_split: 0.2, // Ignored since we provide split data
1029            shuffle: false,
1030            random_state: None,
1031            higher_is_better: true,
1032            min_iterations: 1,
1033            restore_best_weights: false,
1034        };
1035
1036        let model = LinearRegression::lasso(0.01);
1037        let result = model.fit_with_early_stopping_split(
1038            &x_train,
1039            &y_train,
1040            &x_val,
1041            &y_val,
1042            early_stopping_config,
1043        );
1044
1045        assert!(result.is_ok());
1046        let (fitted_model, validation_info) = result.expect("operation should succeed");
1047
1048        assert_eq!(fitted_model.coef().len(), 2);
1049        assert!(fitted_model.intercept().is_some());
1050        assert!(!validation_info.validation_scores.is_empty());
1051
1052        // Coefficients should be close to true values [2, 1] with small regularization
1053        let coef = fitted_model.coef();
1054        assert!((coef[0] - 2.0).abs() < 0.5);
1055        assert!((coef[1] - 1.0).abs() < 0.5);
1056    }
1057
1058    #[test]
1059    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1060    fn test_linear_regression_early_stopping_ols() {
1061        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1062
1063        let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1064        let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0]; // y = 2*x + 1
1065
1066        let early_stopping_config = EarlyStoppingConfig {
1067            criterion: StoppingCriterion::Patience(5),
1068            validation_split: 0.33,
1069            shuffle: false,
1070            random_state: None,
1071            higher_is_better: true,
1072            min_iterations: 1,
1073            restore_best_weights: true,
1074        };
1075
1076        // For OLS (no penalty), early stopping returns dummy validation info
1077        let model = LinearRegression::new().fit_intercept(true);
1078        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1079
1080        assert!(result.is_ok());
1081        let (fitted_model, validation_info) = result.expect("operation should succeed");
1082
1083        assert_eq!(fitted_model.coef().len(), 1);
1084        assert!(fitted_model.intercept().is_some());
1085
1086        // For OLS, validation info indicates no early stopping occurred
1087        assert!(!validation_info.stopped_early);
1088        assert!(validation_info.converged);
1089        assert_eq!(validation_info.best_iteration, 1);
1090
1091        // Model should still work correctly
1092        assert_abs_diff_eq!(fitted_model.coef()[0], 2.0, epsilon = 1e-10);
1093        assert_abs_diff_eq!(
1094            fitted_model
1095                .intercept()
1096                .expect("intercept should be available"),
1097            1.0,
1098            epsilon = 1e-10
1099        );
1100    }
1101
1102    #[test]
1103    #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1104    fn test_linear_regression_early_stopping_ridge() {
1105        use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1106
1107        let x = array![
1108            [1.0, 0.5],
1109            [2.0, 1.0],
1110            [3.0, 1.5],
1111            [4.0, 2.0],
1112            [5.0, 2.5],
1113            [6.0, 3.0]
1114        ];
1115        let y = array![2.5, 4.0, 5.5, 7.0, 8.5, 10.0]; // y ≈ 1.5*x1 + x2
1116
1117        let early_stopping_config = EarlyStoppingConfig {
1118            criterion: StoppingCriterion::Patience(3),
1119            validation_split: 0.33,
1120            shuffle: true,
1121            random_state: Some(456),
1122            higher_is_better: true,
1123            min_iterations: 1,
1124            restore_best_weights: false,
1125        };
1126
1127        // For Ridge regression, early stopping currently returns dummy validation info
1128        let model = LinearRegression::new()
1129            .regularization(0.1)
1130            .fit_intercept(true);
1131        let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1132
1133        assert!(result.is_ok());
1134        let (fitted_model, validation_info) = result.expect("operation should succeed");
1135
1136        assert_eq!(fitted_model.coef().len(), 2);
1137        assert!(fitted_model.intercept().is_some());
1138
1139        // For Ridge, early stopping is not fully implemented yet, so it should indicate convergence
1140        assert!(!validation_info.stopped_early);
1141        assert!(validation_info.converged);
1142    }
1143}