scirs2_core/gpu/kernels/blas/
gemv.rs

1//! General matrix-vector multiplication (GEMV) kernels for GPU
2//!
3//! Implements y = alpha * A * x + beta * y where:
4//! - A is an M x N matrix
5//! - x is an N-dimensional vector
6//! - y is an M-dimensional vector
7//! - alpha and beta are scalar values
8
9use std::collections::HashMap;
10
11use crate::gpu::kernels::{
12    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
13};
14use crate::gpu::{GpuBackend, GpuError};
15
16/// General matrix-vector multiplication kernel
17pub struct GemvKernel {
18    base: BaseKernel,
19}
20
21impl Default for GemvKernel {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl GemvKernel {
28    /// Create a new GEMV kernel
29    pub fn new() -> Self {
30        let metadata = KernelMetadata {
31            workgroup_size: [256, 1, 1],
32            local_memory_usage: 1024, // 1 KB local memory for reduction
33            supports_tensor_cores: false,
34            operationtype: OperationType::ComputeIntensive,
35            backend_metadata: HashMap::new(),
36        };
37
38        let cuda_source = r#"
39extern "C" __global__ void gemv(
40    const float* __restrict__ matrix,  // M x N matrix (row-major)
41    const float* __restrict__ vector,  // N-dimensional vector
42    float* __restrict__ result,        // M-dimensional result vector
43    float alpha,
44    float beta,
45    int M,  // Number of rows
46    int N   // Number of columns
47) {
48    int row = blockIdx.x * blockDim.x + threadIdx.x;
49
50    if (row < M) {
51        float sum = 0.0f;
52
53        // Compute dot product of matrix row with vector
54        for (int col = 0; col < N; col++) {
55            sum += matrix[row * N + col] * vector[col];
56        }
57
58        // Apply alpha and beta coefficients
59        result[row] = alpha * sum + beta * result[row];
60    }
61}
62
63// Optimized version using shared memory for larger matrices
64extern "C" __global__ void gemv_shared(
65    const float* __restrict__ matrix,
66    const float* __restrict__ vector,
67    float* __restrict__ result,
68    float alpha,
69    float beta,
70    int M,
71    int N
72) {
73    extern __shared__ float shared_vector[];
74
75    int row = blockIdx.x * blockDim.x + threadIdx.x;
76    int tid = threadIdx.x;
77
78    // Load vector into shared memory in chunks
79    for (int i = tid; i < N; i += blockDim.x) {
80        if (i < N) {
81            shared_vector[i] = vector[i];
82        }
83    }
84    __syncthreads();
85
86    if (row < M) {
87        float sum = 0.0f;
88
89        // Compute dot product using shared memory vector
90        for (int col = 0; col < N; col++) {
91            sum += matrix[row * N + col] * shared_vector[col];
92        }
93
94        result[row] = alpha * sum + beta * result[row];
95    }
96}
97"#
98        .to_string();
99
100        let rocm_source = cuda_source.clone();
101
102        let wgpu_source = r#"
103struct Uniforms {
104    alpha: f32,
105    beta: f32,
106    M: u32,  // Number of rows
107    N: u32,  // Number of columns
108};
109
110@group(0) @binding(0) var<uniform> uniforms: Uniforms;
111@group(0) @binding(1) var<storage, read> matrix: array<f32>;   // M x N matrix
112@group(0) @binding(2) var<storage, read> vector: array<f32>;   // N-dimensional vector
113@group(0) @binding(3) var<storage, write> result: array<f32>;  // M-dimensional result
114
115@compute @workgroup_size(256)
116fn gemv(@builtin(global_invocation_id) global_id: vec3<u32>) {
117    let row = global_id.x;
118
119    if (row < uniforms.M) {
120        var sum = 0.0;
121
122        // Compute dot product of matrix row with vector
123        for (var col = 0u; col < uniforms.N; col = col + 1u) {
124            let matrix_idx = row * uniforms.N + col;
125            sum = sum + matrix[matrix_idx] * vector[col];
126        }
127
128        // Apply alpha and beta coefficients
129        result[row] = uniforms.alpha * sum + uniforms.beta * result[row];
130    }
131}
132"#
133        .to_string();
134
135        let metal_source = r#"
136#include <metal_stdlib>
137using namespace metal;
138
139kernel void gemv(
140    const device float* matrix [[buffer(0)]],    // M x N matrix
141    const device float* vector [[buffer(1)]],    // N-dimensional vector
142    device float* result [[buffer(2)]],          // M-dimensional result
143    constant float& alpha [[buffer(3)]],
144    constant float& beta [[buffer(4)]],
145    constant uint& M [[buffer(5)]],              // Number of rows
146    constant uint& N [[buffer(6)]],              // Number of columns
147    uint gid [[thread_position_in_grid]])
148{
149    if (gid < M) {
150        float sum = 0.0f;
151
152        // Compute dot product of matrix row with vector
153        for (uint col = 0; col < N; col++) {
154            sum += matrix[gid * N + col] * vector[col];
155        }
156
157        // Apply alpha and beta coefficients
158        result[gid] = alpha * sum + beta * result[gid];
159    }
160}
161
162// Optimized version using threadgroup memory
163kernel void gemv_tiled(
164    const device float* matrix [[buffer(0)]],
165    const device float* vector [[buffer(1)]],
166    device float* result [[buffer(2)]],
167    constant float& alpha [[buffer(3)]],
168    constant float& beta [[buffer(4)]],
169    constant uint& M [[buffer(5)]],
170    constant uint& N [[buffer(6)]],
171    uint gid [[thread_position_in_grid]],
172    uint lid [[thread_position_in_threadgroup]],
173    uint blockSize [[threads_per_threadgroup]])
174{
175    threadgroup float shared_vector[256];  // Shared vector storage
176
177    // Load vector into threadgroup memory
178    for (uint i = lid; i < N; i += blockSize) {
179        if (i < N) {
180            shared_vector[i] = vector[i];
181        }
182    }
183    threadgroup_barrier(mem_flags::mem_threadgroup);
184
185    if (gid < M) {
186        float sum = 0.0f;
187
188        // Compute using shared vector
189        for (uint col = 0; col < N; col++) {
190            sum += matrix[gid * N + col] * shared_vector[col];
191        }
192
193        result[gid] = alpha * sum + beta * result[gid];
194    }
195}
196"#
197        .to_string();
198
199        let opencl_source = r#"
200__kernel void gemv(
201    __global const float* matrix,   // M x N matrix
202    __global const float* vector,   // N-dimensional vector
203    __global float* result,         // M-dimensional result
204    const float alpha,
205    const float beta,
206    const int M,                    // Number of rows
207    const int N)                    // Number of columns
208{
209    int row = get_global_id(0);
210
211    if (row < M) {
212        float sum = 0.0f;
213
214        // Compute dot product of matrix row with vector
215        for (int col = 0; col < N; col++) {
216            sum += matrix[row * N + col] * vector[col];
217        }
218
219        // Apply alpha and beta coefficients
220        result[row] = alpha * sum + beta * result[row];
221    }
222}
223
224// Version with local memory optimization
225__kernel void gemv_local(
226    __global const float* matrix,
227    __global const float* vector,
228    __global float* result,
229    const float alpha,
230    const float beta,
231    const int M,
232    const int N,
233    __local float* local_vector)
234{
235    int row = get_global_id(0);
236    int lid = get_local_id(0);
237    int local_size = get_local_size(0);
238
239    // Load vector into local memory
240    for (int i = lid; i < N; i += local_size) {
241        if (i < N) {
242            local_vector[i] = vector[i];
243        }
244    }
245    barrier(CLK_LOCAL_MEM_FENCE);
246
247    if (row < M) {
248        float sum = 0.0f;
249
250        // Compute using local vector
251        for (int col = 0; col < N; col++) {
252            sum += matrix[row * N + col] * local_vector[col];
253        }
254
255        result[row] = alpha * sum + beta * result[row];
256    }
257}
258"#
259        .to_string();
260
261        Self {
262            base: BaseKernel::new(
263                "gemv",
264                &cuda_source,
265                &rocm_source,
266                &wgpu_source,
267                &metal_source,
268                &opencl_source,
269                metadata,
270            ),
271        }
272    }
273}
274
275impl GpuKernel for GemvKernel {
276    fn name(&self) -> &str {
277        self.base.name()
278    }
279
280    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
281        self.base.source_for_backend(backend)
282    }
283
284    fn metadata(&self) -> KernelMetadata {
285        self.base.metadata()
286    }
287
288    fn can_specialize(&self, params: &KernelParams) -> bool {
289        matches!(params.datatype, DataType::Float32 | DataType::Float64)
290    }
291
292    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
293        if !self.can_specialize(params) {
294            return Err(GpuError::SpecializationNotSupported);
295        }
296
297        // For now, return the same kernel (no type specialization implemented yet)
298        Ok(Box::new(Self::new()))
299    }
300}
301
302/// Batched GEMV kernel for processing multiple matrix-vector multiplications
303pub struct BatchGemvKernel {
304    base: BaseKernel,
305}
306
307impl Default for BatchGemvKernel {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313impl BatchGemvKernel {
314    /// Create a new batched GEMV kernel
315    pub fn new() -> Self {
316        let metadata = KernelMetadata {
317            workgroup_size: [16, 16, 1],
318            local_memory_usage: 2048,
319            supports_tensor_cores: false,
320            operationtype: OperationType::ComputeIntensive,
321            backend_metadata: HashMap::new(),
322        };
323
324        let cuda_source = r#"
325extern "C" __global__ void batch_gemv(
326    const float* __restrict__ matrices,  // Batch of M x N matrices
327    const float* __restrict__ vectors,   // Batch of N-dimensional vectors
328    float* __restrict__ results,         // Batch of M-dimensional results
329    float alpha,
330    float beta,
331    int batch_size,
332    int M,  // Number of rows per matrix
333    int N   // Number of columns per matrix
334) {
335    int batch_idx = blockIdx.z;
336    int row = blockIdx.x * blockDim.x + threadIdx.x;
337
338    if (batch_idx < batch_size && row < M) {
339        // Calculate offsets for this batch
340        int matrix_offset = batch_idx * M * N;
341        int vector_offset = batch_idx * N;
342        int result_offset = batch_idx * M;
343
344        float sum = 0.0f;
345
346        // Compute dot product of matrix row with vector
347        for (int col = 0; col < N; col++) {
348            sum += matrices[matrix_offset + row * N + col] *
349                   vectors[vector_offset + col];
350        }
351
352        // Apply alpha and beta coefficients
353        results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
354    }
355}
356"#
357        .to_string();
358
359        let rocm_source = cuda_source.clone();
360
361        let wgpu_source = r#"
362struct Uniforms {
363    alpha: f32,
364    beta: f32,
365    batch_size: u32,
366    M: u32,  // Number of rows per matrix
367    N: u32,  // Number of columns per matrix
368};
369
370@group(0) @binding(0) var<uniform> uniforms: Uniforms;
371@group(0) @binding(1) var<storage, read> matrices: array<f32>;  // Batch of matrices
372@group(0) @binding(2) var<storage, read> vectors: array<f32>;   // Batch of vectors
373@group(0) @binding(3) var<storage, write> results: array<f32>;  // Batch of results
374
375@compute @workgroup_size(16, 16, 1)
376fn batch_gemv(@builtin(global_invocation_id) global_id: vec3<u32>) {
377    let batch_idx = global_id.z;
378    let row = global_id.x;
379
380    if (batch_idx < uniforms.batch_size && row < uniforms.M) {
381        // Calculate offsets for this batch
382        let matrix_offset = batch_idx * uniforms.M * uniforms.N;
383        let vector_offset = batch_idx * uniforms.N;
384        let result_offset = batch_idx * uniforms.M;
385
386        var sum = 0.0;
387
388        // Compute dot product
389        for (var col = 0u; col < uniforms.N; col = col + 1u) {
390            let matrix_idx = matrix_offset + row * uniforms.N + col;
391            let vector_idx = vector_offset + col;
392            sum = sum + matrices[matrix_idx] * vectors[vector_idx];
393        }
394
395        // Apply coefficients
396        let result_idx = result_offset + row;
397        results[result_idx] = uniforms.alpha * sum + uniforms.beta * results[result_idx];
398    }
399}
400"#
401        .to_string();
402
403        let metal_source = r#"
404#include <metal_stdlib>
405using namespace metal;
406
407kernel void batch_gemv(
408    const device float* matrices [[buffer(0)]],   // Batch of matrices
409    const device float* vectors [[buffer(1)]],    // Batch of vectors
410    device float* results [[buffer(2)]],          // Batch of results
411    constant float& alpha [[buffer(3)]],
412    constant float& beta [[buffer(4)]],
413    constant uint& batch_size [[buffer(5)]],
414    constant uint& M [[buffer(6)]],               // Rows per matrix
415    constant uint& N [[buffer(7)]],               // Columns per matrix
416    uint3 gid [[thread_position_in_grid]])
417{
418    uint batch_idx = gid.z;
419    uint row = gid.x;
420
421    if (batch_idx < batch_size && row < M) {
422        // Calculate offsets
423        uint matrix_offset = batch_idx * M * N;
424        uint vector_offset = batch_idx * N;
425        uint result_offset = batch_idx * M;
426
427        float sum = 0.0f;
428
429        // Compute dot product
430        for (uint col = 0; col < N; col++) {
431            sum += matrices[matrix_offset + row * N + col] *
432                   vectors[vector_offset + col];
433        }
434
435        // Apply coefficients
436        results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
437    }
438}
439"#
440        .to_string();
441
442        let opencl_source = r#"
443__kernel void batch_gemv(
444    __global const float* matrices,
445    __global const float* vectors,
446    __global float* results,
447    const float alpha,
448    const float beta,
449    const int batch_size,
450    const int M,
451    const int N)
452{
453    int batch_idx = get_global_id(2);
454    int row = get_global_id(0);
455
456    if (batch_idx < batch_size && row < M) {
457        // Calculate offsets
458        int matrix_offset = batch_idx * M * N;
459        int vector_offset = batch_idx * N;
460        int result_offset = batch_idx * M;
461
462        float sum = 0.0f;
463
464        // Compute dot product
465        for (int col = 0; col < N; col++) {
466            sum += matrices[matrix_offset + row * N + col] *
467                   vectors[vector_offset + col];
468        }
469
470        // Apply coefficients
471        results[result_offset + row] = alpha * sum + beta * results[result_offset + row];
472    }
473}
474"#
475        .to_string();
476
477        Self {
478            base: BaseKernel::new(
479                "batch_gemv",
480                &cuda_source,
481                &rocm_source,
482                &wgpu_source,
483                &metal_source,
484                &opencl_source,
485                metadata,
486            ),
487        }
488    }
489}
490
491impl GpuKernel for BatchGemvKernel {
492    fn name(&self) -> &str {
493        self.base.name()
494    }
495
496    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
497        self.base.source_for_backend(backend)
498    }
499
500    fn metadata(&self) -> KernelMetadata {
501        self.base.metadata()
502    }
503
504    fn can_specialize(&self, params: &KernelParams) -> bool {
505        matches!(params.datatype, DataType::Float32 | DataType::Float64)
506    }
507
508    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
509        if !self.can_specialize(params) {
510            return Err(GpuError::SpecializationNotSupported);
511        }
512
513        Ok(Box::new(Self::new()))
514    }
515}