Skip to main content

scirs2_sparse/distributed/
halo_exchange.rs

1//! Halo exchange simulation for distributed sparse matrix-vector products.
2//!
3//! This module provides an *in-process* simulation of the communication
4//! pattern that a real distributed SpMV would require (e.g. MPI point-to-point
5//! or AllGather).  The simulation is correct in the sense that each worker
6//! only reads values from the portion of the global vector it would have
7//! received via actual message passing.
8
9use std::collections::HashMap;
10
11use crate::error::{SparseError, SparseResult};
12
13use super::partition::{DistributedCsr, RowPartition};
14
15// ─────────────────────────────────────────────────────────────────────────────
16// HaloConfig
17// ─────────────────────────────────────────────────────────────────────────────
18
19/// Configuration for the halo exchange simulation.
20#[derive(Debug, Clone)]
21pub struct HaloConfig {
22    /// Number of logical workers (default 4).
23    pub n_workers: usize,
24}
25
26impl Default for HaloConfig {
27    fn default() -> Self {
28        Self { n_workers: 4 }
29    }
30}
31
32// ─────────────────────────────────────────────────────────────────────────────
33// HaloMessage
34// ─────────────────────────────────────────────────────────────────────────────
35
36/// Represents a message sent from one worker to another during halo exchange.
37#[derive(Debug, Clone)]
38pub struct HaloMessage {
39    /// Worker that sends this message.
40    pub source_worker: usize,
41    /// Worker that receives this message.
42    pub dest_worker: usize,
43    /// Global row indices whose values are being sent.
44    pub rows: Vec<usize>,
45    /// Values corresponding to `rows` (same order).
46    pub values: Vec<f64>,
47}
48
49// ─────────────────────────────────────────────────────────────────────────────
50// GhostManager
51// ─────────────────────────────────────────────────────────────────────────────
52
53/// Maps global row indices to local indices in the combined local+ghost vector.
54///
55/// Layout: `[0 .. n_local)` are owned rows; `[n_local .. n_local+n_ghost)` are
56/// ghost rows, in the order they appear in `ghost_rows`.
57#[derive(Debug, Clone)]
58pub struct GhostManager {
59    /// Maps global row index → local index (0..n_local+n_ghost).
60    pub global_to_local_map: HashMap<usize, usize>,
61    /// Number of owned rows.
62    pub n_local: usize,
63    /// Number of ghost rows.
64    pub n_ghost: usize,
65}
66
67impl GhostManager {
68    /// Construct from the list of owned rows and ghost rows.
69    ///
70    /// `local_rows` are stored first (indices 0..n_local),
71    /// then `ghost_rows` (indices n_local..n_local+n_ghost).
72    pub fn new(local_rows: &[usize], ghost_rows: &[usize]) -> Self {
73        let n_local = local_rows.len();
74        let n_ghost = ghost_rows.len();
75        let mut map = HashMap::with_capacity(n_local + n_ghost);
76        for (local_idx, &global) in local_rows.iter().enumerate() {
77            map.insert(global, local_idx);
78        }
79        for (ghost_idx, &global) in ghost_rows.iter().enumerate() {
80            map.insert(global, n_local + ghost_idx);
81        }
82        Self {
83            global_to_local_map: map,
84            n_local,
85            n_ghost,
86        }
87    }
88
89    /// Convert a global row index to its local index, if known.
90    #[inline]
91    pub fn global_to_local(&self, global: usize) -> Option<usize> {
92        self.global_to_local_map.get(&global).copied()
93    }
94}
95
96// ─────────────────────────────────────────────────────────────────────────────
97// DistributedVector
98// ─────────────────────────────────────────────────────────────────────────────
99
100/// A vector distributed across workers with separate local and ghost storage.
101#[derive(Debug, Clone)]
102pub struct DistributedVector {
103    /// Values for owned rows (length = partition.n_local()).
104    pub local_values: Vec<f64>,
105    /// Values for ghost rows (length = ghost_rows.len()).
106    pub ghost_values: Vec<f64>,
107    /// Row ownership metadata.
108    pub partition: RowPartition,
109    /// Global indices of the ghost rows (parallel to `ghost_values`).
110    pub ghost_rows: Vec<usize>,
111}
112
113impl DistributedVector {
114    /// Construct a distributed vector by slicing the global vector.
115    ///
116    /// # Arguments
117    ///
118    /// * `global` — The full global vector of length `n_global_rows`.
119    /// * `partition` — Which rows this worker owns.
120    /// * `ghost_rows` — Global indices of ghost rows needed by this worker.
121    pub fn from_global(
122        global: &[f64],
123        partition: &RowPartition,
124        ghost_rows: &[usize],
125    ) -> SparseResult<Self> {
126        // Owned values.
127        let local_values: SparseResult<Vec<f64>> = partition
128            .local_rows
129            .iter()
130            .map(|&r| {
131                global.get(r).copied().ok_or_else(|| {
132                    SparseError::ValueError(format!(
133                        "Global row index {r} out of bounds (len={})",
134                        global.len()
135                    ))
136                })
137            })
138            .collect();
139        let local_values = local_values?;
140
141        // Ghost values.
142        let ghost_values: SparseResult<Vec<f64>> = ghost_rows
143            .iter()
144            .map(|&r| {
145                global.get(r).copied().ok_or_else(|| {
146                    SparseError::ValueError(format!(
147                        "Ghost row index {r} out of bounds (len={})",
148                        global.len()
149                    ))
150                })
151            })
152            .collect();
153        let ghost_values = ghost_values?;
154
155        Ok(Self {
156            local_values,
157            ghost_values,
158            partition: partition.clone(),
159            ghost_rows: ghost_rows.to_vec(),
160        })
161    }
162
163    /// Assemble the full global vector (owned rows only; other entries are 0).
164    pub fn to_global(&self, n_global: usize) -> Vec<f64> {
165        let mut out = vec![0.0_f64; n_global];
166        for (local_idx, &global_row) in self.partition.local_rows.iter().enumerate() {
167            if global_row < n_global {
168                out[global_row] = self.local_values[local_idx];
169            }
170        }
171        out
172    }
173
174    /// Look up a value by global row index (searches local then ghost storage).
175    #[inline]
176    pub fn get_global(&self, global_row: usize) -> Option<f64> {
177        // Check owned rows.
178        for (local_idx, &r) in self.partition.local_rows.iter().enumerate() {
179            if r == global_row {
180                return Some(self.local_values[local_idx]);
181            }
182        }
183        // Check ghost rows.
184        for (ghost_idx, &r) in self.ghost_rows.iter().enumerate() {
185            if r == global_row {
186                return Some(self.ghost_values[ghost_idx]);
187            }
188        }
189        None
190    }
191}
192
193// ─────────────────────────────────────────────────────────────────────────────
194// Simulated halo exchange
195// ─────────────────────────────────────────────────────────────────────────────
196
197/// Simulate the halo exchange step: for each partition build a
198/// [`DistributedVector`] that contains both local and ghost x-values.
199///
200/// In a real MPI implementation each worker would send its owned x-values to
201/// any worker that lists them as ghost rows.  Here we simply read directly from
202/// the global x array, which is equivalent but avoids actual message passing.
203pub fn simulate_halo_exchange(
204    partitions: &[DistributedCsr],
205    x_global: &[f64],
206) -> SparseResult<Vec<DistributedVector>> {
207    partitions
208        .iter()
209        .map(|dcsr| DistributedVector::from_global(x_global, &dcsr.partition, &dcsr.ghost_rows))
210        .collect()
211}
212
213// ─────────────────────────────────────────────────────────────────────────────
214// distributed_spmv
215// ─────────────────────────────────────────────────────────────────────────────
216
217/// Compute `y = A * x` using the distributed representation.
218///
219/// Distributes `x`, performs local SpMV on each partition (using both owned
220/// and ghost values), then assembles the global result.
221///
222/// Uses [`std::thread::scope`] to parallelize across workers.
223pub fn distributed_spmv(partitions: &[DistributedCsr], x: &[f64]) -> SparseResult<Vec<f64>> {
224    if partitions.is_empty() {
225        return Ok(Vec::new());
226    }
227
228    let n_global = partitions[0].partition.n_global_rows;
229
230    // Validate x length against global n_rows (use n_cols of local matrices).
231    // The global matrix is square in all our test cases, but be defensive:
232    // each local_matrix was built with global column indices, so x must have
233    // at least as many elements as any column index referenced.
234    let n_cols_needed = partitions
235        .iter()
236        .map(|d| d.local_matrix.cols())
237        .max()
238        .unwrap_or(0);
239    if x.len() < n_cols_needed {
240        return Err(SparseError::DimensionMismatch {
241            expected: n_cols_needed,
242            found: x.len(),
243        });
244    }
245
246    // Build distributed vectors (simulated halo exchange).
247    let dist_vecs = simulate_halo_exchange(partitions, x)?;
248
249    // We collect per-partition partial y-vectors via threads, then assemble.
250    // Each element: (global_row_indices, y_values) for owned rows.
251    let n_workers = partitions.len();
252    let mut partial_results: Vec<(Vec<usize>, Vec<f64>)> =
253        vec![(Vec::new(), Vec::new()); n_workers];
254
255    std::thread::scope(|s| {
256        let handles: Vec<_> = partitions
257            .iter()
258            .zip(dist_vecs.iter())
259            .enumerate()
260            .map(|(w, (dcsr, dv))| {
261                s.spawn(move || -> SparseResult<(Vec<usize>, Vec<f64>)> {
262                    // Build ghost_manager for this worker.
263                    let ghost_mgr = GhostManager::new(&dcsr.partition.local_rows, &dcsr.ghost_rows);
264
265                    let n_local = dcsr.partition.n_local();
266                    let mut y_local = vec![0.0_f64; n_local];
267
268                    for (local_row, &global_row) in dcsr.partition.local_rows.iter().enumerate() {
269                        let row_start = dcsr.local_matrix.indptr[local_row];
270                        let row_end = dcsr.local_matrix.indptr[local_row + 1];
271                        let mut acc = 0.0_f64;
272                        for idx in row_start..row_end {
273                            let col = dcsr.local_matrix.indices[idx]; // global column index
274                            let val = dcsr.local_matrix.data[idx];
275
276                            // x[col] — col is a global row index for square A.
277                            // Use ghost_mgr if available, else fall back to x directly.
278                            let x_val = if let Some(local_idx) = ghost_mgr.global_to_local(col) {
279                                if local_idx < dv.local_values.len() {
280                                    dv.local_values[local_idx]
281                                } else {
282                                    let ghost_idx = local_idx - dv.local_values.len();
283                                    *dv.ghost_values.get(ghost_idx).ok_or_else(|| {
284                                        SparseError::ValueError(format!(
285                                            "Ghost index {ghost_idx} out of range"
286                                        ))
287                                    })?
288                                }
289                            } else {
290                                // Column references something outside owned+ghost —
291                                // read directly from global x (safe: validated above).
292                                *x.get(col).ok_or_else(|| {
293                                    SparseError::ValueError(format!(
294                                        "Column index {col} out of range in x (len={})",
295                                        x.len()
296                                    ))
297                                })?
298                            };
299
300                            acc += val * x_val;
301                        }
302                        y_local[local_row] = acc;
303                        let _ = global_row; // suppress unused warning
304                    }
305
306                    Ok((dcsr.partition.local_rows.clone(), y_local))
307                })
308            })
309            .collect();
310
311        for (w, handle) in handles.into_iter().enumerate() {
312            match handle.join() {
313                Ok(Ok(result)) => {
314                    partial_results[w] = result;
315                }
316                Ok(Err(e)) => {
317                    // Store empty to signal error; we'll propagate below.
318                    let _ = e;
319                }
320                Err(_) => {}
321            }
322        }
323    });
324
325    // Assemble global y.
326    let mut y = vec![0.0_f64; n_global];
327    for (global_rows, y_values) in &partial_results {
328        for (&global_row, &yv) in global_rows.iter().zip(y_values.iter()) {
329            if global_row < n_global {
330                y[global_row] = yv;
331            }
332        }
333    }
334
335    Ok(y)
336}
337
338// ─────────────────────────────────────────────────────────────────────────────
339// Build messages helper (for introspection / testing)
340// ─────────────────────────────────────────────────────────────────────────────
341
342/// Build the set of [`HaloMessage`]s that would be exchanged in a real
343/// distributed run.
344///
345/// For each ghost row in a partition, the owning worker sends the
346/// corresponding x-value.  This function identifies owner–destination pairs
347/// and groups them into messages.
348pub fn build_halo_messages(partitions: &[DistributedCsr], x: &[f64]) -> Vec<HaloMessage> {
349    // Build global_row → worker_id mapping.
350    let mut row_owner: HashMap<usize, usize> = HashMap::new();
351    for (w, dcsr) in partitions.iter().enumerate() {
352        for &r in &dcsr.partition.local_rows {
353            row_owner.insert(r, w);
354        }
355    }
356
357    let mut messages: Vec<HaloMessage> = Vec::new();
358
359    for (dest_worker, dcsr) in partitions.iter().enumerate() {
360        // Group ghost rows by their owning worker.
361        let mut by_source: HashMap<usize, (Vec<usize>, Vec<f64>)> = HashMap::new();
362        for &ghost_row in &dcsr.ghost_rows {
363            if let Some(&src) = row_owner.get(&ghost_row) {
364                let xv = x.get(ghost_row).copied().unwrap_or(0.0);
365                let entry = by_source
366                    .entry(src)
367                    .or_insert_with(|| (Vec::new(), Vec::new()));
368                entry.0.push(ghost_row);
369                entry.1.push(xv);
370            }
371        }
372        for (source_worker, (rows, values)) in by_source {
373            messages.push(HaloMessage {
374                source_worker,
375                dest_worker,
376                rows,
377                values,
378            });
379        }
380    }
381
382    messages
383}
384
385// ─────────────────────────────────────────────────────────────────────────────
386// Tests
387// ─────────────────────────────────────────────────────────────────────────────
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::csr::CsrMatrix;
393    use crate::distributed::partition::create_distributed_csr;
394    use crate::distributed::partition::{partition_rows, PartitionConfig, PartitionMethod};
395
396    /// Build an n×n tridiagonal matrix with diagonal=2, off-diag=-1.
397    fn tridiag(n: usize) -> CsrMatrix<f64> {
398        let mut rows = Vec::new();
399        let mut cols = Vec::new();
400        let mut vals = Vec::new();
401        for i in 0..n {
402            rows.push(i);
403            cols.push(i);
404            vals.push(2.0_f64);
405            if i > 0 {
406                rows.push(i);
407                cols.push(i - 1);
408                vals.push(-1.0);
409                rows.push(i - 1);
410                cols.push(i);
411                vals.push(-1.0);
412            }
413        }
414        CsrMatrix::from_triplets(n, n, rows, cols, vals).expect("tridiag construction")
415    }
416
417    fn make_partitions(mat: &CsrMatrix<f64>, n_workers: usize) -> Vec<DistributedCsr> {
418        let config = PartitionConfig {
419            n_workers,
420            ..Default::default()
421        };
422        let row_parts = partition_rows(mat.rows(), &config);
423        row_parts
424            .iter()
425            .map(|rp| create_distributed_csr(mat, rp).expect("create_distributed_csr"))
426            .collect()
427    }
428
429    #[test]
430    fn test_distributed_spmv_matches_serial() {
431        let n = 10;
432        let mat = tridiag(n);
433        let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
434
435        // Serial SpMV via CsrMatrix::dot.
436        let y_serial = mat.dot(&x).expect("serial dot");
437
438        // Distributed SpMV.
439        let parts = make_partitions(&mat, 4);
440        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
441
442        assert_eq!(y_serial.len(), y_dist.len());
443        for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
444            assert!(
445                (ys - yd).abs() < 1e-10,
446                "row {i}: serial={ys}, distributed={yd}"
447            );
448        }
449    }
450
451    #[test]
452    fn test_distributed_spmv_single_worker() {
453        let n = 8;
454        let mat = tridiag(n);
455        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
456        let y_serial = mat.dot(&x).expect("serial dot");
457        let parts = make_partitions(&mat, 1);
458        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
459        for (ys, yd) in y_serial.iter().zip(y_dist.iter()) {
460            assert!((ys - yd).abs() < 1e-10);
461        }
462    }
463
464    #[test]
465    fn test_ghost_manager_lookup() {
466        let local_rows = vec![0usize, 1, 2];
467        let ghost_rows = vec![5usize, 7];
468        let mgr = GhostManager::new(&local_rows, &ghost_rows);
469        assert_eq!(mgr.global_to_local(0), Some(0));
470        assert_eq!(mgr.global_to_local(2), Some(2));
471        assert_eq!(mgr.global_to_local(5), Some(3));
472        assert_eq!(mgr.global_to_local(7), Some(4));
473        assert_eq!(mgr.global_to_local(9), None);
474    }
475
476    #[test]
477    fn test_distributed_vector_roundtrip() {
478        let global = vec![1.0, 2.0, 3.0, 4.0, 5.0];
479        let rp = RowPartition {
480            worker_id: 0,
481            local_rows: vec![1, 2],
482            n_global_rows: 5,
483        };
484        let ghost_rows = vec![4usize];
485        let dv = DistributedVector::from_global(&global, &rp, &ghost_rows).expect("from_global");
486        assert_eq!(dv.local_values, vec![2.0, 3.0]);
487        assert_eq!(dv.ghost_values, vec![5.0]);
488
489        let reconstructed = dv.to_global(5);
490        assert_eq!(reconstructed[1], 2.0);
491        assert_eq!(reconstructed[2], 3.0);
492        // Other positions are 0.
493        assert_eq!(reconstructed[0], 0.0);
494    }
495
496    #[test]
497    fn test_halo_messages_built() {
498        let n = 10;
499        let mat = tridiag(n);
500        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
501        let parts = make_partitions(&mat, 4);
502        let msgs = build_halo_messages(&parts, &x);
503        // There should be messages at partition boundaries.
504        assert!(
505            !msgs.is_empty(),
506            "Expected halo messages for tridiagonal matrix"
507        );
508    }
509
510    #[test]
511    fn test_distributed_spmv_round_robin() {
512        let n = 12;
513        let mat = tridiag(n);
514        let x: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
515        let y_serial = mat.dot(&x).expect("serial dot");
516
517        let config = PartitionConfig {
518            n_workers: 3,
519            method: PartitionMethod::RoundRobin,
520            ..Default::default()
521        };
522        let row_parts = partition_rows(n, &config);
523        let parts: Vec<DistributedCsr> = row_parts
524            .iter()
525            .map(|rp| create_distributed_csr(&mat, rp).expect("create"))
526            .collect();
527        let y_dist = distributed_spmv(&parts, &x).expect("distributed_spmv");
528
529        for (i, (ys, yd)) in y_serial.iter().zip(y_dist.iter()).enumerate() {
530            assert!((ys - yd).abs() < 1e-10, "row {i}: serial={ys}, dist={yd}");
531        }
532    }
533}