1use cblas::Transpose;
2
3use crate::{
4 device::{cpu::Cpu, Device, DeviceBase},
5 dim::{DimDyn, DimTrait},
6 matrix::{Matrix, Owned, Ref, Repr},
7 matrix_blas::BlasTrans,
8 num::Num,
9 shape_stride::ShapeStride,
10};
11
12pub trait Gemm: DeviceBase {
13 #[expect(clippy::too_many_arguments)]
14 fn gemm_unchecked<T: Num>(
15 transa: BlasTrans,
16 transb: BlasTrans,
17 m: usize,
18 n: usize,
19 k: usize,
20 alpha: T,
21 a: *const T,
22 lda: usize,
23 b: *const T,
24 ldb: usize,
25 beta: T,
26 c: *mut T,
27 ldc: usize,
28 );
29}
30
31fn from_trans(value: BlasTrans) -> Transpose {
32 match value {
33 BlasTrans::None => Transpose::None,
34 BlasTrans::Ordinary => Transpose::Ordinary,
35 BlasTrans::Conjugate => Transpose::Conjugate,
36 }
37}
38
39impl Gemm for Cpu {
40 #[expect(clippy::many_single_char_names, clippy::similar_names)]
41 fn gemm_unchecked<T: Num>(
42 transa: BlasTrans,
43 transb: BlasTrans,
44 m: usize,
45 n: usize,
46 k: usize,
47 alpha: T,
48 a: *const T,
49 lda: usize,
50 b: *const T,
51 ldb: usize,
52 beta: T,
53 c: *mut T,
54 ldc: usize,
55 ) {
56 extern crate openblas_src;
57 use cblas::{dgemm, sgemm, Layout};
58 if T::is_f32() {
59 let a = unsafe { std::slice::from_raw_parts(a.cast(), m * k) };
60 let b = unsafe { std::slice::from_raw_parts(b.cast(), k * n) };
61 let c = unsafe { std::slice::from_raw_parts_mut(c.cast(), m * n) };
62 unsafe {
63 sgemm(
64 Layout::RowMajor,
65 from_trans(transa),
66 from_trans(transb),
67 m.try_into().unwrap(),
68 n.try_into().unwrap(),
69 k.try_into().unwrap(),
70 alpha.to_f32().unwrap(),
71 a,
72 lda.try_into().unwrap(),
73 b,
74 ldb.try_into().unwrap(),
75 beta.to_f32().unwrap(),
76 c,
77 ldc.try_into().unwrap(),
78 );
79 }
80 } else {
81 let a = unsafe { std::slice::from_raw_parts(a.cast(), m * k) };
82 let b = unsafe { std::slice::from_raw_parts(b.cast(), k * n) };
83 let c = unsafe { std::slice::from_raw_parts_mut(c.cast(), m * n) };
84 unsafe {
85 dgemm(
86 Layout::RowMajor,
87 from_trans(transa),
88 from_trans(transb),
89 m.try_into().unwrap(),
90 n.try_into().unwrap(),
91 k.try_into().unwrap(),
92 alpha.to_f64().unwrap(),
93 a,
94 lda.try_into().unwrap(),
95 b,
96 ldb.try_into().unwrap(),
97 beta.to_f64().unwrap(),
98 c,
99 ldc.try_into().unwrap(),
100 );
101 }
102 }
103 }
104}
105
106#[cfg(feature = "nvidia")]
107use crate::device::nvidia::Nvidia;
108
109#[cfg(feature = "nvidia")]
110use zenu_cuda::cublas::{cublas_gemm, ZenuCublasOperation};
111
112#[cfg(feature = "nvidia")]
113impl Gemm for Nvidia {
114 #[expect(clippy::many_single_char_names, clippy::similar_names)]
115 fn gemm_unchecked<T: Num>(
116 transa: BlasTrans,
117 transb: BlasTrans,
118 m: usize,
119 n: usize,
120 k: usize,
121 alpha: T,
122 a: *const T,
123 lda: usize,
124 b: *const T,
125 ldb: usize,
126 beta: T,
127 c: *mut T,
128 ldc: usize,
129 ) {
130 fn to_cuda_ops(trans: BlasTrans) -> ZenuCublasOperation {
131 match trans {
132 BlasTrans::None => ZenuCublasOperation::N,
133 BlasTrans::Ordinary => ZenuCublasOperation::T,
134 BlasTrans::Conjugate => ZenuCublasOperation::ConjT,
135 }
136 }
137 let transa = to_cuda_ops(transa);
138 let transb = to_cuda_ops(transb);
139 let m = i32::try_from(m).unwrap();
140 let n = i32::try_from(n).unwrap();
141 let k = i32::try_from(k).unwrap();
142 let lda = i32::try_from(lda).unwrap();
143 let ldb = i32::try_from(ldb).unwrap();
144 let ldc = i32::try_from(ldc).unwrap();
145 cublas_gemm::<T>(transb, transa, n, m, k, alpha, b, ldb, a, lda, beta, c, ldc).unwrap();
146 }
147}
148
149fn gemm_shape_check<SA: DimTrait, SB: DimTrait, SC: DimTrait>(
150 a: ShapeStride<SA>,
151 b: ShapeStride<SB>,
152 c: ShapeStride<SC>,
153) -> Result<(), String> {
154 let c_shape = c.shape();
155 let a_shape = a.shape();
156 let b_shape = b.shape();
157
158 if c_shape.len() != 2 {
159 return Err("The output matrix C must be 2-D.".to_string());
160 }
161 if a_shape.len() != 2 {
162 return Err("The input matrix A must be 2-D.".to_string());
163 }
164 if b_shape.len() != 2 {
165 return Err("The input matrix B must be 2-D.".to_string());
166 }
167
168 let is_transposed_c = c.is_transposed();
169
170 if is_transposed_c {
171 return Err("The output matrix C must not be transposed.".to_string());
172 }
173
174 if a_shape[0] != c_shape[0] {
175 return Err(
176 "The number of rows of matrix A must match the number of rows of matrix C.".to_string(),
177 );
178 }
179
180 if b_shape[1] != c_shape[1] {
181 return Err(
182 "The number of columns of matrix B must match the number of columns of matrix C."
183 .to_string(),
184 );
185 }
186
187 if a_shape[1] != b_shape[0] {
188 return Err(
189 "The number of columns of matrix A must match the number of rows of matrix B."
190 .to_string(),
191 );
192 }
193
194 if a_shape[0] == 0
195 || a_shape[1] == 0
196 || b_shape[0] == 0
197 || b_shape[1] == 0
198 || c_shape[0] == 0
199 || c_shape[1] == 0
200 {
201 return Err(
202 "The dimensions of the input and output matrices must be greater than 0.".to_string(),
203 );
204 }
205 Ok(())
206}
207
208#[expect(clippy::missing_panics_doc, clippy::similar_names)]
209pub fn gemm_assign<T, D, RA, RB, SA, SB, SC>(
210 a: &Matrix<RA, SA, D>,
211 b: &Matrix<RB, SB, D>,
212 c: &Matrix<Ref<&mut T>, SC, D>,
213 alpha: T,
214 beta: T,
215) where
216 T: Num,
217 D: Device,
218 RA: Repr<Item = T>,
219 RB: Repr<Item = T>,
220 SA: DimTrait,
221 SB: DimTrait,
222 SC: DimTrait,
223{
224 if let Ok(()) = gemm_shape_check(a.shape_stride(), b.shape_stride(), c.shape_stride()) {
225 let transa = if a.shape_stride().is_transposed() {
226 BlasTrans::Ordinary
227 } else {
228 BlasTrans::None
229 };
230 let transb = if b.shape_stride().is_transposed() {
231 BlasTrans::Ordinary
232 } else {
233 BlasTrans::None
234 };
235 let get_lead_dim = |stride: &[usize], trans: BlasTrans| match trans {
236 BlasTrans::None => stride[0],
237 BlasTrans::Ordinary => stride[1],
238 BlasTrans::Conjugate => unreachable!(),
239 };
240 let lda = get_lead_dim(a.stride().slice(), transa);
241 let ldb = get_lead_dim(b.stride().slice(), transb);
242 D::gemm_unchecked(
243 transa,
244 transb,
245 c.shape()[0],
246 c.shape()[1],
247 a.shape()[1],
248 alpha,
249 a.as_ptr(),
250 lda,
251 b.as_ptr(),
252 ldb,
253 beta,
254 c.as_mut_ptr(),
255 c.stride()[0],
256 );
257 return;
258 }
259 panic!("Dimension mismatch");
260}
261
262pub fn gemm<T, D, RA, RB, SA, SB>(
263 a: &Matrix<RA, SA, D>,
264 b: &Matrix<RB, SB, D>,
265 alpha: T,
266 beta: T,
267) -> Matrix<Owned<T>, DimDyn, D>
268where
269 T: Num,
270 D: Device,
271 RA: Repr<Item = T>,
272 RB: Repr<Item = T>,
273 SA: DimTrait,
274 SB: DimTrait,
275{
276 let c_shape = [a.shape()[0], b.shape()[1]];
277 let mut c = Matrix::<_, DimDyn, D>::alloc(c_shape);
278 gemm_assign(a, b, &c.to_ref_mut(), alpha, beta);
279 c
280}
281
282pub fn matmul<T, D, RA, RB, SA, SB>(
283 a: &Matrix<RA, SA, D>,
284 b: &Matrix<RB, SB, D>,
285) -> Matrix<Owned<T>, DimDyn, D>
286where
287 T: Num,
288 D: Device,
289 RA: Repr<Item = T>,
290 RB: Repr<Item = T>,
291 SA: DimTrait,
292 SB: DimTrait,
293{
294 gemm(a, b, T::one(), T::zero())
295}
296
297#[cfg(test)]
298mod gemm {
299 use crate::{
300 device::Device,
301 dim::DimDyn,
302 matrix::{Matrix, Owned},
303 };
304
305 use super::gemm_assign;
306
307 fn gemm_3x4_4x5_3x5<D: Device>() {
308 let a = Matrix::<Owned<f32>, DimDyn, D>::from_vec(
309 vec![1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.],
310 [3, 4],
311 );
312 let b = Matrix::<_, DimDyn, D>::from_vec(
313 vec![
314 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.,
315 19., 20.,
316 ],
317 [4, 5],
318 );
319 let mut c = Matrix::<_, DimDyn, D>::alloc([3, 5]);
320 gemm_assign(&a, &b, &c.to_ref_mut(), 1., 0.);
321 let ans = vec![
322 110., 120., 130., 140., 150., 246., 272., 298., 324., 350., 382., 424., 466., 508.,
323 550.,
324 ];
325 let ans = Matrix::<_, DimDyn, D>::from_vec(ans, [3, 5]);
326 let diff = (c - ans).asum();
327 assert!(diff < 1e-6);
328 }
329 #[test]
330 fn gemm_3x4_4x5_3x5_cpu() {
331 gemm_3x4_4x5_3x5::<crate::device::cpu::Cpu>();
332 }
333 #[cfg(feature = "nvidia")]
334 #[test]
335 fn gemm_3x4_4x5_3x5_nvidia() {
336 gemm_3x4_4x5_3x5::<crate::device::nvidia::Nvidia>();
337 }
338}