1use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::fmt::Debug;
13
14type GCROTInnerResult<T> = SparseResult<(Array1<T>, Option<Array1<T>>, Option<Array1<T>>, bool)>;
16
17#[derive(Debug, Clone)]
19pub struct GCROTOptions {
20 pub max_iter: usize,
22 pub tol: f64,
24 pub truncation_size: usize,
26 pub store_residual_history: bool,
28}
29
30impl Default for GCROTOptions {
31 fn default() -> Self {
32 Self {
33 max_iter: 1000,
34 tol: 1e-6,
35 truncation_size: 20,
36 store_residual_history: true,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct GCROTResult<T> {
44 pub x: Array1<T>,
46 pub iterations: usize,
48 pub residual_norm: T,
50 pub converged: bool,
52 pub residual_history: Option<Vec<T>>,
54}
55
56#[allow(dead_code)]
93pub fn gcrot<T, S>(
94 matrix: &S,
95 b: &ArrayView1<T>,
96 x0: Option<&ArrayView1<T>>,
97 options: GCROTOptions,
98) -> SparseResult<GCROTResult<T>>
99where
100 T: Float + SparseElement + Debug + Copy + 'static,
101 S: SparseArray<T>,
102{
103 let n = b.len();
104 let (rows, cols) = matrix.shape();
105
106 if rows != cols || rows != n {
107 return Err(SparseError::DimensionMismatch {
108 expected: n,
109 found: rows,
110 });
111 }
112
113 let mut x = match x0 {
115 Some(x0_val) => x0_val.to_owned(),
116 None => Array1::zeros(n),
117 };
118
119 let ax = matrix_vector_multiply(matrix, &x.view())?;
121 let mut r = b - &ax;
122
123 let initial_residual_norm = l2_norm(&r.view());
125 let b_norm = l2_norm(b);
126 let tolerance = T::from(options.tol).unwrap() * b_norm;
127
128 if initial_residual_norm <= tolerance {
129 return Ok(GCROTResult {
130 x,
131 iterations: 0,
132 residual_norm: initial_residual_norm,
133 converged: true,
134 residual_history: if options.store_residual_history {
135 Some(vec![initial_residual_norm])
136 } else {
137 None
138 },
139 });
140 }
141
142 let m = options.truncation_size;
143
144 let mut c_vectors = Array2::zeros((n, 0)); let mut u_vectors = Array2::zeros((n, 0)); let mut residual_history = if options.store_residual_history {
149 Some(vec![initial_residual_norm])
150 } else {
151 None
152 };
153
154 let mut converged = false;
155 let mut iter = 0;
156
157 for k in 0..options.max_iter {
158 iter = k + 1;
159
160 let (delta_x, new_c, new_u, inner_converged) = gcrot_inner_iteration(
162 matrix,
163 &r.view(),
164 &c_vectors.view(),
165 &u_vectors.view(),
166 tolerance,
167 )?;
168
169 x = &x + &delta_x;
171
172 let ax = matrix_vector_multiply(matrix, &x.view())?;
174 r = b - &ax;
175 let residual_norm = l2_norm(&r.view());
176
177 if let Some(ref mut history) = residual_history {
178 history.push(residual_norm);
179 }
180
181 if residual_norm <= tolerance || inner_converged {
183 converged = true;
184 break;
185 }
186
187 if let (Some(c), Some(u)) = (new_c, new_u) {
189 if c_vectors.ncols() >= m {
190 let mut new_c_vectors = Array2::zeros((n, m));
192 let mut new_u_vectors = Array2::zeros((n, m));
193
194 for j in 1..c_vectors.ncols() {
196 for i in 0..n {
197 new_c_vectors[[i, j - 1]] = c_vectors[[i, j]];
198 new_u_vectors[[i, j - 1]] = u_vectors[[i, j]];
199 }
200 }
201
202 for i in 0..n {
204 new_c_vectors[[i, m - 1]] = c[i];
205 new_u_vectors[[i, m - 1]] = u[i];
206 }
207
208 c_vectors = new_c_vectors;
209 u_vectors = new_u_vectors;
210 } else {
211 let old_cols = c_vectors.ncols();
213 let mut new_c_vectors = Array2::zeros((n, old_cols + 1));
214 let mut new_u_vectors = Array2::zeros((n, old_cols + 1));
215
216 for j in 0..old_cols {
218 for i in 0..n {
219 new_c_vectors[[i, j]] = c_vectors[[i, j]];
220 new_u_vectors[[i, j]] = u_vectors[[i, j]];
221 }
222 }
223
224 for i in 0..n {
226 new_c_vectors[[i, old_cols]] = c[i];
227 new_u_vectors[[i, old_cols]] = u[i];
228 }
229
230 c_vectors = new_c_vectors;
231 u_vectors = new_u_vectors;
232 }
233 }
234 }
235
236 let ax_final = matrix_vector_multiply(matrix, &x.view())?;
238 let final_residual = b - &ax_final;
239 let final_residual_norm = l2_norm(&final_residual.view());
240
241 Ok(GCROTResult {
242 x,
243 iterations: iter,
244 residual_norm: final_residual_norm,
245 converged,
246 residual_history,
247 })
248}
249
250#[allow(dead_code)]
252fn gcrot_inner_iteration<T, S>(
253 matrix: &S,
254 r: &ArrayView1<T>,
255 c_vectors: &scirs2_core::ndarray::ArrayView2<T>,
256 u_vectors: &scirs2_core::ndarray::ArrayView2<T>,
257 tolerance: T,
258) -> GCROTInnerResult<T>
259where
260 T: Float + SparseElement + Debug + Copy + 'static,
261 S: SparseArray<T>,
262{
263 let n = r.len();
264 let k = c_vectors.ncols(); let mut v = r.to_owned();
268 let beta = l2_norm(&v.view());
269
270 if beta <= tolerance {
271 return Ok((Array1::zeros(n), None, None, true));
272 }
273
274 for i in 0..n {
276 v[i] = v[i] / beta;
277 }
278
279 for j in 0..k {
281 let mut proj = T::sparse_zero();
282 for i in 0..n {
283 proj = proj + u_vectors[[i, j]] * v[i];
284 }
285
286 for i in 0..n {
287 v[i] = v[i] - proj * c_vectors[[i, j]];
288 }
289 }
290
291 let v_norm = l2_norm(&v.view());
293 if v_norm > T::from(1e-12).unwrap() {
294 for i in 0..n {
295 v[i] = v[i] / v_norm;
296 }
297 }
298
299 let av = matrix_vector_multiply(matrix, &v.view())?;
301
302 let av_norm_sq = dot_product(&av.view(), &av.view());
304 let av_r_dot = dot_product(&av.view(), r);
305
306 if av_norm_sq > T::from(1e-12).unwrap() {
307 let alpha = av_r_dot / av_norm_sq;
308 let mut delta_x = Array1::zeros(n);
309
310 for i in 0..n {
311 delta_x[i] = alpha * v[i];
312 }
313
314 Ok((delta_x, Some(v), Some(av), false))
315 } else {
316 Ok((Array1::zeros(n), None, None, true))
317 }
318}
319
320#[allow(dead_code)]
322fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
323where
324 T: Float + SparseElement + Debug + Copy + 'static,
325 S: SparseArray<T>,
326{
327 let (rows, cols) = matrix.shape();
328 if x.len() != cols {
329 return Err(SparseError::DimensionMismatch {
330 expected: cols,
331 found: x.len(),
332 });
333 }
334
335 let mut result = Array1::zeros(rows);
336 let (row_indices, col_indices, values) = matrix.find();
337
338 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
339 result[i] = result[i] + values[k] * x[j];
340 }
341
342 Ok(result)
343}
344
345#[allow(dead_code)]
347fn l2_norm<T>(x: &ArrayView1<T>) -> T
348where
349 T: Float + Debug + Copy + SparseElement,
350{
351 (x.iter()
352 .map(|&val| val * val)
353 .fold(T::sparse_zero(), |a, b| a + b))
354 .sqrt()
355}
356
357#[allow(dead_code)]
359fn dot_product<T>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T
360where
361 T: Float + Debug + Copy + SparseElement,
362{
363 x.iter()
364 .zip(y.iter())
365 .map(|(&xi, &yi)| xi * yi)
366 .fold(T::sparse_zero(), |a, b| a + b)
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372 use crate::csr_array::CsrArray;
373
374 #[test]
375 fn test_gcrot_simple_system() {
376 let rows = vec![0, 0, 1, 1, 2, 2];
378 let cols = vec![0, 1, 0, 1, 1, 2];
379 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
380 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
381
382 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
383 let result = gcrot(&matrix, &b.view(), None, GCROTOptions::default()).unwrap();
384
385 assert!(result.converged);
386
387 let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
389 let residual = &b - &ax;
390 let residual_norm = l2_norm(&residual.view());
391
392 assert!(residual_norm < 1e-6);
393 }
394
395 #[test]
396 fn test_gcrot_diagonal_system() {
397 let rows = vec![0, 1, 2];
399 let cols = vec![0, 1, 2];
400 let data = vec![5.0, 5.0, 5.0];
401 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
402
403 let b = Array1::from_vec(vec![5.0, 10.0, 15.0]);
404
405 let result = gcrot(&matrix, &b.view(), None, GCROTOptions::default()).unwrap();
406
407 assert!(result.converged);
408
409 let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
411 let residual = &b - &ax;
412 let residual_norm = l2_norm(&residual.view());
413
414 assert!(residual_norm < 1e-6);
415 }
416
417 #[test]
418 fn test_gcrot_truncation() {
419 let rows = vec![0, 0, 1, 1, 2, 2];
421 let cols = vec![0, 1, 0, 1, 1, 2];
422 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
423 let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
424
425 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
426
427 let options = GCROTOptions {
428 truncation_size: 2, ..Default::default()
430 };
431
432 let result = gcrot(&matrix, &b.view(), None, options).unwrap();
433
434 assert!(result.converged);
436 }
437}