zenu_matrix/operation/
mul.rs

1use cblas::Transpose;
2
3use crate::{
4    device::{cpu::Cpu, Device, DeviceBase},
5    dim::{DimDyn, DimTrait},
6    matrix::{Matrix, Owned, Ref, Repr},
7    matrix_blas::BlasTrans,
8    num::Num,
9    shape_stride::ShapeStride,
10};
11
12pub trait Gemm: DeviceBase {
13    #[expect(clippy::too_many_arguments)]
14    fn gemm_unchecked<T: Num>(
15        transa: BlasTrans,
16        transb: BlasTrans,
17        m: usize,
18        n: usize,
19        k: usize,
20        alpha: T,
21        a: *const T,
22        lda: usize,
23        b: *const T,
24        ldb: usize,
25        beta: T,
26        c: *mut T,
27        ldc: usize,
28    );
29}
30
31fn from_trans(value: BlasTrans) -> Transpose {
32    match value {
33        BlasTrans::None => Transpose::None,
34        BlasTrans::Ordinary => Transpose::Ordinary,
35        BlasTrans::Conjugate => Transpose::Conjugate,
36    }
37}
38
39impl Gemm for Cpu {
40    #[expect(clippy::many_single_char_names, clippy::similar_names)]
41    fn gemm_unchecked<T: Num>(
42        transa: BlasTrans,
43        transb: BlasTrans,
44        m: usize,
45        n: usize,
46        k: usize,
47        alpha: T,
48        a: *const T,
49        lda: usize,
50        b: *const T,
51        ldb: usize,
52        beta: T,
53        c: *mut T,
54        ldc: usize,
55    ) {
56        extern crate openblas_src;
57        use cblas::{dgemm, sgemm, Layout};
58        if T::is_f32() {
59            let a = unsafe { std::slice::from_raw_parts(a.cast(), m * k) };
60            let b = unsafe { std::slice::from_raw_parts(b.cast(), k * n) };
61            let c = unsafe { std::slice::from_raw_parts_mut(c.cast(), m * n) };
62            unsafe {
63                sgemm(
64                    Layout::RowMajor,
65                    from_trans(transa),
66                    from_trans(transb),
67                    m.try_into().unwrap(),
68                    n.try_into().unwrap(),
69                    k.try_into().unwrap(),
70                    alpha.to_f32().unwrap(),
71                    a,
72                    lda.try_into().unwrap(),
73                    b,
74                    ldb.try_into().unwrap(),
75                    beta.to_f32().unwrap(),
76                    c,
77                    ldc.try_into().unwrap(),
78                );
79            }
80        } else {
81            let a = unsafe { std::slice::from_raw_parts(a.cast(), m * k) };
82            let b = unsafe { std::slice::from_raw_parts(b.cast(), k * n) };
83            let c = unsafe { std::slice::from_raw_parts_mut(c.cast(), m * n) };
84            unsafe {
85                dgemm(
86                    Layout::RowMajor,
87                    from_trans(transa),
88                    from_trans(transb),
89                    m.try_into().unwrap(),
90                    n.try_into().unwrap(),
91                    k.try_into().unwrap(),
92                    alpha.to_f64().unwrap(),
93                    a,
94                    lda.try_into().unwrap(),
95                    b,
96                    ldb.try_into().unwrap(),
97                    beta.to_f64().unwrap(),
98                    c,
99                    ldc.try_into().unwrap(),
100                );
101            }
102        }
103    }
104}
105
106#[cfg(feature = "nvidia")]
107use crate::device::nvidia::Nvidia;
108
109#[cfg(feature = "nvidia")]
110use zenu_cuda::cublas::{cublas_gemm, ZenuCublasOperation};
111
112#[cfg(feature = "nvidia")]
113impl Gemm for Nvidia {
114    #[expect(clippy::many_single_char_names, clippy::similar_names)]
115    fn gemm_unchecked<T: Num>(
116        transa: BlasTrans,
117        transb: BlasTrans,
118        m: usize,
119        n: usize,
120        k: usize,
121        alpha: T,
122        a: *const T,
123        lda: usize,
124        b: *const T,
125        ldb: usize,
126        beta: T,
127        c: *mut T,
128        ldc: usize,
129    ) {
130        fn to_cuda_ops(trans: BlasTrans) -> ZenuCublasOperation {
131            match trans {
132                BlasTrans::None => ZenuCublasOperation::N,
133                BlasTrans::Ordinary => ZenuCublasOperation::T,
134                BlasTrans::Conjugate => ZenuCublasOperation::ConjT,
135            }
136        }
137        let transa = to_cuda_ops(transa);
138        let transb = to_cuda_ops(transb);
139        let m = i32::try_from(m).unwrap();
140        let n = i32::try_from(n).unwrap();
141        let k = i32::try_from(k).unwrap();
142        let lda = i32::try_from(lda).unwrap();
143        let ldb = i32::try_from(ldb).unwrap();
144        let ldc = i32::try_from(ldc).unwrap();
145        cublas_gemm::<T>(transb, transa, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc).unwrap();
146    }
147}
148
149fn gemm_shape_check<SA: DimTrait, SB: DimTrait, SC: DimTrait>(
150    a: ShapeStride<SA>,
151    b: ShapeStride<SB>,
152    c: ShapeStride<SC>,
153) -> Result<(), String> {
154    let c_shape = c.shape();
155    let a_shape = a.shape();
156    let b_shape = b.shape();
157
158    if c_shape.len() != 2 {
159        return Err("The output matrix C must be 2-D.".to_string());
160    }
161    if a_shape.len() != 2 {
162        return Err("The input matrix A must be 2-D.".to_string());
163    }
164    if b_shape.len() != 2 {
165        return Err("The input matrix B must be 2-D.".to_string());
166    }
167
168    let is_transposed_c = c.is_transposed();
169
170    if is_transposed_c {
171        return Err("The output matrix C must not be transposed.".to_string());
172    }
173
174    if a_shape[0] != c_shape[0] {
175        return Err(
176            "The number of rows of matrix A must match the number of rows of matrix C.".to_string(),
177        );
178    }
179
180    if b_shape[1] != c_shape[1] {
181        return Err(
182            "The number of columns of matrix B must match the number of columns of matrix C."
183                .to_string(),
184        );
185    }
186
187    if a_shape[1] != b_shape[0] {
188        return Err(
189            "The number of columns of matrix A must match the number of rows of matrix B."
190                .to_string(),
191        );
192    }
193
194    if a_shape[0] == 0
195        || a_shape[1] == 0
196        || b_shape[0] == 0
197        || b_shape[1] == 0
198        || c_shape[0] == 0
199        || c_shape[1] == 0
200    {
201        return Err(
202            "The dimensions of the input and output matrices must be greater than 0.".to_string(),
203        );
204    }
205    Ok(())
206}
207
208#[expect(clippy::missing_panics_doc, clippy::similar_names)]
209pub fn gemm_assign<T, D, RA, RB, SA, SB, SC>(
210    a: &Matrix<RA, SA, D>,
211    b: &Matrix<RB, SB, D>,
212    c: &Matrix<Ref<&mut T>, SC, D>,
213    alpha: T,
214    beta: T,
215) where
216    T: Num,
217    D: Device,
218    RA: Repr<Item = T>,
219    RB: Repr<Item = T>,
220    SA: DimTrait,
221    SB: DimTrait,
222    SC: DimTrait,
223{
224    if let Ok(()) = gemm_shape_check(a.shape_stride(), b.shape_stride(), c.shape_stride()) {
225        let transa = if a.shape_stride().is_transposed() {
226            BlasTrans::Ordinary
227        } else {
228            BlasTrans::None
229        };
230        let transb = if b.shape_stride().is_transposed() {
231            BlasTrans::Ordinary
232        } else {
233            BlasTrans::None
234        };
235        let get_lead_dim = |stride: &[usize], trans: BlasTrans| match trans {
236            BlasTrans::None => stride[0],
237            BlasTrans::Ordinary => stride[1],
238            BlasTrans::Conjugate => unreachable!(),
239        };
240        let lda = get_lead_dim(a.stride().slice(), transa);
241        let ldb = get_lead_dim(b.stride().slice(), transb);
242        D::gemm_unchecked(
243            transa,
244            transb,
245            c.shape()[0],
246            c.shape()[1],
247            a.shape()[1],
248            alpha,
249            a.as_ptr(),
250            lda,
251            b.as_ptr(),
252            ldb,
253            beta,
254            c.as_mut_ptr(),
255            c.stride()[0],
256        );
257        return;
258    }
259    panic!("Dimension mismatch");
260}
261
262pub fn gemm<T, D, RA, RB, SA, SB>(
263    a: &Matrix<RA, SA, D>,
264    b: &Matrix<RB, SB, D>,
265    alpha: T,
266    beta: T,
267) -> Matrix<Owned<T>, DimDyn, D>
268where
269    T: Num,
270    D: Device,
271    RA: Repr<Item = T>,
272    RB: Repr<Item = T>,
273    SA: DimTrait,
274    SB: DimTrait,
275{
276    let c_shape = [a.shape()[0], b.shape()[1]];
277    let mut c = Matrix::<_, DimDyn, D>::alloc(c_shape);
278    gemm_assign(a, b, &c.to_ref_mut(), alpha, beta);
279    c
280}
281
282pub fn matmul<T, D, RA, RB, SA, SB>(
283    a: &Matrix<RA, SA, D>,
284    b: &Matrix<RB, SB, D>,
285) -> Matrix<Owned<T>, DimDyn, D>
286where
287    T: Num,
288    D: Device,
289    RA: Repr<Item = T>,
290    RB: Repr<Item = T>,
291    SA: DimTrait,
292    SB: DimTrait,
293{
294    gemm(a, b, T::one(), T::zero())
295}
296
297#[cfg(test)]
298mod gemm {
299    use crate::{
300        device::Device,
301        dim::DimDyn,
302        matrix::{Matrix, Owned},
303    };
304
305    use super::gemm_assign;
306
307    fn gemm_3x4_4x5_3x5<D: Device>() {
308        let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
309            vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
310            [3, 4],
311        );
312        let b = Matrix::<_, DimDyn, D>::from_vec(
313            vec![
314                1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
315                19., 20.,
316            ],
317            [4, 5],
318        );
319        let mut c = Matrix::<_, DimDyn, D>::alloc([3, 5]);
320        gemm_assign(&a, &b, &c.to_ref_mut(), 1., 0.);
321        let ans = vec![
322            110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., 466., 508.,
323            550.,
324        ];
325        let ans = Matrix::<_, DimDyn, D>::from_vec(ans, [3, 5]);
326        let diff = (c - ans).asum();
327        assert!(diff < 1e-6);
328    }
329    #[test]
330    fn gemm_3x4_4x5_3x5_cpu() {
331        gemm_3x4_4x5_3x5::<crate::device::cpu::Cpu>();
332    }
333    #[cfg(feature = "nvidia")]
334    #[test]
335    fn gemm_3x4_4x5_3x5_nvidia() {
336        gemm_3x4_4x5_3x5::<crate::device::nvidia::Nvidia>();
337    }
338}