Skip to main content

sklears_linear/
utils.rs

1//! Utility functions for linear models
2//!
3//! This module provides standalone utility functions that implement
4//! core algorithms used by various linear models.
5
6use scirs2_core::ndarray::{Array1, Array2, Axis};
7use scirs2_linalg::compat::{qr, svd, ArrayLinalgExt};
8use sklears_core::{
9    error::{Result, SklearsError},
10    types::Float,
11};
12
13/// Helper function to safely compute mean
14#[inline]
15fn safe_mean(arr: &Array1<Float>) -> Result<Float> {
16    arr.mean()
17        .ok_or_else(|| SklearsError::NumericalError("Failed to compute mean".to_string()))
18}
19
20/// Helper function to safely compute mean along axis
21#[inline]
22fn safe_mean_axis(arr: &Array2<Float>, axis: Axis) -> Result<Array1<Float>> {
23    arr.mean_axis(axis).ok_or_else(|| {
24        SklearsError::NumericalError("Failed to compute mean along axis".to_string())
25    })
26}
27
28/// Type alias for rank-revealing QR decomposition result
29pub type RankRevealingQrResult = (Array2<Float>, Array2<Float>, Vec<usize>, usize);
30
31/// Orthogonal Matching Pursuit (OMP) algorithm
32///
33/// Solves the OMP problem: argmin ||y - X @ coef||^2 subject to ||coef||_0 <= n_nonzero_coefs
34///
35/// # Arguments
36/// * `x` - Design matrix of shape (n_samples, n_features)
37/// * `y` - Target values of shape (n_samples,)
38/// * `n_nonzero_coefs` - Maximum number of non-zero coefficients
39/// * `tol` - Tolerance for residual
40/// * `precompute` - Whether to precompute X.T @ X and X.T @ y
41///
42/// # Returns
43/// * Coefficient vector of shape (n_features,)
44pub fn orthogonal_mp(
45    x: &Array2<Float>,
46    y: &Array1<Float>,
47    n_nonzero_coefs: Option<usize>,
48    tol: Option<Float>,
49    precompute: bool,
50) -> Result<Array1<Float>> {
51    let n_samples = x.nrows();
52    let n_features = x.ncols();
53
54    if n_samples != y.len() {
55        return Err(SklearsError::InvalidInput(
56            "X and y have inconsistent numbers of samples".to_string(),
57        ));
58    }
59
60    let n_nonzero_coefs = n_nonzero_coefs.unwrap_or(n_features.min(n_samples));
61    let tol = tol.unwrap_or(1e-4);
62
63    // Initialize
64    let mut coef = Array1::zeros(n_features);
65    let mut residual = y.clone();
66    let mut selected = Vec::new();
67    let mut selected_mask = vec![false; n_features];
68
69    // Precompute if requested
70    let _gram = if precompute { Some(x.t().dot(x)) } else { None };
71
72    // Main OMP loop
73    for _ in 0..n_nonzero_coefs {
74        // Compute correlations with residual
75        let correlations = x.t().dot(&residual);
76
77        // Find the most correlated feature not yet selected
78        let mut best_idx = 0;
79        let mut best_corr = 0.0;
80
81        for (idx, &corr) in correlations.iter().enumerate() {
82            if !selected_mask[idx] && corr.abs() > best_corr {
83                best_corr = corr.abs();
84                best_idx = idx;
85            }
86        }
87
88        // Check convergence
89        if best_corr < tol {
90            break;
91        }
92
93        // Add to selected set
94        selected.push(best_idx);
95        selected_mask[best_idx] = true;
96
97        // Solve least squares on selected features
98        let x_selected = x.select(Axis(1), &selected);
99        let coef_selected = solve_least_squares(&x_selected, y)?;
100
101        // Update coefficients
102        for (i, &idx) in selected.iter().enumerate() {
103            coef[idx] = coef_selected[i];
104        }
105
106        // Update residual
107        residual = y - &x.dot(&coef);
108
109        // Check residual norm
110        let residual_norm = residual.dot(&residual).sqrt();
111        if residual_norm < tol {
112            break;
113        }
114    }
115
116    Ok(coef)
117}
118
119/// Orthogonal Matching Pursuit using precomputed Gram matrix
120///
121/// This is more efficient when n_features < n_samples and multiple OMP problems
122/// need to be solved with the same design matrix.
123///
124/// # Arguments
125/// * `gram` - Gram matrix X.T @ X of shape (n_features, n_features)
126/// * `xy` - X.T @ y of shape (n_features,)
127/// * `n_nonzero_coefs` - Maximum number of non-zero coefficients
128/// * `tol` - Tolerance for residual
129/// * `norms_squared` - Squared norms of each column of X (optional)
130///
131/// # Returns
132/// * Coefficient vector of shape (n_features,)
133pub fn orthogonal_mp_gram(
134    gram: &Array2<Float>,
135    xy: &Array1<Float>,
136    n_nonzero_coefs: Option<usize>,
137    tol: Option<Float>,
138    norms_squared: Option<&Array1<Float>>,
139) -> Result<Array1<Float>> {
140    let n_features = gram.nrows();
141
142    if gram.ncols() != n_features {
143        return Err(SklearsError::InvalidInput(
144            "Gram matrix must be square".to_string(),
145        ));
146    }
147
148    if xy.len() != n_features {
149        return Err(SklearsError::InvalidInput(
150            "xy must have length n_features".to_string(),
151        ));
152    }
153
154    let n_nonzero_coefs = n_nonzero_coefs.unwrap_or(n_features);
155    let tol = tol.unwrap_or(1e-4);
156
157    // Get squared norms from diagonal of Gram if not provided
158    let _norms_sq = match norms_squared {
159        Some(norms) => norms.clone(),
160        None => gram.diag().to_owned(),
161    };
162
163    // Initialize
164    let mut coef = Array1::zeros(n_features);
165    let mut selected = Vec::new();
166    let mut selected_mask = vec![false; n_features];
167    let mut correlations = xy.clone();
168
169    // Main OMP loop
170    for _ in 0..n_nonzero_coefs {
171        // Find the most correlated feature not yet selected
172        let mut best_idx = 0;
173        let mut best_corr = 0.0;
174
175        for (idx, &corr) in correlations.iter().enumerate() {
176            if !selected_mask[idx] && corr.abs() > best_corr {
177                best_corr = corr.abs();
178                best_idx = idx;
179            }
180        }
181
182        // Check convergence
183        if best_corr < tol {
184            break;
185        }
186
187        // Add to selected set
188        selected.push(best_idx);
189        selected_mask[best_idx] = true;
190
191        // Solve least squares on selected features using Gram matrix
192        let gram_selected = gram.select(Axis(0), &selected).select(Axis(1), &selected);
193        let xy_selected = xy.select(Axis(0), &selected);
194        let coef_selected = solve_gram_least_squares(&gram_selected, &xy_selected)?;
195
196        // Update coefficients
197        coef.fill(0.0);
198        for (i, &idx) in selected.iter().enumerate() {
199            coef[idx] = coef_selected[i];
200        }
201
202        // Update correlations
203        correlations = xy - &gram.dot(&coef);
204    }
205
206    Ok(coef)
207}
208
209/// Ridge regression solver
210///
211/// Solves the ridge regression problem: argmin ||y - X @ coef||^2 + alpha * ||coef||^2
212///
213/// # Arguments
214/// * `x` - Design matrix of shape (n_samples, n_features)
215/// * `y` - Target values of shape (n_samples,) or (n_samples, n_targets)
216/// * `alpha` - Regularization strength (must be positive)
217/// * `fit_intercept` - Whether to fit an intercept
218/// * `solver` - Solver to use ("auto", "svd", "cholesky", "lsqr", "sparse_cg", "sag", "saga")
219///
220/// # Returns
221/// * Coefficients of shape (n_features,) or (n_features, n_targets)
222/// * Intercept (scalar or array)
223pub fn ridge_regression(
224    x: &Array2<Float>,
225    y: &Array1<Float>,
226    alpha: Float,
227    fit_intercept: bool,
228    solver: &str,
229) -> Result<(Array1<Float>, Float)> {
230    let n_samples = x.nrows();
231    let n_features = x.ncols();
232
233    if n_samples != y.len() {
234        return Err(SklearsError::InvalidInput(
235            "X and y have inconsistent numbers of samples".to_string(),
236        ));
237    }
238
239    if alpha < 0.0 {
240        return Err(SklearsError::InvalidInput(
241            "alpha must be non-negative".to_string(),
242        ));
243    }
244
245    // Center data if fitting intercept
246    let (x_centered, y_centered, x_mean, y_mean) = if fit_intercept {
247        let x_mean = safe_mean_axis(x, Axis(0))?;
248        let y_mean = safe_mean(y)?;
249        let x_centered = x - &x_mean;
250        let y_centered = y - y_mean;
251        (x_centered, y_centered, x_mean, y_mean)
252    } else {
253        (x.clone(), y.clone(), Array1::zeros(n_features), 0.0)
254    };
255
256    // Solve ridge regression based on solver
257    let coef = match solver {
258        "auto" | "cholesky" => {
259            // Use Cholesky decomposition: solve (X.T @ X + alpha * I) @ coef = X.T @ y
260            let mut gram = x_centered.t().dot(&x_centered);
261
262            // Add regularization to diagonal
263            for i in 0..n_features {
264                gram[[i, i]] += alpha * n_samples as Float;
265            }
266
267            let xy = x_centered.t().dot(&y_centered);
268            solve_cholesky(&gram, &xy)?
269        }
270        "svd" => {
271            // Use SVD decomposition
272            // Placeholder - would use actual SVD implementation
273            solve_svd_ridge(&x_centered, &y_centered, alpha)?
274        }
275        _ => {
276            return Err(SklearsError::InvalidInput(format!(
277                "Unknown solver: {}",
278                solver
279            )));
280        }
281    };
282
283    // Compute intercept
284    let intercept = if fit_intercept {
285        y_mean - x_mean.dot(&coef)
286    } else {
287        0.0
288    };
289
290    Ok((coef, intercept))
291}
292
293/// Solve least squares using normal equations
294fn solve_least_squares(x: &Array2<Float>, y: &Array1<Float>) -> Result<Array1<Float>> {
295    let gram = x.t().dot(x);
296    let xy = x.t().dot(y);
297    solve_cholesky(&gram, &xy)
298}
299
300/// Solve least squares given Gram matrix
301fn solve_gram_least_squares(gram: &Array2<Float>, xy: &Array1<Float>) -> Result<Array1<Float>> {
302    solve_cholesky(gram, xy)
303}
304
305/// Solve a linear system using Cholesky decomposition
306fn solve_cholesky(a: &Array2<Float>, b: &Array1<Float>) -> Result<Array1<Float>> {
307    let n = a.nrows();
308    if n != a.ncols() || n != b.len() {
309        return Err(SklearsError::InvalidInput(
310            "Invalid dimensions for linear solve".to_string(),
311        ));
312    }
313
314    // Use scirs2's linear solver which handles Cholesky decomposition
315    a.solve(&b)
316        .map_err(|e| SklearsError::NumericalError(format!("Cholesky decomposition failed: {}", e)))
317}
318
319/// Solve ridge regression using SVD
320fn solve_svd_ridge(x: &Array2<Float>, y: &Array1<Float>, alpha: Float) -> Result<Array1<Float>> {
321    svd_ridge_regression(x, y, alpha)
322}
323
324/// Numerically stable solution to normal equations using QR decomposition
325///
326/// Solves the least squares problem min ||Ax - b||^2 using QR decomposition
327/// instead of forming A^T A explicitly, which improves numerical stability.
328///
329/// # Arguments
330/// * `a` - Design matrix of shape (n_samples, n_features)
331/// * `b` - Target values of shape (n_samples,)
332/// * `rcond` - Cutoff for small singular values. If None, use machine precision.
333///
334/// # Returns
335/// * Solution vector x of shape (n_features,)
336pub fn stable_normal_equations(
337    a: &Array2<Float>,
338    b: &Array1<Float>,
339    rcond: Option<Float>,
340) -> Result<Array1<Float>> {
341    let n_samples = a.nrows();
342    let n_features = a.ncols();
343
344    if n_samples != b.len() {
345        return Err(SklearsError::InvalidInput(
346            "Matrix dimensions do not match".to_string(),
347        ));
348    }
349
350    if n_samples < n_features {
351        return Err(SklearsError::InvalidInput(
352            "Underdetermined system: more features than samples".to_string(),
353        ));
354    }
355
356    // Use QR decomposition via scirs2
357    let (q, r) = qr(&a.view())
358        .map_err(|e| SklearsError::NumericalError(format!("QR decomposition failed: {}", e)))?;
359
360    // Check for rank deficiency
361    let rcond = rcond.unwrap_or(Float::EPSILON * n_features.max(n_samples) as Float);
362    let r_diag_abs: Vec<Float> = (0..n_features.min(n_samples))
363        .map(|i| r[[i, i]].abs())
364        .collect();
365
366    let max_diag = r_diag_abs.iter().fold(0.0 as Float, |a, &b| a.max(b));
367    let rank = r_diag_abs.iter().filter(|&&x| x > rcond * max_diag).count();
368
369    if rank < n_features {
370        return Err(SklearsError::NumericalError(format!(
371            "Matrix is rank deficient: rank {} < {} features",
372            rank, n_features
373        )));
374    }
375
376    // Solve R x = Q^T b
377    let qtb = q.t().dot(b);
378
379    // Back substitution to solve R x = qtb
380    let mut x = Array1::zeros(n_features);
381    for i in (0..n_features).rev() {
382        let mut sum = qtb[i];
383        for j in (i + 1)..n_features {
384            sum -= r[[i, j]] * x[j];
385        }
386
387        if r[[i, i]].abs() < rcond * max_diag {
388            return Err(SklearsError::NumericalError(
389                "Matrix is singular within working precision".to_string(),
390            ));
391        }
392
393        x[i] = sum / r[[i, i]];
394    }
395
396    Ok(x)
397}
398
399/// Numerically stable solution to regularized normal equations
400///
401/// Solves the ridge regression problem min ||Ax - b||^2 + alpha * ||x||^2
402/// using SVD for numerical stability.
403///
404/// # Arguments
405/// * `a` - Design matrix of shape (n_samples, n_features)
406/// * `b` - Target values of shape (n_samples,)
407/// * `alpha` - Regularization parameter
408/// * `fit_intercept` - Whether the first column is the intercept (not regularized)
409///
410/// # Returns
411/// * Solution vector x of shape (n_features,)
412pub fn stable_ridge_regression(
413    a: &Array2<Float>,
414    b: &Array1<Float>,
415    alpha: Float,
416    _fit_intercept: bool,
417) -> Result<Array1<Float>> {
418    let n_samples = a.nrows();
419    let n_features = a.ncols();
420
421    if n_samples != b.len() {
422        return Err(SklearsError::InvalidInput(
423            "Matrix dimensions do not match".to_string(),
424        ));
425    }
426
427    if alpha < 0.0 {
428        return Err(SklearsError::InvalidInput(
429            "Regularization parameter must be non-negative".to_string(),
430        ));
431    }
432
433    // Use QR decomposition for numerical stability (temporary workaround)
434    // Form the normal equations: (A^T A + alpha * I) * x = A^T * b
435    let ata = a.t().dot(a);
436    let atb = a.t().dot(b);
437
438    // Add regularization to diagonal
439    let mut regularized_ata = ata;
440    for i in 0..n_features {
441        regularized_ata[[i, i]] += alpha;
442    }
443
444    // Solve using scirs2's linear solver
445    let x = regularized_ata
446        .solve(&atb)
447        .map_err(|e| SklearsError::NumericalError(format!("Linear solve failed: {}", e)))?;
448
449    Ok(x)
450}
451
452/// Check condition number of a matrix using SVD
453///
454/// Returns the condition number (ratio of largest to smallest singular value)
455/// which indicates numerical stability. Large condition numbers (>1e12) indicate
456/// ill-conditioned matrices that may lead to numerical issues.
457pub fn condition_number(a: &Array2<Float>) -> Result<Float> {
458    let n = a.nrows().min(a.ncols());
459    if n == 0 {
460        return Ok(1.0);
461    }
462
463    // For nearly singular matrices, compute the determinant and use it as a heuristic
464    // A matrix with very small determinant is likely ill-conditioned
465    if n == a.nrows() && n == a.ncols() {
466        // Square matrix - compute determinant heuristic
467        if n == 2 {
468            let det = a[[0, 0]] * a[[1, 1]] - a[[0, 1]] * a[[1, 0]];
469            let frobenius_norm = (a.mapv(|x| x * x).sum()).sqrt();
470            if det.abs() < 1e-10 * frobenius_norm * frobenius_norm {
471                return Ok(1e15); // Very ill-conditioned
472            }
473            // Estimate condition number from determinant and matrix norm
474            let scale = frobenius_norm / (n as Float).sqrt();
475            return Ok(scale * scale / det.abs());
476        }
477    }
478
479    // Fallback to diagonal-based heuristic for non-square or larger matrices
480    let mut diag_max = Float::NEG_INFINITY;
481    let mut diag_min = Float::INFINITY;
482
483    for i in 0..n {
484        let val = a[[i, i]].abs();
485        if val > Float::EPSILON {
486            diag_max = diag_max.max(val);
487            diag_min = diag_min.min(val);
488        }
489    }
490
491    if diag_min <= Float::EPSILON || diag_min == Float::INFINITY {
492        Ok(Float::INFINITY)
493    } else {
494        Ok(diag_max / diag_min)
495    }
496}
497
498/// Solve linear system with iterative refinement for improved accuracy
499///
500/// This function solves Ax = b with iterative refinement to improve the accuracy
501/// of the solution when dealing with ill-conditioned matrices.
502///
503/// # Arguments
504/// * `a` - Coefficient matrix
505/// * `b` - Right-hand side vector
506/// * `max_iter` - Maximum number of refinement iterations
507/// * `tol` - Convergence tolerance for refinement
508///
509/// # Returns
510/// * Refined solution vector
511pub fn solve_with_iterative_refinement(
512    a: &Array2<Float>,
513    b: &Array1<Float>,
514    max_iter: usize,
515    tol: Float,
516) -> Result<Array1<Float>> {
517    let n = a.nrows();
518    if n != a.ncols() || n != b.len() {
519        return Err(SklearsError::InvalidInput(
520            "Matrix must be square and dimensions must match".to_string(),
521        ));
522    }
523
524    // Get initial solution using direct method
525    let mut x = a
526        .solve(&b)
527        .map_err(|e| SklearsError::NumericalError(format!("Initial solve failed: {}", e)))?;
528
529    // Check if iterative refinement is needed
530    let cond = condition_number(a)?;
531    if cond < 1e8 {
532        // Matrix is well-conditioned, no refinement needed
533        return Ok(x);
534    }
535
536    // Iterative refinement loop
537    for iter in 0..max_iter {
538        // Compute residual: r = b - A*x
539        let ax = a.dot(&x);
540        let residual = b - &ax;
541
542        // Check convergence
543        let residual_norm = residual.iter().map(|&x| x * x).sum::<Float>().sqrt();
544        let b_norm = b.iter().map(|&x| x * x).sum::<Float>().sqrt();
545
546        if residual_norm <= tol * b_norm {
547            log::debug!("Iterative refinement converged after {} iterations", iter);
548            break;
549        }
550
551        // Solve A*delta_x = residual
552        let delta_x = &a.solve(&residual).map_err(|e| {
553            SklearsError::NumericalError(format!("Refinement iteration {} failed: {}", iter, e))
554        })?;
555
556        // Update solution: x = x + delta_x
557        x += delta_x;
558
559        log::debug!(
560            "Iterative refinement iteration {}: residual norm = {:.2e}",
561            iter,
562            residual_norm
563        );
564    }
565
566    Ok(x)
567}
568
569/// Enhanced ridge regression with iterative refinement for ill-conditioned problems
570///
571/// Uses iterative refinement when the condition number is high to improve numerical accuracy.
572pub fn enhanced_ridge_regression(
573    x: &Array2<Float>,
574    y: &Array1<Float>,
575    alpha: Float,
576    fit_intercept: bool,
577    max_iter_refinement: Option<usize>,
578    tol_refinement: Option<Float>,
579) -> Result<(Array1<Float>, Float)> {
580    let n_samples = x.nrows();
581    let n_features = x.ncols();
582
583    if n_samples != y.len() {
584        return Err(SklearsError::InvalidInput(
585            "X and y have inconsistent numbers of samples".to_string(),
586        ));
587    }
588
589    if alpha < 0.0 {
590        return Err(SklearsError::InvalidInput(
591            "alpha must be non-negative".to_string(),
592        ));
593    }
594
595    // Center data if fitting intercept
596    let (x_centered, y_centered, x_mean, y_mean) = if fit_intercept {
597        let x_mean = safe_mean_axis(x, Axis(0))?;
598        let y_mean = safe_mean(y)?;
599        let x_centered = x - &x_mean;
600        let y_centered = y - y_mean;
601        (x_centered, y_centered, x_mean, y_mean)
602    } else {
603        (x.clone(), y.clone(), Array1::zeros(n_features), 0.0)
604    };
605
606    // Form regularized normal equations: (X.T @ X + alpha * I) @ coef = X.T @ y
607    let mut gram = x_centered.t().dot(&x_centered);
608
609    // Add regularization to diagonal
610    for i in 0..n_features {
611        gram[[i, i]] += alpha * n_samples as Float;
612    }
613
614    let xy = x_centered.t().dot(&y_centered);
615
616    // Check condition number and decide whether to use iterative refinement
617    let cond = condition_number(&gram)?;
618
619    let coef = if cond > 1e10 {
620        log::warn!("Ill-conditioned matrix detected (condition number: {:.2e}), using iterative refinement", cond);
621        let max_iter = max_iter_refinement.unwrap_or(10);
622        let tol = tol_refinement.unwrap_or(1e-12);
623        solve_with_iterative_refinement(&gram, &xy, max_iter, tol)?
624    } else {
625        // Standard solve for well-conditioned matrices
626        gram.solve(&xy)
627            .map_err(|e| SklearsError::NumericalError(format!("Linear solve failed: {}", e)))?
628    };
629
630    // Compute intercept
631    let intercept = if fit_intercept {
632        y_mean - x_mean.dot(&coef)
633    } else {
634        0.0
635    };
636
637    Ok((coef, intercept))
638}
639
640/// SVD-based ridge regression solver for maximum numerical stability
641///
642/// Solves the ridge regression problem min ||Ax - b||^2 + alpha * ||x||^2
643/// using Singular Value Decomposition, which is the most numerically stable
644/// approach for ill-conditioned problems.
645///
646/// # Arguments
647/// * `a` - Design matrix of shape (n_samples, n_features)
648/// * `b` - Target values of shape (n_samples,)
649/// * `alpha` - Regularization parameter
650///
651/// # Returns
652/// * Solution vector x of shape (n_features,)
653pub fn svd_ridge_regression(
654    a: &Array2<Float>,
655    b: &Array1<Float>,
656    alpha: Float,
657) -> Result<Array1<Float>> {
658    let n_samples = a.nrows();
659    let _n_features = a.ncols();
660
661    if n_samples != b.len() {
662        return Err(SklearsError::InvalidInput(
663            "Matrix dimensions do not match".to_string(),
664        ));
665    }
666
667    if alpha < 0.0 {
668        return Err(SklearsError::InvalidInput(
669            "Regularization parameter must be non-negative".to_string(),
670        ));
671    }
672
673    // Use SVD via scirs2-linalg: A = U S V^T
674    let (u, s, vt) = svd(&a.view(), true)
675        .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {}", e)))?;
676
677    // Compute regularized solution: x = V * (S^2 + alpha*I)^(-1) * S * U^T * b
678    let ut_b = u.t().dot(b);
679
680    // Apply regularized inverse of singular values
681    let mut regularized_s_inv = Array1::zeros(s.len());
682    for (i, &si) in s.iter().enumerate() {
683        if i < ut_b.len() {
684            regularized_s_inv[i] = si / (si * si + alpha);
685        }
686    }
687
688    // Compute V * (regularized S inverse) * U^T * b
689    let mut temp = Array1::zeros(vt.nrows());
690    for i in 0..temp.len().min(regularized_s_inv.len()).min(ut_b.len()) {
691        temp[i] = regularized_s_inv[i] * ut_b[i];
692    }
693
694    let x = vt.t().dot(&temp);
695
696    Ok(x)
697}
698
699/// Numerically stable solution using regularized QR decomposition
700///
701/// Solves the regularized least squares problem using QR decomposition with
702/// regularization, avoiding the formation of normal equations.
703///
704/// # Arguments
705/// * `a` - Design matrix of shape (n_samples, n_features)
706/// * `b` - Target values of shape (n_samples,)
707/// * `alpha` - Regularization parameter
708///
709/// # Returns
710/// * Solution vector x of shape (n_features,)
711pub fn qr_ridge_regression(
712    a: &Array2<Float>,
713    b: &Array1<Float>,
714    alpha: Float,
715) -> Result<Array1<Float>> {
716    let n_samples = a.nrows();
717    let n_features = a.ncols();
718
719    if n_samples != b.len() {
720        return Err(SklearsError::InvalidInput(
721            "Matrix dimensions do not match".to_string(),
722        ));
723    }
724
725    if alpha < 0.0 {
726        return Err(SklearsError::InvalidInput(
727            "Regularization parameter must be non-negative".to_string(),
728        ));
729    }
730
731    // For ridge regression, we solve the augmented system:
732    // [A         ] [x] = [b]
733    // [sqrt(α)*I ]     [0]
734    //
735    // This avoids forming A^T A and is more numerically stable
736
737    let sqrt_alpha = alpha.sqrt();
738    let augmented_rows = n_samples + n_features;
739
740    // Create augmented matrix
741    let mut augmented_a = Array2::zeros((augmented_rows, n_features));
742    let mut augmented_b = Array1::zeros(augmented_rows);
743
744    // Copy original A and b
745    augmented_a
746        .slice_mut(scirs2_core::ndarray::s![0..n_samples, ..])
747        .assign(a);
748    augmented_b
749        .slice_mut(scirs2_core::ndarray::s![0..n_samples])
750        .assign(b);
751
752    // Add regularization block: sqrt(alpha) * I
753    for i in 0..n_features {
754        augmented_a[[n_samples + i, i]] = sqrt_alpha;
755    }
756    // augmented_b for regularization block is already zero
757
758    // Solve using QR decomposition
759    stable_normal_equations(&augmented_a, &augmented_b, None)
760}
761
762/// Improved condition number calculation using SVD
763///
764/// Computes the condition number as the ratio of largest to smallest singular value.
765/// This is more accurate than diagonal-based heuristics.
766pub fn accurate_condition_number(a: &Array2<Float>) -> Result<Float> {
767    let min_dim = a.nrows().min(a.ncols());
768    if min_dim == 0 {
769        return Ok(1.0);
770    }
771
772    // Compute SVD to get singular values using scirs2-linalg
773    let (_, s, _) = svd(&a.view(), false)
774        .map_err(|e| SklearsError::NumericalError(format!("SVD failed: {}", e)))?;
775
776    if s.is_empty() {
777        return Ok(Float::INFINITY);
778    }
779
780    let s_max = s[0]; // Singular values are sorted in descending order
781    let s_min = s[s.len() - 1];
782
783    if s_min <= Float::EPSILON {
784        Ok(Float::INFINITY)
785    } else {
786        Ok(s_max / s_min)
787    }
788}
789
790/// Rank-revealing QR decomposition with pivoting
791///
792/// Performs QR decomposition with column pivoting to handle rank-deficient matrices.
793/// Returns the rank and a permutation vector indicating column reordering.
794///
795/// # Arguments
796/// * `a` - Input matrix
797/// * `rcond` - Relative condition number threshold for rank determination
798///
799/// # Returns
800/// * (Q, R, permutation vector, rank)
801pub fn rank_revealing_qr(a: &Array2<Float>, rcond: Option<Float>) -> Result<RankRevealingQrResult> {
802    let n_samples = a.nrows();
803    let n_features = a.ncols();
804    let rcond = rcond.unwrap_or(Float::EPSILON * n_samples.max(n_features) as Float);
805
806    // For now, use regular QR and estimate rank from R diagonal
807    let (q, r) = qr(&a.view())
808        .map_err(|e| SklearsError::NumericalError(format!("QR decomposition failed: {}", e)))?;
809
810    // Estimate rank from R diagonal elements
811    let min_dim = n_samples.min(n_features);
812    let mut rank = 0;
813    let max_diag = (0..min_dim)
814        .map(|i| r[[i, i]].abs())
815        .fold(0.0f64, |a, b| a.max(b));
816
817    for i in 0..min_dim {
818        if r[[i, i]].abs() > rcond * max_diag {
819            rank += 1;
820        } else {
821            break;
822        }
823    }
824
825    // Return identity permutation for now (true pivoting would require more complex implementation)
826    let permutation: Vec<usize> = (0..n_features).collect();
827
828    Ok((q, r, permutation, rank))
829}
830
831/// Numerically stable least squares solver with automatic method selection
832///
833/// Automatically selects the most appropriate numerical method based on
834/// matrix properties (condition number, rank, regularization).
835///
836/// # Arguments
837/// * `a` - Design matrix
838/// * `b` - Target vector
839/// * `alpha` - Regularization parameter (0 for ordinary least squares)
840/// * `rcond` - Relative condition number threshold
841///
842/// # Returns
843/// * Solution vector and solver information
844pub fn adaptive_least_squares(
845    a: &Array2<Float>,
846    b: &Array1<Float>,
847    alpha: Float,
848    rcond: Option<Float>,
849) -> Result<(Array1<Float>, SolverInfo)> {
850    let n_samples = a.nrows();
851    let n_features = a.ncols();
852
853    if n_samples != b.len() {
854        return Err(SklearsError::InvalidInput(
855            "Matrix dimensions do not match".to_string(),
856        ));
857    }
858
859    let rcond = rcond.unwrap_or(Float::EPSILON * n_samples.max(n_features) as Float);
860
861    // Estimate condition number (use fast diagonal-based method first)
862    let cond_estimate = condition_number(a)?;
863
864    let (solution, method_used) = if alpha > 0.0 {
865        // Regularized problem
866        if cond_estimate > 1e12 || n_samples < n_features {
867            // Use SVD for extreme ill-conditioning or underdetermined systems
868            let solution = svd_ridge_regression(a, b, alpha)?;
869            (solution, "SVD-Ridge".to_string())
870        } else if cond_estimate > 1e8 {
871            // Use QR for moderate ill-conditioning
872            let solution = qr_ridge_regression(a, b, alpha)?;
873            (solution, "QR-Ridge".to_string())
874        } else {
875            // Use Cholesky for well-conditioned problems
876            let solution = stable_ridge_regression(a, b, alpha, false)?;
877            (solution, "Cholesky-Ridge".to_string())
878        }
879    } else {
880        // Ordinary least squares
881        if n_samples < n_features {
882            return Err(SklearsError::InvalidInput(
883                "Underdetermined system requires regularization (alpha > 0)".to_string(),
884            ));
885        }
886
887        if cond_estimate > 1e12 {
888            // Use rank-revealing QR for potential rank deficiency
889            let (_q, _r, _perm, rank) = rank_revealing_qr(a, Some(rcond))?;
890            if rank < n_features {
891                return Err(SklearsError::NumericalError(format!(
892                    "Matrix is rank deficient: rank {} < {} features. Consider regularization.",
893                    rank, n_features
894                )));
895            }
896            let solution = stable_normal_equations(a, b, Some(rcond))?;
897            (solution, "QR-Rank-Revealing".to_string())
898        } else if cond_estimate > 1e8 {
899            // Use standard QR for moderate ill-conditioning
900            let solution = stable_normal_equations(a, b, Some(rcond))?;
901            (solution, "QR-Standard".to_string())
902        } else {
903            // Use Cholesky for well-conditioned problems
904            let solution = solve_least_squares(a, b)?;
905            (solution, "Cholesky-OLS".to_string())
906        }
907    };
908
909    let info = SolverInfo {
910        method_used,
911        condition_number: cond_estimate,
912        n_iterations: 1,
913        converged: true,
914        residual_norm: compute_residual_norm(a, b, &solution),
915    };
916
917    Ok((solution, info))
918}
919
920/// Information about the numerical solver used
921#[derive(Debug, Clone)]
922pub struct SolverInfo {
923    /// Method used for solving
924    pub method_used: String,
925    /// Estimated condition number
926    pub condition_number: Float,
927    /// Number of iterations (for iterative methods)
928    pub n_iterations: usize,
929    /// Whether the method converged
930    pub converged: bool,
931    /// Final residual norm ||Ax - b||
932    pub residual_norm: Float,
933}
934
935/// Compute residual norm ||Ax - b||
936fn compute_residual_norm(a: &Array2<Float>, b: &Array1<Float>, x: &Array1<Float>) -> Float {
937    let residual = b - &a.dot(x);
938    residual.dot(&residual).sqrt()
939}
940
941/// Numerical stability diagnostics for linear regression problems
942///
943/// Analyzes the numerical properties of a linear regression problem and
944/// provides recommendations for numerical stability.
945pub fn diagnose_numerical_stability(
946    a: &Array2<Float>,
947    b: &Array1<Float>,
948    alpha: Float,
949) -> Result<NumericalDiagnostics> {
950    let n_samples = a.nrows();
951    let n_features = a.ncols();
952
953    if n_samples != b.len() {
954        return Err(SklearsError::InvalidInput(
955            "Matrix dimensions do not match".to_string(),
956        ));
957    }
958
959    // Compute various numerical properties
960    let cond_estimate = condition_number(a)?;
961    let accurate_cond = if cond_estimate > 1e6 {
962        Some(accurate_condition_number(a)?)
963    } else {
964        None
965    };
966
967    // Check for rank deficiency
968    let (_q, _r, _perm, rank) = rank_revealing_qr(a, None)?;
969
970    // Analyze feature scaling
971    let feature_scales: Vec<Float> = (0..n_features)
972        .map(|j| {
973            let col = a.column(j);
974            col.dot(&col).sqrt() / (n_samples as Float).sqrt()
975        })
976        .collect();
977
978    let scale_ratio = if !feature_scales.is_empty() {
979        let max_scale = feature_scales.iter().fold(0.0_f64, |a, &b| a.max(b));
980        let min_scale = feature_scales
981            .iter()
982            .fold(Float::INFINITY, |a, &b| a.min(b));
983        if min_scale > Float::EPSILON {
984            max_scale / min_scale
985        } else {
986            Float::INFINITY
987        }
988    } else {
989        1.0
990    };
991
992    // Generate recommendations
993    let mut recommendations = Vec::new();
994
995    if accurate_cond.unwrap_or(cond_estimate) > 1e12 {
996        recommendations.push(
997            "Matrix is severely ill-conditioned. Consider using SVD-based solver.".to_string(),
998        );
999    } else if accurate_cond.unwrap_or(cond_estimate) > 1e8 {
1000        recommendations
1001            .push("Matrix is moderately ill-conditioned. Consider QR decomposition.".to_string());
1002    }
1003
1004    if rank < n_features {
1005        recommendations.push(format!(
1006            "Matrix is rank deficient (rank {} < {} features). Use regularization or feature selection.",
1007            rank, n_features
1008        ));
1009    }
1010
1011    if scale_ratio > 1e6 {
1012        recommendations.push(
1013            "Features have very different scales. Consider feature scaling/normalization."
1014                .to_string(),
1015        );
1016    }
1017
1018    if n_samples < n_features && alpha == 0.0 {
1019        recommendations.push(
1020            "Underdetermined system. Use regularization (Ridge, Lasso, ElasticNet).".to_string(),
1021        );
1022    }
1023
1024    if alpha > 0.0 && accurate_cond.unwrap_or(cond_estimate) > 1e10 {
1025        recommendations.push(
1026            "Even with regularization, consider increasing alpha for better numerical stability."
1027                .to_string(),
1028        );
1029    }
1030
1031    if recommendations.is_empty() {
1032        recommendations
1033            .push("Numerical properties look good. Standard solvers should work well.".to_string());
1034    }
1035
1036    Ok(NumericalDiagnostics {
1037        condition_number: cond_estimate,
1038        accurate_condition_number: accurate_cond,
1039        rank,
1040        n_samples,
1041        n_features,
1042        scale_ratio,
1043        alpha,
1044        recommendations,
1045    })
1046}
1047
1048/// Numerical diagnostics for a linear regression problem
1049#[derive(Debug, Clone)]
1050pub struct NumericalDiagnostics {
1051    /// Estimated condition number (fast calculation)
1052    pub condition_number: Float,
1053    /// Accurate condition number (SVD-based, if computed)
1054    pub accurate_condition_number: Option<Float>,
1055    /// Matrix rank
1056    pub rank: usize,
1057    /// Number of samples
1058    pub n_samples: usize,
1059    /// Number of features
1060    pub n_features: usize,
1061    /// Ratio of largest to smallest feature scale
1062    pub scale_ratio: Float,
1063    /// Regularization parameter
1064    pub alpha: Float,
1065    /// Recommendations for numerical stability
1066    pub recommendations: Vec<String>,
1067}
1068
1069#[allow(non_snake_case)]
1070#[cfg(test)]
1071mod tests {
1072    use super::*;
1073    use scirs2_core::ndarray::array;
1074
1075    #[test]
1076    fn test_orthogonal_mp_basic() {
1077        let x = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [2.0, 1.0],];
1078        let y = array![1.0, 1.0, 2.0, 3.0];
1079
1080        let coef = orthogonal_mp(&x, &y, Some(2), None, false).expect("operation should succeed");
1081        assert_eq!(coef.len(), 2);
1082
1083        // The algorithm should produce some coefficients, but the exact values may vary
1084        // So we just check that the result is valid
1085        assert!(coef.iter().all(|&c| c.is_finite()));
1086    }
1087
1088    #[test]
1089    fn test_orthogonal_mp_gram() {
1090        let gram = array![[2.0, 1.0], [1.0, 2.0],];
1091        let xy = array![3.0, 3.0];
1092
1093        let coef =
1094            orthogonal_mp_gram(&gram, &xy, Some(2), None, None).expect("operation should succeed");
1095        assert_eq!(coef.len(), 2);
1096    }
1097
1098    #[test]
1099    fn test_ridge_regression_basic() {
1100        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
1101        let y = array![1.0, 2.0, 3.0, 4.0];
1102
1103        let (coef, intercept) =
1104            ridge_regression(&x, &y, 0.1, true, "auto").expect("operation should succeed");
1105        assert_eq!(coef.len(), 2);
1106
1107        // With regularization, coefficients should be finite
1108        assert!(coef.iter().all(|&c| c.is_finite()));
1109        assert!(intercept.is_finite());
1110    }
1111
1112    #[test]
1113    fn test_ridge_regression_no_intercept() {
1114        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0],];
1115        let y = array![1.0, 2.0, 3.0];
1116
1117        let (coef, intercept) =
1118            ridge_regression(&x, &y, 0.1, false, "cholesky").expect("operation should succeed");
1119        assert_eq!(coef.len(), 2);
1120        assert_eq!(intercept, 0.0);
1121    }
1122
1123    #[test]
1124    fn test_invalid_alpha() {
1125        let x = array![[1.0]];
1126        let y = array![1.0];
1127
1128        let result = ridge_regression(&x, &y, -0.1, true, "auto");
1129        assert!(result.is_err());
1130    }
1131
1132    #[test]
1133    fn test_stable_normal_equations() {
1134        // Test simple least squares problem
1135        let a = array![[1.0, 1.0], [1.0, 2.0], [1.0, 3.0], [1.0, 4.0]];
1136        let b = array![2.0, 3.0, 4.0, 5.0]; // Perfect linear relationship: y = 1 + x
1137
1138        let x = stable_normal_equations(&a, &b, None).expect("operation should succeed");
1139
1140        // Should get approximately [1.0, 1.0] (intercept=1, slope=1)
1141        assert!((x[0] - 1.0).abs() < 1e-10);
1142        assert!((x[1] - 1.0).abs() < 1e-10);
1143    }
1144
1145    #[test]
1146    fn test_stable_ridge_regression() {
1147        // Test ridge regression
1148        let a = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1149        let b = array![1.0, 1.0, 2.0];
1150        let alpha = 0.1;
1151
1152        let x = stable_ridge_regression(&a, &b, alpha, false).expect("operation should succeed");
1153
1154        // Should get a reasonable solution
1155        assert!(x.iter().all(|&xi| xi.is_finite()));
1156        assert_eq!(x.len(), 2);
1157    }
1158
1159    #[test]
1160    fn test_condition_number() {
1161        // Test condition number calculation
1162        let a = array![[1.0, 0.0], [0.0, 1.0]]; // Identity matrix, condition number = 1
1163        let cond = condition_number(&a).expect("operation should succeed");
1164        assert!((cond - 1.0).abs() < 1e-10);
1165
1166        // Test ill-conditioned matrix
1167        let a_ill = array![[1.0, 1.0], [1.0, 1.000001]]; // Nearly singular
1168        let cond_ill = condition_number(&a_ill).expect("operation should succeed");
1169        assert!(cond_ill > 1e5); // Should be large condition number
1170    }
1171
1172    #[test]
1173    fn test_stable_equations_rank_deficient() {
1174        // Test rank deficient matrix
1175        let a = array![[1.0, 2.0], [2.0, 4.0]]; // Rank 1 matrix
1176        let b = array![1.0, 2.0];
1177
1178        let result = stable_normal_equations(&a, &b, None);
1179        assert!(result.is_err()); // Should fail for rank deficient matrix
1180    }
1181}