scirs2_sparse/linalg/
lsqr.rs

1//! Least Squares QR (LSQR) method for sparse linear systems
2//!
3//! LSQR is an iterative method for solving sparse least squares problems
4//! and sparse linear systems. It can handle both overdetermined and
5//! underdetermined systems.
6
7#![allow(unused_variables)]
8#![allow(unused_assignments)]
9#![allow(unused_mut)]
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::SparseArray;
13use scirs2_core::ndarray::{Array1, ArrayView1};
14use scirs2_core::numeric::{Float, SparseElement};
15use std::fmt::Debug;
16
17/// Options for the LSQR solver
18#[derive(Debug, Clone)]
19pub struct LSQROptions {
20    /// Maximum number of iterations
21    pub max_iter: usize,
22    /// Convergence tolerance for the residual
23    pub atol: f64,
24    /// Convergence tolerance for the solution
25    pub btol: f64,
26    /// Condition number limit
27    pub conlim: f64,
28    /// Whether to compute standard errors
29    pub calc_var: bool,
30    /// Whether to store residual history
31    pub store_residual_history: bool,
32}
33
34impl Default for LSQROptions {
35    fn default() -> Self {
36        Self {
37            max_iter: 1000,
38            atol: 1e-8,
39            btol: 1e-8,
40            conlim: 1e8,
41            calc_var: false,
42            store_residual_history: true,
43        }
44    }
45}
46
47/// Result from LSQR solver
48#[derive(Debug, Clone)]
49pub struct LSQRResult<T> {
50    /// Solution vector
51    pub x: Array1<T>,
52    /// Number of iterations performed
53    pub iterations: usize,
54    /// Final residual norm ||Ax - b||
55    pub residualnorm: T,
56    /// Final solution norm ||x||
57    pub solution_norm: T,
58    /// Condition number estimate
59    pub condition_number: T,
60    /// Whether the solver converged
61    pub converged: bool,
62    /// Standard errors (if requested)
63    pub standard_errors: Option<Array1<T>>,
64    /// Residual history (if requested)
65    pub residual_history: Option<Vec<T>>,
66    /// Convergence reason
67    pub convergence_reason: String,
68}
69
70/// LSQR algorithm for sparse least squares problems
71///
72/// Solves the least squares problem min ||Ax - b||_2 or the linear system Ax = b.
73/// The method is based on the bidiagonalization of A.
74///
75/// # Arguments
76///
77/// * `matrix` - The coefficient matrix A (m x n)
78/// * `b` - The right-hand side vector (length m)
79/// * `x0` - Initial guess (optional, length n)
80/// * `options` - Solver options
81///
82/// # Returns
83///
84/// An `LSQRResult` containing the solution and convergence information
85///
86/// # Example
87///
88/// ```rust
89/// use scirs2_sparse::csr_array::CsrArray;
90/// use scirs2_sparse::linalg::{lsqr, LSQROptions};
91/// use scirs2_core::ndarray::Array1;
92///
93/// // Create an overdetermined system
94/// let rows = vec![0, 0, 1, 1, 2, 2];
95/// let cols = vec![0, 1, 0, 1, 0, 1];
96/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
97/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
98///
99/// // Right-hand side
100/// let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
101///
102/// // Solve using LSQR
103/// let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
104/// ```
105#[allow(dead_code)]
106pub fn lsqr<T, S>(
107    matrix: &S,
108    b: &ArrayView1<T>,
109    x0: Option<&ArrayView1<T>>,
110    options: LSQROptions,
111) -> SparseResult<LSQRResult<T>>
112where
113    T: Float + SparseElement + Debug + Copy + 'static,
114    S: SparseArray<T>,
115{
116    let (m, n) = matrix.shape();
117
118    if b.len() != m {
119        return Err(SparseError::DimensionMismatch {
120            expected: m,
121            found: b.len(),
122        });
123    }
124
125    // Initialize solution vector
126    let mut x = match x0 {
127        Some(x0_val) => {
128            if x0_val.len() != n {
129                return Err(SparseError::DimensionMismatch {
130                    expected: n,
131                    found: x0_val.len(),
132                });
133            }
134            x0_val.to_owned()
135        }
136        None => Array1::zeros(n),
137    };
138
139    // Compute initial residual
140    let ax = matrix_vector_multiply(matrix, &x.view())?;
141    let mut u = b - &ax;
142    let beta = l2_norm(&u.view());
143
144    if beta > T::sparse_zero() {
145        for i in 0..m {
146            u[i] = u[i] / beta;
147        }
148    }
149
150    // Initialize variables
151    let mut v = matrix_transpose_vector_multiply(matrix, &u.view())?;
152    let mut alpha = l2_norm(&v.view());
153
154    if alpha > T::sparse_zero() {
155        for i in 0..n {
156            v[i] = v[i] / alpha;
157        }
158    }
159
160    let mut w = v.clone();
161    let mut x_norm = T::sparse_zero();
162    let mut dd_norm = T::sparse_zero();
163    let mut res2 = beta;
164
165    // Variables for QR factorization of bidiagonal matrix
166    let mut rho_bar = alpha;
167    let mut phi_bar = beta;
168
169    // Tolerances
170    let atol = T::from(options.atol).unwrap();
171    let btol = T::from(options.btol).unwrap();
172    let conlim = T::from(options.conlim).unwrap();
173
174    let mut residual_history = if options.store_residual_history {
175        Some(vec![beta])
176    } else {
177        None
178    };
179
180    let mut converged = false;
181    let mut convergence_reason = String::new();
182    let mut iter = 0;
183
184    for k in 0..options.max_iter {
185        iter = k + 1;
186
187        // Bidiagonalization step: u := A*v - alpha*u
188        let av = matrix_vector_multiply(matrix, &v.view())?;
189        for i in 0..m {
190            u[i] = av[i] - alpha * u[i];
191        }
192        let beta_new = l2_norm(&u.view());
193
194        if beta_new > T::sparse_zero() {
195            for i in 0..m {
196                u[i] = u[i] / beta_new;
197            }
198        }
199
200        // v := A^T*u - beta_new*v
201        let atu = matrix_transpose_vector_multiply(matrix, &u.view())?;
202        for i in 0..n {
203            v[i] = atu[i] - beta_new * v[i];
204        }
205        let alpha_new = l2_norm(&v.view());
206
207        if alpha_new > T::sparse_zero() {
208            for i in 0..n {
209                v[i] = v[i] / alpha_new;
210            }
211        }
212
213        // QR factorization of the bidiagonal matrix
214        let rho = (rho_bar * rho_bar + beta_new * beta_new).sqrt();
215        let c = rho_bar / rho;
216        let s = beta_new / rho;
217        let theta = s * alpha_new;
218        let rho_bar_new = -c * alpha_new;
219        let phi = c * phi_bar;
220        let phi_bar_new = s * phi_bar;
221
222        // Update solution
223        for i in 0..n {
224            x[i] = x[i] + (phi / rho) * w[i];
225            w[i] = v[i] - (theta / rho) * w[i];
226        }
227
228        // Update norms and residual estimate
229        x_norm = (x_norm * x_norm + (phi / rho) * (phi / rho)).sqrt();
230        dd_norm = dd_norm + (T::sparse_one() / rho) * (T::sparse_one() / rho);
231        res2 = phi_bar_new.abs();
232
233        if let Some(ref mut history) = residual_history {
234            history.push(res2);
235        }
236
237        // Check convergence
238        let r1_norm = res2;
239        let r2_norm = if x_norm > T::sparse_zero() {
240            alpha_new.abs() * x_norm
241        } else {
242            alpha_new.abs()
243        };
244
245        let test1 = r1_norm / (atol + btol * beta);
246        let test2 = if x_norm > T::sparse_zero() {
247            alpha_new.abs() / (atol + btol * x_norm)
248        } else {
249            alpha_new.abs() / atol
250        };
251        let test3 = T::sparse_one() / conlim;
252
253        if test1 <= T::sparse_one() {
254            converged = true;
255            convergence_reason = "Residual tolerance satisfied".to_string();
256            break;
257        }
258
259        if test2 <= T::sparse_one() {
260            converged = true;
261            convergence_reason = "Solution tolerance satisfied".to_string();
262            break;
263        }
264
265        // Condition number estimate should be compared to limit, not x_norm to test3
266        let condition_estimate = if dd_norm > T::sparse_zero() {
267            x_norm / dd_norm.sqrt()
268        } else {
269            T::sparse_one()
270        };
271
272        if condition_estimate > conlim {
273            converged = true;
274            convergence_reason = "Condition number limit reached".to_string();
275            break;
276        }
277
278        // Update for next iteration
279        alpha = alpha_new;
280        rho_bar = rho_bar_new;
281        phi_bar = phi_bar_new;
282    }
283
284    if !converged {
285        convergence_reason = "Maximum iterations reached".to_string();
286    }
287
288    // Compute final metrics
289    let ax_final = matrix_vector_multiply(matrix, &x.view())?;
290    let final_residual = b - &ax_final;
291    let final_residualnorm = l2_norm(&final_residual.view());
292    let final_solution_norm = l2_norm(&x.view());
293
294    // Estimate condition number (simplified)
295    let condition_number = if dd_norm > T::sparse_zero() {
296        x_norm / dd_norm.sqrt()
297    } else {
298        T::sparse_one()
299    };
300
301    // Compute standard errors if requested
302    let standard_errors = if options.calc_var {
303        Some(compute_standard_errors(matrix, final_residualnorm, n)?)
304    } else {
305        None
306    };
307
308    Ok(LSQRResult {
309        x,
310        iterations: iter,
311        residualnorm: final_residualnorm,
312        solution_norm: final_solution_norm,
313        condition_number,
314        converged,
315        standard_errors,
316        residual_history,
317        convergence_reason,
318    })
319}
320
321/// Helper function for matrix-vector multiplication
322#[allow(dead_code)]
323fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
324where
325    T: Float + SparseElement + Debug + Copy + 'static,
326    S: SparseArray<T>,
327{
328    let (rows, cols) = matrix.shape();
329    if x.len() != cols {
330        return Err(SparseError::DimensionMismatch {
331            expected: cols,
332            found: x.len(),
333        });
334    }
335
336    let mut result = Array1::zeros(rows);
337    let (row_indices, col_indices, values) = matrix.find();
338
339    for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
340        result[i] = result[i] + values[k] * x[j];
341    }
342
343    Ok(result)
344}
345
346/// Helper function for matrix transpose-vector multiplication
347#[allow(dead_code)]
348fn matrix_transpose_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
349where
350    T: Float + SparseElement + Debug + Copy + 'static,
351    S: SparseArray<T>,
352{
353    let (rows, cols) = matrix.shape();
354    if x.len() != rows {
355        return Err(SparseError::DimensionMismatch {
356            expected: rows,
357            found: x.len(),
358        });
359    }
360
361    let mut result = Array1::zeros(cols);
362    let (row_indices, col_indices, values) = matrix.find();
363
364    for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
365        result[j] = result[j] + values[k] * x[i];
366    }
367
368    Ok(result)
369}
370
371/// Compute L2 norm of a vector
372#[allow(dead_code)]
373fn l2_norm<T>(x: &ArrayView1<T>) -> T
374where
375    T: Float + SparseElement + Debug + Copy,
376{
377    (x.iter()
378        .map(|&val| val * val)
379        .fold(T::sparse_zero(), |a, b| a + b))
380    .sqrt()
381}
382
383/// Compute standard errors (simplified implementation)
384#[allow(dead_code)]
385fn compute_standard_errors<T, S>(matrix: &S, residualnorm: T, n: usize) -> SparseResult<Array1<T>>
386where
387    T: Float + SparseElement + Debug + Copy + 'static,
388    S: SparseArray<T>,
389{
390    let (m, _) = matrix.shape();
391
392    // Simplified standard error computation
393    // In practice, this should use the diagonal of (A^T A)^(-1)
394    let variance = if m > n {
395        residualnorm * residualnorm / T::from(m - n).unwrap()
396    } else {
397        residualnorm * residualnorm
398    };
399
400    let std_err = variance.sqrt();
401    Ok(Array1::from_elem(n, std_err))
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use crate::csr_array::CsrArray;
408    use approx::assert_relative_eq;
409
410    #[test]
411    fn test_lsqr_square_system() {
412        // Create a simple 3x3 system
413        let rows = vec![0, 0, 1, 1, 2, 2];
414        let cols = vec![0, 1, 0, 1, 1, 2];
415        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
416        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
417
418        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
419        let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
420
421        assert!(result.converged);
422
423        // Verify solution by computing residual
424        let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
425        let residual = &b - &ax;
426        let residualnorm = l2_norm(&residual.view());
427
428        assert!(residualnorm < 1e-6);
429    }
430
431    #[test]
432    fn test_lsqr_overdetermined_system() {
433        // Create an overdetermined 3x2 system
434        let rows = vec![0, 0, 1, 1, 2, 2];
435        let cols = vec![0, 1, 0, 1, 0, 1];
436        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
437        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
438
439        let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
440        let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
441
442        assert!(result.converged);
443        assert_eq!(result.x.len(), 2);
444
445        // For overdetermined systems, check that we get a reasonable least squares solution
446        assert!(result.residualnorm < 2.0); // Should be a reasonable fit
447    }
448
449    #[test]
450    fn test_lsqr_diagonal_system() {
451        // Create a diagonal system
452        let rows = vec![0, 1, 2];
453        let cols = vec![0, 1, 2];
454        let data = vec![2.0, 3.0, 4.0];
455        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
456
457        let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
458        let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
459
460        assert!(result.converged);
461
462        // For diagonal system, solution should be [2, 3, 4]
463        assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
464        assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
465        assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
466    }
467
468    #[test]
469    fn test_lsqr_with_initial_guess() {
470        let rows = vec![0, 1, 2];
471        let cols = vec![0, 1, 2];
472        let data = vec![1.0, 1.0, 1.0];
473        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
474
475        let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
476        let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); // Close to solution
477
478        let result = lsqr(&matrix, &b.view(), Some(&x0.view()), LSQROptions::default()).unwrap();
479
480        assert!(result.converged);
481        assert!(result.iterations <= 5); // Should converge quickly with good initial guess
482    }
483
484    #[test]
485    fn test_lsqr_standard_errors() {
486        let rows = vec![0, 1, 2];
487        let cols = vec![0, 1, 2];
488        let data = vec![1.0, 1.0, 1.0];
489        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
490
491        let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
492
493        let options = LSQROptions {
494            calc_var: true,
495            ..Default::default()
496        };
497
498        let result = lsqr(&matrix, &b.view(), None, options).unwrap();
499
500        assert!(result.converged);
501        assert!(result.standard_errors.is_some());
502
503        let std_errs = result.standard_errors.unwrap();
504        assert_eq!(std_errs.len(), 3);
505    }
506}