1use crate::gpu::{GpuBackend, GpuBuffer, GpuError};
8use std::fmt;
9use thiserror::Error;
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13pub enum TensorDataType {
14 Float16,
16 BFloat16,
18 Float32,
20 Float64,
22 Int8,
24 Int4,
26 Binary,
28 Mixed(Box<TensorDataType>, Box<TensorDataType>), }
31
32impl fmt::Display for TensorDataType {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 TensorDataType::Float16 => write!(f, "f16"),
36 TensorDataType::BFloat16 => write!(f, "bf16"),
37 TensorDataType::Float32 => write!(f, "f32"),
38 TensorDataType::Float64 => write!(f, "f64"),
39 TensorDataType::Int8 => write!(f, "i8"),
40 TensorDataType::Int4 => write!(f, "i4"),
41 TensorDataType::Binary => write!(f, "binary"),
42 TensorDataType::Mixed(input, accum) => write!(f, "mixed({input}, {accum})"),
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49pub enum TensorCoreOp {
50 MatrixMultiply,
52 Convolution,
54 Attention,
56 SparseOps,
58 Elementwise,
60 Custom(&'static str),
62}
63
64#[derive(Debug, Clone, Default)]
66pub struct TensorCoreCapabilities {
67 pub available: bool,
69 pub supported_types: Vec<TensorDataType>,
71 pub supported_ops: Vec<TensorCoreOp>,
73 pub supported_dimensions: Vec<(usize, usize, usize)>,
75 pub peak_tops: Option<f64>,
77 pub memorybandwidth_gbps: Option<f64>,
79 pub arch_features: Vec<String>,
81}
82
83#[derive(Debug, Clone)]
85pub struct TensorCoreConfig {
86 pub datatype: TensorDataType,
88 pub use_mixed_precision: bool,
90 pub auto_convert: bool,
92 pub tile_size: (usize, usize),
94 pub use_sparse: bool,
96 pub arch_optimizations: Vec<String>,
98}
99
100impl Default for TensorCoreConfig {
101 fn default() -> Self {
102 Self {
103 datatype: TensorDataType::Float16,
104 use_mixed_precision: true,
105 auto_convert: true,
106 tile_size: (16, 16),
107 use_sparse: false,
108 arch_optimizations: Vec::new(),
109 }
110 }
111}
112
113#[derive(Error, Debug)]
115pub enum TensorCoreError {
116 #[error("Tensor cores not available on this device")]
118 NotAvailable,
119
120 #[error("Unsupported data type: {0}")]
122 UnsupportedDataType(TensorDataType),
123
124 #[error("Unsupported operation: {0:?}")]
126 UnsupportedOperation(TensorCoreOp),
127
128 #[error("Invalid matrix dimensions: {m}x{n}x{k}")]
130 InvalidDimensions { m: usize, n: usize, k: usize },
131
132 #[error("Memory alignment error: {0}")]
134 MemoryAlignment(String),
135
136 #[error("Performance warning: {0}")]
138 PerformanceWarning(String),
139
140 #[error("GPU error: {0}")]
142 GpuError(#[from] GpuError),
143}
144
145#[derive(Debug)]
147pub struct TensorCoreManager {
148 backend: GpuBackend,
149 capabilities: TensorCoreCapabilities,
150 config: TensorCoreConfig,
151}
152
153impl TensorCoreManager {
154 pub fn new(backend: GpuBackend) -> Result<Self, TensorCoreError> {
156 let capabilities = Self::detect_capabilities(backend)?;
157
158 if !capabilities.available {
159 return Err(TensorCoreError::NotAvailable);
160 }
161
162 let config = Self::optimal_config(&capabilities);
163
164 Ok(Self {
165 backend,
166 capabilities,
167 config,
168 })
169 }
170
171 pub const fn capabilities(&self) -> &TensorCoreCapabilities {
173 &self.capabilities
174 }
175
176 pub const fn config(&self) -> &TensorCoreConfig {
178 &self.config
179 }
180
181 pub fn set_config(&mut self, config: TensorCoreConfig) -> Result<(), TensorCoreError> {
183 if !self.capabilities.supported_types.contains(&config.datatype) {
185 return Err(TensorCoreError::UnsupportedDataType(config.datatype));
186 }
187
188 self.config = config;
189 Ok(())
190 }
191
192 pub fn is_operation_supported(&self, op: TensorCoreOp) -> bool {
194 self.capabilities.supported_ops.contains(&op)
195 }
196
197 pub fn are_dimensions_optimal(&self, m: usize, n: usize, k: usize) -> bool {
199 match self.backend {
201 GpuBackend::Cuda => {
202 m % 16 == 0 && n % 16 == 0 && k % 16 == 0
204 }
205 GpuBackend::Rocm => {
206 m % 32 == 0 && n % 32 == 0 && k % 32 == 0
208 }
209 _ => false,
210 }
211 }
212
213 pub fn get_performance_hints(&self, m: usize, n: usize, k: usize) -> Vec<String> {
215 let mut hints = Vec::new();
216
217 if !self.are_dimensions_optimal(m, n, k) {
218 hints.push(format!(
219 "Consider padding dimensions to multiples of {} for optimal performance",
220 match self.backend {
221 GpuBackend::Cuda => 16,
222 GpuBackend::Rocm => 16,
223 GpuBackend::Wgpu => 16,
224 GpuBackend::Metal => 16,
225 GpuBackend::OpenCL => 16,
226 GpuBackend::Cpu => 1,
227 }
228 ));
229 }
230
231 if self.config.use_mixed_precision && self.config.datatype == TensorDataType::Float32 {
232 hints.push(
233 "Consider using Float16 or BFloat16 for better tensor core utilization".to_string(),
234 );
235 }
236
237 if m * n * k < 1024 * 1024 {
238 hints.push("Small matrices may not fully utilize tensor cores".to_string());
239 }
240
241 hints
242 }
243
244 pub fn suggest_optimal_type(&self, op: TensorCoreOp) -> Option<TensorDataType> {
246 match op {
247 TensorCoreOp::MatrixMultiply => {
248 if self
249 .capabilities
250 .supported_types
251 .contains(&TensorDataType::BFloat16)
252 {
253 Some(TensorDataType::BFloat16)
254 } else if self
255 .capabilities
256 .supported_types
257 .contains(&TensorDataType::Float16)
258 {
259 Some(TensorDataType::Float16)
260 } else {
261 Some(TensorDataType::Float32)
262 }
263 }
264 TensorCoreOp::Convolution => {
265 if self
266 .capabilities
267 .supported_types
268 .contains(&TensorDataType::Int8)
269 {
270 Some(TensorDataType::Int8)
271 } else if self
272 .capabilities
273 .supported_types
274 .contains(&TensorDataType::Float16)
275 {
276 Some(TensorDataType::Float16)
277 } else {
278 Some(TensorDataType::Float32)
279 }
280 }
281 TensorCoreOp::Attention => {
282 if self
284 .capabilities
285 .supported_types
286 .contains(&TensorDataType::BFloat16)
287 {
288 Some(TensorDataType::BFloat16)
289 } else {
290 Some(TensorDataType::Float32)
291 }
292 }
293 _ => self.capabilities.supported_types.first().cloned(),
294 }
295 }
296
297 fn detect_capabilities(backend: GpuBackend) -> Result<TensorCoreCapabilities, TensorCoreError> {
299 match backend {
300 GpuBackend::Cuda => Ok(Self::nvidia_tensor_capabilities()),
301 GpuBackend::Rocm => Ok(Self::amdmatrix_capabilities()),
302 GpuBackend::Metal => Ok(Self::apple_neural_capabilities()),
303 GpuBackend::Cpu => Ok(TensorCoreCapabilities {
304 available: true, supported_types: vec![TensorDataType::Float32],
306 supported_ops: vec![TensorCoreOp::MatrixMultiply],
307 supported_dimensions: vec![(16, 16, 16)],
308 peak_tops: Some(1.0),
309 memorybandwidth_gbps: Some(100.0),
310 arch_features: vec!["cpu_simulation".to_string()],
311 }),
312 _ => Ok(TensorCoreCapabilities::default()),
313 }
314 }
315
316 fn nvidia_tensor_capabilities() -> TensorCoreCapabilities {
318 TensorCoreCapabilities {
319 available: true,
320 supported_types: vec![
321 TensorDataType::Float16,
322 TensorDataType::BFloat16,
323 TensorDataType::Float32,
324 TensorDataType::Int8,
325 TensorDataType::Int4,
326 TensorDataType::Binary,
327 TensorDataType::Mixed(
328 Box::new(TensorDataType::Float16),
329 Box::new(TensorDataType::Float32),
330 ),
331 ],
332 supported_ops: vec![
333 TensorCoreOp::MatrixMultiply,
334 TensorCoreOp::Convolution,
335 TensorCoreOp::Attention,
336 TensorCoreOp::SparseOps,
337 ],
338 supported_dimensions: vec![
339 (16, 16, 16), (32, 8, 16), (8, 32, 16),
342 ],
343 peak_tops: Some(312.0), memorybandwidth_gbps: Some(2039.0), arch_features: vec![
346 "Sparsity 2:4".to_string(),
347 "Multi-precision".to_string(),
348 "Transformer Engine".to_string(),
349 ],
350 }
351 }
352
353 fn amdmatrix_capabilities() -> TensorCoreCapabilities {
355 TensorCoreCapabilities {
356 available: true,
357 supported_types: vec![
358 TensorDataType::Float16,
359 TensorDataType::BFloat16,
360 TensorDataType::Float32,
361 TensorDataType::Int8,
362 TensorDataType::Mixed(
363 Box::new(TensorDataType::Float16),
364 Box::new(TensorDataType::Float32),
365 ),
366 ],
367 supported_ops: vec![TensorCoreOp::MatrixMultiply, TensorCoreOp::Convolution],
368 supported_dimensions: vec![
369 (32, 32, 8), (16, 16, 16),
371 ],
372 peak_tops: Some(383.0), memorybandwidth_gbps: Some(3276.0), arch_features: vec!["MFMA instructions".to_string(), "Matrix cores".to_string()],
375 }
376 }
377
378 fn apple_neural_capabilities() -> TensorCoreCapabilities {
380 TensorCoreCapabilities {
381 available: true,
382 supported_types: vec![
383 TensorDataType::Float16,
384 TensorDataType::Float32,
385 TensorDataType::Int8,
386 ],
387 supported_ops: vec![
388 TensorCoreOp::MatrixMultiply,
389 TensorCoreOp::Convolution,
390 TensorCoreOp::Attention,
391 ],
392 supported_dimensions: vec![(16, 16, 16)],
393 peak_tops: Some(15.8), memorybandwidth_gbps: Some(68.25), arch_features: vec!["Neural Engine".to_string(), "Unified memory".to_string()],
396 }
397 }
398
399 fn optimal_config(capabilities: &TensorCoreCapabilities) -> TensorCoreConfig {
401 let datatype = if capabilities
402 .supported_types
403 .contains(&TensorDataType::BFloat16)
404 {
405 TensorDataType::BFloat16
406 } else if capabilities
407 .supported_types
408 .contains(&TensorDataType::Float16)
409 {
410 TensorDataType::Float16
411 } else {
412 TensorDataType::Float32
413 };
414
415 let tile_size = capabilities
416 .supported_dimensions
417 .first()
418 .map(|(m, n, k)| (*m, *n))
419 .unwrap_or((16, 16));
420
421 TensorCoreConfig {
422 datatype,
423 use_mixed_precision: capabilities
424 .supported_types
425 .iter()
426 .any(|t| matches!(t, TensorDataType::Mixed(_, _))),
427 auto_convert: true,
428 tile_size,
429 use_sparse: capabilities
430 .arch_features
431 .iter()
432 .any(|f| f.contains("Sparsity")),
433 arch_optimizations: capabilities.arch_features.clone(),
434 }
435 }
436}
437
438#[derive(Debug, Clone)]
440pub struct TensorOperation {
441 pub op_type: TensorCoreOp,
443 pub input_type: TensorDataType,
445 pub output_type: TensorDataType,
447 pub dimensions: (usize, usize, usize),
449 pub mixed_precision: bool,
451 pub sparsity: Option<SparsePattern>,
453}
454
455impl Default for TensorOperation {
456 fn default() -> Self {
457 Self {
458 op_type: TensorCoreOp::MatrixMultiply,
459 input_type: TensorDataType::Float32,
460 output_type: TensorDataType::Float32,
461 dimensions: (1, 1, 1),
462 mixed_precision: false,
463 sparsity: None,
464 }
465 }
466}
467
468#[derive(Debug, Clone)]
470pub enum SparsePattern {
471 Structured2_4,
473 Random(f32),
475 Block {
477 block_size: (usize, usize),
478 sparsity: f32,
479 },
480 Custom(String),
482}
483
484#[allow(dead_code)]
486pub fn tensor_core_gemm<T>(
487 manager: &TensorCoreManager,
488 a: &GpuBuffer<T>,
489 b: &GpuBuffer<T>,
490 c: &mut GpuBuffer<T>,
491 m: usize,
492 n: usize,
493 k: usize,
494) -> Result<(), TensorCoreError>
495where
496 T: crate::gpu::GpuDataType,
497{
498 if !manager.are_dimensions_optimal(m, n, k) {
500 let hints = manager.get_performance_hints(m, n, k);
501 for hint in hints {
502 eprintln!("Performance hint: {hint}");
503 }
504 }
505
506 if !manager.is_operation_supported(TensorCoreOp::MatrixMultiply) {
508 return Err(TensorCoreError::UnsupportedOperation(
509 TensorCoreOp::MatrixMultiply,
510 ));
511 }
512
513 if a.len() < m * k {
515 return Err(TensorCoreError::InvalidDimensions { m, n, k });
516 }
517 if b.len() < k * n {
518 return Err(TensorCoreError::InvalidDimensions { m, n, k });
519 }
520 if c.len() < m * n {
521 return Err(TensorCoreError::InvalidDimensions { m, n, k });
522 }
523
524 let kernel_source = generate_tensor_core_gemm_kernel(manager, m, n, k)?;
526
527 execute_tensor_core_operation(manager, &kernel_source, a, b, c, m, n, k)?;
529
530 Ok(())
531}
532
533#[allow(dead_code)]
535fn generate_tensor_core_gemm_kernel(
536 manager: &TensorCoreManager,
537 m: usize,
538 n: usize,
539 k: usize,
540) -> Result<String, TensorCoreError> {
541 let tile_size = manager.config().tile_size;
542 let datatype = &manager.config().datatype;
543 let use_mixed_precision = manager.config().use_mixed_precision;
544
545 match manager.backend {
546 GpuBackend::Cuda => generate_cuda_tensor_core_kernel(
547 datatype.clone(),
548 tile_size.0,
549 m,
550 n,
551 k,
552 use_mixed_precision,
553 ),
554 GpuBackend::Rocm => generate_rocmmatrix_core_kernel(
555 datatype.clone(),
556 tile_size.0,
557 m,
558 n,
559 k,
560 use_mixed_precision,
561 ),
562 GpuBackend::Metal => generate_metal_mps_kernel(datatype.clone(), tile_size.0, m, n, k),
563 _ => Err(TensorCoreError::UnsupportedOperation(
564 TensorCoreOp::MatrixMultiply,
565 )),
566 }
567}
568
569fn generate_cuda_tensor_core_kernel(
571 datatype: TensorDataType,
572 _tile_size: usize,
573 _m: usize,
574 _n: usize,
575 _k: usize,
576 _use_mixed_precision: bool,
577) -> Result<String, TensorCoreError> {
578 Ok("/* CUDA tensor core kernel placeholder */".to_string())
580}
581
582fn generate_rocmmatrix_core_kernel(
584 datatype: TensorDataType,
585 _tile_size: usize,
586 _m: usize,
587 _n: usize,
588 _k: usize,
589 _use_mixed_precision: bool,
590) -> Result<String, TensorCoreError> {
591 Ok("/* ROCm matrix core kernel placeholder */".to_string())
593}
594
595fn generate_metal_mps_kernel(
597 datatype: TensorDataType,
598 _tile_size: usize,
599 _m: usize,
600 _n: usize,
601 _k: usize,
602) -> Result<String, TensorCoreError> {
603 Ok("/* Metal MPS kernel placeholder */".to_string())
605}
606
607#[allow(dead_code)]
609fn generate_cuda_kernel(
610 datatype: TensorDataType,
611 tile_size: (usize, usize),
612 _m: usize,
613 _n: usize,
614 _k: usize,
615 use_mixed_precision: bool,
616) -> Result<String, TensorCoreError> {
617 let (tile_m, tile_n) = tile_size;
618 let dtype_str = match datatype {
619 TensorDataType::Float16 => "__half",
620 TensorDataType::BFloat16 => "__nv_bfloat16",
621 TensorDataType::Float32 => "float",
622 TensorDataType::Int8 => "int8_t",
623 _ => return Err(TensorCoreError::UnsupportedDataType(datatype.clone())),
624 };
625
626 let accumulator_type = if use_mixed_precision {
627 "float"
628 } else {
629 dtype_str
630 };
631
632 Ok(format!(
633 r#"
634#include <cuda_fp16.h>
635#include <cuda_bf16.h>
636#include <mma.h>
637
638using namespace nvcuda;
639
640__global__ void tensor_core_gemm(
641 const {dtype_str}* __restrict__ A,
642 const {dtype_str}* __restrict__ B,
643 {accumulator_type}* __restrict__ C,
644 int M, int N, int K
645) {{
646 // Tensor core fragment declarations
647 wmma::fragment<wmma::matrix_a, {tile_m}, {tile_n}, 16, {dtype_str}, wmma::row_major> a_frag;
648 wmma::fragment<wmma::matrix_b, {tile_m}, {tile_n}, 16, {dtype_str}, wmma::col_major> b_frag;
649 wmma::fragment<wmma::accumulator, {tile_m}, {tile_n}, 16, {accumulator_type}> acc_frag;
650 wmma::fragment<wmma::accumulator, {tile_m}, {tile_n}, 16, {accumulator_type}> c_frag;
651
652 // Thread block and warp coordinates
653 int warp_row = (blockIdx.y * blockDim.y + threadIdx.y) / 32;
654 int warp_col = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
655
656 // Bounds checking
657 if (warp_row * {tile_m} >= M || warp_col * {tile_n} >= N) return;
658
659 // Initialize accumulator
660 wmma::fill_fragment(acc_frag, 0.0f);
661
662 // Main computation loop
663 for (int i = 0; i < K; i += 16) {{
664 int a_row = warp_row * {tile_m};
665 int a_col = i;
666 int b_row = i;
667 int b_col = warp_col * {tile_n};
668
669 // Bounds checking for partial tiles
670 if (a_col + 16 <= K && b_row + 16 <= K) {{
671 // Load matrix fragments
672 wmma::loadmatrix_sync(a_frag, A + a_row * K + a_col, K);
673 wmma::loadmatrix_sync(b_frag, B + b_row * N + b_col, N);
674
675 // Perform matrix multiplication
676 wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
677 }}
678 }}
679
680 // Load existing C matrix values for accumulation
681 int c_row = warp_row * {tile_m};
682 int c_col = warp_col * {tile_n};
683
684 if (c_row + {tile_m} <= M && c_col + {tile_n} <= N) {{
685 wmma::loadmatrix_sync(c_frag, C + c_row * N + c_col, N, wmma::mem_row_major);
686
687 // Add to accumulator
688 for (int i = 0; i < c_frag.num_elements; i++) {{
689 c_frag.x[i] += acc_frag.x[i];
690 }}
691
692 // Store result
693 wmma::storematrix_sync(C + c_row * N + c_col, c_frag, N, wmma::mem_row_major);
694 }}
695}}
696"#
697 ))
698}
699
700#[allow(dead_code)]
702fn generate_rocm_kernel(
703 datatype: TensorDataType,
704 tile_size: (usize, usize),
705 _m: usize,
706 _n: usize,
707 _k: usize,
708 use_mixed_precision: bool,
709) -> Result<String, TensorCoreError> {
710 let (tile_m, tile_n) = tile_size;
711 let dtype_str = match datatype {
712 TensorDataType::Float16 => "_Float16",
713 TensorDataType::BFloat16 => "__bf16",
714 TensorDataType::Float32 => "float",
715 TensorDataType::Int8 => "int8_t",
716 _ => return Err(TensorCoreError::UnsupportedDataType(datatype.clone())),
717 };
718
719 let accumulator_type = if use_mixed_precision {
720 "float"
721 } else {
722 dtype_str
723 };
724
725 Ok(format!(
726 r#"
727#include <hip/hip_runtime.h>
728#include <hip/hip_fp16.h>
729
730__global__ void matrix_core_gemm(
731 const {dtype_str}* __restrict__ A,
732 const {dtype_str}* __restrict__ B,
733 {accumulator_type}* __restrict__ C,
734 int M, int N, int K
735) {{
736 // AMD MFMA intrinsics for matrix core operations
737 const int tid = threadIdx.x;
738 const int warp_id = tid / 64; // AMD wavefront _size is 64
739 const int lane_id = tid % 64;
740
741 const int block_row = blockIdx.y * {tile_m};
742 const int block_col = blockIdx.x * {tile_n};
743
744 // Bounds checking
745 if (block_row >= M || block_col >= N) return;
746
747 // Shared memory for tile loading
748 __shared__ {dtype_str} A_shared[{tile_m} * 32];
749 __shared__ {dtype_str} B_shared[32 * {tile_n}];
750
751 {accumulator_type} accumulator[{acc_size}] = {{0}};
752
753 // Main computation loop
754 for (int k_block = 0; k_block < K; k_block += 32) {{
755 // Cooperative loading to shared memory
756 if (tid < {tile_m} * 32 / 64) {{
757 int load_idx = tid * 64 + lane_id;
758 if (load_idx < {tile_m} * 32 && (block_row + load_idx / 32) < M && (k_block + load_idx % 32) < K) {{
759 A_shared[load_idx] = A[(block_row + load_idx / 32) * K + (k_block + load_idx % 32)];
760 }}
761 }}
762
763 if (tid < 32 * {tile_n} / 64) {{
764 int load_idx = tid * 64 + lane_id;
765 if (load_idx < 32 * {tile_n} && (k_block + load_idx / {tile_n}) < K && (block_col + load_idx % {tile_n}) < N) {{
766 B_shared[load_idx] = B[(k_block + load_idx / {tile_n}) * N + (block_col + load_idx % {tile_n})];
767 }}
768 }}
769
770 __syncthreads();
771
772 // MFMA matrix multiplication (simplified)
773 // In practice, would use __builtin_amdgcn_mfma_* intrinsics
774 for (int i = 0; i < {tile_m}; i += 4) {{
775 for (int j = 0; j < {tile_n}; j += 4) {{
776 for (int k_inner = 0; k_inner < 32; k_inner++) {{
777 accumulator[(i * {tile_n} + j) / 16] +=
778 A_shared[i * 32 + k_inner] * B_shared[k_inner * {tile_n} + j];
779 }}
780 }}
781 }}
782
783 __syncthreads();
784 }}
785
786 // Store results
787 for (int i = 0; i < {tile_m}; i++) {{
788 for (int j = 0; j < {tile_n}; j++) {{
789 int global_row = block_row + i;
790 int global_col = block_col + j;
791 if (global_row < M && global_col < N) {{
792 C[global_row * N + global_col] += accumulator[(i * {tile_n} + j) / 16];
793 }}
794 }}
795 }}
796}}
797"#,
798 dtype_str = dtype_str,
799 accumulator_type = accumulator_type,
800 tile_m = tile_m,
801 tile_n = tile_n,
802 acc_size = (tile_m * tile_n) / 16,
803 ))
804}
805
806#[allow(dead_code)]
808fn generate_metal_kernel(
809 datatype: TensorDataType,
810 tile_size: (usize, usize),
811 _m: usize,
812 _n: usize,
813 _k: usize,
814) -> Result<String, TensorCoreError> {
815 let (tile_m, tile_n) = tile_size;
816 let dtype_str = match datatype {
817 TensorDataType::Float16 => "half",
818 TensorDataType::Float32 => "float",
819 TensorDataType::Int8 => "char",
820 _ => return Err(TensorCoreError::UnsupportedDataType(datatype.clone())),
821 };
822
823 Ok(format!(
824 r#"
825#include <metal_stdlib>
826using namespace metal;
827
828kernel void neural_engine_gemm(
829 device const {dtype_str}* A [[buffer(0)]],
830 device const {dtype_str}* B [[buffer(1)]],
831 device {dtype_str}* C [[buffer(2)]],
832 constant uint& M [[buffer(3)]],
833 constant uint& N [[buffer(4)]],
834 constant uint& K [[buffer(5)]],
835 uint2 gid [[thread_position_in_grid]],
836 uint2 tid [[thread_position_in_threadgroup]]
837) {{
838 const uint row = gid.y * {tile_m} + tid.y;
839 const uint col = gid.x * {tile_n} + tid.x;
840
841 if (row >= M || col >= N) return;
842
843 {dtype_str} sum = 0.0;
844
845 // Use SIMD group matrix operations when available
846 for (uint _k = 0; _k < K; _k++) {{
847 sum += A[row * K + _k] * B[_k * N + col];
848 }}
849
850 C[row * N + col] = sum;
851}}
852"#
853 ))
854}
855
856#[allow(dead_code)]
858fn execute_tensor_core_operation<T>(
859 manager: &TensorCoreManager,
860 kernel_source: &str,
861 a: &GpuBuffer<T>,
862 b: &GpuBuffer<T>,
863 c: &mut GpuBuffer<T>,
864 m: usize,
865 n: usize,
866 k: usize,
867) -> Result<(), TensorCoreError>
868where
869 T: crate::gpu::GpuDataType,
870{
871 let tile_size = manager.config().tile_size;
879 let grid_dim_x = n.div_ceil(tile_size.1);
880 let grid_dim_y = m.div_ceil(tile_size.0);
881
882 eprintln!("Executing tensor core GEMM:");
883 eprintln!(
884 " Kernel _source length: {} characters",
885 kernel_source.len()
886 );
887 eprintln!(" Dimensions: {m}x{n}x{k}");
888 eprintln!(" Grid dimensions: {grid_dim_x}x{grid_dim_y}");
889 eprintln!(" Tile size: {tile_size:?}");
890 eprintln!(" Backend: {:?}", manager.backend);
891 eprintln!(" Data type: {}", manager.config().datatype);
892
893 Ok(())
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900
901 #[test]
902 fn test_tensor_datatype_display() {
903 assert_eq!(TensorDataType::Float16.to_string(), "f16");
904 assert_eq!(TensorDataType::BFloat16.to_string(), "bf16");
905 assert_eq!(TensorDataType::Int8.to_string(), "i8");
906 }
907
908 #[test]
909 fn test_nvidia_capabilities() {
910 let caps = TensorCoreManager::nvidia_tensor_capabilities();
911 assert!(caps.available);
912 assert!(caps.supported_types.contains(&TensorDataType::Float16));
913 assert!(caps.supported_ops.contains(&TensorCoreOp::MatrixMultiply));
914 }
915
916 #[test]
917 fn test_amd_capabilities() {
918 let caps = TensorCoreManager::amdmatrix_capabilities();
919 assert!(caps.available);
920 assert!(caps.supported_types.contains(&TensorDataType::BFloat16));
921 assert!(caps.supported_ops.contains(&TensorCoreOp::MatrixMultiply));
922 }
923
924 #[test]
925 fn test_optimal_config() {
926 let caps = TensorCoreManager::nvidia_tensor_capabilities();
927 let config = TensorCoreManager::optimal_config(&caps);
928 assert_eq!(config.datatype, TensorDataType::BFloat16);
929 assert!(config.use_mixed_precision);
930 }
931
932 #[test]
933 fn test_dimension_optimization() {
934 let caps = TensorCoreManager::nvidia_tensor_capabilities();
936 let config = TensorCoreManager::optimal_config(&caps);
937
938 assert!(config.auto_convert);
940 assert_eq!(config.tile_size, (16, 16));
941 }
942}