1use std::marker::PhantomData;
4
5use scirs2_core::ndarray::{s, Array};
6use scirs2_linalg::compat::ArrayLinalgExt;
7use sklears_core::{
9 error::{validate, Result, SklearsError},
10 traits::{Estimator, Fit, Predict, Score, Trained, Untrained},
11 types::{Array1, Array2, Float},
12};
13
14use crate::{Penalty, Solver};
15
16#[cfg(feature = "coordinate-descent")]
17use crate::coordinate_descent::CoordinateDescentSolver;
18
19#[cfg(feature = "coordinate-descent")]
20use crate::coordinate_descent::ValidationInfo;
21
22#[cfg(feature = "early-stopping")]
23use crate::early_stopping::EarlyStoppingConfig;
24
25#[derive(Debug, Clone)]
27pub struct LinearRegressionConfig {
28 pub fit_intercept: bool,
30 pub penalty: Penalty,
32 pub solver: Solver,
34 pub max_iter: usize,
36 pub tol: f64,
38 pub warm_start: bool,
40 #[cfg(feature = "gpu")]
42 pub use_gpu: bool,
43 #[cfg(feature = "gpu")]
45 pub gpu_min_size: usize,
46}
47
48impl Default for LinearRegressionConfig {
49 fn default() -> Self {
50 Self {
51 fit_intercept: true,
52 penalty: Penalty::None,
53 solver: Solver::Auto,
54 max_iter: 1000,
55 tol: 1e-4,
56 warm_start: false,
57 #[cfg(feature = "gpu")]
58 use_gpu: true,
59 #[cfg(feature = "gpu")]
60 gpu_min_size: 1000,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct LinearRegression<State = Untrained> {
68 config: LinearRegressionConfig,
69 state: PhantomData<State>,
70 coef_: Option<Array1<Float>>,
72 intercept_: Option<Float>,
73 n_features_: Option<usize>,
74}
75
76impl LinearRegression<Untrained> {
77 pub fn new() -> Self {
79 Self {
80 config: LinearRegressionConfig::default(),
81 state: PhantomData,
82 coef_: None,
83 intercept_: None,
84 n_features_: None,
85 }
86 }
87
88 pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
90 self.config.fit_intercept = fit_intercept;
91 self
92 }
93
94 pub fn regularization(mut self, alpha: f64) -> Self {
96 self.config.penalty = Penalty::L2(alpha);
97 self
98 }
99
100 pub fn lasso(alpha: f64) -> Self {
102 Self::new()
103 .penalty(Penalty::L1(alpha))
104 .solver(Solver::CoordinateDescent)
105 }
106
107 pub fn elastic_net(alpha: f64, l1_ratio: f64) -> Self {
109 Self::new()
110 .penalty(Penalty::ElasticNet { l1_ratio, alpha })
111 .solver(Solver::CoordinateDescent)
112 }
113
114 pub fn penalty(mut self, penalty: Penalty) -> Self {
116 self.config.penalty = penalty;
117 self
118 }
119
120 pub fn solver(mut self, solver: Solver) -> Self {
122 self.config.solver = solver;
123 self
124 }
125
126 pub fn max_iter(mut self, max_iter: usize) -> Self {
128 self.config.max_iter = max_iter;
129 self
130 }
131
132 pub fn warm_start(mut self, warm_start: bool) -> Self {
134 self.config.warm_start = warm_start;
135 self
136 }
137
138 #[cfg(feature = "gpu")]
140 pub fn use_gpu(mut self, use_gpu: bool) -> Self {
141 self.config.use_gpu = use_gpu;
142 self
143 }
144
145 #[cfg(feature = "gpu")]
147 pub fn gpu_min_size(mut self, min_size: usize) -> Self {
148 self.config.gpu_min_size = min_size;
149 self
150 }
151}
152
153impl Default for LinearRegression<Untrained> {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl Estimator for LinearRegression<Untrained> {
160 type Config = LinearRegressionConfig;
161 type Error = SklearsError;
162 type Float = Float;
163
164 fn config(&self) -> &Self::Config {
165 &self.config
166 }
167}
168
169impl Fit<Array2<Float>, Array1<Float>> for LinearRegression<Untrained> {
170 type Fitted = LinearRegression<Trained>;
171
172 fn fit(self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Self::Fitted> {
173 validate::check_consistent_length(x, y)?;
175
176 let n_samples = x.nrows();
177 let n_features = x.ncols();
178
179 let (x_with_intercept, n_params) = if self.config.fit_intercept {
181 let mut x_new = Array::ones((n_samples, n_features + 1));
182 x_new.slice_mut(s![.., 1..]).assign(x);
183 (x_new, n_features + 1)
184 } else {
185 (x.clone(), n_features)
186 };
187
188 let params = match self.config.penalty {
190 Penalty::None => {
191 #[cfg(feature = "gpu")]
193 if self.config.use_gpu && n_samples * n_features >= self.config.gpu_min_size {
194 match self.solve_ols_gpu(&x_with_intercept, y) {
196 Ok(params) => params,
197 Err(_) => {
198 self.solve_ols_cpu(&x_with_intercept, y)?
200 }
201 }
202 } else {
203 self.solve_ols_cpu(&x_with_intercept, y)?
204 }
205
206 #[cfg(not(feature = "gpu"))]
207 self.solve_ols_cpu(&x_with_intercept, y)?
208 }
209 Penalty::L2(alpha) => {
210 let xtx = x_with_intercept.t().dot(&x_with_intercept);
213 let xty = x_with_intercept.t().dot(y);
214
215 let mut regularized = xtx.clone();
217 let start_idx = if self.config.fit_intercept { 1 } else { 0 };
218 for i in start_idx..n_params {
219 regularized[[i, i]] += alpha;
220 }
221
222 regularized.solve(&xty).map_err(|e| {
223 SklearsError::NumericalError(format!("Failed to solve ridge regression: {}", e))
224 })?
225 }
226 Penalty::L1(alpha) => {
227 #[cfg(feature = "coordinate-descent")]
229 {
230 let cd_solver = CoordinateDescentSolver {
231 max_iter: self.config.max_iter,
232 tol: self.config.tol,
233 cyclic: true,
234 #[cfg(feature = "early-stopping")]
235 early_stopping_config: None,
236 };
237
238 let (coef, intercept) = cd_solver
239 .solve_lasso(x, y, alpha, self.config.fit_intercept)
240 .map_err(|e| {
241 SklearsError::NumericalError(format!(
242 "Coordinate descent failed: {}",
243 e
244 ))
245 })?;
246
247 if self.config.fit_intercept {
248 let mut params = Array::zeros(coef.len() + 1);
250 params[0] = intercept.unwrap_or(0.0);
251 params.slice_mut(s![1..]).assign(&coef);
252 params
253 } else {
254 coef
255 }
256 }
257 #[cfg(not(feature = "coordinate-descent"))]
258 {
259 return Err(SklearsError::InvalidParameter {
260 name: "penalty".to_string(),
261 reason:
262 "L1 regularization (Lasso) requires the 'coordinate-descent' feature"
263 .to_string(),
264 });
265 }
266 }
267 Penalty::ElasticNet { l1_ratio, alpha } => {
268 #[cfg(feature = "coordinate-descent")]
270 {
271 let cd_solver = CoordinateDescentSolver {
272 max_iter: self.config.max_iter,
273 tol: self.config.tol,
274 cyclic: true,
275 #[cfg(feature = "early-stopping")]
276 early_stopping_config: None,
277 };
278
279 let (coef, intercept) = cd_solver
280 .solve_elastic_net(x, y, alpha, l1_ratio, self.config.fit_intercept)
281 .map_err(|e| {
282 SklearsError::NumericalError(format!(
283 "Coordinate descent failed: {}",
284 e
285 ))
286 })?;
287
288 if self.config.fit_intercept {
289 let mut params = Array::zeros(coef.len() + 1);
291 params[0] = intercept.unwrap_or(0.0);
292 params.slice_mut(s![1..]).assign(&coef);
293 params
294 } else {
295 coef
296 }
297 }
298 #[cfg(not(feature = "coordinate-descent"))]
299 {
300 return Err(SklearsError::InvalidParameter {
301 name: "penalty".to_string(),
302 reason:
303 "ElasticNet regularization requires the 'coordinate-descent' feature"
304 .to_string(),
305 });
306 }
307 }
308 };
309
310 let (coef_, intercept_) = if self.config.fit_intercept {
312 let intercept = params[0];
313 let coef = params.slice(s![1..]).to_owned();
314 (coef, Some(intercept))
315 } else {
316 (params, None)
317 };
318
319 Ok(LinearRegression {
320 config: self.config,
321 state: PhantomData,
322 coef_: Some(coef_),
323 intercept_,
324 n_features_: Some(n_features),
325 })
326 }
327}
328
329impl LinearRegression<Untrained> {
330 fn solve_ols_cpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
332 let xtx = x.t().dot(x);
335 let xty = x.t().dot(y);
336
337 xtx.solve(&xty).map_err(|e| {
339 SklearsError::NumericalError(format!("Failed to solve linear system: {}", e))
340 })
341 }
342
343 #[cfg(feature = "gpu")]
345 fn solve_ols_gpu(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
346 use crate::gpu_acceleration::{GpuConfig, GpuLinearOps};
347
348 let gpu_config = GpuConfig {
350 device_id: 0,
351 use_pinned_memory: true,
352 min_problem_size: self.config.gpu_min_size,
353 ..Default::default()
354 };
355
356 let gpu_ops = GpuLinearOps::new(gpu_config).map_err(|e| {
357 SklearsError::NumericalError(format!("Failed to initialize GPU operations: {}", e))
358 })?;
359
360 if !gpu_ops.is_gpu_available() {
362 return Err(SklearsError::NumericalError(
363 "GPU not available, falling back to CPU".to_string(),
364 ));
365 }
366
367 let xt = gpu_ops.matrix_transpose(x)?;
369 let xtx = gpu_ops.matrix_multiply(&xt, x)?;
370
371 let xty = gpu_ops.matrix_vector_multiply(&xt, y)?;
373
374 gpu_ops.solve_linear_system(&xtx, &xty)
376 }
377
378 pub fn fit_with_warm_start(
382 self,
383 x: &Array2<Float>,
384 y: &Array1<Float>,
385 initial_coef: Option<&Array1<Float>>,
386 initial_intercept: Option<Float>,
387 ) -> Result<LinearRegression<Trained>> {
388 validate::check_consistent_length(x, y)?;
390
391 let n_features = x.ncols();
392
393 let params: Array1<Float> = match self.config.penalty {
395 Penalty::L1(_)
396 | Penalty::L2(_)
397 | Penalty::ElasticNet {
398 alpha: _,
399 l1_ratio: _,
400 } => {
401 #[cfg(feature = "coordinate-descent")]
402 {
403 let (alpha_val, l1_ratio) = match self.config.penalty {
404 Penalty::L1(alpha) => (alpha, 1.0),
405 Penalty::L2(alpha) => (alpha, 0.0),
406 Penalty::ElasticNet { alpha, l1_ratio } => (alpha, l1_ratio),
407 _ => unreachable!(),
408 };
409
410 let cd_solver = CoordinateDescentSolver {
411 max_iter: self.config.max_iter,
412 tol: self.config.tol,
413 cyclic: true,
414 #[cfg(feature = "early-stopping")]
415 early_stopping_config: None,
416 };
417
418 let (coef, intercept) = cd_solver
419 .solve_elastic_net_with_warm_start(
420 x,
421 y,
422 alpha_val,
423 l1_ratio,
424 self.config.fit_intercept,
425 initial_coef,
426 initial_intercept,
427 )
428 .map_err(|e| {
429 SklearsError::NumericalError(format!(
430 "Coordinate descent failed: {}",
431 e
432 ))
433 })?;
434
435 if self.config.fit_intercept {
436 let mut params = Array::zeros(coef.len() + 1);
438 params[0] = intercept.unwrap_or(0.0);
439 params.slice_mut(s![1..]).assign(&coef);
440 params
441 } else {
442 coef
443 }
444 }
445 #[cfg(not(feature = "coordinate-descent"))]
446 {
447 return Err(SklearsError::InvalidParameter {
448 name: "penalty".to_string(),
449 reason: "Warm start requires the 'coordinate-descent' feature".to_string(),
450 });
451 }
452 }
453 Penalty::None => {
454 return Err(SklearsError::InvalidParameter {
455 name: "penalty".to_string(),
456 reason:
457 "Warm start only supported for regularized methods (L1, L2, ElasticNet)"
458 .to_string(),
459 });
460 }
461 };
462
463 let (coef_, intercept_) = if self.config.fit_intercept {
465 let intercept = params[0];
466 let coef = params.slice(s![1..]).to_owned();
467 (coef, Some(intercept))
468 } else {
469 (params, None)
470 };
471
472 Ok(LinearRegression {
473 config: self.config,
474 state: PhantomData,
475 coef_: Some(coef_),
476 intercept_,
477 n_features_: Some(n_features),
478 })
479 }
480}
481
482impl LinearRegression<Trained> {
483 pub fn coef(&self) -> &Array1<Float> {
485 self.coef_.as_ref().expect("Model is trained")
486 }
487
488 pub fn intercept(&self) -> Option<Float> {
490 self.intercept_
491 }
492}
493
494impl Predict<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
495 fn predict(&self, x: &Array2<Float>) -> Result<Array1<Float>> {
496 let n_features = self.n_features_.expect("Model is trained");
497 validate::check_n_features(x, n_features)?;
498
499 let coef = self.coef_.as_ref().expect("Model is trained");
500 let mut predictions = x.dot(coef);
501
502 if let Some(intercept) = self.intercept_ {
503 predictions += intercept;
504 }
505
506 Ok(predictions)
507 }
508}
509
510impl Score<Array2<Float>, Array1<Float>> for LinearRegression<Trained> {
511 type Float = Float;
512
513 fn score(&self, x: &Array2<Float>, y: &Array1<Float>) -> Result<f64> {
514 let predictions = self.predict(x)?;
515
516 let ss_res = (&predictions - y).mapv(|x| x * x).sum();
518 let y_mean = y.mean().unwrap_or(0.0);
519 let ss_tot = y.mapv(|yi| (yi - y_mean).powi(2)).sum();
520
521 if ss_tot == 0.0 {
522 return Ok(1.0);
523 }
524
525 Ok(1.0 - (ss_res / ss_tot))
526 }
527}
528
529impl LinearRegression<Untrained> {
530 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
535 pub fn fit_with_early_stopping(
536 self,
537 x: &Array2<Float>,
538 y: &Array1<Float>,
539 early_stopping_config: EarlyStoppingConfig,
540 ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
541 validate::check_consistent_length(x, y)?;
543
544 let n_features = x.ncols();
545
546 match self.config.penalty {
548 Penalty::L1(alpha) => {
549 let cd_solver = CoordinateDescentSolver {
550 max_iter: self.config.max_iter,
551 tol: self.config.tol,
552 cyclic: true,
553 early_stopping_config: Some(early_stopping_config),
554 };
555
556 let (coef, intercept, validation_info) = cd_solver
557 .solve_lasso_with_early_stopping(x, y, alpha, self.config.fit_intercept)?;
558
559 let intercept_ = if self.config.fit_intercept {
560 intercept
561 } else {
562 None
563 };
564
565 let fitted_model = LinearRegression {
566 config: self.config,
567 state: PhantomData,
568 coef_: Some(coef),
569 intercept_,
570 n_features_: Some(n_features),
571 };
572
573 Ok((fitted_model, validation_info))
574 }
575 Penalty::ElasticNet { l1_ratio, alpha } => {
576 let cd_solver = CoordinateDescentSolver {
577 max_iter: self.config.max_iter,
578 tol: self.config.tol,
579 cyclic: true,
580 early_stopping_config: Some(early_stopping_config),
581 };
582
583 let (coef, intercept, validation_info) = cd_solver
584 .solve_elastic_net_with_early_stopping(
585 x,
586 y,
587 alpha,
588 l1_ratio,
589 self.config.fit_intercept,
590 )?;
591
592 let intercept_ = if self.config.fit_intercept {
593 intercept
594 } else {
595 None
596 };
597
598 let fitted_model = LinearRegression {
599 config: self.config,
600 state: PhantomData,
601 coef_: Some(coef),
602 intercept_,
603 n_features_: Some(n_features),
604 };
605
606 Ok((fitted_model, validation_info))
607 }
608 Penalty::L2(_alpha) => {
609 let fitted_model = self.fit(x, y)?;
612 let validation_info = ValidationInfo {
613 validation_scores: vec![1.0], best_score: Some(1.0),
615 best_iteration: 1,
616 stopped_early: false,
617 converged: true,
618 };
619 Ok((fitted_model, validation_info))
620 }
621 Penalty::None => {
622 let fitted_model = self.fit(x, y)?;
624 let validation_info = ValidationInfo {
625 validation_scores: vec![1.0], best_score: Some(1.0),
627 best_iteration: 1,
628 stopped_early: false,
629 converged: true,
630 };
631 Ok((fitted_model, validation_info))
632 }
633 }
634 }
635
636 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
641 pub fn fit_with_early_stopping_split(
642 self,
643 x_train: &Array2<Float>,
644 y_train: &Array1<Float>,
645 x_val: &Array2<Float>,
646 y_val: &Array1<Float>,
647 early_stopping_config: EarlyStoppingConfig,
648 ) -> Result<(LinearRegression<Trained>, ValidationInfo)> {
649 validate::check_consistent_length(x_train, y_train)?;
651 validate::check_consistent_length(x_val, y_val)?;
652
653 let n_features = x_train.ncols();
654 if x_val.ncols() != n_features {
655 return Err(SklearsError::FeatureMismatch {
656 expected: n_features,
657 actual: x_val.ncols(),
658 });
659 }
660
661 match self.config.penalty {
663 Penalty::L1(alpha) => {
664 let cd_solver = CoordinateDescentSolver {
665 max_iter: self.config.max_iter,
666 tol: self.config.tol,
667 cyclic: true,
668 early_stopping_config: Some(early_stopping_config),
669 };
670
671 let (coef, intercept, validation_info) = cd_solver
672 .solve_lasso_with_early_stopping_split(
673 x_train,
674 y_train,
675 x_val,
676 y_val,
677 alpha,
678 self.config.fit_intercept,
679 )?;
680
681 let intercept_ = if self.config.fit_intercept {
682 intercept
683 } else {
684 None
685 };
686
687 let fitted_model = LinearRegression {
688 config: self.config,
689 state: PhantomData,
690 coef_: Some(coef),
691 intercept_,
692 n_features_: Some(n_features),
693 };
694
695 Ok((fitted_model, validation_info))
696 }
697 Penalty::ElasticNet { l1_ratio, alpha } => {
698 let cd_solver = CoordinateDescentSolver {
699 max_iter: self.config.max_iter,
700 tol: self.config.tol,
701 cyclic: true,
702 early_stopping_config: Some(early_stopping_config),
703 };
704
705 let (coef, intercept, validation_info) = cd_solver
706 .solve_elastic_net_with_early_stopping_split(
707 x_train,
708 y_train,
709 x_val,
710 y_val,
711 alpha,
712 l1_ratio,
713 self.config.fit_intercept,
714 )?;
715
716 let intercept_ = if self.config.fit_intercept {
717 intercept
718 } else {
719 None
720 };
721
722 let fitted_model = LinearRegression {
723 config: self.config,
724 state: PhantomData,
725 coef_: Some(coef),
726 intercept_,
727 n_features_: Some(n_features),
728 };
729
730 Ok((fitted_model, validation_info))
731 }
732 Penalty::L2(_alpha) => {
733 let fitted_model = LinearRegression::new()
735 .penalty(self.config.penalty)
736 .fit_intercept(self.config.fit_intercept)
737 .fit(x_train, y_train)?;
738
739 let val_predictions = fitted_model.predict(x_val)?;
741 let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
742
743 let validation_info = ValidationInfo {
744 validation_scores: vec![r2_score],
745 best_score: Some(r2_score),
746 best_iteration: 1,
747 stopped_early: false,
748 converged: true,
749 };
750
751 Ok((fitted_model, validation_info))
752 }
753 Penalty::None => {
754 let fitted_model = LinearRegression::new()
756 .fit_intercept(self.config.fit_intercept)
757 .fit(x_train, y_train)?;
758
759 let val_predictions = fitted_model.predict(x_val)?;
761 let r2_score = crate::coordinate_descent::compute_r2_score(&val_predictions, y_val);
762
763 let validation_info = ValidationInfo {
764 validation_scores: vec![r2_score],
765 best_score: Some(r2_score),
766 best_iteration: 1,
767 stopped_early: false,
768 converged: true,
769 };
770
771 Ok((fitted_model, validation_info))
772 }
773 }
774 }
775}
776
777#[allow(non_snake_case)]
778#[cfg(test)]
779mod tests {
780 use super::*;
781 use approx::assert_abs_diff_eq;
782 use scirs2_core::ndarray::array;
783
784 #[test]
785 fn test_linear_regression_simple() {
786 let x = array![[1.0], [2.0], [3.0], [4.0]];
787 let y = array![2.0, 4.0, 6.0, 8.0];
788
789 let model = LinearRegression::new()
790 .fit_intercept(false)
791 .fit(&x, &y)
792 .expect("operation should succeed");
793
794 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
795
796 let predictions = model
797 .predict(&array![[5.0]])
798 .expect("prediction should succeed");
799 assert_abs_diff_eq!(predictions[0], 10.0, epsilon = 1e-10);
800 }
801
802 #[test]
803 fn test_linear_regression_with_intercept() {
804 let x = array![[1.0], [2.0], [3.0], [4.0]];
805 let y = array![3.0, 5.0, 7.0, 9.0]; let model = LinearRegression::new()
808 .fit_intercept(true)
809 .fit(&x, &y)
810 .expect("operation should succeed");
811
812 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 1e-10);
813 assert_abs_diff_eq!(
814 model.intercept().expect("intercept should be available"),
815 1.0,
816 epsilon = 1e-10
817 );
818 }
819
820 #[test]
821 fn test_ridge_regression() {
822 let x = array![[1.0], [2.0], [3.0], [4.0]];
823 let y = array![2.0, 4.0, 6.0, 8.0];
824
825 let model = LinearRegression::new()
826 .fit_intercept(false)
827 .regularization(0.1)
828 .fit(&x, &y)
829 .expect("operation should succeed");
830
831 assert!(model.coef()[0] < 2.0);
833 assert!(model.coef()[0] > 1.9);
834 }
835
836 #[test]
837 fn test_lasso_regression() {
838 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0]];
839 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
840
841 let model = LinearRegression::lasso(0.01)
843 .fit_intercept(false)
844 .fit(&x, &y)
845 .expect("operation should succeed");
846
847 assert_abs_diff_eq!(model.coef()[0], 2.0, epsilon = 0.1);
849
850 let model = LinearRegression::lasso(0.5)
852 .fit_intercept(false)
853 .fit(&x, &y)
854 .expect("operation should succeed");
855
856 assert!(model.coef()[0] < 2.0);
858 assert!(model.coef()[0] > 1.0);
859 }
860
861 #[test]
862 fn test_elastic_net_regression() {
863 let x = array![[1.0, 0.5], [2.0, 1.0], [3.0, 1.5], [4.0, 2.0]];
864 let y = array![3.0, 6.0, 9.0, 12.0]; let model = LinearRegression::elastic_net(0.1, 0.5)
867 .fit_intercept(false)
868 .fit(&x, &y)
869 .expect("operation should succeed");
870
871 println!(
873 "ElasticNet coef[0] = {}, coef[1] = {}",
874 model.coef()[0],
875 model.coef()[1]
876 );
877 assert!(model.coef()[0] > 0.0);
878 assert!(model.coef()[0] < 3.0); assert!(model.coef()[1] > 0.0);
880 assert!(model.coef()[1] < 3.0); }
882
883 #[test]
884 fn test_lasso_sparsity() {
885 let n_samples = 20;
887 let mut x = Array2::zeros((n_samples, 5));
888 let mut y = Array1::zeros(n_samples);
889
890 for i in 0..n_samples {
891 x[[i, 0]] = i as f64;
892 x[[i, 1]] = (i as f64) * 0.1; x[[i, 2]] = ((i * 7) % 10) as f64 / 10.0; x[[i, 3]] = ((i * 13) % 10) as f64 / 10.0; x[[i, 4]] = ((i * 17) % 10) as f64 / 10.0; y[i] = 2.0 * x[[i, 0]] + 0.05 * (i % 3) as f64;
898 }
899
900 let model = LinearRegression::lasso(1.0)
902 .fit_intercept(false)
903 .fit(&x, &y)
904 .expect("operation should succeed");
905
906 let coef = model.coef();
907
908 assert!(coef[0] > 0.5);
910
911 for i in 2..5 {
913 assert_abs_diff_eq!(coef[i], 0.0, epsilon = 0.01);
914 }
915 }
916
917 #[test]
918 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
919 fn test_linear_regression_early_stopping_lasso() {
920 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
921
922 let n_samples = 100;
924 let n_features = 8;
925 let mut x = Array2::zeros((n_samples, n_features));
926 let mut y = Array1::zeros(n_samples);
927
928 for i in 0..n_samples {
930 for j in 0..n_features {
931 x[[i, j]] = (i * j + 1) as f64 / 20.0;
932 }
933 y[i] = 2.0 * x[[i, 0]] + 1.5 * x[[i, 1]] + 0.8 * x[[i, 2]] + 0.1 * (i as f64 % 5.0);
935 }
936
937 let early_stopping_config = EarlyStoppingConfig {
938 criterion: StoppingCriterion::Patience(10),
939 validation_split: 0.25,
940 shuffle: true,
941 random_state: Some(42),
942 higher_is_better: true,
943 min_iterations: 5,
944 restore_best_weights: true,
945 };
946
947 let model = LinearRegression::lasso(0.1);
948 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
949
950 assert!(result.is_ok());
951 let (fitted_model, validation_info) = result.expect("operation should succeed");
952
953 assert_eq!(fitted_model.coef().len(), n_features);
955 assert!(fitted_model.intercept().is_some());
956
957 assert!(!validation_info.validation_scores.is_empty());
959 assert!(validation_info.best_score.is_some());
960 assert!(validation_info.best_iteration >= 1);
961
962 let predictions = fitted_model.predict(&x);
964 assert!(predictions.is_ok());
965 assert_eq!(
966 predictions.expect("operation should succeed").len(),
967 n_samples
968 );
969 }
970
971 #[test]
972 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
973 fn test_linear_regression_early_stopping_elastic_net() {
974 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
975
976 let x = array![
977 [1.0, 2.0, 0.5],
978 [2.0, 3.0, 1.0],
979 [3.0, 4.0, 1.5],
980 [4.0, 5.0, 2.0],
981 [5.0, 6.0, 2.5],
982 [6.0, 7.0, 3.0],
983 [7.0, 8.0, 3.5],
984 [8.0, 9.0, 4.0]
985 ];
986 let y = array![4.5, 7.0, 9.5, 12.0, 14.5, 17.0, 19.5, 22.0]; let early_stopping_config = EarlyStoppingConfig {
989 criterion: StoppingCriterion::TolerancePatience {
990 tolerance: 0.005,
991 patience: 3,
992 },
993 validation_split: 0.25,
994 shuffle: false,
995 random_state: Some(123),
996 higher_is_better: true,
997 min_iterations: 2,
998 restore_best_weights: true,
999 };
1000
1001 let model = LinearRegression::elastic_net(0.1, 0.7);
1002 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1003
1004 assert!(result.is_ok());
1005 let (fitted_model, validation_info) = result.expect("operation should succeed");
1006
1007 assert_eq!(fitted_model.coef().len(), 3);
1008 assert!(fitted_model.intercept().is_some());
1009 assert!(!validation_info.validation_scores.is_empty());
1010 assert!(validation_info.best_score.is_some());
1011 }
1012
1013 #[test]
1014 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1015 fn test_linear_regression_early_stopping_with_split() {
1016 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1017
1018 let x_train = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]];
1020 let y_train = array![5.0, 8.0, 11.0, 14.0, 17.0]; let x_val = array![[6.0, 7.0], [7.0, 8.0]];
1024 let y_val = array![20.0, 23.0];
1025
1026 let early_stopping_config = EarlyStoppingConfig {
1027 criterion: StoppingCriterion::TargetScore(0.9),
1028 validation_split: 0.2, shuffle: false,
1030 random_state: None,
1031 higher_is_better: true,
1032 min_iterations: 1,
1033 restore_best_weights: false,
1034 };
1035
1036 let model = LinearRegression::lasso(0.01);
1037 let result = model.fit_with_early_stopping_split(
1038 &x_train,
1039 &y_train,
1040 &x_val,
1041 &y_val,
1042 early_stopping_config,
1043 );
1044
1045 assert!(result.is_ok());
1046 let (fitted_model, validation_info) = result.expect("operation should succeed");
1047
1048 assert_eq!(fitted_model.coef().len(), 2);
1049 assert!(fitted_model.intercept().is_some());
1050 assert!(!validation_info.validation_scores.is_empty());
1051
1052 let coef = fitted_model.coef();
1054 assert!((coef[0] - 2.0).abs() < 0.5);
1055 assert!((coef[1] - 1.0).abs() < 0.5);
1056 }
1057
1058 #[test]
1059 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1060 fn test_linear_regression_early_stopping_ols() {
1061 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1062
1063 let x = array![[1.0], [2.0], [3.0], [4.0], [5.0], [6.0]];
1064 let y = array![3.0, 5.0, 7.0, 9.0, 11.0, 13.0]; let early_stopping_config = EarlyStoppingConfig {
1067 criterion: StoppingCriterion::Patience(5),
1068 validation_split: 0.33,
1069 shuffle: false,
1070 random_state: None,
1071 higher_is_better: true,
1072 min_iterations: 1,
1073 restore_best_weights: true,
1074 };
1075
1076 let model = LinearRegression::new().fit_intercept(true);
1078 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1079
1080 assert!(result.is_ok());
1081 let (fitted_model, validation_info) = result.expect("operation should succeed");
1082
1083 assert_eq!(fitted_model.coef().len(), 1);
1084 assert!(fitted_model.intercept().is_some());
1085
1086 assert!(!validation_info.stopped_early);
1088 assert!(validation_info.converged);
1089 assert_eq!(validation_info.best_iteration, 1);
1090
1091 assert_abs_diff_eq!(fitted_model.coef()[0], 2.0, epsilon = 1e-10);
1093 assert_abs_diff_eq!(
1094 fitted_model
1095 .intercept()
1096 .expect("intercept should be available"),
1097 1.0,
1098 epsilon = 1e-10
1099 );
1100 }
1101
1102 #[test]
1103 #[cfg(all(feature = "coordinate-descent", feature = "early-stopping"))]
1104 fn test_linear_regression_early_stopping_ridge() {
1105 use crate::early_stopping::{EarlyStoppingConfig, StoppingCriterion};
1106
1107 let x = array![
1108 [1.0, 0.5],
1109 [2.0, 1.0],
1110 [3.0, 1.5],
1111 [4.0, 2.0],
1112 [5.0, 2.5],
1113 [6.0, 3.0]
1114 ];
1115 let y = array![2.5, 4.0, 5.5, 7.0, 8.5, 10.0]; let early_stopping_config = EarlyStoppingConfig {
1118 criterion: StoppingCriterion::Patience(3),
1119 validation_split: 0.33,
1120 shuffle: true,
1121 random_state: Some(456),
1122 higher_is_better: true,
1123 min_iterations: 1,
1124 restore_best_weights: false,
1125 };
1126
1127 let model = LinearRegression::new()
1129 .regularization(0.1)
1130 .fit_intercept(true);
1131 let result = model.fit_with_early_stopping(&x, &y, early_stopping_config);
1132
1133 assert!(result.is_ok());
1134 let (fitted_model, validation_info) = result.expect("operation should succeed");
1135
1136 assert_eq!(fitted_model.coef().len(), 2);
1137 assert!(fitted_model.intercept().is_some());
1138
1139 assert!(!validation_info.stopped_early);
1141 assert!(validation_info.converged);
1142 }
1143}