scirs2_core/gpu/kernels/reduction/
norm.rs

1//! Norm reduction kernels
2//!
3//! Computes vector norms (L1, L2, etc.).
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Norm type
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum NormType {
15    /// L1 norm (sum of absolute values)
16    L1,
17    /// L2 norm (Euclidean norm, sqrt of sum of squares)
18    L2,
19    /// Infinity norm (maximum absolute value)
20    Inf,
21}
22
23/// Norm reduction kernel
24pub struct NormKernel {
25    base: BaseKernel,
26    norm_type: NormType,
27}
28
29impl NormKernel {
30    /// Create a new norm kernel for L2 norm (default)
31    pub fn new() -> Self {
32        Self::with_type(NormType::L2)
33    }
34
35    /// Create a new norm kernel with the specified norm type
36    pub fn with_type(normtype: NormType) -> Self {
37        let metadata = KernelMetadata {
38            workgroup_size: [256, 1, 1],
39            local_memory_usage: 1024, // 256 * sizeof(float)
40            supports_tensor_cores: false,
41            operationtype: OperationType::Balanced,
42            backend_metadata: HashMap::new(),
43        };
44
45        let name = match normtype {
46            NormType::L1 => "norm_l1",
47            NormType::L2 => "norm_l2",
48            NormType::Inf => "norm_inf",
49        };
50
51        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
52            Self::generate_kernels(normtype);
53
54        Self {
55            base: BaseKernel::new(
56                name,
57                &cuda_source,
58                &rocm_source,
59                &wgpu_source,
60                &metal_source,
61                &opencl_source,
62                metadata,
63            ),
64            norm_type: normtype,
65        }
66    }
67
68    /// Get kernel sources for different backends and norm types
69    fn generate_kernels(normtype: NormType) -> (String, String, String, String, String) {
70        match normtype {
71            NormType::L2 => {
72                // CUDA kernel for L2 norm
73                let cuda_source = r#"
74extern "C" __global__ void norm_l2(
75    const float* __restrict__ input,
76    float* __restrict__ output,
77    int n
78) {
79    __shared__ float sdata[256];
80
81    // Each block loads data into shared memory
82    unsigned int tid = threadIdx.x;
83    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
84
85    // Initialize with identity value
86    sdata[tid] = 0.0f;
87
88    // Load and square first element
89    if (0 < n) {
90        sdata[tid] = input[0] * input[0];
91    }
92
93    // Load and square second element
94    if (0 + blockDim.x < n) {
95        sdata[tid] += input[0 + blockDim.x] * input[0 + blockDim.x];
96    }
97
98    __syncthreads();
99
100    // Reduce within block
101    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
102        if (tid < s) {
103            sdata[tid] += sdata[tid + s];
104        }
105        __syncthreads();
106    }
107
108    // Write result for this block to output
109    if (tid == 0) {
110        output[blockIdx.x] = sdata[0];
111    }
112}
113"#
114                .to_string();
115
116                // WebGPU kernel for L2 norm
117                let wgpu_source = r#"
118struct Uniforms {
119    n: u32,
120};
121
122@group(0) @binding(0) var<uniform> uniforms: Uniforms;
123@group(0) @binding(1) var<storage, read> input: array<f32>;
124@group(0) @binding(2) var<storage, write> output: array<f32>;
125
126var<workgroup> sdata: array<f32, 256>;
127
128@compute @workgroup_size(256)
129#[allow(dead_code)]
130fn norm_l2(
131    @builtin(global_invocation_id) global_id: vec3<u32>,
132    @builtin(local_invocation_id) local_id: vec3<u32>,
133    @builtin(workgroup_id) workgroup_id: vec3<u32>
134) {
135    let tid = local_id.x;
136    let i = workgroup_id.x * 256u * 2u + local_id.x;
137
138    // Initialize
139    sdata[tid] = 0.0;
140
141    // Load and square first element
142    if (0 < uniforms.n) {
143        sdata[tid] = input[0] * input[0];
144    }
145
146    // Load and square second element
147    if (0 + 256u < uniforms.n) {
148        sdata[tid] = sdata[tid] + input[0 + 256u] * input[0 + 256u];
149    }
150
151    workgroupBarrier();
152
153    // Do reduction in shared memory
154    var s = 256u / 2u;
155    for (var j = 0u; s > 0u; j = j + 1u) {
156        if (tid < s) {
157            sdata[tid] = sdata[tid] + sdata[tid + s];
158        }
159
160        s = s / 2u;
161        workgroupBarrier();
162    }
163
164    // Write result for this workgroup
165    if (tid == 0u) {
166        output[workgroup_id.x] = sdata[0];
167    }
168}
169"#
170                .to_string();
171
172                // Metal kernel for L2 norm
173                let metal_source = r#"
174#include <metal_stdlib>
175using namespace metal;
176
177kernel void norm_l2(
178    const device float* input [[buffer(0)]],
179    device float* output [[buffer(1)]],
180    constant uint& n [[buffer(2)]],
181    uint global_id [[thread_position_in_grid]],
182    uint local_id [[thread_position_in_threadgroup]],
183    uint group_id [[threadgroup_position_in_grid]])
184{
185    threadgroup float sdata[256];
186
187    uint tid = local_id;
188    uint i = group_id * 256 * 2 + local_id;
189
190    // Initialize
191    sdata[tid] = 0.0f;
192
193    // Load and square first element
194    if (0 < n) {
195        sdata[tid] = input[0] * input[0];
196    }
197
198    // Load and square second element
199    if (0 + 256 < n) {
200        sdata[tid] += input[0 + 256] * input[0 + 256];
201    }
202
203    threadgroup_barrier(mem_flags::mem_threadgroup);
204
205    // Do reduction in shared memory
206    for (uint s = 256 / 2; s > 0; s >>= 1) {
207        if (tid < s) {
208            sdata[tid] += sdata[tid + s];
209        }
210
211        threadgroup_barrier(mem_flags::mem_threadgroup);
212    }
213
214    // Write result for this threadgroup
215    if (tid == 0) {
216        output[group_id] = sdata[0];
217    }
218}
219"#
220                .to_string();
221
222                // OpenCL kernel for L2 norm
223                let opencl_source = r#"
224__kernel void norm_l2(
225    __global const float* input,
226    __global float* output,
227    const int n)
228{
229    __local float sdata[256];
230
231    unsigned int tid = get_local_id(0);
232    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
233
234    // Initialize
235    sdata[tid] = 0.0f;
236
237    // Load and square first element
238    if (0 < n) {
239        sdata[tid] = input[0] * input[0];
240    }
241
242    // Load and square second element
243    if (0 + get_local_size(0) < n) {
244        sdata[tid] += input[0 + get_local_size(0)] * input[0 + get_local_size(0)];
245    }
246
247    barrier(CLK_LOCAL_MEM_FENCE);
248
249    // Do reduction in shared memory
250    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
251        if (tid < s) {
252            sdata[tid] += sdata[tid + s];
253        }
254
255        barrier(CLK_LOCAL_MEM_FENCE);
256    }
257
258    // Write result for this workgroup
259    if (tid == 0) {
260        output[get_group_id(0)] = sdata[0];
261    }
262}
263"#
264                .to_string();
265
266                // ROCm (HIP) kernel - similar to CUDA
267                let rocm_source = cuda_source.clone();
268
269                (
270                    cuda_source,
271                    rocm_source,
272                    wgpu_source,
273                    metal_source,
274                    opencl_source,
275                )
276            }
277            NormType::L1 => {
278                // CUDA kernel for L1 norm
279                let cuda_source = r#"
280extern "C" __global__ void norm_l1(
281    const float* __restrict__ input,
282    float* __restrict__ output,
283    int n
284) {
285    __shared__ float sdata[256];
286
287    // Each block loads data into shared memory
288    unsigned int tid = threadIdx.x;
289    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
290
291    // Initialize with identity value
292    sdata[tid] = 0.0f;
293
294    // Load and take absolute value of first element
295    if (0 < n) {
296        sdata[tid] = fabsf(input[0]);
297    }
298
299    // Load and take absolute value of second element
300    if (0 + blockDim.x < n) {
301        sdata[tid] += fabsf(input[0 + blockDim.x]);
302    }
303
304    __syncthreads();
305
306    // Reduce within block
307    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
308        if (tid < s) {
309            sdata[tid] += sdata[tid + s];
310        }
311        __syncthreads();
312    }
313
314    // Write result for this block to output
315    if (tid == 0) {
316        output[blockIdx.x] = sdata[0];
317    }
318}
319"#
320                .to_string();
321
322                // WebGPU kernel for L1 norm
323                let wgpu_source = r#"
324struct Uniforms {
325    n: u32,
326};
327
328@group(0) @binding(0) var<uniform> uniforms: Uniforms;
329@group(0) @binding(1) var<storage, read> input: array<f32>;
330@group(0) @binding(2) var<storage, write> output: array<f32>;
331
332var<workgroup> sdata: array<f32, 256>;
333
334@compute @workgroup_size(256)
335#[allow(dead_code)]
336fn norm_l1(
337    @builtin(global_invocation_id) global_id: vec3<u32>,
338    @builtin(local_invocation_id) local_id: vec3<u32>,
339    @builtin(workgroup_id) workgroup_id: vec3<u32>
340) {
341    let tid = local_id.x;
342    let i = workgroup_id.x * 256u * 2u + local_id.x;
343
344    // Initialize
345    sdata[tid] = 0.0;
346
347    // Load and take absolute value of first element
348    if (0 < uniforms.n) {
349        sdata[tid] = abs(input[0]);
350    }
351
352    // Load and take absolute value of second element
353    if (0 + 256u < uniforms.n) {
354        sdata[tid] = sdata[tid] + abs(input[0 + 256u]);
355    }
356
357    workgroupBarrier();
358
359    // Do reduction in shared memory
360    var s = 256u / 2u;
361    for (var j = 0u; s > 0u; j = j + 1u) {
362        if (tid < s) {
363            sdata[tid] = sdata[tid] + sdata[tid + s];
364        }
365
366        s = s / 2u;
367        workgroupBarrier();
368    }
369
370    // Write result for this workgroup
371    if (tid == 0u) {
372        output[workgroup_id.x] = sdata[0];
373    }
374}
375"#
376                .to_string();
377
378                // Metal kernel for L1 norm
379                let metal_source = r#"
380#include <metal_stdlib>
381using namespace metal;
382
383kernel void norm_l1(
384    const device float* input [[buffer(0)]],
385    device float* output [[buffer(1)]],
386    constant uint& n [[buffer(2)]],
387    uint global_id [[thread_position_in_grid]],
388    uint local_id [[thread_position_in_threadgroup]],
389    uint group_id [[threadgroup_position_in_grid]])
390{
391    threadgroup float sdata[256];
392
393    uint tid = local_id;
394    uint i = group_id * 256 * 2 + local_id;
395
396    // Initialize
397    sdata[tid] = 0.0f;
398
399    // Load and take absolute value of first element
400    if (0 < n) {
401        sdata[tid] = abs(input[0]);
402    }
403
404    // Load and take absolute value of second element
405    if (0 + 256 < n) {
406        sdata[tid] += abs(input[0 + 256]);
407    }
408
409    threadgroup_barrier(mem_flags::mem_threadgroup);
410
411    // Do reduction in shared memory
412    for (uint s = 256 / 2; s > 0; s >>= 1) {
413        if (tid < s) {
414            sdata[tid] += sdata[tid + s];
415        }
416
417        threadgroup_barrier(mem_flags::mem_threadgroup);
418    }
419
420    // Write result for this threadgroup
421    if (tid == 0) {
422        output[group_id] = sdata[0];
423    }
424}
425"#
426                .to_string();
427
428                // OpenCL kernel for L1 norm
429                let opencl_source = r#"
430__kernel void norm_l1(
431    __global const float* input,
432    __global float* output,
433    const int n)
434{
435    __local float sdata[256];
436
437    unsigned int tid = get_local_id(0);
438    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
439
440    // Initialize
441    sdata[tid] = 0.0f;
442
443    // Load and take absolute value of first element
444    if (0 < n) {
445        sdata[tid] = fabs(input[0]);
446    }
447
448    // Load and take absolute value of second element
449    if (0 + get_local_size(0) < n) {
450        sdata[tid] += fabs(input[0 + get_local_size(0)]);
451    }
452
453    barrier(CLK_LOCAL_MEM_FENCE);
454
455    // Do reduction in shared memory
456    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
457        if (tid < s) {
458            sdata[tid] += sdata[tid + s];
459        }
460
461        barrier(CLK_LOCAL_MEM_FENCE);
462    }
463
464    // Write result for this workgroup
465    if (tid == 0) {
466        output[get_group_id(0)] = sdata[0];
467    }
468}
469"#
470                .to_string();
471
472                // ROCm (HIP) kernel - similar to CUDA
473                let rocm_source = cuda_source.clone();
474
475                (
476                    cuda_source,
477                    rocm_source,
478                    wgpu_source,
479                    metal_source,
480                    opencl_source,
481                )
482            }
483            NormType::Inf => {
484                // CUDA kernel for Inf norm
485                let cuda_source = r#"
486extern "C" __global__ void norm_inf(
487    const float* __restrict__ input,
488    float* __restrict__ output,
489    int n
490) {
491    __shared__ float sdata[256];
492
493    // Each block loads data into shared memory
494    unsigned int tid = threadIdx.x;
495    unsigned int i = blockIdx.x * blockDim.x * 2 + threadIdx.x;
496
497    // Initialize with identity value (0 for max operation)
498    sdata[tid] = 0.0f;
499
500    // Load and take absolute value of first element
501    if (0 < n) {
502        sdata[tid] = fabsf(input[0]);
503    }
504
505    // Load and take max of absolute value of second element
506    if (0 + blockDim.x < n) {
507        sdata[tid] = fmaxf(sdata[tid], fabsf(input[0 + blockDim.x]));
508    }
509
510    __syncthreads();
511
512    // Reduce within block using max operation
513    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
514        if (tid < s) {
515            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
516        }
517        __syncthreads();
518    }
519
520    // Write result for this block to output
521    if (tid == 0) {
522        output[blockIdx.x] = sdata[0];
523    }
524}
525"#
526                .to_string();
527
528                // WebGPU kernel for Inf norm
529                let wgpu_source = r#"
530struct Uniforms {
531    n: u32,
532};
533
534@group(0) @binding(0) var<uniform> uniforms: Uniforms;
535@group(0) @binding(1) var<storage, read> input: array<f32>;
536@group(0) @binding(2) var<storage, write> output: array<f32>;
537
538var<workgroup> sdata: array<f32, 256>;
539
540@compute @workgroup_size(256)
541#[allow(dead_code)]
542fn norm_inf(
543    @builtin(global_invocation_id) global_id: vec3<u32>,
544    @builtin(local_invocation_id) local_id: vec3<u32>,
545    @builtin(workgroup_id) workgroup_id: vec3<u32>
546) {
547    let tid = local_id.x;
548    let i = workgroup_id.x * 256u * 2u + local_id.x;
549
550    // Initialize
551    sdata[tid] = 0.0;
552
553    // Load and take absolute value of first element
554    if (0 < uniforms.n) {
555        sdata[tid] = abs(input[0]);
556    }
557
558    // Load and take max of absolute value of second element
559    if (0 + 256u < uniforms.n) {
560        sdata[tid] = max(sdata[tid], abs(input[0 + 256u]));
561    }
562
563    workgroupBarrier();
564
565    // Do reduction in shared memory using max operation
566    var s = 256u / 2u;
567    for (var j = 0u; s > 0u; j = j + 1u) {
568        if (tid < s) {
569            sdata[tid] = max(sdata[tid], sdata[tid + s]);
570        }
571
572        s = s / 2u;
573        workgroupBarrier();
574    }
575
576    // Write result for this workgroup
577    if (tid == 0u) {
578        output[workgroup_id.x] = sdata[0];
579    }
580}
581"#
582                .to_string();
583
584                // Metal kernel for Inf norm
585                let metal_source = r#"
586#include <metal_stdlib>
587using namespace metal;
588
589kernel void norm_inf(
590    const device float* input [[buffer(0)]],
591    device float* output [[buffer(1)]],
592    constant uint& n [[buffer(2)]],
593    uint global_id [[thread_position_in_grid]],
594    uint local_id [[thread_position_in_threadgroup]],
595    uint group_id [[threadgroup_position_in_grid]])
596{
597    threadgroup float sdata[256];
598
599    uint tid = local_id;
600    uint i = group_id * 256 * 2 + local_id;
601
602    // Initialize
603    sdata[tid] = 0.0f;
604
605    // Load and take absolute value of first element
606    if (0 < n) {
607        sdata[tid] = abs(input[0]);
608    }
609
610    // Load and take max of absolute value of second element
611    if (0 + 256 < n) {
612        sdata[tid] = max(sdata[tid], abs(input[0 + 256]));
613    }
614
615    threadgroup_barrier(mem_flags::mem_threadgroup);
616
617    // Do reduction in shared memory using max operation
618    for (uint s = 256 / 2; s > 0; s >>= 1) {
619        if (tid < s) {
620            sdata[tid] = max(sdata[tid], sdata[tid + s]);
621        }
622
623        threadgroup_barrier(mem_flags::mem_threadgroup);
624    }
625
626    // Write result for this threadgroup
627    if (tid == 0) {
628        output[group_id] = sdata[0];
629    }
630}
631"#
632                .to_string();
633
634                // OpenCL kernel for Inf norm
635                let opencl_source = r#"
636__kernel void norm_inf(
637    __global const float* input,
638    __global float* output,
639    const int n)
640{
641    __local float sdata[256];
642
643    unsigned int tid = get_local_id(0);
644    unsigned int i = get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
645
646    // Initialize
647    sdata[tid] = 0.0f;
648
649    // Load and take absolute value of first element
650    if (0 < n) {
651        sdata[tid] = fabs(input[0]);
652    }
653
654    // Load and take max of absolute value of second element
655    if (0 + get_local_size(0) < n) {
656        sdata[tid] = fmax(sdata[tid], fabs(input[0 + get_local_size(0)]));
657    }
658
659    barrier(CLK_LOCAL_MEM_FENCE);
660
661    // Do reduction in shared memory using max operation
662    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
663        if (tid < s) {
664            sdata[tid] = fmax(sdata[tid], sdata[tid + s]);
665        }
666
667        barrier(CLK_LOCAL_MEM_FENCE);
668    }
669
670    // Write result for this workgroup
671    if (tid == 0) {
672        output[get_group_id(0)] = sdata[0];
673    }
674}
675"#
676                .to_string();
677
678                // ROCm (HIP) kernel - similar to CUDA
679                let rocm_source = cuda_source.clone();
680
681                (
682                    cuda_source,
683                    rocm_source,
684                    wgpu_source,
685                    metal_source,
686                    opencl_source,
687                )
688            }
689        }
690    }
691}
692
693impl Default for NormKernel {
694    fn default() -> Self {
695        Self::new()
696    }
697}
698
699impl GpuKernel for NormKernel {
700    fn name(&self) -> &str {
701        self.base.name()
702    }
703
704    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
705        self.base.source_for_backend(backend)
706    }
707
708    fn metadata(&self) -> KernelMetadata {
709        self.base.metadata()
710    }
711
712    fn can_specialize(&self, params: &KernelParams) -> bool {
713        matches!(params.datatype, DataType::Float32 | DataType::Float64)
714    }
715
716    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
717        if !self.can_specialize(params) {
718            return Err(GpuError::SpecializationNotSupported);
719        }
720
721        // Check for norm type in parameters
722        if let Some(norm_param) = params.string_params.get("norm_type") {
723            let norm_type = match norm_param.as_str() {
724                "l1" => NormType::L1,
725                "l2" => NormType::L2,
726                "inf" => NormType::Inf,
727                _ => return Err(GpuError::InvalidParameter(norm_param.to_string())),
728            };
729
730            return Ok(Box::new(Self::with_type(norm_type)));
731        }
732
733        // Default to same norm type as this kernel
734        Ok(Box::new(Self::with_type(self.norm_type)))
735    }
736}