Skip to main content

trustformers_core/
kernel_tuning.rs

1//! Automatic kernel tuning for hardware adaptation
2//!
3//! This module provides automatic performance tuning for kernel operations across
4//! different hardware backends. It profiles kernel execution times and adaptively
5//! selects optimal parameters (block sizes, thread counts, memory layouts) for the
6//! specific hardware being used.
7//!
8//! # Features
9//!
10//! - **Auto-tuning:** Automatic parameter selection through benchmarking
11//! - **Hardware Detection:** Platform capability detection and profiling
12//! - **Caching:** Persistent tuning results for faster subsequent runs
13//! - **Multi-Backend:** Support for CUDA, ROCm, Metal, CPU, and more
14//! - **Adaptive:** Dynamic adjustment based on tensor sizes and operations
15//!
16//! # Examples
17//!
18//! ```rust,no_run
19//! use trustformers_core::kernel_tuning::{KernelTuner, TuningConfig, Operation};
20//!
21//! // Create tuner with default configuration
22//! let mut tuner = KernelTuner::new(TuningConfig::default())?;
23//!
24//! // Auto-tune matrix multiplication parameters for 1024x768 * 768x512
25//! let params = tuner.tune_matmul(1024, 512, 768)?;
26//! println!("Optimal block size: {:?}", params.block_size);
27//! # Ok::<(), Box<dyn std::error::Error>>(())
28//! ```
29
30use crate::errors::{Result, TrustformersError};
31use serde::{Deserialize, Serialize};
32use std::collections::HashMap;
33use std::path::PathBuf;
34use std::time::{Duration, Instant};
35
36/// Kernel operation types for tuning
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38pub enum Operation {
39    /// Matrix multiplication (GEMM)
40    MatMul,
41    /// Convolution operation
42    Convolution,
43    /// Softmax activation
44    Softmax,
45    /// Layer normalization
46    LayerNorm,
47    /// Attention computation
48    Attention,
49    /// Element-wise operations
50    ElementWise,
51    /// Reduction operations
52    Reduction,
53    /// Transpose/permute
54    Transpose,
55}
56
57/// Hardware backend types
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
59pub enum Backend {
60    /// CPU backend
61    CPU,
62    /// NVIDIA CUDA
63    CUDA,
64    /// AMD ROCm/HIP
65    ROCm,
66    /// Apple Metal
67    Metal,
68    /// Vulkan Compute
69    Vulkan,
70    /// Intel oneAPI
71    OneAPI,
72    /// Google TPU
73    TPU,
74}
75
76/// Platform characteristics for tuning decisions
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct PlatformInfo {
79    /// Backend type
80    pub backend: Backend,
81
82    /// Device name (e.g., "NVIDIA RTX 4090", "Apple M3 Max")
83    pub device_name: String,
84
85    /// Number of compute units (SMs, CUs, cores)
86    pub compute_units: usize,
87
88    /// Total memory in bytes
89    pub total_memory: usize,
90
91    /// Memory bandwidth in GB/s
92    pub memory_bandwidth: f32,
93
94    /// Peak compute performance in TFLOPS
95    pub peak_tflops: f32,
96
97    /// Cache sizes (L1, L2, L3) in bytes
98    pub cache_sizes: Vec<usize>,
99
100    /// Warp/wavefront size
101    pub warp_size: usize,
102
103    /// Maximum threads per block/workgroup
104    pub max_threads_per_block: usize,
105}
106
107impl PlatformInfo {
108    /// Detect current platform characteristics
109    pub fn detect() -> Result<Self> {
110        // This would use actual hardware detection APIs
111        // Simplified implementation for now
112        Ok(Self {
113            backend: Backend::CPU,
114            device_name: "Generic CPU".to_string(),
115            compute_units: num_cpus::get(),
116            total_memory: 16 * 1024 * 1024 * 1024, // 16GB default
117            memory_bandwidth: 50.0,                // GB/s
118            peak_tflops: 1.0,
119            cache_sizes: vec![32768, 262144, 8388608], // L1: 32KB, L2: 256KB, L3: 8MB
120            warp_size: 1,
121            max_threads_per_block: 256,
122        })
123    }
124
125    /// Create platform info for CUDA device
126    #[cfg(feature = "cuda")]
127    pub fn cuda(device_id: usize) -> Result<Self> {
128        // Would query actual CUDA device properties
129        Ok(Self {
130            backend: Backend::CUDA,
131            device_name: format!("CUDA Device {}", device_id),
132            compute_units: 128,
133            total_memory: 24 * 1024 * 1024 * 1024,
134            memory_bandwidth: 900.0,
135            peak_tflops: 82.0,
136            cache_sizes: vec![128 * 1024, 40 * 1024 * 1024], // L1: 128KB, L2: 40MB
137            warp_size: 32,
138            max_threads_per_block: 1024,
139        })
140    }
141
142    /// Get optimal block size based on hardware characteristics
143    pub fn suggested_block_size(&self, operation: Operation) -> (usize, usize, usize) {
144        match self.backend {
145            Backend::CUDA => {
146                // CUDA-specific block sizes
147                match operation {
148                    Operation::MatMul => (16, 16, 1),
149                    Operation::Convolution => (16, 16, 1),
150                    Operation::Softmax => (256, 1, 1),
151                    Operation::LayerNorm => (256, 1, 1),
152                    Operation::Attention => (64, 1, 1),
153                    Operation::ElementWise => (256, 1, 1),
154                    Operation::Reduction => (256, 1, 1),
155                    Operation::Transpose => (32, 8, 1),
156                }
157            },
158            Backend::CPU => {
159                // CPU tile sizes (for blocked algorithms)
160                match operation {
161                    Operation::MatMul => (64, 64, 64),
162                    _ => (32, 32, 1),
163                }
164            },
165            _ => (16, 16, 1), // Conservative default
166        }
167    }
168}
169
170/// Tuned kernel parameters
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct KernelParams {
173    /// Operation type
174    pub operation: Operation,
175
176    /// Block/tile size (x, y, z)
177    pub block_size: (usize, usize, usize),
178
179    /// Thread count per block
180    pub threads_per_block: usize,
181
182    /// Use shared/local memory
183    pub use_shared_memory: bool,
184
185    /// Unroll factor for loops
186    pub unroll_factor: usize,
187
188    /// Vectorization width (1, 2, 4, 8, 16)
189    pub vector_width: usize,
190
191    /// Grid dimensions
192    pub grid_size: (usize, usize, usize),
193
194    /// Estimated execution time in microseconds
195    pub estimated_time_us: f64,
196}
197
198impl Default for KernelParams {
199    fn default() -> Self {
200        Self {
201            operation: Operation::ElementWise,
202            block_size: (16, 16, 1),
203            threads_per_block: 256,
204            use_shared_memory: true,
205            unroll_factor: 4,
206            vector_width: 4,
207            grid_size: (1, 1, 1),
208            estimated_time_us: 0.0,
209        }
210    }
211}
212
213/// Tuning configuration
214#[derive(Debug, Clone, Serialize, Deserialize)]
215pub struct TuningConfig {
216    /// Enable auto-tuning (vs. using cached results)
217    pub enable_tuning: bool,
218
219    /// Number of warmup iterations
220    pub warmup_iterations: usize,
221
222    /// Number of benchmark iterations
223    pub benchmark_iterations: usize,
224
225    /// Cache directory for tuning results
226    pub cache_dir: Option<PathBuf>,
227
228    /// Maximum tuning time per kernel in seconds
229    pub max_tuning_time_secs: f32,
230
231    /// Minimum performance improvement threshold (fraction)
232    pub min_improvement_threshold: f32,
233}
234
235impl Default for TuningConfig {
236    fn default() -> Self {
237        Self {
238            enable_tuning: true,
239            warmup_iterations: 3,
240            benchmark_iterations: 10,
241            cache_dir: Some(PathBuf::from(".kernel_cache")),
242            max_tuning_time_secs: 10.0,
243            min_improvement_threshold: 0.05, // 5% improvement
244        }
245    }
246}
247
248/// Tuning result for a specific configuration
249#[derive(Debug, Clone)]
250struct TuningResult {
251    params: KernelParams,
252    mean_time: Duration,
253    #[allow(dead_code)]
254    std_dev: f64,
255}
256
257/// Cache key for tuning results
258#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
259struct CacheKey {
260    operation: Operation,
261    backend: Backend,
262    device_name: String,
263    input_shape: Vec<usize>,
264}
265
266/// Automatic kernel tuner
267pub struct KernelTuner {
268    /// Tuning configuration
269    config: TuningConfig,
270
271    /// Platform information
272    platform: PlatformInfo,
273
274    /// Cache of tuned parameters
275    cache: HashMap<CacheKey, KernelParams>,
276
277    /// Whether cache has been modified
278    cache_dirty: bool,
279}
280
281impl KernelTuner {
282    /// Create a new kernel tuner
283    pub fn new(config: TuningConfig) -> Result<Self> {
284        let platform = PlatformInfo::detect()?;
285
286        let mut tuner = Self {
287            config,
288            platform,
289            cache: HashMap::new(),
290            cache_dirty: false,
291        };
292
293        // Load cached tuning results
294        tuner.load_cache()?;
295
296        Ok(tuner)
297    }
298
299    /// Create tuner for specific backend
300    pub fn for_backend(backend: Backend, config: TuningConfig) -> Result<Self> {
301        let platform = match backend {
302            #[cfg(feature = "cuda")]
303            Backend::CUDA => PlatformInfo::cuda(0)?,
304            _ => PlatformInfo::detect()?,
305        };
306
307        let mut tuner = Self {
308            config,
309            platform,
310            cache: HashMap::new(),
311            cache_dirty: false,
312        };
313
314        tuner.load_cache()?;
315
316        Ok(tuner)
317    }
318
319    /// Get or tune parameters for matrix multiplication
320    pub fn tune_matmul(&mut self, m: usize, n: usize, k: usize) -> Result<KernelParams> {
321        let key = CacheKey {
322            operation: Operation::MatMul,
323            backend: self.platform.backend,
324            device_name: self.platform.device_name.clone(),
325            input_shape: vec![m, n, k],
326        };
327
328        if let Some(cached) = self.cache.get(&key) {
329            return Ok(cached.clone());
330        }
331
332        if !self.config.enable_tuning {
333            // Use heuristic defaults
334            return Ok(self.default_matmul_params(m, n, k));
335        }
336
337        // Auto-tune parameters
338        let params = self.auto_tune_matmul(m, n, k)?;
339
340        self.cache.insert(key, params.clone());
341        self.cache_dirty = true;
342
343        Ok(params)
344    }
345
346    /// Auto-tune matrix multiplication parameters
347    fn auto_tune_matmul(&self, m: usize, n: usize, k: usize) -> Result<KernelParams> {
348        let start_time = Instant::now();
349        let max_duration = Duration::from_secs_f32(self.config.max_tuning_time_secs);
350
351        let mut best_result: Option<TuningResult> = None;
352
353        // Search space for block sizes
354        let block_sizes = vec![
355            (8, 8, 8),
356            (16, 16, 16),
357            (32, 32, 32),
358            (64, 64, 64),
359            (128, 128, 8),
360        ];
361
362        // Search space for thread counts
363        let thread_counts = vec![64, 128, 256, 512, 1024];
364
365        // Search space for unroll factors
366        let unroll_factors = vec![1, 2, 4, 8];
367
368        for &block_size in &block_sizes {
369            if start_time.elapsed() > max_duration {
370                break;
371            }
372
373            for &threads in &thread_counts {
374                if threads > self.platform.max_threads_per_block {
375                    continue;
376                }
377
378                for &unroll in &unroll_factors {
379                    if start_time.elapsed() > max_duration {
380                        break;
381                    }
382
383                    let params = KernelParams {
384                        operation: Operation::MatMul,
385                        block_size,
386                        threads_per_block: threads,
387                        use_shared_memory: true,
388                        unroll_factor: unroll,
389                        vector_width: 4,
390                        grid_size: self.compute_grid_size(m, n, block_size),
391                        estimated_time_us: 0.0,
392                    };
393
394                    // Benchmark this configuration
395                    if let Ok(result) = self.benchmark_config(&params, m, n, k) {
396                        let is_better = match &best_result {
397                            None => true,
398                            Some(best) => result.mean_time < best.mean_time,
399                        };
400                        if is_better {
401                            best_result = Some(result);
402                        }
403                    }
404                }
405            }
406        }
407
408        if let Some(result) = best_result {
409            let mut params = result.params;
410            params.estimated_time_us = result.mean_time.as_secs_f64() * 1_000_000.0;
411            Ok(params)
412        } else {
413            Ok(self.default_matmul_params(m, n, k))
414        }
415    }
416
417    /// Benchmark a specific kernel configuration
418    fn benchmark_config(
419        &self,
420        params: &KernelParams,
421        m: usize,
422        n: usize,
423        k: usize,
424    ) -> Result<TuningResult> {
425        let mut timings = Vec::new();
426
427        // Warmup iterations
428        for _ in 0..self.config.warmup_iterations {
429            self.execute_kernel(params, m, n, k)?;
430        }
431
432        // Benchmark iterations
433        for _ in 0..self.config.benchmark_iterations {
434            let start = Instant::now();
435            self.execute_kernel(params, m, n, k)?;
436            timings.push(start.elapsed());
437        }
438
439        // Compute statistics
440        let mean_time = timings.iter().sum::<Duration>() / timings.len() as u32;
441
442        let variance = timings
443            .iter()
444            .map(|t| {
445                let diff = t.as_secs_f64() - mean_time.as_secs_f64();
446                diff * diff
447            })
448            .sum::<f64>()
449            / timings.len() as f64;
450
451        let std_dev = variance.sqrt();
452
453        Ok(TuningResult {
454            params: params.clone(),
455            mean_time,
456            std_dev,
457        })
458    }
459
460    /// Execute kernel with given parameters (mock implementation)
461    fn execute_kernel(
462        &self,
463        _params: &KernelParams,
464        _m: usize,
465        _n: usize,
466        _k: usize,
467    ) -> Result<()> {
468        // This would execute the actual kernel
469        // For now, simulate execution time based on parameters
470        std::thread::sleep(Duration::from_micros(10));
471        Ok(())
472    }
473
474    /// Compute grid size for given problem and block size
475    fn compute_grid_size(
476        &self,
477        m: usize,
478        n: usize,
479        block_size: (usize, usize, usize),
480    ) -> (usize, usize, usize) {
481        let grid_x = m.div_ceil(block_size.0);
482        let grid_y = n.div_ceil(block_size.1);
483        (grid_x, grid_y, 1)
484    }
485
486    /// Get default parameters for matrix multiplication
487    fn default_matmul_params(&self, m: usize, n: usize, _k: usize) -> KernelParams {
488        let block_size = self.platform.suggested_block_size(Operation::MatMul);
489
490        KernelParams {
491            operation: Operation::MatMul,
492            block_size,
493            threads_per_block: 256,
494            use_shared_memory: true,
495            unroll_factor: 4,
496            vector_width: 4,
497            grid_size: self.compute_grid_size(m, n, block_size),
498            estimated_time_us: 0.0,
499        }
500    }
501
502    /// Tune parameters for a generic operation
503    pub fn tune_operation(
504        &mut self,
505        operation: Operation,
506        input_shape: &[usize],
507    ) -> Result<KernelParams> {
508        let key = CacheKey {
509            operation,
510            backend: self.platform.backend,
511            device_name: self.platform.device_name.clone(),
512            input_shape: input_shape.to_vec(),
513        };
514
515        if let Some(cached) = self.cache.get(&key) {
516            return Ok(cached.clone());
517        }
518
519        // Use heuristic defaults for non-matmul operations
520        let block_size = self.platform.suggested_block_size(operation);
521
522        let params = KernelParams {
523            operation,
524            block_size,
525            threads_per_block: 256,
526            use_shared_memory: matches!(
527                operation,
528                Operation::Attention | Operation::LayerNorm | Operation::Softmax
529            ),
530            unroll_factor: 4,
531            vector_width: 4,
532            grid_size: (1, 1, 1),
533            estimated_time_us: 0.0,
534        };
535
536        self.cache.insert(key, params.clone());
537        self.cache_dirty = true;
538
539        Ok(params)
540    }
541
542    /// Load tuning cache from disk
543    fn load_cache(&mut self) -> Result<()> {
544        if let Some(cache_dir) = &self.config.cache_dir {
545            let cache_file = cache_dir.join(format!(
546                "kernel_cache_{}_{}.json",
547                self.platform.backend as u8, self.platform.device_name
548            ));
549
550            if cache_file.exists() {
551                let contents = std::fs::read_to_string(&cache_file).map_err(|e| {
552                    TrustformersError::io_error(format!("Failed to read cache: {}", e))
553                })?;
554
555                // Deserialize from Vec and convert to HashMap
556                let cache_vec: Vec<(CacheKey, KernelParams)> = serde_json::from_str(&contents)
557                    .map_err(|e| {
558                        TrustformersError::io_error(format!("Failed to parse cache: {}", e))
559                    })?;
560
561                self.cache = cache_vec.into_iter().collect();
562            }
563        }
564
565        Ok(())
566    }
567
568    /// Save tuning cache to disk
569    pub fn save_cache(&mut self) -> Result<()> {
570        if !self.cache_dirty {
571            return Ok(());
572        }
573
574        if let Some(cache_dir) = &self.config.cache_dir {
575            std::fs::create_dir_all(cache_dir).map_err(|e| {
576                TrustformersError::io_error(format!("Failed to create cache dir: {}", e))
577            })?;
578
579            let cache_file = cache_dir.join(format!(
580                "kernel_cache_{}_{}.json",
581                self.platform.backend as u8, self.platform.device_name
582            ));
583
584            // Convert to Vec for serialization (JSON doesn't support non-string keys)
585            let cache_vec: Vec<(CacheKey, KernelParams)> =
586                self.cache.iter().map(|(k, v)| (k.clone(), v.clone())).collect();
587
588            let contents = serde_json::to_string_pretty(&cache_vec).map_err(|e| {
589                TrustformersError::io_error(format!("Failed to serialize cache: {}", e))
590            })?;
591
592            std::fs::write(&cache_file, contents).map_err(|e| {
593                TrustformersError::io_error(format!("Failed to write cache: {}", e))
594            })?;
595
596            self.cache_dirty = false;
597        }
598
599        Ok(())
600    }
601
602    /// Clear all cached tuning results
603    pub fn clear_cache(&mut self) {
604        self.cache.clear();
605        self.cache_dirty = true;
606    }
607
608    /// Get platform information
609    pub fn platform_info(&self) -> &PlatformInfo {
610        &self.platform
611    }
612
613    /// Get tuning statistics
614    pub fn get_statistics(&self) -> TuningStatistics {
615        TuningStatistics {
616            total_cached_configs: self.cache.len(),
617            backends_covered: vec![self.platform.backend],
618            operations_tuned: self
619                .cache
620                .keys()
621                .map(|k| k.operation)
622                .collect::<std::collections::HashSet<_>>()
623                .into_iter()
624                .collect(),
625        }
626    }
627}
628
629impl Drop for KernelTuner {
630    fn drop(&mut self) {
631        // Auto-save cache on drop
632        let _ = self.save_cache();
633    }
634}
635
636/// Statistics about tuning results
637#[derive(Debug, Clone)]
638pub struct TuningStatistics {
639    /// Total number of cached configurations
640    pub total_cached_configs: usize,
641
642    /// Backends that have tuned configurations
643    pub backends_covered: Vec<Backend>,
644
645    /// Operations that have been tuned
646    pub operations_tuned: Vec<Operation>,
647}
648
649/// Global kernel tuner instance
650static mut GLOBAL_TUNER: Option<KernelTuner> = None;
651static TUNER_INIT: std::sync::Once = std::sync::Once::new();
652
653/// Get or initialize the global kernel tuner
654#[allow(static_mut_refs)]
655pub fn get_kernel_tuner() -> &'static mut KernelTuner {
656    unsafe {
657        TUNER_INIT.call_once(|| {
658            GLOBAL_TUNER = Some(
659                KernelTuner::new(TuningConfig::default())
660                    .expect("Failed to initialize kernel tuner"),
661            );
662        });
663
664        GLOBAL_TUNER.as_mut().expect("GLOBAL_TUNER initialized in call_once")
665    }
666}
667
668#[cfg(test)]
669mod tests {
670    use super::*;
671
672    #[test]
673    fn test_platform_detection() -> Result<()> {
674        let platform = PlatformInfo::detect()?;
675
676        assert!(platform.compute_units > 0);
677        assert!(platform.total_memory > 0);
678        assert!(!platform.device_name.is_empty());
679
680        Ok(())
681    }
682
683    #[test]
684    fn test_kernel_tuner_creation() -> Result<()> {
685        let tuner = KernelTuner::new(TuningConfig::default())?;
686
687        assert_eq!(tuner.platform.backend, Backend::CPU);
688
689        Ok(())
690    }
691
692    #[test]
693    fn test_matmul_tuning() -> Result<()> {
694        let mut tuner = KernelTuner::new(TuningConfig {
695            enable_tuning: false, // Use defaults for testing
696            ..Default::default()
697        })?;
698
699        let params = tuner.tune_matmul(1024, 768, 512)?;
700
701        assert_eq!(params.operation, Operation::MatMul);
702        assert!(params.block_size.0 > 0);
703        assert!(params.threads_per_block > 0);
704
705        Ok(())
706    }
707
708    #[test]
709    fn test_cache_persistence() -> Result<()> {
710        let temp_dir = std::env::temp_dir().join("kernel_cache_test");
711
712        {
713            let mut tuner = KernelTuner::new(TuningConfig {
714                cache_dir: Some(temp_dir.clone()),
715                enable_tuning: true,
716                max_tuning_time_secs: 1.0, // Short tuning time for tests
717                ..Default::default()
718            })?;
719
720            let _ = tuner.tune_matmul(128, 128, 128)?;
721            assert!(
722                !tuner.cache.is_empty(),
723                "Cache should be populated after tuning"
724            );
725            tuner.save_cache()?;
726        }
727
728        // Load cache in new instance
729        {
730            let tuner = KernelTuner::new(TuningConfig {
731                cache_dir: Some(temp_dir.clone()),
732                ..Default::default()
733            })?;
734
735            assert!(!tuner.cache.is_empty(), "Cache should be loaded from disk");
736        }
737
738        // Cleanup
739        let _ = std::fs::remove_dir_all(temp_dir);
740
741        Ok(())
742    }
743
744    #[test]
745    fn test_operation_tuning() -> Result<()> {
746        let mut tuner = KernelTuner::new(TuningConfig::default())?;
747
748        let params = tuner.tune_operation(Operation::Softmax, &[1024, 512])?;
749
750        assert_eq!(params.operation, Operation::Softmax);
751
752        Ok(())
753    }
754
755    #[test]
756    fn test_suggested_block_sizes() {
757        let platform = PlatformInfo {
758            backend: Backend::CUDA,
759            device_name: "Test GPU".to_string(),
760            compute_units: 80,
761            total_memory: 16 * 1024 * 1024 * 1024,
762            memory_bandwidth: 600.0,
763            peak_tflops: 40.0,
764            cache_sizes: vec![128 * 1024],
765            warp_size: 32,
766            max_threads_per_block: 1024,
767        };
768
769        let matmul_size = platform.suggested_block_size(Operation::MatMul);
770        assert_eq!(matmul_size, (16, 16, 1));
771
772        let softmax_size = platform.suggested_block_size(Operation::Softmax);
773        assert_eq!(softmax_size, (256, 1, 1));
774    }
775}