sparse_ir/
gemm.rs

1//! Matrix multiplication utilities with pluggable BLAS backend
2//!
3//! This module provides thin wrappers around matrix multiplication operations,
4//! with support for runtime selection of BLAS implementations.
5//!
6//! # Design
7//! - **Default**: Pure Rust Faer backend (no external dependencies)
8//! - **Optional**: External BLAS via function pointer injection
9//! - **Thread-safe**: Global dispatcher protected by RwLock
10//!
11//! # Example
12//! ```ignore
13//! use sparse_ir::gemm::{matmul_par, set_blas_backend};
14//!
15//! // Use default Faer backend
16//! let c = matmul_par(&a, &b);
17//!
18//! // Or inject custom BLAS (from C-API)
19//! unsafe {
20//!     set_blas_backend(my_dgemm_ptr, my_zgemm_ptr);
21//! }
22//! let c = matmul_par(&a, &b);  // Now uses custom BLAS
23//! ```
24
25use mdarray::{DSlice, DTensor, Layout};
26use once_cell::sync::Lazy;
27use std::sync::{Arc, RwLock};
28
29#[cfg(feature = "system-blas")]
30use blas_sys::dgemm_;
31
32//==============================================================================
33// BLAS Function Pointer Types
34//==============================================================================
35
36/// BLAS dgemm function pointer type (LP64: 32-bit integers)
37///
38/// Signature matches Fortran BLAS dgemm:
39/// ```c
40/// void dgemm_(char *transa, char *transb, int *m, int *n, int *k,
41///             double *alpha, double *a, int *lda, double *b, int *ldb,
42///             double *beta, double *c, int *ldc);
43/// ```
44/// Note: All parameters are passed by reference (pointers).
45/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
46pub type DgemmFnPtr = unsafe extern "C" fn(
47    transa: *const libc::c_char,
48    transb: *const libc::c_char,
49    m: *const libc::c_int,
50    n: *const libc::c_int,
51    k: *const libc::c_int,
52    alpha: *const libc::c_double,
53    a: *const libc::c_double,
54    lda: *const libc::c_int,
55    b: *const libc::c_double,
56    ldb: *const libc::c_int,
57    beta: *const libc::c_double,
58    c: *mut libc::c_double,
59    ldc: *const libc::c_int,
60);
61
62/// BLAS zgemm function pointer type (LP64: 32-bit integers)
63///
64/// Signature matches Fortran BLAS zgemm:
65/// ```c
66/// void zgemm_(char *transa, char *transb, int *m, int *n, int *k,
67///             void *alpha, void *a, int *lda, void *b, int *ldb,
68///             void *beta, void *c, int *ldc);
69/// ```
70/// Note: All parameters are passed by reference (pointers).
71/// Complex numbers are passed as void* (typically complex<double>*).
72/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
73pub type ZgemmFnPtr = unsafe extern "C" fn(
74    transa: *const libc::c_char,
75    transb: *const libc::c_char,
76    m: *const libc::c_int,
77    n: *const libc::c_int,
78    k: *const libc::c_int,
79    alpha: *const num_complex::Complex<f64>,
80    a: *const num_complex::Complex<f64>,
81    lda: *const libc::c_int,
82    b: *const num_complex::Complex<f64>,
83    ldb: *const libc::c_int,
84    beta: *const num_complex::Complex<f64>,
85    c: *mut num_complex::Complex<f64>,
86    ldc: *const libc::c_int,
87);
88
89// When using system BLAS via `blas-sys`, we need a small wrapper to adapt
90// `blas_sys::zgemm_` (which uses `c_double_complex = [f64; 2]`) to the
91// `ZgemmFnPtr` signature that takes `num_complex::Complex<f64>`.
92#[cfg(feature = "system-blas")]
93unsafe extern "C" fn zgemm_wrapper(
94    transa: *const libc::c_char,
95    transb: *const libc::c_char,
96    m: *const libc::c_int,
97    n: *const libc::c_int,
98    k: *const libc::c_int,
99    alpha: *const num_complex::Complex<f64>,
100    a: *const num_complex::Complex<f64>,
101    lda: *const libc::c_int,
102    b: *const num_complex::Complex<f64>,
103    ldb: *const libc::c_int,
104    beta: *const num_complex::Complex<f64>,
105    c: *mut num_complex::Complex<f64>,
106    ldc: *const libc::c_int,
107) {
108    // Safety: `blas_sys::c_double_complex` is defined as `[f64; 2]` and is
109    // layout-compatible with `num_complex::Complex<f64>` in memory, so we can
110    // cast between the two pointer types here.
111    unsafe {
112        blas_sys::zgemm_(
113            transa,
114            transb,
115            m,
116            n,
117            k,
118            alpha as *const _ as *const blas_sys::c_double_complex,
119            a as *const _ as *const blas_sys::c_double_complex,
120            lda,
121            b as *const _ as *const blas_sys::c_double_complex,
122            ldb,
123            beta as *const _ as *const blas_sys::c_double_complex,
124            c as *mut _ as *mut blas_sys::c_double_complex,
125            ldc,
126        );
127    }
128}
129
130/// BLAS dgemm function pointer type (ILP64: 64-bit integers)
131///
132/// Signature matches Fortran BLAS dgemm (ILP64):
133/// ```c
134/// void dgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
135///             double *alpha, double *a, long long *lda, double *b, long long *ldb,
136///             double *beta, double *c, long long *ldc);
137/// ```
138pub type Dgemm64FnPtr = unsafe extern "C" fn(
139    transa: *const libc::c_char,
140    transb: *const libc::c_char,
141    m: *const i64,
142    n: *const i64,
143    k: *const i64,
144    alpha: *const libc::c_double,
145    a: *const libc::c_double,
146    lda: *const i64,
147    b: *const libc::c_double,
148    ldb: *const i64,
149    beta: *const libc::c_double,
150    c: *mut libc::c_double,
151    ldc: *const i64,
152);
153
154/// BLAS zgemm function pointer type (ILP64: 64-bit integers)
155///
156/// Signature matches Fortran BLAS zgemm (ILP64):
157/// ```c
158/// void zgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
159///             void *alpha, void *a, long long *lda, void *b, long long *ldb,
160///             void *beta, void *c, long long *ldc);
161/// ```
162pub type Zgemm64FnPtr = unsafe extern "C" fn(
163    transa: *const libc::c_char,
164    transb: *const libc::c_char,
165    m: *const i64,
166    n: *const i64,
167    k: *const i64,
168    alpha: *const num_complex::Complex<f64>,
169    a: *const num_complex::Complex<f64>,
170    lda: *const i64,
171    b: *const num_complex::Complex<f64>,
172    ldb: *const i64,
173    beta: *const num_complex::Complex<f64>,
174    c: *mut num_complex::Complex<f64>,
175    ldc: *const i64,
176);
177
178//==============================================================================
179// Fortran BLAS Constants
180//==============================================================================
181
182// Fortran BLAS transpose characters
183
184//==============================================================================
185// GemmBackend Trait
186//==============================================================================
187
188/// GEMM backend trait for runtime dispatch
189pub trait GemmBackend: Send + Sync {
190    /// Matrix multiplication: C = A * B (f64)
191    ///
192    /// # Arguments
193    /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
194    /// * `a` - Pointer to matrix A (row-major, M x K)
195    /// * `b` - Pointer to matrix B (row-major, K x N)
196    /// * `c` - Pointer to output matrix C (row-major, M x N)
197    /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
198    unsafe fn dgemm(&self, m: usize, n: usize, k: usize, a: *const f64, b: *const f64, c: *mut f64);
199
200    /// Matrix multiplication: C = A * B (Complex<f64>)
201    ///
202    /// # Arguments
203    /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
204    /// * `a` - Pointer to matrix A (row-major, M x K)
205    /// * `b` - Pointer to matrix B (row-major, K x N)
206    /// * `c` - Pointer to output matrix C (row-major, M x N)
207    /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
208    unsafe fn zgemm(
209        &self,
210        m: usize,
211        n: usize,
212        k: usize,
213        a: *const num_complex::Complex<f64>,
214        b: *const num_complex::Complex<f64>,
215        c: *mut num_complex::Complex<f64>,
216    );
217
218    /// Returns true if this backend uses 64-bit integers (ILP64)
219    fn is_ilp64(&self) -> bool {
220        false
221    }
222
223    /// Returns backend name for debugging
224    fn name(&self) -> &'static str;
225}
226
227//==============================================================================
228// Faer Backend (Default, Pure Rust)
229//==============================================================================
230
231/// Default Faer backend (Pure Rust, no external dependencies)
232struct FaerBackend;
233
234impl GemmBackend for FaerBackend {
235    unsafe fn dgemm(
236        &self,
237        m: usize,
238        n: usize,
239        k: usize,
240        a: *const f64,
241        b: *const f64,
242        c: *mut f64,
243    ) {
244        use mdarray_linalg::matmul::MatMulBuilder;
245        use mdarray_linalg::prelude::MatMul;
246        use mdarray_linalg_faer::Faer;
247
248        // Create tensors from pointers (row-major order)
249        let a_slice = unsafe { std::slice::from_raw_parts(a, m * k) };
250        let b_slice = unsafe { std::slice::from_raw_parts(b, k * n) };
251        let a_tensor = DTensor::<f64, 2>::from_fn([m, k], |idx| a_slice[idx[0] * k + idx[1]]);
252        let b_tensor = DTensor::<f64, 2>::from_fn([k, n], |idx| b_slice[idx[0] * n + idx[1]]);
253
254        // Perform matrix multiplication
255        let c_tensor = Faer.matmul(&*a_tensor, &*b_tensor).parallelize().eval();
256
257        // Copy result back to output pointer (row-major order)
258        // For row-major, ldc = n (number of columns)
259        let ldc = n;
260        let c_slice = unsafe { std::slice::from_raw_parts_mut(c, m * ldc) };
261        for i in 0..m {
262            for j in 0..n {
263                c_slice[i * ldc + j] = c_tensor[[i, j]];
264            }
265        }
266    }
267
268    unsafe fn zgemm(
269        &self,
270        m: usize,
271        n: usize,
272        k: usize,
273        a: *const num_complex::Complex<f64>,
274        b: *const num_complex::Complex<f64>,
275        c: *mut num_complex::Complex<f64>,
276    ) {
277        use mdarray_linalg::matmul::MatMulBuilder;
278        use mdarray_linalg::prelude::MatMul;
279        use mdarray_linalg_faer::Faer;
280
281        // Create tensors from pointers (row-major order)
282        let a_slice = unsafe { std::slice::from_raw_parts(a, m * k) };
283        let b_slice = unsafe { std::slice::from_raw_parts(b, k * n) };
284        let a_tensor = DTensor::<num_complex::Complex<f64>, 2>::from_fn([m, k], |idx| {
285            a_slice[idx[0] * k + idx[1]]
286        });
287        let b_tensor = DTensor::<num_complex::Complex<f64>, 2>::from_fn([k, n], |idx| {
288            b_slice[idx[0] * n + idx[1]]
289        });
290
291        // Perform matrix multiplication
292        let c_tensor = Faer.matmul(&*a_tensor, &*b_tensor).parallelize().eval();
293
294        // Copy result back to output pointer (row-major order)
295        // For row-major, ldc = n (number of columns)
296        let ldc = n;
297        let c_slice = unsafe { std::slice::from_raw_parts_mut(c, m * ldc) };
298        for i in 0..m {
299            for j in 0..n {
300                c_slice[i * ldc + j] = c_tensor[[i, j]];
301            }
302        }
303    }
304
305    fn name(&self) -> &'static str {
306        "Faer (Pure Rust)"
307    }
308}
309
310//==============================================================================
311// External BLAS Backends (LP64 and ILP64)
312//==============================================================================
313
314/// Conversion rules for row-major data to column-major BLAS:
315///
316/// **Goal**: Compute C = A * B where:
317///   - A is m×k (row-major)
318///   - B is k×n (row-major)
319///   - C is m×n (row-major)
320///
321/// **Row-major to column-major interpretation**:
322///   - Row-major A (m×k) appears as A^T (k×m) in column-major → call this At
323///   - Row-major B (k×n) appears as B^T (n×k) in column-major → call this Bt
324///   - Row-major C (m×n) appears as C^T (n×m) in column-major → call this Ct
325///   - To compute C = A * B, we need: C^T = (A * B)^T = B^T * A^T
326///   - So: Ct = Bt * At
327///
328/// **BLAS call transformation**:
329///   - Original: C = A * B (row-major world)
330///   - BLAS call: Ct = Bt * At (column-major world)
331///   - transa = 'N' (Bt is already transposed-looking, no transpose needed)
332///   - transb = 'N' (At is already transposed-looking, no transpose needed)
333///   - Call: dgemm('N', 'N', n, m, k, alpha, B, lda, A, ldb, beta, C, ldc)
334///
335/// **Dimension conversions**:
336///   - m_blas = n (Ct rows = Bt rows)
337///   - n_blas = m (Ct cols = At cols)
338///   - k_blas = k (common dimension)
339///   - lda = n (leading dimension of Bt: n×k in column-major, lda = n)
340///   - ldb = k (leading dimension of At: k×m in column-major, ldb = k)
341///   - ldc = n (leading dimension of Ct: n×m in column-major, ldc = n)
342
343/// External BLAS backend (LP64: 32-bit integers)
344pub struct ExternalBlasBackend {
345    dgemm: DgemmFnPtr,
346    zgemm: ZgemmFnPtr,
347}
348
349impl ExternalBlasBackend {
350    pub fn new(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) -> Self {
351        Self { dgemm, zgemm }
352    }
353}
354
355impl GemmBackend for ExternalBlasBackend {
356    unsafe fn dgemm(
357        &self,
358        m: usize,
359        n: usize,
360        k: usize,
361        a: *const f64,
362        b: *const f64,
363        c: *mut f64,
364    ) {
365        // Validate dimensions fit in i32
366        assert!(
367            m <= i32::MAX as usize,
368            "Matrix dimension m too large for LP64 BLAS"
369        );
370        assert!(
371            n <= i32::MAX as usize,
372            "Matrix dimension n too large for LP64 BLAS"
373        );
374        assert!(
375            k <= i32::MAX as usize,
376            "Matrix dimension k too large for LP64 BLAS"
377        );
378
379        // Fortran BLAS requires all parameters passed by reference
380        // Apply row-major to column-major conversion (see conversion rules above)
381        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
382        let transb = b'N' as libc::c_char; // At is already transposed-looking
383        let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
384        let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
385        let k_i32 = k as i32; // k_blas = k (common dimension)
386        let alpha = 1.0f64;
387        let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
388        let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
389        let beta = 0.0f64;
390        // For row-major C (m×n) viewed as column-major Ct (n×m):
391        // Leading dimension in column-major is the stride between rows
392        // In row-major, stride between rows = number of columns = n
393        // So ldc = n (the number of columns in the original row-major matrix)
394        let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
395
396        unsafe {
397            (self.dgemm)(
398                &transa, &transb, &m_i32, &n_i32, &k_i32, &alpha, b, // B first (Bt)
399                &lda, a, // A second (At)
400                &ldb, &beta, c, &ldc_i32,
401            );
402        }
403    }
404
405    unsafe fn zgemm(
406        &self,
407        m: usize,
408        n: usize,
409        k: usize,
410        a: *const num_complex::Complex<f64>,
411        b: *const num_complex::Complex<f64>,
412        c: *mut num_complex::Complex<f64>,
413    ) {
414        assert!(
415            m <= i32::MAX as usize,
416            "Matrix dimension m too large for LP64 BLAS"
417        );
418        assert!(
419            n <= i32::MAX as usize,
420            "Matrix dimension n too large for LP64 BLAS"
421        );
422        assert!(
423            k <= i32::MAX as usize,
424            "Matrix dimension k too large for LP64 BLAS"
425        );
426
427        // Fortran BLAS requires all parameters passed by reference
428        // Apply row-major to column-major conversion (see conversion rules above)
429        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
430        let transb = b'N' as libc::c_char; // At is already transposed-looking
431        let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
432        let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
433        let k_i32 = k as i32; // k_blas = k (common dimension)
434        let alpha = num_complex::Complex::new(1.0, 0.0);
435        let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
436        let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
437        let beta = num_complex::Complex::new(0.0, 0.0);
438        // For row-major C (m×n) viewed as column-major Ct (n×m):
439        // Leading dimension in column-major is the stride between rows = n
440        let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
441
442        unsafe {
443            (self.zgemm)(
444                &transa,
445                &transb,
446                &m_i32,
447                &n_i32,
448                &k_i32,
449                &alpha,
450                b as *const _, // B first (Bt)
451                &lda,
452                a as *const _, // A second (At)
453                &ldb,
454                &beta,
455                c as *mut _,
456                &ldc_i32,
457            );
458        }
459    }
460
461    fn name(&self) -> &'static str {
462        "External BLAS (LP64)"
463    }
464}
465
466/// External BLAS backend (ILP64: 64-bit integers)
467pub struct ExternalBlas64Backend {
468    dgemm64: Dgemm64FnPtr,
469    zgemm64: Zgemm64FnPtr,
470}
471
472impl ExternalBlas64Backend {
473    pub fn new(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) -> Self {
474        Self { dgemm64, zgemm64 }
475    }
476}
477
478impl GemmBackend for ExternalBlas64Backend {
479    unsafe fn dgemm(
480        &self,
481        m: usize,
482        n: usize,
483        k: usize,
484        a: *const f64,
485        b: *const f64,
486        c: *mut f64,
487    ) {
488        // Fortran BLAS requires all parameters passed by reference
489        // Apply row-major to column-major conversion (see conversion rules above)
490        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
491        let transb = b'N' as libc::c_char; // At is already transposed-looking
492        let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
493        let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
494        let k_i64 = k as i64; // k_blas = k (common dimension)
495        let alpha = 1.0f64;
496        let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
497        let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
498        let beta = 0.0f64;
499        // For row-major C (m×n) viewed as column-major Ct (n×m):
500        // Leading dimension in column-major is the stride between rows = n
501        let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
502
503        unsafe {
504            (self.dgemm64)(
505                &transa, &transb, &m_i64, &n_i64, &k_i64, &alpha, b, // B first (Bt)
506                &lda, a, // A second (At)
507                &ldb, &beta, c, &ldc_i64,
508            );
509        }
510    }
511
512    unsafe fn zgemm(
513        &self,
514        m: usize,
515        n: usize,
516        k: usize,
517        a: *const num_complex::Complex<f64>,
518        b: *const num_complex::Complex<f64>,
519        c: *mut num_complex::Complex<f64>,
520    ) {
521        // Fortran BLAS requires all parameters passed by reference
522        // Apply row-major to column-major conversion (see conversion rules above)
523        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
524        let transb = b'N' as libc::c_char; // At is already transposed-looking
525        let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
526        let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
527        let k_i64 = k as i64; // k_blas = k (common dimension)
528        let alpha = num_complex::Complex::new(1.0, 0.0);
529        let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
530        let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
531        let beta = num_complex::Complex::new(0.0, 0.0);
532        // For row-major C (m×n) viewed as column-major Ct (n×m):
533        // Leading dimension in column-major is the stride between rows = n
534        let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
535
536        unsafe {
537            (self.zgemm64)(
538                &transa,
539                &transb,
540                &m_i64,
541                &n_i64,
542                &k_i64,
543                &alpha,
544                b as *const _, // B first (Bt)
545                &lda,
546                a as *const _, // A second (At)
547                &ldb,
548                &beta,
549                c as *mut _,
550                &ldc_i64,
551            );
552        }
553    }
554
555    fn is_ilp64(&self) -> bool {
556        true
557    }
558
559    fn name(&self) -> &'static str {
560        "External BLAS (ILP64)"
561    }
562}
563
564//==============================================================================
565// Backend Handle
566//==============================================================================
567
568/// Thread-safe handle to a GEMM backend
569///
570/// This type wraps an `Arc<dyn GemmBackend>` to allow sharing a backend
571/// across multiple function calls without global state.
572///
573/// # Example
574/// ```ignore
575/// use sparse_ir::gemm::GemmBackendHandle;
576///
577/// let backend = GemmBackendHandle::default();
578/// let result = matmul_par(&a, &b, Some(&backend));
579/// ```
580#[derive(Clone)]
581pub struct GemmBackendHandle {
582    inner: Arc<dyn GemmBackend>,
583}
584
585impl GemmBackendHandle {
586    /// Create a new backend handle from a boxed backend
587    pub fn new(backend: Box<dyn GemmBackend>) -> Self {
588        Self {
589            inner: Arc::from(backend),
590        }
591    }
592
593    /// Create a default backend handle (Faer backend)
594    pub fn default() -> Self {
595        Self {
596            inner: Arc::new(FaerBackend),
597        }
598    }
599
600    /// Get a reference to the inner backend
601    pub(crate) fn as_ref(&self) -> &dyn GemmBackend {
602        self.inner.as_ref()
603    }
604}
605
606//==============================================================================
607// Global Dispatcher (for backward compatibility)
608//==============================================================================
609
610/// Global BLAS dispatcher (thread-safe)
611///
612/// This is kept for backward compatibility when `None` is passed as backend.
613/// New code should use `GemmBackendHandle` explicitly.
614static BLAS_DISPATCHER: Lazy<RwLock<Box<dyn GemmBackend>>> = Lazy::new(|| {
615    #[cfg(feature = "system-blas")]
616    {
617        // Use system BLAS (LP64) by default via `blas-sys`.
618        let backend =
619            ExternalBlasBackend::new(dgemm_ as DgemmFnPtr, zgemm_wrapper as ZgemmFnPtr);
620        RwLock::new(Box::new(backend) as Box<dyn GemmBackend>)
621    }
622    #[cfg(not(feature = "system-blas"))]
623    {
624        // Default to the pure Rust Faer backend.
625        RwLock::new(Box::new(FaerBackend) as Box<dyn GemmBackend>)
626    }
627});
628
629/// Set BLAS backend (LP64: 32-bit integers)
630///
631/// # Safety
632/// - Function pointers must be valid and thread-safe
633/// - Must remain valid for the lifetime of the program
634/// - Must follow Fortran BLAS calling convention
635///
636/// # Example
637/// ```ignore
638/// unsafe {
639///     set_blas_backend(dgemm_ as _, zgemm_ as _);
640/// }
641/// ```
642pub unsafe fn set_blas_backend(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) {
643    let backend = ExternalBlasBackend { dgemm, zgemm };
644    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
645    *dispatcher = Box::new(backend);
646}
647
648/// Set ILP64 BLAS backend (64-bit integers)
649///
650/// # Safety
651/// - Function pointers must be valid, thread-safe, and use 64-bit integers
652/// - Must remain valid for the lifetime of the program
653/// - Must follow Fortran BLAS calling convention with ILP64 interface
654///
655/// # Example
656/// ```ignore
657/// unsafe {
658///     set_ilp64_backend(dgemm_ as _, zgemm_ as _);
659/// }
660/// ```
661pub unsafe fn set_ilp64_backend(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) {
662    let backend = ExternalBlas64Backend { dgemm64, zgemm64 };
663    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
664    *dispatcher = Box::new(backend);
665}
666
667/// Clear BLAS backend (reset to default Faer)
668///
669/// This function resets the GEMM dispatcher to use the default Pure Rust Faer backend.
670pub fn clear_blas_backend() {
671    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
672    *dispatcher = Box::new(FaerBackend);
673}
674
675/// Get current BLAS backend information
676///
677/// Returns:
678/// - `(backend_name, is_external, is_ilp64)`
679pub fn get_backend_info() -> (&'static str, bool, bool) {
680    let dispatcher = BLAS_DISPATCHER.read().unwrap();
681    let name = dispatcher.name();
682    let is_external = !name.contains("Faer");
683    let is_ilp64 = dispatcher.is_ilp64();
684    (name, is_external, is_ilp64)
685}
686
687//==============================================================================
688// Public API
689//==============================================================================
690
691/// Parallel matrix multiplication: C = A * B
692///
693/// Dispatches to the provided backend, or the global dispatcher if `None`.
694///
695/// # Arguments
696/// * `a` - Left matrix (M x K)
697/// * `b` - Right matrix (K x N)
698/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
699///
700/// # Returns
701/// Result matrix (M x N)
702///
703/// # Panics
704/// Panics if matrix dimensions are incompatible (A.cols != B.rows)
705///
706/// # Example
707/// ```ignore
708/// use mdarray::tensor;
709/// use sparse_ir::gemm::{matmul_par, GemmBackendHandle};
710///
711/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
712/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
713/// let backend = GemmBackendHandle::default();
714/// let c = matmul_par(&a, &b, Some(&backend));
715/// // c = [[19.0, 22.0], [43.0, 50.0]]
716/// ```
717pub fn matmul_par<T>(
718    a: &DTensor<T, 2>,
719    b: &DTensor<T, 2>,
720    backend: Option<&GemmBackendHandle>,
721) -> DTensor<T, 2>
722where
723    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
724{
725    let (_m, k) = *a.shape();
726    let (k2, _n) = *b.shape();
727
728    // Validate dimensions
729    assert_eq!(
730        k, k2,
731        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
732        k, k2
733    );
734
735    // Use Faer directly to avoid creating intermediate DTensors through backend
736    // create _m x _n result tensor
737    let mut result = DTensor::<T, 2>::from_elem([_m, _n], T::zero().into());
738    matmul_par_overwrite(a, b, &mut result, backend);
739    result
740}
741
742/// Parallel matrix multiplication with overwrite: C = A * B (writes to existing buffer)
743///
744/// This function writes the result directly into the provided buffer `c`,
745/// avoiding memory allocation. This is more memory-efficient for repeated operations.
746///
747/// # Arguments
748/// * `a` - Left matrix (M x K)
749/// * `b` - Right matrix (K x N)
750/// * `c` - Output matrix (M x N) - will be overwritten with result
751/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
752///
753/// # Panics
754/// Panics if matrix dimensions are incompatible (A.cols != B.rows or C.shape != [M, N])
755pub fn matmul_par_overwrite<T, Lc: Layout>(
756    a: &DTensor<T, 2>,
757    b: &DTensor<T, 2>,
758    c: &mut DSlice<T, 2, Lc>,
759    backend: Option<&GemmBackendHandle>,
760) where
761    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
762{
763    let (m, k) = *a.shape();
764    let (k2, n) = *b.shape();
765    let (mc, nc) = *c.shape();
766
767    // Validate dimensions
768    assert_eq!(
769        k, k2,
770        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
771        k, k2
772    );
773    assert_eq!(
774        m, mc,
775        "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
776        mc, m
777    );
778    assert_eq!(
779        n, nc,
780        "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
781        nc, n
782    );
783
784    // Type dispatch: f64 or Complex<f64>
785    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
786        // f64 case
787        // Get pointers directly from DTensors (row-major order)
788        let a_ptr = a.as_ptr() as *const f64;
789        let b_ptr = b.as_ptr() as *const f64;
790        let c_ptr = c.as_mut_ptr() as *mut f64;
791
792        // Get backend: use provided handle or fall back to global dispatcher
793        match backend {
794            Some(handle) => {
795                // Call backend directly with pointers (no temporary buffer needed)
796                // Leading dimension is calculated internally in the backend
797                unsafe {
798                    handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
799                }
800            }
801            None => {
802                // Backward compatibility: use global dispatcher
803                let dispatcher = BLAS_DISPATCHER.read().unwrap();
804                unsafe {
805                    dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
806                }
807            }
808        }
809    } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
810        // Complex<f64> case
811        // Get pointers directly from DTensors (row-major order)
812        let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
813        let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
814        let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
815
816        // Get backend: use provided handle or fall back to global dispatcher
817        match backend {
818            Some(handle) => {
819                // Call backend directly with pointers (no temporary buffer needed)
820                // Leading dimension is calculated internally in the backend
821                unsafe {
822                    handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
823                }
824            }
825            None => {
826                // Backward compatibility: use global dispatcher
827                let dispatcher = BLAS_DISPATCHER.read().unwrap();
828                unsafe {
829                    dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
830                }
831            }
832        }
833    } else {
834        // Fallback to Faer for unsupported types
835        use mdarray_linalg::matmul::MatMulBuilder;
836        use mdarray_linalg::prelude::MatMul;
837        use mdarray_linalg_faer::Faer;
838
839        Faer.matmul(a, b).parallelize().overwrite(c);
840    }
841}
842
843#[cfg(test)]
844mod tests {
845    use super::*;
846
847    #[test]
848    fn test_default_backend_is_faer() {
849        let (name, is_external, is_ilp64) = get_backend_info();
850        assert_eq!(name, "Faer (Pure Rust)");
851        assert!(!is_external);
852        assert!(!is_ilp64);
853    }
854
855    #[test]
856    fn test_clear_backend() {
857        // Should not panic
858        clear_blas_backend();
859        let (name, _, _) = get_backend_info();
860        assert_eq!(name, "Faer (Pure Rust)");
861    }
862
863    #[test]
864    fn test_matmul_f64() {
865        let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
866        let b_data = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
867
868        let a = DTensor::<f64, 2>::from_fn([2, 3], |idx| a_data[idx[0] * 3 + idx[1]]);
869        let b = DTensor::<f64, 2>::from_fn([3, 2], |idx| b_data[idx[0] * 2 + idx[1]]);
870        let c = matmul_par(&a, &b, None);
871
872        assert_eq!(*c.shape(), (2, 2));
873        // First row: [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
874        // Second row: [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
875        assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
876        assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
877        assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
878        assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
879    }
880
881    #[test]
882    fn test_matmul_par_basic() {
883        use mdarray::tensor;
884        let a: DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
885        let b: DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
886        let c = matmul_par(&a, &b, None);
887
888        // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
889        //         = [[19, 22], [43, 50]]
890        assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
891        assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
892        assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
893        assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
894    }
895
896    #[test]
897    fn test_matmul_par_non_square() {
898        use mdarray::tensor;
899        let a: DTensor<f64, 2> = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // 2x3
900        let b: DTensor<f64, 2> = tensor![[7.0], [8.0], [9.0]]; // 3x1
901        let c = matmul_par(&a, &b, None);
902
903        // Expected: [[1*7+2*8+3*9], [4*7+5*8+6*9]]
904        //         = [[50], [122]]
905        assert!((c[[0, 0]] - 50.0).abs() < 1e-10);
906        assert!((c[[1, 0]] - 122.0).abs() < 1e-10);
907    }
908}