Skip to main content

scirs2_sparse/distributed/
partition.rs

1//! Row partitioning for distributed sparse matrices.
2//!
3//! Provides [`partition_rows`] and [`create_distributed_csr`] for splitting a
4//! [`CsrMatrix<f64>`] across multiple logical workers, including identification
5//! of halo (ghost) rows needed for halo-exchange SpMV.
6
7use std::collections::HashSet;
8
9use crate::csr::CsrMatrix;
10use crate::error::{SparseError, SparseResult};
11
12// ─────────────────────────────────────────────────────────────────────────────
13// PartitionMethod
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// Strategy used to assign rows to workers.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18#[non_exhaustive]
19pub enum PartitionMethod {
20    /// Contiguous blocks of rows: worker *i* owns rows `[i*n/p, (i+1)*n/p)`.
21    #[default]
22    Contiguous,
23    /// Round-robin: row *r* goes to worker `r % p`.
24    RoundRobin,
25    /// Greedy by row NNZ: balance work by assigning rows greedily so that each
26    /// worker gets approximately the same number of non-zeros.
27    GraphBased,
28}
29
30// ─────────────────────────────────────────────────────────────────────────────
31// PartitionConfig
32// ─────────────────────────────────────────────────────────────────────────────
33
34/// Configuration for row partitioning.
35#[derive(Debug, Clone)]
36pub struct PartitionConfig {
37    /// Number of workers (partitions) to create (default 4).
38    pub n_workers: usize,
39    /// Number of halo rows to include on each side (default 0 — halo exchange
40    /// based on column references, not geometric proximity).
41    pub overlap: usize,
42    /// Row assignment strategy (default [`PartitionMethod::Contiguous`]).
43    pub method: PartitionMethod,
44}
45
46impl Default for PartitionConfig {
47    fn default() -> Self {
48        Self {
49            n_workers: 4,
50            overlap: 0,
51            method: PartitionMethod::Contiguous,
52        }
53    }
54}
55
56// ─────────────────────────────────────────────────────────────────────────────
57// RowPartition
58// ─────────────────────────────────────────────────────────────────────────────
59
60/// Describes which rows a single worker owns.
61#[derive(Debug, Clone)]
62pub struct RowPartition {
63    /// Worker identifier (0-based).
64    pub worker_id: usize,
65    /// Global row indices owned by this worker, in ascending order.
66    pub local_rows: Vec<usize>,
67    /// Total number of rows in the global matrix.
68    pub n_global_rows: usize,
69}
70
71impl RowPartition {
72    /// Number of rows owned by this worker.
73    #[inline]
74    pub fn n_local(&self) -> usize {
75        self.local_rows.len()
76    }
77}
78
79// ─────────────────────────────────────────────────────────────────────────────
80// DistributedCsr
81// ─────────────────────────────────────────────────────────────────────────────
82
83/// A worker's local view of a distributed CSR matrix.
84///
85/// `local_matrix` rows correspond to the global rows listed in
86/// `partition.local_rows` (index 0 = `partition.local_rows[0]`, etc.).
87/// `ghost_rows` holds the global row indices of off-partition rows referenced
88/// by at least one non-zero in the local rows.
89#[derive(Debug, Clone)]
90pub struct DistributedCsr {
91    /// Local CSR matrix (only owned rows; columns are global indices).
92    pub local_matrix: CsrMatrix<f64>,
93    /// Row ownership information.
94    pub partition: RowPartition,
95    /// Global row indices that appear as column targets in the local rows but
96    /// are owned by other workers — i.e. halo / ghost rows.
97    pub ghost_rows: Vec<usize>,
98}
99
100// ─────────────────────────────────────────────────────────────────────────────
101// partition_rows
102// ─────────────────────────────────────────────────────────────────────────────
103
104/// Partition `n_rows` global rows across `config.n_workers` workers.
105///
106/// Returns a [`Vec<RowPartition>`] of length `config.n_workers`.
107pub fn partition_rows(n_rows: usize, config: &PartitionConfig) -> Vec<RowPartition> {
108    let p = config.n_workers.max(1);
109
110    match config.method {
111        PartitionMethod::Contiguous => (0..p)
112            .map(|w| {
113                let start = w * n_rows / p;
114                let end = (w + 1) * n_rows / p;
115                RowPartition {
116                    worker_id: w,
117                    local_rows: (start..end).collect(),
118                    n_global_rows: n_rows,
119                }
120            })
121            .collect(),
122        PartitionMethod::RoundRobin => {
123            let mut bins: Vec<Vec<usize>> = vec![Vec::new(); p];
124            for r in 0..n_rows {
125                bins[r % p].push(r);
126            }
127            bins.into_iter()
128                .enumerate()
129                .map(|(w, rows)| RowPartition {
130                    worker_id: w,
131                    local_rows: rows,
132                    n_global_rows: n_rows,
133                })
134                .collect()
135        }
136        PartitionMethod::GraphBased => {
137            // Partition_rows does not have NNZ information — return contiguous
138            // blocks as a fallback; create_distributed_csr with GraphBased
139            // re-balances by NNZ.
140            (0..p)
141                .map(|w| {
142                    let start = w * n_rows / p;
143                    let end = (w + 1) * n_rows / p;
144                    RowPartition {
145                        worker_id: w,
146                        local_rows: (start..end).collect(),
147                        n_global_rows: n_rows,
148                    }
149                })
150                .collect()
151        }
152    }
153}
154
155// ─────────────────────────────────────────────────────────────────────────────
156// create_distributed_csr
157// ─────────────────────────────────────────────────────────────────────────────
158
159/// Build a [`DistributedCsr`] for one worker from the global CSR matrix.
160///
161/// The function extracts the rows listed in `partition.local_rows`, re-indexes
162/// them into a compact local matrix (rows 0..n_local, global column indices
163/// preserved), and identifies the ghost rows.
164pub fn create_distributed_csr(
165    global_matrix: &CsrMatrix<f64>,
166    partition: &RowPartition,
167) -> SparseResult<DistributedCsr> {
168    let n_local = partition.local_rows.len();
169    let n_cols = global_matrix.cols();
170    let n_global_rows = global_matrix.rows();
171
172    // Build a set of owned global rows for fast ghost detection.
173    let owned_set: HashSet<usize> = partition.local_rows.iter().copied().collect();
174
175    // Collect triplets for the local matrix and accumulate ghost rows.
176    let mut row_indices: Vec<usize> = Vec::new();
177    let mut col_indices: Vec<usize> = Vec::new();
178    let mut values: Vec<f64> = Vec::new();
179    let mut ghost_set: HashSet<usize> = HashSet::new();
180
181    for (local_row, &global_row) in partition.local_rows.iter().enumerate() {
182        if global_row >= n_global_rows {
183            return Err(SparseError::ValueError(format!(
184                "Global row {global_row} out of bounds (n_rows={n_global_rows})"
185            )));
186        }
187        let row_start = global_matrix.indptr[global_row];
188        let row_end = global_matrix.indptr[global_row + 1];
189
190        for idx in row_start..row_end {
191            let col = global_matrix.indices[idx];
192            let val = global_matrix.data[idx];
193
194            row_indices.push(local_row);
195            col_indices.push(col);
196            values.push(val);
197
198            // column `col` corresponds to a row in the global matrix; if it is
199            // outside the owned set it becomes a ghost row.
200            if col < n_global_rows && !owned_set.contains(&col) {
201                ghost_set.insert(col);
202            }
203        }
204    }
205
206    let local_matrix = CsrMatrix::from_triplets(n_local, n_cols, row_indices, col_indices, values)?;
207
208    let mut ghost_rows: Vec<usize> = ghost_set.into_iter().collect();
209    ghost_rows.sort_unstable();
210
211    Ok(DistributedCsr {
212        local_matrix,
213        partition: partition.clone(),
214        ghost_rows,
215    })
216}
217
218// ─────────────────────────────────────────────────────────────────────────────
219// NNZ-balanced partitioning helper (used by halo_exchange & dist_amg)
220// ─────────────────────────────────────────────────────────────────────────────
221
222/// Partition `global_matrix` by NNZ balance into `n_workers` pieces.
223///
224/// Returns [`DistributedCsr`] per worker.  Uses the `GraphBased` strategy
225/// (greedy NNZ balance with contiguous row blocks for cache friendliness).
226pub fn partition_matrix_nnz(
227    global_matrix: &CsrMatrix<f64>,
228    n_workers: usize,
229) -> SparseResult<Vec<DistributedCsr>> {
230    let n_rows = global_matrix.rows();
231    let p = n_workers.max(1);
232
233    // Compute per-row NNZ.
234    let row_nnz: Vec<usize> = (0..n_rows)
235        .map(|r| global_matrix.indptr[r + 1] - global_matrix.indptr[r])
236        .collect();
237    let total_nnz: usize = row_nnz.iter().sum();
238    let target = (total_nnz + p - 1) / p;
239
240    // Greedy contiguous block assignment.
241    let mut partitions_rows: Vec<Vec<usize>> = vec![Vec::new(); p];
242    let mut worker = 0usize;
243    let mut acc = 0usize;
244
245    for r in 0..n_rows {
246        partitions_rows[worker].push(r);
247        acc += row_nnz[r];
248        if acc >= target && worker + 1 < p {
249            worker += 1;
250            acc = 0;
251        }
252    }
253
254    let result: SparseResult<Vec<DistributedCsr>> = partitions_rows
255        .into_iter()
256        .enumerate()
257        .map(|(w, rows)| {
258            let rp = RowPartition {
259                worker_id: w,
260                local_rows: rows,
261                n_global_rows: n_rows,
262            };
263            create_distributed_csr(global_matrix, &rp)
264        })
265        .collect();
266
267    result
268}
269
270// ─────────────────────────────────────────────────────────────────────────────
271// Tests
272// ─────────────────────────────────────────────────────────────────────────────
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    /// Build a 100×100 identity-like banded (tridiagonal) matrix.
279    fn tridiag_100() -> CsrMatrix<f64> {
280        let n = 100usize;
281        let mut rows = Vec::new();
282        let mut cols = Vec::new();
283        let mut vals = Vec::new();
284        for i in 0..n {
285            rows.push(i);
286            cols.push(i);
287            vals.push(2.0_f64);
288            if i > 0 {
289                rows.push(i);
290                cols.push(i - 1);
291                vals.push(-1.0);
292                rows.push(i - 1);
293                cols.push(i);
294                vals.push(-1.0);
295            }
296        }
297        CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag_100 construction")
298    }
299
300    #[test]
301    fn test_contiguous_row_count_sums_to_n() {
302        let config = PartitionConfig {
303            n_workers: 4,
304            ..Default::default()
305        };
306        let parts = partition_rows(100, &config);
307        assert_eq!(parts.len(), 4);
308        let total: usize = parts.iter().map(|p| p.n_local()).sum();
309        assert_eq!(total, 100);
310    }
311
312    #[test]
313    fn test_contiguous_first_partition_rows() {
314        let config = PartitionConfig {
315            n_workers: 4,
316            ..Default::default()
317        };
318        let parts = partition_rows(100, &config);
319        // worker 0 should own rows [0..25)
320        assert_eq!(parts[0].local_rows, (0..25).collect::<Vec<_>>());
321        assert_eq!(parts[1].local_rows, (25..50).collect::<Vec<_>>());
322        assert_eq!(parts[2].local_rows, (50..75).collect::<Vec<_>>());
323        assert_eq!(parts[3].local_rows, (75..100).collect::<Vec<_>>());
324    }
325
326    #[test]
327    fn test_round_robin_all_rows_assigned() {
328        let config = PartitionConfig {
329            n_workers: 4,
330            method: PartitionMethod::RoundRobin,
331            ..Default::default()
332        };
333        let parts = partition_rows(100, &config);
334        let total: usize = parts.iter().map(|p| p.n_local()).sum();
335        assert_eq!(total, 100);
336    }
337
338    #[test]
339    fn test_create_distributed_csr_ghost_rows() {
340        let mat = tridiag_100();
341        let config = PartitionConfig {
342            n_workers: 4,
343            ..Default::default()
344        };
345        let partitions = partition_rows(100, &config);
346        // Worker 1 owns rows [25..50].
347        let dcsr =
348            create_distributed_csr(&mat, &partitions[1]).expect("create_distributed_csr failed");
349        // Ghost rows should include row 24 and row 50 (neighbours of boundary rows).
350        assert!(
351            dcsr.ghost_rows.contains(&24),
352            "Expected row 24 as ghost, got {:?}",
353            dcsr.ghost_rows
354        );
355        assert!(
356            dcsr.ghost_rows.contains(&50),
357            "Expected row 50 as ghost, got {:?}",
358            dcsr.ghost_rows
359        );
360    }
361
362    #[test]
363    fn test_distributed_csr_local_matrix_nnz() {
364        let mat = tridiag_100();
365        let config = PartitionConfig {
366            n_workers: 4,
367            ..Default::default()
368        };
369        let partitions = partition_rows(100, &config);
370        let dcsr =
371            create_distributed_csr(&mat, &partitions[0]).expect("create_distributed_csr failed");
372        // Worker 0 owns rows 0..25; interior rows have 3 nnz, boundary rows have 2.
373        // Row 0: 2 nnz; rows 1..24: 3 nnz each; row 24: 3 nnz (has row 25 as ghost)
374        // Actually row 24 (last in partition 0) still references col 25, so 3 nnz.
375        // Row 0 references only (0,0) and (0,1) => 2 nnz.
376        // Total: 2 + 23*3 + 3 = 2 + 69 + 3 = 74
377        assert_eq!(dcsr.local_matrix.nnz(), 2 + 23 * 3 + 3);
378    }
379
380    #[test]
381    fn test_partition_matrix_nnz_balanced() {
382        let mat = tridiag_100();
383        let dcsrs = partition_matrix_nnz(&mat, 4).expect("partition_matrix_nnz failed");
384        assert_eq!(dcsrs.len(), 4);
385        let total_rows: usize = dcsrs.iter().map(|d| d.partition.n_local()).sum();
386        assert_eq!(total_rows, 100);
387    }
388}