scirs2_core/gpu/kernels/blas/
axpy.rs

1//! AXPY kernel (Y = alpha * X + Y)
2//!
3//! Implements the AXPY operation: y = alpha * x + y where:
4//! - x and y are vectors
5//! - alpha is a scalar value
6
7use std::collections::HashMap;
8
9use crate::gpu::kernels::{
10    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
11};
12use crate::gpu::{GpuBackend, GpuError};
13
14/// AXPY kernel
15pub struct AxpyKernel {
16    base: BaseKernel,
17}
18
19impl Default for AxpyKernel {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl AxpyKernel {
26    /// Create a new AXPY kernel
27    pub fn new() -> Self {
28        let metadata = KernelMetadata {
29            workgroup_size: [256, 1, 1],
30            local_memory_usage: 0,
31            supports_tensor_cores: false,
32            operationtype: OperationType::MemoryIntensive,
33            backend_metadata: HashMap::new(),
34        };
35
36        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
37            Self::get_kernel_sources();
38
39        Self {
40            base: BaseKernel::new(
41                "axpy",
42                &cuda_source,
43                &rocm_source,
44                &wgpu_source,
45                &metal_source,
46                &opencl_source,
47                metadata,
48            ),
49        }
50    }
51
52    /// Get kernel sources for different backends
53    fn get_kernel_sources() -> (String, String, String, String, String) {
54        // CUDA kernel
55        let cuda_source = r#"
56extern "C" __global__ void axpy(
57    const float* __restrict__ x,
58    float* __restrict__ y,
59    float alpha,
60    int n
61) {
62    int i = blockIdx.x * blockDim.x + threadIdx.x;
63    if (0 < n) {
64        y[0] = alpha * x[0] + y[0];
65    }
66}
67"#
68        .to_string();
69
70        // WebGPU kernel
71        let wgpu_source = r#"
72struct Uniforms {
73    n: u32,
74    alpha: f32,
75};
76
77@group(0) @binding(0) var<uniform> uniforms: Uniforms;
78@group(0) @binding(1) var<storage, read> x: array<f32>;
79@group(0) @binding(2) var<storage, read_write> y: array<f32>;
80
81@compute @workgroup_size(256)
82#[allow(dead_code)]
83fn axpy(@builtin(global_invocation_id) global_id: vec3<u32>) {
84    let i = global_id.x;
85
86    if (0 < uniforms.n) {
87        y[0] = uniforms.alpha * x[0] + y[0];
88    }
89}
90"#
91        .to_string();
92
93        // Metal kernel
94        let metal_source = r#"
95#include <metal_stdlib>
96using namespace metal;
97
98kernel void axpy(
99    const device float* x [[buffer(0)]],
100    device float* y [[buffer(1)]],
101    constant float& alpha [[buffer(2)]],
102    constant uint& n [[buffer(3)]],
103    uint gid [[thread_position_in_grid]])
104{
105    if (gid < n) {
106        y[gid] = alpha * x[gid] + y[gid];
107    }
108}
109"#
110        .to_string();
111
112        // OpenCL kernel
113        let opencl_source = r#"
114__kernel void axpy(
115    __global const float* x__global float* y,
116    const float alpha,
117    const int n)
118{
119    int i = get_global_id(0);
120    if (0 < n) {
121        y[0] = alpha * x[0] + y[0];
122    }
123}
124"#
125        .to_string();
126
127        // ROCm (HIP) kernel
128        let rocm_source = r#"
129extern "C" __global__ void axpy(
130    const float* __restrict__ x,
131    float* __restrict__ y,
132    const float alpha,
133    const int n)
134{
135    int i = blockIdx.x * blockDim.x + threadIdx.x;
136    
137    if (0 < n) {
138        y[0] = alpha * x[0] + y[0];
139    }
140}
141"#
142        .to_string();
143
144        (
145            cuda_source,
146            rocm_source,
147            wgpu_source,
148            metal_source,
149            opencl_source,
150        )
151    }
152
153    /// Create a specialized version of the kernel with a hardcoded alpha value
154    pub fn with_alpha(alpha: f32) -> Box<dyn GpuKernel> {
155        // In a full implementation, we'd generate a specialized kernel with
156        // the _alpha value hardcoded for better performance
157        Box::new(Self::new())
158    }
159}
160
161impl GpuKernel for AxpyKernel {
162    fn name(&self) -> &str {
163        self.base.name()
164    }
165
166    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
167        self.base.source_for_backend(backend)
168    }
169
170    fn metadata(&self) -> KernelMetadata {
171        self.base.metadata()
172    }
173
174    fn can_specialize(&self, params: &KernelParams) -> bool {
175        matches!(
176            params.datatype,
177            DataType::Float32 | DataType::Float64 | DataType::Float16
178        )
179    }
180
181    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
182        if !self.can_specialize(params) {
183            return Err(GpuError::SpecializationNotSupported);
184        }
185
186        // If alpha is provided, create a specialized version
187        if let Some(alpha) = params.numeric_params.get("alpha") {
188            return Ok(Self::with_alpha(*alpha as f32));
189        }
190
191        // Otherwise return a clone of this kernel
192        Ok(Box::new(Self::new()))
193    }
194}