scirs2_sparse/distributed/
halo_exchange.rs1use std::collections::HashMap;
10
11use crate::error::{SparseError, SparseResult};
12
13use super::partition::{DistributedCsr, RowPartition};
14
15#[derive(Debug, Clone)]
21pub struct HaloConfig {
22 pub n_workers: usize,
24}
25
26impl Default for HaloConfig {
27 fn default() -> Self {
28 Self { n_workers: 4 }
29 }
30}
31
32#[derive(Debug, Clone)]
38pub struct HaloMessage {
39 pub source_worker: usize,
41 pub dest_worker: usize,
43 pub rows: Vec<usize>,
45 pub values: Vec<f64>,
47}
48
49#[derive(Debug, Clone)]
58pub struct GhostManager {
59 pub global_to_local_map: HashMap<usize, usize>,
61 pub n_local: usize,
63 pub n_ghost: usize,
65}
66
67impl GhostManager {
68 pub fn new(local_rows: &[usize], ghost_rows: &[usize]) -> Self {
73 let n_local = local_rows.len();
74 let n_ghost = ghost_rows.len();
75 let mut map = HashMap::with_capacity(n_local + n_ghost);
76 for (local_idx, &global) in local_rows.iter().enumerate() {
77 map.insert(global, local_idx);
78 }
79 for (ghost_idx, &global) in ghost_rows.iter().enumerate() {
80 map.insert(global, n_local + ghost_idx);
81 }
82 Self {
83 global_to_local_map: map,
84 n_local,
85 n_ghost,
86 }
87 }
88
89 #[inline]
91 pub fn global_to_local(&self, global: usize) -> Option<usize> {
92 self.global_to_local_map.get(&global).copied()
93 }
94}
95
96#[derive(Debug, Clone)]
102pub struct DistributedVector {
103 pub local_values: Vec<f64>,
105 pub ghost_values: Vec<f64>,
107 pub partition: RowPartition,
109 pub ghost_rows: Vec<usize>,
111}
112
113impl DistributedVector {
114 pub fn from_global(
122 global: &[f64],
123 partition: &RowPartition,
124 ghost_rows: &[usize],
125 ) -> SparseResult<Self> {
126 let local_values: SparseResult<Vec<f64>> = partition
128 .local_rows
129 .iter()
130 .map(|&r| {
131 global.get(r).copied().ok_or_else(|| {
132 SparseError::ValueError(format!(
133 "Global row index {r} out of bounds (len={})",
134 global.len()
135 ))
136 })
137 })
138 .collect();
139 let local_values = local_values?;
140
141 let ghost_values: SparseResult<Vec<f64>> = ghost_rows
143 .iter()
144 .map(|&r| {
145 global.get(r).copied().ok_or_else(|| {
146 SparseError::ValueError(format!(
147 "Ghost row index {r} out of bounds (len={})",
148 global.len()
149 ))
150 })
151 })
152 .collect();
153 let ghost_values = ghost_values?;
154
155 Ok(Self {
156 local_values,
157 ghost_values,
158 partition: partition.clone(),
159 ghost_rows: ghost_rows.to_vec(),
160 })
161 }
162
163 pub fn to_global(&self, n_global: usize) -> Vec<f64> {
165 let mut out = vec![0.0_f64; n_global];
166 for (local_idx, &global_row) in self.partition.local_rows.iter().enumerate() {
167 if global_row < n_global {
168 out[global_row] = self.local_values[local_idx];
169 }
170 }
171 out
172 }
173
174 #[inline]
176 pub fn get_global(&self, global_row: usize) -> Option<f64> {
177 for (local_idx, &r) in self.partition.local_rows.iter().enumerate() {
179 if r == global_row {
180 return Some(self.local_values[local_idx]);
181 }
182 }
183 for (ghost_idx, &r) in self.ghost_rows.iter().enumerate() {
185 if r == global_row {
186 return Some(self.ghost_values[ghost_idx]);
187 }
188 }
189 None
190 }
191}
192
193pub fn simulate_halo_exchange(
204 partitions: &[DistributedCsr],
205 x_global: &[f64],
206) -> SparseResult<Vec<DistributedVector>> {
207 partitions
208 .iter()
209 .map(|dcsr| DistributedVector::from_global(x_global, &dcsr.partition, &dcsr.ghost_rows))
210 .collect()
211}
212
213pub fn distributed_spmv(partitions: &[DistributedCsr], x: &[f64]) -> SparseResult<Vec<f64>> {
224 if partitions.is_empty() {
225 return Ok(Vec::new());
226 }
227
228 let n_global = partitions[0].partition.n_global_rows;
229
230 let n_cols_needed = partitions
235 .iter()
236 .map(|d| d.local_matrix.cols())
237 .max()
238 .unwrap_or(0);
239 if x.len() < n_cols_needed {
240 return Err(SparseError::DimensionMismatch {
241 expected: n_cols_needed,
242 found: x.len(),
243 });
244 }
245
246 let dist_vecs = simulate_halo_exchange(partitions, x)?;
248
249 let n_workers = partitions.len();
252 let mut partial_results: Vec<(Vec<usize>, Vec<f64>)> =
253 vec![(Vec::new(), Vec::new()); n_workers];
254
255 std::thread::scope(|s| {
256 let handles: Vec<_> = partitions
257 .iter()
258 .zip(dist_vecs.iter())
259 .enumerate()
260 .map(|(w, (dcsr, dv))| {
261 s.spawn(move || -> SparseResult<(Vec<usize>, Vec<f64>)> {
262 let ghost_mgr = GhostManager::new(&dcsr.partition.local_rows, &dcsr.ghost_rows);
264
265 let n_local = dcsr.partition.n_local();
266 let mut y_local = vec![0.0_f64; n_local];
267
268 for (local_row, &global_row) in dcsr.partition.local_rows.iter().enumerate() {
269 let row_start = dcsr.local_matrix.indptr[local_row];
270 let row_end = dcsr.local_matrix.indptr[local_row + 1];
271 let mut acc = 0.0_f64;
272 for idx in row_start..row_end {
273 let col = dcsr.local_matrix.indices[idx]; let val = dcsr.local_matrix.data[idx];
275
276 let x_val = if let Some(local_idx) = ghost_mgr.global_to_local(col) {
279 if local_idx < dv.local_values.len() {
280 dv.local_values[local_idx]
281 } else {
282 let ghost_idx = local_idx - dv.local_values.len();
283 *dv.ghost_values.get(ghost_idx).ok_or_else(|| {
284 SparseError::ValueError(format!(
285 "Ghost index {ghost_idx} out of range"
286 ))
287 })?
288 }
289 } else {
290 *x.get(col).ok_or_else(|| {
293 SparseError::ValueError(format!(
294 "Column index {col} out of range in x (len={})",
295 x.len()
296 ))
297 })?
298 };
299
300 acc += val * x_val;
301 }
302 y_local[local_row] = acc;
303 let _ = global_row; }
305
306 Ok((dcsr.partition.local_rows.clone(), y_local))
307 })
308 })
309 .collect();
310
311 for (w, handle) in handles.into_iter().enumerate() {
312 match handle.join() {
313 Ok(Ok(result)) => {
314 partial_results[w] = result;
315 }
316 Ok(Err(e)) => {
317 let _ = e;
319 }
320 Err(_) => {}
321 }
322 }
323 });
324
325 let mut y = vec![0.0_f64; n_global];
327 for (global_rows, y_values) in &partial_results {
328 for (&global_row, &yv) in global_rows.iter().zip(y_values.iter()) {
329 if global_row < n_global {
330 y[global_row] = yv;
331 }
332 }
333 }
334
335 Ok(y)
336}
337
338pub fn build_halo_messages(partitions: &[DistributedCsr], x: &[f64]) -> Vec<HaloMessage> {
349 let mut row_owner: HashMap<usize, usize> = HashMap::new();
351 for (w, dcsr) in partitions.iter().enumerate() {
352 for &r in &dcsr.partition.local_rows {
353 row_owner.insert(r, w);
354 }
355 }
356
357 let mut messages: Vec<HaloMessage> = Vec::new();
358
359 for (dest_worker, dcsr) in partitions.iter().enumerate() {
360 let mut by_source: HashMap<usize, (Vec<usize>, Vec<f64>)> = HashMap::new();
362 for &ghost_row in &dcsr.ghost_rows {
363 if let Some(&src) = row_owner.get(&ghost_row) {
364 let xv = x.get(ghost_row).copied().unwrap_or(0.0);
365 let entry = by_source
366 .entry(src)
367 .or_insert_with(|| (Vec::new(), Vec::new()));
368 entry.0.push(ghost_row);
369 entry.1.push(xv);
370 }
371 }
372 for (source_worker, (rows, values)) in by_source {
373 messages.push(HaloMessage {
374 source_worker,
375 dest_worker,
376 rows,
377 values,
378 });
379 }
380 }
381
382 messages
383}
384
385#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::csr::CsrMatrix;
393 use crate::distributed::partition::create_distributed_csr;
394 use crate::distributed::partition::{partition_rows, PartitionConfig, PartitionMethod};
395
396 fn tridiag(n: usize) -> CsrMatrix<f64> {
398 let mut rows = Vec::new();
399 let mut cols = Vec::new();
400 let mut vals = Vec::new();
401 for i in 0..n {
402 rows.push(i);
403 cols.push(i);
404 vals.push(2.0_f64);
405 if i > 0 {
406 rows.push(i);
407 cols.push(i - 1);
408 vals.push(-1.0);
409 rows.push(i - 1);
410 cols.push(i);
411 vals.push(-1.0);
412 }
413 }
414 CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag construction")
415 }
416
417 fn make_partitions(mat: &CsrMatrix<f64>, n_workers: usize) -> Vec<DistributedCsr> {
418 let config = PartitionConfig {
419 n_workers,
420 ..Default::default()
421 };
422 let row_parts = partition_rows(mat.rows(), &config);
423 row_parts
424 .iter()
425 .map(|rp| create_distributed_csr(mat, rp).expect("create_distributed_csr"))
426 .collect()
427 }
428
429 #[test]
430 fn test_distributed_spmv_matches_serial() {
431 let n = 10;
432 let mat = tridiag(n);
433 let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
434
435 let y_serial = mat.dot(&x).expect("serial dot");
437
438 let parts = make_partitions(&mat, 4);
440 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
441
442 assert_eq!(y_serial.len(), y_dist.len());
443 for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
444 assert!(
445 (ys - yd).abs() < 1e-10,
446 "row {i}: serial={ys}, distributed={yd}"
447 );
448 }
449 }
450
451 #[test]
452 fn test_distributed_spmv_single_worker() {
453 let n = 8;
454 let mat = tridiag(n);
455 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
456 let y_serial = mat.dot(&x).expect("serial dot");
457 let parts = make_partitions(&mat, 1);
458 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
459 for (ys, yd) in y_serial.iter().zip(y_dist.iter()) {
460 assert!((ys - yd).abs() < 1e-10);
461 }
462 }
463
464 #[test]
465 fn test_ghost_manager_lookup() {
466 let local_rows = vec![0usize, 1, 2];
467 let ghost_rows = vec![5usize, 7];
468 let mgr = GhostManager::new(&local_rows, &ghost_rows);
469 assert_eq!(mgr.global_to_local(0), Some(0));
470 assert_eq!(mgr.global_to_local(2), Some(2));
471 assert_eq!(mgr.global_to_local(5), Some(3));
472 assert_eq!(mgr.global_to_local(7), Some(4));
473 assert_eq!(mgr.global_to_local(9), None);
474 }
475
476 #[test]
477 fn test_distributed_vector_roundtrip() {
478 let global = vec![1.0, 2.0, 3.0, 4.0, 5.0];
479 let rp = RowPartition {
480 worker_id: 0,
481 local_rows: vec![1, 2],
482 n_global_rows: 5,
483 };
484 let ghost_rows = vec![4usize];
485 let dv = DistributedVector::from_global(&global, &rp, &ghost_rows).expect("from_global");
486 assert_eq!(dv.local_values, vec![2.0, 3.0]);
487 assert_eq!(dv.ghost_values, vec![5.0]);
488
489 let reconstructed = dv.to_global(5);
490 assert_eq!(reconstructed[1], 2.0);
491 assert_eq!(reconstructed[2], 3.0);
492 assert_eq!(reconstructed[0], 0.0);
494 }
495
496 #[test]
497 fn test_halo_messages_built() {
498 let n = 10;
499 let mat = tridiag(n);
500 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
501 let parts = make_partitions(&mat, 4);
502 let msgs = build_halo_messages(&parts, &x);
503 assert!(
505 !msgs.is_empty(),
506 "Expected halo messages for tridiagonal matrix"
507 );
508 }
509
510 #[test]
511 fn test_distributed_spmv_round_robin() {
512 let n = 12;
513 let mat = tridiag(n);
514 let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
515 let y_serial = mat.dot(&x).expect("serial dot");
516
517 let config = PartitionConfig {
518 n_workers: 3,
519 method: PartitionMethod::RoundRobin,
520 ..Default::default()
521 };
522 let row_parts = partition_rows(n, &config);
523 let parts: Vec<DistributedCsr> = row_parts
524 .iter()
525 .map(|rp| create_distributed_csr(&mat, rp).expect("create"))
526 .collect();
527 let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
528
529 for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
530 assert!((ys - yd).abs() < 1e-10, "row {i}: serial={ys}, dist={yd}");
531 }
532 }
533}