1use crate::error::BackendResult as Result;
8use crate::performance_modeling::{
9 EnvironmentalFactors, PerformanceMeasurement, RuntimePerformanceModeler,
10};
11use crate::performance_tuning::{
12 AccessPattern, ActualPerformance, DataType, OperationType, PerformancePrediction, SystemState,
13 TuningParameters, WorkloadCharacteristics,
14};
15use crate::{BackendType, Device};
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, Mutex, RwLock};
19use std::time::{Duration, Instant};
20use torsh_core::error::TorshError;
21
22#[cfg(feature = "serialize")]
23use serde::{Deserialize, Serialize};
24
25#[cfg(not(feature = "std"))]
26use alloc::{boxed::Box, format, string::String, vec::Vec};
27
28static MEASUREMENT_ID_COUNTER: AtomicU64 = AtomicU64::new(1);
30
31fn generate_measurement_id() -> u64 {
33 MEASUREMENT_ID_COUNTER.fetch_add(1, Ordering::SeqCst)
34}
35
36pub struct AdaptiveKernelSelector {
38 kernel_registry: Arc<RwLock<KernelRegistry>>,
40 performance_modeler: Arc<RuntimePerformanceModeler>,
42 selection_algorithm: SelectionAlgorithm,
44 performance_tracker: Arc<Mutex<PerformanceTracker>>,
46 config: AdaptiveSelectionConfig,
48}
49
50#[derive(Debug)]
52pub struct KernelRegistry {
53 kernels: HashMap<(OperationType, BackendType), Vec<KernelImplementation>>,
55 custom_kernels: HashMap<String, Box<dyn CustomKernel + Send + Sync>>,
57 kernel_characteristics: HashMap<String, KernelCharacteristics>,
59 #[allow(dead_code)]
61 default_kernels: HashMap<(OperationType, BackendType), String>,
62}
63
64#[derive(Debug, Clone)]
66pub struct KernelImplementation {
67 pub id: String,
69 pub name: String,
71 pub operation_type: OperationType,
73 pub backend_type: BackendType,
75 pub variant: KernelVariant,
77 pub characteristics: KernelCharacteristics,
79 pub constraints: KernelConstraints,
81 pub implementation: Arc<dyn KernelExecutor + Send + Sync>,
83}
84
85#[derive(Debug, Clone, PartialEq, Eq)]
87#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
88pub enum KernelVariant {
89 Naive,
91 Optimized,
93 Tiled,
95 Vectorized,
97 Parallel,
99 Fused,
101 HardwareSpecific(String),
103 Custom(String),
105}
106
107#[derive(Debug, Clone)]
109#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
110pub struct KernelCharacteristics {
111 pub optimal_size_range: (usize, usize),
113 pub memory_pattern: AccessPattern,
115 pub compute_intensity: f64,
117 pub parallelization_efficiency: f64,
119 pub cache_efficiency: f64,
121 pub memory_bandwidth_utilization: f64,
123 pub initialization_overhead: Duration,
125 pub scalability: ScalabilityCharacteristics,
127}
128
129#[derive(Debug, Clone)]
131#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
132pub struct ScalabilityCharacteristics {
133 pub size_scaling: ScalingBehavior,
135 pub thread_scaling: ScalingBehavior,
137 pub memory_hierarchy_scaling: ScalingBehavior,
139}
140
141#[derive(Debug, Clone, PartialEq, Eq)]
143#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
144pub enum ScalingBehavior {
145 Linear,
147 Logarithmic,
149 Exponential,
151 Constant,
153 Custom(String),
155}
156
157#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
160pub struct KernelConstraints {
161 pub min_size: usize,
163 pub max_size: Option<usize>,
165 pub supported_dtypes: Vec<DataType>,
167 pub required_alignment: usize,
169 pub supported_shapes: Option<Vec<Vec<usize>>>,
171 pub required_features: Vec<String>,
173}
174
175pub trait KernelExecutor: std::fmt::Debug + Send + Sync {
177 fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs>;
179
180 fn estimate_execution_time(&self, inputs: &KernelInputs) -> Duration;
182
183 fn can_handle(&self, inputs: &KernelInputs) -> bool;
185
186 fn get_resource_requirements(&self, inputs: &KernelInputs) -> ResourceRequirements;
188}
189
190#[derive(Debug, Clone)]
192pub struct KernelInputs {
193 pub input_shapes: Vec<Vec<usize>>,
195 pub data_types: Vec<DataType>,
197 pub total_size: usize,
199 pub operation_params: HashMap<String, KernelParameter>,
201 pub device: Device,
203}
204
205#[derive(Debug, Clone)]
207pub struct KernelOutputs {
208 pub output_shapes: Vec<Vec<usize>>,
210 pub execution_time: Duration,
212 pub memory_usage: usize,
214 pub success: bool,
216 pub error_message: Option<String>,
218}
219
220#[derive(Debug, Clone)]
222#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
223pub enum KernelParameter {
224 Integer(i64),
225 Float(f64),
226 String(String),
227 Boolean(bool),
228 IntegerArray(Vec<i64>),
229 FloatArray(Vec<f64>),
230}
231
232#[derive(Debug, Clone)]
234#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
235pub struct ResourceRequirements {
236 pub memory: usize,
238 pub compute_units: usize,
240 pub bandwidth: usize,
242 pub temporary_storage: usize,
244}
245
246pub trait CustomKernel: std::fmt::Debug + Send + Sync {
248 fn name(&self) -> &str;
250
251 fn operation_type(&self) -> OperationType;
253
254 fn backend_type(&self) -> BackendType;
256
257 fn characteristics(&self) -> KernelCharacteristics;
259
260 fn constraints(&self) -> KernelConstraints;
262
263 fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs>;
265
266 fn benchmark(&self, inputs: &KernelInputs, iterations: usize) -> Result<BenchmarkResult>;
268}
269
270#[derive(Debug, Clone)]
272#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
273pub struct BenchmarkResult {
274 pub avg_execution_time: Duration,
276 pub min_execution_time: Duration,
278 pub max_execution_time: Duration,
280 pub std_deviation: Duration,
282 pub throughput: f64,
284 pub memory_bandwidth: f64,
286 pub cache_hit_rate: f64,
288}
289
290#[derive(Debug, Clone)]
292pub enum SelectionAlgorithm {
293 ScoreBased(ScoreBasedConfig),
295 MachineLearning(MLBasedConfig),
297 Hybrid(HybridConfig),
299}
300
301#[derive(Debug, Clone)]
303#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
304pub struct ScoreBasedConfig {
305 pub execution_time_weight: f64,
307 pub memory_usage_weight: f64,
309 pub cache_efficiency_weight: f64,
311 pub historical_weight: f64,
313 pub switching_penalty: f64,
315}
316
317#[derive(Debug, Clone)]
319#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
320pub struct MLBasedConfig {
321 pub model_type: MLModelType,
323 pub training_params: MLTrainingParams,
325 pub feature_weights: HashMap<String, f64>,
327}
328
329#[derive(Debug, Clone, PartialEq, Eq)]
331#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
332pub enum MLModelType {
333 DecisionTree,
334 RandomForest,
335 NeuralNetwork,
336 SupportVectorMachine,
337 LinearRegression,
338 Custom(String),
339}
340
341#[derive(Debug, Clone)]
343#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
344pub struct MLTrainingParams {
345 pub learning_rate: f64,
347 pub epochs: usize,
349 pub batch_size: usize,
351 pub regularization: f64,
353}
354
355#[derive(Debug, Clone)]
357#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
358pub struct HybridConfig {
359 pub score_based: ScoreBasedConfig,
361 pub ml_based: MLBasedConfig,
363 pub ml_threshold: f64,
365}
366
367#[derive(Debug)]
369pub struct PerformanceTracker {
370 #[allow(dead_code)]
372 performance_history: HashMap<String, Vec<KernelPerformanceRecord>>,
373 usage_stats: HashMap<String, KernelUsageStats>,
375 selection_accuracy: SelectionAccuracyTracker,
377}
378
379#[derive(Debug, Clone)]
381#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
382pub struct KernelPerformanceRecord {
383 pub timestamp: std::time::SystemTime,
385 pub input_characteristics: WorkloadCharacteristics,
387 pub system_state: SystemState,
389 pub actual_performance: ActualPerformance,
391 pub predicted_performance: Option<PerformancePrediction>,
393 pub selection_confidence: f64,
395}
396
397#[derive(Debug, Clone)]
399#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
400pub struct KernelUsageStats {
401 pub total_executions: usize,
403 pub successful_executions: usize,
405 pub avg_execution_time: Duration,
407 pub last_used: std::time::SystemTime,
409 pub selection_frequency: f64,
411}
412
413#[derive(Debug)]
415pub struct SelectionAccuracyTracker {
416 total_selections: usize,
418 optimal_selections: usize,
420 accuracy_by_operation: HashMap<OperationType, f64>,
422 accuracy_by_backend: HashMap<BackendType, f64>,
424}
425
426#[derive(Debug, Clone)]
428#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
429pub struct AdaptiveSelectionConfig {
430 pub enable_learning: bool,
432 pub exploration_factor: f64,
434 pub min_confidence_threshold: f64,
436 pub max_concurrent_benchmarks: usize,
438 pub benchmark_timeout: Duration,
440 pub history_retention: Duration,
442}
443
444impl Default for AdaptiveSelectionConfig {
445 fn default() -> Self {
446 Self {
447 enable_learning: true,
448 exploration_factor: 0.1,
449 min_confidence_threshold: 0.8,
450 max_concurrent_benchmarks: 4,
451 benchmark_timeout: Duration::from_secs(30),
452 history_retention: Duration::from_secs(7 * 24 * 3600), }
454 }
455}
456
457impl AdaptiveKernelSelector {
458 pub fn new(performance_modeler: Arc<RuntimePerformanceModeler>) -> Self {
460 Self {
461 kernel_registry: Arc::new(RwLock::new(KernelRegistry::new())),
462 performance_modeler,
463 selection_algorithm: SelectionAlgorithm::ScoreBased(ScoreBasedConfig::default()),
464 performance_tracker: Arc::new(Mutex::new(PerformanceTracker::new())),
465 config: AdaptiveSelectionConfig::default(),
466 }
467 }
468
469 pub fn register_kernel(&self, kernel: KernelImplementation) -> Result<()> {
471 let mut registry = self
472 .kernel_registry
473 .write()
474 .expect("lock should not be poisoned");
475 registry.register_kernel(kernel)
476 }
477
478 pub fn register_custom_kernel(
480 &self,
481 kernel: Box<dyn CustomKernel + Send + Sync>,
482 ) -> Result<()> {
483 let mut registry = self
484 .kernel_registry
485 .write()
486 .expect("lock should not be poisoned");
487 registry.register_custom_kernel(kernel)
488 }
489
490 pub fn select_kernel(
492 &self,
493 operation_type: OperationType,
494 backend_type: BackendType,
495 inputs: &KernelInputs,
496 workload: &WorkloadCharacteristics,
497 system_state: &SystemState,
498 ) -> Result<KernelSelection> {
499 let registry = self
500 .kernel_registry
501 .read()
502 .expect("lock should not be poisoned");
503
504 let candidates = registry.get_candidates(operation_type, backend_type, inputs)?;
506
507 if candidates.is_empty() {
508 return Err(TorshError::BackendError(format!(
509 "No suitable kernels found for operation {:?} on backend {:?}",
510 operation_type, backend_type
511 )));
512 }
513
514 let selection = match &self.selection_algorithm {
516 SelectionAlgorithm::ScoreBased(config) => {
517 self.score_based_selection(&candidates, inputs, workload, system_state, config)?
518 }
519 SelectionAlgorithm::MachineLearning(config) => {
520 self.ml_based_selection(&candidates, inputs, workload, system_state, config)?
521 }
522 SelectionAlgorithm::Hybrid(config) => {
523 self.hybrid_selection(&candidates, inputs, workload, system_state, config)?
524 }
525 };
526
527 if self.config.enable_learning {
529 self.track_selection(&selection, workload, system_state)?;
530 }
531
532 Ok(selection)
533 }
534
535 fn score_based_selection(
537 &self,
538 candidates: &[KernelImplementation],
539 inputs: &KernelInputs,
540 workload: &WorkloadCharacteristics,
541 system_state: &SystemState,
542 config: &ScoreBasedConfig,
543 ) -> Result<KernelSelection> {
544 let mut best_kernel = None;
545 let mut best_score = f64::NEG_INFINITY;
546
547 for kernel in candidates {
548 let score =
549 self.calculate_kernel_score(kernel, inputs, workload, system_state, config)?;
550
551 if score > best_score {
552 best_score = score;
553 best_kernel = Some(kernel);
554 }
555 }
556
557 let selected_kernel = best_kernel
558 .ok_or_else(|| TorshError::BackendError("No suitable kernel found".to_string()))?;
559
560 Ok(KernelSelection {
561 kernel: selected_kernel.clone(),
562 confidence: (best_score + 1.0) / 2.0, selection_reason: SelectionReason::ScoreBased(best_score),
564 alternatives: candidates
565 .iter()
566 .filter(|k| k.id != selected_kernel.id)
567 .cloned()
568 .collect(),
569 })
570 }
571
572 fn ml_based_selection(
574 &self,
575 candidates: &[KernelImplementation],
576 inputs: &KernelInputs,
577 workload: &WorkloadCharacteristics,
578 system_state: &SystemState,
579 _config: &MLBasedConfig,
580 ) -> Result<KernelSelection> {
581 let score_config = ScoreBasedConfig::default();
584 self.score_based_selection(candidates, inputs, workload, system_state, &score_config)
585 }
586
587 fn hybrid_selection(
589 &self,
590 candidates: &[KernelImplementation],
591 inputs: &KernelInputs,
592 workload: &WorkloadCharacteristics,
593 system_state: &SystemState,
594 config: &HybridConfig,
595 ) -> Result<KernelSelection> {
596 let ml_confidence = self.get_ml_confidence(inputs, workload, system_state)?;
598
599 if ml_confidence > config.ml_threshold {
600 self.ml_based_selection(candidates, inputs, workload, system_state, &config.ml_based)
601 } else {
602 self.score_based_selection(
603 candidates,
604 inputs,
605 workload,
606 system_state,
607 &config.score_based,
608 )
609 }
610 }
611
612 fn calculate_kernel_score(
614 &self,
615 kernel: &KernelImplementation,
616 inputs: &KernelInputs,
617 workload: &WorkloadCharacteristics,
618 system_state: &SystemState,
619 config: &ScoreBasedConfig,
620 ) -> Result<f64> {
621 let mut score = 0.0;
622
623 let predicted_time = self.predict_execution_time(kernel, inputs, workload, system_state)?;
625 let time_score = 1.0 / (1.0 + predicted_time.as_secs_f64());
626 score += config.execution_time_weight * time_score;
627
628 let memory_requirements = kernel.implementation.get_resource_requirements(inputs);
630 let memory_score = 1.0 / (1.0 + memory_requirements.memory as f64 / 1024.0 / 1024.0);
631 score += config.memory_usage_weight * memory_score;
632
633 let cache_score = kernel.characteristics.cache_efficiency;
635 score += config.cache_efficiency_weight * cache_score;
636
637 let historical_score = self.get_historical_performance_score(&kernel.id)?;
639 score += config.historical_weight * historical_score;
640
641 if let Some(current_kernel) = self.get_current_kernel(workload)? {
643 if current_kernel != kernel.id {
644 score -= config.switching_penalty;
645 }
646 }
647
648 Ok(score)
649 }
650
651 fn predict_execution_time(
653 &self,
654 kernel: &KernelImplementation,
655 inputs: &KernelInputs,
656 workload: &WorkloadCharacteristics,
657 system_state: &SystemState,
658 ) -> Result<Duration> {
659 let device_id = inputs.device.id();
661
662 let _measurement = PerformanceMeasurement {
664 id: generate_measurement_id(),
665 timestamp: std::time::SystemTime::now(),
666 backend_type: kernel.backend_type,
667 device_id,
668 workload: workload.clone(),
669 parameters: TuningParameters::default(),
670 system_state: system_state.clone(),
671 actual_performance: ActualPerformance::default(),
672 predicted_performance: None,
673 prediction_accuracy: None,
674 environment: crate::performance_modeling::EnvironmentalFactors::default(),
675 };
676
677 let default_params = TuningParameters::default();
679 let default_env = EnvironmentalFactors::default();
680 let prediction = self.performance_modeler.predict_performance(
681 kernel.backend_type,
682 workload,
683 &default_params,
684 system_state,
685 &default_env,
686 )?;
687
688 Ok(prediction.execution_time)
689 }
690
691 fn get_historical_performance_score(&self, kernel_id: &str) -> Result<f64> {
693 let tracker = self
694 .performance_tracker
695 .lock()
696 .expect("lock should not be poisoned");
697
698 if let Some(stats) = tracker.usage_stats.get(kernel_id) {
699 let success_rate = stats.successful_executions as f64 / stats.total_executions as f64;
700 let recency_factor = self.calculate_recency_factor(stats.last_used);
701 Ok(success_rate * recency_factor)
702 } else {
703 Ok(0.5) }
705 }
706
707 fn calculate_recency_factor(&self, last_used: std::time::SystemTime) -> f64 {
709 let now = std::time::SystemTime::now();
710 let elapsed = now
711 .duration_since(last_used)
712 .unwrap_or(Duration::from_secs(0));
713 let days_elapsed = elapsed.as_secs() as f64 / (24.0 * 3600.0);
714
715 (-days_elapsed / 7.0).exp()
717 }
718
719 fn get_current_kernel(&self, _workload: &WorkloadCharacteristics) -> Result<Option<String>> {
721 Ok(None)
724 }
725
726 fn get_ml_confidence(
728 &self,
729 _inputs: &KernelInputs,
730 _workload: &WorkloadCharacteristics,
731 _system_state: &SystemState,
732 ) -> Result<f64> {
733 Ok(0.5)
735 }
736
737 fn track_selection(
739 &self,
740 selection: &KernelSelection,
741 workload: &WorkloadCharacteristics,
742 system_state: &SystemState,
743 ) -> Result<()> {
744 let mut tracker = self
745 .performance_tracker
746 .lock()
747 .expect("lock should not be poisoned");
748 tracker.track_selection(selection, workload, system_state)
749 }
750
751 pub fn update_performance_feedback(
753 &self,
754 kernel_id: &str,
755 actual_performance: ActualPerformance,
756 predicted_performance: Option<PerformancePrediction>,
757 ) -> Result<()> {
758 let mut tracker = self
759 .performance_tracker
760 .lock()
761 .expect("lock should not be poisoned");
762 tracker.update_performance_feedback(kernel_id, actual_performance, predicted_performance)
763 }
764
765 pub fn get_selection_statistics(&self) -> Result<SelectionStatistics> {
767 let tracker = self
768 .performance_tracker
769 .lock()
770 .expect("lock should not be poisoned");
771 Ok(tracker.get_statistics())
772 }
773
774 pub fn benchmark_kernels(
776 &self,
777 operation_type: OperationType,
778 backend_type: BackendType,
779 test_inputs: &[KernelInputs],
780 ) -> Result<BenchmarkResults> {
781 let registry = self
782 .kernel_registry
783 .read()
784 .expect("lock should not be poisoned");
785 let kernels = registry.get_kernels_for_operation(operation_type, backend_type);
786
787 let mut results = BenchmarkResults::new();
788
789 for kernel in kernels {
790 for inputs in test_inputs {
791 if kernel.implementation.can_handle(inputs) {
792 let benchmark = self.benchmark_kernel(&kernel, inputs)?;
793 results.add_result(kernel.id.clone(), benchmark);
794 }
795 }
796 }
797
798 Ok(results)
799 }
800
801 fn benchmark_kernel(
803 &self,
804 kernel: &KernelImplementation,
805 inputs: &KernelInputs,
806 ) -> Result<BenchmarkResult> {
807 let iterations = 10;
808 let mut execution_times = Vec::new();
809
810 for _ in 0..iterations {
811 let start = Instant::now();
812 let result = kernel.implementation.execute(inputs)?;
813 let execution_time = start.elapsed();
814
815 if result.success {
816 execution_times.push(execution_time);
817 }
818 }
819
820 if execution_times.is_empty() {
821 return Err(TorshError::BackendError(
822 "All benchmark iterations failed".to_string(),
823 ));
824 }
825
826 let avg_time = execution_times.iter().sum::<Duration>() / execution_times.len() as u32;
827 let min_time = *execution_times
828 .iter()
829 .min()
830 .expect("execution_times should not be empty after check");
831 let max_time = *execution_times
832 .iter()
833 .max()
834 .expect("execution_times should not be empty after check");
835
836 let variance = execution_times
838 .iter()
839 .map(|t| {
840 let diff = t.as_secs_f64() - avg_time.as_secs_f64();
841 diff * diff
842 })
843 .sum::<f64>()
844 / execution_times.len() as f64;
845 let std_dev = Duration::from_secs_f64(variance.sqrt());
846
847 let total_bytes_accessed = (inputs.total_size * 2) as f64; let memory_bandwidth = total_bytes_accessed / avg_time.as_secs_f64();
851
852 let cache_size_estimate = 32.0 * 1024.0 * 1024.0; let data_size_ratio = (inputs.total_size as f64 / cache_size_estimate).min(1.0);
857 let variance_factor = 1.0 - (std_dev.as_secs_f64() / avg_time.as_secs_f64()).min(0.5);
858 let cache_hit_rate = (1.0 - data_size_ratio) * variance_factor;
859
860 Ok(BenchmarkResult {
861 avg_execution_time: avg_time,
862 min_execution_time: min_time,
863 max_execution_time: max_time,
864 std_deviation: std_dev,
865 throughput: 1.0 / avg_time.as_secs_f64(),
866 memory_bandwidth,
867 cache_hit_rate,
868 })
869 }
870}
871
872#[derive(Debug, Clone)]
874pub struct KernelSelection {
875 pub kernel: KernelImplementation,
877 pub confidence: f64,
879 pub selection_reason: SelectionReason,
881 pub alternatives: Vec<KernelImplementation>,
883}
884
885#[derive(Debug, Clone)]
887pub enum SelectionReason {
888 ScoreBased(f64),
890 MachineLearning(f64),
892 Hybrid(f64),
894 Default,
896}
897
898#[derive(Debug)]
900pub struct BenchmarkResults {
901 results: HashMap<String, BenchmarkResult>,
903}
904
905impl BenchmarkResults {
906 pub fn new() -> Self {
907 Self {
908 results: HashMap::new(),
909 }
910 }
911
912 pub fn add_result(&mut self, kernel_id: String, result: BenchmarkResult) {
913 self.results.insert(kernel_id, result);
914 }
915
916 pub fn get_result(&self, kernel_id: &str) -> Option<&BenchmarkResult> {
917 self.results.get(kernel_id)
918 }
919
920 pub fn get_best_kernel(&self) -> Option<(&String, &BenchmarkResult)> {
921 self.results
922 .iter()
923 .min_by(|a, b| a.1.avg_execution_time.cmp(&b.1.avg_execution_time))
924 }
925}
926
927#[derive(Debug, Clone)]
929#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
930pub struct SelectionStatistics {
931 pub total_selections: usize,
933 pub overall_accuracy: f64,
935 pub accuracy_by_operation: HashMap<OperationType, f64>,
937 pub accuracy_by_backend: HashMap<BackendType, f64>,
939 pub popular_kernels: Vec<(String, usize)>,
941}
942
943impl KernelRegistry {
944 pub fn new() -> Self {
945 Self {
946 kernels: HashMap::new(),
947 custom_kernels: HashMap::new(),
948 kernel_characteristics: HashMap::new(),
949 default_kernels: HashMap::new(),
950 }
951 }
952
953 pub fn register_kernel(&mut self, kernel: KernelImplementation) -> Result<()> {
954 let key = (kernel.operation_type, kernel.backend_type);
955 self.kernels
956 .entry(key)
957 .or_insert_with(Vec::new)
958 .push(kernel.clone());
959 self.kernel_characteristics
960 .insert(kernel.id.clone(), kernel.characteristics);
961 Ok(())
962 }
963
964 pub fn register_custom_kernel(
965 &mut self,
966 kernel: Box<dyn CustomKernel + Send + Sync>,
967 ) -> Result<()> {
968 let name = kernel.name().to_string();
969 self.custom_kernels.insert(name, kernel);
970 Ok(())
971 }
972
973 pub fn get_candidates(
974 &self,
975 operation_type: OperationType,
976 backend_type: BackendType,
977 inputs: &KernelInputs,
978 ) -> Result<Vec<KernelImplementation>> {
979 let key = (operation_type, backend_type);
980
981 if let Some(kernels) = self.kernels.get(&key) {
982 let candidates = kernels
983 .iter()
984 .filter(|k| k.implementation.can_handle(inputs))
985 .cloned()
986 .collect();
987 Ok(candidates)
988 } else {
989 Ok(Vec::new())
990 }
991 }
992
993 pub fn get_kernels_for_operation(
994 &self,
995 operation_type: OperationType,
996 backend_type: BackendType,
997 ) -> Vec<KernelImplementation> {
998 let key = (operation_type, backend_type);
999 self.kernels.get(&key).cloned().unwrap_or_default()
1000 }
1001}
1002
1003impl PerformanceTracker {
1004 pub fn new() -> Self {
1005 Self {
1006 performance_history: HashMap::new(),
1007 usage_stats: HashMap::new(),
1008 selection_accuracy: SelectionAccuracyTracker::new(),
1009 }
1010 }
1011
1012 pub fn track_selection(
1013 &mut self,
1014 selection: &KernelSelection,
1015 workload: &WorkloadCharacteristics,
1016 system_state: &SystemState,
1017 ) -> Result<()> {
1018 let kernel_id = &selection.kernel.id;
1019
1020 let stats = self
1022 .usage_stats
1023 .entry(kernel_id.clone())
1024 .or_insert_with(KernelUsageStats::default);
1025 stats.total_executions += 1;
1026 stats.last_used = std::time::SystemTime::now();
1027
1028 self.selection_accuracy
1030 .track_selection(selection, workload, system_state);
1031
1032 Ok(())
1033 }
1034
1035 pub fn update_performance_feedback(
1036 &mut self,
1037 kernel_id: &str,
1038 actual_performance: ActualPerformance,
1039 _predicted_performance: Option<PerformancePrediction>,
1040 ) -> Result<()> {
1041 if let Some(stats) = self.usage_stats.get_mut(kernel_id) {
1043 stats.successful_executions += 1;
1044 stats.avg_execution_time = actual_performance.execution_time;
1045 }
1046
1047 Ok(())
1048 }
1049
1050 pub fn get_statistics(&self) -> SelectionStatistics {
1051 SelectionStatistics {
1052 total_selections: self.selection_accuracy.total_selections,
1053 overall_accuracy: self.selection_accuracy.get_overall_accuracy(),
1054 accuracy_by_operation: self.selection_accuracy.accuracy_by_operation.clone(),
1055 accuracy_by_backend: self.selection_accuracy.accuracy_by_backend.clone(),
1056 popular_kernels: self.get_popular_kernels(),
1057 }
1058 }
1059
1060 fn get_popular_kernels(&self) -> Vec<(String, usize)> {
1061 let mut kernels: Vec<_> = self
1062 .usage_stats
1063 .iter()
1064 .map(|(id, stats)| (id.clone(), stats.total_executions))
1065 .collect();
1066 kernels.sort_by(|a, b| b.1.cmp(&a.1));
1067 kernels.into_iter().take(10).collect()
1068 }
1069}
1070
1071impl SelectionAccuracyTracker {
1072 pub fn new() -> Self {
1073 Self {
1074 total_selections: 0,
1075 optimal_selections: 0,
1076 accuracy_by_operation: HashMap::new(),
1077 accuracy_by_backend: HashMap::new(),
1078 }
1079 }
1080
1081 pub fn track_selection(
1082 &mut self,
1083 selection: &KernelSelection,
1084 _workload: &WorkloadCharacteristics,
1085 _system_state: &SystemState,
1086 ) {
1087 self.total_selections += 1;
1088
1089 if selection.confidence > 0.8 {
1092 self.optimal_selections += 1;
1093 }
1094 }
1095
1096 pub fn get_overall_accuracy(&self) -> f64 {
1097 if self.total_selections == 0 {
1098 0.0
1099 } else {
1100 self.optimal_selections as f64 / self.total_selections as f64
1101 }
1102 }
1103}
1104
1105impl Default for ScoreBasedConfig {
1107 fn default() -> Self {
1108 Self {
1109 execution_time_weight: 0.4,
1110 memory_usage_weight: 0.2,
1111 cache_efficiency_weight: 0.2,
1112 historical_weight: 0.15,
1113 switching_penalty: 0.05,
1114 }
1115 }
1116}
1117
1118impl Default for KernelUsageStats {
1119 fn default() -> Self {
1120 Self {
1121 total_executions: 0,
1122 successful_executions: 0,
1123 avg_execution_time: Duration::from_secs(0),
1124 last_used: std::time::SystemTime::now(),
1125 selection_frequency: 0.0,
1126 }
1127 }
1128}
1129
1130impl Default for crate::performance_modeling::EnvironmentalFactors {
1131 fn default() -> Self {
1132 Self {
1133 ambient_temperature: None,
1134 system_load: 0.0,
1135 background_processes: 0,
1136 network_activity: 0.0,
1137 storage_io: 0.0,
1138 available_memory: 0,
1139 cpu_frequency: None,
1140 gpu_frequency: None,
1141 }
1142 }
1143}
1144
1145impl Default for TuningParameters {
1146 fn default() -> Self {
1147 Self {
1148 thread_count: 1,
1149 vector_width: 1,
1150 block_size: Some(1024),
1151 tile_size: None,
1152 unroll_factor: 1,
1153 scheduling_strategy: crate::performance_tuning::SchedulingStrategy::Static,
1154 memory_allocation_strategy:
1155 crate::performance_tuning::MemoryAllocationStrategy::Default,
1156 optimization_level: crate::performance_tuning::OptimizationLevel::Default,
1157 backend_specific: HashMap::new(),
1158 }
1159 }
1160}
1161
1162impl Default for ActualPerformance {
1163 fn default() -> Self {
1164 Self {
1165 execution_time: Duration::from_secs(0),
1166 throughput: 0.0,
1167 memory_usage_peak: 0,
1168 power_consumption_avg: 0.0,
1169 cache_hit_ratio: 0.0,
1170 thermal_increase: 0.0,
1171 cpu_utilization: 0.0,
1172 }
1173 }
1174}
1175
1176#[cfg(test)]
1177mod tests {
1178 use super::*;
1179
1180 #[test]
1181 fn test_kernel_registry() {
1182 let mut registry = KernelRegistry::new();
1183
1184 let kernel = KernelImplementation {
1185 id: "test_kernel".to_string(),
1186 name: "Test Kernel".to_string(),
1187 operation_type: OperationType::MatrixMultiply,
1188 backend_type: BackendType::Cpu,
1189 variant: KernelVariant::Naive,
1190 characteristics: KernelCharacteristics {
1191 optimal_size_range: (1, 1000),
1192 memory_pattern: AccessPattern::Sequential,
1193 compute_intensity: 1.0,
1194 parallelization_efficiency: 0.8,
1195 cache_efficiency: 0.7,
1196 memory_bandwidth_utilization: 0.6,
1197 initialization_overhead: Duration::from_millis(1),
1198 scalability: ScalabilityCharacteristics {
1199 size_scaling: ScalingBehavior::Linear,
1200 thread_scaling: ScalingBehavior::Linear,
1201 memory_hierarchy_scaling: ScalingBehavior::Constant,
1202 },
1203 },
1204 constraints: KernelConstraints {
1205 min_size: 1,
1206 max_size: Some(1000),
1207 supported_dtypes: vec![DataType::F32],
1208 required_alignment: 4,
1209 supported_shapes: None,
1210 required_features: vec![],
1211 },
1212 implementation: std::sync::Arc::new(MockKernelExecutor),
1213 };
1214
1215 assert!(registry.register_kernel(kernel).is_ok());
1216
1217 let inputs = KernelInputs {
1218 input_shapes: vec![vec![10, 10]],
1219 data_types: vec![DataType::F32],
1220 total_size: 400,
1221 operation_params: HashMap::new(),
1222 device: Device::cpu().unwrap(),
1223 };
1224
1225 let candidates = registry
1226 .get_candidates(OperationType::MatrixMultiply, BackendType::Cpu, &inputs)
1227 .unwrap();
1228 assert_eq!(candidates.len(), 1);
1229 }
1230
1231 #[test]
1232 fn test_performance_tracker() {
1233 let mut tracker = PerformanceTracker::new();
1234
1235 let selection = KernelSelection {
1236 kernel: create_test_kernel(),
1237 confidence: 0.9,
1238 selection_reason: SelectionReason::ScoreBased(0.8),
1239 alternatives: vec![],
1240 };
1241
1242 let workload = WorkloadCharacteristics {
1243 operation_type: OperationType::MatrixMultiply,
1244 data_size: 1000,
1245 data_shape: vec![10, 10],
1246 data_type: DataType::F32,
1247 access_pattern: AccessPattern::Sequential,
1248 compute_intensity: 1.0,
1249 memory_bandwidth_requirement: 0.5,
1250 parallelization_potential: 0.8,
1251 cache_locality: 0.7,
1252 branch_predictability: 0.9,
1253 vectorization_potential: 0.8,
1254 };
1255
1256 let system_state = SystemState {
1257 cpu_utilization: 0.5,
1258 memory_utilization: 0.6,
1259 thermal_state: crate::performance_tuning::ThermalState {
1260 cpu_temperature: 65.0,
1261 gpu_temperature: Some(70.0),
1262 thermal_throttling_active: false,
1263 cooling_efficiency: 0.85,
1264 },
1265 power_state: crate::performance_tuning::PowerState {
1266 power_limit: Some(100.0),
1267 current_power_draw: 75.0,
1268 battery_level: Some(0.8),
1269 power_efficiency_mode: crate::performance_tuning::PowerEfficiencyMode::Balanced,
1270 },
1271 concurrent_workloads: 2,
1272 available_memory_bandwidth: 0.7,
1273 cache_pressure: 0.4,
1274 numa_topology: crate::performance_tuning::NumaTopologyState {
1275 node_count: 1,
1276 current_node: 0,
1277 memory_distribution: vec![0.6],
1278 cross_node_traffic: 0.0,
1279 },
1280 };
1281
1282 assert!(tracker
1283 .track_selection(&selection, &workload, &system_state)
1284 .is_ok());
1285
1286 let stats = tracker.get_statistics();
1287 assert_eq!(stats.total_selections, 1);
1288 }
1289
1290 fn create_test_kernel() -> KernelImplementation {
1291 KernelImplementation {
1292 id: "test_kernel".to_string(),
1293 name: "Test Kernel".to_string(),
1294 operation_type: OperationType::MatrixMultiply,
1295 backend_type: BackendType::Cpu,
1296 variant: KernelVariant::Naive,
1297 characteristics: KernelCharacteristics {
1298 optimal_size_range: (1, 1000),
1299 memory_pattern: AccessPattern::Sequential,
1300 compute_intensity: 1.0,
1301 parallelization_efficiency: 0.8,
1302 cache_efficiency: 0.7,
1303 memory_bandwidth_utilization: 0.6,
1304 initialization_overhead: Duration::from_millis(1),
1305 scalability: ScalabilityCharacteristics {
1306 size_scaling: ScalingBehavior::Linear,
1307 thread_scaling: ScalingBehavior::Linear,
1308 memory_hierarchy_scaling: ScalingBehavior::Constant,
1309 },
1310 },
1311 constraints: KernelConstraints {
1312 min_size: 1,
1313 max_size: Some(1000),
1314 supported_dtypes: vec![DataType::F32],
1315 required_alignment: 4,
1316 supported_shapes: None,
1317 required_features: vec![],
1318 },
1319 implementation: std::sync::Arc::new(MockKernelExecutor),
1320 }
1321 }
1322
1323 #[derive(Debug)]
1324 struct MockKernelExecutor;
1325
1326 impl KernelExecutor for MockKernelExecutor {
1327 fn execute(&self, inputs: &KernelInputs) -> Result<KernelOutputs> {
1328 Ok(KernelOutputs {
1329 output_shapes: inputs.input_shapes.clone(),
1330 execution_time: Duration::from_millis(10),
1331 memory_usage: inputs.total_size,
1332 success: true,
1333 error_message: None,
1334 })
1335 }
1336
1337 fn estimate_execution_time(&self, _inputs: &KernelInputs) -> Duration {
1338 Duration::from_millis(10)
1339 }
1340
1341 fn can_handle(&self, _inputs: &KernelInputs) -> bool {
1342 true
1343 }
1344
1345 fn get_resource_requirements(&self, inputs: &KernelInputs) -> ResourceRequirements {
1346 ResourceRequirements {
1347 memory: inputs.total_size,
1348 compute_units: 1,
1349 bandwidth: 1000,
1350 temporary_storage: 0,
1351 }
1352 }
1353 }
1354}