scirs2_core/gpu/kernels/reduction/
sum.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct SumKernel {
14 base: BaseKernel,
15}
16
17impl SumKernel {
18 pub fn new() -> Self {
20 let metadata = KernelMetadata {
21 workgroup_size: [256, 1, 1],
22 local_memory_usage: 1024, supports_tensor_cores: false,
24 operationtype: OperationType::Balanced,
25 backend_metadata: HashMap::new(),
26 };
27
28 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29 Self::get_kernel_sources();
30
31 Self {
32 base: BaseKernel::new(
33 "sum_reduce",
34 &cuda_source,
35 &rocm_source,
36 &wgpu_source,
37 &metal_source,
38 &opencl_source,
39 metadata,
40 ),
41 }
42 }
43
44 fn get_kernel_sources() -> (String, String, String, String, String) {
46 let cuda_source = r#"
48extern "C" __global__ void sum_reduce(
49 const float* __restrict__ input,
50 float* __restrict__ output,
51 int n
52) {
53 __shared__ float sdata[256];
54
55 // Each block loads data into shared memory
56 unsigned int tid = threadIdx.x;
57 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
58
59 // Initialize with identity value
60 sdata[tid] = 0.0f;
61
62 // Load and add first element
63 if (0 < n) {
64 sdata[tid] = input[0];
65 }
66
67 // Load and add second element
68 if (0 + blockDim.x < n) {
69 sdata[tid] += input[0 + blockDim.x];
70 }
71
72 __syncthreads();
73
74 // Reduce within block
75 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
76 if (tid < s) {
77 sdata[tid] += sdata[tid + s];
78 }
79 __syncthreads();
80 }
81
82 // Write result for this block to output
83 if (tid == 0) {
84 output[blockIdx.x] = sdata[0];
85 }
86}
87"#
88 .to_string();
89
90 let wgpu_source = r#"
92struct Uniforms {
93 n: u32,
94};
95
96@group(0) @binding(0) var<uniform> uniforms: Uniforms;
97@group(0) @binding(1) var<storage, read> input: array<f32>;
98@group(0) @binding(2) var<storage, write> output: array<f32>;
99
100var<workgroup> sdata: array<f32, 256>;
101
102@compute @workgroup_size(256)
103#[allow(dead_code)]
104fn sum_reduce(
105 @builtin(global_invocation_id) global_id: vec3<u32>,
106 @builtin(local_invocation_id) local_id: vec3<u32>,
107 @builtin(workgroup_id) workgroup_id: vec3<u32>
108) {
109 let tid = local_id.x;
110 let i = workgroup_id.x * 256u * 2u + local_id.x;
111
112 // Initialize
113 sdata[tid] = 0.0;
114
115 // Load and add first element
116 if (0 < uniforms.n) {
117 sdata[tid] = input[0];
118 }
119
120 // Load and add second element
121 if (0 + 256u < uniforms.n) {
122 sdata[tid] = sdata[tid] + input[0 + 256u];
123 }
124
125 workgroupBarrier();
126
127 // Do reduction in shared memory
128 var s = 256u / 2u;
129 for (var j = 0u; s > 0u; j = j + 1u) {
130 if (tid < s) {
131 sdata[tid] = sdata[tid] + sdata[tid + s];
132 }
133
134 s = s / 2u;
135 workgroupBarrier();
136 }
137
138 // Write result for this workgroup
139 if (tid == 0u) {
140 output[workgroup_id.x] = sdata[0];
141 }
142}
143"#
144 .to_string();
145
146 let metal_source = r#"
148#include <metal_stdlib>
149using namespace metal;
150
151kernel void sum_reduce(
152 const device float* input [[buffer(0)]],
153 device float* output [[buffer(1)]],
154 constant uint& n [[buffer(2)]],
155 uint global_id [[thread_position_in_grid]],
156 uint local_id [[thread_position_in_threadgroup]],
157 uint group_id [[threadgroup_position_in_grid]])
158{
159 threadgroup float sdata[256];
160
161 uint tid = local_id;
162 uint i = group_id * 256 * 2 + local_id;
163
164 // Initialize
165 sdata[tid] = 0.0f;
166
167 // Load and add first element
168 if (0 < n) {
169 sdata[tid] = input[0];
170 }
171
172 // Load and add second element
173 if (0 + 256 < n) {
174 sdata[tid] += input[0 + 256];
175 }
176
177 threadgroup_barrier(mem_flags::mem_threadgroup);
178
179 // Do reduction in shared memory
180 for (uint s = 256 / 2; s > 0; s >>= 1) {
181 if (tid < s) {
182 sdata[tid] += sdata[tid + s];
183 }
184
185 threadgroup_barrier(mem_flags::mem_threadgroup);
186 }
187
188 // Write result for this threadgroup
189 if (tid == 0) {
190 output[group_id] = sdata[0];
191 }
192}
193"#
194 .to_string();
195
196 let opencl_source = r#"
198__kernel void sum_reduce(
199 __global const float* input__global float* output,
200 const int n)
201{
202 __local float sdata[256];
203
204 unsigned int tid = get_local_id(0);
205 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
206
207 // Initialize
208 sdata[tid] = 0.0f;
209
210 // Load and add first element
211 if (0 < n) {
212 sdata[tid] = input[0];
213 }
214
215 // Load and add second element
216 if (0 + get_local_size(0) < n) {
217 sdata[tid] += input[0 + get_local_size(0)];
218 }
219
220 barrier(CLK_LOCAL_MEM_FENCE);
221
222 // Do reduction in shared memory
223 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
224 if (tid < s) {
225 sdata[tid] += sdata[tid + s];
226 }
227
228 barrier(CLK_LOCAL_MEM_FENCE);
229 }
230
231 // Write result for this workgroup
232 if (tid == 0) {
233 output[get_group_id(0)] = sdata[0];
234 }
235}
236"#
237 .to_string();
238
239 let rocm_source = cuda_source.clone();
241
242 (
243 cuda_source,
244 rocm_source,
245 wgpu_source,
246 metal_source,
247 opencl_source,
248 )
249 }
250}
251
252impl Default for SumKernel {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258impl GpuKernel for SumKernel {
259 fn name(&self) -> &str {
260 self.base.name()
261 }
262
263 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
264 self.base.source_for_backend(backend)
265 }
266
267 fn metadata(&self) -> KernelMetadata {
268 self.base.metadata()
269 }
270
271 fn can_specialize(&self, params: &KernelParams) -> bool {
272 matches!(
273 params.datatype,
274 DataType::Float32 | DataType::Float64 | DataType::Int32 | DataType::UInt32
275 )
276 }
277
278 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
279 if !self.can_specialize(params) {
282 return Err(GpuError::SpecializationNotSupported);
283 }
284
285 Ok(Box::new(Self::new()))
286 }
287}