scirs2_sparse/distributed/
csr.rs1use crate::error::{SparseError, SparseResult};
10use crate::gpu::construction::{GpuCooMatrix, GpuCsrMatrix};
11
12#[derive(Debug, Clone)]
18pub struct DistributedCsrConfig {
19 pub n_workers: usize,
21 pub overlap: usize,
24}
25
26impl Default for DistributedCsrConfig {
27 fn default() -> Self {
28 Self {
29 n_workers: 4,
30 overlap: 1,
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
45pub struct PartitionedCsr {
46 pub partitions: Vec<GpuCsrMatrix>,
48 pub row_offsets: Vec<usize>,
50 pub halo_rows: Vec<Vec<usize>>,
53 pub n_total_rows: usize,
55 pub n_cols: usize,
57 partition_global_start: Vec<usize>,
60 owned_ends: Vec<usize>,
62}
63
64impl PartitionedCsr {
65 pub fn from_csr(matrix: &GpuCsrMatrix, config: &DistributedCsrConfig) -> Self {
71 let n_workers = config.n_workers.max(1);
72 let overlap = config.overlap;
73 let n = matrix.n_rows;
74 let n_cols = matrix.n_cols;
75
76 let total_nnz = matrix.n_nnz();
78 let target_nnz = total_nnz
79 .checked_div(n_workers)
80 .map(|q| q + usize::from(!total_nnz.is_multiple_of(n_workers)))
81 .unwrap_or(total_nnz);
82
83 let mut owned_starts: Vec<usize> = vec![0];
84 let mut acc = 0usize;
85 for row in 0..n {
86 acc += matrix.row_ptr[row + 1] - matrix.row_ptr[row];
87 if acc >= target_nnz && owned_starts.len() < n_workers {
88 owned_starts.push(row + 1);
89 acc = 0;
90 }
91 }
92 while owned_starts.len() < n_workers {
93 owned_starts.push(n);
94 }
95 let mut o_ends: Vec<usize> = owned_starts[1..].to_vec();
96 o_ends.push(n);
97
98 let mut partitions: Vec<GpuCsrMatrix> = Vec::with_capacity(n_workers);
100 let mut halo_rows: Vec<Vec<usize>> = Vec::with_capacity(n_workers);
101 let mut row_offsets: Vec<usize> = Vec::with_capacity(n_workers);
102 let mut part_global_starts: Vec<usize> = Vec::with_capacity(n_workers);
103
104 for w in 0..n_workers {
105 let own_start = owned_starts[w];
106 let own_end = o_ends[w];
107
108 let halo_start = own_start.saturating_sub(overlap);
110 let halo_end = (own_end + overlap).min(n);
111
112 let mut ghost: Vec<usize> = Vec::new();
114 for r in halo_start..own_start {
115 ghost.push(r);
116 }
117 for r in own_end..halo_end {
118 ghost.push(r);
119 }
120
121 let local_nrows = halo_end - halo_start;
123 let local_nnz_start = matrix.row_ptr[halo_start];
124 let local_nnz_end = matrix.row_ptr[halo_end];
125
126 let local_row_ptr: Vec<usize> = (halo_start..=halo_end)
127 .map(|r| matrix.row_ptr[r] - local_nnz_start)
128 .collect();
129 let local_col_idx = matrix.col_idx[local_nnz_start..local_nnz_end].to_vec();
130 let local_values = matrix.values[local_nnz_start..local_nnz_end].to_vec();
131
132 partitions.push(GpuCsrMatrix {
133 row_ptr: local_row_ptr,
134 col_idx: local_col_idx,
135 values: local_values,
136 n_rows: local_nrows,
137 n_cols,
138 });
139
140 halo_rows.push(ghost);
141 row_offsets.push(own_start);
142 part_global_starts.push(halo_start);
143 }
144
145 Self {
146 partitions,
147 row_offsets,
148 halo_rows,
149 n_total_rows: n,
150 n_cols,
151 partition_global_start: part_global_starts,
152 owned_ends: o_ends,
153 }
154 }
155
156 pub fn spmv(&self, x: &[f64]) -> SparseResult<Vec<f64>> {
165 if x.len() != self.n_cols {
166 return Err(SparseError::DimensionMismatch {
167 expected: self.n_cols,
168 found: x.len(),
169 });
170 }
171
172 let n_workers = self.partitions.len();
173 let mut y = vec![0.0_f64; self.n_total_rows];
174
175 for w in 0..n_workers {
176 let partition = &self.partitions[w];
177 let local_y = partition.spmv(x)?;
178
179 let own_global_start = self.row_offsets[w];
181 let own_global_end = self.owned_ends[w];
182
183 let local_owned_start = own_global_start - self.partition_global_start[w];
186
187 let owned_len = own_global_end - own_global_start;
189 for k in 0..owned_len {
190 let local_idx = local_owned_start + k;
191 if local_idx < local_y.len() {
192 y[own_global_start + k] = local_y[local_idx];
193 }
194 }
195 }
196
197 Ok(y)
198 }
199
200 pub fn to_csr(&self) -> GpuCsrMatrix {
202 let n_workers = self.partitions.len();
203 let mut coo = GpuCooMatrix::new(self.n_total_rows, self.n_cols);
204
205 for w in 0..n_workers {
206 let partition = &self.partitions[w];
207 let own_global_start = self.row_offsets[w];
208 let own_global_end = self.owned_ends[w];
209 let local_owned_start = own_global_start - self.partition_global_start[w];
210 let owned_len = own_global_end - own_global_start;
211
212 for k in 0..owned_len {
213 let local_row = local_owned_start + k;
214 let global_row = own_global_start + k;
215 let row_start = partition.row_ptr[local_row];
216 let row_end = partition.row_ptr[local_row + 1];
217 for idx in row_start..row_end {
218 coo.push(global_row, partition.col_idx[idx], partition.values[idx]);
219 }
220 }
221 }
222
223 coo.to_csr()
224 }
225
226 pub fn load_balance_quality(&self) -> f64 {
233 let n = self.partitions.len();
234 if n <= 1 {
235 return 0.0;
236 }
237 let counts: Vec<f64> = self.partitions.iter().map(|p| p.n_nnz() as f64).collect();
238 let mean = counts.iter().sum::<f64>() / n as f64;
239 if mean < f64::EPSILON {
240 return 0.0;
241 }
242 let variance = counts.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / n as f64;
243 variance.sqrt() / mean
244 }
245}
246
247#[cfg(test)]
252mod tests {
253 use super::*;
254 use crate::gpu::construction::GpuCooMatrix;
255
256 fn tridiag(n: usize) -> GpuCsrMatrix {
258 let mut coo = GpuCooMatrix::new(n, n);
259 for i in 0..n {
260 coo.push(i, i, 4.0);
261 if i > 0 {
262 coo.push(i, i - 1, -1.0);
263 coo.push(i - 1, i, -1.0);
264 }
265 }
266 coo.to_csr()
267 }
268
269 #[test]
270 fn test_distributed_csr_spmv_matches_sequential() {
271 let n = 12;
272 let mat = tridiag(n);
273 let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
274
275 let y_seq = mat.spmv(&x).expect("sequential spmv failed");
276
277 let config = DistributedCsrConfig {
278 n_workers: 4,
279 overlap: 1,
280 };
281 let dist = PartitionedCsr::from_csr(&mat, &config);
282 let y_dist = dist.spmv(&x).expect("distributed spmv failed");
283
284 assert_eq!(y_seq.len(), y_dist.len());
285 for (i, (ys, yd)) in y_seq.iter().zip(y_dist.iter()).enumerate() {
286 assert!(
287 (ys - yd).abs() < 1e-10,
288 "row {i}: sequential={ys} distributed={yd}"
289 );
290 }
291 }
292
293 #[test]
294 fn test_distributed_partitioning_row_split() {
295 let n = 12;
296 let mat = tridiag(n);
297 let config = DistributedCsrConfig {
298 n_workers: 4,
299 overlap: 0,
300 };
301 let dist = PartitionedCsr::from_csr(&mat, &config);
302 assert_eq!(dist.partitions.len(), 4);
303
304 for w in &dist.row_offsets {
306 assert!(*w <= n);
307 }
308 }
309
310 #[test]
311 fn test_distributed_to_csr_roundtrip() {
312 let n = 8;
313 let mat = tridiag(n);
314 let config = DistributedCsrConfig {
315 n_workers: 3,
316 overlap: 1,
317 };
318 let dist = PartitionedCsr::from_csr(&mat, &config);
319 let reassembled = dist.to_csr();
320
321 assert_eq!(mat.n_nnz(), reassembled.n_nnz());
323 let d1 = mat.to_dense();
325 let d2 = reassembled.to_dense();
326 for i in 0..n {
327 for j in 0..n {
328 assert!(
329 (d1[[i, j]] - d2[[i, j]]).abs() < 1e-12,
330 "mismatch at ({i},{j})"
331 );
332 }
333 }
334 }
335
336 #[test]
337 fn test_load_balance_quality() {
338 let n = 12;
339 let mat = tridiag(n);
340 let config = DistributedCsrConfig {
341 n_workers: 4,
342 overlap: 0,
343 };
344 let dist = PartitionedCsr::from_csr(&mat, &config);
345 let q = dist.load_balance_quality();
346 assert!(q >= 0.0);
347 assert!(q < 1.0);
348 }
349
350 #[test]
351 fn test_single_worker() {
352 let n = 6;
353 let mat = tridiag(n);
354 let config = DistributedCsrConfig {
355 n_workers: 1,
356 overlap: 0,
357 };
358 let dist = PartitionedCsr::from_csr(&mat, &config);
359 let x = vec![1.0; n];
360 let y_seq = mat.spmv(&x).expect("spmv failed");
361 let y_dist = dist.spmv(&x).expect("distributed spmv failed");
362 for (ys, yd) in y_seq.iter().zip(y_dist.iter()) {
363 assert!((ys - yd).abs() < 1e-10);
364 }
365 }
366
367 #[test]
368 fn test_more_workers_than_rows() {
369 let n = 3;
370 let mat = tridiag(n);
371 let config = DistributedCsrConfig {
372 n_workers: 6,
373 overlap: 0,
374 };
375 let dist = PartitionedCsr::from_csr(&mat, &config);
376 let x = vec![1.0; n];
377 let y_seq = mat.spmv(&x).expect("spmv failed");
378 let y_dist = dist.spmv(&x).expect("distributed spmv failed");
379 for (ys, yd) in y_seq.iter().zip(y_dist.iter()) {
380 assert!((ys - yd).abs() < 1e-10);
381 }
382 }
383}