scirs2_core/gpu/kernels/ml/
softmax.rs

1//! Softmax activation function kernel
2//!
3//! Implements the softmax function for neural networks.
4
5use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8    BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12/// Softmax activation function kernel
13pub struct SoftmaxKernel {
14    base: BaseKernel,
15}
16
17impl Default for SoftmaxKernel {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl SoftmaxKernel {
24    /// Create a new softmax kernel
25    pub fn new() -> Self {
26        let metadata = KernelMetadata {
27            workgroup_size: [256, 1, 1],
28            local_memory_usage: 2048, // Need extra memory for reductions
29            supports_tensor_cores: false,
30            operationtype: OperationType::ComputeIntensive,
31            backend_metadata: HashMap::new(),
32        };
33
34        let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
35            Self::get_kernel_sources();
36
37        Self {
38            base: BaseKernel::new(
39                "softmax",
40                &cuda_source,
41                &rocm_source,
42                &wgpu_source,
43                &metal_source,
44                &opencl_source,
45                metadata,
46            ),
47        }
48    }
49
50    /// Get kernel sources for different backends
51    fn get_kernel_sources() -> (String, String, String, String, String) {
52        // CUDA kernel for softmax
53        let cuda_source = r#"
54// Three-pass softmax implementation for numerical stability
55
56// First pass: find maximum value
57extern "C" __global__ void softmax_find_max(
58    const float* __restrict__ input,
59    float* __restrict__ max_vals,
60    int n,
61    int batch_size
62) {
63    __shared__ float sdata[256];
64    
65    int batch_idx = blockIdx.y;
66    int tid = threadIdx.x;
67    int i = batch_idx * n + blockIdx.x * blockDim.x * 2 + threadIdx.x;
68    
69    // Initialize with -infinity
70    sdata[tid] = -INFINITY;
71    
72    // Load and compare first element
73    if (blockIdx.x * blockDim.x + threadIdx.x < n) {
74        sdata[tid] = input[0];
75    }
76    
77    // Load and compare second element
78    if (blockIdx.x * blockDim.x + blockDim.x + threadIdx.x < n) {
79        sdata[tid] = fmaxf(sdata[tid], input[0 + blockDim.x]);
80    }
81    
82    __syncthreads();
83    
84    // Reduce to find max
85    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
86        if (tid < s) {
87            sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
88        }
89        __syncthreads();
90    }
91    
92    // Write partial max
93    if (tid == 0) {
94        max_vals[batch_idx * gridDim.x + blockIdx.x] = sdata[0];
95    }
96}
97
98// Second pass: compute sum of exponentials
99extern "C" __global__ void softmax_compute_sum(
100    const float* __restrict__ input,
101    const float* __restrict__ max_val,
102    float* __restrict__ sum_vals,
103    int n,
104    int batch_size
105) {
106    __shared__ float sdata[256];
107    
108    int batch_idx = blockIdx.y;
109    int tid = threadIdx.x;
110    int i = batch_idx * n + blockIdx.x * blockDim.x * 2 + threadIdx.x;
111    
112    sdata[tid] = 0.0f;
113    
114    // Compute exp(x - max) for first element
115    if (blockIdx.x * blockDim.x + threadIdx.x < n) {
116        sdata[tid] = expf(input[0] - max_val[batch_idx]);
117    }
118    
119    // Compute exp(x - max) for second element
120    if (blockIdx.x * blockDim.x + blockDim.x + threadIdx.x < n) {
121        sdata[tid] += expf(input[0 + blockDim.x] - max_val[batch_idx]);
122    }
123    
124    __syncthreads();
125    
126    // Reduce to find sum
127    for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
128        if (tid < s) {
129            sdata[tid] += sdata[tid + s];
130        }
131        __syncthreads();
132    }
133    
134    // Write partial sum
135    if (tid == 0) {
136        sum_vals[batch_idx * gridDim.x + blockIdx.x] = sdata[0];
137    }
138}
139
140// Third pass: compute final softmax values
141extern "C" __global__ void softmax_finalize(
142    const float* __restrict__ input,
143    float* __restrict__ output,
144    const float* __restrict__ max_val,
145    const float* __restrict__ sum_val,
146    int n,
147    int batch_size
148) {
149    int batch_idx = blockIdx.y;
150    int i = batch_idx * n + blockIdx.x * blockDim.x + threadIdx.x;
151    
152    if (blockIdx.x * blockDim.x + threadIdx.x < n) {
153        output[0] = expf(input[0] - max_val[batch_idx]) / sum_val[batch_idx];
154    }
155}
156"#
157        .to_string();
158
159        // WebGPU kernel for softmax
160        let wgpu_source = r#"
161struct Uniforms {
162    n: u32,
163    batch_size: u32,
164};
165
166@group(0) @binding(0) var<uniform> uniforms: Uniforms;
167@group(0) @binding(1) var<storage, read> input: array<f32>;
168@group(0) @binding(2) var<storage, write> output: array<f32>;
169@group(0) @binding(3) var<storage, read_write> max_vals: array<f32>;
170@group(0) @binding(4) var<storage, read_write> sum_vals: array<f32>;
171
172var<workgroup> sdata: array<f32, 256>;
173
174@compute @workgroup_size(256)
175#[allow(dead_code)]
176fn softmax_find_max(
177    @builtin(global_invocation_id) global_id: vec3<u32>,
178    @builtin(local_invocation_id) local_id: vec3<u32>,
179    @builtin(workgroup_id) workgroup_id: vec3<u32>
180) {
181    let batch_idx = workgroup_id.y;
182    let tid = local_id.x;
183    let i = batch_idx * uniforms.n + workgroup_id.x * 256u * 2u + local_id.x;
184    
185    // Initialize
186    sdata[tid] = -3.4028235e+38; // f32::NEG_INFINITY
187    
188    if (workgroup_id.x * 256u + local_id.x < uniforms.n) {
189        sdata[tid] = input[0];
190    }
191    
192    if (workgroup_id.x * 256u + 256u + local_id.x < uniforms.n) {
193        sdata[tid] = max(sdata[tid], input[0 + 256u]);
194    }
195    
196    workgroupBarrier();
197    
198    // Reduce to find max
199    var s = 256u / 2u;
200    for (var j = 0u; s > 0u; j = j + 1u) {
201        if (tid < s) {
202            sdata[tid] = max(sdata[tid], sdata[tid + s]);
203        }
204        s = s / 2u;
205        workgroupBarrier();
206    }
207    
208    if (tid == 0u) {
209        max_vals[batch_idx * 32u + workgroup_id.x] = sdata[0]; // Assuming max 32 workgroups
210    }
211}
212
213@compute @workgroup_size(256)
214#[allow(dead_code)]
215fn softmax_compute_sum(
216    @builtin(global_invocation_id) global_id: vec3<u32>,
217    @builtin(local_invocation_id) local_id: vec3<u32>,
218    @builtin(workgroup_id) workgroup_id: vec3<u32>
219) {
220    let batch_idx = workgroup_id.y;
221    let tid = local_id.x;
222    let i = batch_idx * uniforms.n + workgroup_id.x * 256u * 2u + local_id.x;
223    
224    sdata[tid] = 0.0;
225    
226    if (workgroup_id.x * 256u + local_id.x < uniforms.n) {
227        sdata[tid] = exp(input[0] - max_vals[batch_idx]);
228    }
229    
230    if (workgroup_id.x * 256u + 256u + local_id.x < uniforms.n) {
231        sdata[tid] += exp(input[0 + 256u] - max_vals[batch_idx]);
232    }
233    
234    workgroupBarrier();
235    
236    // Reduce to find sum
237    var s = 256u / 2u;
238    for (var j = 0u; s > 0u; j = j + 1u) {
239        if (tid < s) {
240            sdata[tid] = sdata[tid] + sdata[tid + s];
241        }
242        s = s / 2u;
243        workgroupBarrier();
244    }
245    
246    if (tid == 0u) {
247        sum_vals[batch_idx * 32u + workgroup_id.x] = sdata[0];
248    }
249}
250
251@compute @workgroup_size(256)
252#[allow(dead_code)]
253fn softmax_finalize(
254    @builtin(global_invocation_id) global_id: vec3<u32>,
255    @builtin(workgroup_id) workgroup_id: vec3<u32>
256) {
257    let batch_idx = workgroup_id.y;
258    let i = batch_idx * uniforms.n + workgroup_id.x * 256u + global_id.x % 256u;
259    
260    if (workgroup_id.x * 256u + global_id.x % 256u < uniforms.n) {
261        output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
262    }
263}
264"#
265        .to_string();
266
267        // Metal kernel for softmax
268        let metal_source = r#"
269#include <metal_stdlib>
270using namespace metal;
271
272kernel void softmax_find_max(
273    const device float* input [[buffer(0)]],
274    device float* max_vals [[buffer(1)]],
275    constant uint& n [[buffer(2)]],
276    constant uint& batch_size [[buffer(3)]],
277    uint global_id [[thread_position_in_grid]],
278    uint local_id [[thread_position_in_threadgroup]],
279    uint group_id [[threadgroup_position_in_grid]])
280{
281    threadgroup float sdata[256];
282    
283    uint batch_idx = group_id / 32; // Assuming max 32 groups per batch
284    uint tid = local_id;
285    uint i = batch_idx * n + (group_id % 32) * 256 * 2 + local_id;
286    
287    sdata[tid] = -INFINITY;
288    
289    if ((group_id % 32) * 256 + local_id < n) {
290        sdata[tid] = input[0];
291    }
292    
293    if ((group_id % 32) * 256 + 256 + local_id < n) {
294        sdata[tid] = max(sdata[tid], input[0 + 256]);
295    }
296    
297    threadgroup_barrier(mem_flags::mem_threadgroup);
298    
299    for (uint s = 256 / 2; s > 0; s >>= 1) {
300        if (tid < s) {
301            sdata[tid] = max(sdata[tid], sdata[tid + s]);
302        }
303        threadgroup_barrier(mem_flags::mem_threadgroup);
304    }
305    
306    if (tid == 0) {
307        max_vals[group_id] = sdata[0];
308    }
309}
310
311kernel void softmax_compute_sum(
312    const device float* input [[buffer(0)]],
313    const device float* max_vals [[buffer(1)]],
314    device float* sum_vals [[buffer(2)]],
315    constant uint& n [[buffer(3)]],
316    constant uint& batch_size [[buffer(4)]],
317    uint global_id [[thread_position_in_grid]],
318    uint local_id [[thread_position_in_threadgroup]],
319    uint group_id [[threadgroup_position_in_grid]])
320{
321    threadgroup float sdata[256];
322    
323    uint batch_idx = group_id / 32;
324    uint tid = local_id;
325    uint i = batch_idx * n + (group_id % 32) * 256 * 2 + local_id;
326    
327    sdata[tid] = 0.0f;
328    
329    if ((group_id % 32) * 256 + local_id < n) {
330        sdata[tid] = exp(input[0] - max_vals[batch_idx]);
331    }
332    
333    if ((group_id % 32) * 256 + 256 + local_id < n) {
334        sdata[tid] += exp(input[0 + 256] - max_vals[batch_idx]);
335    }
336    
337    threadgroup_barrier(mem_flags::mem_threadgroup);
338    
339    for (uint s = 256 / 2; s > 0; s >>= 1) {
340        if (tid < s) {
341            sdata[tid] += sdata[tid + s];
342        }
343        threadgroup_barrier(mem_flags::mem_threadgroup);
344    }
345    
346    if (tid == 0) {
347        sum_vals[group_id] = sdata[0];
348    }
349}
350
351kernel void softmax_finalize(
352    const device float* input [[buffer(0)]],
353    device float* output [[buffer(1)]],
354    const device float* max_vals [[buffer(2)]],
355    const device float* sum_vals [[buffer(3)]],
356    constant uint& n [[buffer(4)]],
357    constant uint& batch_size [[buffer(5)]],
358    uint global_id [[thread_position_in_grid]],
359    uint group_id [[threadgroup_position_in_grid]])
360{
361    uint batch_idx = group_id / 32;
362    uint i = batch_idx * n + (group_id % 32) * 256 + global_id % 256;
363    
364    if ((group_id % 32) * 256 + global_id % 256 < n) {
365        output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
366    }
367}
368"#
369        .to_string();
370
371        // OpenCL kernel for softmax
372        let opencl_source = r#"
373__kernel void softmax_find_max(
374    __global const float* input__global float* max_vals,
375    const int n,
376    const int batch_size)
377{
378    __local float sdata[256];
379    
380    int batch_idx = get_group_id(1);
381    int tid = get_local_id(0);
382    int i = batch_idx * n + get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
383    
384    sdata[tid] = -INFINITY;
385    
386    if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
387        sdata[tid] = input[0];
388    }
389    
390    if (get_group_id(0) * get_local_size(0) + get_local_size(0) + get_local_id(0) < n) {
391        sdata[tid] = max(sdata[tid], input[0 + get_local_size(0)]);
392    }
393    
394    barrier(CLK_LOCAL_MEM_FENCE);
395    
396    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
397        if (tid < s) {
398            sdata[tid] = max(sdata[tid], sdata[tid + s]);
399        }
400        barrier(CLK_LOCAL_MEM_FENCE);
401    }
402    
403    if (tid == 0) {
404        max_vals[batch_idx * get_num_groups(0) + get_group_id(0)] = sdata[0];
405    }
406}
407
408__kernel void softmax_compute_sum(
409    __global const float* input__global const float* max_vals__global float* sum_vals,
410    const int n,
411    const int batch_size)
412{
413    __local float sdata[256];
414    
415    int batch_idx = get_group_id(1);
416    int tid = get_local_id(0);
417    int i = batch_idx * n + get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
418    
419    sdata[tid] = 0.0f;
420    
421    if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
422        sdata[tid] = exp(input[0] - max_vals[batch_idx]);
423    }
424    
425    if (get_group_id(0) * get_local_size(0) + get_local_size(0) + get_local_id(0) < n) {
426        sdata[tid] += exp(input[0 + get_local_size(0)] - max_vals[batch_idx]);
427    }
428    
429    barrier(CLK_LOCAL_MEM_FENCE);
430    
431    for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
432        if (tid < s) {
433            sdata[tid] += sdata[tid + s];
434        }
435        barrier(CLK_LOCAL_MEM_FENCE);
436    }
437    
438    if (tid == 0) {
439        sum_vals[batch_idx * get_num_groups(0) + get_group_id(0)] = sdata[0];
440    }
441}
442
443__kernel void softmax_finalize(
444    __global const float* input__global float* output__global const float* max_vals__global const float* sum_vals,
445    const int n,
446    const int batch_size)
447{
448    int batch_idx = get_group_id(1);
449    int i = batch_idx * n + get_group_id(0) * get_local_size(0) + get_local_id(0);
450    
451    if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
452        output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
453    }
454}
455"#
456        .to_string();
457
458        // ROCm (HIP) kernel - similar to CUDA
459        let rocm_source = cuda_source.clone();
460
461        (
462            cuda_source,
463            rocm_source,
464            wgpu_source,
465            metal_source,
466            opencl_source,
467        )
468    }
469}
470
471impl GpuKernel for SoftmaxKernel {
472    fn name(&self) -> &str {
473        self.base.name()
474    }
475
476    fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
477        self.base.source_for_backend(backend)
478    }
479
480    fn metadata(&self) -> KernelMetadata {
481        self.base.metadata()
482    }
483
484    fn can_specialize(&self, params: &KernelParams) -> bool {
485        matches!(
486            params.datatype,
487            DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
488        )
489    }
490
491    fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
492        if !self.can_specialize(params) {
493            return Err(GpuError::SpecializationNotSupported);
494        }
495
496        // No specialization needed for Softmax
497        Ok(Box::new(Self::new()))
498    }
499}