Skip to main content

scry_learn/linear/
regression.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Linear regression via OLS (Ordinary Least Squares).
3
4use crate::accel;
5use crate::dataset::Dataset;
6use crate::error::{Result, ScryLearnError};
7use crate::sparse::{CscMatrix, CsrMatrix};
8
9/// Solver strategy for linear regression.
10#[derive(Clone, Debug, Default)]
11#[non_exhaustive]
12pub enum LinRegSolver {
13    /// Normal equation: (X^T X + aI)^-1 X^T y. Fast but numerically fragile.
14    Normal,
15    /// QR decomposition. More robust than Normal, faster than SVD.
16    Qr,
17    /// SVD (pseudoinverse). Most robust, handles rank-deficient and wide matrices.
18    Svd,
19    /// Auto: use Normal for well-conditioned problems, fall back to SVD otherwise.
20    #[default]
21    Auto,
22}
23
24/// Linear regression model.
25///
26/// Uses the **OLS** closed-form normal equations solution by default:
27/// `β = (XᵀX + αI)⁻¹ Xᵀy`. Set `alpha > 0` for Ridge (L2) regularization.
28///
29/// Alternative solvers (QR, SVD) provide better numerical stability for
30/// ill-conditioned or rank-deficient problems.
31///
32/// When the `gpu` feature is enabled and the dataset is large enough,
33/// the XᵀX/Xᵀy computation is automatically offloaded to GPU compute shaders.
34#[derive(Clone)]
35#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
36#[non_exhaustive]
37pub struct LinearRegression {
38    /// Learned coefficients (one per feature).
39    coefficients: Vec<f64>,
40    /// Learned intercept (bias term).
41    intercept: f64,
42    /// L2 regularization strength (0.0 = OLS, >0 = Ridge).
43    alpha: f64,
44    /// Solver strategy.
45    #[cfg_attr(feature = "serde", serde(skip))]
46    solver: LinRegSolver,
47    fitted: bool,
48    #[cfg_attr(feature = "serde", serde(default))]
49    _schema_version: u32,
50}
51
52impl LinearRegression {
53    /// Create a new linear regression model.
54    pub fn new() -> Self {
55        Self {
56            coefficients: Vec::new(),
57            intercept: 0.0,
58            alpha: 0.0,
59            solver: LinRegSolver::Auto,
60            fitted: false,
61            _schema_version: crate::version::SCHEMA_VERSION,
62        }
63    }
64
65    /// Set L2 regularization strength (Ridge regression).
66    pub fn alpha(mut self, a: f64) -> Self {
67        self.alpha = a;
68        self
69    }
70
71    /// Set the solver strategy.
72    pub fn solver(mut self, s: LinRegSolver) -> Self {
73        self.solver = s;
74        self
75    }
76
77    /// Train the model on the given dataset.
78    pub fn fit(&mut self, data: &Dataset) -> Result<()> {
79        data.validate_finite()?;
80        if let Some(csc) = data.sparse_csc() {
81            return self.fit_sparse(csc, &data.target);
82        }
83        let n = data.n_samples();
84        let m = data.n_features();
85        if n == 0 {
86            return Err(ScryLearnError::EmptyDataset);
87        }
88
89        match &self.solver {
90            LinRegSolver::Normal => self.fit_normal(data),
91            LinRegSolver::Qr => self.fit_qr(data),
92            LinRegSolver::Svd => self.fit_svd(data),
93            LinRegSolver::Auto => {
94                // Normal equations are fast (~350µs). Fall back to SVD for
95                // underdetermined systems or singular matrices.
96                if m >= n {
97                    return self.fit_svd(data);
98                }
99                match self.fit_normal(data) {
100                    Ok(()) => Ok(()),
101                    Err(_) => self.fit_svd(data),
102                }
103            }
104        }
105    }
106
107    /// Normal equation solver (existing code path).
108    fn fit_normal(&mut self, data: &Dataset) -> Result<()> {
109        let n = data.n_samples();
110        let m = data.n_features();
111        let dim = m + 1;
112
113        let backend = accel::auto();
114        let mat = data.matrix();
115        let (mut xtx, mut xty) = backend.xtx_xty_contiguous(mat.as_slice(), &data.target, n, m);
116
117        for j in 1..dim {
118            xtx[j * dim + j] += self.alpha;
119        }
120
121        let beta = solve_linear(dim, &mut xtx, &mut xty)?;
122
123        self.intercept = beta[0];
124        self.coefficients = beta[1..].to_vec();
125        self.fitted = true;
126        Ok(())
127    }
128
129    /// Build augmented column-major feature matrix [1, x1, x2, ...].
130    fn build_augmented(data: &Dataset) -> (Vec<f64>, usize, usize) {
131        let n = data.n_samples();
132        let m = data.n_features();
133        let dim = m + 1;
134        let mat = data.matrix();
135        let mut x = vec![0.0; n * dim];
136        for i in 0..n {
137            x[i] = 1.0;
138        }
139        for j in 0..m {
140            let offset = (j + 1) * n;
141            x[offset..offset + n].copy_from_slice(mat.col(j));
142        }
143        (x, n, dim)
144    }
145
146    /// Build augmented matrix with Ridge regularization rows appended.
147    fn build_regularized(data: &Dataset, alpha: f64) -> (Vec<f64>, Vec<f64>, usize, usize) {
148        let n = data.n_samples();
149        let m = data.n_features();
150        let dim = m + 1;
151        let mat = data.matrix();
152        let sqrt_a = alpha.sqrt();
153        let aug_rows = n + m;
154        let mut x_aug = vec![0.0; aug_rows * dim];
155        let mut y_aug = vec![0.0; aug_rows];
156
157        for i in 0..n {
158            x_aug[i] = 1.0;
159        }
160        for j in 0..m {
161            let offset = (j + 1) * aug_rows;
162            x_aug[offset..offset + n].copy_from_slice(mat.col(j));
163        }
164        y_aug[..n].copy_from_slice(&data.target);
165
166        for j in 0..m {
167            x_aug[(j + 1) * aug_rows + n + j] = sqrt_a;
168        }
169
170        (x_aug, y_aug, aug_rows, dim)
171    }
172
173    /// QR decomposition solver.
174    fn fit_qr(&mut self, data: &Dataset) -> Result<()> {
175        if self.alpha > 0.0 {
176            let (x_aug, y_aug, aug_rows, dim) = Self::build_regularized(data, self.alpha);
177            let beta = super::qr::qr_solve(&x_aug, &y_aug, aug_rows, dim)?;
178            self.intercept = beta[0];
179            self.coefficients = beta[1..].to_vec();
180        } else {
181            let (x, n, dim) = Self::build_augmented(data);
182            let beta = super::qr::qr_solve(&x, &data.target, n, dim)?;
183            self.intercept = beta[0];
184            self.coefficients = beta[1..].to_vec();
185        }
186
187        self.fitted = true;
188        Ok(())
189    }
190
191    /// SVD solver.
192    fn fit_svd(&mut self, data: &Dataset) -> Result<()> {
193        if self.alpha > 0.0 {
194            let (x_aug, y_aug, aug_rows, dim) = Self::build_regularized(data, self.alpha);
195            let result = super::svd::svd_solve(&x_aug, &y_aug, aug_rows, dim)?;
196            self.intercept = result.coefficients[0];
197            self.coefficients = result.coefficients[1..].to_vec();
198        } else {
199            let (x, n, dim) = Self::build_augmented(data);
200            let result = super::svd::svd_solve(&x, &data.target, n, dim)?;
201            self.intercept = result.coefficients[0];
202            self.coefficients = result.coefficients[1..].to_vec();
203        }
204
205        self.fitted = true;
206        Ok(())
207    }
208
209    /// Predict target values.
210    pub fn predict(&self, features: &[Vec<f64>]) -> Result<Vec<f64>> {
211        crate::version::check_schema_version(self._schema_version)?;
212        if !self.fitted {
213            return Err(ScryLearnError::NotFitted);
214        }
215        Ok(features
216            .iter()
217            .map(|row| {
218                let mut y = self.intercept;
219                for (j, &coeff) in self.coefficients.iter().enumerate() {
220                    if j < row.len() {
221                        y += coeff * row[j];
222                    }
223                }
224                y
225            })
226            .collect())
227    }
228
229    /// Fit on sparse features (CSC format for column-oriented access).
230    ///
231    /// Builds XᵀX and Xᵀy by iterating only over non-zero entries.
232    pub fn fit_sparse(&mut self, features: &CscMatrix, target: &[f64]) -> Result<()> {
233        let n = features.n_rows();
234        let m = features.n_cols();
235        if n == 0 {
236            return Err(ScryLearnError::EmptyDataset);
237        }
238        if target.len() != n {
239            return Err(ScryLearnError::InvalidParameter(format!(
240                "target length {} != n_rows {}",
241                target.len(),
242                n
243            )));
244        }
245
246        let dim = m + 1; // intercept + features
247
248        // Build XᵀX (dim×dim row-major) and Xᵀy (dim) with intercept column.
249        let mut xtx = vec![0.0; dim * dim];
250        let mut xty = vec![0.0; dim];
251
252        // Intercept-intercept: XᵀX[0][0] = n
253        xtx[0] = n as f64;
254
255        // Intercept-target: Xᵀy[0] = Σ y_i
256        xty[0] = target.iter().sum();
257
258        // Intercept-feature cross terms: XᵀX[0][j+1] = XᵀX[j+1][0] = Σ x_ij
259        for j in 0..m {
260            let col = features.col(j);
261            let sum: f64 = col.iter().map(|(_, v)| v).sum();
262            xtx[j + 1] = sum;
263            xtx[(j + 1) * dim] = sum;
264
265            // Xᵀy[j+1] = Σ x_ij * y_i
266            let mut dot = 0.0;
267            for (row_idx, val) in col.iter() {
268                dot += val * target[row_idx];
269            }
270            xty[j + 1] = dot;
271        }
272
273        // Feature-feature: XᵀX[i+1][j+1] = Σ_k x_ki * x_kj (only non-zero entries)
274        // For efficiency, use scatter approach: for each column j, scatter into a dense vector,
275        // then dot with each column i.
276        let mut dense_col = vec![0.0; n];
277        for j in 0..m {
278            // Scatter column j into dense.
279            for (row_idx, val) in features.col(j).iter() {
280                dense_col[row_idx] = val;
281            }
282
283            // Diagonal: XᵀX[j+1][j+1]
284            let mut diag = 0.0;
285            for (row_idx, val) in features.col(j).iter() {
286                diag += val * dense_col[row_idx];
287            }
288            xtx[(j + 1) * dim + j + 1] = diag;
289
290            // Off-diagonal with columns i < j.
291            for i in 0..j {
292                let mut dot = 0.0;
293                for (row_idx, val) in features.col(i).iter() {
294                    dot += val * dense_col[row_idx];
295                }
296                xtx[(i + 1) * dim + j + 1] = dot;
297                xtx[(j + 1) * dim + i + 1] = dot;
298            }
299
300            // Clear dense_col.
301            for (row_idx, _) in features.col(j).iter() {
302                dense_col[row_idx] = 0.0;
303            }
304        }
305
306        // Add Ridge regularization.
307        for j in 1..dim {
308            xtx[j * dim + j] += self.alpha;
309        }
310
311        let beta = solve_linear(dim, &mut xtx, &mut xty)?;
312        self.intercept = beta[0];
313        self.coefficients = beta[1..].to_vec();
314        self.fitted = true;
315        Ok(())
316    }
317
318    /// Predict from sparse features (CSR format for row-oriented access).
319    pub fn predict_sparse(&self, features: &CsrMatrix) -> Result<Vec<f64>> {
320        if !self.fitted {
321            return Err(ScryLearnError::NotFitted);
322        }
323        Ok((0..features.n_rows())
324            .map(|i| {
325                let mut y = self.intercept;
326                for (col, val) in features.row(i).iter() {
327                    if col < self.coefficients.len() {
328                        y += self.coefficients[col] * val;
329                    }
330                }
331                y
332            })
333            .collect())
334    }
335
336    /// Get learned coefficients.
337    pub fn coefficients(&self) -> &[f64] {
338        &self.coefficients
339    }
340
341    /// Get learned intercept.
342    pub fn intercept(&self) -> f64 {
343        self.intercept
344    }
345}
346
347impl Default for LinearRegression {
348    fn default() -> Self {
349        Self::new()
350    }
351}
352
353/// Gauss-Jordan elimination for Ax = b.
354fn solve_linear(n: usize, a: &mut [f64], b: &mut [f64]) -> Result<Vec<f64>> {
355    for col in 0..n {
356        let mut max_row = col;
357        let mut max_val = a[col * n + col].abs();
358        for row in (col + 1)..n {
359            let val = a[row * n + col].abs();
360            if val > max_val {
361                max_val = val;
362                max_row = row;
363            }
364        }
365        if max_val < crate::constants::SINGULAR_THRESHOLD {
366            return Err(ScryLearnError::InvalidParameter(
367                "singular matrix — features may be linearly dependent".into(),
368            ));
369        }
370
371        if max_row != col {
372            for k in 0..n {
373                a.swap(col * n + k, max_row * n + k);
374            }
375            b.swap(col, max_row);
376        }
377
378        let pivot = a[col * n + col];
379        for k in col..n {
380            a[col * n + k] /= pivot;
381        }
382        b[col] /= pivot;
383
384        for row in 0..n {
385            if row == col {
386                continue;
387            }
388            let factor = a[row * n + col];
389            for k in col..n {
390                a[row * n + k] -= factor * a[col * n + k];
391            }
392            b[row] -= factor * b[col];
393        }
394    }
395
396    Ok(b.to_vec())
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_linear_regression_y_equals_x() {
405        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
406        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
407        let data = Dataset::new(features, target, vec!["x".into()], "y");
408
409        let mut lr = LinearRegression::new();
410        lr.fit(&data).unwrap();
411
412        assert!(
413            (lr.coefficients()[0] - 2.0).abs() < 1e-6,
414            "coefficient should be ~2.0, got {}",
415            lr.coefficients()[0]
416        );
417        assert!(
418            (lr.intercept() - 3.0).abs() < 1e-6,
419            "intercept should be ~3.0, got {}",
420            lr.intercept()
421        );
422    }
423
424    #[test]
425    fn test_ridge_regression() {
426        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
427        let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
428        let data = Dataset::new(features, target, vec!["x".into()], "y");
429
430        let mut lr = LinearRegression::new().alpha(1.0);
431        lr.fit(&data).unwrap();
432
433        assert!(lr.coefficients()[0] < 2.0);
434        assert!(lr.coefficients()[0] > 1.0);
435    }
436
437    #[test]
438    fn test_svd_solver_matches_normal() {
439        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
440        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
441        let data = Dataset::new(features, target, vec!["x".into()], "y");
442
443        let mut lr_normal = LinearRegression::new();
444        lr_normal.fit(&data).unwrap();
445
446        let mut lr_svd = LinearRegression::new().solver(LinRegSolver::Svd);
447        lr_svd.fit(&data).unwrap();
448
449        assert!(
450            (lr_normal.coefficients()[0] - lr_svd.coefficients()[0]).abs() < 1e-6,
451            "Normal={} vs SVD={}",
452            lr_normal.coefficients()[0],
453            lr_svd.coefficients()[0]
454        );
455        assert!(
456            (lr_normal.intercept() - lr_svd.intercept()).abs() < 1e-6,
457            "Normal intercept={} vs SVD={}",
458            lr_normal.intercept(),
459            lr_svd.intercept()
460        );
461    }
462
463    #[test]
464    fn test_qr_solver_matches_normal() {
465        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
466        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
467        let data = Dataset::new(features, target, vec!["x".into()], "y");
468
469        let mut lr_normal = LinearRegression::new();
470        lr_normal.fit(&data).unwrap();
471
472        let mut lr_qr = LinearRegression::new().solver(LinRegSolver::Qr);
473        lr_qr.fit(&data).unwrap();
474
475        assert!(
476            (lr_normal.coefficients()[0] - lr_qr.coefficients()[0]).abs() < 1e-6,
477            "Normal={} vs QR={}",
478            lr_normal.coefficients()[0],
479            lr_qr.coefficients()[0]
480        );
481        assert!(
482            (lr_normal.intercept() - lr_qr.intercept()).abs() < 1e-6,
483            "Normal intercept={} vs QR={}",
484            lr_normal.intercept(),
485            lr_qr.intercept()
486        );
487    }
488
489    #[test]
490    fn test_svd_handles_ill_conditioned() {
491        let n = 5;
492        let mut features = vec![vec![0.0; n]; n];
493        for j in 0..n {
494            for i in 0..n {
495                features[j][i] = 1.0 / (i + j + 1) as f64;
496            }
497        }
498        let true_beta = vec![1.0; n];
499        let target: Vec<f64> = (0..n)
500            .map(|i| (0..n).map(|j| features[j][i] * true_beta[j]).sum())
501            .collect();
502        let names: Vec<String> = (0..n).map(|j| format!("f{j}")).collect();
503        let data = Dataset::new(features, target, names, "y");
504
505        let mut lr = LinearRegression::new().solver(LinRegSolver::Svd);
506        lr.fit(&data).unwrap();
507
508        for (i, &c) in lr.coefficients().iter().enumerate() {
509            assert!(
510                (c - 1.0).abs() < 0.5,
511                "SVD Hilbert coeff[{}] = {}, expected ~1.0",
512                i,
513                c
514            );
515        }
516    }
517
518    #[test]
519    fn test_ridge_with_svd() {
520        let features = vec![vec![1.0, 2.0, 3.0, 4.0, 5.0]];
521        let target = vec![2.0, 4.0, 6.0, 8.0, 10.0];
522        let data = Dataset::new(features, target, vec!["x".into()], "y");
523
524        let mut lr_normal = LinearRegression::new().alpha(1.0);
525        lr_normal.fit(&data).unwrap();
526
527        let mut lr_svd = LinearRegression::new().alpha(1.0).solver(LinRegSolver::Svd);
528        lr_svd.fit(&data).unwrap();
529
530        assert!(
531            (lr_normal.coefficients()[0] - lr_svd.coefficients()[0]).abs() < 0.1,
532            "Ridge Normal={} vs SVD={}",
533            lr_normal.coefficients()[0],
534            lr_svd.coefficients()[0]
535        );
536    }
537
538    #[test]
539    fn test_auto_solver() {
540        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
541        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
542        let data = Dataset::new(features, target, vec!["x".into()], "y");
543
544        let mut lr = LinearRegression::new().solver(LinRegSolver::Auto);
545        lr.fit(&data).unwrap();
546
547        assert!(
548            (lr.coefficients()[0] - 2.0).abs() < 1e-6,
549            "Auto solver coefficient should be ~2.0, got {}",
550            lr.coefficients()[0]
551        );
552    }
553
554    #[test]
555    fn test_sparse_fit_matches_dense() {
556        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
557        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
558        let data = Dataset::new(features.clone(), target.clone(), vec!["x".into()], "y");
559
560        let mut lr_dense = LinearRegression::new();
561        lr_dense.fit(&data).unwrap();
562
563        let csc = CscMatrix::from_dense(&features);
564        let mut lr_sparse = LinearRegression::new();
565        lr_sparse.fit_sparse(&csc, &target).unwrap();
566
567        assert!(
568            (lr_dense.coefficients()[0] - lr_sparse.coefficients()[0]).abs() < 1e-6,
569            "Dense={} vs Sparse={}",
570            lr_dense.coefficients()[0],
571            lr_sparse.coefficients()[0]
572        );
573        assert!(
574            (lr_dense.intercept() - lr_sparse.intercept()).abs() < 1e-6,
575            "Dense intercept={} vs Sparse={}",
576            lr_dense.intercept(),
577            lr_sparse.intercept()
578        );
579    }
580
581    #[test]
582    fn test_sparse_predict_matches_dense() {
583        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
584        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
585        let data = Dataset::new(features, target, vec!["x".into()], "y");
586
587        let mut lr = LinearRegression::new();
588        lr.fit(&data).unwrap();
589
590        let test_rows = vec![vec![3.0], vec![10.0], vec![15.0]];
591        let preds_dense = lr.predict(&test_rows).unwrap();
592
593        let csr = CsrMatrix::from_dense(&test_rows);
594        let preds_sparse = lr.predict_sparse(&csr).unwrap();
595
596        for (d, s) in preds_dense.iter().zip(preds_sparse.iter()) {
597            assert!((d - s).abs() < 1e-6, "Dense pred={d} vs Sparse pred={s}");
598        }
599    }
600
601    #[test]
602    fn test_auto_dispatch_sparse_fit() {
603        // Create a sparse Dataset, call fit() (not fit_sparse), verify it works.
604        let features = vec![(0..20).map(|i| i as f64).collect::<Vec<f64>>()];
605        let target: Vec<f64> = (0..20).map(|i| 2.0 * i as f64 + 3.0).collect();
606        let csc = CscMatrix::from_dense(&features);
607        let data = crate::dataset::Dataset::from_sparse(csc, target, vec!["x".into()], "y");
608
609        let mut lr = LinearRegression::new();
610        lr.fit(&data).unwrap();
611
612        assert!(
613            (lr.coefficients()[0] - 2.0).abs() < 1e-4,
614            "Auto-dispatched sparse fit: coefficient should be ~2.0, got {}",
615            lr.coefficients()[0]
616        );
617    }
618}
619
620#[cfg(all(test, feature = "scry-gpu"))]
621mod gpu_tests {
622    use super::*;
623
624    #[test]
625    fn gpu_linear_regression_matches_cpu() {
626        let n = 500;
627        let m = 50;
628        let mut features = Vec::with_capacity(m);
629        for j in 0..m {
630            let col: Vec<f64> = (0..n).map(|i| ((i * (j + 1)) % 97) as f64 * 0.1).collect();
631            features.push(col);
632        }
633        let target: Vec<f64> = (0..n)
634            .map(|i| features[0][i] * 2.0 + features[1][i] * 0.5 + features[2][i] + 3.0)
635            .collect();
636        let names: Vec<String> = (0..m).map(|j| format!("f{j}")).collect();
637        let data = Dataset::new(features, target, names, "y");
638
639        let mut lr = LinearRegression::new().alpha(0.1);
640        lr.fit(&data).unwrap();
641
642        assert!(lr.coefficients().len() == m);
643        let preds = lr.predict(&[vec![1.0; m]]).unwrap();
644        assert!(
645            preds[0].is_finite(),
646            "prediction must be finite, got {}",
647            preds[0]
648        );
649    }
650}