1use std::collections::HashMap;
6
7use crate::gpu::kernels::{
8 BaseKernel, DataType, GpuKernel, KernelMetadata, KernelParams, OperationType,
9};
10use crate::gpu::{GpuBackend, GpuError};
11
12pub struct Conv1dKernel {
14 base: BaseKernel,
15}
16
17impl Conv1dKernel {
18 pub fn new() -> Self {
20 let metadata = KernelMetadata {
21 workgroup_size: [256, 1, 1],
22 local_memory_usage: 2048, supports_tensor_cores: false,
24 operationtype: OperationType::ComputeIntensive,
25 backend_metadata: HashMap::new(),
26 };
27
28 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
29 Self::get_kernel_sources();
30
31 Self {
32 base: BaseKernel::new(
33 "conv1d",
34 &cuda_source,
35 &rocm_source,
36 &wgpu_source,
37 &metal_source,
38 &opencl_source,
39 metadata,
40 ),
41 }
42 }
43
44 fn get_kernel_sources() -> (String, String, String, String, String) {
46 let cuda_source = r#"
48extern "C" __global__ void conv1d(
49 const float* __restrict__ input,
50 const float* __restrict__ kernel,
51 float* __restrict__ output,
52 int input_length,
53 int kernel_length,
54 int output_length,
55 int stride,
56 int padding
57) {
58 int out_idx = blockIdx.x * blockDim.x + threadIdx.x;
59
60 if (out_idx >= output_length) {
61 return;
62 }
63
64 float sum = 0.0f;
65
66 for (int k = 0; k < kernel_length; k++) {
67 int input_idx = out_idx * stride + k - padding;
68
69 if (input_idx >= 0 && input_idx < input_length) {
70 sum += input[input_idx] * kernel[k];
71 }
72 }
73
74 output[out_idx] = sum;
75}
76"#
77 .to_string();
78
79 let wgpu_source = r#"
81struct Uniforms {
82 input_length: u32,
83 kernel_length: u32,
84 output_length: u32,
85 stride: u32,
86 padding: u32,
87};
88
89@group(0) @binding(0) var<uniform> uniforms: Uniforms;
90@group(0) @binding(1) var<storage, read> input: array<f32>;
91@group(0) @binding(2) var<storage, read> kernel_data: array<f32>;
92@group(0) @binding(3) var<storage, write> output: array<f32>;
93
94@compute @workgroup_size(256)
95#[allow(dead_code)]
96fn conv1d(@builtin(global_invocation_id) global_id: vec3<u32>) {
97 let out_idx = global_id.x;
98
99 if (out_idx >= uniforms.output_length) {
100 return;
101 }
102
103 var sum = 0.0;
104
105 for (var k = 0u; k < uniforms.kernel_length; k = k + 1u) {
106 let input_idx = i32(out_idx * uniforms.stride + k) - i32(uniforms.padding);
107
108 if (input_idx >= 0 && input_idx < i32(uniforms.input_length)) {
109 sum += input[input_idx] * kernel_data[k];
110 }
111 }
112
113 output[out_idx] = sum;
114}
115"#
116 .to_string();
117
118 let metal_source = r#"
120#include <metal_stdlib>
121using namespace metal;
122
123kernel void conv1d(
124 const device float* input [[buffer(0)]],
125 const device float* kernel_data [[buffer(1)]],
126 device float* output [[buffer(2)]],
127 constant uint& input_length [[buffer(3)]],
128 constant uint& kernel_length [[buffer(4)]],
129 constant uint& output_length [[buffer(5)]],
130 constant uint& stride [[buffer(6)]],
131 constant uint& padding [[buffer(7)]],
132 uint gid [[thread_position_in_grid]])
133{
134 if (gid >= output_length) {
135 return;
136 }
137
138 float sum = 0.0f;
139
140 for (uint k = 0; k < kernel_length; k++) {
141 int input_idx = int(gid * stride + k) - int(padding);
142
143 if (input_idx >= 0 && input_idx < int(input_length)) {
144 sum += input[input_idx] * kernel_data[k];
145 }
146 }
147
148 output[gid] = sum;
149}
150"#
151 .to_string();
152
153 let opencl_source = r#"
155__kernel void conv1d(
156 __global const float* input__global const float* kernel_data__global float* output,
157 const int input_length,
158 const int kernel_length,
159 const int output_length,
160 const int stride,
161 const int padding)
162{
163 int out_idx = get_global_id(0);
164
165 if (out_idx >= output_length) {
166 return;
167 }
168
169 float sum = 0.0f;
170
171 for (int k = 0; k < kernel_length; k++) {
172 int input_idx = out_idx * stride + k - padding;
173
174 if (input_idx >= 0 && input_idx < input_length) {
175 sum += input[input_idx] * kernel_data[k];
176 }
177 }
178
179 output[out_idx] = sum;
180}
181"#
182 .to_string();
183
184 let rocm_source = cuda_source.clone();
186
187 (
188 cuda_source,
189 rocm_source,
190 wgpu_source,
191 metal_source,
192 opencl_source,
193 )
194 }
195}
196
197impl GpuKernel for Conv1dKernel {
198 fn name(&self) -> &str {
199 self.base.name()
200 }
201
202 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
203 self.base.source_for_backend(backend)
204 }
205
206 fn metadata(&self) -> KernelMetadata {
207 self.base.metadata()
208 }
209
210 fn can_specialize(&self, params: &KernelParams) -> bool {
211 matches!(params.datatype, DataType::Float32 | DataType::Float64)
212 }
213
214 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
215 if !self.can_specialize(params) {
216 return Err(GpuError::SpecializationNotSupported);
217 }
218
219 Ok(Box::new(Self::new()))
220 }
221}
222
223impl Default for Conv1dKernel {
224 fn default() -> Self {
225 Self::new()
226 }
227}
228
229pub struct Conv2dKernel {
231 base: BaseKernel,
232}
233
234impl Conv2dKernel {
235 pub fn new() -> Self {
237 let metadata = KernelMetadata {
238 workgroup_size: [16, 16, 1],
239 local_memory_usage: 4096, supports_tensor_cores: true, operationtype: OperationType::ComputeIntensive,
242 backend_metadata: HashMap::new(),
243 };
244
245 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
246 Self::get_kernel_sources();
247
248 Self {
249 base: BaseKernel::new(
250 "conv2d",
251 &cuda_source,
252 &rocm_source,
253 &wgpu_source,
254 &metal_source,
255 &opencl_source,
256 metadata,
257 ),
258 }
259 }
260
261 fn get_kernel_sources() -> (String, String, String, String, String) {
263 let cuda_source = r#"
265extern "C" __global__ void conv2d(
266 const float* __restrict__ input,
267 const float* __restrict__ kernel,
268 float* __restrict__ output,
269 int batch_size,
270 int in_channels,
271 int out_channels,
272 int input_height,
273 int input_width,
274 int output_height,
275 int output_width,
276 int kernel_height,
277 int kernel_width,
278 int stride_y,
279 int stride_x,
280 int padding_y,
281 int padding_x
282) {
283 int batch_idx = blockIdx.z;
284 int out_channel = blockIdx.y;
285 int out_y = blockIdx.x * blockDim.x + threadIdx.x;
286 int out_x = threadIdx.y;
287
288 if (batch_idx >= batch_size || out_channel >= out_channels ||
289 out_y >= output_height || out_x >= output_width) {
290 return;
291 }
292
293 float sum = 0.0f;
294
295 // Input and output offsets
296 int input_batch_offset = batch_idx * in_channels * input_height * input_width;
297 int output_batch_offset = batch_idx * out_channels * output_height * output_width;
298 int kernel_offset = out_channel * in_channels * kernel_height * kernel_width;
299
300 // Convolution computation
301 for (int in_ch = 0; in_ch < in_channels; in_ch++) {
302 for (int ky = 0; ky < kernel_height; ky++) {
303 for (int kx = 0; kx < kernel_width; kx++) {
304 int input_y = out_y * stride_y + ky - padding_y;
305 int input_x = out_x * stride_x + kx - padding_x;
306
307 if (input_y >= 0 && input_y < input_height &&
308 input_x >= 0 && input_x < input_width) {
309
310 int input_idx = input_batch_offset +
311 in_ch * input_height * input_width +
312 input_y * input_width + input_x;
313
314 int kernel_idx = kernel_offset +
315 in_ch * kernel_height * kernel_width +
316 ky * kernel_width + kx;
317
318 sum += input[input_idx] * kernel[kernel_idx];
319 }
320 }
321 }
322 }
323
324 int output_idx = output_batch_offset +
325 out_channel * output_height * output_width +
326 out_y * output_width + out_x;
327 output[output_idx] = sum;
328}
329"#
330 .to_string();
331
332 let wgpu_source = r#"
334struct Uniforms {
335 batch_size: u32,
336 in_channels: u32,
337 out_channels: u32,
338 input_height: u32,
339 input_width: u32,
340 output_height: u32,
341 output_width: u32,
342 kernel_height: u32,
343 kernel_width: u32,
344 stride_y: u32,
345 stride_x: u32,
346 padding_y: u32,
347 padding_x: u32,
348};
349
350@group(0) @binding(0) var<uniform> uniforms: Uniforms;
351@group(0) @binding(1) var<storage, read> input: array<f32>;
352@group(0) @binding(2) var<storage, read> kernel_data: array<f32>;
353@group(0) @binding(3) var<storage, write> output: array<f32>;
354
355@compute @workgroup_size(16, 16)
356#[allow(dead_code)]
357fn conv2d(@builtin(global_invocation_id) global_id: vec3<u32>) {
358 let batch_idx = global_id.z;
359 let out_channel = global_id.y % uniforms.out_channels;
360 let out_y = global_id.x;
361 let out_x = global_id.y / uniforms.out_channels;
362
363 if (batch_idx >= uniforms.batch_size || out_channel >= uniforms.out_channels ||
364 out_y >= uniforms.output_height || out_x >= uniforms.output_width) {
365 return;
366 }
367
368 var sum = 0.0;
369
370 // Simplified convolution - would need optimization
371 for (var in_ch = 0u; in_ch < uniforms.in_channels; in_ch = in_ch + 1u) {
372 for (var ky = 0u; ky < uniforms.kernel_height; ky = ky + 1u) {
373 for (var kx = 0u; kx < uniforms.kernel_width; kx = kx + 1u) {
374 let input_y = i32(out_y * uniforms.stride_y + ky) - i32(uniforms.padding_y);
375 let input_x = i32(out_x * uniforms.stride_x + kx) - i32(uniforms.padding_x);
376
377 if (input_y >= 0 && input_y < i32(uniforms.input_height) &&
378 input_x >= 0 && input_x < i32(uniforms.input_width)) {
379
380 let input_idx = batch_idx * uniforms.in_channels * uniforms.input_height * uniforms.input_width +
381 in_ch * uniforms.input_height * uniforms.input_width +
382 u32(input_y) * uniforms.input_width + u32(input_x);
383
384 let kernel_idx = out_channel * uniforms.in_channels * uniforms.kernel_height * uniforms.kernel_width +
385 in_ch * uniforms.kernel_height * uniforms.kernel_width +
386 ky * uniforms.kernel_width + kx;
387
388 sum += input[input_idx] * kernel_data[kernel_idx];
389 }
390 }
391 }
392 }
393
394 let output_idx = batch_idx * uniforms.out_channels * uniforms.output_height * uniforms.output_width +
395 out_channel * uniforms.output_height * uniforms.output_width +
396 out_y * uniforms.output_width + out_x;
397 output[output_idx] = sum;
398}
399"#
400 .to_string();
401
402 let metal_source = r#"
404// Metal 2D convolution implementation (simplified)
405#include <metal_stdlib>
406using namespace metal;
407
408kernel void conv2d(
409 const device float* input [[buffer(0)]],
410 const device float* kernel_data [[buffer(1)]],
411 device float* output [[buffer(2)]],
412 constant uint& batch_size [[buffer(3)]],
413 constant uint& in_channels [[buffer(4)]],
414 constant uint& out_channels [[buffer(5)]],
415 constant uint& input_height [[buffer(6)]],
416 constant uint& input_width [[buffer(7)]],
417 constant uint& output_height [[buffer(8)]],
418 constant uint& output_width [[buffer(9)]],
419 constant uint& kernel_height [[buffer(10)]],
420 constant uint& kernel_width [[buffer(11)]],
421 constant uint& stride_y [[buffer(12)]],
422 constant uint& stride_x [[buffer(13)]],
423 constant uint& padding_y [[buffer(14)]],
424 constant uint& padding_x [[buffer(15)]],
425 uint3 global_id [[thread_position_in_grid]])
426{
427 // Similar implementation to CUDA kernel
428}
429"#
430 .to_string();
431
432 let opencl_source = r#"
433// OpenCL 2D convolution implementation (simplified)
434__kernel void conv2d(
435 __global const float* input__global const float* kernel_data__global float* output,
436 const int batch_size,
437 const int in_channels,
438 const int out_channels,
439 const int input_height,
440 const int input_width,
441 const int output_height,
442 const int output_width,
443 const int kernel_height,
444 const int kernel_width,
445 const int stride_y,
446 const int stride_x,
447 const int padding_y,
448 const int padding_x)
449{
450 // Similar implementation to CUDA kernel
451}
452"#
453 .to_string();
454
455 let rocm_source = cuda_source.clone();
457
458 (
459 cuda_source,
460 rocm_source,
461 wgpu_source,
462 metal_source,
463 opencl_source,
464 )
465 }
466}
467
468impl GpuKernel for Conv2dKernel {
469 fn name(&self) -> &str {
470 self.base.name()
471 }
472
473 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
474 self.base.source_for_backend(backend)
475 }
476
477 fn metadata(&self) -> KernelMetadata {
478 self.base.metadata()
479 }
480
481 fn can_specialize(&self, params: &KernelParams) -> bool {
482 matches!(
483 params.datatype,
484 DataType::Float32 | DataType::Float64 | DataType::Float16 | DataType::BFloat16
485 )
486 }
487
488 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
489 if !self.can_specialize(params) {
490 return Err(GpuError::SpecializationNotSupported);
491 }
492
493 Ok(Box::new(Self::new()))
500 }
501}
502
503impl Default for Conv2dKernel {
504 fn default() -> Self {
505 Self::new()
506 }
507}