sprs_rssn/sparse/
smmp.rs

1//! Implementation of the paper
2//! Bank and Douglas, 2001, Sparse Matrix Multiplication Package (SMPP)
3
4use crate::indexing::SpIndex;
5use crate::sparse::prelude::*;
6use crate::sparse::CompressedStorage::CSR;
7#[cfg(feature = "multi_thread")]
8use rayon::prelude::*;
9
10#[cfg(feature = "multi_thread")]
11use std::cell::RefCell;
12
13/// Control the strategy used to parallelize the matrix product workload.
14///
15/// The `Automatic` strategy will try to pick a good number of threads based
16/// on the number of cores and an estimation of the nnz of the product
17/// matrix. This strategy is used by default.
18///
19/// The `AutomaticPhysical` strategy will try to pick a good number of threads
20/// based on the number of physical cores and an estimation of the nnz of the
21/// product matrix. This strategy is a fallback for machines where virtual
22/// cores do not provide a performance advantage.
23///
24/// The `Fixed` strategy leaves the control to the user. It is a programming
25/// error to request 0 threads.
26#[derive(Copy, Clone, Debug, PartialEq, Eq)]
27#[cfg(feature = "multi_thread")]
28pub enum ThreadingStrategy {
29    Automatic,
30    AutomaticPhysical,
31    Fixed(usize),
32}
33
34#[cfg(feature = "multi_thread")]
35thread_local! {
36    static THREADING_STRAT: RefCell<ThreadingStrategy> =
37    const { RefCell::new(ThreadingStrategy::Automatic) };
38}
39
40/// Set the threading strategy for matrix products in this thread.
41///
42/// # Panics
43///
44/// If a number of 0 threads is requested.
45#[cfg(feature = "multi_thread")]
46pub fn set_thread_threading_strategy(strategy: ThreadingStrategy) {
47    if let ThreadingStrategy::Fixed(nb_threads) = strategy {
48        assert!(nb_threads > 0);
49    }
50    THREADING_STRAT.with(|s| {
51        *s.borrow_mut() = strategy;
52    });
53}
54
55#[cfg(feature = "multi_thread")]
56pub fn thread_threading_strategy() -> ThreadingStrategy {
57    THREADING_STRAT.with(|s| *s.borrow())
58}
59
60/// Compute the symbolic structure of the matrix product C = A * B, with
61/// A, B and C stored in the CSR matrix format.
62///
63/// This algorithm has a complexity of O(n * k * log(k)), where k is the
64/// average number of nonzeros in the rows of the result.
65///
66/// # Panics
67///
68/// `index.len()` should be equal to the maximum dimension among the input
69/// matrices.
70///
71/// The matrices should be in proper CSR structure, and their dimensions
72/// should be compatible. Failures to do so may result in out of bounds errors
73/// (though some cases might go unnoticed).
74///
75/// # Minimizing allocations
76///
77/// This function will reserve
78/// `a_indptr.last().unwrap() + b_indptr.last.unwrap()` in `c_indices`.
79/// Therefore, to prevent this function from allocating, it is required
80/// to have reserved at least this amount of memory.
81pub fn symbolic<Iptr: SpIndex, I: SpIndex>(
82    a: CsStructureViewI<I, Iptr>,
83    b: CsStructureViewI<I, Iptr>,
84    c_indptr: &mut [Iptr],
85    // TODO look for litterature on the nnz of C to be able to have a slice here
86    c_indices: &mut Vec<I>,
87    seen: &mut [bool],
88) {
89    assert!(a.indptr().len() == c_indptr.len());
90    let a_nnz = a.nnz();
91    let b_nnz = b.nnz();
92    c_indices.clear();
93    c_indices.reserve_exact(a_nnz + b_nnz);
94
95    assert_eq!(a.cols(), b.rows());
96    assert!(seen.len() == b.cols());
97    for elt in seen.iter_mut() {
98        *elt = false;
99    }
100
101    c_indptr[0] = Iptr::from_usize(0);
102    for (a_row, a_range) in a.indptr().iter_outer_sz().enumerate() {
103        let mut length = 0;
104
105        // FIXME are iterators possible here?
106        // TODO benchmark unsafe indexing here. It's possible to get
107        // a subslice using get(a_range). It should also be possible to use
108        // index_unchecked and from_usize_unchecked
109        for &a_col in &a.indices()[a_range] {
110            let b_row = a_col.index();
111            let b_range = b.indptr().outer_inds_sz(b_row);
112            for b_col in &b.indices()[b_range] {
113                let b_col = b_col.index();
114                if !seen[b_col] {
115                    seen[b_col] = true;
116                    c_indices.push(I::from_usize(b_col));
117                    length += 1;
118                }
119            }
120        }
121        c_indptr[a_row + 1] = c_indptr[a_row] + Iptr::from_usize(length);
122        let c_start = c_indptr[a_row].index();
123        let c_end = c_start + length;
124        // TODO maybe sorting should be done outside, to have an even parallel
125        // workload
126        c_indices[c_start..c_end].sort_unstable();
127        for c_col in &c_indices[c_start..c_end] {
128            seen[c_col.index()] = false;
129        }
130    }
131}
132
133/// Numeric part of the matrix product C = A * B with A, B and C stored in the
134/// CSR matrix format.
135///
136/// This function is low-level, and supports execution on chunks of the
137/// rows of C and A. To use the chunks, split the indptrs of A and C and split
138/// `c_indices` and `c_data` to only contain the elements referenced in
139/// `c_indptr`. This function will take care of using the correct offset
140/// inside the sliced indices and data.
141///
142/// # Panics
143///
144/// `tmp.len()` should be equal to the maximum dimension of the inputs.
145///
146/// The matrices should be in proper CSR structure, and their dimensions
147/// should be compatible. Failures to do so may result in out of bounds errors
148/// (though some cases might go unnoticed).
149///
150/// The parts for the C matrix should come from the `symbolic` function.
151pub fn numeric<
152    Iptr: SpIndex,
153    I: SpIndex,
154    A,
155    B,
156    N: crate::MulAcc<A, B> + num_traits::Zero,
157>(
158    a: CsMatViewI<A, I, Iptr>,
159    b: CsMatViewI<B, I, Iptr>,
160    mut c: CsMatViewMutI<N, I, Iptr>,
161    tmp: &mut [N],
162) {
163    assert_eq!(a.rows(), c.rows());
164    assert_eq!(a.cols(), b.rows());
165    assert_eq!(b.cols(), c.cols());
166    assert_eq!(tmp.len(), b.cols());
167    assert!(a.is_csr());
168    assert!(b.is_csr());
169
170    for elt in tmp.iter_mut() {
171        *elt = N::zero();
172    }
173    for (a_row, mut c_row) in a.outer_iterator().zip(c.outer_iterator_mut()) {
174        for (a_col, a_val) in a_row.iter() {
175            // TODO unchecked index
176            let b_row = b.outer_view(a_col.index()).unwrap();
177            for (b_col, b_val) in b_row.iter() {
178                // TODO unsafe indexing
179                tmp[b_col.index()].mul_acc(a_val, b_val);
180            }
181        }
182        for (c_col, c_val) in c_row.iter_mut() {
183            // TODO unsafe indexing
184            let mut val = N::zero();
185            std::mem::swap(&mut val, &mut tmp[c_col]);
186            *c_val = val;
187        }
188    }
189}
190
191/// Compute a sparse matrix product using the SMMP routines
192///
193/// # Panics
194///
195/// - if `lhs.cols() != rhs.rows()`.
196pub fn mul_csr_csr<N, A, B, I, Iptr>(
197    lhs: CsMatViewI<A, I, Iptr>,
198    rhs: CsMatViewI<B, I, Iptr>,
199) -> CsMatI<N, I, Iptr>
200where
201    N: crate::MulAcc<A, B> + num_traits::Zero + Clone + Send + Sync,
202    A: Send + Sync,
203    B: Send + Sync,
204    I: SpIndex,
205    Iptr: SpIndex,
206{
207    assert_eq!(lhs.cols(), rhs.rows());
208    let workspace_len = rhs.cols();
209    #[cfg(feature = "multi_thread")]
210    let nb_threads = std::cmp::min(lhs.rows().max(1), {
211        use self::ThreadingStrategy::{Automatic, AutomaticPhysical};
212        match thread_threading_strategy() {
213            ThreadingStrategy::Fixed(nb_threads) => nb_threads,
214            strat @ Automatic | strat @ AutomaticPhysical => {
215                let nb_cpus = if strat == ThreadingStrategy::Automatic {
216                    num_cpus::get()
217                } else {
218                    num_cpus::get_physical()
219                };
220                let ideal_chunk_size = 8128;
221                let wanted_threads = (lhs.nnz() + rhs.nnz()) / ideal_chunk_size;
222                // wanted_threads could be < nb_cpus
223                #[allow(clippy::manual_clamp)]
224                1.max(wanted_threads).min(nb_cpus)
225            }
226        }
227    });
228    #[cfg(not(feature = "multi_thread"))]
229    let nb_threads = 1;
230    let mut tmps = Vec::with_capacity(nb_threads);
231    for _ in 0..nb_threads {
232        tmps.push(vec![N::zero(); workspace_len].into_boxed_slice());
233    }
234    let mut seens =
235        vec![vec![false; workspace_len].into_boxed_slice(); nb_threads];
236    mul_csr_csr_with_workspace(lhs, rhs, &mut seens, &mut tmps)
237}
238
239/// Compute a sparse matrix product using the SMMP routines, using temporary
240/// storage that was already allocated
241///
242/// `seens` and `tmps` are temporary storage vectors used to accumulate non
243/// zero locations and values. Their values need not be specified on input.
244/// They will be zero on output. They are slices of boxed slices, where the
245/// outer slice is there to give mutliple workspaces for multi-threading.
246/// Therefore, `seens.len()` controls the number of threads used for symbolic
247/// computation, and `tmps.len()` the number of threads for numeric computation.
248///
249/// # Panics
250///
251/// - if `lhs.cols() != rhs.rows()`.
252/// - if `seens.len() == 0`
253/// - if `tmps.len() == 0`
254/// - if `seens[i].len() != lhs.cols().max(lhs.rows()).max(rhs.cols())`
255/// - if `tmps[i].len() != lhs.cols().max(lhs.rows()).max(rhs.cols())`
256pub fn mul_csr_csr_with_workspace<N, A, B, I, Iptr>(
257    lhs: CsMatViewI<A, I, Iptr>,
258    rhs: CsMatViewI<B, I, Iptr>,
259    seens: &mut [Box<[bool]>],
260    tmps: &mut [Box<[N]>],
261) -> CsMatI<N, I, Iptr>
262where
263    N: crate::MulAcc<A, B> + num_traits::Zero + Clone + Send + Sync,
264    A: Send + Sync,
265    B: Send + Sync,
266    I: SpIndex,
267    Iptr: SpIndex,
268{
269    let workspace_len = rhs.cols();
270    assert_eq!(lhs.cols(), rhs.rows());
271    assert!(seens.iter().all(|x| x.len() == workspace_len));
272    assert!(tmps.iter().all(|x| x.len() == workspace_len));
273    let indptr_len = lhs.rows() + 1;
274    let mut res_indices = Vec::new();
275    let nb_threads = seens.len();
276    assert!(nb_threads > 0);
277    let chunk_size = lhs.indptr().len() / nb_threads;
278    let mut lhs_chunks = Vec::with_capacity(nb_threads);
279    let mut res_indptr_chunks = Vec::with_capacity(nb_threads);
280    let mut res_indices_chunks = Vec::with_capacity(nb_threads);
281    for chunk_id in 0..nb_threads {
282        let start = if chunk_id == 0 {
283            0
284        } else {
285            chunk_id * chunk_size
286        };
287        let stop = if chunk_id + 1 < nb_threads {
288            (chunk_id + 1) * chunk_size
289        } else {
290            lhs.rows()
291        };
292        lhs_chunks.push(lhs.slice_outer(start..stop));
293        res_indptr_chunks.push(vec![Iptr::zero(); stop - start + 1]);
294        res_indices_chunks
295            .push(Vec::with_capacity(lhs.nnz() + rhs.nnz() / chunk_size));
296    }
297    #[cfg(feature = "multi_thread")]
298    let iter = lhs_chunks
299        .par_iter()
300        .zip(res_indptr_chunks.par_iter_mut())
301        .zip(res_indices_chunks.par_iter_mut())
302        .zip(seens.par_iter_mut());
303    #[cfg(not(feature = "multi_thread"))]
304    let iter = lhs_chunks
305        .iter()
306        .zip(res_indptr_chunks.iter_mut())
307        .zip(res_indices_chunks.iter_mut())
308        .zip(seens.iter_mut());
309    iter.for_each(
310        |(((lhs_chunk, res_indptr_chunk), res_indices_chunk), seen)| {
311            symbolic(
312                lhs_chunk.structure_view(),
313                rhs.structure_view(),
314                res_indptr_chunk,
315                res_indices_chunk,
316                seen,
317            );
318        },
319    );
320    res_indices.reserve(res_indices_chunks.iter().map(Vec::len).sum());
321    for res_indices_chunk in &res_indices_chunks {
322        res_indices.extend_from_slice(res_indices_chunk);
323    }
324    let mut res_indptr = Vec::with_capacity(indptr_len);
325    res_indptr.push(Iptr::zero());
326    for res_indptr_chunk in &res_indptr_chunks {
327        for row in res_indptr_chunk.windows(2) {
328            let nnz = row[1] - row[0];
329            res_indptr.push(nnz + *res_indptr.last().unwrap());
330        }
331    }
332    let mut res_data = vec![N::zero(); res_indices.len()];
333    let nb_threads = tmps.len();
334    assert!(nb_threads > 0);
335    let chunk_size = res_indices.len() / nb_threads;
336    let mut res_indices_rem = &res_indices[..];
337    let mut res_data_rem = &mut res_data[..];
338    let mut prev_nnz = 0;
339    let mut split_nnz = 0;
340    let mut split_row = 0;
341    let mut lhs_chunks = Vec::with_capacity(nb_threads);
342    let mut res_indptr_chunks = Vec::with_capacity(nb_threads);
343    let mut res_indices_chunks = Vec::with_capacity(nb_threads);
344    let mut res_data_chunks = Vec::with_capacity(nb_threads);
345    for (row, nnz) in res_indptr.iter().enumerate() {
346        let nnz = nnz.index();
347        if nnz - split_nnz > chunk_size && row > 0 {
348            lhs_chunks.push(lhs.slice_outer(split_row..row - 1));
349
350            res_indptr_chunks.push(&res_indptr[split_row..row]);
351
352            let (left, right) = res_indices_rem
353                .split_at(prev_nnz - res_indptr[split_row].index());
354            res_indices_chunks.push(left);
355            res_indices_rem = right;
356
357            // FIXME it would be a good idea to have split_outer_mut on
358            // CsMatViewMut
359            let (left, right) = res_data_rem
360                .split_at_mut(prev_nnz - res_indptr[split_row].index());
361            res_data_chunks.push(left);
362            res_data_rem = right;
363
364            split_nnz = nnz;
365            split_row = row - 1;
366        }
367        prev_nnz = nnz;
368    }
369    lhs_chunks.push(lhs.slice_outer(split_row..lhs.rows()));
370    res_indptr_chunks.push(&res_indptr[split_row..]);
371    res_indices_chunks.push(res_indices_rem);
372    res_data_chunks.push(res_data_rem);
373    #[cfg(feature = "multi_thread")]
374    let iter = lhs_chunks
375        .par_iter()
376        .zip(res_indptr_chunks.par_iter())
377        .zip(res_indices_chunks.par_iter())
378        .zip(res_data_chunks.par_iter_mut())
379        .zip(tmps.par_iter_mut());
380    #[cfg(not(feature = "multi_thread"))]
381    let iter = lhs_chunks
382        .iter()
383        .zip(res_indptr_chunks.iter())
384        .zip(res_indices_chunks.iter())
385        .zip(res_data_chunks.iter_mut())
386        .zip(tmps.iter_mut());
387    iter.for_each(
388        |(
389            (
390                ((lhs_chunk, res_indptr_chunk), res_indices_chunk),
391                res_data_chunk,
392            ),
393            tmp,
394        )| {
395            let res_chunk = CsMatViewMutI::new_trusted(
396                CSR,
397                (lhs_chunk.rows(), rhs.cols()),
398                res_indptr_chunk,
399                res_indices_chunk,
400                res_data_chunk,
401            );
402            numeric(lhs_chunk.view(), rhs.view(), res_chunk, tmp);
403        },
404    );
405
406    // Correctness: The invariants of the output come from the invariants of
407    // the inputs when in-bounds indices are concerned, and we are sorting
408    // indices.
409    CsMatI::new_trusted(
410        CSR,
411        (lhs.rows(), rhs.cols()),
412        res_indptr,
413        res_indices,
414        res_data,
415    )
416}
417
418#[cfg(test)]
419mod test {
420    use crate::test_data;
421
422    #[test]
423    fn symbolic_and_numeric() {
424        let a = test_data::mat1();
425        let b = test_data::mat2();
426        // a * b 's structure:
427        //                | x x x   x |
428        //                | x     x   |
429        //                |           |
430        //                |     x x   |
431        //                |   x x     |
432        //
433        // |     x x   |  |     x x   |
434        // |       x x |  |   x x x   |
435        // |     x     |  |           |
436        // |   x       |  | x     x   |
437        // |       x   |  |     x x   |
438        let exp = test_data::mat1_matprod_mat2();
439
440        let mut c_indptr = [0; 6];
441        let mut c_indices = Vec::new();
442        let mut seen = [false; 5];
443
444        super::symbolic(
445            a.structure_view(),
446            b.structure_view(),
447            &mut c_indptr,
448            &mut c_indices,
449            &mut seen,
450        );
451
452        let mut c_data = vec![0.; c_indices.len()];
453        let mut tmp = [0.; 5];
454        let mut c = crate::CsMatViewMutI::new_trusted(
455            crate::CompressedStorage::CSR,
456            (a.rows(), b.cols()),
457            &c_indptr[..],
458            &c_indices[..],
459            &mut c_data[..],
460        );
461        super::numeric(a.view(), b.view(), c.view_mut(), &mut tmp);
462        assert_eq!(exp.indptr(), &c_indptr[..]);
463        assert_eq!(exp.indices(), &c_indices[..]);
464        assert_eq!(exp.data(), &c_data[..]);
465    }
466
467    #[test]
468    fn mul_csr_csr() {
469        let a = test_data::mat1();
470        let exp = test_data::mat1_self_matprod();
471        let res = super::mul_csr_csr(a.view(), a.view());
472        assert_eq!(exp, res);
473    }
474
475    #[test]
476    fn mul_zero_rows() {
477        // See https://github.com/vbarrielle/sprs/issues/239
478        let a = crate::CsMat::new((0, 11), vec![0], vec![], vec![]);
479        let b = crate::CsMat::new(
480            (11, 11),
481            vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
482            vec![],
483            vec![],
484        );
485        let c: crate::CsMat<f64> = &a * &b;
486        assert_eq!(c.rows(), 0);
487        assert_eq!(c.cols(), 11);
488        assert_eq!(c.nnz(), 0);
489    }
490
491    #[test]
492    #[cfg(feature = "multi_thread")]
493    fn mul_csr_csr_multithreaded() {
494        let a = test_data::mat1();
495        let exp = test_data::mat1_self_matprod();
496        super::set_thread_threading_strategy(super::ThreadingStrategy::Fixed(
497            4,
498        ));
499        let res = super::mul_csr_csr(a.view(), a.view());
500        assert_eq!(exp, res);
501    }
502
503    #[test]
504    #[cfg(feature = "multi_thread")]
505    fn mul_csr_csr_one_long_row_multithreaded() {
506        super::set_thread_threading_strategy(super::ThreadingStrategy::Fixed(
507            4,
508        ));
509        let a = crate::CsVec::<f32>::empty(100);
510        let b = crate::CsMat::<f32>::zero((100, 10)).to_csc();
511
512        let _ = &a * &b;
513    }
514
515    #[test]
516    fn mul_complex() {
517        use num_complex::Complex32;
518        // | 0  1 0   0  |
519        // | 0  0 0   0  |
520        // | i  0 0  1+i |
521        // | 0  0 2i  0  |
522        let a = crate::CsMat::new(
523            (4, 4),
524            vec![0, 1, 1, 3, 4],
525            vec![1, 0, 3, 2],
526            vec![
527                Complex32::new(1., 0.),
528                Complex32::new(0., 1.),
529                Complex32::new(1., 1.),
530                Complex32::new(0., 2.),
531            ],
532        );
533        //                 | 0  1 0      0  |
534        //                 | 0  0 0      0  |
535        //                 | i  0 0     1+i |
536        //                 | 0  0 2i     0  |
537        //
538        // | 0  1 0   0  | | 0  0   0    0  |
539        // | 0  0 0   0  | | 0  0   0    0  |
540        // | i  0 0  1+i | | 0  i -2+2i  0  |
541        // | 0  0 2i  0  | |-2  0   0  -2+2i|
542        let expected = crate::CsMat::new(
543            (4, 4),
544            vec![0, 0, 0, 2, 4],
545            vec![1, 2, 0, 3],
546            vec![
547                Complex32::new(0., 1.),
548                Complex32::new(-2., 2.),
549                Complex32::new(-2., 0.),
550                Complex32::new(-2., 2.),
551            ],
552        );
553        let b = &a * &a;
554        assert_eq!(b, expected);
555    }
556}