scirs2_core/gpu/kernels/blas/
gemv.rs1use std::collections::HashMap;
10
11use crate::gpu::kernels::{
12 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
13};
14use crate::gpu::{GpuBackend, GpuError};
15
16pub struct GemvKernel {
18 base: BaseKernel,
19}
20
21impl Default for GemvKernel {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl GemvKernel {
28 pub fn new() -> Self {
30 let metadata = KernelMetadata {
31 workgroup_size: [256, 1, 1],
32 local_memory_usage: 1024, supports_tensor_cores: false,
34 operationtype: OperationType::ComputeIntensive,
35 backend_metadata: HashMap::new(),
36 };
37
38 let cuda_source = r#"
39extern "C" __global__ void gemv(
40 const float* __restrict__ matrix, // M x N matrix (row-major)
41 const float* __restrict__ vector, // N-dimensional vector
42 float* __restrict__ result, // M-dimensional result vector
43 float alpha,
44 float beta,
45 int M, // Number of rows
46 int N // Number of columns
47) {
48 int row = blockIdx.x * blockDim.x + threadIdx.x;
49
50 if (row < M) {
51 float sum = 0.0f;
52
53 // Compute dot product of matrix row with vector
54 for (int col = 0; col < N; col++) {
55 sum += matrix[row * N + col] * vector[col];
56 }
57
58 // Apply alpha and beta coefficients
59 result[row] = alpha * sum + beta * result[row];
60 }
61}
62
63// Optimized version using shared memory for larger matrices
64extern "C" __global__ void gemv_shared(
65 const float* __restrict__ matrix,
66 const float* __restrict__ vector,
67 float* __restrict__ result,
68 float alpha,
69 float beta,
70 int M,
71 int N
72) {
73 extern __shared__ float shared_vector[];
74
75 int row = blockIdx.x * blockDim.x + threadIdx.x;
76 int tid = threadIdx.x;
77
78 // Load vector into shared memory in chunks
79 for (int i = tid; i < N; i += blockDim.x) {
80 if (i < N) {
81 shared_vector[i] = vector[i];
82 }
83 }
84 __syncthreads();
85
86 if (row < M) {
87 float sum = 0.0f;
88
89 // Compute dot product using shared memory vector
90 for (int col = 0; col < N; col++) {
91 sum += matrix[row * N + col] * shared_vector[col];
92 }
93
94 result[row] = alpha * sum + beta * result[row];
95 }
96}
97"#
98 .to_string();
99
100 let rocm_source = cuda_source.clone();
101
102 let wgpu_source = r#"
103struct Uniforms {
104 alpha: f32,
105 beta: f32,
106 M: u32, // Number of rows
107 N: u32, // Number of columns
108};
109
110@group(0) @binding(0) var<uniform> uniforms: Uniforms;
111@group(0) @binding(1) var<storage, read> matrix: array<f32>; // M x N matrix
112@group(0) @binding(2) var<storage, read> vector: array<f32>; // N-dimensional vector
113@group(0) @binding(3) var<storage, write> result: array<f32>; // M-dimensional result
114
115@compute @workgroup_size(256)
116fn gemv(@builtin(global_invocation_id) global_id: vec3<u32>) {
117 let row = global_id.x;
118
119 if (row < uniforms.M) {
120 var sum = 0.0;
121
122 // Compute dot product of matrix row with vector
123 for (var col = 0u; col < uniforms.N; col = col + 1u) {
124 let matrix_idx = row * uniforms.N + col;
125 sum = sum + matrix[matrix_idx] * vector[col];
126 }
127
128 // Apply alpha and beta coefficients
129 result[row] = uniforms.alpha * sum + uniforms.beta * result[row];
130 }
131}
132"#
133 .to_string();
134
135 let metal_source = r#"
136#include <metal_stdlib>
137using namespace metal;
138
139kernel void gemv(
140 const device float* matrix [[buffer(0)]], // M x N matrix
141 const device float* vector [[buffer(1)]], // N-dimensional vector
142 device float* result [[buffer(2)]], // M-dimensional result
143 constant float& alpha [[buffer(3)]],
144 constant float& beta [[buffer(4)]],
145 constant uint& M [[buffer(5)]], // Number of rows
146 constant uint& N [[buffer(6)]], // Number of columns
147 uint gid [[thread_position_in_grid]])
148{
149 if (gid < M) {
150 float sum = 0.0f;
151
152 // Compute dot product of matrix row with vector
153 for (uint col = 0; col < N; col++) {
154 sum += matrix[gid * N + col] * vector[col];
155 }
156
157 // Apply alpha and beta coefficients
158 result[gid] = alpha * sum + beta * result[gid];
159 }
160}
161
162// Optimized version using threadgroup memory
163kernel void gemv_tiled(
164 const device float* matrix [[buffer(0)]],
165 const device float* vector [[buffer(1)]],
166 device float* result [[buffer(2)]],
167 constant float& alpha [[buffer(3)]],
168 constant float& beta [[buffer(4)]],
169 constant uint& M [[buffer(5)]],
170 constant uint& N [[buffer(6)]],
171 uint gid [[thread_position_in_grid]],
172 uint lid [[thread_position_in_threadgroup]],
173 uint blockSize [[threads_per_threadgroup]])
174{
175 threadgroup float shared_vector[256]; // Shared vector storage
176
177 // Load vector into threadgroup memory
178 for (uint i = lid; i < N; i += blockSize) {
179 if (i < N) {
180 shared_vector[i] = vector[i];
181 }
182 }
183 threadgroup_barrier(mem_flags::mem_threadgroup);
184
185 if (gid < M) {
186 float sum = 0.0f;
187
188 // Compute using shared vector
189 for (uint col = 0; col < N; col++) {
190 sum += matrix[gid * N + col] * shared_vector[col];
191 }
192
193 result[gid] = alpha * sum + beta * result[gid];
194 }
195}
196"#
197 .to_string();
198
199 let opencl_source = r#"
200__kernel void gemv(
201 __global const float* matrix, // M x N matrix
202 __global const float* vector, // N-dimensional vector
203 __global float* result, // M-dimensional result
204 const float alpha,
205 const float beta,
206 const int M, // Number of rows
207 const int N) // Number of columns
208{
209 int row = get_global_id(0);
210
211 if (row < M) {
212 float sum = 0.0f;
213
214 // Compute dot product of matrix row with vector
215 for (int col = 0; col < N; col++) {
216 sum += matrix[row * N + col] * vector[col];
217 }
218
219 // Apply alpha and beta coefficients
220 result[row] = alpha * sum + beta * result[row];
221 }
222}
223
224// Version with local memory optimization
225__kernel void gemv_local(
226 __global const float* matrix,
227 __global const float* vector,
228 __global float* result,
229 const float alpha,
230 const float beta,
231 const int M,
232 const int N,
233 __local float* local_vector)
234{
235 int row = get_global_id(0);
236 int lid = get_local_id(0);
237 int local_size = get_local_size(0);
238
239 // Load vector into local memory
240 for (int i = lid; i < N; i += local_size) {
241 if (i < N) {
242 local_vector[i] = vector[i];
243 }
244 }
245 barrier(CLK_LOCAL_MEM_FENCE);
246
247 if (row < M) {
248 float sum = 0.0f;
249
250 // Compute using local vector
251 for (int col = 0; col < N; col++) {
252 sum += matrix[row * N + col] * local_vector[col];
253 }
254
255 result[row] = alpha * sum + beta * result[row];
256 }
257}
258"#
259 .to_string();
260
261 Self {
262 base: BaseKernel::new(
263 "gemv",
264 &cuda_source,
265 &rocm_source,
266 &wgpu_source,
267 &metal_source,
268 &opencl_source,
269 metadata,
270 ),
271 }
272 }
273}
274
275impl GpuKernel for GemvKernel {
276 fn name(&self) -> &str {
277 self.base.name()
278 }
279
280 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
281 self.base.source_for_backend(backend)
282 }
283
284 fn metadata(&self) -> KernelMetadata {
285 self.base.metadata()
286 }
287
288 fn can_specialize(&self, params: &KernelParams) -> bool {
289 matches!(params.datatype, DataType::Float32 | DataType::Float64)
290 }
291
292 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
293 if !self.can_specialize(params) {
294 return Err(GpuError::SpecializationNotSupported);
295 }
296
297 Ok(Box::new(Self::new()))
299 }
300}
301
302pub struct BatchGemvKernel {
304 base: BaseKernel,
305}
306
307impl Default for BatchGemvKernel {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313impl BatchGemvKernel {
314 pub fn new() -> Self {
316 let metadata = KernelMetadata {
317 workgroup_size: [16, 16, 1],
318 local_memory_usage: 2048,
319 supports_tensor_cores: false,
320 operationtype: OperationType::ComputeIntensive,
321 backend_metadata: HashMap::new(),
322 };
323
324 let cuda_source = r#"
325extern "C" __global__ void batch_gemv(
326 const float* __restrict__ matrices, // Batch of M x N matrices
327 const float* __restrict__ vectors, // Batch of N-dimensional vectors
328 float* __restrict__ results, // Batch of M-dimensional results
329 float alpha,
330 float beta,
331 int batch_size,
332 int M, // Number of rows per matrix
333 int N // Number of columns per matrix
334) {
335 int batch_idx = blockIdx.z;
336 int row = blockIdx.x * blockDim.x + threadIdx.x;
337
338 if (batch_idx < batch_size && row < M) {
339 // Calculate offsets for this batch
340 int matrix_offset = batch_idx * M * N;
341 int vector_offset = batch_idx * N;
342 int result_offset = batch_idx * M;
343
344 float sum = 0.0f;
345
346 // Compute dot product of matrix row with vector
347 for (int col = 0; col < N; col++) {
348 sum += matrices[matrix_offset + row * N + col] *
349 vectors[vector_offset + col];
350 }
351
352 // Apply alpha and beta coefficients
353 results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
354 }
355}
356"#
357 .to_string();
358
359 let rocm_source = cuda_source.clone();
360
361 let wgpu_source = r#"
362struct Uniforms {
363 alpha: f32,
364 beta: f32,
365 batch_size: u32,
366 M: u32, // Number of rows per matrix
367 N: u32, // Number of columns per matrix
368};
369
370@group(0) @binding(0) var<uniform> uniforms: Uniforms;
371@group(0) @binding(1) var<storage, read> matrices: array<f32>; // Batch of matrices
372@group(0) @binding(2) var<storage, read> vectors: array<f32>; // Batch of vectors
373@group(0) @binding(3) var<storage, write> results: array<f32>; // Batch of results
374
375@compute @workgroup_size(16, 16, 1)
376fn batch_gemv(@builtin(global_invocation_id) global_id: vec3<u32>) {
377 let batch_idx = global_id.z;
378 let row = global_id.x;
379
380 if (batch_idx < uniforms.batch_size && row < uniforms.M) {
381 // Calculate offsets for this batch
382 let matrix_offset = batch_idx * uniforms.M * uniforms.N;
383 let vector_offset = batch_idx * uniforms.N;
384 let result_offset = batch_idx * uniforms.M;
385
386 var sum = 0.0;
387
388 // Compute dot product
389 for (var col = 0u; col < uniforms.N; col = col + 1u) {
390 let matrix_idx = matrix_offset + row * uniforms.N + col;
391 let vector_idx = vector_offset + col;
392 sum = sum + matrices[matrix_idx] * vectors[vector_idx];
393 }
394
395 // Apply coefficients
396 let result_idx = result_offset + row;
397 results[result_idx] = uniforms.alpha * sum + uniforms.beta * results[result_idx];
398 }
399}
400"#
401 .to_string();
402
403 let metal_source = r#"
404#include <metal_stdlib>
405using namespace metal;
406
407kernel void batch_gemv(
408 const device float* matrices [[buffer(0)]], // Batch of matrices
409 const device float* vectors [[buffer(1)]], // Batch of vectors
410 device float* results [[buffer(2)]], // Batch of results
411 constant float& alpha [[buffer(3)]],
412 constant float& beta [[buffer(4)]],
413 constant uint& batch_size [[buffer(5)]],
414 constant uint& M [[buffer(6)]], // Rows per matrix
415 constant uint& N [[buffer(7)]], // Columns per matrix
416 uint3 gid [[thread_position_in_grid]])
417{
418 uint batch_idx = gid.z;
419 uint row = gid.x;
420
421 if (batch_idx < batch_size && row < M) {
422 // Calculate offsets
423 uint matrix_offset = batch_idx * M * N;
424 uint vector_offset = batch_idx * N;
425 uint result_offset = batch_idx * M;
426
427 float sum = 0.0f;
428
429 // Compute dot product
430 for (uint col = 0; col < N; col++) {
431 sum += matrices[matrix_offset + row * N + col] *
432 vectors[vector_offset + col];
433 }
434
435 // Apply coefficients
436 results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
437 }
438}
439"#
440 .to_string();
441
442 let opencl_source = r#"
443__kernel void batch_gemv(
444 __global const float* matrices,
445 __global const float* vectors,
446 __global float* results,
447 const float alpha,
448 const float beta,
449 const int batch_size,
450 const int M,
451 const int N)
452{
453 int batch_idx = get_global_id(2);
454 int row = get_global_id(0);
455
456 if (batch_idx < batch_size && row < M) {
457 // Calculate offsets
458 int matrix_offset = batch_idx * M * N;
459 int vector_offset = batch_idx * N;
460 int result_offset = batch_idx * M;
461
462 float sum = 0.0f;
463
464 // Compute dot product
465 for (int col = 0; col < N; col++) {
466 sum += matrices[matrix_offset + row * N + col] *
467 vectors[vector_offset + col];
468 }
469
470 // Apply coefficients
471 results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
472 }
473}
474"#
475 .to_string();
476
477 Self {
478 base: BaseKernel::new(
479 "batch_gemv",
480 &cuda_source,
481 &rocm_source,
482 &wgpu_source,
483 &metal_source,
484 &opencl_source,
485 metadata,
486 ),
487 }
488 }
489}
490
491impl GpuKernel for BatchGemvKernel {
492 fn name(&self) -> &str {
493 self.base.name()
494 }
495
496 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
497 self.base.source_for_backend(backend)
498 }
499
500 fn metadata(&self) -> KernelMetadata {
501 self.base.metadata()
502 }
503
504 fn can_specialize(&self, params: &KernelParams) -> bool {
505 matches!(params.datatype, DataType::Float32 | DataType::Float64)
506 }
507
508 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
509 if !self.can_specialize(params) {
510 return Err(GpuError::SpecializationNotSupported);
511 }
512
513 Ok(Box::new(Self::new()))
514 }
515}