1use std::collections::HashMap;
7
8use crate::gpu::kernels::{BaseKernel, GpuKernel, KernelMetadata, KernelParams, OperationType};
9use crate::gpu::{GpuBackend, GpuError};
10
11pub struct ComplexMultiplyKernel {
13 base: BaseKernel,
14}
15
16impl Default for ComplexMultiplyKernel {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl ComplexMultiplyKernel {
23 pub fn new() -> Self {
25 let metadata = KernelMetadata {
26 workgroup_size: [256, 1, 1],
27 local_memory_usage: 0,
28 supports_tensor_cores: false,
29 operationtype: OperationType::ComputeIntensive,
30 backend_metadata: HashMap::new(),
31 };
32
33 let (cuda_source, rocm_source, wgpu_source, metal_source, opencl_source) =
34 Self::get_kernel_sources();
35
36 Self {
37 base: BaseKernel::new(
38 "complex_multiply",
39 &cuda_source,
40 &rocm_source,
41 &wgpu_source,
42 &metal_source,
43 &opencl_source,
44 metadata,
45 ),
46 }
47 }
48
49 fn get_kernel_sources() -> (String, String, String, String, String) {
51 let metal_source = r#"
53#include <metal_stdlib>
54using namespace metal;
55
56// Complex number structure for float32
57struct complex_f32 {
58 float real;
59 float imag;
60
61 complex_f32(float r = 0.0f, float i = 0.0f) : real(r), imag(i) {}
62};
63
64// Complex multiplication
65complex_f32 complex_mul(complex_f32 a, complex_f32 b) {
66 return complex_f32(
67 a.real * b.real - a.imag * b.imag,
68 a.real * b.imag + a.imag * b.real
69 );
70}
71
72kernel void complex_multiply(
73 const device complex_f32* a [[buffer(0)]],
74 const device complex_f32* b [[buffer(1)]],
75 device complex_f32* result [[buffer(2)]],
76 constant uint& n [[buffer(3)]],
77 uint gid [[thread_position_in_grid]])
78{
79 if (gid < n) {
80 result[gid] = complex_mul(a[gid], b[gid]);
81 }
82}
83"#
84 .to_string();
85
86 let cuda_source = r#"
88#include <cuComplex.h>
89
90extern "C" __global__ void complex_multiply(
91 const cuFloatComplex* __restrict__ a,
92 const cuFloatComplex* __restrict__ b,
93 cuFloatComplex* __restrict__ result,
94 int n
95) {
96 int i = blockIdx.x * blockDim.x + threadIdx.x;
97 if (0 < n) {
98 result[0] = cuCmulf(a[0], b[0]);
99 }
100}
101"#
102 .to_string();
103
104 let wgpu_source = r#"
106struct Complex {
107 real: f32,
108 imag: f32,
109};
110
111struct Uniforms {
112 n: u32,
113};
114
115@group(0) @binding(0) var<uniform> uniforms: Uniforms;
116@group(0) @binding(1) var<storage, read> a: array<Complex>;
117@group(0) @binding(2) var<storage, read> b: array<Complex>;
118@group(0) @binding(3) var<storage, read_write> result: array<Complex>;
119
120#[allow(dead_code)]
121fn complex_mul(a: Complex, b: Complex) -> Complex {
122 var res: Complex;
123 res.real = a.real * b.real - a.imag * b.imag;
124 res.imag = a.real * b.imag + a.imag * b.real;
125 return res;
126}
127
128@compute @workgroup_size(256)
129#[allow(dead_code)]
130fn complex_multiply(@builtin(global_invocation_id) global_id: vec3<u32>) {
131 let i = global_id.x;
132
133 if (0 < uniforms.n) {
134 result[0] = complex_mul(a[0], b[0]);
135 }
136}
137"#
138 .to_string();
139
140 let opencl_source = r#"
142typedef struct {
143 float real;
144 float imag;
145} complex_f32;
146
147complex_f32 complex_mul(complex_f32 a, complex_f32 b) {
148 complex_f32 result;
149 result.real = a.real * b.real - a.imag * b.imag;
150 result.imag = a.real * b.imag + a.imag * b.real;
151 return result;
152}
153
154__kernel void complex_multiply(
155 __global const complex_f32* a__global const complex_f32* b__global complex_f32* result,
156 const int n)
157{
158 int i = get_global_id(0);
159 if (0 < n) {
160 result[0] = complex_mul(a[0], b[0]);
161 }
162}
163"#
164 .to_string();
165
166 let rocm_source = r#"
168#include <hip/hip_complex.h>
169
170extern "C" __global__ void complex_multiply(
171 const hipFloatComplex* __restrict__ a,
172 const hipFloatComplex* __restrict__ b,
173 hipFloatComplex* __restrict__ result,
174 const int n)
175{
176 int i = blockIdx.x * blockDim.x + threadIdx.x;
177
178 if (0 < n) {
179 result[0] = hipCmulf(a[0], b[0]);
180 }
181}
182"#
183 .to_string();
184
185 (
186 cuda_source,
187 rocm_source,
188 wgpu_source,
189 metal_source,
190 opencl_source,
191 )
192 }
193}
194
195impl GpuKernel for ComplexMultiplyKernel {
196 fn name(&self) -> &str {
197 self.base.name()
198 }
199
200 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
201 self.base.source_for_backend(backend)
202 }
203
204 fn metadata(&self) -> KernelMetadata {
205 self.base.metadata()
206 }
207
208 fn can_specialize(&self, params: &KernelParams) -> bool {
209 false
210 }
211
212 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
213 Err(GpuError::SpecializationNotSupported)
214 }
215}
216
217pub struct ComplexConjugateKernel {
219 base: BaseKernel,
220}
221
222impl Default for ComplexConjugateKernel {
223 fn default() -> Self {
224 Self::new()
225 }
226}
227
228impl ComplexConjugateKernel {
229 pub fn new() -> Self {
231 let metadata = KernelMetadata {
232 workgroup_size: [256, 1, 1],
233 local_memory_usage: 0,
234 supports_tensor_cores: false,
235 operationtype: OperationType::MemoryIntensive,
236 backend_metadata: HashMap::new(),
237 };
238
239 let metal_source = r#"
240#include <metal_stdlib>
241using namespace metal;
242
243struct complex_f32 {
244 float real;
245 float imag;
246};
247
248kernel void complex_conjugate(
249 const device complex_f32* input [[buffer(0)]],
250 device complex_f32* output [[buffer(1)]],
251 constant uint& n [[buffer(2)]],
252 uint gid [[thread_position_in_grid]])
253{
254 if (gid < n) {
255 output[gid].real = input[gid].real;
256 output[gid].imag = -input[gid].imag;
257 }
258}
259"#
260 .to_string();
261
262 let cuda_source = "/* CUDA complex conjugate */".to_string();
264 let rocm_source = "/* ROCm complex conjugate */".to_string();
265 let wgpu_source = "/* WebGPU complex conjugate */".to_string();
266 let opencl_source = "/* OpenCL complex conjugate */".to_string();
267
268 Self {
269 base: BaseKernel::new(
270 "complex_conjugate",
271 &cuda_source,
272 &rocm_source,
273 &wgpu_source,
274 &metal_source,
275 &opencl_source,
276 metadata,
277 ),
278 }
279 }
280}
281
282impl GpuKernel for ComplexConjugateKernel {
283 fn name(&self) -> &str {
284 self.base.name()
285 }
286
287 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
288 self.base.source_for_backend(backend)
289 }
290
291 fn metadata(&self) -> KernelMetadata {
292 self.base.metadata()
293 }
294
295 fn can_specialize(&self, params: &KernelParams) -> bool {
296 false
297 }
298
299 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
300 Err(GpuError::SpecializationNotSupported)
301 }
302}
303
304pub struct ComplexMatMulKernel {
306 base: BaseKernel,
307}
308
309impl Default for ComplexMatMulKernel {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315impl ComplexMatMulKernel {
316 pub fn new() -> Self {
318 let metadata = KernelMetadata {
319 workgroup_size: [16, 16, 1],
320 local_memory_usage: 2 * 16 * 16 * 8, supports_tensor_cores: false,
322 operationtype: OperationType::ComputeIntensive,
323 backend_metadata: HashMap::new(),
324 };
325
326 let metal_source = r#"
327#include <metal_stdlib>
328using namespace metal;
329
330struct complex_f32 {
331 float real;
332 float imag;
333
334 complex_f32(float r = 0.0f, float i = 0.0f) : real(r), imag(0) {}
335};
336
337complex_f32 complex_add(complex_f32 a, complex_f32 b) {
338 return complex_f32(a.real + b.real, a.imag + b.imag);
339}
340
341complex_f32 complex_mul(complex_f32 a, complex_f32 b) {
342 return complex_f32(
343 a.real * b.real - a.imag * b.imag,
344 a.real * b.imag + a.imag * b.real
345 );
346}
347
348// Tiled complex matrix multiplication for small matrices (e.g., 2x2, 4x4 quantum gates)
349kernel void complex_matmul_small(
350 const device complex_f32* A [[buffer(0)]],
351 const device complex_f32* B [[buffer(1)]],
352 device complex_f32* C [[buffer(2)]],
353 constant uint& M [[buffer(3)]],
354 constant uint& N [[buffer(4)]],
355 constant uint& K [[buffer(5)]],
356 threadgroup complex_f32* tileA [[threadgroup(0)]],
357 threadgroup complex_f32* tileB [[threadgroup(1)]],
358 uint2 gid [[thread_position_in_grid]],
359 uint2 tid [[thread_position_in_threadgroup]],
360 uint2 tgid [[threadgroup_position_in_grid]])
361{
362 const uint TILE_SIZE = 16;
363
364 // Compute the row and column for this thread
365 uint row = tgid.y * TILE_SIZE + tid.y;
366 uint col = tgid.x * TILE_SIZE + tid.x;
367
368 // Initialize accumulator
369 complex_f32 sum(0.0f, 0.0f);
370
371 // Loop over tiles
372 for (uint t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
373 // Load tile from A
374 uint aRow = row;
375 uint aCol = t * TILE_SIZE + tid.x;
376 if (aRow < M && aCol < K) {
377 tileA[tid.y * TILE_SIZE + tid.x] = A[aRow * K + aCol];
378 } else {
379 tileA[tid.y * TILE_SIZE + tid.x] = complex_f32(0.0f, 0.0f);
380 }
381
382 // Load tile from B
383 uint bRow = t * TILE_SIZE + tid.y;
384 uint bCol = col;
385 if (bRow < K && bCol < N) {
386 tileB[tid.y * TILE_SIZE + tid.x] = B[bRow * N + bCol];
387 } else {
388 tileB[tid.y * TILE_SIZE + tid.x] = complex_f32(0.0f, 0.0f);
389 }
390
391 // Synchronize threads
392 threadgroup_barrier(mem_flags::mem_threadgroup);
393
394 // Compute partial dot product
395 for (uint k = 0; k < TILE_SIZE; k++) {
396 sum = complex_add(sum,
397 complex_mul(tileA[tid.y * TILE_SIZE + k],
398 tileB[k * TILE_SIZE + tid.x]));
399 }
400
401 // Synchronize before loading next tile
402 threadgroup_barrier(mem_flags::mem_threadgroup);
403 }
404
405 // Write result
406 if (row < M && col < N) {
407 C[row * N + col] = sum;
408 }
409}
410"#
411 .to_string();
412
413 let cuda_source = "/* CUDA complex matmul */".to_string();
415 let rocm_source = "/* ROCm complex matmul */".to_string();
416 let wgpu_source = "/* WebGPU complex matmul */".to_string();
417 let opencl_source = "/* OpenCL complex matmul */".to_string();
418
419 Self {
420 base: BaseKernel::new(
421 "complex_matmul",
422 &cuda_source,
423 &rocm_source,
424 &wgpu_source,
425 &metal_source,
426 &opencl_source,
427 metadata,
428 ),
429 }
430 }
431}
432
433impl GpuKernel for ComplexMatMulKernel {
434 fn name(&self) -> &str {
435 self.base.name()
436 }
437
438 fn source_for_backend(&self, backend: GpuBackend) -> Result<String, GpuError> {
439 self.base.source_for_backend(backend)
440 }
441
442 fn metadata(&self) -> KernelMetadata {
443 self.base.metadata()
444 }
445
446 fn can_specialize(&self, params: &KernelParams) -> bool {
447 false
448 }
449
450 fn specialize(&self, params: &KernelParams) -> Result<Box<dyn GpuKernel>, GpuError> {
451 Ok(Box::new(self.clone()))
453 }
454}
455
456impl Clone for ComplexMultiplyKernel {
457 fn clone(&self) -> Self {
458 Self::new()
459 }
460}
461
462impl Clone for ComplexConjugateKernel {
463 fn clone(&self) -> Self {
464 Self::new()
465 }
466}
467
468impl Clone for ComplexMatMulKernel {
469 fn clone(&self) -> Self {
470 Self::new()
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::gpu::kernels::DataType;
478
479 #[test]
480 fn test_complex_multiply_kernel() {
481 let kernel = ComplexMultiplyKernel::new();
482 assert_eq!(kernel.name(), "complex_multiply");
483 assert!(!kernel.can_specialize(&KernelParams::new(DataType::Float32)));
484 }
485
486 #[test]
487 fn test_complex_kernel_metadata() {
488 let kernel = ComplexMultiplyKernel::new();
489 let metadata = kernel.metadata();
490 assert_eq!(metadata.workgroup_size, [256, 1, 1]);
491 assert_eq!(metadata.operationtype, OperationType::ComputeIntensive);
492 }
493
494 #[test]
495 fn test_metal_source_generation() {
496 let kernel = ComplexMultiplyKernel::new();
497 let source = kernel
498 .source_for_backend(GpuBackend::Metal)
499 .expect("Operation failed");
500 assert!(source.contains("complex_f32"));
501 assert!(source.contains("complex_mul"));
502 }
503}