1use crate::errors::{Result, TrustformersError};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33use std::path::PathBuf;
34use std::time::{Duration, Instant};
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum Operation {
39 MatMul,
41 Convolution,
43 Softmax,
45 LayerNorm,
47 Attention,
49 ElementWise,
51 Reduction,
53 Transpose,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
59pub enum Backend {
60 CPU,
62 CUDA,
64 ROCm,
66 Metal,
68 Vulkan,
70 OneAPI,
72 TPU,
74}
75
76#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct PlatformInfo {
79 pub backend: Backend,
81
82 pub device_name: String,
84
85 pub compute_units: usize,
87
88 pub total_memory: usize,
90
91 pub memory_bandwidth: f32,
93
94 pub peak_tflops: f32,
96
97 pub cache_sizes: Vec<usize>,
99
100 pub warp_size: usize,
102
103 pub max_threads_per_block: usize,
105}
106
107impl PlatformInfo {
108 pub fn detect() -> Result<Self> {
110 Ok(Self {
113 backend: Backend::CPU,
114 device_name: "Generic CPU".to_string(),
115 compute_units: num_cpus::get(),
116 total_memory: 16 * 1024 * 1024 * 1024, memory_bandwidth: 50.0, peak_tflops: 1.0,
119 cache_sizes: vec![32768, 262144, 8388608], warp_size: 1,
121 max_threads_per_block: 256,
122 })
123 }
124
125 #[cfg(feature = "cuda")]
127 pub fn cuda(device_id: usize) -> Result<Self> {
128 Ok(Self {
130 backend: Backend::CUDA,
131 device_name: format!("CUDA Device {}", device_id),
132 compute_units: 128,
133 total_memory: 24 * 1024 * 1024 * 1024,
134 memory_bandwidth: 900.0,
135 peak_tflops: 82.0,
136 cache_sizes: vec![128 * 1024, 40 * 1024 * 1024], warp_size: 32,
138 max_threads_per_block: 1024,
139 })
140 }
141
142 pub fn suggested_block_size(&self, operation: Operation) -> (usize, usize, usize) {
144 match self.backend {
145 Backend::CUDA => {
146 match operation {
148 Operation::MatMul => (16, 16, 1),
149 Operation::Convolution => (16, 16, 1),
150 Operation::Softmax => (256, 1, 1),
151 Operation::LayerNorm => (256, 1, 1),
152 Operation::Attention => (64, 1, 1),
153 Operation::ElementWise => (256, 1, 1),
154 Operation::Reduction => (256, 1, 1),
155 Operation::Transpose => (32, 8, 1),
156 }
157 },
158 Backend::CPU => {
159 match operation {
161 Operation::MatMul => (64, 64, 64),
162 _ => (32, 32, 1),
163 }
164 },
165 _ => (16, 16, 1), }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct KernelParams {
173 pub operation: Operation,
175
176 pub block_size: (usize, usize, usize),
178
179 pub threads_per_block: usize,
181
182 pub use_shared_memory: bool,
184
185 pub unroll_factor: usize,
187
188 pub vector_width: usize,
190
191 pub grid_size: (usize, usize, usize),
193
194 pub estimated_time_us: f64,
196}
197
198impl Default for KernelParams {
199 fn default() -> Self {
200 Self {
201 operation: Operation::ElementWise,
202 block_size: (16, 16, 1),
203 threads_per_block: 256,
204 use_shared_memory: true,
205 unroll_factor: 4,
206 vector_width: 4,
207 grid_size: (1, 1, 1),
208 estimated_time_us: 0.0,
209 }
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct TuningConfig {
216 pub enable_tuning: bool,
218
219 pub warmup_iterations: usize,
221
222 pub benchmark_iterations: usize,
224
225 pub cache_dir: Option<PathBuf>,
227
228 pub max_tuning_time_secs: f32,
230
231 pub min_improvement_threshold: f32,
233}
234
235impl Default for TuningConfig {
236 fn default() -> Self {
237 Self {
238 enable_tuning: true,
239 warmup_iterations: 3,
240 benchmark_iterations: 10,
241 cache_dir: Some(PathBuf::from(".kernel_cache")),
242 max_tuning_time_secs: 10.0,
243 min_improvement_threshold: 0.05, }
245 }
246}
247
248#[derive(Debug, Clone)]
250struct TuningResult {
251 params: KernelParams,
252 mean_time: Duration,
253 #[allow(dead_code)]
254 std_dev: f64,
255}
256
257#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
259struct CacheKey {
260 operation: Operation,
261 backend: Backend,
262 device_name: String,
263 input_shape: Vec<usize>,
264}
265
266pub struct KernelTuner {
268 config: TuningConfig,
270
271 platform: PlatformInfo,
273
274 cache: HashMap<CacheKey, KernelParams>,
276
277 cache_dirty: bool,
279}
280
281impl KernelTuner {
282 pub fn new(config: TuningConfig) -> Result<Self> {
284 let platform = PlatformInfo::detect()?;
285
286 let mut tuner = Self {
287 config,
288 platform,
289 cache: HashMap::new(),
290 cache_dirty: false,
291 };
292
293 tuner.load_cache()?;
295
296 Ok(tuner)
297 }
298
299 pub fn for_backend(backend: Backend, config: TuningConfig) -> Result<Self> {
301 let platform = match backend {
302 #[cfg(feature = "cuda")]
303 Backend::CUDA => PlatformInfo::cuda(0)?,
304 _ => PlatformInfo::detect()?,
305 };
306
307 let mut tuner = Self {
308 config,
309 platform,
310 cache: HashMap::new(),
311 cache_dirty: false,
312 };
313
314 tuner.load_cache()?;
315
316 Ok(tuner)
317 }
318
319 pub fn tune_matmul(&mut self, m: usize, n: usize, k: usize) -> Result<KernelParams> {
321 let key = CacheKey {
322 operation: Operation::MatMul,
323 backend: self.platform.backend,
324 device_name: self.platform.device_name.clone(),
325 input_shape: vec![m, n, k],
326 };
327
328 if let Some(cached) = self.cache.get(&key) {
329 return Ok(cached.clone());
330 }
331
332 if !self.config.enable_tuning {
333 return Ok(self.default_matmul_params(m, n, k));
335 }
336
337 let params = self.auto_tune_matmul(m, n, k)?;
339
340 self.cache.insert(key, params.clone());
341 self.cache_dirty = true;
342
343 Ok(params)
344 }
345
346 fn auto_tune_matmul(&self, m: usize, n: usize, k: usize) -> Result<KernelParams> {
348 let start_time = Instant::now();
349 let max_duration = Duration::from_secs_f32(self.config.max_tuning_time_secs);
350
351 let mut best_result: Option<TuningResult> = None;
352
353 let block_sizes = vec![
355 (8, 8, 8),
356 (16, 16, 16),
357 (32, 32, 32),
358 (64, 64, 64),
359 (128, 128, 8),
360 ];
361
362 let thread_counts = vec![64, 128, 256, 512, 1024];
364
365 let unroll_factors = vec![1, 2, 4, 8];
367
368 for &block_size in &block_sizes {
369 if start_time.elapsed() > max_duration {
370 break;
371 }
372
373 for &threads in &thread_counts {
374 if threads > self.platform.max_threads_per_block {
375 continue;
376 }
377
378 for &unroll in &unroll_factors {
379 if start_time.elapsed() > max_duration {
380 break;
381 }
382
383 let params = KernelParams {
384 operation: Operation::MatMul,
385 block_size,
386 threads_per_block: threads,
387 use_shared_memory: true,
388 unroll_factor: unroll,
389 vector_width: 4,
390 grid_size: self.compute_grid_size(m, n, block_size),
391 estimated_time_us: 0.0,
392 };
393
394 if let Ok(result) = self.benchmark_config(¶ms, m, n, k) {
396 let is_better = match &best_result {
397 None => true,
398 Some(best) => result.mean_time < best.mean_time,
399 };
400 if is_better {
401 best_result = Some(result);
402 }
403 }
404 }
405 }
406 }
407
408 if let Some(result) = best_result {
409 let mut params = result.params;
410 params.estimated_time_us = result.mean_time.as_secs_f64() * 1_000_000.0;
411 Ok(params)
412 } else {
413 Ok(self.default_matmul_params(m, n, k))
414 }
415 }
416
417 fn benchmark_config(
419 &self,
420 params: &KernelParams,
421 m: usize,
422 n: usize,
423 k: usize,
424 ) -> Result<TuningResult> {
425 let mut timings = Vec::new();
426
427 for _ in 0..self.config.warmup_iterations {
429 self.execute_kernel(params, m, n, k)?;
430 }
431
432 for _ in 0..self.config.benchmark_iterations {
434 let start = Instant::now();
435 self.execute_kernel(params, m, n, k)?;
436 timings.push(start.elapsed());
437 }
438
439 let mean_time = timings.iter().sum::<Duration>() / timings.len() as u32;
441
442 let variance = timings
443 .iter()
444 .map(|t| {
445 let diff = t.as_secs_f64() - mean_time.as_secs_f64();
446 diff * diff
447 })
448 .sum::<f64>()
449 / timings.len() as f64;
450
451 let std_dev = variance.sqrt();
452
453 Ok(TuningResult {
454 params: params.clone(),
455 mean_time,
456 std_dev,
457 })
458 }
459
460 fn execute_kernel(
462 &self,
463 _params: &KernelParams,
464 _m: usize,
465 _n: usize,
466 _k: usize,
467 ) -> Result<()> {
468 std::thread::sleep(Duration::from_micros(10));
471 Ok(())
472 }
473
474 fn compute_grid_size(
476 &self,
477 m: usize,
478 n: usize,
479 block_size: (usize, usize, usize),
480 ) -> (usize, usize, usize) {
481 let grid_x = m.div_ceil(block_size.0);
482 let grid_y = n.div_ceil(block_size.1);
483 (grid_x, grid_y, 1)
484 }
485
486 fn default_matmul_params(&self, m: usize, n: usize, _k: usize) -> KernelParams {
488 let block_size = self.platform.suggested_block_size(Operation::MatMul);
489
490 KernelParams {
491 operation: Operation::MatMul,
492 block_size,
493 threads_per_block: 256,
494 use_shared_memory: true,
495 unroll_factor: 4,
496 vector_width: 4,
497 grid_size: self.compute_grid_size(m, n, block_size),
498 estimated_time_us: 0.0,
499 }
500 }
501
502 pub fn tune_operation(
504 &mut self,
505 operation: Operation,
506 input_shape: &[usize],
507 ) -> Result<KernelParams> {
508 let key = CacheKey {
509 operation,
510 backend: self.platform.backend,
511 device_name: self.platform.device_name.clone(),
512 input_shape: input_shape.to_vec(),
513 };
514
515 if let Some(cached) = self.cache.get(&key) {
516 return Ok(cached.clone());
517 }
518
519 let block_size = self.platform.suggested_block_size(operation);
521
522 let params = KernelParams {
523 operation,
524 block_size,
525 threads_per_block: 256,
526 use_shared_memory: matches!(
527 operation,
528 Operation::Attention | Operation::LayerNorm | Operation::Softmax
529 ),
530 unroll_factor: 4,
531 vector_width: 4,
532 grid_size: (1, 1, 1),
533 estimated_time_us: 0.0,
534 };
535
536 self.cache.insert(key, params.clone());
537 self.cache_dirty = true;
538
539 Ok(params)
540 }
541
542 fn load_cache(&mut self) -> Result<()> {
544 if let Some(cache_dir) = &self.config.cache_dir {
545 let cache_file = cache_dir.join(format!(
546 "kernel_cache_{}_{}.json",
547 self.platform.backend as u8, self.platform.device_name
548 ));
549
550 if cache_file.exists() {
551 let contents = std::fs::read_to_string(&cache_file).map_err(|e| {
552 TrustformersError::io_error(format!("Failed to read cache: {}", e))
553 })?;
554
555 let cache_vec: Vec<(CacheKey, KernelParams)> = serde_json::from_str(&contents)
557 .map_err(|e| {
558 TrustformersError::io_error(format!("Failed to parse cache: {}", e))
559 })?;
560
561 self.cache = cache_vec.into_iter().collect();
562 }
563 }
564
565 Ok(())
566 }
567
568 pub fn save_cache(&mut self) -> Result<()> {
570 if !self.cache_dirty {
571 return Ok(());
572 }
573
574 if let Some(cache_dir) = &self.config.cache_dir {
575 std::fs::create_dir_all(cache_dir).map_err(|e| {
576 TrustformersError::io_error(format!("Failed to create cache dir: {}", e))
577 })?;
578
579 let cache_file = cache_dir.join(format!(
580 "kernel_cache_{}_{}.json",
581 self.platform.backend as u8, self.platform.device_name
582 ));
583
584 let cache_vec: Vec<(CacheKey, KernelParams)> =
586 self.cache.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
587
588 let contents = serde_json::to_string_pretty(&cache_vec).map_err(|e| {
589 TrustformersError::io_error(format!("Failed to serialize cache: {}", e))
590 })?;
591
592 std::fs::write(&cache_file, contents).map_err(|e| {
593 TrustformersError::io_error(format!("Failed to write cache: {}", e))
594 })?;
595
596 self.cache_dirty = false;
597 }
598
599 Ok(())
600 }
601
602 pub fn clear_cache(&mut self) {
604 self.cache.clear();
605 self.cache_dirty = true;
606 }
607
608 pub fn platform_info(&self) -> &PlatformInfo {
610 &self.platform
611 }
612
613 pub fn get_statistics(&self) -> TuningStatistics {
615 TuningStatistics {
616 total_cached_configs: self.cache.len(),
617 backends_covered: vec![self.platform.backend],
618 operations_tuned: self
619 .cache
620 .keys()
621 .map(|k| k.operation)
622 .collect::<std::collections::HashSet<_>>()
623 .into_iter()
624 .collect(),
625 }
626 }
627}
628
629impl Drop for KernelTuner {
630 fn drop(&mut self) {
631 let _ = self.save_cache();
633 }
634}
635
636#[derive(Debug, Clone)]
638pub struct TuningStatistics {
639 pub total_cached_configs: usize,
641
642 pub backends_covered: Vec<Backend>,
644
645 pub operations_tuned: Vec<Operation>,
647}
648
649static mut GLOBAL_TUNER: Option<KernelTuner> = None;
651static TUNER_INIT: std::sync::Once = std::sync::Once::new();
652
653#[allow(static_mut_refs)]
655pub fn get_kernel_tuner() -> &'static mut KernelTuner {
656 unsafe {
657 TUNER_INIT.call_once(|| {
658 GLOBAL_TUNER = Some(
659 KernelTuner::new(TuningConfig::default())
660 .expect("Failed to initialize kernel tuner"),
661 );
662 });
663
664 GLOBAL_TUNER.as_mut().expect("GLOBAL_TUNER initialized in call_once")
665 }
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn test_platform_detection() -> Result<()> {
674 let platform = PlatformInfo::detect()?;
675
676 assert!(platform.compute_units > 0);
677 assert!(platform.total_memory > 0);
678 assert!(!platform.device_name.is_empty());
679
680 Ok(())
681 }
682
683 #[test]
684 fn test_kernel_tuner_creation() -> Result<()> {
685 let tuner = KernelTuner::new(TuningConfig::default())?;
686
687 assert_eq!(tuner.platform.backend, Backend::CPU);
688
689 Ok(())
690 }
691
692 #[test]
693 fn test_matmul_tuning() -> Result<()> {
694 let mut tuner = KernelTuner::new(TuningConfig {
695 enable_tuning: false, ..Default::default()
697 })?;
698
699 let params = tuner.tune_matmul(1024, 768, 512)?;
700
701 assert_eq!(params.operation, Operation::MatMul);
702 assert!(params.block_size.0 > 0);
703 assert!(params.threads_per_block > 0);
704
705 Ok(())
706 }
707
708 #[test]
709 fn test_cache_persistence() -> Result<()> {
710 let temp_dir = std::env::temp_dir().join("kernel_cache_test");
711
712 {
713 let mut tuner = KernelTuner::new(TuningConfig {
714 cache_dir: Some(temp_dir.clone()),
715 enable_tuning: true,
716 max_tuning_time_secs: 1.0, ..Default::default()
718 })?;
719
720 let _ = tuner.tune_matmul(128, 128, 128)?;
721 assert!(
722 !tuner.cache.is_empty(),
723 "Cache should be populated after tuning"
724 );
725 tuner.save_cache()?;
726 }
727
728 {
730 let tuner = KernelTuner::new(TuningConfig {
731 cache_dir: Some(temp_dir.clone()),
732 ..Default::default()
733 })?;
734
735 assert!(!tuner.cache.is_empty(), "Cache should be loaded from disk");
736 }
737
738 let _ = std::fs::remove_dir_all(temp_dir);
740
741 Ok(())
742 }
743
744 #[test]
745 fn test_operation_tuning() -> Result<()> {
746 let mut tuner = KernelTuner::new(TuningConfig::default())?;
747
748 let params = tuner.tune_operation(Operation::Softmax, &[1024, 512])?;
749
750 assert_eq!(params.operation, Operation::Softmax);
751
752 Ok(())
753 }
754
755 #[test]
756 fn test_suggested_block_sizes() {
757 let platform = PlatformInfo {
758 backend: Backend::CUDA,
759 device_name: "Test GPU".to_string(),
760 compute_units: 80,
761 total_memory: 16 * 1024 * 1024 * 1024,
762 memory_bandwidth: 600.0,
763 peak_tflops: 40.0,
764 cache_sizes: vec![128 * 1024],
765 warp_size: 32,
766 max_threads_per_block: 1024,
767 };
768
769 let matmul_size = platform.suggested_block_size(Operation::MatMul);
770 assert_eq!(matmul_size, (16, 16, 1));
771
772 let softmax_size = platform.suggested_block_size(Operation::Softmax);
773 assert_eq!(softmax_size, (256, 1, 1));
774 }
775}