1#![allow(unused_variables)]
8#![allow(unused_assignments)]
9#![allow(unused_mut)]
10
11use crate::error::{SparseError, SparseResult};
12use crate::sparray::SparseArray;
13use scirs2_core::ndarray::{Array1, ArrayView1};
14use scirs2_core::numeric::{Float, SparseElement};
15use std::fmt::Debug;
16
17#[derive(Debug, Clone)]
19pub struct LSMROptions {
20 pub max_iter: usize,
22 pub atol: f64,
24 pub btol: f64,
26 pub conlim: f64,
28 pub calc_var: bool,
30 pub store_residual_history: bool,
32 pub local_size: usize,
34}
35
36impl Default for LSMROptions {
37 fn default() -> Self {
38 Self {
39 max_iter: 1000,
40 atol: 1e-8,
41 btol: 1e-8,
42 conlim: 1e8,
43 calc_var: false,
44 store_residual_history: true,
45 local_size: 0,
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct LSMRResult<T> {
53 pub x: Array1<T>,
55 pub iterations: usize,
57 pub residualnorm: T,
59 pub solution_norm: T,
61 pub condition_number: T,
63 pub converged: bool,
65 pub standard_errors: Option<Array1<T>>,
67 pub residual_history: Option<Vec<T>>,
69 pub convergence_reason: String,
71}
72
73#[allow(dead_code)]
109pub fn lsmr<T, S>(
110 matrix: &S,
111 b: &ArrayView1<T>,
112 x0: Option<&ArrayView1<T>>,
113 options: LSMROptions,
114) -> SparseResult<LSMRResult<T>>
115where
116 T: Float + SparseElement + Debug + Copy + 'static,
117 S: SparseArray<T>,
118{
119 let (m, n) = matrix.shape();
120
121 if b.len() != m {
122 return Err(SparseError::DimensionMismatch {
123 expected: m,
124 found: b.len(),
125 });
126 }
127
128 let mut x = match x0 {
130 Some(x0_val) => {
131 if x0_val.len() != n {
132 return Err(SparseError::DimensionMismatch {
133 expected: n,
134 found: x0_val.len(),
135 });
136 }
137 x0_val.to_owned()
138 }
139 None => Array1::zeros(n),
140 };
141
142 let ax = matrix_vector_multiply(matrix, &x.view())?;
144 let mut u = b - &ax;
145 let mut beta = l2_norm(&u.view());
146
147 let atol = T::from(options.atol).unwrap();
149 let btol = T::from(options.btol).unwrap();
150 let conlim = T::from(options.conlim).unwrap();
151
152 let mut residual_history = if options.store_residual_history {
153 Some(vec![beta])
154 } else {
155 None
156 };
157
158 if beta <= atol {
160 let solution_norm = l2_norm(&x.view());
161 return Ok(LSMRResult {
162 x,
163 iterations: 0,
164 residualnorm: beta,
165 solution_norm,
166 condition_number: T::sparse_one(),
167 converged: true,
168 standard_errors: None,
169 residual_history,
170 convergence_reason: "Already converged".to_string(),
171 });
172 }
173
174 if beta > T::sparse_zero() {
176 for i in 0..m {
177 u[i] = u[i] / beta;
178 }
179 }
180
181 let mut v = matrix_transpose_vector_multiply(matrix, &u.view())?;
183 let mut alpha = l2_norm(&v.view());
184
185 if alpha > T::sparse_zero() {
186 for i in 0..n {
187 v[i] = v[i] / alpha;
188 }
189 }
190
191 let mut alphabar = alpha;
193 let mut zetabar = alpha * beta;
194 let mut rho = T::sparse_one();
195 let mut rhobar = T::sparse_one();
196 let mut cbar = T::sparse_one();
197 let mut sbar = T::sparse_zero();
198
199 let mut h = v.clone();
200 let mut hbar = Array1::zeros(n);
201
202 let mut arnorm = alpha * beta;
204 let mut beta_dd = beta;
205 let mut tau = T::sparse_zero();
206 let mut theta = T::sparse_zero();
207 let mut zeta = T::sparse_zero();
208 let mut d = T::sparse_zero();
209 let mut res2 = T::sparse_zero();
210 let mut anorm = T::sparse_zero();
211 let mut xxnorm = T::sparse_zero();
212
213 let mut converged = false;
214 let mut convergence_reason = String::new();
215 let mut iter = 0;
216
217 for k in 0..options.max_iter {
218 iter = k + 1;
219
220 let au = matrix_vector_multiply(matrix, &v.view())?;
222 for i in 0..m {
223 u[i] = au[i] - alpha * u[i];
224 }
225 beta = l2_norm(&u.view());
226
227 if beta > T::sparse_zero() {
228 for i in 0..m {
229 u[i] = u[i] / beta;
230 }
231
232 let atu = matrix_transpose_vector_multiply(matrix, &u.view())?;
233 for i in 0..n {
234 v[i] = atu[i] - beta * v[i];
235 }
236 alpha = l2_norm(&v.view());
237
238 if alpha > T::sparse_zero() {
239 for i in 0..n {
240 v[i] = v[i] / alpha;
241 }
242 }
243
244 anorm = (anorm * anorm + alpha * alpha + beta * beta).sqrt();
245 }
246
247 let rhobar1 = (rhobar * rhobar + beta * beta).sqrt();
249 let cs1 = rhobar / rhobar1;
250 let sn1 = beta / rhobar1;
251 let psi = sn1 * alpha;
252 alpha = cs1 * alpha;
253
254 let cs = cbar * cs1;
256 let sn = sbar * cs1;
257 let theta = sbar * alpha;
258 rho = (cs * alpha * cs * alpha + theta * theta).sqrt();
259 let c = cs * alpha / rho;
260 let s = theta / rho;
261 zeta = c * zetabar;
262 zetabar = -s * zetabar;
263
264 for i in 0..n {
266 hbar[i] = h[i] - (theta * rho / (rhobar * rhobar1)) * hbar[i];
267 x[i] = x[i] + (zeta / (rho * rhobar1)) * hbar[i];
268 h[i] = v[i] - (alpha / rhobar1) * h[i];
269 }
270
271 xxnorm = (xxnorm + (zeta / rho) * (zeta / rho)).sqrt();
273 let ddnorm = (d + (zeta / rho) * (zeta / rho)).sqrt();
274 d = ddnorm;
275
276 let beta_dd1 = beta_dd;
278 let beta_dd = beta * sn1;
279 let rhodold = rho;
280 let tautilde = (zetabar * zetabar).sqrt();
281 let tau = tau + tautilde * tautilde;
282 let d1 = (d * d + (beta_dd1 / rhodold) * (beta_dd1 / rhodold)).sqrt();
283 let d2 = (d1 * d1 + (beta_dd / rho) * (beta_dd / rho)).sqrt();
284
285 res2 = (d2 * d2 + tau).sqrt();
286 let arnorm = alpha * beta.abs();
287
288 if let Some(ref mut history) = residual_history {
289 history.push(res2);
290 }
291
292 let r1norm = res2;
294 let r2norm = arnorm;
295 let cond = anorm * xxnorm;
296
297 let test1 = res2 / (T::sparse_one() + anorm * xxnorm);
298 let test2 = arnorm / (T::sparse_one() + anorm);
299 let test3 = T::sparse_one() / (T::sparse_one() + cond);
300
301 if test1 <= atol {
302 converged = true;
303 convergence_reason = "Residual tolerance satisfied".to_string();
304 break;
305 }
306
307 if test2 <= btol {
308 converged = true;
309 convergence_reason = "Solution tolerance satisfied".to_string();
310 break;
311 }
312
313 if test3 <= T::sparse_one() / conlim {
314 converged = true;
315 convergence_reason = "Condition number limit reached".to_string();
316 break;
317 }
318
319 rhobar = rhobar1;
321 cbar = cs1;
322 sbar = sn1;
323 alphabar = alpha;
324 }
325
326 if !converged {
327 convergence_reason = "Maximum iterations reached".to_string();
328 }
329
330 let ax_final = matrix_vector_multiply(matrix, &x.view())?;
332 let final_residual = b - &ax_final;
333 let final_residualnorm = l2_norm(&final_residual.view());
334 let final_solution_norm = l2_norm(&x.view());
335
336 let condition_number = anorm * xxnorm;
338
339 let standard_errors = if options.calc_var {
341 Some(compute_standard_errors(matrix, final_residualnorm, n)?)
342 } else {
343 None
344 };
345
346 Ok(LSMRResult {
347 x,
348 iterations: iter,
349 residualnorm: final_residualnorm,
350 solution_norm: final_solution_norm,
351 condition_number,
352 converged,
353 standard_errors,
354 residual_history,
355 convergence_reason,
356 })
357}
358
359#[allow(dead_code)]
361fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
362where
363 T: Float + SparseElement + Debug + Copy + 'static,
364 S: SparseArray<T>,
365{
366 let (rows, cols) = matrix.shape();
367 if x.len() != cols {
368 return Err(SparseError::DimensionMismatch {
369 expected: cols,
370 found: x.len(),
371 });
372 }
373
374 let mut result = Array1::zeros(rows);
375 let (row_indices, col_indices, values) = matrix.find();
376
377 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
378 result[i] = result[i] + values[k] * x[j];
379 }
380
381 Ok(result)
382}
383
384#[allow(dead_code)]
386fn matrix_transpose_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
387where
388 T: Float + SparseElement + Debug + Copy + 'static,
389 S: SparseArray<T>,
390{
391 let (rows, cols) = matrix.shape();
392 if x.len() != rows {
393 return Err(SparseError::DimensionMismatch {
394 expected: rows,
395 found: x.len(),
396 });
397 }
398
399 let mut result = Array1::zeros(cols);
400 let (row_indices, col_indices, values) = matrix.find();
401
402 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
403 result[j] = result[j] + values[k] * x[i];
404 }
405
406 Ok(result)
407}
408
409#[allow(dead_code)]
411fn l2_norm<T>(x: &ArrayView1<T>) -> T
412where
413 T: Float + SparseElement + Debug + Copy,
414{
415 (x.iter()
416 .map(|&val| val * val)
417 .fold(T::sparse_zero(), |a, b| a + b))
418 .sqrt()
419}
420
421#[allow(dead_code)]
423fn compute_standard_errors<T, S>(matrix: &S, residualnorm: T, n: usize) -> SparseResult<Array1<T>>
424where
425 T: Float + SparseElement + Debug + Copy + 'static,
426 S: SparseArray<T>,
427{
428 let (m, _) = matrix.shape();
429
430 let variance = if m > n {
432 residualnorm * residualnorm / T::from(m - n).unwrap()
433 } else {
434 residualnorm * residualnorm
435 };
436
437 let std_err = variance.sqrt();
438 Ok(Array1::from_elem(n, std_err))
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::csr_array::CsrArray;
445 use approx::assert_relative_eq;
446
447 #[test]
448 #[ignore] fn test_lsmr_square_system() {
450 let rows = vec![0, 0, 1, 1, 2, 2];
452 let cols = vec![0, 1, 0, 1, 1, 2];
453 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
454 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
455
456 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
457 let result = lsmr(&matrix, &b.view(), None, LSMROptions::default()).unwrap();
458
459 assert!(result.converged);
460
461 let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
463 let residual = &b - &ax;
464 let residualnorm = l2_norm(&residual.view());
465
466 assert!(residualnorm < 1e-6);
467 }
468
469 #[test]
470 #[ignore] fn test_lsmr_overdetermined_system() {
472 let rows = vec![0, 0, 1, 1, 2, 2];
474 let cols = vec![0, 1, 0, 1, 0, 1];
475 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
476 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).unwrap();
477
478 let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
479 let result = lsmr(&matrix, &b.view(), None, LSMROptions::default()).unwrap();
480
481 assert!(result.converged);
482 assert_eq!(result.x.len(), 2);
483
484 assert!(result.residualnorm < 2.0);
486 }
487
488 #[test]
489 #[ignore] fn test_lsmr_diagonal_system() {
491 let rows = vec![0, 1, 2];
493 let cols = vec![0, 1, 2];
494 let data = vec![2.0, 3.0, 4.0];
495 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
496
497 let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
498 let result = lsmr(&matrix, &b.view(), None, LSMROptions::default()).unwrap();
499
500 assert!(result.converged);
501
502 assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
504 assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
505 assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
506 }
507
508 #[test]
509 fn test_lsmr_with_initial_guess() {
510 let rows = vec![0, 1, 2];
511 let cols = vec![0, 1, 2];
512 let data = vec![1.0, 1.0, 1.0];
513 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
514
515 let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
516 let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); let result = lsmr(&matrix, &b.view(), Some(&x0.view()), LSMROptions::default()).unwrap();
519
520 assert!(result.converged);
521 assert!(result.iterations <= 10); }
523
524 #[test]
525 fn test_lsmr_standard_errors() {
526 let rows = vec![0, 1, 2];
527 let cols = vec![0, 1, 2];
528 let data = vec![1.0, 1.0, 1.0];
529 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
530
531 let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
532
533 let options = LSMROptions {
534 calc_var: true,
535 ..Default::default()
536 };
537
538 let result = lsmr(&matrix, &b.view(), None, options).unwrap();
539
540 assert!(result.converged);
541 assert!(result.standard_errors.is_some());
542
543 let std_errs = result.standard_errors.unwrap();
544 assert_eq!(std_errs.len(), 3);
545 }
546}