1#![allow(unused_variables)]
8#![allow(unused_assignments)]
9#![allow(unused_mut)]
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::SparseArray;
13use scirs2_core::ndarray::{Array1, ArrayView1};
14use scirs2_core::numeric::{Float, SparseElement};
15use std::fmt::Debug;
16
17#[derive(Debug, Clone)]
19pub struct LSQROptions {
20 pub max_iter: usize,
22 pub atol: f64,
24 pub btol: f64,
26 pub conlim: f64,
28 pub calc_var: bool,
30 pub store_residual_history: bool,
32}
33
34impl Default for LSQROptions {
35 fn default() -> Self {
36 Self {
37 max_iter: 1000,
38 atol: 1e-8,
39 btol: 1e-8,
40 conlim: 1e8,
41 calc_var: false,
42 store_residual_history: true,
43 }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct LSQRResult<T> {
50 pub x: Array1<T>,
52 pub iterations: usize,
54 pub residualnorm: T,
56 pub solution_norm: T,
58 pub condition_number: T,
60 pub converged: bool,
62 pub standard_errors: Option<Array1<T>>,
64 pub residual_history: Option<Vec<T>>,
66 pub convergence_reason: String,
68}
69
70#[allow(dead_code)]
106pub fn lsqr<T, S>(
107 matrix: &S,
108 b: &ArrayView1<T>,
109 x0: Option<&ArrayView1<T>>,
110 options: LSQROptions,
111) -> SparseResult<LSQRResult<T>>
112where
113 T: Float + SparseElement + Debug + Copy + 'static,
114 S: SparseArray<T>,
115{
116 let (m, n) = matrix.shape();
117
118 if b.len() != m {
119 return Err(SparseError::DimensionMismatch {
120 expected: m,
121 found: b.len(),
122 });
123 }
124
125 let mut x = match x0 {
127 Some(x0_val) => {
128 if x0_val.len() != n {
129 return Err(SparseError::DimensionMismatch {
130 expected: n,
131 found: x0_val.len(),
132 });
133 }
134 x0_val.to_owned()
135 }
136 None => Array1::zeros(n),
137 };
138
139 let ax = matrix_vector_multiply(matrix, &x.view())?;
141 let mut u = b - &ax;
142 let beta = l2_norm(&u.view());
143
144 if beta > T::sparse_zero() {
145 for i in 0..m {
146 u[i] = u[i] / beta;
147 }
148 }
149
150 let mut v = matrix_transpose_vector_multiply(matrix, &u.view())?;
152 let mut alpha = l2_norm(&v.view());
153
154 if alpha > T::sparse_zero() {
155 for i in 0..n {
156 v[i] = v[i] / alpha;
157 }
158 }
159
160 let mut w = v.clone();
161 let mut x_norm = T::sparse_zero();
162 let mut dd_norm = T::sparse_zero();
163 let mut res2 = beta;
164
165 let mut rho_bar = alpha;
167 let mut phi_bar = beta;
168
169 let atol = T::from(options.atol).unwrap();
171 let btol = T::from(options.btol).unwrap();
172 let conlim = T::from(options.conlim).unwrap();
173
174 let mut residual_history = if options.store_residual_history {
175 Some(vec![beta])
176 } else {
177 None
178 };
179
180 let mut converged = false;
181 let mut convergence_reason = String::new();
182 let mut iter = 0;
183
184 for k in 0..options.max_iter {
185 iter = k + 1;
186
187 let av = matrix_vector_multiply(matrix, &v.view())?;
189 for i in 0..m {
190 u[i] = av[i] - alpha * u[i];
191 }
192 let beta_new = l2_norm(&u.view());
193
194 if beta_new > T::sparse_zero() {
195 for i in 0..m {
196 u[i] = u[i] / beta_new;
197 }
198 }
199
200 let atu = matrix_transpose_vector_multiply(matrix, &u.view())?;
202 for i in 0..n {
203 v[i] = atu[i] - beta_new * v[i];
204 }
205 let alpha_new = l2_norm(&v.view());
206
207 if alpha_new > T::sparse_zero() {
208 for i in 0..n {
209 v[i] = v[i] / alpha_new;
210 }
211 }
212
213 let rho = (rho_bar * rho_bar + beta_new * beta_new).sqrt();
215 let c = rho_bar / rho;
216 let s = beta_new / rho;
217 let theta = s * alpha_new;
218 let rho_bar_new = -c * alpha_new;
219 let phi = c * phi_bar;
220 let phi_bar_new = s * phi_bar;
221
222 for i in 0..n {
224 x[i] = x[i] + (phi / rho) * w[i];
225 w[i] = v[i] - (theta / rho) * w[i];
226 }
227
228 x_norm = (x_norm * x_norm + (phi / rho) * (phi / rho)).sqrt();
230 dd_norm = dd_norm + (T::sparse_one() / rho) * (T::sparse_one() / rho);
231 res2 = phi_bar_new.abs();
232
233 if let Some(ref mut history) = residual_history {
234 history.push(res2);
235 }
236
237 let r1_norm = res2;
239 let r2_norm = if x_norm > T::sparse_zero() {
240 alpha_new.abs() * x_norm
241 } else {
242 alpha_new.abs()
243 };
244
245 let test1 = r1_norm / (atol + btol * beta);
246 let test2 = if x_norm > T::sparse_zero() {
247 alpha_new.abs() / (atol + btol * x_norm)
248 } else {
249 alpha_new.abs() / atol
250 };
251 let test3 = T::sparse_one() / conlim;
252
253 if test1 <= T::sparse_one() {
254 converged = true;
255 convergence_reason = "Residual tolerance satisfied".to_string();
256 break;
257 }
258
259 if test2 <= T::sparse_one() {
260 converged = true;
261 convergence_reason = "Solution tolerance satisfied".to_string();
262 break;
263 }
264
265 let condition_estimate = if dd_norm > T::sparse_zero() {
267 x_norm / dd_norm.sqrt()
268 } else {
269 T::sparse_one()
270 };
271
272 if condition_estimate > conlim {
273 converged = true;
274 convergence_reason = "Condition number limit reached".to_string();
275 break;
276 }
277
278 alpha = alpha_new;
280 rho_bar = rho_bar_new;
281 phi_bar = phi_bar_new;
282 }
283
284 if !converged {
285 convergence_reason = "Maximum iterations reached".to_string();
286 }
287
288 let ax_final = matrix_vector_multiply(matrix, &x.view())?;
290 let final_residual = b - &ax_final;
291 let final_residualnorm = l2_norm(&final_residual.view());
292 let final_solution_norm = l2_norm(&x.view());
293
294 let condition_number = if dd_norm > T::sparse_zero() {
296 x_norm / dd_norm.sqrt()
297 } else {
298 T::sparse_one()
299 };
300
301 let standard_errors = if options.calc_var {
303 Some(compute_standard_errors(matrix, final_residualnorm, n)?)
304 } else {
305 None
306 };
307
308 Ok(LSQRResult {
309 x,
310 iterations: iter,
311 residualnorm: final_residualnorm,
312 solution_norm: final_solution_norm,
313 condition_number,
314 converged,
315 standard_errors,
316 residual_history,
317 convergence_reason,
318 })
319}
320
321#[allow(dead_code)]
323fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
324where
325 T: Float + SparseElement + Debug + Copy + 'static,
326 S: SparseArray<T>,
327{
328 let (rows, cols) = matrix.shape();
329 if x.len() != cols {
330 return Err(SparseError::DimensionMismatch {
331 expected: cols,
332 found: x.len(),
333 });
334 }
335
336 let mut result = Array1::zeros(rows);
337 let (row_indices, col_indices, values) = matrix.find();
338
339 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
340 result[i] = result[i] + values[k] * x[j];
341 }
342
343 Ok(result)
344}
345
346#[allow(dead_code)]
348fn matrix_transpose_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
349where
350 T: Float + SparseElement + Debug + Copy + 'static,
351 S: SparseArray<T>,
352{
353 let (rows, cols) = matrix.shape();
354 if x.len() != rows {
355 return Err(SparseError::DimensionMismatch {
356 expected: rows,
357 found: x.len(),
358 });
359 }
360
361 let mut result = Array1::zeros(cols);
362 let (row_indices, col_indices, values) = matrix.find();
363
364 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
365 result[j] = result[j] + values[k] * x[i];
366 }
367
368 Ok(result)
369}
370
371#[allow(dead_code)]
373fn l2_norm<T>(x: &ArrayView1<T>) -> T
374where
375 T: Float + SparseElement + Debug + Copy,
376{
377 (x.iter()
378 .map(|&val| val * val)
379 .fold(T::sparse_zero(), |a, b| a + b))
380 .sqrt()
381}
382
383#[allow(dead_code)]
385fn compute_standard_errors<T, S>(matrix: &S, residualnorm: T, n: usize) -> SparseResult<Array1<T>>
386where
387 T: Float + SparseElement + Debug + Copy + 'static,
388 S: SparseArray<T>,
389{
390 let (m, _) = matrix.shape();
391
392 let variance = if m > n {
395 residualnorm * residualnorm / T::from(m - n).unwrap()
396 } else {
397 residualnorm * residualnorm
398 };
399
400 let std_err = variance.sqrt();
401 Ok(Array1::from_elem(n, std_err))
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use crate::csr_array::CsrArray;
408 use approx::assert_relative_eq;
409
410 #[test]
411 fn test_lsqr_square_system() {
412 let rows = vec![0, 0, 1, 1, 2, 2];
414 let cols = vec![0, 1, 0, 1, 1, 2];
415 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
416 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
417
418 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
419 let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
420
421 assert!(result.converged);
422
423 let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
425 let residual = &b - &ax;
426 let residualnorm = l2_norm(&residual.view());
427
428 assert!(residualnorm < 1e-6);
429 }
430
431 #[test]
432 fn test_lsqr_overdetermined_system() {
433 let rows = vec![0, 0, 1, 1, 2, 2];
435 let cols = vec![0, 1, 0, 1, 0, 1];
436 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
437 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
438
439 let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
440 let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
441
442 assert!(result.converged);
443 assert_eq!(result.x.len(), 2);
444
445 assert!(result.residualnorm < 2.0); }
448
449 #[test]
450 fn test_lsqr_diagonal_system() {
451 let rows = vec![0, 1, 2];
453 let cols = vec![0, 1, 2];
454 let data = vec![2.0, 3.0, 4.0];
455 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
456
457 let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
458 let result = lsqr(&matrix, &b.view(), None, LSQROptions::default()).unwrap();
459
460 assert!(result.converged);
461
462 assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
464 assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
465 assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
466 }
467
468 #[test]
469 fn test_lsqr_with_initial_guess() {
470 let rows = vec![0, 1, 2];
471 let cols = vec![0, 1, 2];
472 let data = vec![1.0, 1.0, 1.0];
473 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
474
475 let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
476 let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); let result = lsqr(&matrix, &b.view(), Some(&x0.view()), LSQROptions::default()).unwrap();
479
480 assert!(result.converged);
481 assert!(result.iterations <= 5); }
483
484 #[test]
485 fn test_lsqr_standard_errors() {
486 let rows = vec![0, 1, 2];
487 let cols = vec![0, 1, 2];
488 let data = vec![1.0, 1.0, 1.0];
489 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
490
491 let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
492
493 let options = LSQROptions {
494 calc_var: true,
495 ..Default::default()
496 };
497
498 let result = lsqr(&matrix, &b.view(), None, options).unwrap();
499
500 assert!(result.converged);
501 assert!(result.standard_errors.is_some());
502
503 let std_errs = result.standard_errors.unwrap();
504 assert_eq!(std_errs.len(), 3);
505 }
506}