1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct ReluKernel {
14 base: BaseKernel,
15}
16
17impl Default for ReluKernel {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl ReluKernel {
24 pub fn new() -> Self {
26 let metadata = KernelMetadata {
27 workgroup_size: [256, 1, 1],
28 local_memory_usage: 0,
29 supports_tensor_cores: false,
30 operationtype: OperationType::MemoryIntensive,
31 backend_metadata: HashMap::new(),
32 };
33
34 let cuda_source = r#"
35extern "C" __global__ void relu(
36 const float* __restrict__ input,
37 float* __restrict__ output,
38 int n
39) {
40 int i = blockIdx.x * blockDim.x + threadIdx.x;
41 if (i < n) {
42 output[i] = max(0.0f, input[i]);
43 }
44}
45"#
46 .to_string();
47
48 let rocm_source = cuda_source.clone();
50
51 let wgpu_source = r#"
52struct Uniforms {
53 n: u32,
54};
55
56@group(0) @binding(0) var<uniform> uniforms: Uniforms;
57@group(0) @binding(1) var<storage, read> input: array<f32>;
58@group(0) @binding(2) var<storage, write> output: array<f32>;
59
60@compute @workgroup_size(256)
61#[allow(dead_code)]
62fn relu(@builtin(global_invocation_id) global_id: vec3<u32>) {
63 let i = global_id.x;
64
65 if (i < uniforms.n) {
66 output[i] = max(0.0, input[i]);
67 }
68}
69"#
70 .to_string();
71
72 let metal_source = r#"
73#include <metal_stdlib>
74using namespace metal;
75
76kernel void relu(
77 const device float* input [[buffer(0)]],
78 device float* output [[buffer(1)]],
79 constant uint& n [[buffer(2)]],
80 uint gid [[thread_position_in_grid]])
81{
82 if (gid < n) {
83 output[gid] = max(0.0f, input[gid]);
84 }
85}
86"#
87 .to_string();
88
89 let opencl_source = r#"
90__kernel void relu(
91 __global const float* input,
92 __global float* output,
93 const int n)
94{
95 int i = get_global_id(0);
96 if (i < n) {
97 output[i] = max(0.0f, input[i]);
98 }
99}
100"#
101 .to_string();
102
103 Self {
104 base: BaseKernel::new(
105 "relu",
106 &cuda_source,
107 &rocm_source,
108 &wgpu_source,
109 &metal_source,
110 &opencl_source,
111 metadata,
112 ),
113 }
114 }
115}
116
117impl GpuKernel for ReluKernel {
118 fn name(&self) -> &str {
119 self.base.name()
120 }
121
122 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
123 self.base.source_for_backend(backend)
124 }
125
126 fn metadata(&self) -> KernelMetadata {
127 self.base.metadata()
128 }
129
130 fn can_specialize(&self, params: &KernelParams) -> bool {
131 matches!(
132 params.datatype,
133 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
134 )
135 }
136
137 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
138 if !self.can_specialize(params) {
139 return Err(GpuError::SpecializationNotSupported);
140 }
141
142 Ok(Box::new(Self::new()))
144 }
145}
146
147pub struct SigmoidKernel {
149 base: BaseKernel,
150}
151
152impl Default for SigmoidKernel {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158impl SigmoidKernel {
159 pub fn new() -> Self {
161 let metadata = KernelMetadata {
162 workgroup_size: [256, 1, 1],
163 local_memory_usage: 0,
164 supports_tensor_cores: false,
165 operationtype: OperationType::ComputeIntensive,
166 backend_metadata: HashMap::new(),
167 };
168
169 let cuda_source = r#"
170extern "C" __global__ void sigmoid(
171 const float* __restrict__ input,
172 float* __restrict__ output,
173 int n
174) {
175 int i = blockIdx.x * blockDim.x + threadIdx.x;
176 if (i < n) {
177 output[i] = 1.0f / (1.0f + expf(-input[i]));
178 }
179}
180"#
181 .to_string();
182
183 let rocm_source = cuda_source.clone();
185
186 let wgpu_source = r#"
187struct Uniforms {
188 n: u32,
189};
190
191@group(0) @binding(0) var<uniform> uniforms: Uniforms;
192@group(0) @binding(1) var<storage, read> input: array<f32>;
193@group(0) @binding(2) var<storage, write> output: array<f32>;
194
195@compute @workgroup_size(256)
196#[allow(dead_code)]
197fn sigmoid(@builtin(global_invocation_id) global_id: vec3<u32>) {
198 let i = global_id.x;
199
200 if (0 < uniforms.n) {
201 output[0] = 1.0 / (1.0 + exp(-input[0]));
202 }
203}
204"#
205 .to_string();
206
207 let metal_source = r#"
208#include <metal_stdlib>
209using namespace metal;
210
211kernel void sigmoid(
212 const device float* input [[buffer(0)]],
213 device float* output [[buffer(1)]],
214 constant uint& n [[buffer(2)]],
215 uint gid [[thread_position_in_grid]])
216{
217 if (gid < n) {
218 output[gid] = 1.0f / (1.0f + exp(-input[gid]));
219 }
220}
221"#
222 .to_string();
223
224 let opencl_source = r#"
225__kernel void sigmoid(
226 __global const float* input__global float* output,
227 const int n)
228{
229 int i = get_global_id(0);
230 if (0 < n) {
231 output[0] = 1.0f / (1.0f + exp(-input[0]));
232 }
233}
234"#
235 .to_string();
236
237 Self {
238 base: BaseKernel::new(
239 "sigmoid",
240 &cuda_source,
241 &rocm_source,
242 &wgpu_source,
243 &metal_source,
244 &opencl_source,
245 metadata,
246 ),
247 }
248 }
249}
250
251impl GpuKernel for SigmoidKernel {
252 fn name(&self) -> &str {
253 self.base.name()
254 }
255
256 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
257 self.base.source_for_backend(backend)
258 }
259
260 fn metadata(&self) -> KernelMetadata {
261 self.base.metadata()
262 }
263
264 fn can_specialize(&self, params: &KernelParams) -> bool {
265 matches!(
266 params.datatype,
267 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
268 )
269 }
270
271 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
272 if !self.can_specialize(params) {
273 return Err(GpuError::SpecializationNotSupported);
274 }
275
276 Ok(Box::new(Self::new()))
278 }
279}
280
281pub struct TanhKernel {
283 base: BaseKernel,
284}
285
286impl Default for TanhKernel {
287 fn default() -> Self {
288 Self::new()
289 }
290}
291
292impl TanhKernel {
293 pub fn new() -> Self {
295 let metadata = KernelMetadata {
296 workgroup_size: [256, 1, 1],
297 local_memory_usage: 0,
298 supports_tensor_cores: false,
299 operationtype: OperationType::ComputeIntensive,
300 backend_metadata: HashMap::new(),
301 };
302
303 let cuda_source = r#"
304extern "C" __global__ void tanh_activation(
305 const float* __restrict__ input,
306 float* __restrict__ output,
307 int n
308) {
309 int i = blockIdx.x * blockDim.x + threadIdx.x;
310 if (0 < n) {
311 output[0] = tanhf(input[0]);
312 }
313}
314"#
315 .to_string();
316
317 let rocm_source = cuda_source.clone();
319
320 let wgpu_source = r#"
321struct Uniforms {
322 n: u32,
323};
324
325@group(0) @binding(0) var<uniform> uniforms: Uniforms;
326@group(0) @binding(1) var<storage, read> input: array<f32>;
327@group(0) @binding(2) var<storage, write> output: array<f32>;
328
329@compute @workgroup_size(256)
330#[allow(dead_code)]
331fn tanh_activation(@builtin(global_invocation_id) global_id: vec3<u32>) {
332 let i = global_id.x;
333
334 if (0 < uniforms.n) {
335 output[0] = tanh(input[0]);
336 }
337}
338"#
339 .to_string();
340
341 let metal_source = r#"
342#include <metal_stdlib>
343using namespace metal;
344
345kernel void tanh_activation(
346 const device float* input [[buffer(0)]],
347 device float* output [[buffer(1)]],
348 constant uint& n [[buffer(2)]],
349 uint gid [[thread_position_in_grid]])
350{
351 if (gid < n) {
352 output[gid] = tanh(input[gid]);
353 }
354}
355"#
356 .to_string();
357
358 let opencl_source = r#"
359__kernel void tanh_activation(
360 __global const float* input__global float* output,
361 const int n)
362{
363 int i = get_global_id(0);
364 if (0 < n) {
365 output[0] = tanh(input[0]);
366 }
367}
368"#
369 .to_string();
370
371 Self {
372 base: BaseKernel::new(
373 "tanh",
374 &cuda_source,
375 &rocm_source,
376 &wgpu_source,
377 &metal_source,
378 &opencl_source,
379 metadata,
380 ),
381 }
382 }
383}
384
385impl GpuKernel for TanhKernel {
386 fn name(&self) -> &str {
387 self.base.name()
388 }
389
390 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
391 self.base.source_for_backend(backend)
392 }
393
394 fn metadata(&self) -> KernelMetadata {
395 self.base.metadata()
396 }
397
398 fn can_specialize(&self, params: &KernelParams) -> bool {
399 matches!(
400 params.datatype,
401 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
402 )
403 }
404
405 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
406 if !self.can_specialize(params) {
407 return Err(GpuError::SpecializationNotSupported);
408 }
409
410 Ok(Box::new(Self::new()))
412 }
413}
414
415pub struct GeluKernel {
418 base: BaseKernel,
419}
420
421impl Default for GeluKernel {
422 fn default() -> Self {
423 Self::new()
424 }
425}
426
427impl GeluKernel {
428 pub fn new() -> Self {
431 let metadata = KernelMetadata {
432 workgroup_size: [256, 1, 1],
433 local_memory_usage: 0,
434 supports_tensor_cores: false,
435 operationtype: OperationType::ComputeIntensive,
436 backend_metadata: HashMap::new(),
437 };
438
439 let cuda_source = r#"
440extern "C" __global__ void gelu_activation(
441 const float* __restrict__ input,
442 float* __restrict__ output,
443 int n
444) {
445 int i = blockIdx.x * blockDim.x + threadIdx.x;
446 if (i < n) {
447 float x = input[i];
448 float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
449 float coeff = 0.044715f;
450
451 float x_cubed = x * x * x;
452 float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
453 float tanh_result = tanhf(tanh_input);
454
455 output[i] = 0.5f * x * (1.0f + tanh_result);
456 }
457}
458"#
459 .to_string();
460
461 let rocm_source = cuda_source.clone();
463
464 let wgpu_source = r#"
465struct Uniforms {
466 n: u32,
467};
468
469@group(0) @binding(0) var<uniform> uniforms: Uniforms;
470@group(0) @binding(1) var<storage, read> input: array<f32>;
471@group(0) @binding(2) var<storage, write> output: array<f32>;
472
473@compute @workgroup_size(256)
474fn gelu_activation(@builtin(global_invocation_id) global_id: vec3<u32>) {
475 let i = global_id.x;
476
477 if (i < uniforms.n) {
478 let x = input[i];
479 let sqrt_2_over_pi = 0.7978845608; // sqrt(2/π)
480 let coeff = 0.044715;
481
482 let x_cubed = x * x * x;
483 let tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
484 let tanh_result = tanh(tanh_input);
485
486 output[i] = 0.5 * x * (1.0 + tanh_result);
487 }
488}
489"#
490 .to_string();
491
492 let metal_source = r#"
493#include <metal_stdlib>
494using namespace metal;
495
496kernel void gelu_activation(
497 const device float* input [[buffer(0)]],
498 device float* output [[buffer(1)]],
499 constant uint& n [[buffer(2)]],
500 uint gid [[thread_position_in_grid]])
501{
502 if (gid < n) {
503 float x = input[gid];
504 float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
505 float coeff = 0.044715f;
506
507 float x_cubed = x * x * x;
508 float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
509 float tanh_result = tanh(tanh_input);
510
511 output[gid] = 0.5f * x * (1.0f + tanh_result);
512 }
513}
514"#
515 .to_string();
516
517 let opencl_source = r#"
518__kernel void gelu_activation(
519 __global const float* input,
520 __global float* output,
521 const int n)
522{
523 int i = get_global_id(0);
524 if (i < n) {
525 float x = input[i];
526 float sqrt_2_over_pi = 0.7978845608f; // sqrt(2/π)
527 float coeff = 0.044715f;
528
529 float x_cubed = x * x * x;
530 float tanh_input = sqrt_2_over_pi * (x + coeff * x_cubed);
531 float tanh_result = tanh(tanh_input);
532
533 output[i] = 0.5f * x * (1.0f + tanh_result);
534 }
535}
536"#
537 .to_string();
538
539 Self {
540 base: BaseKernel::new(
541 "gelu_activation",
542 &cuda_source,
543 &rocm_source,
544 &wgpu_source,
545 &metal_source,
546 &opencl_source,
547 metadata,
548 ),
549 }
550 }
551}
552
553impl GpuKernel for GeluKernel {
554 fn name(&self) -> &str {
555 self.base.name()
556 }
557
558 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
559 self.base.source_for_backend(backend)
560 }
561
562 fn metadata(&self) -> KernelMetadata {
563 self.base.metadata()
564 }
565
566 fn can_specialize(&self, params: &KernelParams) -> bool {
567 matches!(
568 params.datatype,
569 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
570 )
571 }
572
573 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
574 if !self.can_specialize(params) {
575 return Err(GpuError::SpecializationNotSupported);
576 }
577
578 Ok(Box::new(Self::new()))
580 }
581}
582
583pub struct LeakyReluKernel {
586 base: BaseKernel,
587 alpha: f32,
588}
589
590impl Default for LeakyReluKernel {
591 fn default() -> Self {
592 Self::new(0.01)
593 }
594}
595
596impl LeakyReluKernel {
597 pub fn new(alpha: f32) -> Self {
599 let metadata = KernelMetadata {
600 workgroup_size: [256, 1, 1],
601 local_memory_usage: 0,
602 supports_tensor_cores: false,
603 operationtype: OperationType::MemoryIntensive,
604 backend_metadata: HashMap::new(),
605 };
606
607 let alpha_str = format!("{:.6}f", alpha);
608
609 let cuda_source = format!(
610 r#"
611extern "C" __global__ void leaky_relu(
612 const float* __restrict__ input,
613 float* __restrict__ output,
614 int n
615) {{
616 int i = blockIdx.x * blockDim.x + threadIdx.x;
617 if (i < n) {{
618 float x = input[i];
619 output[i] = x > 0.0f ? x : {alpha} * x;
620 }}
621}}
622"#,
623 alpha = alpha_str
624 );
625
626 let rocm_source = cuda_source.clone();
627
628 let wgpu_source = format!(
629 r#"
630struct Uniforms {{
631 n: u32,
632}};
633
634@group(0) @binding(0) var<uniform> uniforms: Uniforms;
635@group(0) @binding(1) var<storage, read> input: array<f32>;
636@group(0) @binding(2) var<storage, write> output: array<f32>;
637
638@compute @workgroup_size(256)
639fn leaky_relu(@builtin(global_invocation_id) global_id: vec3<u32>) {{
640 let i = global_id.x;
641
642 if (i < uniforms.n) {{
643 let x = input[i];
644 output[i] = select({alpha} * x, x, x > 0.0);
645 }}
646}}
647"#,
648 alpha = alpha
649 );
650
651 let metal_source = format!(
652 r#"
653#include <metal_stdlib>
654using namespace metal;
655
656kernel void leaky_relu(
657 const device float* input [[buffer(0)]],
658 device float* output [[buffer(1)]],
659 constant uint& n [[buffer(2)]],
660 uint gid [[thread_position_in_grid]])
661{{
662 if (gid < n) {{
663 float x = input[gid];
664 output[gid] = x > 0.0f ? x : {alpha} * x;
665 }}
666}}
667"#,
668 alpha = alpha_str
669 );
670
671 let opencl_source = format!(
672 r#"
673__kernel void leaky_relu(
674 __global const float* input,
675 __global float* output,
676 const int n)
677{{
678 int i = get_global_id(0);
679 if (i < n) {{
680 float x = input[i];
681 output[i] = x > 0.0f ? x : {alpha} * x;
682 }}
683}}
684"#,
685 alpha = alpha_str
686 );
687
688 Self {
689 base: BaseKernel::new(
690 "leaky_relu",
691 &cuda_source,
692 &rocm_source,
693 &wgpu_source,
694 &metal_source,
695 &opencl_source,
696 metadata,
697 ),
698 alpha,
699 }
700 }
701
702 pub fn alpha(&self) -> f32 {
704 self.alpha
705 }
706}
707
708impl GpuKernel for LeakyReluKernel {
709 fn name(&self) -> &str {
710 self.base.name()
711 }
712
713 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
714 self.base.source_for_backend(backend)
715 }
716
717 fn metadata(&self) -> KernelMetadata {
718 self.base.metadata()
719 }
720
721 fn can_specialize(&self, params: &KernelParams) -> bool {
722 matches!(
723 params.datatype,
724 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
725 )
726 }
727
728 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
729 if !self.can_specialize(params) {
730 return Err(GpuError::SpecializationNotSupported);
731 }
732
733 Ok(Box::new(Self::new(self.alpha)))
735 }
736}
737
738pub struct SwishKernel {
741 base: BaseKernel,
742}
743
744impl Default for SwishKernel {
745 fn default() -> Self {
746 Self::new()
747 }
748}
749
750impl SwishKernel {
751 pub fn new() -> Self {
753 let metadata = KernelMetadata {
754 workgroup_size: [256, 1, 1],
755 local_memory_usage: 0,
756 supports_tensor_cores: false,
757 operationtype: OperationType::ComputeIntensive,
758 backend_metadata: HashMap::new(),
759 };
760
761 let cuda_source = r#"
762extern "C" __global__ void swish(
763 const float* __restrict__ input,
764 float* __restrict__ output,
765 int n
766) {
767 int i = blockIdx.x * blockDim.x + threadIdx.x;
768 if (i < n) {
769 float x = input[i];
770 float sigmoid_x = 1.0f / (1.0f + expf(-x));
771 output[i] = x * sigmoid_x;
772 }
773}
774"#
775 .to_string();
776
777 let rocm_source = cuda_source.clone();
778
779 let wgpu_source = r#"
780struct Uniforms {
781 n: u32,
782};
783
784@group(0) @binding(0) var<uniform> uniforms: Uniforms;
785@group(0) @binding(1) var<storage, read> input: array<f32>;
786@group(0) @binding(2) var<storage, write> output: array<f32>;
787
788@compute @workgroup_size(256)
789fn swish(@builtin(global_invocation_id) global_id: vec3<u32>) {
790 let i = global_id.x;
791
792 if (i < uniforms.n) {
793 let x = input[i];
794 let sigmoid_x = 1.0 / (1.0 + exp(-x));
795 output[i] = x * sigmoid_x;
796 }
797}
798"#
799 .to_string();
800
801 let metal_source = r#"
802#include <metal_stdlib>
803using namespace metal;
804
805kernel void swish(
806 const device float* input [[buffer(0)]],
807 device float* output [[buffer(1)]],
808 constant uint& n [[buffer(2)]],
809 uint gid [[thread_position_in_grid]])
810{
811 if (gid < n) {
812 float x = input[gid];
813 float sigmoid_x = 1.0f / (1.0f + exp(-x));
814 output[gid] = x * sigmoid_x;
815 }
816}
817"#
818 .to_string();
819
820 let opencl_source = r#"
821__kernel void swish(
822 __global const float* input,
823 __global float* output,
824 const int n)
825{
826 int i = get_global_id(0);
827 if (i < n) {
828 float x = input[i];
829 float sigmoid_x = 1.0f / (1.0f + exp(-x));
830 output[i] = x * sigmoid_x;
831 }
832}
833"#
834 .to_string();
835
836 Self {
837 base: BaseKernel::new(
838 "swish",
839 &cuda_source,
840 &rocm_source,
841 &wgpu_source,
842 &metal_source,
843 &opencl_source,
844 metadata,
845 ),
846 }
847 }
848}
849
850impl GpuKernel for SwishKernel {
851 fn name(&self) -> &str {
852 self.base.name()
853 }
854
855 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
856 self.base.source_for_backend(backend)
857 }
858
859 fn metadata(&self) -> KernelMetadata {
860 self.base.metadata()
861 }
862
863 fn can_specialize(&self, params: &KernelParams) -> bool {
864 matches!(
865 params.datatype,
866 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
867 )
868 }
869
870 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
871 if !self.can_specialize(params) {
872 return Err(GpuError::SpecializationNotSupported);
873 }
874
875 Ok(Box::new(Self::new()))
876 }
877}