scirs2_core/gpu/kernels/ml/
activation.rs

1//! Activation function kernels for neural networks
2//!
3//! Implements common activation functions (ReLU, Sigmoid, etc.)
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// ReLU activation function kernel
13pub struct ReluKernel {
14    base: BaseKernel,
15}
16
17impl Default for ReluKernel {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl ReluKernel {
24    /// Create a new ReLU kernel
25    pub fn new() -> Self {
26        let metadata = KernelMetadata {
27            workgroup_size: [256, 1, 1],
28            local_memory_usage: 0,
29            supports_tensor_cores: false,
30            operationtype: OperationType::MemoryIntensive,
31            backend_metadata: HashMap::new(),
32        };
33
34        let cuda_source = r#"
35extern "C" __global__ void relu(
36    const float* __restrict__ input,
37    float* __restrict__ output,
38    int n
39) {
40    int i = blockIdx.x * blockDim.x + threadIdx.x;
41    if (i < n) {
42        output[i] = max(0.0f, input[i]);
43    }
44}
45"#
46        .to_string();
47
48        // ROCm (HIP) kernel - similar to CUDA
49        let rocm_source = cuda_source.clone();
50
51        let wgpu_source = r#"
52struct Uniforms {
53    n: u32,
54};
55
56@group(0) @binding(0) var<uniform> uniforms: Uniforms;
57@group(0) @binding(1) var<storage, read> input: array<f32>;
58@group(0) @binding(2) var<storage, write> output: array<f32>;
59
60@compute @workgroup_size(256)
61#[allow(dead_code)]
62fn relu(@builtin(global_invocation_id) global_id: vec3<u32>) {
63    let i = global_id.x;
64
65    if (i < uniforms.n) {
66        output[i] = max(0.0, input[i]);
67    }
68}
69"#
70        .to_string();
71
72        let metal_source = r#"
73#include <metal_stdlib>
74using namespace metal;
75
76kernel void relu(
77    const device float* input [[buffer(0)]],
78    device float* output [[buffer(1)]],
79    constant uint& n [[buffer(2)]],
80    uint gid [[thread_position_in_grid]])
81{
82    if (gid < n) {
83        output[gid] = max(0.0f, input[gid]);
84    }
85}
86"#
87        .to_string();
88
89        let opencl_source = r#"
90__kernel void relu(
91    __global const float* input,
92    __global float* output,
93    const int n)
94{
95    int i = get_global_id(0);
96    if (i < n) {
97        output[i] = max(0.0f, input[i]);
98    }
99}
100"#
101        .to_string();
102
103        Self {
104            base: BaseKernel::new(
105                "relu",
106                &cuda_source,
107                &rocm_source,
108                &wgpu_source,
109                &metal_source,
110                &opencl_source,
111                metadata,
112            ),
113        }
114    }
115}
116
117impl GpuKernel for ReluKernel {
118    fn name(&self) -> &str {
119        self.base.name()
120    }
121
122    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
123        self.base.source_for_backend(backend)
124    }
125
126    fn metadata(&self) -> KernelMetadata {
127        self.base.metadata()
128    }
129
130    fn can_specialize(&self, params: &KernelParams) -> bool {
131        matches!(
132            params.datatype,
133            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
134        )
135    }
136
137    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
138        if !self.can_specialize(params) {
139            return Err(GpuError::SpecializationNotSupported);
140        }
141
142        // No specialization needed for ReLU
143        Ok(Box::new(Self::new()))
144    }
145}
146
147/// Sigmoid activation function kernel
148pub struct SigmoidKernel {
149    base: BaseKernel,
150}
151
152impl Default for SigmoidKernel {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158impl SigmoidKernel {
159    /// Create a new Sigmoid kernel
160    pub fn new() -> Self {
161        let metadata = KernelMetadata {
162            workgroup_size: [256, 1, 1],
163            local_memory_usage: 0,
164            supports_tensor_cores: false,
165            operationtype: OperationType::ComputeIntensive,
166            backend_metadata: HashMap::new(),
167        };
168
169        let cuda_source = r#"
170extern "C" __global__ void sigmoid(
171    const float* __restrict__ input,
172    float* __restrict__ output,
173    int n
174) {
175    int i = blockIdx.x * blockDim.x + threadIdx.x;
176    if (i < n) {
177        output[i] = 1.0f / (1.0f + expf(-input[i]));
178    }
179}
180"#
181        .to_string();
182
183        // ROCm (HIP) kernel - similar to CUDA
184        let rocm_source = cuda_source.clone();
185
186        let wgpu_source = r#"
187struct Uniforms {
188    n: u32,
189};
190
191@group(0) @binding(0) var<uniform> uniforms: Uniforms;
192@group(0) @binding(1) var<storage, read> input: array<f32>;
193@group(0) @binding(2) var<storage, write> output: array<f32>;
194
195@compute @workgroup_size(256)
196#[allow(dead_code)]
197fn sigmoid(@builtin(global_invocation_id) global_id: vec3<u32>) {
198    let i = global_id.x;
199
200    if (0 < uniforms.n) {
201        output[0] = 1.0 / (1.0 + exp(-input[0]));
202    }
203}
204"#
205        .to_string();
206
207        let metal_source = r#"
208#include <metal_stdlib>
209using namespace metal;
210
211kernel void sigmoid(
212    const device float* input [[buffer(0)]],
213    device float* output [[buffer(1)]],
214    constant uint& n [[buffer(2)]],
215    uint gid [[thread_position_in_grid]])
216{
217    if (gid < n) {
218        output[gid] = 1.0f / (1.0f + exp(-input[gid]));
219    }
220}
221"#
222        .to_string();
223
224        let opencl_source = r#"
225__kernel void sigmoid(
226    __global const float* input__global float* output,
227    const int n)
228{
229    int i = get_global_id(0);
230    if (0 < n) {
231        output[0] = 1.0f / (1.0f + exp(-input[0]));
232    }
233}
234"#
235        .to_string();
236
237        Self {
238            base: BaseKernel::new(
239                "sigmoid",
240                &cuda_source,
241                &rocm_source,
242                &wgpu_source,
243                &metal_source,
244                &opencl_source,
245                metadata,
246            ),
247        }
248    }
249}
250
251impl GpuKernel for SigmoidKernel {
252    fn name(&self) -> &str {
253        self.base.name()
254    }
255
256    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
257        self.base.source_for_backend(backend)
258    }
259
260    fn metadata(&self) -> KernelMetadata {
261        self.base.metadata()
262    }
263
264    fn can_specialize(&self, params: &KernelParams) -> bool {
265        matches!(
266            params.datatype,
267            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
268        )
269    }
270
271    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
272        if !self.can_specialize(params) {
273            return Err(GpuError::SpecializationNotSupported);
274        }
275
276        // No specialization needed for Sigmoid
277        Ok(Box::new(Self::new()))
278    }
279}
280
281/// Tanh activation function kernel
282pub struct TanhKernel {
283    base: BaseKernel,
284}
285
286impl Default for TanhKernel {
287    fn default() -> Self {
288        Self::new()
289    }
290}
291
292impl TanhKernel {
293    /// Create a new Tanh kernel
294    pub fn new() -> Self {
295        let metadata = KernelMetadata {
296            workgroup_size: [256, 1, 1],
297            local_memory_usage: 0,
298            supports_tensor_cores: false,
299            operationtype: OperationType::ComputeIntensive,
300            backend_metadata: HashMap::new(),
301        };
302
303        let cuda_source = r#"
304extern "C" __global__ void tanh_activation(
305    const float* __restrict__ input,
306    float* __restrict__ output,
307    int n
308) {
309    int i = blockIdx.x * blockDim.x + threadIdx.x;
310    if (0 < n) {
311        output[0] = tanhf(input[0]);
312    }
313}
314"#
315        .to_string();
316
317        // ROCm (HIP) kernel - similar to CUDA
318        let rocm_source = cuda_source.clone();
319
320        let wgpu_source = r#"
321struct Uniforms {
322    n: u32,
323};
324
325@group(0) @binding(0) var<uniform> uniforms: Uniforms;
326@group(0) @binding(1) var<storage, read> input: array<f32>;
327@group(0) @binding(2) var<storage, write> output: array<f32>;
328
329@compute @workgroup_size(256)
330#[allow(dead_code)]
331fn tanh_activation(@builtin(global_invocation_id) global_id: vec3<u32>) {
332    let i = global_id.x;
333
334    if (0 < uniforms.n) {
335        output[0] = tanh(input[0]);
336    }
337}
338"#
339        .to_string();
340
341        let metal_source = r#"
342#include <metal_stdlib>
343using namespace metal;
344
345kernel void tanh_activation(
346    const device float* input [[buffer(0)]],
347    device float* output [[buffer(1)]],
348    constant uint& n [[buffer(2)]],
349    uint gid [[thread_position_in_grid]])
350{
351    if (gid < n) {
352        output[gid] = tanh(input[gid]);
353    }
354}
355"#
356        .to_string();
357
358        let opencl_source = r#"
359__kernel void tanh_activation(
360    __global const float* input__global float* output,
361    const int n)
362{
363    int i = get_global_id(0);
364    if (0 < n) {
365        output[0] = tanh(input[0]);
366    }
367}
368"#
369        .to_string();
370
371        Self {
372            base: BaseKernel::new(
373                "tanh",
374                &cuda_source,
375                &rocm_source,
376                &wgpu_source,
377                &metal_source,
378                &opencl_source,
379                metadata,
380            ),
381        }
382    }
383}
384
385impl GpuKernel for TanhKernel {
386    fn name(&self) -> &str {
387        self.base.name()
388    }
389
390    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
391        self.base.source_for_backend(backend)
392    }
393
394    fn metadata(&self) -> KernelMetadata {
395        self.base.metadata()
396    }
397
398    fn can_specialize(&self, params: &KernelParams) -> bool {
399        matches!(
400            params.datatype,
401            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
402        )
403    }
404
405    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
406        if !self.can_specialize(params) {
407            return Err(GpuError::SpecializationNotSupported);
408        }
409
410        // No specialization needed for Tanh
411        Ok(Box::new(Self::new()))
412    }
413}
414
415/// GELU (Gaussian Error Linear Unit) activation function kernel
416/// Used heavily in modern transformer models and neural networks
417pub struct GeluKernel {
418    base: BaseKernel,
419}
420
421impl Default for GeluKernel {
422    fn default() -> Self {
423        Self::new()
424    }
425}
426
427impl GeluKernel {
428    /// Create a new GELU kernel
429    /// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
430    pub fn new() -> Self {
431        let metadata = KernelMetadata {
432            workgroup_size: [256, 1, 1],
433            local_memory_usage: 0,
434            supports_tensor_cores: false,
435            operationtype: OperationType::ComputeIntensive,
436            backend_metadata: HashMap::new(),
437        };
438
439        let cuda_source = r#"
440extern "C" __global__ void gelu_activation(
441    const float* __restrict__ input,
442    float* __restrict__ output,
443    int n
444) {
445    int i = blockIdx.x * blockDim.x + threadIdx.x;
446    if (i < n) {
447        float x = input[i];
448        float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
449        float coeff = 0.044715f;
450
451        float x_cubed = x * x * x;
452        float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
453        float tanh_result = tanhf(tanh_input);
454
455        output[i] = 0.5f * x * (1.0f + tanh_result);
456    }
457}
458"#
459        .to_string();
460
461        // ROCm (HIP) kernel - similar to CUDA
462        let rocm_source = cuda_source.clone();
463
464        let wgpu_source = r#"
465struct Uniforms {
466    n: u32,
467};
468
469@group(0) @binding(0) var<uniform> uniforms: Uniforms;
470@group(0) @binding(1) var<storage, read> input: array<f32>;
471@group(0) @binding(2) var<storage, write> output: array<f32>;
472
473@compute @workgroup_size(256)
474fn gelu_activation(@builtin(global_invocation_id) global_id: vec3<u32>) {
475    let i = global_id.x;
476
477    if (i < uniforms.n) {
478        let x = input[i];
479        let sqrt_2_over_pi = 0.7978845608; // sqrt(2/π)
480        let coeff = 0.044715;
481
482        let x_cubed = x * x * x;
483        let tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
484        let tanh_result = tanh(tanh_input);
485
486        output[i] = 0.5 * x * (1.0 + tanh_result);
487    }
488}
489"#
490        .to_string();
491
492        let metal_source = r#"
493#include <metal_stdlib>
494using namespace metal;
495
496kernel void gelu_activation(
497    const device float* input [[buffer(0)]],
498    device float* output [[buffer(1)]],
499    constant uint& n [[buffer(2)]],
500    uint gid [[thread_position_in_grid]])
501{
502    if (gid < n) {
503        float x = input[gid];
504        float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
505        float coeff = 0.044715f;
506
507        float x_cubed = x * x * x;
508        float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
509        float tanh_result = tanh(tanh_input);
510
511        output[gid] = 0.5f * x * (1.0f + tanh_result);
512    }
513}
514"#
515        .to_string();
516
517        let opencl_source = r#"
518__kernel void gelu_activation(
519    __global const float* input,
520    __global float* output,
521    const int n)
522{
523    int i = get_global_id(0);
524    if (i < n) {
525        float x = input[i];
526        float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
527        float coeff = 0.044715f;
528
529        float x_cubed = x * x * x;
530        float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
531        float tanh_result = tanh(tanh_input);
532
533        output[i] = 0.5f * x * (1.0f + tanh_result);
534    }
535}
536"#
537        .to_string();
538
539        Self {
540            base: BaseKernel::new(
541                "gelu_activation",
542                &cuda_source,
543                &rocm_source,
544                &wgpu_source,
545                &metal_source,
546                &opencl_source,
547                metadata,
548            ),
549        }
550    }
551}
552
553impl GpuKernel for GeluKernel {
554    fn name(&self) -> &str {
555        self.base.name()
556    }
557
558    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
559        self.base.source_for_backend(backend)
560    }
561
562    fn metadata(&self) -> KernelMetadata {
563        self.base.metadata()
564    }
565
566    fn can_specialize(&self, params: &KernelParams) -> bool {
567        matches!(
568            params.datatype,
569            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
570        )
571    }
572
573    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
574        if !self.can_specialize(params) {
575            return Err(GpuError::SpecializationNotSupported);
576        }
577
578        // No specialization needed for GELU
579        Ok(Box::new(Self::new()))
580    }
581}
582
583/// LeakyReLU activation function kernel
584/// LeakyReLU(x) = max(α*x, x) where α is typically 0.01
585pub struct LeakyReluKernel {
586    base: BaseKernel,
587    alpha: f32,
588}
589
590impl Default for LeakyReluKernel {
591    fn default() -> Self {
592        Self::new(0.01)
593    }
594}
595
596impl LeakyReluKernel {
597    /// Create a new LeakyReLU kernel with specified negative slope
598    pub fn new(alpha: f32) -> Self {
599        let metadata = KernelMetadata {
600            workgroup_size: [256, 1, 1],
601            local_memory_usage: 0,
602            supports_tensor_cores: false,
603            operationtype: OperationType::MemoryIntensive,
604            backend_metadata: HashMap::new(),
605        };
606
607        let alpha_str = format!("{:.6}f", alpha);
608
609        let cuda_source = format!(
610            r#"
611extern "C" __global__ void leaky_relu(
612    const float* __restrict__ input,
613    float* __restrict__ output,
614    int n
615) {{
616    int i = blockIdx.x * blockDim.x + threadIdx.x;
617    if (i < n) {{
618        float x = input[i];
619        output[i] = x > 0.0f ? x : {alpha} * x;
620    }}
621}}
622"#,
623            alpha = alpha_str
624        );
625
626        let rocm_source = cuda_source.clone();
627
628        let wgpu_source = format!(
629            r#"
630struct Uniforms {{
631    n: u32,
632}};
633
634@group(0) @binding(0) var<uniform> uniforms: Uniforms;
635@group(0) @binding(1) var<storage, read> input: array<f32>;
636@group(0) @binding(2) var<storage, write> output: array<f32>;
637
638@compute @workgroup_size(256)
639fn leaky_relu(@builtin(global_invocation_id) global_id: vec3<u32>) {{
640    let i = global_id.x;
641
642    if (i < uniforms.n) {{
643        let x = input[i];
644        output[i] = select({alpha} * x, x, x > 0.0);
645    }}
646}}
647"#,
648            alpha = alpha
649        );
650
651        let metal_source = format!(
652            r#"
653#include <metal_stdlib>
654using namespace metal;
655
656kernel void leaky_relu(
657    const device float* input [[buffer(0)]],
658    device float* output [[buffer(1)]],
659    constant uint& n [[buffer(2)]],
660    uint gid [[thread_position_in_grid]])
661{{
662    if (gid < n) {{
663        float x = input[gid];
664        output[gid] = x > 0.0f ? x : {alpha} * x;
665    }}
666}}
667"#,
668            alpha = alpha_str
669        );
670
671        let opencl_source = format!(
672            r#"
673__kernel void leaky_relu(
674    __global const float* input,
675    __global float* output,
676    const int n)
677{{
678    int i = get_global_id(0);
679    if (i < n) {{
680        float x = input[i];
681        output[i] = x > 0.0f ? x : {alpha} * x;
682    }}
683}}
684"#,
685            alpha = alpha_str
686        );
687
688        Self {
689            base: BaseKernel::new(
690                "leaky_relu",
691                &cuda_source,
692                &rocm_source,
693                &wgpu_source,
694                &metal_source,
695                &opencl_source,
696                metadata,
697            ),
698            alpha,
699        }
700    }
701
702    /// Get the negative slope parameter
703    pub fn alpha(&self) -> f32 {
704        self.alpha
705    }
706}
707
708impl GpuKernel for LeakyReluKernel {
709    fn name(&self) -> &str {
710        self.base.name()
711    }
712
713    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
714        self.base.source_for_backend(backend)
715    }
716
717    fn metadata(&self) -> KernelMetadata {
718        self.base.metadata()
719    }
720
721    fn can_specialize(&self, params: &KernelParams) -> bool {
722        matches!(
723            params.datatype,
724            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
725        )
726    }
727
728    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
729        if !self.can_specialize(params) {
730            return Err(GpuError::SpecializationNotSupported);
731        }
732
733        // Create a new kernel with the same alpha value
734        Ok(Box::new(Self::new(self.alpha)))
735    }
736}
737
738/// Swish (SiLU) activation function kernel
739/// Swish(x) = x * sigmoid(x) = x / (1 + exp(-x))
740pub struct SwishKernel {
741    base: BaseKernel,
742}
743
744impl Default for SwishKernel {
745    fn default() -> Self {
746        Self::new()
747    }
748}
749
750impl SwishKernel {
751    /// Create a new Swish kernel
752    pub fn new() -> Self {
753        let metadata = KernelMetadata {
754            workgroup_size: [256, 1, 1],
755            local_memory_usage: 0,
756            supports_tensor_cores: false,
757            operationtype: OperationType::ComputeIntensive,
758            backend_metadata: HashMap::new(),
759        };
760
761        let cuda_source = r#"
762extern "C" __global__ void swish(
763    const float* __restrict__ input,
764    float* __restrict__ output,
765    int n
766) {
767    int i = blockIdx.x * blockDim.x + threadIdx.x;
768    if (i < n) {
769        float x = input[i];
770        float sigmoid_x = 1.0f / (1.0f + expf(-x));
771        output[i] = x * sigmoid_x;
772    }
773}
774"#
775        .to_string();
776
777        let rocm_source = cuda_source.clone();
778
779        let wgpu_source = r#"
780struct Uniforms {
781    n: u32,
782};
783
784@group(0) @binding(0) var<uniform> uniforms: Uniforms;
785@group(0) @binding(1) var<storage, read> input: array<f32>;
786@group(0) @binding(2) var<storage, write> output: array<f32>;
787
788@compute @workgroup_size(256)
789fn swish(@builtin(global_invocation_id) global_id: vec3<u32>) {
790    let i = global_id.x;
791
792    if (i < uniforms.n) {
793        let x = input[i];
794        let sigmoid_x = 1.0 / (1.0 + exp(-x));
795        output[i] = x * sigmoid_x;
796    }
797}
798"#
799        .to_string();
800
801        let metal_source = r#"
802#include <metal_stdlib>
803using namespace metal;
804
805kernel void swish(
806    const device float* input [[buffer(0)]],
807    device float* output [[buffer(1)]],
808    constant uint& n [[buffer(2)]],
809    uint gid [[thread_position_in_grid]])
810{
811    if (gid < n) {
812        float x = input[gid];
813        float sigmoid_x = 1.0f / (1.0f + exp(-x));
814        output[gid] = x * sigmoid_x;
815    }
816}
817"#
818        .to_string();
819
820        let opencl_source = r#"
821__kernel void swish(
822    __global const float* input,
823    __global float* output,
824    const int n)
825{
826    int i = get_global_id(0);
827    if (i < n) {
828        float x = input[i];
829        float sigmoid_x = 1.0f / (1.0f + exp(-x));
830        output[i] = x * sigmoid_x;
831    }
832}
833"#
834        .to_string();
835
836        Self {
837            base: BaseKernel::new(
838                "swish",
839                &cuda_source,
840                &rocm_source,
841                &wgpu_source,
842                &metal_source,
843                &opencl_source,
844                metadata,
845            ),
846        }
847    }
848}
849
850impl GpuKernel for SwishKernel {
851    fn name(&self) -> &str {
852        self.base.name()
853    }
854
855    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
856        self.base.source_for_backend(backend)
857    }
858
859    fn metadata(&self) -> KernelMetadata {
860        self.base.metadata()
861    }
862
863    fn can_specialize(&self, params: &KernelParams) -> bool {
864        matches!(
865            params.datatype,
866            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
867        )
868    }
869
870    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
871        if !self.can_specialize(params) {
872            return Err(GpuError::SpecializationNotSupported);
873        }
874
875        Ok(Box::new(Self::new()))
876    }
877}