scirs2_sparse/distributed/
halo_exchange.rs1use std::collections::HashMap;
10
11use crate::error::{SparseError, SparseResult};
12
13use super::partition::{DistributedCsr, RowPartition};
14
15#[derive(Debug, Clone)]
21pub struct HaloConfig {
22 pub n_workers: usize,
24}
25
26impl Default for HaloConfig {
27 fn default() -> Self {
28 Self { n_workers: 4 }
29 }
30}
31
32#[derive(Debug, Clone)]
38pub struct HaloMessage {
39 pub source_worker: usize,
41 pub dest_worker: usize,
43 pub rows: Vec<usize>,
45 pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
58pub struct GhostManager {
59 pub global_to_local_map: HashMap<usize, usize>,
61 pub n_local: usize,
63 pub n_ghost: usize,
65}
66
67impl GhostManager {
68 pub fn new(local_rows: &[usize], ghost_rows: &[usize]) -> Self {
73 let n_local = local_rows.len();
74 let n_ghost = ghost_rows.len();
75 let mut map = HashMap::with_capacity(n_local + n_ghost);
76 for (local_idx, &global) in local_rows.iter().enumerate() {
77 map.insert(global, local_idx);
78 }
79 for (ghost_idx, &global) in ghost_rows.iter().enumerate() {
80 map.insert(global, n_local + ghost_idx);
81 }
82 Self {
83 global_to_local_map: map,
84 n_local,
85 n_ghost,
86 }
87 }
88
89 #[inline]
91 pub fn global_to_local(&self, global: usize) -> Option<usize> {
92 self.global_to_local_map.get(&global).copied()
93 }
94}
95
96#[derive(Debug, Clone)]
102pub struct DistributedVector {
103 pub local_values: Vec<f64>,
105 pub ghost_values: Vec<f64>,
107 pub partition: RowPartition,
109 pub ghost_rows: Vec<usize>,
111}
112
113impl DistributedVector {
114 pub fn from_global(
122 global: &[f64],
123 partition: &RowPartition,
124 ghost_rows: &[usize],
125 ) -> SparseResult<Self> {
126 let local_values: SparseResult<Vec<f64>> = partition
128 .local_rows
129 .iter()
130 .map(|&r| {
131 global.get(r).copied().ok_or_else(|| {
132 SparseError::ValueError(format!(
133 "Global row index {r} out of bounds (len={})",
134 global.len()
135 ))
136 })
137 })
138 .collect();
139 let local_values = local_values?;
140
141 let ghost_values: SparseResult<Vec<f64>> = ghost_rows
143 .iter()
144 .map(|&r| {
145 global.get(r).copied().ok_or_else(|| {
146 SparseError::ValueError(format!(
147 "Ghost row index {r} out of bounds (len={})",
148 global.len()
149 ))
150 })
151 })
152 .collect();
153 let ghost_values = ghost_values?;
154
155 Ok(Self {
156 local_values,
157 ghost_values,
158 partition: partition.clone(),
159 ghost_rows: ghost_rows.to_vec(),
160 })
161 }
162
163 pub fn to_global(&self, n_global: usize) -> Vec<f64> {
165 let mut out = vec![0.0_f64; n_global];
166 for (local_idx, &global_row) in self.partition.local_rows.iter().enumerate() {
167 if global_row < n_global {
168 out[global_row] = self.local_values[local_idx];
169 }
170 }
171 out
172 }
173
174 #[inline]
176 pub fn get_global(&self, global_row: usize) -> Option<f64> {
177 for (local_idx, &r) in self.partition.local_rows.iter().enumerate() {
179 if r == global_row {
180 return Some(self.local_values[local_idx]);
181 }
182 }
183 for (ghost_idx, &r) in self.ghost_rows.iter().enumerate() {
185 if r == global_row {
186 return Some(self.ghost_values[ghost_idx]);
187 }
188 }
189 None
190 }
191}
192
193pub fn simulate_halo_exchange(
210 partitions: &[DistributedCsr],
211 x_global: &[f64],
212) -> SparseResult<Vec<DistributedVector>> {
213 partitions
214 .iter()
215 .map(|dcsr| DistributedVector::from_global(x_global, &dcsr.partition, &dcsr.ghost_rows))
216 .collect()
217}
218
219pub fn distributed_spmv(partitions: &[DistributedCsr], x: &[f64]) -> SparseResult<Vec<f64>> {
230 if partitions.is_empty() {
231 return Ok(Vec::new());
232 }
233
234 let n_global = partitions[0].partition.n_global_rows;
235
236 let n_cols_needed = partitions
241 .iter()
242 .map(|d| d.local_matrix.cols())
243 .max()
244 .unwrap_or(0);
245 if x.len() < n_cols_needed {
246 return Err(SparseError::DimensionMismatch {
247 expected: n_cols_needed,
248 found: x.len(),
249 });
250 }
251
252 let dist_vecs = simulate_halo_exchange(partitions, x)?;
254
255 let n_workers = partitions.len();
258 let mut partial_results: Vec<(Vec<usize>, Vec<f64>)> =
259 vec![(Vec::new(), Vec::new()); n_workers];
260
261 std::thread::scope(|s| {
262 let handles: Vec<_> = partitions
263 .iter()
264 .zip(dist_vecs.iter())
265 .enumerate()
266 .map(|(w, (dcsr, dv))| {
267 s.spawn(move || -> SparseResult<(Vec<usize>, Vec<f64>)> {
268 let ghost_mgr = GhostManager::new(&dcsr.partition.local_rows, &dcsr.ghost_rows);
270
271 let n_local = dcsr.partition.n_local();
272 let mut y_local = vec![0.0_f64; n_local];
273
274 for (local_row, &global_row) in dcsr.partition.local_rows.iter().enumerate() {
275 let row_start = dcsr.local_matrix.indptr[local_row];
276 let row_end = dcsr.local_matrix.indptr[local_row + 1];
277 let mut acc = 0.0_f64;
278 for idx in row_start..row_end {
279 let col = dcsr.local_matrix.indices[idx]; let val = dcsr.local_matrix.data[idx];
281
282 let x_val = if let Some(local_idx) = ghost_mgr.global_to_local(col) {
285 if local_idx < dv.local_values.len() {
286 dv.local_values[local_idx]
287 } else {
288 let ghost_idx = local_idx - dv.local_values.len();
289 *dv.ghost_values.get(ghost_idx).ok_or_else(|| {
290 SparseError::ValueError(format!(
291 "Ghost index {ghost_idx} out of range"
292 ))
293 })?
294 }
295 } else {
296 *x.get(col).ok_or_else(|| {
299 SparseError::ValueError(format!(
300 "Column index {col} out of range in x (len={})",
301 x.len()
302 ))
303 })?
304 };
305
306 acc += val * x_val;
307 }
308 y_local[local_row] = acc;
309 let _ = global_row; }
311
312 Ok((dcsr.partition.local_rows.clone(), y_local))
313 })
314 })
315 .collect();
316
317 for (w, handle) in handles.into_iter().enumerate() {
318 match handle.join() {
319 Ok(Ok(result)) => {
320 partial_results[w] = result;
321 }
322 Ok(Err(e)) => {
323 let _ = e;
325 }
326 Err(_) => {}
327 }
328 }
329 });
330
331 let mut y = vec![0.0_f64; n_global];
333 for (global_rows, y_values) in &partial_results {
334 for (&global_row, &yv) in global_rows.iter().zip(y_values.iter()) {
335 if global_row < n_global {
336 y[global_row] = yv;
337 }
338 }
339 }
340
341 Ok(y)
342}
343
344pub fn build_halo_messages(partitions: &[DistributedCsr], x: &[f64]) -> Vec<HaloMessage> {
355 let mut row_owner: HashMap<usize, usize> = HashMap::new();
357 for (w, dcsr) in partitions.iter().enumerate() {
358 for &r in &dcsr.partition.local_rows {
359 row_owner.insert(r, w);
360 }
361 }
362
363 let mut messages: Vec<HaloMessage> = Vec::new();
364
365 for (dest_worker, dcsr) in partitions.iter().enumerate() {
366 let mut by_source: HashMap<usize, (Vec<usize>, Vec<f64>)> = HashMap::new();
368 for &ghost_row in &dcsr.ghost_rows {
369 if let Some(&src) = row_owner.get(&ghost_row) {
370 let xv = x.get(ghost_row).copied().unwrap_or(0.0);
371 let entry = by_source
372 .entry(src)
373 .or_insert_with(|| (Vec::new(), Vec::new()));
374 entry.0.push(ghost_row);
375 entry.1.push(xv);
376 }
377 }
378 for (source_worker, (rows, values)) in by_source {
379 messages.push(HaloMessage {
380 source_worker,
381 dest_worker,
382 rows,
383 values,
384 });
385 }
386 }
387
388 messages
389}
390
391#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::csr::CsrMatrix;
399 use crate::distributed::partition::create_distributed_csr;
400 use crate::distributed::partition::{partition_rows, PartitionConfig, PartitionMethod};
401
402 fn tridiag(n: usize) -> CsrMatrix<f64> {
404 let mut rows = Vec::new();
405 let mut cols = Vec::new();
406 let mut vals = Vec::new();
407 for i in 0..n {
408 rows.push(i);
409 cols.push(i);
410 vals.push(2.0_f64);
411 if i > 0 {
412 rows.push(i);
413 cols.push(i - 1);
414 vals.push(-1.0);
415 rows.push(i - 1);
416 cols.push(i);
417 vals.push(-1.0);
418 }
419 }
420 CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag construction")
421 }
422
423 fn make_partitions(mat: &CsrMatrix<f64>, n_workers: usize) -> Vec<DistributedCsr> {
424 let config = PartitionConfig {
425 n_workers,
426 ..Default::default()
427 };
428 let row_parts = partition_rows(mat.rows(), &config);
429 row_parts
430 .iter()
431 .map(|rp| create_distributed_csr(mat, rp).expect("create_distributed_csr"))
432 .collect()
433 }
434
435 #[test]
436 fn test_distributed_spmv_matches_serial() {
437 let n = 10;
438 let mat = tridiag(n);
439 let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
440
441 let y_serial = mat.dot(&x).expect("serial dot");
443
444 let parts = make_partitions(&mat, 4);
446 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
447
448 assert_eq!(y_serial.len(), y_dist.len());
449 for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
450 assert!(
451 (ys - yd).abs() < 1e-10,
452 "row {i}: serial={ys}, distributed={yd}"
453 );
454 }
455 }
456
457 #[test]
458 fn test_distributed_spmv_single_worker() {
459 let n = 8;
460 let mat = tridiag(n);
461 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
462 let y_serial = mat.dot(&x).expect("serial dot");
463 let parts = make_partitions(&mat, 1);
464 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
465 for (ys, yd) in y_serial.iter().zip(y_dist.iter()) {
466 assert!((ys - yd).abs() < 1e-10);
467 }
468 }
469
470 #[test]
471 fn test_ghost_manager_lookup() {
472 let local_rows = vec![0usize, 1, 2];
473 let ghost_rows = vec![5usize, 7];
474 let mgr = GhostManager::new(&local_rows, &ghost_rows);
475 assert_eq!(mgr.global_to_local(0), Some(0));
476 assert_eq!(mgr.global_to_local(2), Some(2));
477 assert_eq!(mgr.global_to_local(5), Some(3));
478 assert_eq!(mgr.global_to_local(7), Some(4));
479 assert_eq!(mgr.global_to_local(9), None);
480 }
481
482 #[test]
483 fn test_distributed_vector_roundtrip() {
484 let global = vec![1.0, 2.0, 3.0, 4.0, 5.0];
485 let rp = RowPartition {
486 worker_id: 0,
487 local_rows: vec![1, 2],
488 n_global_rows: 5,
489 };
490 let ghost_rows = vec![4usize];
491 let dv = DistributedVector::from_global(&global, &rp, &ghost_rows).expect("from_global");
492 assert_eq!(dv.local_values, vec![2.0, 3.0]);
493 assert_eq!(dv.ghost_values, vec![5.0]);
494
495 let reconstructed = dv.to_global(5);
496 assert_eq!(reconstructed[1], 2.0);
497 assert_eq!(reconstructed[2], 3.0);
498 assert_eq!(reconstructed[0], 0.0);
500 }
501
502 #[test]
503 fn test_halo_messages_built() {
504 let n = 10;
505 let mat = tridiag(n);
506 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
507 let parts = make_partitions(&mat, 4);
508 let msgs = build_halo_messages(&parts, &x);
509 assert!(
511 !msgs.is_empty(),
512 "Expected halo messages for tridiagonal matrix"
513 );
514 }
515
516 #[test]
517 fn test_distributed_spmv_round_robin() {
518 let n = 12;
519 let mat = tridiag(n);
520 let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
521 let y_serial = mat.dot(&x).expect("serial dot");
522
523 let config = PartitionConfig {
524 n_workers: 3,
525 method: PartitionMethod::RoundRobin,
526 ..Default::default()
527 };
528 let row_parts = partition_rows(n, &config);
529 let parts: Vec<DistributedCsr> = row_parts
530 .iter()
531 .map(|rp| create_distributed_csr(&mat, rp).expect("create"))
532 .collect();
533 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
534
535 for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
536 assert!((ys - yd).abs() < 1e-10, "row {i}: serial={ys}, dist={yd}");
537 }
538 }
539}