1use crate::indexing::SpIndex;
5use crate::sparse::prelude::*;
6use crate::sparse::CompressedStorage::CSR;
7#[cfg(feature = "multi_thread")]
8use rayon::prelude::*;
9
10#[cfg(feature = "multi_thread")]
11use std::cell::RefCell;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq)]
27#[cfg(feature = "multi_thread")]
28pub enum ThreadingStrategy {
29 Automatic,
30 AutomaticPhysical,
31 Fixed(usize),
32}
33
34#[cfg(feature = "multi_thread")]
35thread_local! {
36 static THREADING_STRAT: RefCell<ThreadingStrategy> =
37 const { RefCell::new(ThreadingStrategy::Automatic) };
38}
39
40#[cfg(feature = "multi_thread")]
46pub fn set_thread_threading_strategy(strategy: ThreadingStrategy) {
47 if let ThreadingStrategy::Fixed(nb_threads) = strategy {
48 assert!(nb_threads > 0);
49 }
50 THREADING_STRAT.with(|s| {
51 *s.borrow_mut() = strategy;
52 });
53}
54
55#[cfg(feature = "multi_thread")]
56pub fn thread_threading_strategy() -> ThreadingStrategy {
57 THREADING_STRAT.with(|s| *s.borrow())
58}
59
60pub fn symbolic<Iptr: SpIndex, I: SpIndex>(
82 a: CsStructureViewI<I, Iptr>,
83 b: CsStructureViewI<I, Iptr>,
84 c_indptr: &mut [Iptr],
85 c_indices: &mut Vec<I>,
87 seen: &mut [bool],
88) {
89 assert!(a.indptr().len() == c_indptr.len());
90 let a_nnz = a.nnz();
91 let b_nnz = b.nnz();
92 c_indices.clear();
93 c_indices.reserve_exact(a_nnz + b_nnz);
94
95 assert_eq!(a.cols(), b.rows());
96 assert!(seen.len() == b.cols());
97 for elt in seen.iter_mut() {
98 *elt = false;
99 }
100
101 c_indptr[0] = Iptr::from_usize(0);
102 for (a_row, a_range) in a.indptr().iter_outer_sz().enumerate() {
103 let mut length = 0;
104
105 for &a_col in &a.indices()[a_range] {
110 let b_row = a_col.index();
111 let b_range = b.indptr().outer_inds_sz(b_row);
112 for b_col in &b.indices()[b_range] {
113 let b_col = b_col.index();
114 if !seen[b_col] {
115 seen[b_col] = true;
116 c_indices.push(I::from_usize(b_col));
117 length += 1;
118 }
119 }
120 }
121 c_indptr[a_row + 1] = c_indptr[a_row] + Iptr::from_usize(length);
122 let c_start = c_indptr[a_row].index();
123 let c_end = c_start + length;
124 c_indices[c_start..c_end].sort_unstable();
127 for c_col in &c_indices[c_start..c_end] {
128 seen[c_col.index()] = false;
129 }
130 }
131}
132
133pub fn numeric<
152 Iptr: SpIndex,
153 I: SpIndex,
154 A,
155 B,
156 N: crate::MulAcc<A, B> + num_traits::Zero,
157>(
158 a: CsMatViewI<A, I, Iptr>,
159 b: CsMatViewI<B, I, Iptr>,
160 mut c: CsMatViewMutI<N, I, Iptr>,
161 tmp: &mut [N],
162) {
163 assert_eq!(a.rows(), c.rows());
164 assert_eq!(a.cols(), b.rows());
165 assert_eq!(b.cols(), c.cols());
166 assert_eq!(tmp.len(), b.cols());
167 assert!(a.is_csr());
168 assert!(b.is_csr());
169
170 for elt in tmp.iter_mut() {
171 *elt = N::zero();
172 }
173 for (a_row, mut c_row) in a.outer_iterator().zip(c.outer_iterator_mut()) {
174 for (a_col, a_val) in a_row.iter() {
175 let b_row = b.outer_view(a_col.index()).unwrap();
177 for (b_col, b_val) in b_row.iter() {
178 tmp[b_col.index()].mul_acc(a_val, b_val);
180 }
181 }
182 for (c_col, c_val) in c_row.iter_mut() {
183 let mut val = N::zero();
185 std::mem::swap(&mut val, &mut tmp[c_col]);
186 *c_val = val;
187 }
188 }
189}
190
191pub fn mul_csr_csr<N, A, B, I, Iptr>(
197 lhs: CsMatViewI<A, I, Iptr>,
198 rhs: CsMatViewI<B, I, Iptr>,
199) -> CsMatI<N, I, Iptr>
200where
201 N: crate::MulAcc<A, B> + num_traits::Zero + Clone + Send + Sync,
202 A: Send + Sync,
203 B: Send + Sync,
204 I: SpIndex,
205 Iptr: SpIndex,
206{
207 assert_eq!(lhs.cols(), rhs.rows());
208 let workspace_len = rhs.cols();
209 #[cfg(feature = "multi_thread")]
210 let nb_threads = std::cmp::min(lhs.rows().max(1), {
211 use self::ThreadingStrategy::{Automatic, AutomaticPhysical};
212 match thread_threading_strategy() {
213 ThreadingStrategy::Fixed(nb_threads) => nb_threads,
214 strat @ Automatic | strat @ AutomaticPhysical => {
215 let nb_cpus = if strat == ThreadingStrategy::Automatic {
216 num_cpus::get()
217 } else {
218 num_cpus::get_physical()
219 };
220 let ideal_chunk_size = 8128;
221 let wanted_threads = (lhs.nnz() + rhs.nnz()) / ideal_chunk_size;
222 #[allow(clippy::manual_clamp)]
224 1.max(wanted_threads).min(nb_cpus)
225 }
226 }
227 });
228 #[cfg(not(feature = "multi_thread"))]
229 let nb_threads = 1;
230 let mut tmps = Vec::with_capacity(nb_threads);
231 for _ in 0..nb_threads {
232 tmps.push(vec![N::zero(); workspace_len].into_boxed_slice());
233 }
234 let mut seens =
235 vec![vec![false; workspace_len].into_boxed_slice(); nb_threads];
236 mul_csr_csr_with_workspace(lhs, rhs, &mut seens, &mut tmps)
237}
238
239pub fn mul_csr_csr_with_workspace<N, A, B, I, Iptr>(
257 lhs: CsMatViewI<A, I, Iptr>,
258 rhs: CsMatViewI<B, I, Iptr>,
259 seens: &mut [Box<[bool]>],
260 tmps: &mut [Box<[N]>],
261) -> CsMatI<N, I, Iptr>
262where
263 N: crate::MulAcc<A, B> + num_traits::Zero + Clone + Send + Sync,
264 A: Send + Sync,
265 B: Send + Sync,
266 I: SpIndex,
267 Iptr: SpIndex,
268{
269 let workspace_len = rhs.cols();
270 assert_eq!(lhs.cols(), rhs.rows());
271 assert!(seens.iter().all(|x| x.len() == workspace_len));
272 assert!(tmps.iter().all(|x| x.len() == workspace_len));
273 let indptr_len = lhs.rows() + 1;
274 let mut res_indices = Vec::new();
275 let nb_threads = seens.len();
276 assert!(nb_threads > 0);
277 let chunk_size = lhs.indptr().len() / nb_threads;
278 let mut lhs_chunks = Vec::with_capacity(nb_threads);
279 let mut res_indptr_chunks = Vec::with_capacity(nb_threads);
280 let mut res_indices_chunks = Vec::with_capacity(nb_threads);
281 for chunk_id in 0..nb_threads {
282 let start = if chunk_id == 0 {
283 0
284 } else {
285 chunk_id * chunk_size
286 };
287 let stop = if chunk_id + 1 < nb_threads {
288 (chunk_id + 1) * chunk_size
289 } else {
290 lhs.rows()
291 };
292 lhs_chunks.push(lhs.slice_outer(start..stop));
293 res_indptr_chunks.push(vec![Iptr::zero(); stop - start + 1]);
294 res_indices_chunks
295 .push(Vec::with_capacity(lhs.nnz() + rhs.nnz() / chunk_size));
296 }
297 #[cfg(feature = "multi_thread")]
298 let iter = lhs_chunks
299 .par_iter()
300 .zip(res_indptr_chunks.par_iter_mut())
301 .zip(res_indices_chunks.par_iter_mut())
302 .zip(seens.par_iter_mut());
303 #[cfg(not(feature = "multi_thread"))]
304 let iter = lhs_chunks
305 .iter()
306 .zip(res_indptr_chunks.iter_mut())
307 .zip(res_indices_chunks.iter_mut())
308 .zip(seens.iter_mut());
309 iter.for_each(
310 |(((lhs_chunk, res_indptr_chunk), res_indices_chunk), seen)| {
311 symbolic(
312 lhs_chunk.structure_view(),
313 rhs.structure_view(),
314 res_indptr_chunk,
315 res_indices_chunk,
316 seen,
317 );
318 },
319 );
320 res_indices.reserve(res_indices_chunks.iter().map(Vec::len).sum());
321 for res_indices_chunk in &res_indices_chunks {
322 res_indices.extend_from_slice(res_indices_chunk);
323 }
324 let mut res_indptr = Vec::with_capacity(indptr_len);
325 res_indptr.push(Iptr::zero());
326 for res_indptr_chunk in &res_indptr_chunks {
327 for row in res_indptr_chunk.windows(2) {
328 let nnz = row[1] - row[0];
329 res_indptr.push(nnz + *res_indptr.last().unwrap());
330 }
331 }
332 let mut res_data = vec![N::zero(); res_indices.len()];
333 let nb_threads = tmps.len();
334 assert!(nb_threads > 0);
335 let chunk_size = res_indices.len() / nb_threads;
336 let mut res_indices_rem = &res_indices[..];
337 let mut res_data_rem = &mut res_data[..];
338 let mut prev_nnz = 0;
339 let mut split_nnz = 0;
340 let mut split_row = 0;
341 let mut lhs_chunks = Vec::with_capacity(nb_threads);
342 let mut res_indptr_chunks = Vec::with_capacity(nb_threads);
343 let mut res_indices_chunks = Vec::with_capacity(nb_threads);
344 let mut res_data_chunks = Vec::with_capacity(nb_threads);
345 for (row, nnz) in res_indptr.iter().enumerate() {
346 let nnz = nnz.index();
347 if nnz - split_nnz > chunk_size && row > 0 {
348 lhs_chunks.push(lhs.slice_outer(split_row..row - 1));
349
350 res_indptr_chunks.push(&res_indptr[split_row..row]);
351
352 let (left, right) = res_indices_rem
353 .split_at(prev_nnz - res_indptr[split_row].index());
354 res_indices_chunks.push(left);
355 res_indices_rem = right;
356
357 let (left, right) = res_data_rem
360 .split_at_mut(prev_nnz - res_indptr[split_row].index());
361 res_data_chunks.push(left);
362 res_data_rem = right;
363
364 split_nnz = nnz;
365 split_row = row - 1;
366 }
367 prev_nnz = nnz;
368 }
369 lhs_chunks.push(lhs.slice_outer(split_row..lhs.rows()));
370 res_indptr_chunks.push(&res_indptr[split_row..]);
371 res_indices_chunks.push(res_indices_rem);
372 res_data_chunks.push(res_data_rem);
373 #[cfg(feature = "multi_thread")]
374 let iter = lhs_chunks
375 .par_iter()
376 .zip(res_indptr_chunks.par_iter())
377 .zip(res_indices_chunks.par_iter())
378 .zip(res_data_chunks.par_iter_mut())
379 .zip(tmps.par_iter_mut());
380 #[cfg(not(feature = "multi_thread"))]
381 let iter = lhs_chunks
382 .iter()
383 .zip(res_indptr_chunks.iter())
384 .zip(res_indices_chunks.iter())
385 .zip(res_data_chunks.iter_mut())
386 .zip(tmps.iter_mut());
387 iter.for_each(
388 |(
389 (
390 ((lhs_chunk, res_indptr_chunk), res_indices_chunk),
391 res_data_chunk,
392 ),
393 tmp,
394 )| {
395 let res_chunk = CsMatViewMutI::new_trusted(
396 CSR,
397 (lhs_chunk.rows(), rhs.cols()),
398 res_indptr_chunk,
399 res_indices_chunk,
400 res_data_chunk,
401 );
402 numeric(lhs_chunk.view(), rhs.view(), res_chunk, tmp);
403 },
404 );
405
406 CsMatI::new_trusted(
410 CSR,
411 (lhs.rows(), rhs.cols()),
412 res_indptr,
413 res_indices,
414 res_data,
415 )
416}
417
418#[cfg(test)]
419mod test {
420 use crate::test_data;
421
422 #[test]
423 fn symbolic_and_numeric() {
424 let a = test_data::mat1();
425 let b = test_data::mat2();
426 let exp = test_data::mat1_matprod_mat2();
439
440 let mut c_indptr = [0; 6];
441 let mut c_indices = Vec::new();
442 let mut seen = [false; 5];
443
444 super::symbolic(
445 a.structure_view(),
446 b.structure_view(),
447 &mut c_indptr,
448 &mut c_indices,
449 &mut seen,
450 );
451
452 let mut c_data = vec![0.; c_indices.len()];
453 let mut tmp = [0.; 5];
454 let mut c = crate::CsMatViewMutI::new_trusted(
455 crate::CompressedStorage::CSR,
456 (a.rows(), b.cols()),
457 &c_indptr[..],
458 &c_indices[..],
459 &mut c_data[..],
460 );
461 super::numeric(a.view(), b.view(), c.view_mut(), &mut tmp);
462 assert_eq!(exp.indptr(), &c_indptr[..]);
463 assert_eq!(exp.indices(), &c_indices[..]);
464 assert_eq!(exp.data(), &c_data[..]);
465 }
466
467 #[test]
468 fn mul_csr_csr() {
469 let a = test_data::mat1();
470 let exp = test_data::mat1_self_matprod();
471 let res = super::mul_csr_csr(a.view(), a.view());
472 assert_eq!(exp, res);
473 }
474
475 #[test]
476 fn mul_zero_rows() {
477 let a = crate::CsMat::new((0, 11), vec![0], vec![], vec![]);
479 let b = crate::CsMat::new(
480 (11, 11),
481 vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
482 vec![],
483 vec![],
484 );
485 let c: crate::CsMat<f64> = &a * &b;
486 assert_eq!(c.rows(), 0);
487 assert_eq!(c.cols(), 11);
488 assert_eq!(c.nnz(), 0);
489 }
490
491 #[test]
492 #[cfg(feature = "multi_thread")]
493 fn mul_csr_csr_multithreaded() {
494 let a = test_data::mat1();
495 let exp = test_data::mat1_self_matprod();
496 super::set_thread_threading_strategy(super::ThreadingStrategy::Fixed(
497 4,
498 ));
499 let res = super::mul_csr_csr(a.view(), a.view());
500 assert_eq!(exp, res);
501 }
502
503 #[test]
504 #[cfg(feature = "multi_thread")]
505 fn mul_csr_csr_one_long_row_multithreaded() {
506 super::set_thread_threading_strategy(super::ThreadingStrategy::Fixed(
507 4,
508 ));
509 let a = crate::CsVec::<f32>::empty(100);
510 let b = crate::CsMat::<f32>::zero((100, 10)).to_csc();
511
512 let _ = &a * &b;
513 }
514
515 #[test]
516 fn mul_complex() {
517 use num_complex::Complex32;
518 let a = crate::CsMat::new(
523 (4, 4),
524 vec![0, 1, 1, 3, 4],
525 vec![1, 0, 3, 2],
526 vec![
527 Complex32::new(1., 0.),
528 Complex32::new(0., 1.),
529 Complex32::new(1., 1.),
530 Complex32::new(0., 2.),
531 ],
532 );
533 let expected = crate::CsMat::new(
543 (4, 4),
544 vec![0, 0, 0, 2, 4],
545 vec![1, 2, 0, 3],
546 vec![
547 Complex32::new(0., 1.),
548 Complex32::new(-2., 2.),
549 Complex32::new(-2., 0.),
550 Complex32::new(-2., 2.),
551 ],
552 );
553 let b = &a * &a;
554 assert_eq!(b, expected);
555 }
556}