scirs2_core/gpu/kernels/ml/
pooling.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct MaxPoolKernel {
14 base: BaseKernel,
15}
16
17impl Default for MaxPoolKernel {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl MaxPoolKernel {
24 pub fn new() -> Self {
26 let metadata = KernelMetadata {
27 workgroup_size: [16, 16, 1],
28 local_memory_usage: 0,
29 supports_tensor_cores: false,
30 operationtype: OperationType::MemoryIntensive,
31 backend_metadata: HashMap::new(),
32 };
33
34 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
35 Self::get_kernel_sources();
36
37 Self {
38 base: BaseKernel::new(
39 "max_pool2d",
40 &cuda_source,
41 &rocm_source,
42 &wgpu_source,
43 &metal_source,
44 &opencl_source,
45 metadata,
46 ),
47 }
48 }
49
50 fn get_kernel_sources() -> (String, String, String, String, String) {
52 let cuda_source = r#"
54extern "C" __global__ void max_pool2d(
55 const float* __restrict__ input,
56 float* __restrict__ output,
57 int batch_size,
58 int channels,
59 int input_height,
60 int input_width,
61 int output_height,
62 int output_width,
63 int pool_height,
64 int pool_width,
65 int stride_y,
66 int stride_x
67) {
68 int batch_idx = blockIdx.z;
69 int channel_idx = blockIdx.y;
70 int out_y = blockIdx.x * blockDim.x + threadIdx.x;
71 int out_x = threadIdx.y;
72
73 if (batch_idx >= batch_size || channel_idx >= channels ||
74 out_y >= output_height || out_x >= output_width) {
75 return;
76 }
77
78 int input_offset = ((batch_idx * channels + channel_idx) * input_height) * input_width;
79 int output_offset = ((batch_idx * channels + channel_idx) * output_height) * output_width;
80
81 int start_y = out_y * stride_y;
82 int start_x = out_x * stride_x;
83 int end_y = min(start_y + pool_height, input_height);
84 int end_x = min(start_x + pool_width, input_width);
85
86 float max_val = -INFINITY;
87
88 for (int y = start_y; y < end_y; y++) {
89 for (int x = start_x; x < end_x; x++) {
90 int input_idx = input_offset + y * input_width + x;
91 max_val = fmaxf(max_val, input[input_idx]);
92 }
93 }
94
95 int output_idx = output_offset + out_y * output_width + out_x;
96 output[output_idx] = max_val;
97}
98"#
99 .to_string();
100
101 let wgpu_source = r#"
103struct Uniforms {
104 batch_size: u32,
105 channels: u32,
106 input_height: u32,
107 input_width: u32,
108 output_height: u32,
109 output_width: u32,
110 pool_height: u32,
111 pool_width: u32,
112 stride_y: u32,
113 stride_x: u32,
114};
115
116@group(0) @binding(0) var<uniform> uniforms: Uniforms;
117@group(0) @binding(1) var<storage, read> input: array<f32>;
118@group(0) @binding(2) var<storage, write> output: array<f32>;
119
120@compute @workgroup_size(16, 16)
121#[allow(dead_code)]
122fn max_pool2d(
123 @builtin(global_invocation_id) global_id: vec3<u32>
124) {
125 let batch_idx = global_id.z;
126 let channel_idx = global_id.y % uniforms.channels;
127 let out_y = global_id.x;
128 let out_x = global_id.y / uniforms.channels;
129
130 if (batch_idx >= uniforms.batch_size || channel_idx >= uniforms.channels ||
131 out_y >= uniforms.output_height || out_x >= uniforms.output_width) {
132 return;
133 }
134
135 let input_offset = ((batch_idx * uniforms.channels + channel_idx) * uniforms.input_height) * uniforms.input_width;
136 let output_offset = ((batch_idx * uniforms.channels + channel_idx) * uniforms.output_height) * uniforms.output_width;
137
138 let start_y = out_y * uniforms.stride_y;
139 let start_x = out_x * uniforms.stride_x;
140 let end_y = min(start_y + uniforms.pool_height, uniforms.input_height);
141 let end_x = min(start_x + uniforms.pool_width, uniforms.input_width);
142
143 var max_val = -3.4028235e+38; // f32::NEG_INFINITY
144
145 for (var y = start_y; y < end_y; y = y + 1u) {
146 for (var x = start_x; x < end_x; x = x + 1u) {
147 let input_idx = input_offset + y * uniforms.input_width + x;
148 max_val = max(max_val, input[input_idx]);
149 }
150 }
151
152 let output_idx = output_offset + out_y * uniforms.output_width + out_x;
153 output[output_idx] = max_val;
154}
155"#
156 .to_string();
157
158 let metal_source = r#"
160#include <metal_stdlib>
161using namespace metal;
162
163kernel void max_pool2d(
164 const device float* input [[buffer(0)]],
165 device float* output [[buffer(1)]],
166 constant uint& batch_size [[buffer(2)]],
167 constant uint& channels [[buffer(3)]],
168 constant uint& input_height [[buffer(4)]],
169 constant uint& input_width [[buffer(5)]],
170 constant uint& output_height [[buffer(6)]],
171 constant uint& output_width [[buffer(7)]],
172 constant uint& pool_height [[buffer(8)]],
173 constant uint& pool_width [[buffer(9)]],
174 constant uint& stride_y [[buffer(10)]],
175 constant uint& stride_x [[buffer(11)]],
176 uint3 global_id [[thread_position_in_grid]])
177{
178 uint batch_idx = global_id.z;
179 uint channel_idx = global_id.y % channels;
180 uint out_y = global_id.x;
181 uint out_x = global_id.y / channels;
182
183 if (batch_idx >= batch_size || channel_idx >= channels ||
184 out_y >= output_height || out_x >= output_width) {
185 return;
186 }
187
188 uint input_offset = ((batch_idx * channels + channel_idx) * input_height) * input_width;
189 uint output_offset = ((batch_idx * channels + channel_idx) * output_height) * output_width;
190
191 uint start_y = out_y * stride_y;
192 uint start_x = out_x * stride_x;
193 uint end_y = min(start_y + pool_height, input_height);
194 uint end_x = min(start_x + pool_width, input_width);
195
196 float max_val = -INFINITY;
197
198 for (uint y = start_y; y < end_y; y++) {
199 for (uint x = start_x; x < end_x; x++) {
200 uint input_idx = input_offset + y * input_width + x;
201 max_val = max(max_val, input[input_idx]);
202 }
203 }
204
205 uint output_idx = output_offset + out_y * output_width + out_x;
206 output[output_idx] = max_val;
207}
208"#
209 .to_string();
210
211 let opencl_source = r#"
213__kernel void max_pool2d(
214 __global const float* input__global float* output,
215 const int batch_size,
216 const int channels,
217 const int input_height,
218 const int input_width,
219 const int output_height,
220 const int output_width,
221 const int pool_height,
222 const int pool_width,
223 const int stride_y,
224 const int stride_x)
225{
226 int batch_idx = get_global_id(2);
227 int channel_idx = get_global_id(1) % channels;
228 int out_y = get_global_id(0);
229 int out_x = get_global_id(1) / channels;
230
231 if (batch_idx >= batch_size || channel_idx >= channels ||
232 out_y >= output_height || out_x >= output_width) {
233 return;
234 }
235
236 int input_offset = ((batch_idx * channels + channel_idx) * input_height) * input_width;
237 int output_offset = ((batch_idx * channels + channel_idx) * output_height) * output_width;
238
239 int start_y = out_y * stride_y;
240 int start_x = out_x * stride_x;
241 int end_y = min(start_y + pool_height, input_height);
242 int end_x = min(start_x + pool_width, input_width);
243
244 float max_val = -INFINITY;
245
246 for (int y = start_y; y < end_y; y++) {
247 for (int x = start_x; x < end_x; x++) {
248 int input_idx = input_offset + y * input_width + x;
249 max_val = max(max_val, input[input_idx]);
250 }
251 }
252
253 int output_idx = output_offset + out_y * output_width + out_x;
254 output[output_idx] = max_val;
255}
256"#
257 .to_string();
258
259 let rocm_source = cuda_source.clone();
261
262 (
263 cuda_source,
264 rocm_source,
265 wgpu_source,
266 metal_source,
267 opencl_source,
268 )
269 }
270}
271
272impl GpuKernel for MaxPoolKernel {
273 fn name(&self) -> &str {
274 self.base.name()
275 }
276
277 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
278 self.base.source_for_backend(backend)
279 }
280
281 fn metadata(&self) -> KernelMetadata {
282 self.base.metadata()
283 }
284
285 fn can_specialize(&self, params: &KernelParams) -> bool {
286 matches!(
287 params.datatype,
288 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
289 )
290 }
291
292 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
293 if !self.can_specialize(params) {
294 return Err(GpuError::SpecializationNotSupported);
295 }
296
297 Ok(Box::new(Self::new()))
298 }
299}
300
301pub struct AvgPoolKernel {
303 base: BaseKernel,
304}
305
306impl Default for AvgPoolKernel {
307 fn default() -> Self {
308 Self::new()
309 }
310}
311
312impl AvgPoolKernel {
313 pub fn new() -> Self {
315 let metadata = KernelMetadata {
316 workgroup_size: [16, 16, 1],
317 local_memory_usage: 0,
318 supports_tensor_cores: false,
319 operationtype: OperationType::MemoryIntensive,
320 backend_metadata: HashMap::new(),
321 };
322
323 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
324 Self::get_kernel_sources();
325
326 Self {
327 base: BaseKernel::new(
328 "avg_pool2d",
329 &cuda_source,
330 &rocm_source,
331 &wgpu_source,
332 &metal_source,
333 &opencl_source,
334 metadata,
335 ),
336 }
337 }
338
339 fn get_kernel_sources() -> (String, String, String, String, String) {
341 let cuda_source = r#"
343extern "C" __global__ void avg_pool2d(
344 const float* __restrict__ input,
345 float* __restrict__ output,
346 int batch_size,
347 int channels,
348 int input_height,
349 int input_width,
350 int output_height,
351 int output_width,
352 int pool_height,
353 int pool_width,
354 int stride_y,
355 int stride_x
356) {
357 int batch_idx = blockIdx.z;
358 int channel_idx = blockIdx.y;
359 int out_y = blockIdx.x * blockDim.x + threadIdx.x;
360 int out_x = threadIdx.y;
361
362 if (batch_idx >= batch_size || channel_idx >= channels ||
363 out_y >= output_height || out_x >= output_width) {
364 return;
365 }
366
367 int input_offset = ((batch_idx * channels + channel_idx) * input_height) * input_width;
368 int output_offset = ((batch_idx * channels + channel_idx) * output_height) * output_width;
369
370 int start_y = out_y * stride_y;
371 int start_x = out_x * stride_x;
372 int end_y = min(start_y + pool_height, input_height);
373 int end_x = min(start_x + pool_width, input_width);
374
375 float sum = 0.0f;
376 int count = 0;
377
378 for (int y = start_y; y < end_y; y++) {
379 for (int x = start_x; x < end_x; x++) {
380 int input_idx = input_offset + y * input_width + x;
381 sum += input[input_idx];
382 count++;
383 }
384 }
385
386 int output_idx = output_offset + out_y * output_width + out_x;
387 output[output_idx] = sum / (float)count;
388}
389"#
390 .to_string();
391
392 let wgpu_source = r#"
395// WebGPU average pooling implementation
396// Similar structure to max pooling but computing average instead of max
397struct Uniforms {
398 batch_size: u32,
399 channels: u32,
400 input_height: u32,
401 input_width: u32,
402 output_height: u32,
403 output_width: u32,
404 pool_height: u32,
405 pool_width: u32,
406 stride_y: u32,
407 stride_x: u32,
408};
409
410@group(0) @binding(0) var<uniform> uniforms: Uniforms;
411@group(0) @binding(1) var<storage, read> input: array<f32>;
412@group(0) @binding(2) var<storage, write> output: array<f32>;
413
414@compute @workgroup_size(16, 16)
415#[allow(dead_code)]
416fn avg_pool2d(@builtin(global_invocation_id) global_id: vec3<u32>) {
417 // Implementation similar to max pooling but computing average
418}
419"#
420 .to_string();
421
422 let metal_source = r#"
423// Metal average pooling implementation
424#include <metal_stdlib>
425using namespace metal;
426
427kernel void avg_pool2d(/* parameters similar to max pooling */) {
428 // Implementation similar to max pooling but computing average
429}
430"#
431 .to_string();
432
433 let opencl_source = r#"
434// OpenCL average pooling implementation
435__kernel void avg_pool2d(/* parameters similar to max pooling */) {
436 // Implementation similar to max pooling but computing average
437}
438"#
439 .to_string();
440
441 let rocm_source = cuda_source.clone();
443
444 (
445 cuda_source,
446 rocm_source,
447 wgpu_source,
448 metal_source,
449 opencl_source,
450 )
451 }
452}
453
454impl GpuKernel for AvgPoolKernel {
455 fn name(&self) -> &str {
456 self.base.name()
457 }
458
459 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
460 self.base.source_for_backend(backend)
461 }
462
463 fn metadata(&self) -> KernelMetadata {
464 self.base.metadata()
465 }
466
467 fn can_specialize(&self, params: &KernelParams) -> bool {
468 matches!(
469 params.datatype,
470 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
471 )
472 }
473
474 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
475 if !self.can_specialize(params) {
476 return Err(GpuError::SpecializationNotSupported);
477 }
478
479 Ok(Box::new(Self::new()))
480 }
481}