sparse_ir/gemm.rs
1//! Matrix multiplication utilities with pluggable BLAS backend
2//!
3//! This module provides thin wrappers around matrix multiplication operations,
4//! with support for runtime selection of BLAS implementations.
5//!
6//! # Design
7//! - **Default**: Pure Rust Faer backend (no external dependencies)
8//! - **Optional**: External BLAS via function pointer injection
9//! - **Thread-safe**: Global dispatcher protected by RwLock
10//!
11//! # Example
12//! ```ignore
13//! use sparse_ir::gemm::{matmul_par, set_blas_backend};
14//!
15//! // Use default Faer backend
16//! let c = matmul_par(&a, &b);
17//!
18//! // Or inject custom BLAS (from C-API)
19//! unsafe {
20//! set_blas_backend(my_dgemm_ptr, my_zgemm_ptr);
21//! }
22//! let c = matmul_par(&a, &b); // Now uses custom BLAS
23//! ```
24
25use mdarray::{DSlice, DTensor, DView, DViewMut, Layout};
26use once_cell::sync::Lazy;
27use std::sync::{Arc, RwLock};
28
29#[cfg(feature = "system-blas")]
30use blas_sys::dgemm_;
31
32//==============================================================================
33// BLAS Function Pointer Types
34//==============================================================================
35
36/// BLAS dgemm function pointer type (LP64: 32-bit integers)
37///
38/// Signature matches Fortran BLAS dgemm:
39/// ```c
40/// void dgemm_(char *transa, char *transb, int *m, int *n, int *k,
41/// double *alpha, double *a, int *lda, double *b, int *ldb,
42/// double *beta, double *c, int *ldc);
43/// ```
44/// Note: All parameters are passed by reference (pointers).
45/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
46pub type DgemmFnPtr = unsafe extern "C" fn(
47 transa: *const libc::c_char,
48 transb: *const libc::c_char,
49 m: *const libc::c_int,
50 n: *const libc::c_int,
51 k: *const libc::c_int,
52 alpha: *const libc::c_double,
53 a: *const libc::c_double,
54 lda: *const libc::c_int,
55 b: *const libc::c_double,
56 ldb: *const libc::c_int,
57 beta: *const libc::c_double,
58 c: *mut libc::c_double,
59 ldc: *const libc::c_int,
60);
61
62/// BLAS zgemm function pointer type (LP64: 32-bit integers)
63///
64/// Signature matches Fortran BLAS zgemm:
65/// ```c
66/// void zgemm_(char *transa, char *transb, int *m, int *n, int *k,
67/// void *alpha, void *a, int *lda, void *b, int *ldb,
68/// void *beta, void *c, int *ldc);
69/// ```
70/// Note: All parameters are passed by reference (pointers).
71/// Complex numbers are passed as void* (typically complex<double>*).
72/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
73pub type ZgemmFnPtr = unsafe extern "C" fn(
74 transa: *const libc::c_char,
75 transb: *const libc::c_char,
76 m: *const libc::c_int,
77 n: *const libc::c_int,
78 k: *const libc::c_int,
79 alpha: *const num_complex::Complex<f64>,
80 a: *const num_complex::Complex<f64>,
81 lda: *const libc::c_int,
82 b: *const num_complex::Complex<f64>,
83 ldb: *const libc::c_int,
84 beta: *const num_complex::Complex<f64>,
85 c: *mut num_complex::Complex<f64>,
86 ldc: *const libc::c_int,
87);
88
89// When using system BLAS via `blas-sys`, we need a small wrapper to adapt
90// `blas_sys::zgemm_` (which uses `c_double_complex = [f64; 2]`) to the
91// `ZgemmFnPtr` signature that takes `num_complex::Complex<f64>`.
92#[cfg(feature = "system-blas")]
93unsafe extern "C" fn zgemm_wrapper(
94 transa: *const libc::c_char,
95 transb: *const libc::c_char,
96 m: *const libc::c_int,
97 n: *const libc::c_int,
98 k: *const libc::c_int,
99 alpha: *const num_complex::Complex<f64>,
100 a: *const num_complex::Complex<f64>,
101 lda: *const libc::c_int,
102 b: *const num_complex::Complex<f64>,
103 ldb: *const libc::c_int,
104 beta: *const num_complex::Complex<f64>,
105 c: *mut num_complex::Complex<f64>,
106 ldc: *const libc::c_int,
107) {
108 // Safety: `blas_sys::c_double_complex` is defined as `[f64; 2]` and is
109 // layout-compatible with `num_complex::Complex<f64>` in memory, so we can
110 // cast between the two pointer types here.
111 unsafe {
112 blas_sys::zgemm_(
113 transa,
114 transb,
115 m,
116 n,
117 k,
118 alpha as *const _ as *const blas_sys::c_double_complex,
119 a as *const _ as *const blas_sys::c_double_complex,
120 lda,
121 b as *const _ as *const blas_sys::c_double_complex,
122 ldb,
123 beta as *const _ as *const blas_sys::c_double_complex,
124 c as *mut _ as *mut blas_sys::c_double_complex,
125 ldc,
126 );
127 }
128}
129
130/// BLAS dgemm function pointer type (ILP64: 64-bit integers)
131///
132/// Signature matches Fortran BLAS dgemm (ILP64):
133/// ```c
134/// void dgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
135/// double *alpha, double *a, long long *lda, double *b, long long *ldb,
136/// double *beta, double *c, long long *ldc);
137/// ```
138pub type Dgemm64FnPtr = unsafe extern "C" fn(
139 transa: *const libc::c_char,
140 transb: *const libc::c_char,
141 m: *const i64,
142 n: *const i64,
143 k: *const i64,
144 alpha: *const libc::c_double,
145 a: *const libc::c_double,
146 lda: *const i64,
147 b: *const libc::c_double,
148 ldb: *const i64,
149 beta: *const libc::c_double,
150 c: *mut libc::c_double,
151 ldc: *const i64,
152);
153
154/// BLAS zgemm function pointer type (ILP64: 64-bit integers)
155///
156/// Signature matches Fortran BLAS zgemm (ILP64):
157/// ```c
158/// void zgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
159/// void *alpha, void *a, long long *lda, void *b, long long *ldb,
160/// void *beta, void *c, long long *ldc);
161/// ```
162pub type Zgemm64FnPtr = unsafe extern "C" fn(
163 transa: *const libc::c_char,
164 transb: *const libc::c_char,
165 m: *const i64,
166 n: *const i64,
167 k: *const i64,
168 alpha: *const num_complex::Complex<f64>,
169 a: *const num_complex::Complex<f64>,
170 lda: *const i64,
171 b: *const num_complex::Complex<f64>,
172 ldb: *const i64,
173 beta: *const num_complex::Complex<f64>,
174 c: *mut num_complex::Complex<f64>,
175 ldc: *const i64,
176);
177
178//==============================================================================
179// Fortran BLAS Constants
180//==============================================================================
181
182// Fortran BLAS transpose characters
183
184//==============================================================================
185// GemmBackend Trait
186//==============================================================================
187
188/// GEMM backend trait for runtime dispatch
189pub trait GemmBackend: Send + Sync {
190 /// Matrix multiplication: C = A * B (f64)
191 ///
192 /// # Arguments
193 /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
194 /// * `a` - Pointer to matrix A (row-major, M x K)
195 /// * `b` - Pointer to matrix B (row-major, K x N)
196 /// * `c` - Pointer to output matrix C (row-major, M x N)
197 /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
198 unsafe fn dgemm(&self, m: usize, n: usize, k: usize, a: *const f64, b: *const f64, c: *mut f64);
199
200 /// Matrix multiplication: C = A * B (Complex<f64>)
201 ///
202 /// # Arguments
203 /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
204 /// * `a` - Pointer to matrix A (row-major, M x K)
205 /// * `b` - Pointer to matrix B (row-major, K x N)
206 /// * `c` - Pointer to output matrix C (row-major, M x N)
207 /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
208 unsafe fn zgemm(
209 &self,
210 m: usize,
211 n: usize,
212 k: usize,
213 a: *const num_complex::Complex<f64>,
214 b: *const num_complex::Complex<f64>,
215 c: *mut num_complex::Complex<f64>,
216 );
217
218 /// Returns true if this backend uses 64-bit integers (ILP64)
219 fn is_ilp64(&self) -> bool {
220 false
221 }
222
223 /// Returns backend name for debugging
224 fn name(&self) -> &'static str;
225}
226
227//==============================================================================
228// Faer Backend (Default, Pure Rust, Zero-Copy)
229//==============================================================================
230
231/// Default Faer backend (Pure Rust, no external dependencies)
232///
233/// This implementation uses faer's native API with raw pointer views,
234/// achieving zero-copy matrix multiplication without intermediate allocations.
235struct FaerBackend;
236
237impl GemmBackend for FaerBackend {
238 unsafe fn dgemm(
239 &self,
240 m: usize,
241 n: usize,
242 k: usize,
243 a: *const f64,
244 b: *const f64,
245 c: *mut f64,
246 ) {
247 use faer::linalg::matmul::matmul;
248 use faer::mat::{MatMut, MatRef};
249 use faer::{Accum, Par};
250
251 // Create views directly from raw pointers (zero-copy!)
252 // Row-major layout: row_stride = number of columns, col_stride = 1
253 let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
254 let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
255 let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
256
257 // In-place matrix multiplication (no intermediate allocations)
258 matmul(&mut dst, Accum::Replace, &lhs, &rhs, 1.0, Par::Seq);
259 }
260
261 unsafe fn zgemm(
262 &self,
263 m: usize,
264 n: usize,
265 k: usize,
266 a: *const num_complex::Complex<f64>,
267 b: *const num_complex::Complex<f64>,
268 c: *mut num_complex::Complex<f64>,
269 ) {
270 use faer::linalg::matmul::matmul;
271 use faer::mat::{MatMut, MatRef};
272 use faer::{Accum, Par};
273
274 // Create views directly from raw pointers (zero-copy!)
275 // Row-major layout: row_stride = number of columns, col_stride = 1
276 let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
277 let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
278 let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
279
280 // In-place matrix multiplication (no intermediate allocations)
281 matmul(
282 &mut dst,
283 Accum::Replace,
284 &lhs,
285 &rhs,
286 num_complex::Complex::new(1.0, 0.0),
287 Par::Seq,
288 );
289 }
290
291 fn name(&self) -> &'static str {
292 "Faer (Pure Rust)"
293 }
294}
295
296//==============================================================================
297// External BLAS Backends (LP64 and ILP64)
298//==============================================================================
299
300/// Conversion rules for row-major data to column-major BLAS:
301///
302/// **Goal**: Compute C = A * B where:
303/// - A is m×k (row-major)
304/// - B is k×n (row-major)
305/// - C is m×n (row-major)
306///
307/// **Row-major to column-major interpretation**:
308/// - Row-major A (m×k) appears as A^T (k×m) in column-major → call this At
309/// - Row-major B (k×n) appears as B^T (n×k) in column-major → call this Bt
310/// - Row-major C (m×n) appears as C^T (n×m) in column-major → call this Ct
311/// - To compute C = A * B, we need: C^T = (A * B)^T = B^T * A^T
312/// - So: Ct = Bt * At
313///
314/// **BLAS call transformation**:
315/// - Original: C = A * B (row-major world)
316/// - BLAS call: Ct = Bt * At (column-major world)
317/// - transa = 'N' (Bt is already transposed-looking, no transpose needed)
318/// - transb = 'N' (At is already transposed-looking, no transpose needed)
319/// - Call: dgemm('N', 'N', n, m, k, alpha, B, lda, A, ldb, beta, C, ldc)
320///
321/// **Dimension conversions**:
322/// - m_blas = n (Ct rows = Bt rows)
323/// - n_blas = m (Ct cols = At cols)
324/// - k_blas = k (common dimension)
325/// - lda = n (leading dimension of Bt: n×k in column-major, lda = n)
326/// - ldb = k (leading dimension of At: k×m in column-major, ldb = k)
327/// - ldc = n (leading dimension of Ct: n×m in column-major, ldc = n)
328
329/// External BLAS backend (LP64: 32-bit integers)
330pub struct ExternalBlasBackend {
331 dgemm: DgemmFnPtr,
332 zgemm: ZgemmFnPtr,
333}
334
335impl ExternalBlasBackend {
336 pub fn new(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) -> Self {
337 Self { dgemm, zgemm }
338 }
339}
340
341impl GemmBackend for ExternalBlasBackend {
342 unsafe fn dgemm(
343 &self,
344 m: usize,
345 n: usize,
346 k: usize,
347 a: *const f64,
348 b: *const f64,
349 c: *mut f64,
350 ) {
351 // Validate dimensions fit in i32
352 assert!(
353 m <= i32::MAX as usize,
354 "Matrix dimension m too large for LP64 BLAS"
355 );
356 assert!(
357 n <= i32::MAX as usize,
358 "Matrix dimension n too large for LP64 BLAS"
359 );
360 assert!(
361 k <= i32::MAX as usize,
362 "Matrix dimension k too large for LP64 BLAS"
363 );
364
365 // Fortran BLAS requires all parameters passed by reference
366 // Apply row-major to column-major conversion (see conversion rules above)
367 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
368 let transb = b'N' as libc::c_char; // At is already transposed-looking
369 let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
370 let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
371 let k_i32 = k as i32; // k_blas = k (common dimension)
372 let alpha = 1.0f64;
373 let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
374 let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
375 let beta = 0.0f64;
376 // For row-major C (m×n) viewed as column-major Ct (n×m):
377 // Leading dimension in column-major is the stride between rows
378 // In row-major, stride between rows = number of columns = n
379 // So ldc = n (the number of columns in the original row-major matrix)
380 let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
381
382 unsafe {
383 (self.dgemm)(
384 &transa, &transb, &m_i32, &n_i32, &k_i32, &alpha, b, // B first (Bt)
385 &lda, a, // A second (At)
386 &ldb, &beta, c, &ldc_i32,
387 );
388 }
389 }
390
391 unsafe fn zgemm(
392 &self,
393 m: usize,
394 n: usize,
395 k: usize,
396 a: *const num_complex::Complex<f64>,
397 b: *const num_complex::Complex<f64>,
398 c: *mut num_complex::Complex<f64>,
399 ) {
400 assert!(
401 m <= i32::MAX as usize,
402 "Matrix dimension m too large for LP64 BLAS"
403 );
404 assert!(
405 n <= i32::MAX as usize,
406 "Matrix dimension n too large for LP64 BLAS"
407 );
408 assert!(
409 k <= i32::MAX as usize,
410 "Matrix dimension k too large for LP64 BLAS"
411 );
412
413 // Fortran BLAS requires all parameters passed by reference
414 // Apply row-major to column-major conversion (see conversion rules above)
415 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
416 let transb = b'N' as libc::c_char; // At is already transposed-looking
417 let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
418 let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
419 let k_i32 = k as i32; // k_blas = k (common dimension)
420 let alpha = num_complex::Complex::new(1.0, 0.0);
421 let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
422 let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
423 let beta = num_complex::Complex::new(0.0, 0.0);
424 // For row-major C (m×n) viewed as column-major Ct (n×m):
425 // Leading dimension in column-major is the stride between rows = n
426 let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
427
428 unsafe {
429 (self.zgemm)(
430 &transa,
431 &transb,
432 &m_i32,
433 &n_i32,
434 &k_i32,
435 &alpha,
436 b as *const _, // B first (Bt)
437 &lda,
438 a as *const _, // A second (At)
439 &ldb,
440 &beta,
441 c as *mut _,
442 &ldc_i32,
443 );
444 }
445 }
446
447 fn name(&self) -> &'static str {
448 "External BLAS (LP64)"
449 }
450}
451
452/// External BLAS backend (ILP64: 64-bit integers)
453pub struct ExternalBlas64Backend {
454 dgemm64: Dgemm64FnPtr,
455 zgemm64: Zgemm64FnPtr,
456}
457
458impl ExternalBlas64Backend {
459 pub fn new(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) -> Self {
460 Self { dgemm64, zgemm64 }
461 }
462}
463
464impl GemmBackend for ExternalBlas64Backend {
465 unsafe fn dgemm(
466 &self,
467 m: usize,
468 n: usize,
469 k: usize,
470 a: *const f64,
471 b: *const f64,
472 c: *mut f64,
473 ) {
474 // Fortran BLAS requires all parameters passed by reference
475 // Apply row-major to column-major conversion (see conversion rules above)
476 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
477 let transb = b'N' as libc::c_char; // At is already transposed-looking
478 let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
479 let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
480 let k_i64 = k as i64; // k_blas = k (common dimension)
481 let alpha = 1.0f64;
482 let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
483 let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
484 let beta = 0.0f64;
485 // For row-major C (m×n) viewed as column-major Ct (n×m):
486 // Leading dimension in column-major is the stride between rows = n
487 let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
488
489 unsafe {
490 (self.dgemm64)(
491 &transa, &transb, &m_i64, &n_i64, &k_i64, &alpha, b, // B first (Bt)
492 &lda, a, // A second (At)
493 &ldb, &beta, c, &ldc_i64,
494 );
495 }
496 }
497
498 unsafe fn zgemm(
499 &self,
500 m: usize,
501 n: usize,
502 k: usize,
503 a: *const num_complex::Complex<f64>,
504 b: *const num_complex::Complex<f64>,
505 c: *mut num_complex::Complex<f64>,
506 ) {
507 // Fortran BLAS requires all parameters passed by reference
508 // Apply row-major to column-major conversion (see conversion rules above)
509 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
510 let transb = b'N' as libc::c_char; // At is already transposed-looking
511 let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
512 let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
513 let k_i64 = k as i64; // k_blas = k (common dimension)
514 let alpha = num_complex::Complex::new(1.0, 0.0);
515 let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
516 let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
517 let beta = num_complex::Complex::new(0.0, 0.0);
518 // For row-major C (m×n) viewed as column-major Ct (n×m):
519 // Leading dimension in column-major is the stride between rows = n
520 let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
521
522 unsafe {
523 (self.zgemm64)(
524 &transa,
525 &transb,
526 &m_i64,
527 &n_i64,
528 &k_i64,
529 &alpha,
530 b as *const _, // B first (Bt)
531 &lda,
532 a as *const _, // A second (At)
533 &ldb,
534 &beta,
535 c as *mut _,
536 &ldc_i64,
537 );
538 }
539 }
540
541 fn is_ilp64(&self) -> bool {
542 true
543 }
544
545 fn name(&self) -> &'static str {
546 "External BLAS (ILP64)"
547 }
548}
549
550//==============================================================================
551// Backend Handle
552//==============================================================================
553
554/// Thread-safe handle to a GEMM backend
555///
556/// This type wraps an `Arc<dyn GemmBackend>` to allow sharing a backend
557/// across multiple function calls without global state.
558///
559/// # Example
560/// ```ignore
561/// use sparse_ir::gemm::GemmBackendHandle;
562///
563/// let backend = GemmBackendHandle::default();
564/// let result = matmul_par(&a, &b, Some(&backend));
565/// ```
566#[derive(Clone)]
567pub struct GemmBackendHandle {
568 inner: Arc<dyn GemmBackend>,
569}
570
571impl GemmBackendHandle {
572 /// Create a new backend handle from a boxed backend
573 pub fn new(backend: Box<dyn GemmBackend>) -> Self {
574 Self {
575 inner: Arc::from(backend),
576 }
577 }
578
579 /// Create a default backend handle (Faer backend)
580 pub fn default() -> Self {
581 Self {
582 inner: Arc::new(FaerBackend),
583 }
584 }
585
586 /// Get a reference to the inner backend
587 pub(crate) fn as_ref(&self) -> &dyn GemmBackend {
588 self.inner.as_ref()
589 }
590}
591
592//==============================================================================
593// Global Dispatcher (for backward compatibility)
594//==============================================================================
595
596/// Global BLAS dispatcher (thread-safe)
597///
598/// This is kept for backward compatibility when `None` is passed as backend.
599/// New code should use `GemmBackendHandle` explicitly.
600static BLAS_DISPATCHER: Lazy<RwLock<Box<dyn GemmBackend>>> = Lazy::new(|| {
601 #[cfg(feature = "system-blas")]
602 {
603 // Use system BLAS (LP64) by default via `blas-sys`.
604 let backend = ExternalBlasBackend::new(dgemm_ as DgemmFnPtr, zgemm_wrapper as ZgemmFnPtr);
605 RwLock::new(Box::new(backend) as Box<dyn GemmBackend>)
606 }
607 #[cfg(not(feature = "system-blas"))]
608 {
609 // Default to the pure Rust Faer backend.
610 RwLock::new(Box::new(FaerBackend) as Box<dyn GemmBackend>)
611 }
612});
613
614/// Set BLAS backend (LP64: 32-bit integers)
615///
616/// # Safety
617/// - Function pointers must be valid and thread-safe
618/// - Must remain valid for the lifetime of the program
619/// - Must follow Fortran BLAS calling convention
620///
621/// # Example
622/// ```ignore
623/// unsafe {
624/// set_blas_backend(dgemm_ as _, zgemm_ as _);
625/// }
626/// ```
627pub unsafe fn set_blas_backend(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) {
628 let backend = ExternalBlasBackend { dgemm, zgemm };
629 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
630 *dispatcher = Box::new(backend);
631}
632
633/// Set ILP64 BLAS backend (64-bit integers)
634///
635/// # Safety
636/// - Function pointers must be valid, thread-safe, and use 64-bit integers
637/// - Must remain valid for the lifetime of the program
638/// - Must follow Fortran BLAS calling convention with ILP64 interface
639///
640/// # Example
641/// ```ignore
642/// unsafe {
643/// set_ilp64_backend(dgemm_ as _, zgemm_ as _);
644/// }
645/// ```
646pub unsafe fn set_ilp64_backend(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) {
647 let backend = ExternalBlas64Backend { dgemm64, zgemm64 };
648 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
649 *dispatcher = Box::new(backend);
650}
651
652/// Clear BLAS backend (reset to default Faer)
653///
654/// This function resets the GEMM dispatcher to use the default Pure Rust Faer backend.
655pub fn clear_blas_backend() {
656 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
657 *dispatcher = Box::new(FaerBackend);
658}
659
660/// Get current BLAS backend information
661///
662/// Returns:
663/// - `(backend_name, is_external, is_ilp64)`
664pub fn get_backend_info() -> (&'static str, bool, bool) {
665 let dispatcher = BLAS_DISPATCHER.read().unwrap();
666 let name = dispatcher.name();
667 let is_external = !name.contains("Faer");
668 let is_ilp64 = dispatcher.is_ilp64();
669 (name, is_external, is_ilp64)
670}
671
672//==============================================================================
673// Public API
674//==============================================================================
675
676/// Parallel matrix multiplication: C = A * B
677///
678/// Dispatches to the provided backend, or the global dispatcher if `None`.
679///
680/// # Arguments
681/// * `a` - Left matrix (M x K)
682/// * `b` - Right matrix (K x N)
683/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
684///
685/// # Returns
686/// Result matrix (M x N)
687///
688/// # Panics
689/// Panics if matrix dimensions are incompatible (A.cols != B.rows)
690///
691/// # Example
692/// ```ignore
693/// use mdarray::tensor;
694/// use sparse_ir::gemm::{matmul_par, GemmBackendHandle};
695///
696/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
697/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
698/// let backend = GemmBackendHandle::default();
699/// let c = matmul_par(&a, &b, Some(&backend));
700/// // c = [[19.0, 22.0], [43.0, 50.0]]
701/// ```
702pub fn matmul_par<T>(
703 a: &DTensor<T, 2>,
704 b: &DTensor<T, 2>,
705 backend: Option<&GemmBackendHandle>,
706) -> DTensor<T, 2>
707where
708 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
709{
710 let (_m, k) = *a.shape();
711 let (k2, _n) = *b.shape();
712
713 // Validate dimensions
714 assert_eq!(
715 k, k2,
716 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
717 k, k2
718 );
719
720 // Use Faer directly to avoid creating intermediate DTensors through backend
721 // create _m x _n result tensor
722 let mut result = DTensor::<T, 2>::from_elem([_m, _n], T::zero().into());
723 matmul_par_overwrite(a, b, &mut result, backend);
724 result
725}
726
727/// Parallel matrix multiplication accepting DView (assumes contiguous memory)
728///
729/// This function accepts `DView` instead of `DTensor`, allowing views of arrays
730/// to be used directly without copying. The view must have contiguous memory layout.
731///
732/// # Arguments
733/// * `a` - Left matrix view (M x K) - must be contiguous
734/// * `b` - Right matrix view (K x N) - must be contiguous
735/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
736///
737/// # Panics
738/// Panics if:
739/// - Matrix dimensions are incompatible (A.cols != B.rows)
740/// - Views are not contiguous in memory
741///
742/// # Example
743/// ```ignore
744/// use mdarray::DView;
745///
746/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
747/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
748/// let a_view: DView<'_, f64, 2> = a.view(..);
749/// let b_view: DView<'_, f64, 2> = b.view(..);
750/// let c = matmul_par_view(&a_view, &b_view, None);
751/// ```
752pub fn matmul_par_view<T>(
753 a: &DView<'_, T, 2>,
754 b: &DView<'_, T, 2>,
755 backend: Option<&GemmBackendHandle>,
756) -> DTensor<T, 2>
757where
758 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
759{
760 // Check that views are contiguous (required for BLAS operations)
761 assert!(
762 a.is_contiguous(),
763 "Matrix A view must be contiguous in memory"
764 );
765 assert!(
766 b.is_contiguous(),
767 "Matrix B view must be contiguous in memory"
768 );
769
770 let (m, k) = *a.shape();
771 let (k2, n) = *b.shape();
772
773 // Validate dimensions
774 assert_eq!(
775 k, k2,
776 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
777 k, k2
778 );
779
780 // Create result tensor
781 let mut result = DTensor::<T, 2>::from_elem([m, n], T::zero().into());
782 matmul_par_overwrite_view(a, b, &mut result, backend);
783 result
784}
785
786/// Parallel matrix multiplication with overwrite accepting DView (assumes contiguous memory)
787///
788/// This function writes the result directly into the provided buffer `c`,
789/// accepting `DView` inputs. The views must have contiguous memory layout.
790///
791/// # Arguments
792/// * `a` - Left matrix view (M x K) - must be contiguous
793/// * `b` - Right matrix view (K x N) - must be contiguous
794/// * `c` - Output matrix (M x N) - will be overwritten with result
795/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
796///
797/// # Panics
798/// Panics if:
799/// - Matrix dimensions are incompatible
800/// - Views are not contiguous in memory
801pub fn matmul_par_overwrite_view<T>(
802 a: &DView<'_, T, 2>,
803 b: &DView<'_, T, 2>,
804 c: &mut DTensor<T, 2>,
805 backend: Option<&GemmBackendHandle>,
806) where
807 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
808{
809 // Check that views are contiguous (required for BLAS operations)
810 assert!(
811 a.is_contiguous(),
812 "Matrix A view must be contiguous in memory"
813 );
814 assert!(
815 b.is_contiguous(),
816 "Matrix B view must be contiguous in memory"
817 );
818
819 let (m, k) = *a.shape();
820 let (k2, n) = *b.shape();
821 let (mc, nc) = *c.shape();
822
823 // Validate dimensions
824 assert_eq!(
825 k, k2,
826 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
827 k, k2
828 );
829 assert_eq!(
830 m, mc,
831 "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
832 mc, m
833 );
834 assert_eq!(
835 n, nc,
836 "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
837 nc, n
838 );
839
840 // Type dispatch: f64 or Complex<f64>
841 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
842 // f64 case
843 // Get pointers from views (contiguous memory assumed)
844 let a_ptr = a.as_ptr() as *const f64;
845 let b_ptr = b.as_ptr() as *const f64;
846 let c_ptr = c.as_mut_ptr() as *mut f64;
847
848 // Get backend: use provided handle or fall back to global dispatcher
849 match backend {
850 Some(handle) => unsafe {
851 handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
852 },
853 None => {
854 let dispatcher = BLAS_DISPATCHER.read().unwrap();
855 unsafe {
856 dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
857 }
858 }
859 }
860 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
861 // Complex<f64> case
862 let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
863 let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
864 let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
865
866 match backend {
867 Some(handle) => unsafe {
868 handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
869 },
870 None => {
871 let dispatcher = BLAS_DISPATCHER.read().unwrap();
872 unsafe {
873 dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
874 }
875 }
876 }
877 } else {
878 // Fallback to Faer for unsupported types
879 // Convert views to DTensors for Faer (this will copy, but only for unsupported types)
880 let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
881 let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
882 use mdarray_linalg::matmul::MatMulBuilder;
883 use mdarray_linalg::prelude::MatMul;
884 use mdarray_linalg_faer::Faer;
885
886 Faer.matmul(&a_tensor, &b_tensor).parallelize().overwrite(c);
887 }
888}
889
890/// Parallel matrix multiplication with overwrite to mutable view: C = A * B
891///
892/// This function writes the result directly into the provided mutable view `c`,
893/// allowing zero-copy writes to pre-allocated buffers (e.g., C pointers via FFI).
894///
895/// # Arguments
896/// * `a` - Left matrix view (M x K), must be contiguous
897/// * `b` - Right matrix view (K x N), must be contiguous
898/// * `c` - Output mutable view (M x N), must be contiguous - will be overwritten
899/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
900///
901/// # Panics
902/// Panics if:
903/// - Matrix dimensions are incompatible
904/// - Views are not contiguous in memory
905pub fn matmul_par_to_viewmut<T>(
906 a: &DView<'_, T, 2>,
907 b: &DView<'_, T, 2>,
908 c: &mut DViewMut<'_, T, 2>,
909 backend: Option<&GemmBackendHandle>,
910) where
911 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
912{
913 // Check that views are contiguous (required for BLAS operations)
914 assert!(
915 a.is_contiguous(),
916 "Matrix A view must be contiguous in memory"
917 );
918 assert!(
919 b.is_contiguous(),
920 "Matrix B view must be contiguous in memory"
921 );
922 assert!(
923 c.is_contiguous(),
924 "Matrix C view must be contiguous in memory"
925 );
926
927 let (m, k) = *a.shape();
928 let (k2, n) = *b.shape();
929 let (mc, nc) = *c.shape();
930
931 // Validate dimensions
932 assert_eq!(
933 k, k2,
934 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
935 k, k2
936 );
937 assert_eq!(
938 m, mc,
939 "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
940 mc, m
941 );
942 assert_eq!(
943 n, nc,
944 "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
945 nc, n
946 );
947
948 // Type dispatch: f64 or Complex<f64>
949 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
950 // f64 case
951 let a_ptr = a.as_ptr() as *const f64;
952 let b_ptr = b.as_ptr() as *const f64;
953 let c_ptr = c.as_mut_ptr() as *mut f64;
954
955 match backend {
956 Some(handle) => unsafe {
957 handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
958 },
959 None => {
960 let dispatcher = BLAS_DISPATCHER.read().unwrap();
961 unsafe {
962 dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
963 }
964 }
965 }
966 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
967 // Complex<f64> case
968 let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
969 let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
970 let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
971
972 match backend {
973 Some(handle) => unsafe {
974 handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
975 },
976 None => {
977 let dispatcher = BLAS_DISPATCHER.read().unwrap();
978 unsafe {
979 dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
980 }
981 }
982 }
983 } else {
984 // Fallback: convert to DTensor (will copy)
985 let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
986 let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
987 let mut c_tensor = DTensor::<T, 2>::from_fn(*c.shape(), |_| T::zero());
988 use mdarray_linalg::matmul::MatMulBuilder;
989 use mdarray_linalg::prelude::MatMul;
990 use mdarray_linalg_faer::Faer;
991
992 Faer.matmul(&a_tensor, &b_tensor)
993 .parallelize()
994 .overwrite(&mut c_tensor);
995
996 // Copy back to view
997 for i in 0..mc {
998 for j in 0..nc {
999 c[[i, j]] = c_tensor[[i, j]];
1000 }
1001 }
1002 }
1003}
1004
1005/// Parallel matrix multiplication with overwrite: C = A * B (writes to existing buffer)
1006///
1007/// This function writes the result directly into the provided buffer `c`,
1008/// avoiding memory allocation. This is more memory-efficient for repeated operations.
1009///
1010/// # Arguments
1011/// * `a` - Left matrix (M x K)
1012/// * `b` - Right matrix (K x N)
1013/// * `c` - Output matrix (M x N) - will be overwritten with result
1014/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
1015///
1016/// # Panics
1017/// Panics if matrix dimensions are incompatible (A.cols != B.rows or C.shape != [M, N])
1018pub fn matmul_par_overwrite<T, Lc: Layout>(
1019 a: &DTensor<T, 2>,
1020 b: &DTensor<T, 2>,
1021 c: &mut DSlice<T, 2, Lc>,
1022 backend: Option<&GemmBackendHandle>,
1023) where
1024 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
1025{
1026 let (m, k) = *a.shape();
1027 let (k2, n) = *b.shape();
1028 let (mc, nc) = *c.shape();
1029
1030 // Validate dimensions
1031 assert_eq!(
1032 k, k2,
1033 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
1034 k, k2
1035 );
1036 assert_eq!(
1037 m, mc,
1038 "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
1039 mc, m
1040 );
1041 assert_eq!(
1042 n, nc,
1043 "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
1044 nc, n
1045 );
1046
1047 // Type dispatch: f64 or Complex<f64>
1048 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
1049 // f64 case
1050 // Get pointers directly from DTensors (row-major order)
1051 let a_ptr = a.as_ptr() as *const f64;
1052 let b_ptr = b.as_ptr() as *const f64;
1053 let c_ptr = c.as_mut_ptr() as *mut f64;
1054
1055 // Get backend: use provided handle or fall back to global dispatcher
1056 match backend {
1057 Some(handle) => {
1058 // Call backend directly with pointers (no temporary buffer needed)
1059 // Leading dimension is calculated internally in the backend
1060 unsafe {
1061 handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1062 }
1063 }
1064 None => {
1065 // Backward compatibility: use global dispatcher
1066 let dispatcher = BLAS_DISPATCHER.read().unwrap();
1067 unsafe {
1068 dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1069 }
1070 }
1071 }
1072 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
1073 // Complex<f64> case
1074 // Get pointers directly from DTensors (row-major order)
1075 let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
1076 let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
1077 let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
1078
1079 // Get backend: use provided handle or fall back to global dispatcher
1080 match backend {
1081 Some(handle) => {
1082 // Call backend directly with pointers (no temporary buffer needed)
1083 // Leading dimension is calculated internally in the backend
1084 unsafe {
1085 handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1086 }
1087 }
1088 None => {
1089 // Backward compatibility: use global dispatcher
1090 let dispatcher = BLAS_DISPATCHER.read().unwrap();
1091 unsafe {
1092 dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1093 }
1094 }
1095 }
1096 } else {
1097 // Fallback to Faer for unsupported types
1098 use mdarray_linalg::matmul::MatMulBuilder;
1099 use mdarray_linalg::prelude::MatMul;
1100 use mdarray_linalg_faer::Faer;
1101
1102 Faer.matmul(a, b).parallelize().overwrite(c);
1103 }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108 use super::*;
1109 use mdarray::DView;
1110
1111 #[test]
1112 #[cfg(not(feature = "system-blas"))]
1113 fn test_default_backend_is_faer() {
1114 let (name, is_external, is_ilp64) = get_backend_info();
1115 assert_eq!(name, "Faer (Pure Rust)");
1116 assert!(!is_external);
1117 assert!(!is_ilp64);
1118 }
1119
1120 #[test]
1121 fn test_matmul_par_view() {
1122 // Test with f64
1123 let a = DTensor::<f64, 2>::from([[1.0, 2.0], [3.0, 4.0]]);
1124 let b = DTensor::<f64, 2>::from([[5.0, 6.0], [7.0, 8.0]]);
1125 let a_view: DView<'_, f64, 2> = a.view(.., ..);
1126 let b_view: DView<'_, f64, 2> = b.view(.., ..);
1127
1128 let c_view = matmul_par_view(&a_view, &b_view, None);
1129 let c_expected = matmul_par(&a, &b, None);
1130
1131 // Results should be identical
1132 assert_eq!(c_view.shape(), c_expected.shape());
1133 for i in 0..c_view.shape().0 {
1134 for j in 0..c_view.shape().1 {
1135 assert!((c_view[[i, j]] - c_expected[[i, j]]).abs() < 1e-10);
1136 }
1137 }
1138 }
1139
1140 #[test]
1141 fn test_matmul_par_overwrite_view() {
1142 // Test with Complex<f64>
1143 use num_complex::Complex;
1144 let a = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
1145 Complex::new((idx[0] * 2 + idx[1]) as f64, 0.0)
1146 });
1147 let b = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
1148 Complex::new((idx[0] * 2 + idx[1] + 10) as f64, 0.0)
1149 });
1150 let a_view: DView<'_, Complex<f64>, 2> = a.view(.., ..);
1151 let b_view: DView<'_, Complex<f64>, 2> = b.view(.., ..);
1152
1153 let mut c_view = DTensor::<Complex<f64>, 2>::from_elem([2, 2], Complex::new(0.0, 0.0));
1154 matmul_par_overwrite_view(&a_view, &b_view, &mut c_view, None);
1155
1156 let c_expected = matmul_par(&a, &b, None);
1157
1158 // Results should be identical
1159 assert_eq!(c_view.shape(), c_expected.shape());
1160 for i in 0..c_view.shape().0 {
1161 for j in 0..c_view.shape().1 {
1162 assert!((c_view[[i, j]] - c_expected[[i, j]]).norm() < 1e-10);
1163 }
1164 }
1165 }
1166
1167 #[test]
1168 fn test_clear_backend() {
1169 // Should not panic
1170 clear_blas_backend();
1171 let (name, _, _) = get_backend_info();
1172 assert_eq!(name, "Faer (Pure Rust)");
1173 }
1174
1175 #[test]
1176 fn test_matmul_f64() {
1177 let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1178 let b_data = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1179
1180 let a = DTensor::<f64, 2>::from_fn([2, 3], |idx| a_data[idx[0] * 3 + idx[1]]);
1181 let b = DTensor::<f64, 2>::from_fn([3, 2], |idx| b_data[idx[0] * 2 + idx[1]]);
1182 let c = matmul_par(&a, &b, None);
1183
1184 assert_eq!(*c.shape(), (2, 2));
1185 // First row: [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
1186 // Second row: [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
1187 assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
1188 assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
1189 assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
1190 assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
1191 }
1192
1193 #[test]
1194 fn test_matmul_par_basic() {
1195 use mdarray::tensor;
1196 let a: DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
1197 let b: DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
1198 let c = matmul_par(&a, &b, None);
1199
1200 // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1201 // = [[19, 22], [43, 50]]
1202 assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
1203 assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
1204 assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
1205 assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
1206 }
1207
1208 #[test]
1209 fn test_matmul_par_non_square() {
1210 use mdarray::tensor;
1211 let a: DTensor<f64, 2> = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // 2x3
1212 let b: DTensor<f64, 2> = tensor![[7.0], [8.0], [9.0]]; // 3x1
1213 let c = matmul_par(&a, &b, None);
1214
1215 // Expected: [[1*7+2*8+3*9], [4*7+5*8+6*9]]
1216 // = [[50], [122]]
1217 assert!((c[[0, 0]] - 50.0).abs() < 1e-10);
1218 assert!((c[[1, 0]] - 122.0).abs() < 1e-10);
1219 }
1220}