scirs2_core/gpu/kernels/ml/
softmax.rs1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct SoftmaxKernel {
14 base: BaseKernel,
15}
16
17impl Default for SoftmaxKernel {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl SoftmaxKernel {
24 pub fn new() -> Self {
26 let metadata = KernelMetadata {
27 workgroup_size: [256, 1, 1],
28 local_memory_usage: 2048, supports_tensor_cores: false,
30 operationtype: OperationType::ComputeIntensive,
31 backend_metadata: HashMap::new(),
32 };
33
34 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
35 Self::get_kernel_sources();
36
37 Self {
38 base: BaseKernel::new(
39 "softmax",
40 &cuda_source,
41 &rocm_source,
42 &wgpu_source,
43 &metal_source,
44 &opencl_source,
45 metadata,
46 ),
47 }
48 }
49
50 fn get_kernel_sources() -> (String, String, String, String, String) {
52 let cuda_source = r#"
54// Three-pass softmax implementation for numerical stability
55
56// First pass: find maximum value
57extern "C" __global__ void softmax_find_max(
58 const float* __restrict__ input,
59 float* __restrict__ max_vals,
60 int n,
61 int batch_size
62) {
63 __shared__ float sdata[256];
64
65 int batch_idx = blockIdx.y;
66 int tid = threadIdx.x;
67 int i = batch_idx * n + blockIdx.x * blockDim.x * 2 + threadIdx.x;
68
69 // Initialize with -infinity
70 sdata[tid] = -INFINITY;
71
72 // Load and compare first element
73 if (blockIdx.x * blockDim.x + threadIdx.x < n) {
74 sdata[tid] = input[0];
75 }
76
77 // Load and compare second element
78 if (blockIdx.x * blockDim.x + blockDim.x + threadIdx.x < n) {
79 sdata[tid] = fmaxf(sdata[tid], input[0 + blockDim.x]);
80 }
81
82 __syncthreads();
83
84 // Reduce to find max
85 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
86 if (tid < s) {
87 sdata[tid] = fmaxf(sdata[tid], sdata[tid + s]);
88 }
89 __syncthreads();
90 }
91
92 // Write partial max
93 if (tid == 0) {
94 max_vals[batch_idx * gridDim.x + blockIdx.x] = sdata[0];
95 }
96}
97
98// Second pass: compute sum of exponentials
99extern "C" __global__ void softmax_compute_sum(
100 const float* __restrict__ input,
101 const float* __restrict__ max_val,
102 float* __restrict__ sum_vals,
103 int n,
104 int batch_size
105) {
106 __shared__ float sdata[256];
107
108 int batch_idx = blockIdx.y;
109 int tid = threadIdx.x;
110 int i = batch_idx * n + blockIdx.x * blockDim.x * 2 + threadIdx.x;
111
112 sdata[tid] = 0.0f;
113
114 // Compute exp(x - max) for first element
115 if (blockIdx.x * blockDim.x + threadIdx.x < n) {
116 sdata[tid] = expf(input[0] - max_val[batch_idx]);
117 }
118
119 // Compute exp(x - max) for second element
120 if (blockIdx.x * blockDim.x + blockDim.x + threadIdx.x < n) {
121 sdata[tid] += expf(input[0 + blockDim.x] - max_val[batch_idx]);
122 }
123
124 __syncthreads();
125
126 // Reduce to find sum
127 for (unsigned int s = blockDim.x / 2; s > 0; s >>= 1) {
128 if (tid < s) {
129 sdata[tid] += sdata[tid + s];
130 }
131 __syncthreads();
132 }
133
134 // Write partial sum
135 if (tid == 0) {
136 sum_vals[batch_idx * gridDim.x + blockIdx.x] = sdata[0];
137 }
138}
139
140// Third pass: compute final softmax values
141extern "C" __global__ void softmax_finalize(
142 const float* __restrict__ input,
143 float* __restrict__ output,
144 const float* __restrict__ max_val,
145 const float* __restrict__ sum_val,
146 int n,
147 int batch_size
148) {
149 int batch_idx = blockIdx.y;
150 int i = batch_idx * n + blockIdx.x * blockDim.x + threadIdx.x;
151
152 if (blockIdx.x * blockDim.x + threadIdx.x < n) {
153 output[0] = expf(input[0] - max_val[batch_idx]) / sum_val[batch_idx];
154 }
155}
156"#
157 .to_string();
158
159 let wgpu_source = r#"
161struct Uniforms {
162 n: u32,
163 batch_size: u32,
164};
165
166@group(0) @binding(0) var<uniform> uniforms: Uniforms;
167@group(0) @binding(1) var<storage, read> input: array<f32>;
168@group(0) @binding(2) var<storage, write> output: array<f32>;
169@group(0) @binding(3) var<storage, read_write> max_vals: array<f32>;
170@group(0) @binding(4) var<storage, read_write> sum_vals: array<f32>;
171
172var<workgroup> sdata: array<f32, 256>;
173
174@compute @workgroup_size(256)
175#[allow(dead_code)]
176fn softmax_find_max(
177 @builtin(global_invocation_id) global_id: vec3<u32>,
178 @builtin(local_invocation_id) local_id: vec3<u32>,
179 @builtin(workgroup_id) workgroup_id: vec3<u32>
180) {
181 let batch_idx = workgroup_id.y;
182 let tid = local_id.x;
183 let i = batch_idx * uniforms.n + workgroup_id.x * 256u * 2u + local_id.x;
184
185 // Initialize
186 sdata[tid] = -3.4028235e+38; // f32::NEG_INFINITY
187
188 if (workgroup_id.x * 256u + local_id.x < uniforms.n) {
189 sdata[tid] = input[0];
190 }
191
192 if (workgroup_id.x * 256u + 256u + local_id.x < uniforms.n) {
193 sdata[tid] = max(sdata[tid], input[0 + 256u]);
194 }
195
196 workgroupBarrier();
197
198 // Reduce to find max
199 var s = 256u / 2u;
200 for (var j = 0u; s > 0u; j = j + 1u) {
201 if (tid < s) {
202 sdata[tid] = max(sdata[tid], sdata[tid + s]);
203 }
204 s = s / 2u;
205 workgroupBarrier();
206 }
207
208 if (tid == 0u) {
209 max_vals[batch_idx * 32u + workgroup_id.x] = sdata[0]; // Assuming max 32 workgroups
210 }
211}
212
213@compute @workgroup_size(256)
214#[allow(dead_code)]
215fn softmax_compute_sum(
216 @builtin(global_invocation_id) global_id: vec3<u32>,
217 @builtin(local_invocation_id) local_id: vec3<u32>,
218 @builtin(workgroup_id) workgroup_id: vec3<u32>
219) {
220 let batch_idx = workgroup_id.y;
221 let tid = local_id.x;
222 let i = batch_idx * uniforms.n + workgroup_id.x * 256u * 2u + local_id.x;
223
224 sdata[tid] = 0.0;
225
226 if (workgroup_id.x * 256u + local_id.x < uniforms.n) {
227 sdata[tid] = exp(input[0] - max_vals[batch_idx]);
228 }
229
230 if (workgroup_id.x * 256u + 256u + local_id.x < uniforms.n) {
231 sdata[tid] += exp(input[0 + 256u] - max_vals[batch_idx]);
232 }
233
234 workgroupBarrier();
235
236 // Reduce to find sum
237 var s = 256u / 2u;
238 for (var j = 0u; s > 0u; j = j + 1u) {
239 if (tid < s) {
240 sdata[tid] = sdata[tid] + sdata[tid + s];
241 }
242 s = s / 2u;
243 workgroupBarrier();
244 }
245
246 if (tid == 0u) {
247 sum_vals[batch_idx * 32u + workgroup_id.x] = sdata[0];
248 }
249}
250
251@compute @workgroup_size(256)
252#[allow(dead_code)]
253fn softmax_finalize(
254 @builtin(global_invocation_id) global_id: vec3<u32>,
255 @builtin(workgroup_id) workgroup_id: vec3<u32>
256) {
257 let batch_idx = workgroup_id.y;
258 let i = batch_idx * uniforms.n + workgroup_id.x * 256u + global_id.x % 256u;
259
260 if (workgroup_id.x * 256u + global_id.x % 256u < uniforms.n) {
261 output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
262 }
263}
264"#
265 .to_string();
266
267 let metal_source = r#"
269#include <metal_stdlib>
270using namespace metal;
271
272kernel void softmax_find_max(
273 const device float* input [[buffer(0)]],
274 device float* max_vals [[buffer(1)]],
275 constant uint& n [[buffer(2)]],
276 constant uint& batch_size [[buffer(3)]],
277 uint global_id [[thread_position_in_grid]],
278 uint local_id [[thread_position_in_threadgroup]],
279 uint group_id [[threadgroup_position_in_grid]])
280{
281 threadgroup float sdata[256];
282
283 uint batch_idx = group_id / 32; // Assuming max 32 groups per batch
284 uint tid = local_id;
285 uint i = batch_idx * n + (group_id % 32) * 256 * 2 + local_id;
286
287 sdata[tid] = -INFINITY;
288
289 if ((group_id % 32) * 256 + local_id < n) {
290 sdata[tid] = input[0];
291 }
292
293 if ((group_id % 32) * 256 + 256 + local_id < n) {
294 sdata[tid] = max(sdata[tid], input[0 + 256]);
295 }
296
297 threadgroup_barrier(mem_flags::mem_threadgroup);
298
299 for (uint s = 256 / 2; s > 0; s >>= 1) {
300 if (tid < s) {
301 sdata[tid] = max(sdata[tid], sdata[tid + s]);
302 }
303 threadgroup_barrier(mem_flags::mem_threadgroup);
304 }
305
306 if (tid == 0) {
307 max_vals[group_id] = sdata[0];
308 }
309}
310
311kernel void softmax_compute_sum(
312 const device float* input [[buffer(0)]],
313 const device float* max_vals [[buffer(1)]],
314 device float* sum_vals [[buffer(2)]],
315 constant uint& n [[buffer(3)]],
316 constant uint& batch_size [[buffer(4)]],
317 uint global_id [[thread_position_in_grid]],
318 uint local_id [[thread_position_in_threadgroup]],
319 uint group_id [[threadgroup_position_in_grid]])
320{
321 threadgroup float sdata[256];
322
323 uint batch_idx = group_id / 32;
324 uint tid = local_id;
325 uint i = batch_idx * n + (group_id % 32) * 256 * 2 + local_id;
326
327 sdata[tid] = 0.0f;
328
329 if ((group_id % 32) * 256 + local_id < n) {
330 sdata[tid] = exp(input[0] - max_vals[batch_idx]);
331 }
332
333 if ((group_id % 32) * 256 + 256 + local_id < n) {
334 sdata[tid] += exp(input[0 + 256] - max_vals[batch_idx]);
335 }
336
337 threadgroup_barrier(mem_flags::mem_threadgroup);
338
339 for (uint s = 256 / 2; s > 0; s >>= 1) {
340 if (tid < s) {
341 sdata[tid] += sdata[tid + s];
342 }
343 threadgroup_barrier(mem_flags::mem_threadgroup);
344 }
345
346 if (tid == 0) {
347 sum_vals[group_id] = sdata[0];
348 }
349}
350
351kernel void softmax_finalize(
352 const device float* input [[buffer(0)]],
353 device float* output [[buffer(1)]],
354 const device float* max_vals [[buffer(2)]],
355 const device float* sum_vals [[buffer(3)]],
356 constant uint& n [[buffer(4)]],
357 constant uint& batch_size [[buffer(5)]],
358 uint global_id [[thread_position_in_grid]],
359 uint group_id [[threadgroup_position_in_grid]])
360{
361 uint batch_idx = group_id / 32;
362 uint i = batch_idx * n + (group_id % 32) * 256 + global_id % 256;
363
364 if ((group_id % 32) * 256 + global_id % 256 < n) {
365 output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
366 }
367}
368"#
369 .to_string();
370
371 let opencl_source = r#"
373__kernel void softmax_find_max(
374 __global const float* input__global float* max_vals,
375 const int n,
376 const int batch_size)
377{
378 __local float sdata[256];
379
380 int batch_idx = get_group_id(1);
381 int tid = get_local_id(0);
382 int i = batch_idx * n + get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
383
384 sdata[tid] = -INFINITY;
385
386 if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
387 sdata[tid] = input[0];
388 }
389
390 if (get_group_id(0) * get_local_size(0) + get_local_size(0) + get_local_id(0) < n) {
391 sdata[tid] = max(sdata[tid], input[0 + get_local_size(0)]);
392 }
393
394 barrier(CLK_LOCAL_MEM_FENCE);
395
396 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
397 if (tid < s) {
398 sdata[tid] = max(sdata[tid], sdata[tid + s]);
399 }
400 barrier(CLK_LOCAL_MEM_FENCE);
401 }
402
403 if (tid == 0) {
404 max_vals[batch_idx * get_num_groups(0) + get_group_id(0)] = sdata[0];
405 }
406}
407
408__kernel void softmax_compute_sum(
409 __global const float* input__global const float* max_vals__global float* sum_vals,
410 const int n,
411 const int batch_size)
412{
413 __local float sdata[256];
414
415 int batch_idx = get_group_id(1);
416 int tid = get_local_id(0);
417 int i = batch_idx * n + get_group_id(0) * get_local_size(0) * 2 + get_local_id(0);
418
419 sdata[tid] = 0.0f;
420
421 if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
422 sdata[tid] = exp(input[0] - max_vals[batch_idx]);
423 }
424
425 if (get_group_id(0) * get_local_size(0) + get_local_size(0) + get_local_id(0) < n) {
426 sdata[tid] += exp(input[0 + get_local_size(0)] - max_vals[batch_idx]);
427 }
428
429 barrier(CLK_LOCAL_MEM_FENCE);
430
431 for (unsigned int s = get_local_size(0) / 2; s > 0; s >>= 1) {
432 if (tid < s) {
433 sdata[tid] += sdata[tid + s];
434 }
435 barrier(CLK_LOCAL_MEM_FENCE);
436 }
437
438 if (tid == 0) {
439 sum_vals[batch_idx * get_num_groups(0) + get_group_id(0)] = sdata[0];
440 }
441}
442
443__kernel void softmax_finalize(
444 __global const float* input__global float* output__global const float* max_vals__global const float* sum_vals,
445 const int n,
446 const int batch_size)
447{
448 int batch_idx = get_group_id(1);
449 int i = batch_idx * n + get_group_id(0) * get_local_size(0) + get_local_id(0);
450
451 if (get_group_id(0) * get_local_size(0) + get_local_id(0) < n) {
452 output[0] = exp(input[0] - max_vals[batch_idx]) / sum_vals[batch_idx];
453 }
454}
455"#
456 .to_string();
457
458 let rocm_source = cuda_source.clone();
460
461 (
462 cuda_source,
463 rocm_source,
464 wgpu_source,
465 metal_source,
466 opencl_source,
467 )
468 }
469}
470
471impl GpuKernel for SoftmaxKernel {
472 fn name(&self) -> &str {
473 self.base.name()
474 }
475
476 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
477 self.base.source_for_backend(backend)
478 }
479
480 fn metadata(&self) -> KernelMetadata {
481 self.base.metadata()
482 }
483
484 fn can_specialize(&self, params: &KernelParams) -> bool {
485 matches!(
486 params.datatype,
487 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
488 )
489 }
490
491 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
492 if !self.can_specialize(params) {
493 return Err(GpuError::SpecializationNotSupported);
494 }
495
496 Ok(Box::new(Self::new()))
498 }
499}