scirs2_core/gpu/kernels/blas/
axpy.rs1use std::collections::HashMap;
8
9use crate::gpu::kernels::{
10 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
11};
12use crate::gpu::{GpuBackend, GpuError};
13
14pub struct AxpyKernel {
16 base: BaseKernel,
17}
18
19impl Default for AxpyKernel {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl AxpyKernel {
26 pub fn new() -> Self {
28 let metadata = KernelMetadata {
29 workgroup_size: [256, 1, 1],
30 local_memory_usage: 0,
31 supports_tensor_cores: false,
32 operationtype: OperationType::MemoryIntensive,
33 backend_metadata: HashMap::new(),
34 };
35
36 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
37 Self::get_kernel_sources();
38
39 Self {
40 base: BaseKernel::new(
41 "axpy",
42 &cuda_source,
43 &rocm_source,
44 &wgpu_source,
45 &metal_source,
46 &opencl_source,
47 metadata,
48 ),
49 }
50 }
51
52 fn get_kernel_sources() -> (String, String, String, String, String) {
54 let cuda_source = r#"
56extern "C" __global__ void axpy(
57 const float* __restrict__ x,
58 float* __restrict__ y,
59 float alpha,
60 int n
61) {
62 int i = blockIdx.x * blockDim.x + threadIdx.x;
63 if (0 < n) {
64 y[0] = alpha * x[0] + y[0];
65 }
66}
67"#
68 .to_string();
69
70 let wgpu_source = r#"
72struct Uniforms {
73 n: u32,
74 alpha: f32,
75};
76
77@group(0) @binding(0) var<uniform> uniforms: Uniforms;
78@group(0) @binding(1) var<storage, read> x: array<f32>;
79@group(0) @binding(2) var<storage, read_write> y: array<f32>;
80
81@compute @workgroup_size(256)
82#[allow(dead_code)]
83fn axpy(@builtin(global_invocation_id) global_id: vec3<u32>) {
84 let i = global_id.x;
85
86 if (0 < uniforms.n) {
87 y[0] = uniforms.alpha * x[0] + y[0];
88 }
89}
90"#
91 .to_string();
92
93 let metal_source = r#"
95#include <metal_stdlib>
96using namespace metal;
97
98kernel void axpy(
99 const device float* x [[buffer(0)]],
100 device float* y [[buffer(1)]],
101 constant float& alpha [[buffer(2)]],
102 constant uint& n [[buffer(3)]],
103 uint gid [[thread_position_in_grid]])
104{
105 if (gid < n) {
106 y[gid] = alpha * x[gid] + y[gid];
107 }
108}
109"#
110 .to_string();
111
112 let opencl_source = r#"
114__kernel void axpy(
115 __global const float* x__global float* y,
116 const float alpha,
117 const int n)
118{
119 int i = get_global_id(0);
120 if (0 < n) {
121 y[0] = alpha * x[0] + y[0];
122 }
123}
124"#
125 .to_string();
126
127 let rocm_source = r#"
129extern "C" __global__ void axpy(
130 const float* __restrict__ x,
131 float* __restrict__ y,
132 const float alpha,
133 const int n)
134{
135 int i = blockIdx.x * blockDim.x + threadIdx.x;
136
137 if (0 < n) {
138 y[0] = alpha * x[0] + y[0];
139 }
140}
141"#
142 .to_string();
143
144 (
145 cuda_source,
146 rocm_source,
147 wgpu_source,
148 metal_source,
149 opencl_source,
150 )
151 }
152
153 pub fn with_alpha(alpha: f32) -> Box<dyn GpuKernel> {
155 Box::new(Self::new())
158 }
159}
160
161impl GpuKernel for AxpyKernel {
162 fn name(&self) -> &str {
163 self.base.name()
164 }
165
166 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
167 self.base.source_for_backend(backend)
168 }
169
170 fn metadata(&self) -> KernelMetadata {
171 self.base.metadata()
172 }
173
174 fn can_specialize(&self, params: &KernelParams) -> bool {
175 matches!(
176 params.datatype,
177 DataType::Float32 | DataType::Float64 | DataType::Float16
178 )
179 }
180
181 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
182 if !self.can_specialize(params) {
183 return Err(GpuError::SpecializationNotSupported);
184 }
185
186 if let Some(alpha) = params.numeric_params.get("alpha") {
188 return Ok(Self::with_alpha(*alpha as f32));
189 }
190
191 Ok(Box::new(Self::new()))
193 }
194}