Skip to main content

torsh_backend/
adaptive_kernel_selection.rs

1//! Adaptive kernel selection based on input characteristics
2//!
3//! This module provides intelligent kernel selection that adapts to input characteristics,
4//! system state, and historical performance data to automatically choose optimal kernels
5//! for different operations.
6
7use crate::error::BackendResult as Result;
8use crate::performance_modeling::{
9    EnvironmentalFactors, PerformanceMeasurement, RuntimePerformanceModeler,
10};
11use crate::performance_tuning::{
12    AccessPattern, ActualPerformance, DataType, OperationType, PerformancePrediction, SystemState,
13    TuningParameters, WorkloadCharacteristics,
14};
15use crate::{BackendType, Device};
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, Mutex, RwLock};
19use std::time::{Duration, Instant};
20use torsh_core::error::TorshError;
21
22#[cfg(feature = "serialize")]
23use serde::{Deserialize, Serialize};
24
25#[cfg(not(feature = "std"))]
26use alloc::{boxed::Box, format, string::String, vec::Vec};
27
28/// Global counter for generating unique measurement IDs
29static MEASUREMENT_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
30
31/// Generate a unique measurement ID
32fn generate_measurement_id() -> u64 {
33    MEASUREMENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
34}
35
36/// Adaptive kernel selection coordinator
37pub struct AdaptiveKernelSelector {
38    /// Kernel registry for different backends
39    kernel_registry: Arc<RwLock<KernelRegistry>>,
40    /// Performance modeler for prediction
41    performance_modeler: Arc<RuntimePerformanceModeler>,
42    /// Selection algorithm
43    selection_algorithm: SelectionAlgorithm,
44    /// Performance tracker for learning
45    performance_tracker: Arc<Mutex<PerformanceTracker>>,
46    /// Configuration parameters
47    config: AdaptiveSelectionConfig,
48}
49
50/// Registry of available kernels organized by operation type and backend
51#[derive(Debug)]
52pub struct KernelRegistry {
53    /// Kernels by operation type and backend
54    kernels: HashMap<(OperationType, BackendType), Vec<KernelImplementation>>,
55    /// Custom kernel implementations
56    custom_kernels: HashMap<String, Box<dyn CustomKernel + Send + Sync>>,
57    /// Kernel performance characteristics
58    kernel_characteristics: HashMap<String, KernelCharacteristics>,
59    /// Default kernel fallbacks
60    #[allow(dead_code)]
61    default_kernels: HashMap<(OperationType, BackendType), String>,
62}
63
64/// Kernel implementation metadata
65#[derive(Debug, Clone)]
66pub struct KernelImplementation {
67    /// Unique kernel identifier
68    pub id: String,
69    /// Human-readable name
70    pub name: String,
71    /// Operation type this kernel implements
72    pub operation_type: OperationType,
73    /// Backend type this kernel runs on
74    pub backend_type: BackendType,
75    /// Kernel variant (e.g., "naive", "optimized", "tiled")
76    pub variant: KernelVariant,
77    /// Performance characteristics
78    pub characteristics: KernelCharacteristics,
79    /// Supported input constraints
80    pub constraints: KernelConstraints,
81    /// Kernel implementation
82    pub implementation: Arc<dyn KernelExecutor + Send + Sync>,
83}
84
85/// Kernel variant types
86#[derive(Debug, Clone, PartialEq, Eq)]
87#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
88pub enum KernelVariant {
89    /// Naive implementation (simple, works for all inputs)
90    Naive,
91    /// Optimized implementation (tuned for specific characteristics)
92    Optimized,
93    /// Tiled implementation (memory hierarchy optimized)
94    Tiled,
95    /// Vectorized implementation (SIMD optimized)
96    Vectorized,
97    /// Parallel implementation (multi-threaded)
98    Parallel,
99    /// Fused implementation (multiple operations combined)
100    Fused,
101    /// Hardware-specific implementation
102    HardwareSpecific(String),
103    /// Custom implementation
104    Custom(String),
105}
106
107/// Kernel performance characteristics
108#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
110pub struct KernelCharacteristics {
111    /// Optimal input size range
112    pub optimal_size_range: (usize, usize),
113    /// Memory access pattern
114    pub memory_pattern: AccessPattern,
115    /// Compute intensity (operations per byte)
116    pub compute_intensity: f64,
117    /// Parallelization efficiency
118    pub parallelization_efficiency: f64,
119    /// Cache efficiency
120    pub cache_efficiency: f64,
121    /// Memory bandwidth utilization
122    pub memory_bandwidth_utilization: f64,
123    /// Initialization overhead
124    pub initialization_overhead: Duration,
125    /// Scalability characteristics
126    pub scalability: ScalabilityCharacteristics,
127}
128
129/// Kernel scalability characteristics
130#[derive(Debug, Clone)]
131#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
132pub struct ScalabilityCharacteristics {
133    /// How performance scales with input size
134    pub size_scaling: ScalingBehavior,
135    /// How performance scales with thread count
136    pub thread_scaling: ScalingBehavior,
137    /// How performance scales with memory hierarchy
138    pub memory_hierarchy_scaling: ScalingBehavior,
139}
140
141/// Scaling behavior patterns
142#[derive(Debug, Clone, PartialEq, Eq)]
143#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
144pub enum ScalingBehavior {
145    /// Linear scaling
146    Linear,
147    /// Logarithmic scaling
148    Logarithmic,
149    /// Exponential scaling
150    Exponential,
151    /// Constant (no scaling)
152    Constant,
153    /// Custom scaling function
154    Custom(String),
155}
156
157/// Kernel input constraints
158#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
160pub struct KernelConstraints {
161    /// Minimum input size
162    pub min_size: usize,
163    /// Maximum input size
164    pub max_size: Option<usize>,
165    /// Supported data types
166    pub supported_dtypes: Vec<DataType>,
167    /// Required memory alignment
168    pub required_alignment: usize,
169    /// Supported tensor shapes (None means any shape)
170    pub supported_shapes: Option<Vec<Vec<usize>>>,
171    /// Required hardware features
172    pub required_features: Vec<String>,
173}
174
175/// Kernel executor interface
176pub trait KernelExecutor: std::fmt::Debug + Send + Sync {
177    /// Execute the kernel with given inputs
178    fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs>;
179
180    /// Get estimated execution time
181    fn estimate_execution_time(&self, inputs: &KernelInputs) -> Duration;
182
183    /// Check if kernel can handle given inputs
184    fn can_handle(&self, inputs: &KernelInputs) -> bool;
185
186    /// Get kernel resource requirements
187    fn get_resource_requirements(&self, inputs: &KernelInputs) -> ResourceRequirements;
188}
189
190/// Kernel input specification
191#[derive(Debug, Clone)]
192pub struct KernelInputs {
193    /// Input tensor dimensions
194    pub input_shapes: Vec<Vec<usize>>,
195    /// Data types
196    pub data_types: Vec<DataType>,
197    /// Total data size in bytes
198    pub total_size: usize,
199    /// Operation parameters
200    pub operation_params: HashMap<String, KernelParameter>,
201    /// Device information
202    pub device: Device,
203}
204
205/// Kernel output specification
206#[derive(Debug, Clone)]
207pub struct KernelOutputs {
208    /// Output tensor dimensions
209    pub output_shapes: Vec<Vec<usize>>,
210    /// Execution time
211    pub execution_time: Duration,
212    /// Memory usage
213    pub memory_usage: usize,
214    /// Success flag
215    pub success: bool,
216    /// Error message (if any)
217    pub error_message: Option<String>,
218}
219
220/// Kernel parameter values
221#[derive(Debug, Clone)]
222#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
223pub enum KernelParameter {
224    Integer(i64),
225    Float(f64),
226    String(String),
227    Boolean(bool),
228    IntegerArray(Vec<i64>),
229    FloatArray(Vec<f64>),
230}
231
232/// Resource requirements for kernel execution
233#[derive(Debug, Clone)]
234#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
235pub struct ResourceRequirements {
236    /// Memory requirement in bytes
237    pub memory: usize,
238    /// Compute units required
239    pub compute_units: usize,
240    /// Bandwidth requirement in bytes/second
241    pub bandwidth: usize,
242    /// Temporary storage requirement
243    pub temporary_storage: usize,
244}
245
246/// Custom kernel trait for user-defined kernels
247pub trait CustomKernel: std::fmt::Debug + Send + Sync {
248    /// Get kernel name
249    fn name(&self) -> &str;
250
251    /// Get operation type
252    fn operation_type(&self) -> OperationType;
253
254    /// Get backend type
255    fn backend_type(&self) -> BackendType;
256
257    /// Get kernel characteristics
258    fn characteristics(&self) -> KernelCharacteristics;
259
260    /// Get kernel constraints
261    fn constraints(&self) -> KernelConstraints;
262
263    /// Execute the kernel
264    fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs>;
265
266    /// Benchmark the kernel
267    fn benchmark(&self, inputs: &KernelInputs, iterations: usize) -> Result<BenchmarkResult>;
268}
269
270/// Benchmark result for kernel performance
271#[derive(Debug, Clone)]
272#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
273pub struct BenchmarkResult {
274    /// Average execution time
275    pub avg_execution_time: Duration,
276    /// Minimum execution time
277    pub min_execution_time: Duration,
278    /// Maximum execution time
279    pub max_execution_time: Duration,
280    /// Standard deviation
281    pub std_deviation: Duration,
282    /// Throughput (operations per second)
283    pub throughput: f64,
284    /// Memory bandwidth utilization
285    pub memory_bandwidth: f64,
286    /// Cache hit rate
287    pub cache_hit_rate: f64,
288}
289
290/// Kernel selection algorithm
291#[derive(Debug, Clone)]
292pub enum SelectionAlgorithm {
293    /// Score-based selection
294    ScoreBased(ScoreBasedConfig),
295    /// Machine learning-based selection
296    MachineLearning(MLBasedConfig),
297    /// Hybrid approach
298    Hybrid(HybridConfig),
299}
300
301/// Configuration for score-based selection
302#[derive(Debug, Clone)]
303#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
304pub struct ScoreBasedConfig {
305    /// Weight for execution time
306    pub execution_time_weight: f64,
307    /// Weight for memory usage
308    pub memory_usage_weight: f64,
309    /// Weight for cache efficiency
310    pub cache_efficiency_weight: f64,
311    /// Weight for historical performance
312    pub historical_weight: f64,
313    /// Penalty for kernel switching
314    pub switching_penalty: f64,
315}
316
317/// Configuration for ML-based selection
318#[derive(Debug, Clone)]
319#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
320pub struct MLBasedConfig {
321    /// Model type
322    pub model_type: MLModelType,
323    /// Training parameters
324    pub training_params: MLTrainingParams,
325    /// Feature weights
326    pub feature_weights: HashMap<String, f64>,
327}
328
329/// Machine learning model types
330#[derive(Debug, Clone, PartialEq, Eq)]
331#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
332pub enum MLModelType {
333    DecisionTree,
334    RandomForest,
335    NeuralNetwork,
336    SupportVectorMachine,
337    LinearRegression,
338    Custom(String),
339}
340
341/// ML training parameters
342#[derive(Debug, Clone)]
343#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
344pub struct MLTrainingParams {
345    /// Learning rate
346    pub learning_rate: f64,
347    /// Number of epochs
348    pub epochs: usize,
349    /// Batch size
350    pub batch_size: usize,
351    /// Regularization parameter
352    pub regularization: f64,
353}
354
355/// Configuration for hybrid selection
356#[derive(Debug, Clone)]
357#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
358pub struct HybridConfig {
359    /// Score-based configuration
360    pub score_based: ScoreBasedConfig,
361    /// ML-based configuration
362    pub ml_based: MLBasedConfig,
363    /// Threshold for switching to ML
364    pub ml_threshold: f64,
365}
366
367/// Performance tracker for kernel learning
368#[derive(Debug)]
369pub struct PerformanceTracker {
370    /// Historical performance data
371    #[allow(dead_code)]
372    performance_history: HashMap<String, Vec<KernelPerformanceRecord>>,
373    /// Kernel usage statistics
374    usage_stats: HashMap<String, KernelUsageStats>,
375    /// Selection accuracy tracking
376    selection_accuracy: SelectionAccuracyTracker,
377}
378
379/// Kernel performance record
380#[derive(Debug, Clone)]
381#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
382pub struct KernelPerformanceRecord {
383    /// Timestamp
384    pub timestamp: std::time::SystemTime,
385    /// Input characteristics
386    pub input_characteristics: WorkloadCharacteristics,
387    /// System state
388    pub system_state: SystemState,
389    /// Actual performance
390    pub actual_performance: ActualPerformance,
391    /// Predicted performance
392    pub predicted_performance: Option<PerformancePrediction>,
393    /// Selection confidence
394    pub selection_confidence: f64,
395}
396
397/// Kernel usage statistics
398#[derive(Debug, Clone)]
399#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
400pub struct KernelUsageStats {
401    /// Total executions
402    pub total_executions: usize,
403    /// Successful executions
404    pub successful_executions: usize,
405    /// Average execution time
406    pub avg_execution_time: Duration,
407    /// Last used timestamp
408    pub last_used: std::time::SystemTime,
409    /// Selection frequency
410    pub selection_frequency: f64,
411}
412
413/// Selection accuracy tracker
414#[derive(Debug)]
415pub struct SelectionAccuracyTracker {
416    /// Total selections made
417    total_selections: usize,
418    /// Optimal selections (in hindsight)
419    optimal_selections: usize,
420    /// Selection accuracy by operation type
421    accuracy_by_operation: HashMap<OperationType, f64>,
422    /// Selection accuracy by backend
423    accuracy_by_backend: HashMap<BackendType, f64>,
424}
425
426/// Adaptive selection configuration
427#[derive(Debug, Clone)]
428#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
429pub struct AdaptiveSelectionConfig {
430    /// Enable learning from performance feedback
431    pub enable_learning: bool,
432    /// Exploration vs exploitation trade-off
433    pub exploration_factor: f64,
434    /// Minimum confidence threshold for selections
435    pub min_confidence_threshold: f64,
436    /// Maximum number of concurrent benchmarks
437    pub max_concurrent_benchmarks: usize,
438    /// Benchmark timeout
439    pub benchmark_timeout: Duration,
440    /// History retention period
441    pub history_retention: Duration,
442}
443
444impl Default for AdaptiveSelectionConfig {
445    fn default() -> Self {
446        Self {
447            enable_learning: true,
448            exploration_factor: 0.1,
449            min_confidence_threshold: 0.8,
450            max_concurrent_benchmarks: 4,
451            benchmark_timeout: Duration::from_secs(30),
452            history_retention: Duration::from_secs(7 * 24 * 3600), // 7 days
453        }
454    }
455}
456
457impl AdaptiveKernelSelector {
458    /// Create a new adaptive kernel selector
459    pub fn new(performance_modeler: Arc<RuntimePerformanceModeler>) -> Self {
460        Self {
461            kernel_registry: Arc::new(RwLock::new(KernelRegistry::new())),
462            performance_modeler,
463            selection_algorithm: SelectionAlgorithm::ScoreBased(ScoreBasedConfig::default()),
464            performance_tracker: Arc::new(Mutex::new(PerformanceTracker::new())),
465            config: AdaptiveSelectionConfig::default(),
466        }
467    }
468
469    /// Register a kernel implementation
470    pub fn register_kernel(&self, kernel: KernelImplementation) -> Result<()> {
471        let mut registry = self
472            .kernel_registry
473            .write()
474            .expect("lock should not be poisoned");
475        registry.register_kernel(kernel)
476    }
477
478    /// Register a custom kernel
479    pub fn register_custom_kernel(
480        &self,
481        kernel: Box<dyn CustomKernel + Send + Sync>,
482    ) -> Result<()> {
483        let mut registry = self
484            .kernel_registry
485            .write()
486            .expect("lock should not be poisoned");
487        registry.register_custom_kernel(kernel)
488    }
489
490    /// Select optimal kernel for given inputs
491    pub fn select_kernel(
492        &self,
493        operation_type: OperationType,
494        backend_type: BackendType,
495        inputs: &KernelInputs,
496        workload: &WorkloadCharacteristics,
497        system_state: &SystemState,
498    ) -> Result<KernelSelection> {
499        let registry = self
500            .kernel_registry
501            .read()
502            .expect("lock should not be poisoned");
503
504        // Get candidate kernels
505        let candidates = registry.get_candidates(operation_type, backend_type, inputs)?;
506
507        if candidates.is_empty() {
508            return Err(TorshError::BackendError(format!(
509                "No suitable kernels found for operation {:?} on backend {:?}",
510                operation_type, backend_type
511            )));
512        }
513
514        // Apply selection algorithm
515        let selection = match &self.selection_algorithm {
516            SelectionAlgorithm::ScoreBased(config) => {
517                self.score_based_selection(&candidates, inputs, workload, system_state, config)?
518            }
519            SelectionAlgorithm::MachineLearning(config) => {
520                self.ml_based_selection(&candidates, inputs, workload, system_state, config)?
521            }
522            SelectionAlgorithm::Hybrid(config) => {
523                self.hybrid_selection(&candidates, inputs, workload, system_state, config)?
524            }
525        };
526
527        // Track selection for learning
528        if self.config.enable_learning {
529            self.track_selection(&selection, workload, system_state)?;
530        }
531
532        Ok(selection)
533    }
534
535    /// Score-based kernel selection
536    fn score_based_selection(
537        &self,
538        candidates: &[KernelImplementation],
539        inputs: &KernelInputs,
540        workload: &WorkloadCharacteristics,
541        system_state: &SystemState,
542        config: &ScoreBasedConfig,
543    ) -> Result<KernelSelection> {
544        let mut best_kernel = None;
545        let mut best_score = f64::NEG_INFINITY;
546
547        for kernel in candidates {
548            let score =
549                self.calculate_kernel_score(kernel, inputs, workload, system_state, config)?;
550
551            if score > best_score {
552                best_score = score;
553                best_kernel = Some(kernel);
554            }
555        }
556
557        let selected_kernel = best_kernel
558            .ok_or_else(|| TorshError::BackendError("No suitable kernel found".to_string()))?;
559
560        Ok(KernelSelection {
561            kernel: selected_kernel.clone(),
562            confidence: (best_score + 1.0) / 2.0, // Normalize to [0, 1]
563            selection_reason: SelectionReason::ScoreBased(best_score),
564            alternatives: candidates
565                .iter()
566                .filter(|k| k.id != selected_kernel.id)
567                .cloned()
568                .collect(),
569        })
570    }
571
572    /// Machine learning-based kernel selection
573    fn ml_based_selection(
574        &self,
575        candidates: &[KernelImplementation],
576        inputs: &KernelInputs,
577        workload: &WorkloadCharacteristics,
578        system_state: &SystemState,
579        _config: &MLBasedConfig,
580    ) -> Result<KernelSelection> {
581        // In a real implementation, this would use a trained ML model
582        // For now, fall back to score-based selection
583        let score_config = ScoreBasedConfig::default();
584        self.score_based_selection(candidates, inputs, workload, system_state, &score_config)
585    }
586
587    /// Hybrid kernel selection
588    fn hybrid_selection(
589        &self,
590        candidates: &[KernelImplementation],
591        inputs: &KernelInputs,
592        workload: &WorkloadCharacteristics,
593        system_state: &SystemState,
594        config: &HybridConfig,
595    ) -> Result<KernelSelection> {
596        // Use ML if confidence is above threshold, otherwise use score-based
597        let ml_confidence = self.get_ml_confidence(inputs, workload, system_state)?;
598
599        if ml_confidence > config.ml_threshold {
600            self.ml_based_selection(candidates, inputs, workload, system_state, &config.ml_based)
601        } else {
602            self.score_based_selection(
603                candidates,
604                inputs,
605                workload,
606                system_state,
607                &config.score_based,
608            )
609        }
610    }
611
612    /// Calculate kernel score for score-based selection
613    fn calculate_kernel_score(
614        &self,
615        kernel: &KernelImplementation,
616        inputs: &KernelInputs,
617        workload: &WorkloadCharacteristics,
618        system_state: &SystemState,
619        config: &ScoreBasedConfig,
620    ) -> Result<f64> {
621        let mut score = 0.0;
622
623        // Execution time score
624        let predicted_time = self.predict_execution_time(kernel, inputs, workload, system_state)?;
625        let time_score = 1.0 / (1.0 + predicted_time.as_secs_f64());
626        score += config.execution_time_weight * time_score;
627
628        // Memory usage score
629        let memory_requirements = kernel.implementation.get_resource_requirements(inputs);
630        let memory_score = 1.0 / (1.0 + memory_requirements.memory as f64 / 1024.0 / 1024.0);
631        score += config.memory_usage_weight * memory_score;
632
633        // Cache efficiency score
634        let cache_score = kernel.characteristics.cache_efficiency;
635        score += config.cache_efficiency_weight * cache_score;
636
637        // Historical performance score
638        let historical_score = self.get_historical_performance_score(&kernel.id)?;
639        score += config.historical_weight * historical_score;
640
641        // Switching penalty (if currently using a different kernel)
642        if let Some(current_kernel) = self.get_current_kernel(workload)? {
643            if current_kernel != kernel.id {
644                score -= config.switching_penalty;
645            }
646        }
647
648        Ok(score)
649    }
650
651    /// Predict execution time for a kernel
652    fn predict_execution_time(
653        &self,
654        kernel: &KernelImplementation,
655        inputs: &KernelInputs,
656        workload: &WorkloadCharacteristics,
657        system_state: &SystemState,
658    ) -> Result<Duration> {
659        // Extract device ID from inputs
660        let device_id = inputs.device.id();
661
662        // Use the performance modeler to predict execution time
663        let _measurement = PerformanceMeasurement {
664            id: generate_measurement_id(),
665            timestamp: std::time::SystemTime::now(),
666            backend_type: kernel.backend_type,
667            device_id,
668            workload: workload.clone(),
669            parameters: TuningParameters::default(),
670            system_state: system_state.clone(),
671            actual_performance: ActualPerformance::default(),
672            predicted_performance: None,
673            prediction_accuracy: None,
674            environment: crate::performance_modeling::EnvironmentalFactors::default(),
675        };
676
677        // Get prediction from performance modeler
678        let default_params = TuningParameters::default();
679        let default_env = EnvironmentalFactors::default();
680        let prediction = self.performance_modeler.predict_performance(
681            kernel.backend_type,
682            workload,
683            &default_params,
684            system_state,
685            &default_env,
686        )?;
687
688        Ok(prediction.execution_time)
689    }
690
691    /// Get historical performance score for a kernel
692    fn get_historical_performance_score(&self, kernel_id: &str) -> Result<f64> {
693        let tracker = self
694            .performance_tracker
695            .lock()
696            .expect("lock should not be poisoned");
697
698        if let Some(stats) = tracker.usage_stats.get(kernel_id) {
699            let success_rate = stats.successful_executions as f64 / stats.total_executions as f64;
700            let recency_factor = self.calculate_recency_factor(stats.last_used);
701            Ok(success_rate * recency_factor)
702        } else {
703            Ok(0.5) // Neutral score for new kernels
704        }
705    }
706
707    /// Calculate recency factor for historical performance
708    fn calculate_recency_factor(&self, last_used: std::time::SystemTime) -> f64 {
709        let now = std::time::SystemTime::now();
710        let elapsed = now
711            .duration_since(last_used)
712            .unwrap_or(Duration::from_secs(0));
713        let days_elapsed = elapsed.as_secs() as f64 / (24.0 * 3600.0);
714
715        // Exponential decay with half-life of 7 days
716        (-days_elapsed / 7.0).exp()
717    }
718
719    /// Get current kernel for workload (if any)
720    fn get_current_kernel(&self, _workload: &WorkloadCharacteristics) -> Result<Option<String>> {
721        // In a real implementation, this would track the currently selected kernel
722        // For now, return None
723        Ok(None)
724    }
725
726    /// Get ML confidence for hybrid selection
727    fn get_ml_confidence(
728        &self,
729        _inputs: &KernelInputs,
730        _workload: &WorkloadCharacteristics,
731        _system_state: &SystemState,
732    ) -> Result<f64> {
733        // Placeholder implementation
734        Ok(0.5)
735    }
736
737    /// Track kernel selection for learning
738    fn track_selection(
739        &self,
740        selection: &KernelSelection,
741        workload: &WorkloadCharacteristics,
742        system_state: &SystemState,
743    ) -> Result<()> {
744        let mut tracker = self
745            .performance_tracker
746            .lock()
747            .expect("lock should not be poisoned");
748        tracker.track_selection(selection, workload, system_state)
749    }
750
751    /// Update performance feedback
752    pub fn update_performance_feedback(
753        &self,
754        kernel_id: &str,
755        actual_performance: ActualPerformance,
756        predicted_performance: Option<PerformancePrediction>,
757    ) -> Result<()> {
758        let mut tracker = self
759            .performance_tracker
760            .lock()
761            .expect("lock should not be poisoned");
762        tracker.update_performance_feedback(kernel_id, actual_performance, predicted_performance)
763    }
764
765    /// Get selection statistics
766    pub fn get_selection_statistics(&self) -> Result<SelectionStatistics> {
767        let tracker = self
768            .performance_tracker
769            .lock()
770            .expect("lock should not be poisoned");
771        Ok(tracker.get_statistics())
772    }
773
774    /// Benchmark kernels for calibration
775    pub fn benchmark_kernels(
776        &self,
777        operation_type: OperationType,
778        backend_type: BackendType,
779        test_inputs: &[KernelInputs],
780    ) -> Result<BenchmarkResults> {
781        let registry = self
782            .kernel_registry
783            .read()
784            .expect("lock should not be poisoned");
785        let kernels = registry.get_kernels_for_operation(operation_type, backend_type);
786
787        let mut results = BenchmarkResults::new();
788
789        for kernel in kernels {
790            for inputs in test_inputs {
791                if kernel.implementation.can_handle(inputs) {
792                    let benchmark = self.benchmark_kernel(&kernel, inputs)?;
793                    results.add_result(kernel.id.clone(), benchmark);
794                }
795            }
796        }
797
798        Ok(results)
799    }
800
801    /// Benchmark a single kernel
802    fn benchmark_kernel(
803        &self,
804        kernel: &KernelImplementation,
805        inputs: &KernelInputs,
806    ) -> Result<BenchmarkResult> {
807        let iterations = 10;
808        let mut execution_times = Vec::new();
809
810        for _ in 0..iterations {
811            let start = Instant::now();
812            let result = kernel.implementation.execute(inputs)?;
813            let execution_time = start.elapsed();
814
815            if result.success {
816                execution_times.push(execution_time);
817            }
818        }
819
820        if execution_times.is_empty() {
821            return Err(TorshError::BackendError(
822                "All benchmark iterations failed".to_string(),
823            ));
824        }
825
826        let avg_time = execution_times.iter().sum::<Duration>() / execution_times.len() as u32;
827        let min_time = *execution_times
828            .iter()
829            .min()
830            .expect("execution_times should not be empty after check");
831        let max_time = *execution_times
832            .iter()
833            .max()
834            .expect("execution_times should not be empty after check");
835
836        // Calculate standard deviation
837        let variance = execution_times
838            .iter()
839            .map(|t| {
840                let diff = t.as_secs_f64() - avg_time.as_secs_f64();
841                diff * diff
842            })
843            .sum::<f64>()
844            / execution_times.len() as f64;
845        let std_dev = Duration::from_secs_f64(variance.sqrt());
846
847        // Calculate memory bandwidth (bytes per second)
848        // Estimate: assume we read and write the data once each
849        let total_bytes_accessed = (inputs.total_size * 2) as f64; // Read + Write
850        let memory_bandwidth = total_bytes_accessed / avg_time.as_secs_f64();
851
852        // Estimate cache hit rate based on data size and variance
853        // Lower variance often indicates better cache locality
854        // This is a heuristic: cache_hit_rate decreases as data size increases
855        let cache_size_estimate = 32.0 * 1024.0 * 1024.0; // 32 MB L3 cache estimate
856        let data_size_ratio = (inputs.total_size as f64 / cache_size_estimate).min(1.0);
857        let variance_factor = 1.0 - (std_dev.as_secs_f64() / avg_time.as_secs_f64()).min(0.5);
858        let cache_hit_rate = (1.0 - data_size_ratio) * variance_factor;
859
860        Ok(BenchmarkResult {
861            avg_execution_time: avg_time,
862            min_execution_time: min_time,
863            max_execution_time: max_time,
864            std_deviation: std_dev,
865            throughput: 1.0 / avg_time.as_secs_f64(),
866            memory_bandwidth,
867            cache_hit_rate,
868        })
869    }
870}
871
872/// Kernel selection result
873#[derive(Debug, Clone)]
874pub struct KernelSelection {
875    /// Selected kernel
876    pub kernel: KernelImplementation,
877    /// Selection confidence (0.0 to 1.0)
878    pub confidence: f64,
879    /// Reason for selection
880    pub selection_reason: SelectionReason,
881    /// Alternative kernels considered
882    pub alternatives: Vec<KernelImplementation>,
883}
884
885/// Reason for kernel selection
886#[derive(Debug, Clone)]
887pub enum SelectionReason {
888    /// Score-based selection with score
889    ScoreBased(f64),
890    /// Machine learning prediction
891    MachineLearning(f64),
892    /// Hybrid selection
893    Hybrid(f64),
894    /// Fallback to default
895    Default,
896}
897
898/// Benchmark results for multiple kernels
899#[derive(Debug)]
900pub struct BenchmarkResults {
901    /// Results by kernel ID
902    results: HashMap<String, BenchmarkResult>,
903}
904
905impl BenchmarkResults {
906    pub fn new() -> Self {
907        Self {
908            results: HashMap::new(),
909        }
910    }
911
912    pub fn add_result(&mut self, kernel_id: String, result: BenchmarkResult) {
913        self.results.insert(kernel_id, result);
914    }
915
916    pub fn get_result(&self, kernel_id: &str) -> Option<&BenchmarkResult> {
917        self.results.get(kernel_id)
918    }
919
920    pub fn get_best_kernel(&self) -> Option<(&String, &BenchmarkResult)> {
921        self.results
922            .iter()
923            .min_by(|a, b| a.1.avg_execution_time.cmp(&b.1.avg_execution_time))
924    }
925}
926
927/// Selection statistics
928#[derive(Debug, Clone)]
929#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
930pub struct SelectionStatistics {
931    /// Total selections made
932    pub total_selections: usize,
933    /// Selection accuracy
934    pub overall_accuracy: f64,
935    /// Accuracy by operation type
936    pub accuracy_by_operation: HashMap<OperationType, f64>,
937    /// Accuracy by backend
938    pub accuracy_by_backend: HashMap<BackendType, f64>,
939    /// Most frequently selected kernels
940    pub popular_kernels: Vec<(String, usize)>,
941}
942
943impl KernelRegistry {
944    pub fn new() -> Self {
945        Self {
946            kernels: HashMap::new(),
947            custom_kernels: HashMap::new(),
948            kernel_characteristics: HashMap::new(),
949            default_kernels: HashMap::new(),
950        }
951    }
952
953    pub fn register_kernel(&mut self, kernel: KernelImplementation) -> Result<()> {
954        let key = (kernel.operation_type, kernel.backend_type);
955        self.kernels
956            .entry(key)
957            .or_insert_with(Vec::new)
958            .push(kernel.clone());
959        self.kernel_characteristics
960            .insert(kernel.id.clone(), kernel.characteristics);
961        Ok(())
962    }
963
964    pub fn register_custom_kernel(
965        &mut self,
966        kernel: Box<dyn CustomKernel + Send + Sync>,
967    ) -> Result<()> {
968        let name = kernel.name().to_string();
969        self.custom_kernels.insert(name, kernel);
970        Ok(())
971    }
972
973    pub fn get_candidates(
974        &self,
975        operation_type: OperationType,
976        backend_type: BackendType,
977        inputs: &KernelInputs,
978    ) -> Result<Vec<KernelImplementation>> {
979        let key = (operation_type, backend_type);
980
981        if let Some(kernels) = self.kernels.get(&key) {
982            let candidates = kernels
983                .iter()
984                .filter(|k| k.implementation.can_handle(inputs))
985                .cloned()
986                .collect();
987            Ok(candidates)
988        } else {
989            Ok(Vec::new())
990        }
991    }
992
993    pub fn get_kernels_for_operation(
994        &self,
995        operation_type: OperationType,
996        backend_type: BackendType,
997    ) -> Vec<KernelImplementation> {
998        let key = (operation_type, backend_type);
999        self.kernels.get(&key).cloned().unwrap_or_default()
1000    }
1001}
1002
1003impl PerformanceTracker {
1004    pub fn new() -> Self {
1005        Self {
1006            performance_history: HashMap::new(),
1007            usage_stats: HashMap::new(),
1008            selection_accuracy: SelectionAccuracyTracker::new(),
1009        }
1010    }
1011
1012    pub fn track_selection(
1013        &mut self,
1014        selection: &KernelSelection,
1015        workload: &WorkloadCharacteristics,
1016        system_state: &SystemState,
1017    ) -> Result<()> {
1018        let kernel_id = &selection.kernel.id;
1019
1020        // Update usage stats
1021        let stats = self
1022            .usage_stats
1023            .entry(kernel_id.clone())
1024            .or_insert_with(KernelUsageStats::default);
1025        stats.total_executions += 1;
1026        stats.last_used = std::time::SystemTime::now();
1027
1028        // Update selection accuracy tracker
1029        self.selection_accuracy
1030            .track_selection(selection, workload, system_state);
1031
1032        Ok(())
1033    }
1034
1035    pub fn update_performance_feedback(
1036        &mut self,
1037        kernel_id: &str,
1038        actual_performance: ActualPerformance,
1039        _predicted_performance: Option<PerformancePrediction>,
1040    ) -> Result<()> {
1041        // Update usage stats
1042        if let Some(stats) = self.usage_stats.get_mut(kernel_id) {
1043            stats.successful_executions += 1;
1044            stats.avg_execution_time = actual_performance.execution_time;
1045        }
1046
1047        Ok(())
1048    }
1049
1050    pub fn get_statistics(&self) -> SelectionStatistics {
1051        SelectionStatistics {
1052            total_selections: self.selection_accuracy.total_selections,
1053            overall_accuracy: self.selection_accuracy.get_overall_accuracy(),
1054            accuracy_by_operation: self.selection_accuracy.accuracy_by_operation.clone(),
1055            accuracy_by_backend: self.selection_accuracy.accuracy_by_backend.clone(),
1056            popular_kernels: self.get_popular_kernels(),
1057        }
1058    }
1059
1060    fn get_popular_kernels(&self) -> Vec<(String, usize)> {
1061        let mut kernels: Vec<_> = self
1062            .usage_stats
1063            .iter()
1064            .map(|(id, stats)| (id.clone(), stats.total_executions))
1065            .collect();
1066        kernels.sort_by(|a, b| b.1.cmp(&a.1));
1067        kernels.into_iter().take(10).collect()
1068    }
1069}
1070
1071impl SelectionAccuracyTracker {
1072    pub fn new() -> Self {
1073        Self {
1074            total_selections: 0,
1075            optimal_selections: 0,
1076            accuracy_by_operation: HashMap::new(),
1077            accuracy_by_backend: HashMap::new(),
1078        }
1079    }
1080
1081    pub fn track_selection(
1082        &mut self,
1083        selection: &KernelSelection,
1084        _workload: &WorkloadCharacteristics,
1085        _system_state: &SystemState,
1086    ) {
1087        self.total_selections += 1;
1088
1089        // In a real implementation, this would determine if the selection was optimal
1090        // For now, assume high-confidence selections are more likely to be optimal
1091        if selection.confidence > 0.8 {
1092            self.optimal_selections += 1;
1093        }
1094    }
1095
1096    pub fn get_overall_accuracy(&self) -> f64 {
1097        if self.total_selections == 0 {
1098            0.0
1099        } else {
1100            self.optimal_selections as f64 / self.total_selections as f64
1101        }
1102    }
1103}
1104
1105// Default implementations
1106impl Default for ScoreBasedConfig {
1107    fn default() -> Self {
1108        Self {
1109            execution_time_weight: 0.4,
1110            memory_usage_weight: 0.2,
1111            cache_efficiency_weight: 0.2,
1112            historical_weight: 0.15,
1113            switching_penalty: 0.05,
1114        }
1115    }
1116}
1117
1118impl Default for KernelUsageStats {
1119    fn default() -> Self {
1120        Self {
1121            total_executions: 0,
1122            successful_executions: 0,
1123            avg_execution_time: Duration::from_secs(0),
1124            last_used: std::time::SystemTime::now(),
1125            selection_frequency: 0.0,
1126        }
1127    }
1128}
1129
1130impl Default for crate::performance_modeling::EnvironmentalFactors {
1131    fn default() -> Self {
1132        Self {
1133            ambient_temperature: None,
1134            system_load: 0.0,
1135            background_processes: 0,
1136            network_activity: 0.0,
1137            storage_io: 0.0,
1138            available_memory: 0,
1139            cpu_frequency: None,
1140            gpu_frequency: None,
1141        }
1142    }
1143}
1144
1145impl Default for TuningParameters {
1146    fn default() -> Self {
1147        Self {
1148            thread_count: 1,
1149            vector_width: 1,
1150            block_size: Some(1024),
1151            tile_size: None,
1152            unroll_factor: 1,
1153            scheduling_strategy: crate::performance_tuning::SchedulingStrategy::Static,
1154            memory_allocation_strategy:
1155                crate::performance_tuning::MemoryAllocationStrategy::Default,
1156            optimization_level: crate::performance_tuning::OptimizationLevel::Default,
1157            backend_specific: HashMap::new(),
1158        }
1159    }
1160}
1161
1162impl Default for ActualPerformance {
1163    fn default() -> Self {
1164        Self {
1165            execution_time: Duration::from_secs(0),
1166            throughput: 0.0,
1167            memory_usage_peak: 0,
1168            power_consumption_avg: 0.0,
1169            cache_hit_ratio: 0.0,
1170            thermal_increase: 0.0,
1171            cpu_utilization: 0.0,
1172        }
1173    }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178    use super::*;
1179
1180    #[test]
1181    fn test_kernel_registry() {
1182        let mut registry = KernelRegistry::new();
1183
1184        let kernel = KernelImplementation {
1185            id: "test_kernel".to_string(),
1186            name: "Test Kernel".to_string(),
1187            operation_type: OperationType::MatrixMultiply,
1188            backend_type: BackendType::Cpu,
1189            variant: KernelVariant::Naive,
1190            characteristics: KernelCharacteristics {
1191                optimal_size_range: (1, 1000),
1192                memory_pattern: AccessPattern::Sequential,
1193                compute_intensity: 1.0,
1194                parallelization_efficiency: 0.8,
1195                cache_efficiency: 0.7,
1196                memory_bandwidth_utilization: 0.6,
1197                initialization_overhead: Duration::from_millis(1),
1198                scalability: ScalabilityCharacteristics {
1199                    size_scaling: ScalingBehavior::Linear,
1200                    thread_scaling: ScalingBehavior::Linear,
1201                    memory_hierarchy_scaling: ScalingBehavior::Constant,
1202                },
1203            },
1204            constraints: KernelConstraints {
1205                min_size: 1,
1206                max_size: Some(1000),
1207                supported_dtypes: vec![DataType::F32],
1208                required_alignment: 4,
1209                supported_shapes: None,
1210                required_features: vec![],
1211            },
1212            implementation: std::sync::Arc::new(MockKernelExecutor),
1213        };
1214
1215        assert!(registry.register_kernel(kernel).is_ok());
1216
1217        let inputs = KernelInputs {
1218            input_shapes: vec![vec![10, 10]],
1219            data_types: vec![DataType::F32],
1220            total_size: 400,
1221            operation_params: HashMap::new(),
1222            device: Device::cpu().unwrap(),
1223        };
1224
1225        let candidates = registry
1226            .get_candidates(OperationType::MatrixMultiply, BackendType::Cpu, &inputs)
1227            .unwrap();
1228        assert_eq!(candidates.len(), 1);
1229    }
1230
1231    #[test]
1232    fn test_performance_tracker() {
1233        let mut tracker = PerformanceTracker::new();
1234
1235        let selection = KernelSelection {
1236            kernel: create_test_kernel(),
1237            confidence: 0.9,
1238            selection_reason: SelectionReason::ScoreBased(0.8),
1239            alternatives: vec![],
1240        };
1241
1242        let workload = WorkloadCharacteristics {
1243            operation_type: OperationType::MatrixMultiply,
1244            data_size: 1000,
1245            data_shape: vec![10, 10],
1246            data_type: DataType::F32,
1247            access_pattern: AccessPattern::Sequential,
1248            compute_intensity: 1.0,
1249            memory_bandwidth_requirement: 0.5,
1250            parallelization_potential: 0.8,
1251            cache_locality: 0.7,
1252            branch_predictability: 0.9,
1253            vectorization_potential: 0.8,
1254        };
1255
1256        let system_state = SystemState {
1257            cpu_utilization: 0.5,
1258            memory_utilization: 0.6,
1259            thermal_state: crate::performance_tuning::ThermalState {
1260                cpu_temperature: 65.0,
1261                gpu_temperature: Some(70.0),
1262                thermal_throttling_active: false,
1263                cooling_efficiency: 0.85,
1264            },
1265            power_state: crate::performance_tuning::PowerState {
1266                power_limit: Some(100.0),
1267                current_power_draw: 75.0,
1268                battery_level: Some(0.8),
1269                power_efficiency_mode: crate::performance_tuning::PowerEfficiencyMode::Balanced,
1270            },
1271            concurrent_workloads: 2,
1272            available_memory_bandwidth: 0.7,
1273            cache_pressure: 0.4,
1274            numa_topology: crate::performance_tuning::NumaTopologyState {
1275                node_count: 1,
1276                current_node: 0,
1277                memory_distribution: vec![0.6],
1278                cross_node_traffic: 0.0,
1279            },
1280        };
1281
1282        assert!(tracker
1283            .track_selection(&selection, &workload, &system_state)
1284            .is_ok());
1285
1286        let stats = tracker.get_statistics();
1287        assert_eq!(stats.total_selections, 1);
1288    }
1289
1290    fn create_test_kernel() -> KernelImplementation {
1291        KernelImplementation {
1292            id: "test_kernel".to_string(),
1293            name: "Test Kernel".to_string(),
1294            operation_type: OperationType::MatrixMultiply,
1295            backend_type: BackendType::Cpu,
1296            variant: KernelVariant::Naive,
1297            characteristics: KernelCharacteristics {
1298                optimal_size_range: (1, 1000),
1299                memory_pattern: AccessPattern::Sequential,
1300                compute_intensity: 1.0,
1301                parallelization_efficiency: 0.8,
1302                cache_efficiency: 0.7,
1303                memory_bandwidth_utilization: 0.6,
1304                initialization_overhead: Duration::from_millis(1),
1305                scalability: ScalabilityCharacteristics {
1306                    size_scaling: ScalingBehavior::Linear,
1307                    thread_scaling: ScalingBehavior::Linear,
1308                    memory_hierarchy_scaling: ScalingBehavior::Constant,
1309                },
1310            },
1311            constraints: KernelConstraints {
1312                min_size: 1,
1313                max_size: Some(1000),
1314                supported_dtypes: vec![DataType::F32],
1315                required_alignment: 4,
1316                supported_shapes: None,
1317                required_features: vec![],
1318            },
1319            implementation: std::sync::Arc::new(MockKernelExecutor),
1320        }
1321    }
1322
1323    #[derive(Debug)]
1324    struct MockKernelExecutor;
1325
1326    impl KernelExecutor for MockKernelExecutor {
1327        fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs> {
1328            Ok(KernelOutputs {
1329                output_shapes: inputs.input_shapes.clone(),
1330                execution_time: Duration::from_millis(10),
1331                memory_usage: inputs.total_size,
1332                success: true,
1333                error_message: None,
1334            })
1335        }
1336
1337        fn estimate_execution_time(&self, _inputs: &KernelInputs) -> Duration {
1338            Duration::from_millis(10)
1339        }
1340
1341        fn can_handle(&self, _inputs: &KernelInputs) -> bool {
1342            true
1343        }
1344
1345        fn get_resource_requirements(&self, inputs: &KernelInputs) -> ResourceRequirements {
1346            ResourceRequirements {
1347                memory: inputs.total_size,
1348                compute_units: 1,
1349                bandwidth: 1000,
1350                temporary_storage: 0,
1351            }
1352        }
1353    }
1354}