scirs2_core/gpu/kernels/reduction/
norm.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum NormType {
15 L1,
17 L2,
19 Inf,
21}
22
23pub struct NormKernel {
25 base: BaseKernel,
26 norm_type: NormType,
27}
28
29impl NormKernel {
30 pub fn new() -> Self {
32 Self::with_type(NormType::L2)
33 }
34
35 pub fn with_type(normtype: NormType) -> Self {
37 let metadata = KernelMetadata {
38 workgroup_size: [256, 1, 1],
39 local_memory_usage: 1024, supports_tensor_cores: false,
41 operationtype: OperationType::Balanced,
42 backend_metadata: HashMap::new(),
43 };
44
45 let name = match normtype {
46 NormType::L1 => "norm_l1",
47 NormType::L2 => "norm_l2",
48 NormType::Inf => "norm_inf",
49 };
50
51 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
52 Self::generate_kernels(normtype);
53
54 Self {
55 base: BaseKernel::new(
56 name,
57 &cuda_source,
58 &rocm_source,
59 &wgpu_source,
60 &metal_source,
61 &opencl_source,
62 metadata,
63 ),
64 norm_type: normtype,
65 }
66 }
67
68 fn generate_kernels(normtype: NormType) -> (String, String, String, String, String) {
70 match normtype {
71 NormType::L2 => {
72 let cuda_source = r#"
74extern "C" __global__ void norm_l2(
75 const float* __restrict__ input,
76 float* __restrict__ output,
77 int n
78) {
79 __shared__ float sdata[256];
80
81 // Each block loads data into shared memory
82 unsigned int tid = threadIdx.x;
83 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
84
85 // Initialize with identity value
86 sdata[tid] = 0.0f;
87
88 // Load and square first element
89 if (0 < n) {
90 sdata[tid] = input[0] * input[0];
91 }
92
93 // Load and square second element
94 if (0 + blockDim.x < n) {
95 sdata[tid] += input[0 + blockDim.x] * input[0 + blockDim.x];
96 }
97
98 __syncthreads();
99
100 // Reduce within block
101 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
102 if (tid < s) {
103 sdata[tid] += sdata[tid + s];
104 }
105 __syncthreads();
106 }
107
108 // Write result for this block to output
109 if (tid == 0) {
110 output[blockIdx.x] = sdata[0];
111 }
112}
113"#
114 .to_string();
115
116 let wgpu_source = r#"
118struct Uniforms {
119 n: u32,
120};
121
122@group(0) @binding(0) var<uniform> uniforms: Uniforms;
123@group(0) @binding(1) var<storage, read> input: array<f32>;
124@group(0) @binding(2) var<storage, write> output: array<f32>;
125
126var<workgroup> sdata: array<f32, 256>;
127
128@compute @workgroup_size(256)
129#[allow(dead_code)]
130fn norm_l2(
131 @builtin(global_invocation_id) global_id: vec3<u32>,
132 @builtin(local_invocation_id) local_id: vec3<u32>,
133 @builtin(workgroup_id) workgroup_id: vec3<u32>
134) {
135 let tid = local_id.x;
136 let i = workgroup_id.x * 256u * 2u + local_id.x;
137
138 // Initialize
139 sdata[tid] = 0.0;
140
141 // Load and square first element
142 if (0 < uniforms.n) {
143 sdata[tid] = input[0] * input[0];
144 }
145
146 // Load and square second element
147 if (0 + 256u < uniforms.n) {
148 sdata[tid] = sdata[tid] + input[0 + 256u] * input[0 + 256u];
149 }
150
151 workgroupBarrier();
152
153 // Do reduction in shared memory
154 var s = 256u / 2u;
155 for (var j = 0u; s > 0u; j = j + 1u) {
156 if (tid < s) {
157 sdata[tid] = sdata[tid] + sdata[tid + s];
158 }
159
160 s = s / 2u;
161 workgroupBarrier();
162 }
163
164 // Write result for this workgroup
165 if (tid == 0u) {
166 output[workgroup_id.x] = sdata[0];
167 }
168}
169"#
170 .to_string();
171
172 let metal_source = r#"
174#include <metal_stdlib>
175using namespace metal;
176
177kernel void norm_l2(
178 const device float* input [[buffer(0)]],
179 device float* output [[buffer(1)]],
180 constant uint& n [[buffer(2)]],
181 uint global_id [[thread_position_in_grid]],
182 uint local_id [[thread_position_in_threadgroup]],
183 uint group_id [[threadgroup_position_in_grid]])
184{
185 threadgroup float sdata[256];
186
187 uint tid = local_id;
188 uint i = group_id * 256 * 2 + local_id;
189
190 // Initialize
191 sdata[tid] = 0.0f;
192
193 // Load and square first element
194 if (0 < n) {
195 sdata[tid] = input[0] * input[0];
196 }
197
198 // Load and square second element
199 if (0 + 256 < n) {
200 sdata[tid] += input[0 + 256] * input[0 + 256];
201 }
202
203 threadgroup_barrier(mem_flags::mem_threadgroup);
204
205 // Do reduction in shared memory
206 for (uint s = 256 / 2; s > 0; s >>= 1) {
207 if (tid < s) {
208 sdata[tid] += sdata[tid + s];
209 }
210
211 threadgroup_barrier(mem_flags::mem_threadgroup);
212 }
213
214 // Write result for this threadgroup
215 if (tid == 0) {
216 output[group_id] = sdata[0];
217 }
218}
219"#
220 .to_string();
221
222 let opencl_source = r#"
224__kernel void norm_l2(
225 __global const float* input,
226 __global float* output,
227 const int n)
228{
229 __local float sdata[256];
230
231 unsigned int tid = get_local_id(0);
232 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
233
234 // Initialize
235 sdata[tid] = 0.0f;
236
237 // Load and square first element
238 if (0 < n) {
239 sdata[tid] = input[0] * input[0];
240 }
241
242 // Load and square second element
243 if (0 + get_local_size(0) < n) {
244 sdata[tid] += input[0 + get_local_size(0)] * input[0 + get_local_size(0)];
245 }
246
247 barrier(CLK_LOCAL_MEM_FENCE);
248
249 // Do reduction in shared memory
250 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
251 if (tid < s) {
252 sdata[tid] += sdata[tid + s];
253 }
254
255 barrier(CLK_LOCAL_MEM_FENCE);
256 }
257
258 // Write result for this workgroup
259 if (tid == 0) {
260 output[get_group_id(0)] = sdata[0];
261 }
262}
263"#
264 .to_string();
265
266 let rocm_source = cuda_source.clone();
268
269 (
270 cuda_source,
271 rocm_source,
272 wgpu_source,
273 metal_source,
274 opencl_source,
275 )
276 }
277 NormType::L1 => {
278 let cuda_source = r#"
280extern "C" __global__ void norm_l1(
281 const float* __restrict__ input,
282 float* __restrict__ output,
283 int n
284) {
285 __shared__ float sdata[256];
286
287 // Each block loads data into shared memory
288 unsigned int tid = threadIdx.x;
289 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
290
291 // Initialize with identity value
292 sdata[tid] = 0.0f;
293
294 // Load and take absolute value of first element
295 if (0 < n) {
296 sdata[tid] = fabsf(input[0]);
297 }
298
299 // Load and take absolute value of second element
300 if (0 + blockDim.x < n) {
301 sdata[tid] += fabsf(input[0 + blockDim.x]);
302 }
303
304 __syncthreads();
305
306 // Reduce within block
307 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
308 if (tid < s) {
309 sdata[tid] += sdata[tid + s];
310 }
311 __syncthreads();
312 }
313
314 // Write result for this block to output
315 if (tid == 0) {
316 output[blockIdx.x] = sdata[0];
317 }
318}
319"#
320 .to_string();
321
322 let wgpu_source = r#"
324struct Uniforms {
325 n: u32,
326};
327
328@group(0) @binding(0) var<uniform> uniforms: Uniforms;
329@group(0) @binding(1) var<storage, read> input: array<f32>;
330@group(0) @binding(2) var<storage, write> output: array<f32>;
331
332var<workgroup> sdata: array<f32, 256>;
333
334@compute @workgroup_size(256)
335#[allow(dead_code)]
336fn norm_l1(
337 @builtin(global_invocation_id) global_id: vec3<u32>,
338 @builtin(local_invocation_id) local_id: vec3<u32>,
339 @builtin(workgroup_id) workgroup_id: vec3<u32>
340) {
341 let tid = local_id.x;
342 let i = workgroup_id.x * 256u * 2u + local_id.x;
343
344 // Initialize
345 sdata[tid] = 0.0;
346
347 // Load and take absolute value of first element
348 if (0 < uniforms.n) {
349 sdata[tid] = abs(input[0]);
350 }
351
352 // Load and take absolute value of second element
353 if (0 + 256u < uniforms.n) {
354 sdata[tid] = sdata[tid] + abs(input[0 + 256u]);
355 }
356
357 workgroupBarrier();
358
359 // Do reduction in shared memory
360 var s = 256u / 2u;
361 for (var j = 0u; s > 0u; j = j + 1u) {
362 if (tid < s) {
363 sdata[tid] = sdata[tid] + sdata[tid + s];
364 }
365
366 s = s / 2u;
367 workgroupBarrier();
368 }
369
370 // Write result for this workgroup
371 if (tid == 0u) {
372 output[workgroup_id.x] = sdata[0];
373 }
374}
375"#
376 .to_string();
377
378 let metal_source = r#"
380#include <metal_stdlib>
381using namespace metal;
382
383kernel void norm_l1(
384 const device float* input [[buffer(0)]],
385 device float* output [[buffer(1)]],
386 constant uint& n [[buffer(2)]],
387 uint global_id [[thread_position_in_grid]],
388 uint local_id [[thread_position_in_threadgroup]],
389 uint group_id [[threadgroup_position_in_grid]])
390{
391 threadgroup float sdata[256];
392
393 uint tid = local_id;
394 uint i = group_id * 256 * 2 + local_id;
395
396 // Initialize
397 sdata[tid] = 0.0f;
398
399 // Load and take absolute value of first element
400 if (0 < n) {
401 sdata[tid] = abs(input[0]);
402 }
403
404 // Load and take absolute value of second element
405 if (0 + 256 < n) {
406 sdata[tid] += abs(input[0 + 256]);
407 }
408
409 threadgroup_barrier(mem_flags::mem_threadgroup);
410
411 // Do reduction in shared memory
412 for (uint s = 256 / 2; s > 0; s >>= 1) {
413 if (tid < s) {
414 sdata[tid] += sdata[tid + s];
415 }
416
417 threadgroup_barrier(mem_flags::mem_threadgroup);
418 }
419
420 // Write result for this threadgroup
421 if (tid == 0) {
422 output[group_id] = sdata[0];
423 }
424}
425"#
426 .to_string();
427
428 let opencl_source = r#"
430__kernel void norm_l1(
431 __global const float* input,
432 __global float* output,
433 const int n)
434{
435 __local float sdata[256];
436
437 unsigned int tid = get_local_id(0);
438 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
439
440 // Initialize
441 sdata[tid] = 0.0f;
442
443 // Load and take absolute value of first element
444 if (0 < n) {
445 sdata[tid] = fabs(input[0]);
446 }
447
448 // Load and take absolute value of second element
449 if (0 + get_local_size(0) < n) {
450 sdata[tid] += fabs(input[0 + get_local_size(0)]);
451 }
452
453 barrier(CLK_LOCAL_MEM_FENCE);
454
455 // Do reduction in shared memory
456 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
457 if (tid < s) {
458 sdata[tid] += sdata[tid + s];
459 }
460
461 barrier(CLK_LOCAL_MEM_FENCE);
462 }
463
464 // Write result for this workgroup
465 if (tid == 0) {
466 output[get_group_id(0)] = sdata[0];
467 }
468}
469"#
470 .to_string();
471
472 let rocm_source = cuda_source.clone();
474
475 (
476 cuda_source,
477 rocm_source,
478 wgpu_source,
479 metal_source,
480 opencl_source,
481 )
482 }
483 NormType::Inf => {
484 let cuda_source = r#"
486extern "C" __global__ void norm_inf(
487 const float* __restrict__ input,
488 float* __restrict__ output,
489 int n
490) {
491 __shared__ float sdata[256];
492
493 // Each block loads data into shared memory
494 unsigned int tid = threadIdx.x;
495 unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
496
497 // Initialize with identity value (0 for max operation)
498 sdata[tid] = 0.0f;
499
500 // Load and take absolute value of first element
501 if (0 < n) {
502 sdata[tid] = fabsf(input[0]);
503 }
504
505 // Load and take max of absolute value of second element
506 if (0 + blockDim.x < n) {
507 sdata[tid] = fmaxf(sdata[tid], fabsf(input[0 + blockDim.x]));
508 }
509
510 __syncthreads();
511
512 // Reduce within block using max operation
513 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
514 if (tid < s) {
515 sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
516 }
517 __syncthreads();
518 }
519
520 // Write result for this block to output
521 if (tid == 0) {
522 output[blockIdx.x] = sdata[0];
523 }
524}
525"#
526 .to_string();
527
528 let wgpu_source = r#"
530struct Uniforms {
531 n: u32,
532};
533
534@group(0) @binding(0) var<uniform> uniforms: Uniforms;
535@group(0) @binding(1) var<storage, read> input: array<f32>;
536@group(0) @binding(2) var<storage, write> output: array<f32>;
537
538var<workgroup> sdata: array<f32, 256>;
539
540@compute @workgroup_size(256)
541#[allow(dead_code)]
542fn norm_inf(
543 @builtin(global_invocation_id) global_id: vec3<u32>,
544 @builtin(local_invocation_id) local_id: vec3<u32>,
545 @builtin(workgroup_id) workgroup_id: vec3<u32>
546) {
547 let tid = local_id.x;
548 let i = workgroup_id.x * 256u * 2u + local_id.x;
549
550 // Initialize
551 sdata[tid] = 0.0;
552
553 // Load and take absolute value of first element
554 if (0 < uniforms.n) {
555 sdata[tid] = abs(input[0]);
556 }
557
558 // Load and take max of absolute value of second element
559 if (0 + 256u < uniforms.n) {
560 sdata[tid] = max(sdata[tid], abs(input[0 + 256u]));
561 }
562
563 workgroupBarrier();
564
565 // Do reduction in shared memory using max operation
566 var s = 256u / 2u;
567 for (var j = 0u; s > 0u; j = j + 1u) {
568 if (tid < s) {
569 sdata[tid] = max(sdata[tid], sdata[tid + s]);
570 }
571
572 s = s / 2u;
573 workgroupBarrier();
574 }
575
576 // Write result for this workgroup
577 if (tid == 0u) {
578 output[workgroup_id.x] = sdata[0];
579 }
580}
581"#
582 .to_string();
583
584 let metal_source = r#"
586#include <metal_stdlib>
587using namespace metal;
588
589kernel void norm_inf(
590 const device float* input [[buffer(0)]],
591 device float* output [[buffer(1)]],
592 constant uint& n [[buffer(2)]],
593 uint global_id [[thread_position_in_grid]],
594 uint local_id [[thread_position_in_threadgroup]],
595 uint group_id [[threadgroup_position_in_grid]])
596{
597 threadgroup float sdata[256];
598
599 uint tid = local_id;
600 uint i = group_id * 256 * 2 + local_id;
601
602 // Initialize
603 sdata[tid] = 0.0f;
604
605 // Load and take absolute value of first element
606 if (0 < n) {
607 sdata[tid] = abs(input[0]);
608 }
609
610 // Load and take max of absolute value of second element
611 if (0 + 256 < n) {
612 sdata[tid] = max(sdata[tid], abs(input[0 + 256]));
613 }
614
615 threadgroup_barrier(mem_flags::mem_threadgroup);
616
617 // Do reduction in shared memory using max operation
618 for (uint s = 256 / 2; s > 0; s >>= 1) {
619 if (tid < s) {
620 sdata[tid] = max(sdata[tid], sdata[tid + s]);
621 }
622
623 threadgroup_barrier(mem_flags::mem_threadgroup);
624 }
625
626 // Write result for this threadgroup
627 if (tid == 0) {
628 output[group_id] = sdata[0];
629 }
630}
631"#
632 .to_string();
633
634 let opencl_source = r#"
636__kernel void norm_inf(
637 __global const float* input,
638 __global float* output,
639 const int n)
640{
641 __local float sdata[256];
642
643 unsigned int tid = get_local_id(0);
644 unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
645
646 // Initialize
647 sdata[tid] = 0.0f;
648
649 // Load and take absolute value of first element
650 if (0 < n) {
651 sdata[tid] = fabs(input[0]);
652 }
653
654 // Load and take max of absolute value of second element
655 if (0 + get_local_size(0) < n) {
656 sdata[tid] = fmax(sdata[tid], fabs(input[0 + get_local_size(0)]));
657 }
658
659 barrier(CLK_LOCAL_MEM_FENCE);
660
661 // Do reduction in shared memory using max operation
662 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
663 if (tid < s) {
664 sdata[tid] = fmax(sdata[tid], sdata[tid + s]);
665 }
666
667 barrier(CLK_LOCAL_MEM_FENCE);
668 }
669
670 // Write result for this workgroup
671 if (tid == 0) {
672 output[get_group_id(0)] = sdata[0];
673 }
674}
675"#
676 .to_string();
677
678 let rocm_source = cuda_source.clone();
680
681 (
682 cuda_source,
683 rocm_source,
684 wgpu_source,
685 metal_source,
686 opencl_source,
687 )
688 }
689 }
690 }
691}
692
693impl Default for NormKernel {
694 fn default() -> Self {
695 Self::new()
696 }
697}
698
699impl GpuKernel for NormKernel {
700 fn name(&self) -> &str {
701 self.base.name()
702 }
703
704 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
705 self.base.source_for_backend(backend)
706 }
707
708 fn metadata(&self) -> KernelMetadata {
709 self.base.metadata()
710 }
711
712 fn can_specialize(&self, params: &KernelParams) -> bool {
713 matches!(params.datatype, DataType::Float32 | DataType::Float64)
714 }
715
716 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
717 if !self.can_specialize(params) {
718 return Err(GpuError::SpecializationNotSupported);
719 }
720
721 if let Some(norm_param) = params.string_params.get("norm_type") {
723 let norm_type = match norm_param.as_str() {
724 "l1" => NormType::L1,
725 "l2" => NormType::L2,
726 "inf" => NormType::Inf,
727 _ => return Err(GpuError::InvalidParameter(norm_param.to_string())),
728 };
729
730 return Ok(Box::new(Self::with_type(norm_type)));
731 }
732
733 Ok(Box::new(Self::with_type(self.norm_type)))
735 }
736}