Skip to main content

scirs2_sparse/parallel_amg/
strength.rs

1//! Parallel Strength-of-Connection Computation for AMG
2//!
3//! Strength-of-connection determines which matrix entries are "strong"
4//! couplings, forming the basis for coarsening decisions.
5//!
6//! # Theory
7//!
8//! Node i strongly influences node j if:
9//!   |a_ij| >= θ · max_{k≠i} |a_ik|
10//!
11//! where θ is the strength threshold (typically 0.25).
12//!
13//! # Parallelism
14//!
15//! Rows are partitioned among threads. Each thread independently computes
16//! strength for its assigned rows. Since rows are disjoint, no synchronization
17//! is needed during computation.
18
19use crate::csr::CsrMatrix;
20use std::sync::Arc;
21
22/// Strength-of-connection graph: adjacency list of strong neighbors
23#[derive(Debug, Clone)]
24pub struct StrengthGraph {
25    /// Number of nodes
26    pub n: usize,
27    /// `strong_neighbors[i]` = nodes j such that i strongly influences j
28    /// (i.e., |a_ij| >= θ * max_k |a_ik|)
29    pub strong_neighbors: Vec<Vec<usize>>,
30    /// `strong_influencers[i]` = nodes j such that j strongly influences i
31    /// (transpose of strong_neighbors)
32    pub strong_influencers: Vec<Vec<usize>>,
33}
34
35impl StrengthGraph {
36    /// Create a StrengthGraph from the strong_neighbors adjacency list.
37    /// Computes strong_influencers as the transpose.
38    pub fn from_neighbors(n: usize, strong_neighbors: Vec<Vec<usize>>) -> Self {
39        // Build transpose (strong influencers)
40        let mut strong_influencers = vec![Vec::new(); n];
41        for (i, neighbors) in strong_neighbors.iter().enumerate() {
42            for &j in neighbors {
43                if j < n {
44                    strong_influencers[j].push(i);
45                }
46            }
47        }
48        Self {
49            n,
50            strong_neighbors,
51            strong_influencers,
52        }
53    }
54
55    /// Check if i strongly influences j
56    pub fn is_strong(&self, i: usize, j: usize) -> bool {
57        self.strong_neighbors.get(i).is_some_and(|v| v.contains(&j))
58    }
59
60    /// Check if i and j are strongly connected (in either direction)
61    pub fn is_strongly_connected(&self, i: usize, j: usize) -> bool {
62        self.is_strong(i, j) || self.is_strong(j, i)
63    }
64}
65
66/// Compute strength-of-connection for a single row range [row_start, row_end).
67/// Returns strong_neighbors for those rows (indexed relative to global row indices).
68fn compute_strength_row_range(
69    indptr: &[usize],
70    indices: &[usize],
71    data: &[f64],
72    theta: f64,
73    row_start: usize,
74    row_end: usize,
75) -> Vec<(usize, Vec<usize>)> {
76    let mut result = Vec::with_capacity(row_end - row_start);
77    for i in row_start..row_end {
78        let row_start_ptr = indptr[i];
79        let row_end_ptr = indptr[i + 1];
80
81        // Find max |a_ij| for j != i in row i
82        let mut max_abs = 0.0f64;
83        for pos in row_start_ptr..row_end_ptr {
84            let j = indices[pos];
85            if j != i {
86                let v = data[pos].abs();
87                if v > max_abs {
88                    max_abs = v;
89                }
90            }
91        }
92
93        let threshold = theta * max_abs;
94        let mut strong = Vec::new();
95
96        if threshold > 0.0 {
97            for pos in row_start_ptr..row_end_ptr {
98                let j = indices[pos];
99                if j != i && data[pos].abs() >= threshold {
100                    strong.push(j);
101                }
102            }
103        }
104
105        result.push((i, strong));
106    }
107    result
108}
109
110/// Compute the parallel strength-of-connection graph.
111///
112/// Partitions rows among `n_threads` threads and computes strength for
113/// each partition concurrently. Results are merged into a StrengthGraph.
114///
115/// # Arguments
116///
117/// * `a` - Input sparse matrix
118/// * `theta` - Strength threshold (typically 0.25)
119/// * `n_threads` - Number of threads to use
120///
121/// # Returns
122///
123/// StrengthGraph with strong neighbor lists for all nodes.
124pub fn parallel_strength_of_connection(
125    a: &CsrMatrix<f64>,
126    theta: f64,
127    n_threads: usize,
128) -> StrengthGraph {
129    let n = a.shape().0;
130    if n == 0 {
131        return StrengthGraph::from_neighbors(0, Vec::new());
132    }
133
134    let n_threads = n_threads.max(1);
135    let indptr = Arc::new(a.indptr.clone());
136    let indices = Arc::new(a.indices.clone());
137    let data = Arc::new(a.data.clone());
138
139    // Partition rows into blocks
140    let chunk_size = (n + n_threads - 1) / n_threads;
141
142    let mut strong_neighbors = vec![Vec::new(); n];
143
144    // Use thread::scope to compute in parallel
145    let mut all_results: Vec<Vec<(usize, Vec<usize>)>> = Vec::with_capacity(n_threads);
146
147    std::thread::scope(|s| {
148        let mut handles = Vec::new();
149
150        for t in 0..n_threads {
151            let row_start = t * chunk_size;
152            let row_end = ((t + 1) * chunk_size).min(n);
153            if row_start >= row_end {
154                continue;
155            }
156
157            let indptr_ref = Arc::clone(&indptr);
158            let indices_ref = Arc::clone(&indices);
159            let data_ref = Arc::clone(&data);
160
161            let handle = s.spawn(move || {
162                compute_strength_row_range(
163                    &indptr_ref,
164                    &indices_ref,
165                    &data_ref,
166                    theta,
167                    row_start,
168                    row_end,
169                )
170            });
171            handles.push(handle);
172        }
173
174        for h in handles {
175            if let Ok(result) = h.join() {
176                all_results.push(result);
177            }
178        }
179    });
180
181    // Merge results
182    for chunk in all_results {
183        for (i, neighbors) in chunk {
184            strong_neighbors[i] = neighbors;
185        }
186    }
187
188    StrengthGraph::from_neighbors(n, strong_neighbors)
189}
190
191/// Compute serial strength-of-connection (single-threaded baseline).
192///
193/// Useful for verification and small problems.
194pub fn serial_strength_of_connection(a: &CsrMatrix<f64>, theta: f64) -> StrengthGraph {
195    let n = a.shape().0;
196    let mut strong_neighbors = vec![Vec::new(); n];
197
198    for i in 0..n {
199        let mut max_abs = 0.0f64;
200        for pos in a.row_range(i) {
201            let j = a.indices[pos];
202            if j != i {
203                let v = a.data[pos].abs();
204                if v > max_abs {
205                    max_abs = v;
206                }
207            }
208        }
209        let threshold = theta * max_abs;
210        if threshold > 0.0 {
211            for pos in a.row_range(i) {
212                let j = a.indices[pos];
213                if j != i && a.data[pos].abs() >= threshold {
214                    strong_neighbors[i].push(j);
215                }
216            }
217        }
218    }
219
220    StrengthGraph::from_neighbors(n, strong_neighbors)
221}
222
223/// Compute the measure of importance λ_i for each node.
224///
225/// λ_i = |{j : i strongly influences j}| + 0.5 * |{j : j strongly influences i, i ∈ F-set}|
226///
227/// In the initial phase (before F-set is known), returns:
228/// `lambda_i = |strong_neighbors[i]|` (number of nodes i influences)
229///
230/// # Arguments
231///
232/// * `strength` - The strength graph
233///
234/// # Returns
235///
236/// Vector of λ values indexed by node.
237pub fn compute_lambda(strength: &StrengthGraph) -> Vec<f64> {
238    let n = strength.n;
239    let mut lambda = vec![0.0f64; n];
240    for i in 0..n {
241        // Count number of nodes that i strongly influences (out-degree in strong graph)
242        lambda[i] = strength.strong_neighbors[i].len() as f64;
243    }
244    lambda
245}
246
247/// Update lambda values given a partial F-set labeling (`cf_splitting[i]` = 0 means F, 1 means C, 2 means undecided).
248///
249/// λ_i = |{j : i strongly influences j}| + 0.5 * |{j : j influences i and j is F-node}|
250pub fn compute_lambda_with_fset(strength: &StrengthGraph, cf_splitting: &[u8]) -> Vec<f64> {
251    let n = strength.n;
252    let mut lambda = vec![0.0f64; n];
253    for i in 0..n {
254        // Out-degree: number of nodes i strongly influences
255        let out_degree = strength.strong_neighbors[i].len() as f64;
256        // Count F-node influencers of i
257        let f_influencers = strength
258            .strong_influencers
259            .get(i)
260            .map(|influencers| {
261                influencers
262                    .iter()
263                    .filter(|&&j| j < cf_splitting.len() && cf_splitting[j] == 0)
264                    .count()
265            })
266            .unwrap_or(0);
267        lambda[i] = out_degree + 0.5 * f_influencers as f64;
268    }
269    lambda
270}
271
272/// Compute the undirected strength-of-connection graph.
273///
274/// i and j have an undirected connection if:
275///   |a_ij| >= θ * max(max_k |a_ik|, max_k |a_jk|)
276///
277/// The resulting graph is symmetric.
278///
279/// # Arguments
280///
281/// * `a` - Input sparse matrix
282/// * `theta` - Strength threshold
283///
284/// # Returns
285///
286/// Symmetric StrengthGraph.
287pub fn undirected_strength(a: &CsrMatrix<f64>, theta: f64) -> StrengthGraph {
288    let n = a.shape().0;
289
290    // Compute row maxima
291    let mut row_max = vec![0.0f64; n];
292    for i in 0..n {
293        let mut max_abs = 0.0f64;
294        for pos in a.row_range(i) {
295            let j = a.indices[pos];
296            if j != i {
297                let v = a.data[pos].abs();
298                if v > max_abs {
299                    max_abs = v;
300                }
301            }
302        }
303        row_max[i] = max_abs;
304    }
305
306    let mut strong_neighbors = vec![Vec::new(); n];
307
308    for i in 0..n {
309        for pos in a.row_range(i) {
310            let j = a.indices[pos];
311            if j == i {
312                continue;
313            }
314            // Undirected threshold: max of both row maxima
315            let threshold = theta * row_max[i].max(row_max[j]);
316            if threshold > 0.0 && a.data[pos].abs() >= threshold {
317                // Add i→j edge (j→i will be added when processing row j)
318                if !strong_neighbors[i].contains(&j) {
319                    strong_neighbors[i].push(j);
320                }
321            }
322        }
323    }
324
325    StrengthGraph::from_neighbors(n, strong_neighbors)
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use crate::csr::CsrMatrix;
332
333    fn laplacian_1d(n: usize) -> CsrMatrix<f64> {
334        let mut rows = Vec::new();
335        let mut cols = Vec::new();
336        let mut vals = Vec::new();
337        for i in 0..n {
338            rows.push(i);
339            cols.push(i);
340            vals.push(2.0f64);
341        }
342        for i in 0..n - 1 {
343            rows.push(i);
344            cols.push(i + 1);
345            vals.push(-1.0f64);
346            rows.push(i + 1);
347            cols.push(i);
348            vals.push(-1.0f64);
349        }
350        CsrMatrix::new(vals, rows, cols, (n, n)).expect("valid Laplacian")
351    }
352
353    #[test]
354    fn test_strength_threshold() {
355        let a = laplacian_1d(6);
356        let g = serial_strength_of_connection(&a, 0.25);
357        // For the 1D Laplacian, off-diagonals are -1, diagonal is 2
358        // |a_ij| = 1, max |a_ik| for k != i = 1, threshold = 0.25 * 1 = 0.25
359        // So all off-diagonal entries are strong
360        for i in 0..6 {
361            for &j in &g.strong_neighbors[i] {
362                assert_ne!(i, j, "No self-loops in strong graph");
363                // Verify strength: |a_ij| >= theta * max_k |a_ik|
364                let aij = a.get(i, j).abs();
365                let mut max_abs = 0.0f64;
366                for pos in a.row_range(i) {
367                    if a.indices[pos] != i {
368                        let v = a.data[pos].abs();
369                        if v > max_abs {
370                            max_abs = v;
371                        }
372                    }
373                }
374                assert!(
375                    aij >= 0.25 * max_abs,
376                    "Strong connection must meet threshold"
377                );
378            }
379        }
380    }
381
382    #[test]
383    fn test_strength_parallel_matches_serial() {
384        let a = laplacian_1d(16);
385        let serial = serial_strength_of_connection(&a, 0.25);
386        let parallel = parallel_strength_of_connection(&a, 0.25, 4);
387        assert_eq!(serial.n, parallel.n);
388        for i in 0..serial.n {
389            let mut s = serial.strong_neighbors[i].clone();
390            let mut p = parallel.strong_neighbors[i].clone();
391            s.sort();
392            p.sort();
393            assert_eq!(s, p, "Mismatch at node {i}");
394        }
395    }
396
397    #[test]
398    fn test_undirected_strength_symmetric() {
399        let a = laplacian_1d(8);
400        let g = undirected_strength(&a, 0.25);
401        // Verify symmetry: if i -> j then j -> i
402        for i in 0..g.n {
403            for &j in &g.strong_neighbors[i] {
404                assert!(
405                    g.strong_neighbors[j].contains(&i),
406                    "Undirected strength must be symmetric: {i} -> {j} but not {j} -> {i}"
407                );
408            }
409        }
410    }
411
412    #[test]
413    fn test_lambda_computation() {
414        let a = laplacian_1d(8);
415        let g = serial_strength_of_connection(&a, 0.25);
416        let lambda = compute_lambda(&g);
417        assert_eq!(lambda.len(), 8);
418        for &l in &lambda {
419            assert!(l >= 0.0, "Lambda must be non-negative");
420        }
421    }
422
423    #[test]
424    fn test_parallel_strength_n_threads() {
425        let a = laplacian_1d(20);
426        for n_threads in [1, 2, 4] {
427            let g = parallel_strength_of_connection(&a, 0.25, n_threads);
428            assert_eq!(g.n, 20);
429            // All interior nodes should have 2 strong neighbors
430            for i in 1..19 {
431                assert_eq!(
432                    g.strong_neighbors[i].len(),
433                    2,
434                    "Interior node {i} should have 2 strong neighbors with {n_threads} threads"
435                );
436            }
437        }
438    }
439}