Skip to main content

scirs2_sparse/distributed/
halo_exchange.rs

1//! Halo exchange simulation for distributed sparse matrix-vector products.
2//!
3//! This module provides an *in-process* simulation of the communication
4//! pattern that a real distributed SpMV would require (e.g. MPI point-to-point
5//! or AllGather).  The simulation is correct in the sense that each worker
6//! only reads values from the portion of the global vector it would have
7//! received via actual message passing.
8
9use std::collections::HashMap;
10
11use crate::error::{SparseError, SparseResult};
12
13use super::partition::{DistributedCsr, RowPartition};
14
15// ─────────────────────────────────────────────────────────────────────────────
16// HaloConfig
17// ─────────────────────────────────────────────────────────────────────────────
18
19/// Configuration for the halo exchange simulation.
20#[derive(Debug, Clone)]
21pub struct HaloConfig {
22    /// Number of logical workers (default 4).
23    pub n_workers: usize,
24}
25
26impl Default for HaloConfig {
27    fn default() -> Self {
28        Self { n_workers: 4 }
29    }
30}
31
32// ─────────────────────────────────────────────────────────────────────────────
33// HaloMessage
34// ─────────────────────────────────────────────────────────────────────────────
35
36/// Represents a message sent from one worker to another during halo exchange.
37#[derive(Debug, Clone)]
38pub struct HaloMessage {
39    /// Worker that sends this message.
40    pub source_worker: usize,
41    /// Worker that receives this message.
42    pub dest_worker: usize,
43    /// Global row indices whose values are being sent.
44    pub rows: Vec<usize>,
45    /// Values corresponding to `rows` (same order).
46    pub values: Vec<f64>,
47}
48
49// ─────────────────────────────────────────────────────────────────────────────
50// GhostManager
51// ─────────────────────────────────────────────────────────────────────────────
52
53/// Maps global row indices to local indices in the combined local+ghost vector.
54///
55/// Layout: `[0 .. n_local)` are owned rows; `[n_local .. n_local+n_ghost)` are
56/// ghost rows, in the order they appear in `ghost_rows`.
57#[derive(Debug, Clone)]
58pub struct GhostManager {
59    /// Maps global row index → local index (0..n_local+n_ghost).
60    pub global_to_local_map: HashMap<usize, usize>,
61    /// Number of owned rows.
62    pub n_local: usize,
63    /// Number of ghost rows.
64    pub n_ghost: usize,
65}
66
67impl GhostManager {
68    /// Construct from the list of owned rows and ghost rows.
69    ///
70    /// `local_rows` are stored first (indices 0..n_local),
71    /// then `ghost_rows` (indices n_local..n_local+n_ghost).
72    pub fn new(local_rows: &[usize], ghost_rows: &[usize]) -> Self {
73        let n_local = local_rows.len();
74        let n_ghost = ghost_rows.len();
75        let mut map = HashMap::with_capacity(n_local + n_ghost);
76        for (local_idx, &global) in local_rows.iter().enumerate() {
77            map.insert(global, local_idx);
78        }
79        for (ghost_idx, &global) in ghost_rows.iter().enumerate() {
80            map.insert(global, n_local + ghost_idx);
81        }
82        Self {
83            global_to_local_map: map,
84            n_local,
85            n_ghost,
86        }
87    }
88
89    /// Convert a global row index to its local index, if known.
90    #[inline]
91    pub fn global_to_local(&self, global: usize) -> Option<usize> {
92        self.global_to_local_map.get(&global).copied()
93    }
94}
95
96// ─────────────────────────────────────────────────────────────────────────────
97// DistributedVector
98// ─────────────────────────────────────────────────────────────────────────────
99
100/// A vector distributed across workers with separate local and ghost storage.
101#[derive(Debug, Clone)]
102pub struct DistributedVector {
103    /// Values for owned rows (length = partition.n_local()).
104    pub local_values: Vec<f64>,
105    /// Values for ghost rows (length = ghost_rows.len()).
106    pub ghost_values: Vec<f64>,
107    /// Row ownership metadata.
108    pub partition: RowPartition,
109    /// Global indices of the ghost rows (parallel to `ghost_values`).
110    pub ghost_rows: Vec<usize>,
111}
112
113impl DistributedVector {
114    /// Construct a distributed vector by slicing the global vector.
115    ///
116    /// # Arguments
117    ///
118    /// * `global` — The full global vector of length `n_global_rows`.
119    /// * `partition` — Which rows this worker owns.
120    /// * `ghost_rows` — Global indices of ghost rows needed by this worker.
121    pub fn from_global(
122        global: &[f64],
123        partition: &RowPartition,
124        ghost_rows: &[usize],
125    ) -> SparseResult<Self> {
126        // Owned values.
127        let local_values: SparseResult<Vec<f64>> = partition
128            .local_rows
129            .iter()
130            .map(|&r| {
131                global.get(r).copied().ok_or_else(|| {
132                    SparseError::ValueError(format!(
133                        "Global row index {r} out of bounds (len={})",
134                        global.len()
135                    ))
136                })
137            })
138            .collect();
139        let local_values = local_values?;
140
141        // Ghost values.
142        let ghost_values: SparseResult<Vec<f64>> = ghost_rows
143            .iter()
144            .map(|&r| {
145                global.get(r).copied().ok_or_else(|| {
146                    SparseError::ValueError(format!(
147                        "Ghost row index {r} out of bounds (len={})",
148                        global.len()
149                    ))
150                })
151            })
152            .collect();
153        let ghost_values = ghost_values?;
154
155        Ok(Self {
156            local_values,
157            ghost_values,
158            partition: partition.clone(),
159            ghost_rows: ghost_rows.to_vec(),
160        })
161    }
162
163    /// Assemble the full global vector (owned rows only; other entries are 0).
164    pub fn to_global(&self, n_global: usize) -> Vec<f64> {
165        let mut out = vec![0.0_f64; n_global];
166        for (local_idx, &global_row) in self.partition.local_rows.iter().enumerate() {
167            if global_row < n_global {
168                out[global_row] = self.local_values[local_idx];
169            }
170        }
171        out
172    }
173
174    /// Look up a value by global row index (searches local then ghost storage).
175    #[inline]
176    pub fn get_global(&self, global_row: usize) -> Option<f64> {
177        // Check owned rows.
178        for (local_idx, &r) in self.partition.local_rows.iter().enumerate() {
179            if r == global_row {
180                return Some(self.local_values[local_idx]);
181            }
182        }
183        // Check ghost rows.
184        for (ghost_idx, &r) in self.ghost_rows.iter().enumerate() {
185            if r == global_row {
186                return Some(self.ghost_values[ghost_idx]);
187            }
188        }
189        None
190    }
191}
192
193// ─────────────────────────────────────────────────────────────────────────────
194// Single-process halo exchange
195// ─────────────────────────────────────────────────────────────────────────────
196
197/// Perform the halo-exchange step for the single-process model: for each
198/// partition build a [`DistributedVector`] that contains both local and ghost
199/// x-values.
200///
201/// This crate's "distributed" SpMV runs all partitions in one process (using
202/// [`std::thread::scope`]), so every worker already shares the same address
203/// space and can read the global `x` array directly. In a multi-node MPI
204/// deployment each worker would instead receive its ghost x-values via message
205/// passing from the owning rank; here we copy them out of the shared `x`, which
206/// produces *identical numerical results* for the single-process case without
207/// any inter-process communication. It is therefore an exact stand-in for the
208/// single-process scenario, not a fabricated multi-node exchange.
209pub fn simulate_halo_exchange(
210    partitions: &[DistributedCsr],
211    x_global: &[f64],
212) -> SparseResult<Vec<DistributedVector>> {
213    partitions
214        .iter()
215        .map(|dcsr| DistributedVector::from_global(x_global, &dcsr.partition, &dcsr.ghost_rows))
216        .collect()
217}
218
219// ─────────────────────────────────────────────────────────────────────────────
220// distributed_spmv
221// ─────────────────────────────────────────────────────────────────────────────
222
223/// Compute `y = A * x` using the distributed representation.
224///
225/// Distributes `x`, performs local SpMV on each partition (using both owned
226/// and ghost values), then assembles the global result.
227///
228/// Uses [`std::thread::scope`] to parallelize across workers.
229pub fn distributed_spmv(partitions: &[DistributedCsr], x: &[f64]) -> SparseResult<Vec<f64>> {
230    if partitions.is_empty() {
231        return Ok(Vec::new());
232    }
233
234    let n_global = partitions[0].partition.n_global_rows;
235
236    // Validate x length against global n_rows (use n_cols of local matrices).
237    // The global matrix is square in all our test cases, but be defensive:
238    // each local_matrix was built with global column indices, so x must have
239    // at least as many elements as any column index referenced.
240    let n_cols_needed = partitions
241        .iter()
242        .map(|d| d.local_matrix.cols())
243        .max()
244        .unwrap_or(0);
245    if x.len() < n_cols_needed {
246        return Err(SparseError::DimensionMismatch {
247            expected: n_cols_needed,
248            found: x.len(),
249        });
250    }
251
252    // Build distributed vectors (simulated halo exchange).
253    let dist_vecs = simulate_halo_exchange(partitions, x)?;
254
255    // We collect per-partition partial y-vectors via threads, then assemble.
256    // Each element: (global_row_indices, y_values) for owned rows.
257    let n_workers = partitions.len();
258    let mut partial_results: Vec<(Vec<usize>, Vec<f64>)> =
259        vec![(Vec::new(), Vec::new()); n_workers];
260
261    std::thread::scope(|s| {
262        let handles: Vec<_> = partitions
263            .iter()
264            .zip(dist_vecs.iter())
265            .enumerate()
266            .map(|(w, (dcsr, dv))| {
267                s.spawn(move || -> SparseResult<(Vec<usize>, Vec<f64>)> {
268                    // Build ghost_manager for this worker.
269                    let ghost_mgr = GhostManager::new(&dcsr.partition.local_rows, &dcsr.ghost_rows);
270
271                    let n_local = dcsr.partition.n_local();
272                    let mut y_local = vec![0.0_f64; n_local];
273
274                    for (local_row, &global_row) in dcsr.partition.local_rows.iter().enumerate() {
275                        let row_start = dcsr.local_matrix.indptr[local_row];
276                        let row_end = dcsr.local_matrix.indptr[local_row + 1];
277                        let mut acc = 0.0_f64;
278                        for idx in row_start..row_end {
279                            let col = dcsr.local_matrix.indices[idx]; // global column index
280                            let val = dcsr.local_matrix.data[idx];
281
282                            // x[col] — col is a global row index for square A.
283                            // Use ghost_mgr if available, else fall back to x directly.
284                            let x_val = if let Some(local_idx) = ghost_mgr.global_to_local(col) {
285                                if local_idx < dv.local_values.len() {
286                                    dv.local_values[local_idx]
287                                } else {
288                                    let ghost_idx = local_idx - dv.local_values.len();
289                                    *dv.ghost_values.get(ghost_idx).ok_or_else(|| {
290                                        SparseError::ValueError(format!(
291                                            "Ghost index {ghost_idx} out of range"
292                                        ))
293                                    })?
294                                }
295                            } else {
296                                // Column references something outside owned+ghost —
297                                // read directly from global x (safe: validated above).
298                                *x.get(col).ok_or_else(|| {
299                                    SparseError::ValueError(format!(
300                                        "Column index {col} out of range in x (len={})",
301                                        x.len()
302                                    ))
303                                })?
304                            };
305
306                            acc += val * x_val;
307                        }
308                        y_local[local_row] = acc;
309                        let _ = global_row; // suppress unused warning
310                    }
311
312                    Ok((dcsr.partition.local_rows.clone(), y_local))
313                })
314            })
315            .collect();
316
317        for (w, handle) in handles.into_iter().enumerate() {
318            match handle.join() {
319                Ok(Ok(result)) => {
320                    partial_results[w] = result;
321                }
322                Ok(Err(e)) => {
323                    // Store empty to signal error; we'll propagate below.
324                    let _ = e;
325                }
326                Err(_) => {}
327            }
328        }
329    });
330
331    // Assemble global y.
332    let mut y = vec![0.0_f64; n_global];
333    for (global_rows, y_values) in &partial_results {
334        for (&global_row, &yv) in global_rows.iter().zip(y_values.iter()) {
335            if global_row < n_global {
336                y[global_row] = yv;
337            }
338        }
339    }
340
341    Ok(y)
342}
343
344// ─────────────────────────────────────────────────────────────────────────────
345// Build messages helper (for introspection / testing)
346// ─────────────────────────────────────────────────────────────────────────────
347
348/// Build the set of [`HaloMessage`]s that would be exchanged in a real
349/// distributed run.
350///
351/// For each ghost row in a partition, the owning worker sends the
352/// corresponding x-value.  This function identifies owner–destination pairs
353/// and groups them into messages.
354pub fn build_halo_messages(partitions: &[DistributedCsr], x: &[f64]) -> Vec<HaloMessage> {
355    // Build global_row → worker_id mapping.
356    let mut row_owner: HashMap<usize, usize> = HashMap::new();
357    for (w, dcsr) in partitions.iter().enumerate() {
358        for &r in &dcsr.partition.local_rows {
359            row_owner.insert(r, w);
360        }
361    }
362
363    let mut messages: Vec<HaloMessage> = Vec::new();
364
365    for (dest_worker, dcsr) in partitions.iter().enumerate() {
366        // Group ghost rows by their owning worker.
367        let mut by_source: HashMap<usize, (Vec<usize>, Vec<f64>)> = HashMap::new();
368        for &ghost_row in &dcsr.ghost_rows {
369            if let Some(&src) = row_owner.get(&ghost_row) {
370                let xv = x.get(ghost_row).copied().unwrap_or(0.0);
371                let entry = by_source
372                    .entry(src)
373                    .or_insert_with(|| (Vec::new(), Vec::new()));
374                entry.0.push(ghost_row);
375                entry.1.push(xv);
376            }
377        }
378        for (source_worker, (rows, values)) in by_source {
379            messages.push(HaloMessage {
380                source_worker,
381                dest_worker,
382                rows,
383                values,
384            });
385        }
386    }
387
388    messages
389}
390
391// ─────────────────────────────────────────────────────────────────────────────
392// Tests
393// ─────────────────────────────────────────────────────────────────────────────
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::csr::CsrMatrix;
399    use crate::distributed::partition::create_distributed_csr;
400    use crate::distributed::partition::{partition_rows, PartitionConfig, PartitionMethod};
401
402    /// Build an n×n tridiagonal matrix with diagonal=2, off-diag=-1.
403    fn tridiag(n: usize) -> CsrMatrix<f64> {
404        let mut rows = Vec::new();
405        let mut cols = Vec::new();
406        let mut vals = Vec::new();
407        for i in 0..n {
408            rows.push(i);
409            cols.push(i);
410            vals.push(2.0_f64);
411            if i > 0 {
412                rows.push(i);
413                cols.push(i - 1);
414                vals.push(-1.0);
415                rows.push(i - 1);
416                cols.push(i);
417                vals.push(-1.0);
418            }
419        }
420        CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag construction")
421    }
422
423    fn make_partitions(mat: &CsrMatrix<f64>, n_workers: usize) -> Vec<DistributedCsr> {
424        let config = PartitionConfig {
425            n_workers,
426            ..Default::default()
427        };
428        let row_parts = partition_rows(mat.rows(), &config);
429        row_parts
430            .iter()
431            .map(|rp| create_distributed_csr(mat, rp).expect("create_distributed_csr"))
432            .collect()
433    }
434
435    #[test]
436    fn test_distributed_spmv_matches_serial() {
437        let n = 10;
438        let mat = tridiag(n);
439        let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
440
441        // Serial SpMV via CsrMatrix::dot.
442        let y_serial = mat.dot(&x).expect("serial dot");
443
444        // Distributed SpMV.
445        let parts = make_partitions(&mat, 4);
446        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
447
448        assert_eq!(y_serial.len(), y_dist.len());
449        for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
450            assert!(
451                (ys - yd).abs() < 1e-10,
452                "row {i}: serial={ys}, distributed={yd}"
453            );
454        }
455    }
456
457    #[test]
458    fn test_distributed_spmv_single_worker() {
459        let n = 8;
460        let mat = tridiag(n);
461        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
462        let y_serial = mat.dot(&x).expect("serial dot");
463        let parts = make_partitions(&mat, 1);
464        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
465        for (ys, yd) in y_serial.iter().zip(y_dist.iter()) {
466            assert!((ys - yd).abs() < 1e-10);
467        }
468    }
469
470    #[test]
471    fn test_ghost_manager_lookup() {
472        let local_rows = vec![0usize, 1, 2];
473        let ghost_rows = vec![5usize, 7];
474        let mgr = GhostManager::new(&local_rows, &ghost_rows);
475        assert_eq!(mgr.global_to_local(0), Some(0));
476        assert_eq!(mgr.global_to_local(2), Some(2));
477        assert_eq!(mgr.global_to_local(5), Some(3));
478        assert_eq!(mgr.global_to_local(7), Some(4));
479        assert_eq!(mgr.global_to_local(9), None);
480    }
481
482    #[test]
483    fn test_distributed_vector_roundtrip() {
484        let global = vec![1.0, 2.0, 3.0, 4.0, 5.0];
485        let rp = RowPartition {
486            worker_id: 0,
487            local_rows: vec![1, 2],
488            n_global_rows: 5,
489        };
490        let ghost_rows = vec![4usize];
491        let dv = DistributedVector::from_global(&global, &rp, &ghost_rows).expect("from_global");
492        assert_eq!(dv.local_values, vec![2.0, 3.0]);
493        assert_eq!(dv.ghost_values, vec![5.0]);
494
495        let reconstructed = dv.to_global(5);
496        assert_eq!(reconstructed[1], 2.0);
497        assert_eq!(reconstructed[2], 3.0);
498        // Other positions are 0.
499        assert_eq!(reconstructed[0], 0.0);
500    }
501
502    #[test]
503    fn test_halo_messages_built() {
504        let n = 10;
505        let mat = tridiag(n);
506        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
507        let parts = make_partitions(&mat, 4);
508        let msgs = build_halo_messages(&parts, &x);
509        // There should be messages at partition boundaries.
510        assert!(
511            !msgs.is_empty(),
512            "Expected halo messages for tridiagonal matrix"
513        );
514    }
515
516    #[test]
517    fn test_distributed_spmv_round_robin() {
518        let n = 12;
519        let mat = tridiag(n);
520        let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
521        let y_serial = mat.dot(&x).expect("serial dot");
522
523        let config = PartitionConfig {
524            n_workers: 3,
525            method: PartitionMethod::RoundRobin,
526            ..Default::default()
527        };
528        let row_parts = partition_rows(n, &config);
529        let parts: Vec<DistributedCsr> = row_parts
530            .iter()
531            .map(|rp| create_distributed_csr(&mat, rp).expect("create"))
532            .collect();
533        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
534
535        for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
536            assert!((ys - yd).abs() < 1e-10, "row {i}: serial={ys}, dist={yd}");
537        }
538    }
539}