scirs2_core/gpu/kernels/
elementwise.rs

1//! Element-wise operation kernels for GPU
2//!
3//! These kernels implement basic element-wise operations that are fundamental
4//! to tensor computations: addition, multiplication, division, subtraction.
5
6use std::collections::HashMap;
7
8use crate::gpu::kernels::{
9    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
10};
11use crate::gpu::{GpuBackend, GpuError};
12
13/// Element-wise addition kernel (a + b)
14pub struct ElementwiseAddKernel {
15    base: BaseKernel,
16}
17
18impl Default for ElementwiseAddKernel {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl ElementwiseAddKernel {
25    /// Create a new element-wise addition kernel
26    pub fn new() -> Self {
27        let metadata = KernelMetadata {
28            workgroup_size: [256, 1, 1],
29            local_memory_usage: 0,
30            supports_tensor_cores: false,
31            operationtype: OperationType::MemoryIntensive,
32            backend_metadata: HashMap::new(),
33        };
34
35        let cuda_source = r#"
36extern "C" __global__ void elementwise_add(
37    const float* __restrict__ a,
38    const float* __restrict__ b,
39    float* __restrict__ result,
40    int n
41) {
42    int i = blockIdx.x * blockDim.x + threadIdx.x;
43    if (i < n) {
44        result[i] = a[i] + b[i];
45    }
46}
47"#
48        .to_string();
49
50        let rocm_source = cuda_source.clone();
51
52        let wgpu_source = r#"
53struct Uniforms {
54    n: u32,
55};
56
57@group(0) @binding(0) var<uniform> uniforms: Uniforms;
58@group(0) @binding(1) var<storage, read> a: array<f32>;
59@group(0) @binding(2) var<storage, read> b: array<f32>;
60@group(0) @binding(3) var<storage, write> result: array<f32>;
61
62@compute @workgroup_size(256)
63fn elementwise_add(@builtin(global_invocation_id) global_id: vec3<u32>) {
64    let i = global_id.x;
65
66    if (i < uniforms.n) {
67        result[i] = a[i] + b[i];
68    }
69}
70"#
71        .to_string();
72
73        let metal_source = r#"
74#include <metal_stdlib>
75using namespace metal;
76
77kernel void elementwise_add(
78    const device float* a [[buffer(0)]],
79    const device float* b [[buffer(1)]],
80    device float* result [[buffer(2)]],
81    constant uint& n [[buffer(3)]],
82    uint gid [[thread_position_in_grid]])
83{
84    if (gid < n) {
85        result[gid] = a[gid] + b[gid];
86    }
87}
88"#
89        .to_string();
90
91        let opencl_source = r#"
92__kernel void elementwise_add(
93    __global const float* a,
94    __global const float* b,
95    __global float* result,
96    const int n)
97{
98    int i = get_global_id(0);
99    if (i < n) {
100        result[i] = a[i] + b[i];
101    }
102}
103"#
104        .to_string();
105
106        Self {
107            base: BaseKernel::new(
108                "elementwise_add",
109                &cuda_source,
110                &rocm_source,
111                &wgpu_source,
112                &metal_source,
113                &opencl_source,
114                metadata,
115            ),
116        }
117    }
118}
119
120impl GpuKernel for ElementwiseAddKernel {
121    fn name(&self) -> &str {
122        self.base.name()
123    }
124
125    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
126        self.base.source_for_backend(backend)
127    }
128
129    fn metadata(&self) -> KernelMetadata {
130        self.base.metadata()
131    }
132
133    fn can_specialize(&self, params: &KernelParams) -> bool {
134        matches!(
135            params.datatype,
136            DataType::Float32 | DataType::Float64 | DataType::Int32
137        )
138    }
139
140    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
141        if !self.can_specialize(params) {
142            return Err(GpuError::SpecializationNotSupported);
143        }
144
145        // For now, return the same kernel (no type specialization implemented yet)
146        Ok(Box::new(Self::new()))
147    }
148}
149
150/// Element-wise multiplication kernel (a * b)
151pub struct ElementwiseMulKernel {
152    base: BaseKernel,
153}
154
155impl Default for ElementwiseMulKernel {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl ElementwiseMulKernel {
162    /// Create a new element-wise multiplication kernel
163    pub fn new() -> Self {
164        let metadata = KernelMetadata {
165            workgroup_size: [256, 1, 1],
166            local_memory_usage: 0,
167            supports_tensor_cores: false,
168            operationtype: OperationType::MemoryIntensive,
169            backend_metadata: HashMap::new(),
170        };
171
172        let cuda_source = r#"
173extern "C" __global__ void elementwise_mul(
174    const float* __restrict__ a,
175    const float* __restrict__ b,
176    float* __restrict__ result,
177    int n
178) {
179    int i = blockIdx.x * blockDim.x + threadIdx.x;
180    if (i < n) {
181        result[i] = a[i] * b[i];
182    }
183}
184"#
185        .to_string();
186
187        let rocm_source = cuda_source.clone();
188
189        let wgpu_source = r#"
190struct Uniforms {
191    n: u32,
192};
193
194@group(0) @binding(0) var<uniform> uniforms: Uniforms;
195@group(0) @binding(1) var<storage, read> a: array<f32>;
196@group(0) @binding(2) var<storage, read> b: array<f32>;
197@group(0) @binding(3) var<storage, write> result: array<f32>;
198
199@compute @workgroup_size(256)
200fn elementwise_mul(@builtin(global_invocation_id) global_id: vec3<u32>) {
201    let i = global_id.x;
202
203    if (i < uniforms.n) {
204        result[i] = a[i] * b[i];
205    }
206}
207"#
208        .to_string();
209
210        let metal_source = r#"
211#include <metal_stdlib>
212using namespace metal;
213
214kernel void elementwise_mul(
215    const device float* a [[buffer(0)]],
216    const device float* b [[buffer(1)]],
217    device float* result [[buffer(2)]],
218    constant uint& n [[buffer(3)]],
219    uint gid [[thread_position_in_grid]])
220{
221    if (gid < n) {
222        result[gid] = a[gid] * b[gid];
223    }
224}
225"#
226        .to_string();
227
228        let opencl_source = r#"
229__kernel void elementwise_mul(
230    __global const float* a,
231    __global const float* b,
232    __global float* result,
233    const int n)
234{
235    int i = get_global_id(0);
236    if (i < n) {
237        result[i] = a[i] * b[i];
238    }
239}
240"#
241        .to_string();
242
243        Self {
244            base: BaseKernel::new(
245                "elementwise_mul",
246                &cuda_source,
247                &rocm_source,
248                &wgpu_source,
249                &metal_source,
250                &opencl_source,
251                metadata,
252            ),
253        }
254    }
255}
256
257impl GpuKernel for ElementwiseMulKernel {
258    fn name(&self) -> &str {
259        self.base.name()
260    }
261
262    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
263        self.base.source_for_backend(backend)
264    }
265
266    fn metadata(&self) -> KernelMetadata {
267        self.base.metadata()
268    }
269
270    fn can_specialize(&self, params: &KernelParams) -> bool {
271        matches!(
272            params.datatype,
273            DataType::Float32 | DataType::Float64 | DataType::Int32
274        )
275    }
276
277    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
278        if !self.can_specialize(params) {
279            return Err(GpuError::SpecializationNotSupported);
280        }
281
282        Ok(Box::new(Self::new()))
283    }
284}
285
286/// Element-wise division kernel (a / b)
287pub struct ElementwiseDivKernel {
288    base: BaseKernel,
289}
290
291impl Default for ElementwiseDivKernel {
292    fn default() -> Self {
293        Self::new()
294    }
295}
296
297impl ElementwiseDivKernel {
298    /// Create a new element-wise division kernel
299    pub fn new() -> Self {
300        let metadata = KernelMetadata {
301            workgroup_size: [256, 1, 1],
302            local_memory_usage: 0,
303            supports_tensor_cores: false,
304            operationtype: OperationType::ComputeIntensive, // Division is more compute-intensive
305            backend_metadata: HashMap::new(),
306        };
307
308        let cuda_source = r#"
309extern "C" __global__ void elementwise_div(
310    const float* __restrict__ a,
311    const float* __restrict__ b,
312    float* __restrict__ result,
313    int n
314) {
315    int i = blockIdx.x * blockDim.x + threadIdx.x;
316    if (i < n) {
317        result[i] = a[i] / b[i];
318    }
319}
320"#
321        .to_string();
322
323        let rocm_source = cuda_source.clone();
324
325        let wgpu_source = r#"
326struct Uniforms {
327    n: u32,
328};
329
330@group(0) @binding(0) var<uniform> uniforms: Uniforms;
331@group(0) @binding(1) var<storage, read> a: array<f32>;
332@group(0) @binding(2) var<storage, read> b: array<f32>;
333@group(0) @binding(3) var<storage, write> result: array<f32>;
334
335@compute @workgroup_size(256)
336fn elementwise_div(@builtin(global_invocation_id) global_id: vec3<u32>) {
337    let i = global_id.x;
338
339    if (i < uniforms.n) {
340        result[i] = a[i] / b[i];
341    }
342}
343"#
344        .to_string();
345
346        let metal_source = r#"
347#include <metal_stdlib>
348using namespace metal;
349
350kernel void elementwise_div(
351    const device float* a [[buffer(0)]],
352    const device float* b [[buffer(1)]],
353    device float* result [[buffer(2)]],
354    constant uint& n [[buffer(3)]],
355    uint gid [[thread_position_in_grid]])
356{
357    if (gid < n) {
358        result[gid] = a[gid] / b[gid];
359    }
360}
361"#
362        .to_string();
363
364        let opencl_source = r#"
365__kernel void elementwise_div(
366    __global const float* a,
367    __global const float* b,
368    __global float* result,
369    const int n)
370{
371    int i = get_global_id(0);
372    if (i < n) {
373        result[i] = a[i] / b[i];
374    }
375}
376"#
377        .to_string();
378
379        Self {
380            base: BaseKernel::new(
381                "elementwise_div",
382                &cuda_source,
383                &rocm_source,
384                &wgpu_source,
385                &metal_source,
386                &opencl_source,
387                metadata,
388            ),
389        }
390    }
391}
392
393impl GpuKernel for ElementwiseDivKernel {
394    fn name(&self) -> &str {
395        self.base.name()
396    }
397
398    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
399        self.base.source_for_backend(backend)
400    }
401
402    fn metadata(&self) -> KernelMetadata {
403        self.base.metadata()
404    }
405
406    fn can_specialize(&self, params: &KernelParams) -> bool {
407        matches!(params.datatype, DataType::Float32 | DataType::Float64)
408    }
409
410    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
411        if !self.can_specialize(params) {
412            return Err(GpuError::SpecializationNotSupported);
413        }
414
415        Ok(Box::new(Self::new()))
416    }
417}
418
419/// Scalar multiplication kernel (a * scalar)
420pub struct ScalarMulKernel {
421    base: BaseKernel,
422}
423
424impl Default for ScalarMulKernel {
425    fn default() -> Self {
426        Self::new()
427    }
428}
429
430impl ScalarMulKernel {
431    /// Create a new scalar multiplication kernel
432    pub fn new() -> Self {
433        let metadata = KernelMetadata {
434            workgroup_size: [256, 1, 1],
435            local_memory_usage: 0,
436            supports_tensor_cores: false,
437            operationtype: OperationType::MemoryIntensive,
438            backend_metadata: HashMap::new(),
439        };
440
441        let cuda_source = r#"
442extern "C" __global__ void scalar_mul(
443    const float* __restrict__ input,
444    float* __restrict__ output,
445    float scalar,
446    int n
447) {
448    int i = blockIdx.x * blockDim.x + threadIdx.x;
449    if (i < n) {
450        output[i] = input[i] * scalar;
451    }
452}
453"#
454        .to_string();
455
456        let rocm_source = cuda_source.clone();
457
458        let wgpu_source = r#"
459struct Uniforms {
460    n: u32,
461    scalar: f32,
462};
463
464@group(0) @binding(0) var<uniform> uniforms: Uniforms;
465@group(0) @binding(1) var<storage, read> input: array<f32>;
466@group(0) @binding(2) var<storage, write> output: array<f32>;
467
468@compute @workgroup_size(256)
469fn scalar_mul(@builtin(global_invocation_id) global_id: vec3<u32>) {
470    let i = global_id.x;
471
472    if (i < uniforms.n) {
473        output[i] = input[i] * uniforms.scalar;
474    }
475}
476"#
477        .to_string();
478
479        let metal_source = r#"
480#include <metal_stdlib>
481using namespace metal;
482
483kernel void scalar_mul(
484    const device float* input [[buffer(0)]],
485    device float* output [[buffer(1)]],
486    constant float& scalar [[buffer(2)]],
487    constant uint& n [[buffer(3)]],
488    uint gid [[thread_position_in_grid]])
489{
490    if (gid < n) {
491        output[gid] = input[gid] * scalar;
492    }
493}
494"#
495        .to_string();
496
497        let opencl_source = r#"
498__kernel void scalar_mul(
499    __global const float* input,
500    __global float* output,
501    const float scalar,
502    const int n)
503{
504    int i = get_global_id(0);
505    if (i < n) {
506        output[i] = input[i] * scalar;
507    }
508}
509"#
510        .to_string();
511
512        Self {
513            base: BaseKernel::new(
514                "scalar_mul",
515                &cuda_source,
516                &rocm_source,
517                &wgpu_source,
518                &metal_source,
519                &opencl_source,
520                metadata,
521            ),
522        }
523    }
524}
525
526impl GpuKernel for ScalarMulKernel {
527    fn name(&self) -> &str {
528        self.base.name()
529    }
530
531    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
532        self.base.source_for_backend(backend)
533    }
534
535    fn metadata(&self) -> KernelMetadata {
536        self.base.metadata()
537    }
538
539    fn can_specialize(&self, params: &KernelParams) -> bool {
540        matches!(
541            params.datatype,
542            DataType::Float32 | DataType::Float64 | DataType::Int32
543        )
544    }
545
546    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
547        if !self.can_specialize(params) {
548            return Err(GpuError::SpecializationNotSupported);
549        }
550
551        Ok(Box::new(Self::new()))
552    }
553}
554/// Element-wise subtraction kernel (a - b)
555pub struct ElementwiseSubKernel {
556    base: BaseKernel,
557}
558
559impl Default for ElementwiseSubKernel {
560    fn default() -> Self {
561        Self::new()
562    }
563}
564
565impl ElementwiseSubKernel {
566    pub fn new() -> Self {
567        let metadata = KernelMetadata {
568            workgroup_size: [256, 1, 1],
569            local_memory_usage: 0,
570            supports_tensor_cores: false,
571            operationtype: OperationType::MemoryIntensive,
572            backend_metadata: HashMap::new(),
573        };
574
575        let cuda_source = r#"
576extern \"C\" __global__ void elementwise_sub(
577    const float* __restrict__ a,
578    const float* __restrict__ b,
579    float* __restrict__ result,
580    int n
581) {
582    int i = blockIdx.x * blockDim.x + threadIdx.x;
583    if (i < n) {
584        result[i] = a[i] - b[i];
585    }
586}
587"#
588        .to_string();
589
590        let rocm_source = cuda_source.clone();
591        let wgpu_source = String::new();
592        let metal_source = String::new();
593        let opencl_source = String::new();
594
595        Self {
596            base: BaseKernel::new(
597                "elementwise_sub",
598                &cuda_source,
599                &rocm_source,
600                &wgpu_source,
601                &metal_source,
602                &opencl_source,
603                metadata,
604            ),
605        }
606    }
607}
608
609impl GpuKernel for ElementwiseSubKernel {
610    fn name(&self) -> &str {
611        self.base.name()
612    }
613
614    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
615        self.base.source_for_backend(backend)
616    }
617
618    fn metadata(&self) -> KernelMetadata {
619        self.base.metadata()
620    }
621
622    fn can_specialize(&self, params: &KernelParams) -> bool {
623        matches!(
624            params.datatype,
625            DataType::Float32 | DataType::Float64 | DataType::Int32
626        )
627    }
628
629    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
630        if !self.can_specialize(params) {
631            return Err(GpuError::SpecializationNotSupported);
632        }
633        Ok(Box::new(Self::new()))
634    }
635}
636/// Element-wise power kernel (pow(a, b))
637pub struct ElementwisePowKernel {
638    base: BaseKernel,
639}
640
641impl Default for ElementwisePowKernel {
642    fn default() -> Self {
643        Self::new()
644    }
645}
646
647impl ElementwisePowKernel {
648    pub fn new() -> Self {
649        let metadata = KernelMetadata {
650            workgroup_size: [256, 1, 1],
651            local_memory_usage: 0,
652            supports_tensor_cores: false,
653            operationtype: OperationType::ComputeIntensive,
654            backend_metadata: HashMap::new(),
655        };
656
657        let cuda_source = r#"
658extern "C" __global__ void elementwise_pow(
659    const float* __restrict__ a,
660    const float* __restrict__ b,
661    float* __restrict__ result,
662    int n
663) {
664    int i = blockIdx.x * blockDim.x + threadIdx.x;
665    if (i < n) {
666        result[i] = powf(a[i], b[i]);
667    }
668}
669"#
670        .to_string();
671
672        let rocm_source = cuda_source.clone();
673        let wgpu_source = String::new();
674        let metal_source = String::new();
675        let opencl_source = String::new();
676
677        Self {
678            base: BaseKernel::new(
679                "elementwise_pow",
680                &cuda_source,
681                &rocm_source,
682                &wgpu_source,
683                &metal_source,
684                &opencl_source,
685                metadata,
686            ),
687        }
688    }
689}
690
691impl GpuKernel for ElementwisePowKernel {
692    fn name(&self) -> &str {
693        self.base.name()
694    }
695
696    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
697        self.base.source_for_backend(backend)
698    }
699
700    fn metadata(&self) -> KernelMetadata {
701        self.base.metadata()
702    }
703
704    fn can_specialize(&self, params: &KernelParams) -> bool {
705        matches!(params.datatype, DataType::Float32 | DataType::Float64)
706    }
707
708    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
709        if !self.can_specialize(params) {
710            return Err(GpuError::SpecializationNotSupported);
711        }
712        Ok(Box::new(Self::new()))
713    }
714}
715
716/// Element-wise square root kernel (sqrt(a))
717pub struct ElementwiseSqrtKernel {
718    base: BaseKernel,
719}
720
721impl Default for ElementwiseSqrtKernel {
722    fn default() -> Self {
723        Self::new()
724    }
725}
726
727impl ElementwiseSqrtKernel {
728    pub fn new() -> Self {
729        let metadata = KernelMetadata {
730            workgroup_size: [256, 1, 1],
731            local_memory_usage: 0,
732            supports_tensor_cores: false,
733            operationtype: OperationType::ComputeIntensive,
734            backend_metadata: HashMap::new(),
735        };
736
737        let cuda_source = r#"
738extern "C" __global__ void elementwise_sqrt(
739    const float* __restrict__ a,
740    float* __restrict__ result,
741    int n
742) {
743    int i = blockIdx.x * blockDim.x + threadIdx.x;
744    if (i < n) {
745        result[i] = sqrtf(a[i]);
746    }
747}
748"#
749        .to_string();
750
751        let rocm_source = cuda_source.clone();
752        let wgpu_source = String::new();
753        let metal_source = String::new();
754        let opencl_source = String::new();
755
756        Self {
757            base: BaseKernel::new(
758                "elementwise_sqrt",
759                &cuda_source,
760                &rocm_source,
761                &wgpu_source,
762                &metal_source,
763                &opencl_source,
764                metadata,
765            ),
766        }
767    }
768}
769
770impl GpuKernel for ElementwiseSqrtKernel {
771    fn name(&self) -> &str {
772        self.base.name()
773    }
774
775    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
776        self.base.source_for_backend(backend)
777    }
778
779    fn metadata(&self) -> KernelMetadata {
780        self.base.metadata()
781    }
782
783    fn can_specialize(&self, params: &KernelParams) -> bool {
784        matches!(params.datatype, DataType::Float32 | DataType::Float64)
785    }
786
787    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
788        if !self.can_specialize(params) {
789            return Err(GpuError::SpecializationNotSupported);
790        }
791        Ok(Box::new(Self::new()))
792    }
793}
794
795/// Element-wise exponential kernel (exp(a))
796pub struct ElementwiseExpKernel {
797    base: BaseKernel,
798}
799
800impl Default for ElementwiseExpKernel {
801    fn default() -> Self {
802        Self::new()
803    }
804}
805
806impl ElementwiseExpKernel {
807    pub fn new() -> Self {
808        let metadata = KernelMetadata {
809            workgroup_size: [256, 1, 1],
810            local_memory_usage: 0,
811            supports_tensor_cores: false,
812            operationtype: OperationType::ComputeIntensive,
813            backend_metadata: HashMap::new(),
814        };
815
816        let cuda_source = r#"
817extern "C" __global__ void elementwise_exp(
818    const float* __restrict__ a,
819    float* __restrict__ result,
820    int n
821) {
822    int i = blockIdx.x * blockDim.x + threadIdx.x;
823    if (i < n) {
824        result[i] = expf(a[i]);
825    }
826}
827"#
828        .to_string();
829
830        let rocm_source = cuda_source.clone();
831        let wgpu_source = String::new();
832        let metal_source = String::new();
833        let opencl_source = String::new();
834
835        Self {
836            base: BaseKernel::new(
837                "elementwise_exp",
838                &cuda_source,
839                &rocm_source,
840                &wgpu_source,
841                &metal_source,
842                &opencl_source,
843                metadata,
844            ),
845        }
846    }
847}
848
849impl GpuKernel for ElementwiseExpKernel {
850    fn name(&self) -> &str {
851        self.base.name()
852    }
853
854    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
855        self.base.source_for_backend(backend)
856    }
857
858    fn metadata(&self) -> KernelMetadata {
859        self.base.metadata()
860    }
861
862    fn can_specialize(&self, params: &KernelParams) -> bool {
863        matches!(params.datatype, DataType::Float32 | DataType::Float64)
864    }
865
866    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
867        if !self.can_specialize(params) {
868            return Err(GpuError::SpecializationNotSupported);
869        }
870        Ok(Box::new(Self::new()))
871    }
872}
873
874/// Element-wise logarithm kernel (log(a))
875pub struct ElementwiseLogKernel {
876    base: BaseKernel,
877}
878
879impl Default for ElementwiseLogKernel {
880    fn default() -> Self {
881        Self::new()
882    }
883}
884
885impl ElementwiseLogKernel {
886    pub fn new() -> Self {
887        let metadata = KernelMetadata {
888            workgroup_size: [256, 1, 1],
889            local_memory_usage: 0,
890            supports_tensor_cores: false,
891            operationtype: OperationType::ComputeIntensive,
892            backend_metadata: HashMap::new(),
893        };
894
895        let cuda_source = r#"
896extern "C" __global__ void elementwise_log(
897    const float* __restrict__ a,
898    float* __restrict__ result,
899    int n
900) {
901    int i = blockIdx.x * blockDim.x + threadIdx.x;
902    if (i < n) {
903        result[i] = logf(a[i]);
904    }
905}
906"#
907        .to_string();
908
909        let rocm_source = cuda_source.clone();
910        let wgpu_source = String::new();
911        let metal_source = String::new();
912        let opencl_source = String::new();
913
914        Self {
915            base: BaseKernel::new(
916                "elementwise_log",
917                &cuda_source,
918                &rocm_source,
919                &wgpu_source,
920                &metal_source,
921                &opencl_source,
922                metadata,
923            ),
924        }
925    }
926}
927
928impl GpuKernel for ElementwiseLogKernel {
929    fn name(&self) -> &str {
930        self.base.name()
931    }
932
933    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
934        self.base.source_for_backend(backend)
935    }
936
937    fn metadata(&self) -> KernelMetadata {
938        self.base.metadata()
939    }
940
941    fn can_specialize(&self, params: &KernelParams) -> bool {
942        matches!(params.datatype, DataType::Float32 | DataType::Float64)
943    }
944
945    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
946        if !self.can_specialize(params) {
947            return Err(GpuError::SpecializationNotSupported);
948        }
949        Ok(Box::new(Self::new()))
950    }
951}