Skip to main content

scirs2_sparse/distributed/
csr.rs

1//! Distributed CSR with row-based partitioning and halo exchange.
2//!
3//! Implements a row-striped decomposition of a sparse matrix where each
4//! logical "worker" owns a contiguous range of rows.  The halo exchange is
5//! simulated via shared memory (suitable for multi-threaded or single-process
6//! testing); a real distributed implementation would replace the halo
7//! broadcast with MPI or similar.
8
9use crate::error::{SparseError, SparseResult};
10use crate::gpu::construction::{GpuCooMatrix, GpuCsrMatrix};
11
12// ============================================================
13// Configuration
14// ============================================================
15
16/// Configuration for distributed CSR partitioning.
17#[derive(Debug, Clone)]
18pub struct DistributedCsrConfig {
19    /// Number of workers (partitions) to create (default 4).
20    pub n_workers: usize,
21    /// Number of halo (ghost) rows from neighbouring partitions to include in
22    /// each local matrix (default 1).
23    pub overlap: usize,
24}
25
26impl Default for DistributedCsrConfig {
27    fn default() -> Self {
28        Self {
29            n_workers: 4,
30            overlap: 1,
31        }
32    }
33}
34
35// ============================================================
36// PartitionedCsr
37// ============================================================
38
39/// A CSR matrix partitioned across multiple logical workers using a
40/// row-striped decomposition with optional halo rows.
41///
42/// Each entry of `partitions` holds the **local** matrix for that worker,
43/// which covers `[partition_row_start[w], partition_row_start[w] + partitions[w].n_rows)`.
44#[derive(Debug, Clone)]
45pub struct PartitionedCsr {
46    /// Local matrix per worker (may include halo rows at the boundaries).
47    pub partitions: Vec<GpuCsrMatrix>,
48    /// Global row index of the first **owned** (non-halo) row for each worker.
49    pub row_offsets: Vec<usize>,
50    /// For each worker: the list of global row indices that are ghost rows
51    /// (owned by another worker but needed by this worker).
52    pub halo_rows: Vec<Vec<usize>>,
53    /// Total number of rows in the original matrix.
54    pub n_total_rows: usize,
55    /// Number of columns in the original matrix.
56    pub n_cols: usize,
57    /// Global row index at which each worker's **local** matrix begins
58    /// (including any leading halo rows).
59    partition_global_start: Vec<usize>,
60    /// Global row index at which each worker's **owned** rows end (exclusive).
61    owned_ends: Vec<usize>,
62}
63
64impl PartitionedCsr {
65    /// Partition `matrix` into `config.n_workers` pieces using a row-striped
66    /// decomposition.
67    ///
68    /// Each partition gets its owned rows plus up to `config.overlap` halo
69    /// rows from each adjacent partition.
70    pub fn from_csr(matrix: &GpuCsrMatrix, config: &DistributedCsrConfig) -> Self {
71        let n_workers = config.n_workers.max(1);
72        let overlap = config.overlap;
73        let n = matrix.n_rows;
74        let n_cols = matrix.n_cols;
75
76        // ── Compute owned row ranges (NNZ-balanced partitioning) ──────────
77        let total_nnz = matrix.n_nnz();
78        let target_nnz = total_nnz
79            .checked_div(n_workers)
80            .map(|q| q + usize::from(!total_nnz.is_multiple_of(n_workers)))
81            .unwrap_or(total_nnz);
82
83        let mut owned_starts: Vec<usize> = vec![0];
84        let mut acc = 0usize;
85        for row in 0..n {
86            acc += matrix.row_ptr[row + 1] - matrix.row_ptr[row];
87            if acc >= target_nnz && owned_starts.len() < n_workers {
88                owned_starts.push(row + 1);
89                acc = 0;
90            }
91        }
92        while owned_starts.len() < n_workers {
93            owned_starts.push(n);
94        }
95        let mut o_ends: Vec<usize> = owned_starts[1..].to_vec();
96        o_ends.push(n);
97
98        // ── Build per-worker local matrices (owned + halo rows) ───────────
99        let mut partitions: Vec<GpuCsrMatrix> = Vec::with_capacity(n_workers);
100        let mut halo_rows: Vec<Vec<usize>> = Vec::with_capacity(n_workers);
101        let mut row_offsets: Vec<usize> = Vec::with_capacity(n_workers);
102        let mut part_global_starts: Vec<usize> = Vec::with_capacity(n_workers);
103
104        for w in 0..n_workers {
105            let own_start = owned_starts[w];
106            let own_end = o_ends[w];
107
108            // Halo rows: `overlap` rows before and after the owned range.
109            let halo_start = own_start.saturating_sub(overlap);
110            let halo_end = (own_end + overlap).min(n);
111
112            // Collect global row indices for ghost rows only.
113            let mut ghost: Vec<usize> = Vec::new();
114            for r in halo_start..own_start {
115                ghost.push(r);
116            }
117            for r in own_end..halo_end {
118                ghost.push(r);
119            }
120
121            // Build local CSR from global rows [halo_start, halo_end).
122            let local_nrows = halo_end - halo_start;
123            let local_nnz_start = matrix.row_ptr[halo_start];
124            let local_nnz_end = matrix.row_ptr[halo_end];
125
126            let local_row_ptr: Vec<usize> = (halo_start..=halo_end)
127                .map(|r| matrix.row_ptr[r] - local_nnz_start)
128                .collect();
129            let local_col_idx = matrix.col_idx[local_nnz_start..local_nnz_end].to_vec();
130            let local_values = matrix.values[local_nnz_start..local_nnz_end].to_vec();
131
132            partitions.push(GpuCsrMatrix {
133                row_ptr: local_row_ptr,
134                col_idx: local_col_idx,
135                values: local_values,
136                n_rows: local_nrows,
137                n_cols,
138            });
139
140            halo_rows.push(ghost);
141            row_offsets.push(own_start);
142            part_global_starts.push(halo_start);
143        }
144
145        Self {
146            partitions,
147            row_offsets,
148            halo_rows,
149            n_total_rows: n,
150            n_cols,
151            partition_global_start: part_global_starts,
152            owned_ends: o_ends,
153        }
154    }
155
156    /// Compute `y = A * x` using the distributed representation.
157    ///
158    /// Each worker computes SpMV on its local matrix (which includes halo rows)
159    /// and contributes only the owned rows to the global result.
160    ///
161    /// # Errors
162    ///
163    /// Returns [`SparseError::DimensionMismatch`] when `x.len() != n_cols`.
164    pub fn spmv(&self, x: &[f64]) -> SparseResult<Vec<f64>> {
165        if x.len() != self.n_cols {
166            return Err(SparseError::DimensionMismatch {
167                expected: self.n_cols,
168                found: x.len(),
169            });
170        }
171
172        let n_workers = self.partitions.len();
173        let mut y = vec![0.0_f64; self.n_total_rows];
174
175        for w in 0..n_workers {
176            let partition = &self.partitions[w];
177            let local_y = partition.spmv(x)?;
178
179            // Owned global row range for this worker.
180            let own_global_start = self.row_offsets[w];
181            let own_global_end = self.owned_ends[w];
182
183            // The local matrix starts at global row `partition_global_start[w]`.
184            // The local index of the first owned row is:
185            let local_owned_start = own_global_start - self.partition_global_start[w];
186
187            // Copy owned rows from local result to global output.
188            let owned_len = own_global_end - own_global_start;
189            for k in 0..owned_len {
190                let local_idx = local_owned_start + k;
191                if local_idx < local_y.len() {
192                    y[own_global_start + k] = local_y[local_idx];
193                }
194            }
195        }
196
197        Ok(y)
198    }
199
200    /// Reassemble all partitions' owned rows into a single [`GpuCsrMatrix`].
201    pub fn to_csr(&self) -> GpuCsrMatrix {
202        let n_workers = self.partitions.len();
203        let mut coo = GpuCooMatrix::new(self.n_total_rows, self.n_cols);
204
205        for w in 0..n_workers {
206            let partition = &self.partitions[w];
207            let own_global_start = self.row_offsets[w];
208            let own_global_end = self.owned_ends[w];
209            let local_owned_start = own_global_start - self.partition_global_start[w];
210            let owned_len = own_global_end - own_global_start;
211
212            for k in 0..owned_len {
213                let local_row = local_owned_start + k;
214                let global_row = own_global_start + k;
215                let row_start = partition.row_ptr[local_row];
216                let row_end = partition.row_ptr[local_row + 1];
217                for idx in row_start..row_end {
218                    coo.push(global_row, partition.col_idx[idx], partition.values[idx]);
219                }
220            }
221        }
222
223        coo.to_csr()
224    }
225
226    /// Measure load balance quality.
227    ///
228    /// Returns `std_dev(nnz_per_partition) / mean(nnz_per_partition)`.
229    /// A value of 0.0 means perfect balance; lower is better.
230    ///
231    /// Returns 0.0 for degenerate cases (0 or 1 workers, or 0 total nnz).
232    pub fn load_balance_quality(&self) -> f64 {
233        let n = self.partitions.len();
234        if n <= 1 {
235            return 0.0;
236        }
237        let counts: Vec<f64> = self.partitions.iter().map(|p| p.n_nnz() as f64).collect();
238        let mean = counts.iter().sum::<f64>() / n as f64;
239        if mean < f64::EPSILON {
240            return 0.0;
241        }
242        let variance = counts.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / n as f64;
243        variance.sqrt() / mean
244    }
245}
246
247// ============================================================
248// Tests
249// ============================================================
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254    use crate::gpu::construction::GpuCooMatrix;
255
256    /// Build an n×n tridiagonal matrix for distributed tests.
257    fn tridiag(n: usize) -> GpuCsrMatrix {
258        let mut coo = GpuCooMatrix::new(n, n);
259        for i in 0..n {
260            coo.push(i, i, 4.0);
261            if i > 0 {
262                coo.push(i, i - 1, -1.0);
263                coo.push(i - 1, i, -1.0);
264            }
265        }
266        coo.to_csr()
267    }
268
269    #[test]
270    fn test_distributed_csr_spmv_matches_sequential() {
271        let n = 12;
272        let mat = tridiag(n);
273        let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
274
275        let y_seq = mat.spmv(&x).expect("sequential spmv failed");
276
277        let config = DistributedCsrConfig {
278            n_workers: 4,
279            overlap: 1,
280        };
281        let dist = PartitionedCsr::from_csr(&mat, &config);
282        let y_dist = dist.spmv(&x).expect("distributed spmv failed");
283
284        assert_eq!(y_seq.len(), y_dist.len());
285        for (i, (ys, yd)) in y_seq.iter().zip(y_dist.iter()).enumerate() {
286            assert!(
287                (ys - yd).abs() < 1e-10,
288                "row {i}: sequential={ys} distributed={yd}"
289            );
290        }
291    }
292
293    #[test]
294    fn test_distributed_partitioning_row_split() {
295        let n = 12;
296        let mat = tridiag(n);
297        let config = DistributedCsrConfig {
298            n_workers: 4,
299            overlap: 0,
300        };
301        let dist = PartitionedCsr::from_csr(&mat, &config);
302        assert_eq!(dist.partitions.len(), 4);
303
304        // Row offsets must be non-decreasing and within [0, n]
305        for w in &dist.row_offsets {
306            assert!(*w <= n);
307        }
308    }
309
310    #[test]
311    fn test_distributed_to_csr_roundtrip() {
312        let n = 8;
313        let mat = tridiag(n);
314        let config = DistributedCsrConfig {
315            n_workers: 3,
316            overlap: 1,
317        };
318        let dist = PartitionedCsr::from_csr(&mat, &config);
319        let reassembled = dist.to_csr();
320
321        // NNZ should be the same
322        assert_eq!(mat.n_nnz(), reassembled.n_nnz());
323        // Dense representations must match
324        let d1 = mat.to_dense();
325        let d2 = reassembled.to_dense();
326        for i in 0..n {
327            for j in 0..n {
328                assert!(
329                    (d1[[i, j]] - d2[[i, j]]).abs() < 1e-12,
330                    "mismatch at ({i},{j})"
331                );
332            }
333        }
334    }
335
336    #[test]
337    fn test_load_balance_quality() {
338        let n = 12;
339        let mat = tridiag(n);
340        let config = DistributedCsrConfig {
341            n_workers: 4,
342            overlap: 0,
343        };
344        let dist = PartitionedCsr::from_csr(&mat, &config);
345        let q = dist.load_balance_quality();
346        assert!(q >= 0.0);
347        assert!(q < 1.0);
348    }
349
350    #[test]
351    fn test_single_worker() {
352        let n = 6;
353        let mat = tridiag(n);
354        let config = DistributedCsrConfig {
355            n_workers: 1,
356            overlap: 0,
357        };
358        let dist = PartitionedCsr::from_csr(&mat, &config);
359        let x = vec![1.0; n];
360        let y_seq = mat.spmv(&x).expect("spmv failed");
361        let y_dist = dist.spmv(&x).expect("distributed spmv failed");
362        for (ys, yd) in y_seq.iter().zip(y_dist.iter()) {
363            assert!((ys - yd).abs() < 1e-10);
364        }
365    }
366
367    #[test]
368    fn test_more_workers_than_rows() {
369        let n = 3;
370        let mat = tridiag(n);
371        let config = DistributedCsrConfig {
372            n_workers: 6,
373            overlap: 0,
374        };
375        let dist = PartitionedCsr::from_csr(&mat, &config);
376        let x = vec![1.0; n];
377        let y_seq = mat.spmv(&x).expect("spmv failed");
378        let y_dist = dist.spmv(&x).expect("distributed spmv failed");
379        for (ys, yd) in y_seq.iter().zip(y_dist.iter()) {
380            assert!((ys - yd).abs() < 1e-10);
381        }
382    }
383}