scirs2_core/gpu/kernels/blas/
gemm.rs1use std::collections::HashMap;
10use std::fmt;
11
12use crate::gpu::kernels::{
13 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
14};
15use crate::gpu::{GpuBackend, GpuError};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum GemmImpl {
20 Standard,
22 Large,
24 Small,
26 TensorCore,
28}
29
30impl fmt::Display for GemmImpl {
31 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32 match self {
33 GemmImpl::Standard => write!(f, "standard"),
34 GemmImpl::Large => write!(f, "large"),
35 GemmImpl::Small => write!(f, "small"),
36 GemmImpl::TensorCore => write!(f, "tensor_core"),
37 }
38 }
39}
40
41pub struct GemmKernel {
43 base: BaseKernel,
44 #[allow(dead_code)]
45 implementation: GemmImpl,
46}
47
48impl Default for GemmKernel {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl GemmKernel {
55 pub fn new() -> Self {
57 Self::with_implementation(GemmImpl::Standard)
59 }
60
61 pub fn with_implementation(implementation: GemmImpl) -> Self {
63 let metadata = match implementation {
64 GemmImpl::Standard => KernelMetadata {
65 workgroup_size: [16, 16, 1],
66 local_memory_usage: 8192, supports_tensor_cores: false,
68 operationtype: OperationType::ComputeIntensive,
69 backend_metadata: HashMap::new(),
70 },
71 GemmImpl::Large => KernelMetadata {
72 workgroup_size: [32, 32, 1],
73 local_memory_usage: 32768, supports_tensor_cores: false,
75 operationtype: OperationType::ComputeIntensive,
76 backend_metadata: HashMap::new(),
77 },
78 GemmImpl::Small => KernelMetadata {
79 workgroup_size: [8, 8, 1],
80 local_memory_usage: 2048, supports_tensor_cores: false,
82 operationtype: OperationType::ComputeIntensive,
83 backend_metadata: HashMap::new(),
84 },
85 GemmImpl::TensorCore => KernelMetadata {
86 workgroup_size: [16, 16, 1],
87 local_memory_usage: 8192, supports_tensor_cores: true,
89 operationtype: OperationType::ComputeIntensive,
90 backend_metadata: HashMap::new(),
91 },
92 };
93
94 let (name, cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
95 Self::get_sources_for_implementation(implementation);
96
97 Self {
98 base: BaseKernel::new(
99 &name,
100 &cuda_source,
101 &rocm_source,
102 &wgpu_source,
103 &metal_source,
104 &opencl_source,
105 metadata,
106 ),
107 implementation,
108 }
109 }
110
111 pub fn with_alpha_beta(_alpha: f32, beta: f32) -> Box<dyn GpuKernel> {
113 let kernel = Self::new();
114
115 Box::new(kernel)
119 }
120
121 fn get_sources_for_implementation(
123 implementation: GemmImpl,
124 ) -> (String, String, String, String, String, String) {
125 let name = format!("{implementation}");
126
127 let cuda_source = match implementation {
132 GemmImpl::Standard => r#"
133extern "C" __global__ void gemm_standard(
134 const float* __restrict__ a,
135 const float* __restrict__ b,
136 float* __restrict__ c,
137 int m, int n, int k,
138 float alpha, float beta
139) {
140 // Block index
141 int bx = blockIdx.x;
142 int by = blockIdx.y;
143
144 // Thread index
145 int tx = threadIdx.x;
146 int ty = threadIdx.y;
147
148 // Define block size
149 const int BLOCK_SIZE = 16;
150
151 // Index of the first sub-matrix of A processed by the block
152 int aBegin = k * BLOCK_SIZE * by;
153
154 // Index of the last sub-matrix of A processed by the block
155 int aEnd = aBegin + k - 1;
156
157 // Step size used to iterate through the sub-matrices of A
158 int aStep = BLOCK_SIZE;
159
160 // Index of the first sub-matrix of B processed by the block
161 int bBegin = BLOCK_SIZE * bx;
162
163 // Step size used to iterate through the sub-matrices of B
164 int bStep = BLOCK_SIZE * n;
165
166 // The element of the block sub-matrix that is computed
167 // by the thread
168 float Csub = 0;
169
170 // Loop over all the sub-matrices of A and B required to
171 // compute the block sub-matrix
172 for (int a = aBegin, b = bBegin;
173 a <= aEnd;
174 a += aStep, b += bStep) {
175
176 // Shared memory for the sub-matrix of A
177 __shared__ float As[BLOCK_SIZE][BLOCK_SIZE];
178
179 // Shared memory for the sub-matrix of B
180 __shared__ float Bs[BLOCK_SIZE][BLOCK_SIZE];
181
182 // Load the matrices from global memory to shared memory
183 As[ty][tx] = a[a + k * ty + tx];
184 Bs[ty][tx] = b[b + n * ty + tx];
185
186 // Synchronize to make sure the matrices are loaded
187 __syncthreads();
188
189 // Multiply the two matrices together
190 #pragma unroll
191 for (int i = 0; i < BLOCK_SIZE; ++i) {
192 Csub += As[ty][i] * Bs[i][tx];
193 }
194
195 // Synchronize to make sure that the preceding
196 // computation is done before loading two new
197 // sub-matrices of A and B in the next iteration
198 __syncthreads();
199 }
200
201 // Write the block sub-matrix to global memory
202 int c_idx = n * BLOCK_SIZE * by + BLOCK_SIZE * bx;
203 int c_row = c_idx + n * ty + tx;
204
205 if (beta == 0) {
206 c[c_row] = alpha * Csub;
207 } else {
208 c[c_row] = alpha * Csub + beta * c[c_row];
209 }
210}
211"#
212 .to_string(),
213 _ => r#"
215// Placeholder for other optimized CUDA kernels
216extern "C" __global__ void gemm_standard(
217 const float* __restrict__ a,
218 const float* __restrict__ b,
219 float* __restrict__ c,
220 int m, int n, int k,
221 float alpha, float beta
222) {
223 // Implementation similar to standard but with optimizations
224 // specific to the implementation type
225}
226"#
227 .to_string(),
228 };
229
230 let wgpu_source = r#"
232struct Uniforms {
233 m: u32,
234 n: u32,
235 k: u32,
236 alpha: f32,
237 beta: f32,
238};
239
240@group(0) @binding(0) var<uniform> uniforms: Uniforms;
241@group(0) @binding(1) var<storage, read> a: array<f32>;
242@group(0) @binding(2) var<storage, read> b: array<f32>;
243@group(0) @binding(3) var<storage, write> c: array<f32>;
244
245var<workgroup> As: array<array<f32, 16>, 16>;
246var<workgroup> Bs: array<array<f32, 16>, 16>;
247
248@compute @workgroup_size(16, 16)
249#[allow(dead_code)]
250fn gemm_standard(@builtin(global_invocation_id) global_id: vec3<u32>,
251 @builtin(workgroup_id) workgroup_id: vec3<u32>,
252 @builtin(local_invocation_id) local_id: vec3<u32>) {
253
254 let bx = workgroup_id.x;
255 let by = workgroup_id.y;
256
257 let tx = local_id.x;
258 let ty = local_id.y;
259
260 let block_size = 16u;
261
262 // Index of c
263 let row = by * block_size + ty;
264 let col = bx * block_size + tx;
265
266 var sum = 0.0;
267
268 // Loop over A and B tiles
269 for (var t = 0u; t < (uniforms.k + block_size - 1u) / block_size; t = t + 1u) {
270 // Load A tile
271 if (row < uniforms.m && t * block_size + tx < uniforms.k) {
272 As[ty][tx] = a[row * uniforms.k + t * block_size + tx];
273 } else {
274 As[ty][tx] = 0.0;
275 }
276
277 // Load B tile
278 if (t * block_size + ty < uniforms.k && col < uniforms.n) {
279 Bs[ty][tx] = b[(t * block_size + ty) * uniforms.n + col];
280 } else {
281 Bs[ty][tx] = 0.0;
282 }
283
284 workgroupBarrier();
285
286 // Compute
287 for (var k = 0u; k < block_size; k = k + 1u) {
288 sum = sum + As[ty][k] * Bs[k][tx];
289 }
290
291 workgroupBarrier();
292 }
293
294 // Write result
295 if (row < uniforms.m && col < uniforms.n) {
296 let c_idx = row * uniforms.n + col;
297 if (uniforms.beta == 0.0) {
298 c[c_idx] = uniforms.alpha * sum;
299 } else {
300 c[c_idx] = uniforms.alpha * sum + uniforms.beta * c[c_idx];
301 }
302 }
303}
304"#
305 .to_string();
306
307 let metal_source = r#"
309#include <metal_stdlib>
310using namespace metal;
311
312kernel void gemm_standard(
313 const device float* a [[buffer(0)]],
314 const device float* b [[buffer(1)]],
315 device float* c [[buffer(2)]],
316 constant uint& m [[buffer(3)]],
317 constant uint& n [[buffer(4)]],
318 constant uint& k [[buffer(5)]],
319 constant float& alpha [[buffer(6)]],
320 constant float& beta [[buffer(7)]],
321 uint2 gid [[thread_position_in_grid]],
322 uint2 lid [[thread_position_in_threadgroup]],
323 uint2 wgid [[threadgroup_position_in_grid]])
324{
325 const uint block_size = 16;
326
327 // Thread indices
328 uint tx = lid.x;
329 uint ty = lid.y;
330
331 // Block indices
332 uint bx = wgid.x;
333 uint by = wgid.y;
334
335 // Global indices
336 uint row = by * block_size + ty;
337 uint col = bx * block_size + tx;
338
339 // Shared memory for tile
340 threadgroup float As[16][16];
341 threadgroup float Bs[16][16];
342
343 float sum = 0.0;
344
345 // Loop over tiles
346 for (uint t = 0; t < (k + block_size - 1) / block_size; t++) {
347 // Load tiles
348 if (row < m && t * block_size + tx < k) {
349 As[ty][tx] = a[row * k + t * block_size + tx];
350 } else {
351 As[ty][tx] = 0.0;
352 }
353
354 if (t * block_size + ty < k && col < n) {
355 Bs[ty][tx] = b[(t * block_size + ty) * n + col];
356 } else {
357 Bs[ty][tx] = 0.0;
358 }
359
360 threadgroup_barrier(mem_flags::mem_threadgroup);
361
362 // Compute
363 for (uint i = 0; i < block_size; i++) {
364 sum += As[ty][i] * Bs[i][tx];
365 }
366
367 threadgroup_barrier(mem_flags::mem_threadgroup);
368 }
369
370 // Write result
371 if (row < m && col < n) {
372 uint c_idx = row * n + col;
373 if (beta == 0.0) {
374 c[c_idx] = alpha * sum;
375 } else {
376 c[c_idx] = alpha * sum + beta * c[c_idx];
377 }
378 }
379}
380"#
381 .to_string();
382
383 let opencl_source = r#"
385__kernel void gemm_standard(
386 __global const float* a,
387 __global const float* b,
388 __global float* c,
389 const int m,
390 const int n,
391 const int k,
392 const float alpha,
393 const float beta)
394{
395 const int block_size = 16;
396
397 // Thread indices
398 const int tx = get_local_id(0);
399 const int ty = get_local_id(1);
400
401 // Block indices
402 const int bx = get_group_id(0);
403 const int by = get_group_id(1);
404
405 // Global indices
406 const int row = by * block_size + ty;
407 const int col = bx * block_size + tx;
408
409 // Shared memory for tile
410 __local float As[16][16];
411 __local float Bs[16][16];
412
413 float sum = 0.0f;
414
415 // Loop over tiles
416 for (int t = 0; t < (k + block_size - 1) / block_size; t++) {
417 // Load tiles
418 if (row < m && t * block_size + tx < k) {
419 As[ty][tx] = a[row * k + t * block_size + tx];
420 } else {
421 As[ty][tx] = 0.0f;
422 }
423
424 if (t * block_size + ty < k && col < n) {
425 Bs[ty][tx] = b[(t * block_size + ty) * n + col];
426 } else {
427 Bs[ty][tx] = 0.0f;
428 }
429
430 barrier(CLK_LOCAL_MEM_FENCE);
431
432 // Compute
433 for (int i = 0; i < block_size; i++) {
434 sum += As[ty][i] * Bs[i][tx];
435 }
436
437 barrier(CLK_LOCAL_MEM_FENCE);
438 }
439
440 // Write result
441 if (row < m && col < n) {
442 const int c_idx = row * n + col;
443 if (beta == 0.0f) {
444 c[c_idx] = alpha * sum;
445 } else {
446 c[c_idx] = alpha * sum + beta * c[c_idx];
447 }
448 }
449}
450"#
451 .to_string();
452
453 let rocm_source = cuda_source.clone();
455
456 (
457 name,
458 cuda_source,
459 rocm_source,
460 wgpu_source,
461 metal_source,
462 opencl_source,
463 )
464 }
465
466 fn generate_kernel(
468 datatype: DataType,
469 m: usize,
470 n: usize,
471 k: usize,
472 ) -> Result<GemmKernel, GpuError> {
473 let implementation = if datatype == DataType::Float16 || datatype == DataType::BFloat16 {
475 GemmImpl::TensorCore
477 } else if m >= 1024 && n >= 1024 && k >= 1024 {
478 GemmImpl::Large
480 } else if m <= 128 && n <= 128 && k <= 128 {
481 GemmImpl::Small
483 } else {
484 GemmImpl::Standard
486 };
487
488 Ok(GemmKernel::with_implementation(implementation))
489 }
490}
491
492impl GpuKernel for GemmKernel {
493 fn name(&self) -> &str {
494 self.base.name()
495 }
496
497 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
498 self.base.source_for_backend(backend)
499 }
500
501 fn metadata(&self) -> KernelMetadata {
502 self.base.metadata()
503 }
504
505 fn can_specialize(&self, params: &KernelParams) -> bool {
506 match params.datatype {
508 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16 => {
509 params.input_dims.len() >= 2 && params.output_dims.len() >= 2
510 }
511 _ => false,
512 }
513 }
514
515 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
516 if !self.can_specialize(params) {
517 return Err(GpuError::SpecializationNotSupported);
518 }
519
520 let m = params.input_dims.first().copied().unwrap_or(0);
522 let k = params.input_dims.get(1).copied().unwrap_or(0);
523 let n = params.output_dims.get(1).copied().unwrap_or(0);
524
525 let specialized = Self::generate_kernel(params.datatype, m, n, k)?;
527
528 Ok(Box::new(specialized))
529 }
530}