Skip to main content

scirs2_sparse/distributed/
dist_amg.rs

1//! Distributed Algebraic Multigrid (AMG) setup and V-cycle.
2//!
3//! Implements a distributed RS coarsening + direct interpolation AMG hierarchy
4//! with a simulated (shared-memory) communication pattern.  In a real MPI
5//! deployment the `simulate_*` steps would be replaced by actual point-to-point
6//! or AllGather calls.
7
8use crate::csr::CsrMatrix;
9use crate::error::{SparseError, SparseResult};
10
11use super::halo_exchange::distributed_spmv;
12use super::partition::{DistributedCsr, RowPartition};
13
14// ─────────────────────────────────────────────────────────────────────────────
15// DistAMGConfig
16// ─────────────────────────────────────────────────────────────────────────────
17
18/// Configuration for the distributed AMG hierarchy.
19#[derive(Debug, Clone)]
20pub struct DistAMGConfig {
21    /// Number of logical workers (default 4).
22    pub n_workers: usize,
23    /// Maximum number of AMG levels (default 4).
24    pub max_levels: usize,
25    /// Target coarsening ratio: stop coarsening when
26    /// `n_coarse / n_fine >= coarsening_ratio` (default 0.25).
27    pub coarsening_ratio: f64,
28    /// Number of pre/post-smoother iterations on each level (default 2).
29    pub smoother_iters: usize,
30}
31
32impl Default for DistAMGConfig {
33    fn default() -> Self {
34        Self {
35            n_workers: 4,
36            max_levels: 4,
37            coarsening_ratio: 0.25,
38            smoother_iters: 2,
39        }
40    }
41}
42
43// ─────────────────────────────────────────────────────────────────────────────
44// DistAMGLevel
45// ─────────────────────────────────────────────────────────────────────────────
46
47/// A single level in the distributed AMG hierarchy.
48#[derive(Debug, Clone)]
49pub struct DistAMGLevel {
50    /// The (local) fine-level matrix for this level, represented in global
51    /// row/column numbering.
52    pub local_matrix: CsrMatrix<f64>,
53    /// Partition metadata for this level's rows.
54    pub partition: RowPartition,
55    /// Local prolongation operator P (n_fine_local × n_coarse).
56    pub interpolation: CsrMatrix<f64>,
57    /// Local restriction operator R = P^T (n_coarse × n_fine_local).
58    pub restriction: CsrMatrix<f64>,
59}
60
61// ─────────────────────────────────────────────────────────────────────────────
62// DistAMGHierarchy
63// ─────────────────────────────────────────────────────────────────────────────
64
65/// The full distributed AMG hierarchy.
66#[derive(Debug, Clone)]
67pub struct DistAMGHierarchy {
68    /// AMG levels from fine (index 0) to coarsest-1.
69    pub levels: Vec<DistAMGLevel>,
70    /// The coarsest-level matrix (solved exactly / directly).
71    pub coarsest_matrix: CsrMatrix<f64>,
72}
73
74// ─────────────────────────────────────────────────────────────────────────────
75// distributed_rs_coarsening
76// ─────────────────────────────────────────────────────────────────────────────
77
78/// Perform Ruge-Stüben (RS) coarsening independently on each partition.
79///
80/// Returns a `Vec<Vec<bool>>` (one per worker) where `true` means the row is
81/// a *coarse* point and `false` means it is a *fine* point.
82///
83/// Algorithm per worker:
84/// 1. Compute the row-wise maximum off-diagonal |a_ij|.
85/// 2. Mark a connection as *strong* if |a_{ij}| ≥ 0.25 * max_j|a_ij|.
86/// 3. RS pass 1: greedily mark a local row as C if it has at least one strong
87///    connection to an undecided row; mark neighbours of new C-points as F.
88pub fn distributed_rs_coarsening(partitions: &[DistributedCsr]) -> Vec<Vec<bool>> {
89    partitions
90        .iter()
91        .map(|dcsr| rs_coarsen_local(&dcsr.local_matrix))
92        .collect()
93}
94
95/// RS coarsening for a single local matrix (in local row indices).
96fn rs_coarsen_local(mat: &CsrMatrix<f64>) -> Vec<bool> {
97    let n = mat.rows();
98    // is_coarse[i]: None = undecided, Some(true) = C, Some(false) = F
99    let mut status: Vec<Option<bool>> = vec![None; n];
100
101    // Step 1: compute per-row max |off-diagonal|.
102    let max_off_diag: Vec<f64> = (0..n)
103        .map(|i| {
104            let start = mat.indptr[i];
105            let end = mat.indptr[i + 1];
106            let mut m = 0.0_f64;
107            for k in start..end {
108                if mat.indices[k] != i {
109                    m = m.max(mat.data[k].abs());
110                }
111            }
112            m
113        })
114        .collect();
115
116    // Step 2: greedy C/F selection.
117    for i in 0..n {
118        if status[i].is_some() {
119            continue;
120        }
121        let start = mat.indptr[i];
122        let end = mat.indptr[i + 1];
123        let threshold = 0.25 * max_off_diag[i];
124
125        // Check whether row i has any strong connection to an undecided node.
126        // Note: column indices may be global (>= n) for off-partition entries;
127        // those are treated as always-undecided (None) for the local C/F decision.
128        let has_strong = (start..end).any(|k| {
129            let j = mat.indices[k];
130            if j == i {
131                return false;
132            }
133            let strong_val = mat.data[k].abs() >= threshold;
134            let neighbour_status = if j < n { status[j] } else { None };
135            strong_val && neighbour_status != Some(false)
136        });
137
138        if has_strong {
139            status[i] = Some(true); // C-point
140                                    // Mark strongly-connected undecided neighbours as F.
141            for k in start..end {
142                let j = mat.indices[k];
143                if j < n && j != i && status[j].is_none() && mat.data[k].abs() >= threshold {
144                    status[j] = Some(false);
145                }
146            }
147        } else {
148            status[i] = Some(true); // isolated → C-point
149        }
150    }
151
152    status.into_iter().map(|s| s.unwrap_or(true)).collect()
153}
154
155// ─────────────────────────────────────────────────────────────────────────────
156// build_distributed_interpolation
157// ─────────────────────────────────────────────────────────────────────────────
158
159/// Build local prolongation matrices P for each partition.
160///
161/// Direct interpolation formula: for each fine point *i* and each strong
162/// coarse neighbour *j*:
163///   w_{ij} = -a_{ij} / (a_{ii} * Σ_{k∈C_i} a_{ik} / a_{kk})
164///
165/// If a fine point has no coarse neighbours it is treated as a trivially
166/// prolonged point (zero interpolation — it will be smoothed away).
167pub fn build_distributed_interpolation(
168    partitions: &[DistributedCsr],
169    coarse_masks: &[Vec<bool>],
170) -> SparseResult<Vec<CsrMatrix<f64>>> {
171    if partitions.len() != coarse_masks.len() {
172        return Err(SparseError::DimensionMismatch {
173            expected: partitions.len(),
174            found: coarse_masks.len(),
175        });
176    }
177
178    partitions
179        .iter()
180        .zip(coarse_masks.iter())
181        .map(|(dcsr, mask)| build_local_interpolation(&dcsr.local_matrix, mask))
182        .collect()
183}
184
185/// Build P for one local matrix (local row/column indexing).
186fn build_local_interpolation(
187    mat: &CsrMatrix<f64>,
188    coarse_mask: &[bool],
189) -> SparseResult<CsrMatrix<f64>> {
190    let n = mat.rows();
191
192    // Assign compact coarse indices.
193    let mut coarse_idx: Vec<Option<usize>> = vec![None; n];
194    let mut n_coarse = 0usize;
195    for (i, &is_c) in coarse_mask.iter().enumerate() {
196        if is_c {
197            coarse_idx[i] = Some(n_coarse);
198            n_coarse += 1;
199        }
200    }
201
202    if n_coarse == 0 {
203        // No coarse points — return empty P.
204        return CsrMatrix::from_triplets(n, 1, vec![], vec![], vec![]);
205    }
206
207    // Pre-compute diagonal.
208    let diagonal: Vec<f64> = (0..n)
209        .map(|i| {
210            let start = mat.indptr[i];
211            let end = mat.indptr[i + 1];
212            (start..end)
213                .find(|&k| mat.indices[k] == i)
214                .map(|k| mat.data[k])
215                .unwrap_or(1.0)
216        })
217        .collect();
218
219    let mut p_rows: Vec<usize> = Vec::new();
220    let mut p_cols: Vec<usize> = Vec::new();
221    let mut p_vals: Vec<f64> = Vec::new();
222
223    for i in 0..n {
224        if coarse_mask[i] {
225            // C-point: P[i, coarse_idx[i]] = 1.
226            p_rows.push(i);
227            p_cols.push(coarse_idx[i].unwrap_or(0));
228            p_vals.push(1.0);
229        } else {
230            // F-point: direct interpolation from strong coarse neighbours.
231            let start = mat.indptr[i];
232            let end = mat.indptr[i + 1];
233            let a_ii = diagonal[i];
234            if a_ii.abs() < f64::EPSILON * 1e6 {
235                continue; // degenerate row — skip
236            }
237
238            // Collect strong coarse neighbours (threshold = 0.25 * max_off_diag).
239            let max_off = (start..end)
240                .filter(|&k| mat.indices[k] != i)
241                .map(|k| mat.data[k].abs())
242                .fold(0.0_f64, f64::max);
243            let threshold = 0.25 * max_off;
244
245            let coarse_nbrs: Vec<(usize, f64)> = (start..end)
246                .filter_map(|k| {
247                    let j = mat.indices[k];
248                    if j < n && j != i && coarse_mask[j] && mat.data[k].abs() >= threshold {
249                        Some((j, mat.data[k]))
250                    } else {
251                        None
252                    }
253                })
254                .collect();
255
256            if coarse_nbrs.is_empty() {
257                continue;
258            }
259
260            // Σ_{k∈C_i} a_{ik} / a_{kk}
261            let sum_ratio: f64 = coarse_nbrs
262                .iter()
263                .map(|&(j, a_ij)| {
264                    let a_jj = diagonal[j];
265                    if a_jj.abs() < f64::EPSILON * 1e6 {
266                        0.0
267                    } else {
268                        a_ij / a_jj
269                    }
270                })
271                .sum();
272
273            let denom = if sum_ratio.abs() < f64::EPSILON * 1e6 {
274                1.0
275            } else {
276                a_ii * sum_ratio
277            };
278
279            for (j, a_ij) in coarse_nbrs {
280                let w = -a_ij / denom;
281                if let Some(ci) = coarse_idx[j] {
282                    p_rows.push(i);
283                    p_cols.push(ci);
284                    p_vals.push(w);
285                }
286            }
287        }
288    }
289
290    CsrMatrix::from_triplets(n, n_coarse.max(1), p_rows, p_cols, p_vals)
291}
292
293// ─────────────────────────────────────────────────────────────────────────────
294// Coarse matrix: R A P
295// ─────────────────────────────────────────────────────────────────────────────
296
297/// Compute A_c = R * A * P where R = P^T.
298///
299/// All matrices use local row/column indices.
300fn triple_product(a: &CsrMatrix<f64>, p: &CsrMatrix<f64>) -> SparseResult<CsrMatrix<f64>> {
301    // R = P^T  (n_coarse × n_fine)
302    let r = p.transpose();
303    // B = A * P  (n_fine × n_coarse)
304    let b = sparse_matmul(a, p)?;
305    // A_c = R * B  (n_coarse × n_coarse)
306    sparse_matmul(&r, &b)
307}
308
309/// Sparse matrix multiplication C = A * B (all in CSR, T=f64).
310fn sparse_matmul(a: &CsrMatrix<f64>, b: &CsrMatrix<f64>) -> SparseResult<CsrMatrix<f64>> {
311    let (m, k_a) = a.shape();
312    let (k_b, n) = b.shape();
313    if k_a != k_b {
314        return Err(SparseError::DimensionMismatch {
315            expected: k_a,
316            found: k_b,
317        });
318    }
319
320    // Dense temporary row accumulator.
321    let mut c_rows: Vec<usize> = Vec::new();
322    let mut c_cols: Vec<usize> = Vec::new();
323    let mut c_vals: Vec<f64> = Vec::new();
324    let mut row_buf: Vec<f64> = vec![0.0; n];
325    let mut nz_cols: Vec<usize> = Vec::new();
326
327    for i in 0..m {
328        let a_start = a.indptr[i];
329        let a_end = a.indptr[i + 1];
330
331        nz_cols.clear();
332
333        for ka in a_start..a_end {
334            let ka_col = a.indices[ka];
335            let a_val = a.data[ka];
336            let b_start = b.indptr[ka_col];
337            let b_end = b.indptr[ka_col + 1];
338            for kb in b_start..b_end {
339                let j = b.indices[kb];
340                if row_buf[j] == 0.0 {
341                    nz_cols.push(j);
342                }
343                row_buf[j] += a_val * b.data[kb];
344            }
345        }
346
347        nz_cols.sort_unstable();
348        for &j in &nz_cols {
349            let v = row_buf[j];
350            if v.abs() > f64::EPSILON * 1e-3 {
351                c_rows.push(i);
352                c_cols.push(j);
353                c_vals.push(v);
354            }
355            row_buf[j] = 0.0; // reset
356        }
357    }
358
359    CsrMatrix::from_triplets(m, n, c_rows, c_cols, c_vals)
360}
361
362// ─────────────────────────────────────────────────────────────────────────────
363// build_distributed_amg
364// ─────────────────────────────────────────────────────────────────────────────
365
366/// Build the full distributed AMG hierarchy from an initial set of partitions.
367pub fn build_distributed_amg(
368    partitions: &[DistributedCsr],
369    config: &DistAMGConfig,
370) -> SparseResult<DistAMGHierarchy> {
371    if partitions.is_empty() {
372        return Err(SparseError::ValueError(
373            "Cannot build AMG hierarchy from empty partition list".to_string(),
374        ));
375    }
376
377    // Assemble the global fine matrix from partitions.
378    let global_fine = assemble_global_matrix(partitions)?;
379
380    let mut levels: Vec<DistAMGLevel> = Vec::new();
381    let mut current_mat = global_fine;
382
383    for _lvl in 0..config.max_levels.saturating_sub(1) {
384        let n = current_mat.rows();
385        if n <= 4 {
386            break;
387        }
388
389        // RS coarsening on local (single-worker) view.
390        let coarse_mask = rs_coarsen_local(&current_mat);
391        let n_coarse = coarse_mask.iter().filter(|&&c| c).count();
392
393        // Stop if coarsening ratio not achieved.
394        if n_coarse == 0 || (n_coarse as f64) / (n as f64) > config.coarsening_ratio + 0.05 {
395            break;
396        }
397
398        // Build P, R, coarse matrix.
399        let p = build_local_interpolation(&current_mat, &coarse_mask)?;
400        let a_c = triple_product(&current_mat, &p)?;
401
402        // Wrap in DistAMGLevel (single global partition for simplicity).
403        let partition = RowPartition {
404            worker_id: 0,
405            local_rows: (0..n).collect(),
406            n_global_rows: n,
407        };
408        let r = p.transpose();
409        levels.push(DistAMGLevel {
410            local_matrix: current_mat,
411            partition,
412            interpolation: p,
413            restriction: r,
414        });
415
416        current_mat = a_c;
417    }
418
419    Ok(DistAMGHierarchy {
420        levels,
421        coarsest_matrix: current_mat,
422    })
423}
424
425/// Assemble the global CSR matrix by gathering all owned rows from partitions.
426fn assemble_global_matrix(partitions: &[DistributedCsr]) -> SparseResult<CsrMatrix<f64>> {
427    if partitions.is_empty() {
428        return CsrMatrix::from_triplets(0, 0, vec![], vec![], vec![]);
429    }
430
431    let n_global = partitions[0].partition.n_global_rows;
432    let n_cols = partitions
433        .iter()
434        .map(|d| d.local_matrix.cols())
435        .max()
436        .unwrap_or(n_global);
437
438    let mut rows: Vec<usize> = Vec::new();
439    let mut cols: Vec<usize> = Vec::new();
440    let mut vals: Vec<f64> = Vec::new();
441
442    for dcsr in partitions {
443        let mat = &dcsr.local_matrix;
444        for (local_row, &global_row) in dcsr.partition.local_rows.iter().enumerate() {
445            let start = mat.indptr[local_row];
446            let end = mat.indptr[local_row + 1];
447            for k in start..end {
448                rows.push(global_row);
449                cols.push(mat.indices[k]);
450                vals.push(mat.data[k]);
451            }
452        }
453    }
454
455    CsrMatrix::from_triplets(n_global, n_cols, rows, cols, vals)
456}
457
458// ─────────────────────────────────────────────────────────────────────────────
459// dist_vcycle
460// ─────────────────────────────────────────────────────────────────────────────
461
462/// Perform one AMG V-cycle: pre-smooth → restrict → coarse solve → prolongate
463/// → post-smooth.
464///
465/// Returns the approximate solution `x` given right-hand side `rhs`.
466pub fn dist_vcycle(
467    hierarchy: &DistAMGHierarchy,
468    rhs: &[f64],
469    config: &DistAMGConfig,
470) -> SparseResult<Vec<f64>> {
471    if hierarchy.levels.is_empty() {
472        // Only coarsest level — direct solve (Jacobi).
473        return jacobi_solve(&hierarchy.coarsest_matrix, rhs, 50);
474    }
475
476    vcycle_recursive(hierarchy, 0, rhs, config)
477}
478
479/// Recursive V-cycle implementation.
480fn vcycle_recursive(
481    hierarchy: &DistAMGHierarchy,
482    level: usize,
483    rhs: &[f64],
484    config: &DistAMGConfig,
485) -> SparseResult<Vec<f64>> {
486    let n = rhs.len();
487
488    if level >= hierarchy.levels.len() {
489        // Coarsest level — direct (Jacobi) solve.
490        return jacobi_solve(&hierarchy.coarsest_matrix, rhs, 50);
491    }
492
493    let lvl = &hierarchy.levels[level];
494    let mat = &lvl.local_matrix;
495
496    // ── Pre-smooth ────────────────────────────────────────────────────────────
497    let mut x = jacobi_smooth(mat, rhs, &vec![0.0; n], config.smoother_iters)?;
498
499    // ── Compute residual r = rhs - A*x ───────────────────────────────────────
500    let ax = mat.dot(&x)?;
501    let residual: Vec<f64> = rhs
502        .iter()
503        .zip(ax.iter())
504        .map(|(&b, &ax_i)| b - ax_i)
505        .collect();
506
507    // ── Restrict: r_c = R * r ────────────────────────────────────────────────
508    let r = &lvl.restriction;
509    let rhs_coarse = csr_matvec(r, &residual)?;
510
511    // ── Coarse correction ─────────────────────────────────────────────────────
512    let e_coarse = vcycle_recursive(hierarchy, level + 1, &rhs_coarse, config)?;
513
514    // ── Prolongate: x += P * e_c ─────────────────────────────────────────────
515    let p = &lvl.interpolation;
516    let e_fine = csr_matvec(p, &e_coarse)?;
517
518    for (xi, ei) in x.iter_mut().zip(e_fine.iter()) {
519        *xi += ei;
520    }
521
522    // ── Post-smooth ───────────────────────────────────────────────────────────
523    x = jacobi_smooth(mat, rhs, &x, config.smoother_iters)?;
524
525    Ok(x)
526}
527
528// ─────────────────────────────────────────────────────────────────────────────
529// Smoother helpers
530// ─────────────────────────────────────────────────────────────────────────────
531
532/// Jacobi smoother: x_new[i] = (b[i] - Σ_{j≠i} a_{ij}*x[j]) / a_{ii}
533fn jacobi_smooth(
534    mat: &CsrMatrix<f64>,
535    rhs: &[f64],
536    x0: &[f64],
537    iters: usize,
538) -> SparseResult<Vec<f64>> {
539    let n = mat.rows();
540    let mut x = x0.to_vec();
541
542    // Pre-compute diagonal.
543    let diag: Vec<f64> = (0..n)
544        .map(|i| {
545            let start = mat.indptr[i];
546            let end = mat.indptr[i + 1];
547            (start..end)
548                .find(|&k| mat.indices[k] == i)
549                .map(|k| mat.data[k])
550                .unwrap_or(1.0)
551        })
552        .collect();
553
554    for _ in 0..iters {
555        let ax = mat.dot(&x)?;
556        for i in 0..n {
557            let d = diag[i];
558            if d.abs() > f64::EPSILON * 1e6 {
559                let off_diag = ax[i] - d * x[i];
560                x[i] = (rhs[i] - off_diag) / d;
561            }
562        }
563    }
564    Ok(x)
565}
566
567/// Simple Jacobi iterative solve (used at coarsest level).
568fn jacobi_solve(mat: &CsrMatrix<f64>, rhs: &[f64], iters: usize) -> SparseResult<Vec<f64>> {
569    jacobi_smooth(mat, rhs, &vec![0.0; rhs.len()], iters)
570}
571
572/// Dense CSR matrix-vector product (for small operators).
573fn csr_matvec(mat: &CsrMatrix<f64>, x: &[f64]) -> SparseResult<Vec<f64>> {
574    if x.len() < mat.cols() {
575        return Err(SparseError::DimensionMismatch {
576            expected: mat.cols(),
577            found: x.len(),
578        });
579    }
580    let n = mat.rows();
581    let mut y = vec![0.0_f64; n];
582    for i in 0..n {
583        let start = mat.indptr[i];
584        let end = mat.indptr[i + 1];
585        let mut acc = 0.0_f64;
586        for k in start..end {
587            let j = mat.indices[k];
588            if j < x.len() {
589                acc += mat.data[k] * x[j];
590            }
591        }
592        y[i] = acc;
593    }
594    Ok(y)
595}
596
597// ─────────────────────────────────────────────────────────────────────────────
598// Tests
599// ─────────────────────────────────────────────────────────────────────────────
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use crate::distributed::partition::{create_distributed_csr, partition_rows, PartitionConfig};
605
606    /// Build a symmetric positive definite tridiagonal n×n matrix.
607    /// Diagonal = 2, off-diagonal = -1.
608    fn tridiag(n: usize) -> CsrMatrix<f64> {
609        let mut rows = Vec::new();
610        let mut cols = Vec::new();
611        let mut vals = Vec::new();
612        for i in 0..n {
613            rows.push(i);
614            cols.push(i);
615            vals.push(2.0_f64);
616            if i > 0 {
617                rows.push(i);
618                cols.push(i - 1);
619                vals.push(-1.0);
620                rows.push(i - 1);
621                cols.push(i);
622                vals.push(-1.0);
623            }
624        }
625        CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag")
626    }
627
628    fn make_partitions(mat: &CsrMatrix<f64>, n_workers: usize) -> Vec<DistributedCsr> {
629        let config = PartitionConfig {
630            n_workers,
631            ..Default::default()
632        };
633        let rps = partition_rows(mat.rows(), &config);
634        rps.iter()
635            .map(|rp| create_distributed_csr(mat, rp).expect("create_distributed_csr"))
636            .collect()
637    }
638
639    #[test]
640    fn test_rs_coarsening_reduces_size() {
641        let n = 20;
642        let mat = tridiag(n);
643        let parts = make_partitions(&mat, 2);
644        let masks = distributed_rs_coarsening(&parts);
645
646        assert_eq!(masks.len(), 2);
647        // Each worker's mask should have fewer C-points than total rows.
648        for mask in &masks {
649            let n_coarse = mask.iter().filter(|&&c| c).count();
650            let n_fine_local = mask.len();
651            assert!(
652                n_coarse < n_fine_local,
653                "Expected coarsening; got n_coarse={n_coarse} of {n_fine_local}"
654            );
655        }
656    }
657
658    #[test]
659    fn test_build_amg_two_level() {
660        let n = 20;
661        let mat = tridiag(n);
662        let parts = make_partitions(&mat, 2);
663        let config = DistAMGConfig {
664            n_workers: 2,
665            max_levels: 2,
666            coarsening_ratio: 0.6, // slightly relaxed for n=20
667            smoother_iters: 1,
668        };
669        let hierarchy =
670            build_distributed_amg(&parts, &config).expect("build_distributed_amg failed");
671
672        // Should have at least one level.
673        assert!(
674            !hierarchy.levels.is_empty(),
675            "Expected at least one AMG level"
676        );
677
678        // Coarsest matrix should be smaller than fine level.
679        let n_fine = hierarchy.levels[0].local_matrix.rows();
680        let n_coarse = hierarchy.coarsest_matrix.rows();
681        assert!(
682            n_coarse < n_fine,
683            "Coarsest ({n_coarse}) should be smaller than fine ({n_fine})"
684        );
685    }
686
687    #[test]
688    fn test_vcycle_reduces_residual() {
689        let n = 20;
690        let mat = tridiag(n);
691        let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
692
693        let parts = make_partitions(&mat, 2);
694        let config = DistAMGConfig {
695            n_workers: 2,
696            max_levels: 3,
697            coarsening_ratio: 0.7,
698            smoother_iters: 2,
699        };
700
701        let hierarchy =
702            build_distributed_amg(&parts, &config).expect("build_distributed_amg failed");
703
704        let x = dist_vcycle(&hierarchy, &rhs, &config).expect("dist_vcycle failed");
705
706        // Compute residual ||b - Ax||.
707        let ax = if hierarchy.levels.is_empty() {
708            hierarchy.coarsest_matrix.dot(&x).expect("coarsest dot")
709        } else {
710            hierarchy.levels[0]
711                .local_matrix
712                .dot(&x)
713                .expect("level 0 dot")
714        };
715
716        let residual_norm: f64 = rhs
717            .iter()
718            .zip(ax.iter())
719            .map(|(&b, &ax_i)| (b - ax_i).powi(2))
720            .sum::<f64>()
721            .sqrt();
722
723        let rhs_norm: f64 = rhs.iter().map(|&b| b * b).sum::<f64>().sqrt();
724        let relative = residual_norm / rhs_norm;
725
726        assert!(
727            relative < 1.0,
728            "V-cycle should reduce relative residual below 1.0; got {relative}"
729        );
730    }
731
732    #[test]
733    fn test_sparse_matmul_identity() {
734        // A * I = A
735        let n = 5;
736        let mat = tridiag(n);
737        // Build n×n identity in CSR.
738        let i_rows: Vec<usize> = (0..n).collect();
739        let i_cols: Vec<usize> = (0..n).collect();
740        let i_vals: Vec<f64> = vec![1.0; n];
741        let identity = CsrMatrix::from_triplets(n, n, i_rows, i_cols, i_vals).expect("identity");
742
743        let result = sparse_matmul(&mat, &identity).expect("matmul");
744
745        for i in 0..n {
746            for j in 0..n {
747                let expected = mat.get(i, j);
748                let got = result.get(i, j);
749                assert!(
750                    (expected - got).abs() < 1e-10,
751                    "mismatch at ({i},{j}): {expected} vs {got}"
752                );
753            }
754        }
755    }
756
757    #[test]
758    fn test_build_interpolation_coarse_points_identity() {
759        let n = 6;
760        let mat = tridiag(n);
761        // Manually coarsen: every even row is C.
762        let coarse_mask: Vec<bool> = (0..n).map(|i| i % 2 == 0).collect();
763        let p = build_local_interpolation(&mat, &coarse_mask).expect("build_local_interpolation");
764        // Each C-point should map to exactly one coarse column with weight 1.
765        for (i, &is_c) in coarse_mask.iter().enumerate() {
766            if is_c {
767                let start = p.indptr[i];
768                let end = p.indptr[i + 1];
769                assert_eq!(end - start, 1, "C-point {i} should have exactly 1 entry");
770                assert!(
771                    (p.data[start] - 1.0).abs() < 1e-10,
772                    "C-point {i} interpolation weight should be 1.0"
773                );
774            }
775        }
776    }
777}