1use crate::{Result, TensorError};
7use scirs2_core::profiling::Profiler;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11#[allow(dead_code)]
13pub struct AdvancedKernelRegistry {
14 kernels: Arc<Mutex<HashMap<String, Vec<SpecializedKernel>>>>,
16 profiler: Arc<Profiler>,
18 selection_strategy: KernelOptimizationStrategy,
20 performance_cache: Arc<Mutex<HashMap<String, KernelPerformanceData>>>,
22}
23
24#[derive(Debug, Clone)]
26pub struct SpecializedKernel {
27 pub id: String,
29 pub name: String,
31 pub operation: String,
33 pub hardware_requirements: HardwareRequirements,
35 pub optimal_data_profile: DataProfile,
37 pub performance_profile: PerformanceProfile,
39 pub implementation: KernelImplementation,
41 pub validator: Option<ValidationFunction>,
43}
44
45#[derive(Debug, Clone)]
47pub struct HardwareRequirements {
48 pub required_cpu_features: Vec<String>,
50 pub min_cache_sizes: CacheSizeRequirements,
52 pub min_memory_bandwidth: f64,
54 pub min_simd_registers: usize,
56 pub preferred_architecture: Vec<String>,
58}
59
60#[derive(Debug, Clone)]
62pub struct CacheSizeRequirements {
63 pub min_l1_size: usize,
64 pub min_l2_size: usize,
65 pub min_l3_size: usize,
66}
67
68#[derive(Debug, Clone)]
70pub struct DataProfile {
71 pub size_range: (usize, usize),
73 pub alignment_requirement: usize,
75 pub access_pattern: AccessPattern,
77 pub layout_preference: MemoryLayout,
79 pub sparsity_tolerance: f64,
81}
82
83#[derive(Debug, Clone, Copy)]
85pub enum AccessPattern {
86 Sequential,
87 Strided,
88 Random,
89 BlockedSequential,
90 CacheOblivious,
91}
92
93#[derive(Debug, Clone, Copy)]
95pub enum MemoryLayout {
96 RowMajor,
97 ColumnMajor,
98 Blocked,
99 Tiled,
100 Interleaved,
101}
102
103#[derive(Debug, Clone)]
105pub struct PerformanceProfile {
106 pub expected_throughput: f64,
108 pub expected_latency: f64,
110 pub memory_efficiency: f64,
112 pub cache_efficiency: f64,
114 pub energy_efficiency: f64,
116 pub scalability_factor: f64,
118}
119
120#[derive(Debug, Clone)]
122pub enum KernelImplementation {
123 Native(NativeKernelFn),
125 Assembly(AssemblyKernelFn),
127 Vectorized(VectorizedKernelFn),
129 Gpu(GpuKernelFn),
131 Hybrid(HybridKernelFn),
133}
134
135pub type NativeKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
137
138pub type AssemblyKernelFn =
140 unsafe fn(*const f32, *const f32, *mut f32, &KernelParams) -> Result<()>;
141
142pub type VectorizedKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
144
145pub type GpuKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
147
148pub type HybridKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
150
151pub type ValidationFunction = fn(&[f32], &[f32], &[f32], &KernelParams) -> bool;
153
154#[derive(Debug, Clone)]
156pub struct KernelParams {
157 pub dimensions: Vec<usize>,
159 pub strides: Vec<usize>,
161 pub data_type: String,
163 pub operation_params: HashMap<String, f64>,
165 pub performance_hints: Vec<String>,
167}
168
169#[derive(Debug, Clone)]
171pub enum KernelOptimizationStrategy {
172 MaxThroughput,
174 MinLatency,
176 EnergyEfficient,
178 Balanced,
180 Adaptive,
182}
183
184#[derive(Debug, Clone)]
186pub struct KernelPerformanceData {
187 pub measured_throughput: f64,
189 pub measured_latency: f64,
191 pub success_rate: f64,
193 pub execution_count: u64,
195 pub last_updated: std::time::Instant,
197}
198
199impl AdvancedKernelRegistry {
200 pub fn new(strategy: KernelOptimizationStrategy) -> Self {
202 let kernels = Arc::new(Mutex::new(HashMap::new()));
203 let profiler = Arc::new(Profiler::new());
204 let performance_cache = Arc::new(Mutex::new(HashMap::new()));
205
206 let mut registry = Self {
207 kernels,
208 profiler,
209 selection_strategy: strategy,
210 performance_cache,
211 };
212
213 registry
215 .register_default_kernels()
216 .expect("Failed to register default kernels");
217
218 registry
219 }
220
221 pub fn register_kernel(&self, kernel: SpecializedKernel) -> Result<()> {
223 let mut kernels = self.kernels.lock().map_err(|_| {
224 TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
225 })?;
226
227 let operation_kernels = kernels
228 .entry(kernel.operation.clone())
229 .or_insert_with(Vec::new);
230 operation_kernels.push(kernel);
231
232 operation_kernels.sort_by(|a, b| {
234 b.performance_profile
235 .expected_throughput
236 .partial_cmp(&a.performance_profile.expected_throughput)
237 .expect("Throughput values must be valid floating-point numbers")
238 });
239
240 Ok(())
241 }
242
243 pub fn select_optimal_kernel(
245 &self,
246 operation: &str,
247 data_size: usize,
248 data_profile: &DataProfile,
249 ) -> Result<SpecializedKernel> {
250 let kernels = self.kernels.lock().map_err(|_| {
251 TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
252 })?;
253
254 let operation_kernels = kernels.get(operation).ok_or_else(|| {
255 TensorError::compute_error_simple(format!(
256 "No kernels registered for operation: {}",
257 operation
258 ))
259 })?;
260
261 let mut scored_kernels: Vec<(f64, &SpecializedKernel)> = operation_kernels
263 .iter()
264 .map(|kernel| (self.score_kernel(kernel, data_size, data_profile), kernel))
265 .collect();
266
267 scored_kernels.sort_by(|a, b| {
269 b.0.partial_cmp(&a.0)
270 .expect("partial_cmp should not return None for valid values")
271 });
272
273 if let Some((score, kernel)) = scored_kernels.first() {
274 if *score > 0.0 {
275 return Ok((*kernel).clone());
276 }
277 }
278
279 Err(TensorError::compute_error_simple(
280 "No suitable kernel found".to_string(),
281 ))
282 }
283
284 fn score_kernel(
286 &self,
287 kernel: &SpecializedKernel,
288 data_size: usize,
289 data_profile: &DataProfile,
290 ) -> f64 {
291 let mut score = 0.0;
292
293 if data_size >= kernel.optimal_data_profile.size_range.0
295 && data_size <= kernel.optimal_data_profile.size_range.1
296 {
297 score += 0.3;
298 }
299
300 if std::mem::discriminant(&kernel.optimal_data_profile.access_pattern)
302 == std::mem::discriminant(&data_profile.access_pattern)
303 {
304 score += 0.2;
305 }
306
307 if std::mem::discriminant(&kernel.optimal_data_profile.layout_preference)
309 == std::mem::discriminant(&data_profile.layout_preference)
310 {
311 score += 0.2;
312 }
313
314 match self.selection_strategy {
316 KernelOptimizationStrategy::MaxThroughput => {
317 score += kernel.performance_profile.expected_throughput / 1e12 * 0.3;
318 }
319 KernelOptimizationStrategy::MinLatency => {
320 score += (1.0 / kernel.performance_profile.expected_latency.max(1e-9)) / 1e9 * 0.3;
321 }
322 KernelOptimizationStrategy::EnergyEfficient => {
323 score += kernel.performance_profile.energy_efficiency / 1e12 * 0.3;
324 }
325 KernelOptimizationStrategy::Balanced => {
326 score += (kernel.performance_profile.expected_throughput / 1e12
327 + kernel.performance_profile.energy_efficiency / 1e12)
328 * 0.15;
329 }
330 KernelOptimizationStrategy::Adaptive => {
331 score += self.get_adaptive_score(kernel) * 0.3;
333 }
334 }
335
336 score.clamp(0.0, 1.0)
337 }
338
339 fn get_adaptive_score(&self, kernel: &SpecializedKernel) -> f64 {
341 if let Ok(cache) = self.performance_cache.lock() {
342 if let Some(perf_data) = cache.get(&kernel.id) {
343 return perf_data.measured_throughput / 1e12 * perf_data.success_rate;
344 }
345 }
346
347 kernel.performance_profile.expected_throughput / 1e12
349 }
350
351 pub fn execute_kernel(
353 &self,
354 kernel: &SpecializedKernel,
355 input_a: &[f32],
356 input_b: &[f32],
357 output: &mut [f32],
358 params: &KernelParams,
359 ) -> Result<KernelExecutionResult> {
360 let start_time = std::time::Instant::now();
361
362 let result = match &kernel.implementation {
364 KernelImplementation::Native(kernel_fn) => kernel_fn(input_a, input_b, output, params),
365 KernelImplementation::Vectorized(kernel_fn) => {
366 kernel_fn(input_a, input_b, output, params)
367 }
368 _ => {
369 Err(TensorError::compute_error_simple(
371 "Unsupported kernel implementation".to_string(),
372 ))
373 }
374 };
375
376 let execution_time = start_time.elapsed();
377
378 self.update_performance_cache(&kernel.id, &result, execution_time);
380
381 if let Some(validator) = &kernel.validator {
383 let is_valid = validator(input_a, input_b, output, params);
384 if !is_valid {
385 return Err(TensorError::compute_error_simple(
386 "Kernel validation failed".to_string(),
387 ));
388 }
389 }
390
391 Ok(KernelExecutionResult {
392 success: result.is_ok(),
393 execution_time,
394 throughput: self.calculate_throughput(params, execution_time),
395 energy_estimate: self.estimate_energy_consumption(kernel, execution_time),
396 cache_efficiency: self.estimate_cache_efficiency(kernel, params),
397 })
398 }
399
400 fn update_performance_cache(
402 &self,
403 kernel_id: &str,
404 result: &Result<()>,
405 execution_time: std::time::Duration,
406 ) {
407 if let Ok(mut cache) = self.performance_cache.lock() {
408 let entry = cache
409 .entry(kernel_id.to_string())
410 .or_insert(KernelPerformanceData {
411 measured_throughput: 0.0,
412 measured_latency: 0.0,
413 success_rate: 0.0,
414 execution_count: 0,
415 last_updated: std::time::Instant::now(),
416 });
417
418 entry.execution_count += 1;
419 entry.measured_latency = execution_time.as_secs_f64();
420
421 if result.is_ok() {
422 entry.success_rate = (entry.success_rate * (entry.execution_count - 1) as f64
423 + 1.0)
424 / entry.execution_count as f64;
425 } else {
426 entry.success_rate = (entry.success_rate * (entry.execution_count - 1) as f64)
427 / entry.execution_count as f64;
428 }
429
430 entry.last_updated = std::time::Instant::now();
431 }
432 }
433
434 fn calculate_throughput(
436 &self,
437 params: &KernelParams,
438 execution_time: std::time::Duration,
439 ) -> f64 {
440 let total_ops = params.dimensions.iter().product::<usize>() as f64;
441 total_ops / execution_time.as_secs_f64()
442 }
443
444 fn estimate_energy_consumption(
446 &self,
447 kernel: &SpecializedKernel,
448 execution_time: std::time::Duration,
449 ) -> f64 {
450 let base_power = 50.0; let efficiency_multiplier = kernel.performance_profile.energy_efficiency / 1e12;
453 base_power * execution_time.as_secs_f64() / efficiency_multiplier
454 }
455
456 fn estimate_cache_efficiency(&self, kernel: &SpecializedKernel, _params: &KernelParams) -> f64 {
458 kernel.performance_profile.cache_efficiency
459 }
460
461 fn register_default_kernels(&mut self) -> Result<()> {
463 self.register_kernel(SpecializedKernel {
465 id: "matmul_high_perf".to_string(),
466 name: "High-Performance Matrix Multiplication".to_string(),
467 operation: "matmul".to_string(),
468 hardware_requirements: HardwareRequirements {
469 required_cpu_features: vec!["avx2".to_string()],
470 min_cache_sizes: CacheSizeRequirements {
471 min_l1_size: 32768,
472 min_l2_size: 262144,
473 min_l3_size: 8388608,
474 },
475 min_memory_bandwidth: 50e9,
476 min_simd_registers: 16,
477 preferred_architecture: vec!["x86_64".to_string()],
478 },
479 optimal_data_profile: DataProfile {
480 size_range: (1024, usize::MAX),
481 alignment_requirement: 64,
482 access_pattern: AccessPattern::BlockedSequential,
483 layout_preference: MemoryLayout::RowMajor,
484 sparsity_tolerance: 0.1,
485 },
486 performance_profile: PerformanceProfile {
487 expected_throughput: 2e12,
488 expected_latency: 1e-6,
489 memory_efficiency: 0.9,
490 cache_efficiency: 0.85,
491 energy_efficiency: 1e12,
492 scalability_factor: 0.95,
493 },
494 implementation: KernelImplementation::Vectorized(high_perf_matmul),
495 validator: Some(validate_matmul_result),
496 })?;
497
498 self.register_kernel(SpecializedKernel {
500 id: "elementwise_cache_friendly".to_string(),
501 name: "Cache-Friendly Element-wise Operations".to_string(),
502 operation: "elementwise".to_string(),
503 hardware_requirements: HardwareRequirements {
504 required_cpu_features: vec![],
505 min_cache_sizes: CacheSizeRequirements {
506 min_l1_size: 16384,
507 min_l2_size: 131072,
508 min_l3_size: 4194304,
509 },
510 min_memory_bandwidth: 25e9,
511 min_simd_registers: 8,
512 preferred_architecture: vec!["x86_64".to_string(), "aarch64".to_string()],
513 },
514 optimal_data_profile: DataProfile {
515 size_range: (64, usize::MAX),
516 alignment_requirement: 32,
517 access_pattern: AccessPattern::Sequential,
518 layout_preference: MemoryLayout::RowMajor,
519 sparsity_tolerance: 0.5,
520 },
521 performance_profile: PerformanceProfile {
522 expected_throughput: 4e12,
523 expected_latency: 5e-7,
524 memory_efficiency: 0.95,
525 cache_efficiency: 0.9,
526 energy_efficiency: 2e12,
527 scalability_factor: 0.98,
528 },
529 implementation: KernelImplementation::Vectorized(cache_friendly_elementwise),
530 validator: Some(validate_elementwise_result),
531 })?;
532
533 Ok(())
534 }
535
536 pub fn get_registry_statistics(&self) -> Result<KernelRegistryStatistics> {
538 let kernels = self.kernels.lock().map_err(|_| {
539 TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
540 })?;
541
542 let cache = self.performance_cache.lock().map_err(|_| {
543 TensorError::compute_error_simple("Failed to lock performance cache".to_string())
544 })?;
545
546 let total_kernels: usize = kernels.values().map(|v| v.len()).sum();
547 let total_operations = kernels.len();
548 let cached_performance_data = cache.len();
549
550 Ok(KernelRegistryStatistics {
551 total_kernels,
552 total_operations,
553 cached_performance_data,
554 selection_strategy: self.selection_strategy.clone(),
555 average_kernel_throughput: self.calculate_average_throughput(&kernels),
556 cache_hit_rate: self.calculate_cache_hit_rate(&cache),
557 })
558 }
559
560 fn calculate_average_throughput(
561 &self,
562 kernels: &HashMap<String, Vec<SpecializedKernel>>,
563 ) -> f64 {
564 let mut total_throughput = 0.0;
565 let mut kernel_count = 0;
566
567 for kernel_list in kernels.values() {
568 for kernel in kernel_list {
569 total_throughput += kernel.performance_profile.expected_throughput;
570 kernel_count += 1;
571 }
572 }
573
574 if kernel_count > 0 {
575 total_throughput / kernel_count as f64
576 } else {
577 0.0
578 }
579 }
580
581 fn calculate_cache_hit_rate(&self, cache: &HashMap<String, KernelPerformanceData>) -> f64 {
582 let total_executions: u64 = cache.values().map(|data| data.execution_count).sum();
583 let successful_executions: f64 = cache
584 .values()
585 .map(|data| data.execution_count as f64 * data.success_rate)
586 .sum();
587
588 if total_executions > 0 {
589 successful_executions / total_executions as f64
590 } else {
591 0.0
592 }
593 }
594}
595
596#[derive(Debug, Clone)]
598pub struct KernelExecutionResult {
599 pub success: bool,
600 pub execution_time: std::time::Duration,
601 pub throughput: f64,
602 pub energy_estimate: f64,
603 pub cache_efficiency: f64,
604}
605
606#[derive(Debug, Clone)]
608pub struct KernelRegistryStatistics {
609 pub total_kernels: usize,
610 pub total_operations: usize,
611 pub cached_performance_data: usize,
612 pub selection_strategy: KernelOptimizationStrategy,
613 pub average_kernel_throughput: f64,
614 pub cache_hit_rate: f64,
615}
616
617fn high_perf_matmul(a: &[f32], b: &[f32], c: &mut [f32], params: &KernelParams) -> Result<()> {
621 let (m, n, k) = if params.dimensions.len() >= 3 {
622 (
623 params.dimensions[0],
624 params.dimensions[1],
625 params.dimensions[2],
626 )
627 } else {
628 return Err(TensorError::compute_error_simple(
629 "Invalid dimensions for matmul".to_string(),
630 ));
631 };
632
633 const BLOCK_SIZE: usize = 64;
635
636 for i in (0..m).step_by(BLOCK_SIZE) {
637 for j in (0..n).step_by(BLOCK_SIZE) {
638 for l in (0..k).step_by(BLOCK_SIZE) {
639 let i_end = (i + BLOCK_SIZE).min(m);
640 let j_end = (j + BLOCK_SIZE).min(n);
641 let l_end = (l + BLOCK_SIZE).min(k);
642
643 for ii in i..i_end {
644 for jj in j..j_end {
645 let mut sum = 0.0;
646 for ll in l..l_end {
647 sum += a[ii * k + ll] * b[ll * n + jj];
648 }
649 c[ii * n + jj] += sum;
650 }
651 }
652 }
653 }
654 }
655
656 Ok(())
657}
658
659fn cache_friendly_elementwise(
661 a: &[f32],
662 b: &[f32],
663 c: &mut [f32],
664 params: &KernelParams,
665) -> Result<()> {
666 let operation = params.operation_params.get("operation").unwrap_or(&0.0) as &f64;
667
668 match *operation as i32 {
669 0 => {
670 for i in 0..a.len() {
672 c[i] = a[i] + b[i];
673 }
674 }
675 1 => {
676 for i in 0..a.len() {
678 c[i] = a[i] * b[i];
679 }
680 }
681 _ => {
682 return Err(TensorError::compute_error_simple(
683 "Unsupported element-wise operation".to_string(),
684 ));
685 }
686 }
687
688 Ok(())
689}
690
691fn validate_matmul_result(a: &[f32], b: &[f32], c: &[f32], _params: &KernelParams) -> bool {
693 let has_nonzero_input = a.iter().any(|&x| x != 0.0) && b.iter().any(|&x| x != 0.0);
695 let has_nonzero_output = c.iter().any(|&x| x != 0.0);
696
697 !has_nonzero_input || has_nonzero_output
698}
699
700fn validate_elementwise_result(a: &[f32], b: &[f32], c: &[f32], _params: &KernelParams) -> bool {
702 a.len() == b.len() && b.len() == c.len()
704}
705
706#[cfg(test)]
707mod tests {
708 use super::*;
709
710 #[test]
711 fn test_kernel_registry_creation() {
712 let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
713 let stats = registry
714 .get_registry_statistics()
715 .expect("test: get_registry_statistics should succeed");
716
717 assert!(stats.total_kernels > 0);
718 assert!(stats.total_operations > 0);
719 }
720
721 #[test]
722 fn test_kernel_selection() {
723 let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
724
725 let data_profile = DataProfile {
726 size_range: (1024, usize::MAX),
727 alignment_requirement: 64,
728 access_pattern: AccessPattern::Sequential,
729 layout_preference: MemoryLayout::RowMajor,
730 sparsity_tolerance: 0.1,
731 };
732
733 let kernel = registry.select_optimal_kernel("matmul", 2048, &data_profile);
734 assert!(kernel.is_ok());
735 }
736
737 #[test]
738 fn test_kernel_execution() {
739 let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
740
741 let data_profile = DataProfile {
742 size_range: (64, usize::MAX),
743 alignment_requirement: 32,
744 access_pattern: AccessPattern::Sequential,
745 layout_preference: MemoryLayout::RowMajor,
746 sparsity_tolerance: 0.5,
747 };
748
749 let kernel = registry
750 .select_optimal_kernel("matmul", 512, &data_profile)
751 .expect("test: operation should succeed");
752
753 let a = vec![1.0; 64];
754 let b = vec![2.0; 64];
755 let mut c = vec![0.0; 64];
756
757 let params = KernelParams {
758 dimensions: vec![8, 8, 8],
759 strides: vec![8, 8, 8],
760 data_type: "f32".to_string(),
761 operation_params: HashMap::new(),
762 performance_hints: vec![],
763 };
764
765 let result = registry.execute_kernel(&kernel, &a, &b, &mut c, ¶ms);
766 assert!(result.is_ok());
767
768 let execution_result = result.expect("test: operation should succeed");
769 assert!(execution_result.success);
770 assert!(execution_result.throughput > 0.0);
771 }
772
773 #[test]
774 fn test_performance_cache_update() {
775 let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::Adaptive);
776
777 let data_profile = DataProfile {
779 size_range: (64, usize::MAX),
780 alignment_requirement: 32,
781 access_pattern: AccessPattern::Sequential,
782 layout_preference: MemoryLayout::RowMajor,
783 sparsity_tolerance: 0.5,
784 };
785
786 let kernel = registry
787 .select_optimal_kernel("elementwise", 256, &data_profile)
788 .expect("test: operation should succeed");
789
790 let a = vec![1.0; 16];
791 let b = vec![2.0; 16];
792 let mut c = vec![0.0; 16];
793
794 let mut params = KernelParams {
795 dimensions: vec![16],
796 strides: vec![1],
797 data_type: "f32".to_string(),
798 operation_params: HashMap::new(),
799 performance_hints: vec![],
800 };
801 params.operation_params.insert("operation".to_string(), 0.0); for _ in 0..5 {
804 let _ = registry.execute_kernel(&kernel, &a, &b, &mut c, ¶ms);
805 }
806
807 let stats = registry
808 .get_registry_statistics()
809 .expect("test: get_registry_statistics should succeed");
810 assert!(stats.cached_performance_data > 0);
811 }
812}