Skip to main content

tenflowers_core/simd/
advanced_kernels.rs

1//! Advanced High-Performance Kernel Registry
2//!
3//! This module provides a comprehensive registry of specialized kernels optimized
4//! for different hardware platforms, data sizes, and operation types.
5
6use crate::{Result, TensorError};
7use scirs2_core::profiling::Profiler;
8use std::collections::HashMap;
9use std::sync::{Arc, Mutex};
10
11/// Advanced kernel registry for automatic kernel selection
12#[allow(dead_code)]
13pub struct AdvancedKernelRegistry {
14    /// Registered kernels by operation type
15    kernels: Arc<Mutex<HashMap<String, Vec<SpecializedKernel>>>>,
16    /// Performance profiler
17    profiler: Arc<Profiler>,
18    /// Kernel selection strategy
19    selection_strategy: KernelOptimizationStrategy,
20    /// Runtime performance cache
21    performance_cache: Arc<Mutex<HashMap<String, KernelPerformanceData>>>,
22}
23
24/// Specialized high-performance kernel
25#[derive(Debug, Clone)]
26pub struct SpecializedKernel {
27    /// Kernel identifier
28    pub id: String,
29    /// Kernel name
30    pub name: String,
31    /// Target operation
32    pub operation: String,
33    /// Hardware requirements
34    pub hardware_requirements: HardwareRequirements,
35    /// Optimal data characteristics
36    pub optimal_data_profile: DataProfile,
37    /// Performance characteristics
38    pub performance_profile: PerformanceProfile,
39    /// Kernel implementation
40    pub implementation: KernelImplementation,
41    /// Validation function
42    pub validator: Option<ValidationFunction>,
43}
44
45/// Hardware requirements for kernel execution
46#[derive(Debug, Clone)]
47pub struct HardwareRequirements {
48    /// Required CPU features
49    pub required_cpu_features: Vec<String>,
50    /// Minimum cache sizes
51    pub min_cache_sizes: CacheSizeRequirements,
52    /// Memory bandwidth requirements
53    pub min_memory_bandwidth: f64,
54    /// SIMD register requirements
55    pub min_simd_registers: usize,
56    /// Architecture preference
57    pub preferred_architecture: Vec<String>,
58}
59
60/// Cache size requirements
61#[derive(Debug, Clone)]
62pub struct CacheSizeRequirements {
63    pub min_l1_size: usize,
64    pub min_l2_size: usize,
65    pub min_l3_size: usize,
66}
67
68/// Data characteristics for optimal kernel performance
69#[derive(Debug, Clone)]
70pub struct DataProfile {
71    /// Optimal data size range
72    pub size_range: (usize, usize),
73    /// Optimal data alignment
74    pub alignment_requirement: usize,
75    /// Data access pattern
76    pub access_pattern: AccessPattern,
77    /// Memory layout preference
78    pub layout_preference: MemoryLayout,
79    /// Sparsity tolerance
80    pub sparsity_tolerance: f64,
81}
82
83/// Memory access pattern types
84#[derive(Debug, Clone, Copy)]
85pub enum AccessPattern {
86    Sequential,
87    Strided,
88    Random,
89    BlockedSequential,
90    CacheOblivious,
91}
92
93/// Memory layout preferences
94#[derive(Debug, Clone, Copy)]
95pub enum MemoryLayout {
96    RowMajor,
97    ColumnMajor,
98    Blocked,
99    Tiled,
100    Interleaved,
101}
102
103/// Kernel performance characteristics
104#[derive(Debug, Clone)]
105pub struct PerformanceProfile {
106    /// Expected throughput (ops/sec)
107    pub expected_throughput: f64,
108    /// Expected latency (seconds)
109    pub expected_latency: f64,
110    /// Memory efficiency (0-1)
111    pub memory_efficiency: f64,
112    /// Cache efficiency (0-1)
113    pub cache_efficiency: f64,
114    /// Energy efficiency (ops/joule)
115    pub energy_efficiency: f64,
116    /// Scalability factor
117    pub scalability_factor: f64,
118}
119
120/// Kernel implementation variants
121#[derive(Debug, Clone)]
122pub enum KernelImplementation {
123    /// Native Rust implementation
124    Native(NativeKernelFn),
125    /// Assembly-optimized implementation
126    Assembly(AssemblyKernelFn),
127    /// SIMD-vectorized implementation
128    Vectorized(VectorizedKernelFn),
129    /// GPU-accelerated implementation
130    Gpu(GpuKernelFn),
131    /// Hybrid CPU-GPU implementation
132    Hybrid(HybridKernelFn),
133}
134
135/// Function type for native Rust kernels
136pub type NativeKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
137
138/// Function type for assembly-optimized kernels
139pub type AssemblyKernelFn =
140    unsafe fn(*const f32, *const f32, *mut f32, &KernelParams) -> Result<()>;
141
142/// Function type for vectorized kernels
143pub type VectorizedKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
144
145/// Function type for GPU kernels
146pub type GpuKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
147
148/// Function type for hybrid kernels
149pub type HybridKernelFn = fn(&[f32], &[f32], &mut [f32], &KernelParams) -> Result<()>;
150
151/// Kernel validation function
152pub type ValidationFunction = fn(&[f32], &[f32], &[f32], &KernelParams) -> bool;
153
154/// Kernel execution parameters
155#[derive(Debug, Clone)]
156pub struct KernelParams {
157    /// Matrix/tensor dimensions
158    pub dimensions: Vec<usize>,
159    /// Stride information
160    pub strides: Vec<usize>,
161    /// Data type information
162    pub data_type: String,
163    /// Operation-specific parameters
164    pub operation_params: HashMap<String, f64>,
165    /// Performance hints
166    pub performance_hints: Vec<String>,
167}
168
169/// Kernel optimization strategy
170#[derive(Debug, Clone)]
171pub enum KernelOptimizationStrategy {
172    /// Maximize throughput
173    MaxThroughput,
174    /// Minimize latency
175    MinLatency,
176    /// Optimize for energy efficiency
177    EnergyEfficient,
178    /// Balance performance and energy
179    Balanced,
180    /// Adaptive based on workload
181    Adaptive,
182}
183
184/// Runtime performance data for kernel selection
185#[derive(Debug, Clone)]
186pub struct KernelPerformanceData {
187    /// Measured throughput
188    pub measured_throughput: f64,
189    /// Measured latency
190    pub measured_latency: f64,
191    /// Success rate
192    pub success_rate: f64,
193    /// Number of executions
194    pub execution_count: u64,
195    /// Last update timestamp
196    pub last_updated: std::time::Instant,
197}
198
199impl AdvancedKernelRegistry {
200    /// Create new advanced kernel registry
201    pub fn new(strategy: KernelOptimizationStrategy) -> Self {
202        let kernels = Arc::new(Mutex::new(HashMap::new()));
203        let profiler = Arc::new(Profiler::new());
204        let performance_cache = Arc::new(Mutex::new(HashMap::new()));
205
206        let mut registry = Self {
207            kernels,
208            profiler,
209            selection_strategy: strategy,
210            performance_cache,
211        };
212
213        // Register default high-performance kernels
214        registry
215            .register_default_kernels()
216            .expect("Failed to register default kernels");
217
218        registry
219    }
220
221    /// Register a new specialized kernel
222    pub fn register_kernel(&self, kernel: SpecializedKernel) -> Result<()> {
223        let mut kernels = self.kernels.lock().map_err(|_| {
224            TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
225        })?;
226
227        let operation_kernels = kernels
228            .entry(kernel.operation.clone())
229            .or_insert_with(Vec::new);
230        operation_kernels.push(kernel);
231
232        // Sort kernels by expected performance
233        operation_kernels.sort_by(|a, b| {
234            b.performance_profile
235                .expected_throughput
236                .partial_cmp(&a.performance_profile.expected_throughput)
237                .expect("Throughput values must be valid floating-point numbers")
238        });
239
240        Ok(())
241    }
242
243    /// Select optimal kernel for given operation and data characteristics
244    pub fn select_optimal_kernel(
245        &self,
246        operation: &str,
247        data_size: usize,
248        data_profile: &DataProfile,
249    ) -> Result<SpecializedKernel> {
250        let kernels = self.kernels.lock().map_err(|_| {
251            TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
252        })?;
253
254        let operation_kernels = kernels.get(operation).ok_or_else(|| {
255            TensorError::compute_error_simple(format!(
256                "No kernels registered for operation: {}",
257                operation
258            ))
259        })?;
260
261        // Score each kernel based on suitability
262        let mut scored_kernels: Vec<(f64, &SpecializedKernel)> = operation_kernels
263            .iter()
264            .map(|kernel| (self.score_kernel(kernel, data_size, data_profile), kernel))
265            .collect();
266
267        // Sort by score (highest first)
268        scored_kernels.sort_by(|a, b| {
269            b.0.partial_cmp(&a.0)
270                .expect("partial_cmp should not return None for valid values")
271        });
272
273        if let Some((score, kernel)) = scored_kernels.first() {
274            if *score > 0.0 {
275                return Ok((*kernel).clone());
276            }
277        }
278
279        Err(TensorError::compute_error_simple(
280            "No suitable kernel found".to_string(),
281        ))
282    }
283
284    /// Score kernel suitability for given data characteristics
285    fn score_kernel(
286        &self,
287        kernel: &SpecializedKernel,
288        data_size: usize,
289        data_profile: &DataProfile,
290    ) -> f64 {
291        let mut score = 0.0;
292
293        // Size compatibility score
294        if data_size >= kernel.optimal_data_profile.size_range.0
295            && data_size <= kernel.optimal_data_profile.size_range.1
296        {
297            score += 0.3;
298        }
299
300        // Access pattern compatibility score
301        if std::mem::discriminant(&kernel.optimal_data_profile.access_pattern)
302            == std::mem::discriminant(&data_profile.access_pattern)
303        {
304            score += 0.2;
305        }
306
307        // Memory layout compatibility score
308        if std::mem::discriminant(&kernel.optimal_data_profile.layout_preference)
309            == std::mem::discriminant(&data_profile.layout_preference)
310        {
311            score += 0.2;
312        }
313
314        // Performance score based on strategy
315        match self.selection_strategy {
316            KernelOptimizationStrategy::MaxThroughput => {
317                score += kernel.performance_profile.expected_throughput / 1e12 * 0.3;
318            }
319            KernelOptimizationStrategy::MinLatency => {
320                score += (1.0 / kernel.performance_profile.expected_latency.max(1e-9)) / 1e9 * 0.3;
321            }
322            KernelOptimizationStrategy::EnergyEfficient => {
323                score += kernel.performance_profile.energy_efficiency / 1e12 * 0.3;
324            }
325            KernelOptimizationStrategy::Balanced => {
326                score += (kernel.performance_profile.expected_throughput / 1e12
327                    + kernel.performance_profile.energy_efficiency / 1e12)
328                    * 0.15;
329            }
330            KernelOptimizationStrategy::Adaptive => {
331                // Use historical performance data if available
332                score += self.get_adaptive_score(kernel) * 0.3;
333            }
334        }
335
336        score.clamp(0.0, 1.0)
337    }
338
339    /// Get adaptive score based on historical performance
340    fn get_adaptive_score(&self, kernel: &SpecializedKernel) -> f64 {
341        if let Ok(cache) = self.performance_cache.lock() {
342            if let Some(perf_data) = cache.get(&kernel.id) {
343                return perf_data.measured_throughput / 1e12 * perf_data.success_rate;
344            }
345        }
346
347        // Fallback to expected performance
348        kernel.performance_profile.expected_throughput / 1e12
349    }
350
351    /// Execute kernel with performance monitoring
352    pub fn execute_kernel(
353        &self,
354        kernel: &SpecializedKernel,
355        input_a: &[f32],
356        input_b: &[f32],
357        output: &mut [f32],
358        params: &KernelParams,
359    ) -> Result<KernelExecutionResult> {
360        let start_time = std::time::Instant::now();
361
362        // Execute kernel based on implementation type
363        let result = match &kernel.implementation {
364            KernelImplementation::Native(kernel_fn) => kernel_fn(input_a, input_b, output, params),
365            KernelImplementation::Vectorized(kernel_fn) => {
366                kernel_fn(input_a, input_b, output, params)
367            }
368            _ => {
369                // Fallback to native implementation for unsupported types
370                Err(TensorError::compute_error_simple(
371                    "Unsupported kernel implementation".to_string(),
372                ))
373            }
374        };
375
376        let execution_time = start_time.elapsed();
377
378        // Update performance cache
379        self.update_performance_cache(&kernel.id, &result, execution_time);
380
381        // Validate result if validator is provided
382        if let Some(validator) = &kernel.validator {
383            let is_valid = validator(input_a, input_b, output, params);
384            if !is_valid {
385                return Err(TensorError::compute_error_simple(
386                    "Kernel validation failed".to_string(),
387                ));
388            }
389        }
390
391        Ok(KernelExecutionResult {
392            success: result.is_ok(),
393            execution_time,
394            throughput: self.calculate_throughput(params, execution_time),
395            energy_estimate: self.estimate_energy_consumption(kernel, execution_time),
396            cache_efficiency: self.estimate_cache_efficiency(kernel, params),
397        })
398    }
399
400    /// Update performance cache with execution results
401    fn update_performance_cache(
402        &self,
403        kernel_id: &str,
404        result: &Result<()>,
405        execution_time: std::time::Duration,
406    ) {
407        if let Ok(mut cache) = self.performance_cache.lock() {
408            let entry = cache
409                .entry(kernel_id.to_string())
410                .or_insert(KernelPerformanceData {
411                    measured_throughput: 0.0,
412                    measured_latency: 0.0,
413                    success_rate: 0.0,
414                    execution_count: 0,
415                    last_updated: std::time::Instant::now(),
416                });
417
418            entry.execution_count += 1;
419            entry.measured_latency = execution_time.as_secs_f64();
420
421            if result.is_ok() {
422                entry.success_rate = (entry.success_rate * (entry.execution_count - 1) as f64
423                    + 1.0)
424                    / entry.execution_count as f64;
425            } else {
426                entry.success_rate = (entry.success_rate * (entry.execution_count - 1) as f64)
427                    / entry.execution_count as f64;
428            }
429
430            entry.last_updated = std::time::Instant::now();
431        }
432    }
433
434    /// Calculate throughput based on operation parameters
435    fn calculate_throughput(
436        &self,
437        params: &KernelParams,
438        execution_time: std::time::Duration,
439    ) -> f64 {
440        let total_ops = params.dimensions.iter().product::<usize>() as f64;
441        total_ops / execution_time.as_secs_f64()
442    }
443
444    /// Estimate energy consumption for kernel execution
445    fn estimate_energy_consumption(
446        &self,
447        kernel: &SpecializedKernel,
448        execution_time: std::time::Duration,
449    ) -> f64 {
450        // Simple energy estimation based on performance profile
451        let base_power = 50.0; // Watts
452        let efficiency_multiplier = kernel.performance_profile.energy_efficiency / 1e12;
453        base_power * execution_time.as_secs_f64() / efficiency_multiplier
454    }
455
456    /// Estimate cache efficiency for kernel execution
457    fn estimate_cache_efficiency(&self, kernel: &SpecializedKernel, _params: &KernelParams) -> f64 {
458        kernel.performance_profile.cache_efficiency
459    }
460
461    /// Register default high-performance kernels
462    fn register_default_kernels(&mut self) -> Result<()> {
463        // High-performance matrix multiplication kernel
464        self.register_kernel(SpecializedKernel {
465            id: "matmul_high_perf".to_string(),
466            name: "High-Performance Matrix Multiplication".to_string(),
467            operation: "matmul".to_string(),
468            hardware_requirements: HardwareRequirements {
469                required_cpu_features: vec!["avx2".to_string()],
470                min_cache_sizes: CacheSizeRequirements {
471                    min_l1_size: 32768,
472                    min_l2_size: 262144,
473                    min_l3_size: 8388608,
474                },
475                min_memory_bandwidth: 50e9,
476                min_simd_registers: 16,
477                preferred_architecture: vec!["x86_64".to_string()],
478            },
479            optimal_data_profile: DataProfile {
480                size_range: (1024, usize::MAX),
481                alignment_requirement: 64,
482                access_pattern: AccessPattern::BlockedSequential,
483                layout_preference: MemoryLayout::RowMajor,
484                sparsity_tolerance: 0.1,
485            },
486            performance_profile: PerformanceProfile {
487                expected_throughput: 2e12,
488                expected_latency: 1e-6,
489                memory_efficiency: 0.9,
490                cache_efficiency: 0.85,
491                energy_efficiency: 1e12,
492                scalability_factor: 0.95,
493            },
494            implementation: KernelImplementation::Vectorized(high_perf_matmul),
495            validator: Some(validate_matmul_result),
496        })?;
497
498        // Cache-friendly element-wise operations kernel
499        self.register_kernel(SpecializedKernel {
500            id: "elementwise_cache_friendly".to_string(),
501            name: "Cache-Friendly Element-wise Operations".to_string(),
502            operation: "elementwise".to_string(),
503            hardware_requirements: HardwareRequirements {
504                required_cpu_features: vec![],
505                min_cache_sizes: CacheSizeRequirements {
506                    min_l1_size: 16384,
507                    min_l2_size: 131072,
508                    min_l3_size: 4194304,
509                },
510                min_memory_bandwidth: 25e9,
511                min_simd_registers: 8,
512                preferred_architecture: vec!["x86_64".to_string(), "aarch64".to_string()],
513            },
514            optimal_data_profile: DataProfile {
515                size_range: (64, usize::MAX),
516                alignment_requirement: 32,
517                access_pattern: AccessPattern::Sequential,
518                layout_preference: MemoryLayout::RowMajor,
519                sparsity_tolerance: 0.5,
520            },
521            performance_profile: PerformanceProfile {
522                expected_throughput: 4e12,
523                expected_latency: 5e-7,
524                memory_efficiency: 0.95,
525                cache_efficiency: 0.9,
526                energy_efficiency: 2e12,
527                scalability_factor: 0.98,
528            },
529            implementation: KernelImplementation::Vectorized(cache_friendly_elementwise),
530            validator: Some(validate_elementwise_result),
531        })?;
532
533        Ok(())
534    }
535
536    /// Get comprehensive kernel registry statistics
537    pub fn get_registry_statistics(&self) -> Result<KernelRegistryStatistics> {
538        let kernels = self.kernels.lock().map_err(|_| {
539            TensorError::compute_error_simple("Failed to lock kernel registry".to_string())
540        })?;
541
542        let cache = self.performance_cache.lock().map_err(|_| {
543            TensorError::compute_error_simple("Failed to lock performance cache".to_string())
544        })?;
545
546        let total_kernels: usize = kernels.values().map(|v| v.len()).sum();
547        let total_operations = kernels.len();
548        let cached_performance_data = cache.len();
549
550        Ok(KernelRegistryStatistics {
551            total_kernels,
552            total_operations,
553            cached_performance_data,
554            selection_strategy: self.selection_strategy.clone(),
555            average_kernel_throughput: self.calculate_average_throughput(&kernels),
556            cache_hit_rate: self.calculate_cache_hit_rate(&cache),
557        })
558    }
559
560    fn calculate_average_throughput(
561        &self,
562        kernels: &HashMap<String, Vec<SpecializedKernel>>,
563    ) -> f64 {
564        let mut total_throughput = 0.0;
565        let mut kernel_count = 0;
566
567        for kernel_list in kernels.values() {
568            for kernel in kernel_list {
569                total_throughput += kernel.performance_profile.expected_throughput;
570                kernel_count += 1;
571            }
572        }
573
574        if kernel_count > 0 {
575            total_throughput / kernel_count as f64
576        } else {
577            0.0
578        }
579    }
580
581    fn calculate_cache_hit_rate(&self, cache: &HashMap<String, KernelPerformanceData>) -> f64 {
582        let total_executions: u64 = cache.values().map(|data| data.execution_count).sum();
583        let successful_executions: f64 = cache
584            .values()
585            .map(|data| data.execution_count as f64 * data.success_rate)
586            .sum();
587
588        if total_executions > 0 {
589            successful_executions / total_executions as f64
590        } else {
591            0.0
592        }
593    }
594}
595
596/// Kernel execution result
597#[derive(Debug, Clone)]
598pub struct KernelExecutionResult {
599    pub success: bool,
600    pub execution_time: std::time::Duration,
601    pub throughput: f64,
602    pub energy_estimate: f64,
603    pub cache_efficiency: f64,
604}
605
606/// Kernel registry statistics
607#[derive(Debug, Clone)]
608pub struct KernelRegistryStatistics {
609    pub total_kernels: usize,
610    pub total_operations: usize,
611    pub cached_performance_data: usize,
612    pub selection_strategy: KernelOptimizationStrategy,
613    pub average_kernel_throughput: f64,
614    pub cache_hit_rate: f64,
615}
616
617// High-performance kernel implementations
618
619/// High-performance matrix multiplication kernel
620fn high_perf_matmul(a: &[f32], b: &[f32], c: &mut [f32], params: &KernelParams) -> Result<()> {
621    let (m, n, k) = if params.dimensions.len() >= 3 {
622        (
623            params.dimensions[0],
624            params.dimensions[1],
625            params.dimensions[2],
626        )
627    } else {
628        return Err(TensorError::compute_error_simple(
629            "Invalid dimensions for matmul".to_string(),
630        ));
631    };
632
633    // Simple blocked implementation
634    const BLOCK_SIZE: usize = 64;
635
636    for i in (0..m).step_by(BLOCK_SIZE) {
637        for j in (0..n).step_by(BLOCK_SIZE) {
638            for l in (0..k).step_by(BLOCK_SIZE) {
639                let i_end = (i + BLOCK_SIZE).min(m);
640                let j_end = (j + BLOCK_SIZE).min(n);
641                let l_end = (l + BLOCK_SIZE).min(k);
642
643                for ii in i..i_end {
644                    for jj in j..j_end {
645                        let mut sum = 0.0;
646                        for ll in l..l_end {
647                            sum += a[ii * k + ll] * b[ll * n + jj];
648                        }
649                        c[ii * n + jj] += sum;
650                    }
651                }
652            }
653        }
654    }
655
656    Ok(())
657}
658
659/// Cache-friendly element-wise operations kernel
660fn cache_friendly_elementwise(
661    a: &[f32],
662    b: &[f32],
663    c: &mut [f32],
664    params: &KernelParams,
665) -> Result<()> {
666    let operation = params.operation_params.get("operation").unwrap_or(&0.0) as &f64;
667
668    match *operation as i32 {
669        0 => {
670            // Add
671            for i in 0..a.len() {
672                c[i] = a[i] + b[i];
673            }
674        }
675        1 => {
676            // Multiply
677            for i in 0..a.len() {
678                c[i] = a[i] * b[i];
679            }
680        }
681        _ => {
682            return Err(TensorError::compute_error_simple(
683                "Unsupported element-wise operation".to_string(),
684            ));
685        }
686    }
687
688    Ok(())
689}
690
691/// Validate matrix multiplication result
692fn validate_matmul_result(a: &[f32], b: &[f32], c: &[f32], _params: &KernelParams) -> bool {
693    // Simple validation: check that result is not all zeros (for non-zero inputs)
694    let has_nonzero_input = a.iter().any(|&x| x != 0.0) && b.iter().any(|&x| x != 0.0);
695    let has_nonzero_output = c.iter().any(|&x| x != 0.0);
696
697    !has_nonzero_input || has_nonzero_output
698}
699
700/// Validate element-wise operation result
701fn validate_elementwise_result(a: &[f32], b: &[f32], c: &[f32], _params: &KernelParams) -> bool {
702    // Simple validation: check dimensions match
703    a.len() == b.len() && b.len() == c.len()
704}
705
706#[cfg(test)]
707mod tests {
708    use super::*;
709
710    #[test]
711    fn test_kernel_registry_creation() {
712        let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
713        let stats = registry
714            .get_registry_statistics()
715            .expect("test: get_registry_statistics should succeed");
716
717        assert!(stats.total_kernels > 0);
718        assert!(stats.total_operations > 0);
719    }
720
721    #[test]
722    fn test_kernel_selection() {
723        let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
724
725        let data_profile = DataProfile {
726            size_range: (1024, usize::MAX),
727            alignment_requirement: 64,
728            access_pattern: AccessPattern::Sequential,
729            layout_preference: MemoryLayout::RowMajor,
730            sparsity_tolerance: 0.1,
731        };
732
733        let kernel = registry.select_optimal_kernel("matmul", 2048, &data_profile);
734        assert!(kernel.is_ok());
735    }
736
737    #[test]
738    fn test_kernel_execution() {
739        let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::MaxThroughput);
740
741        let data_profile = DataProfile {
742            size_range: (64, usize::MAX),
743            alignment_requirement: 32,
744            access_pattern: AccessPattern::Sequential,
745            layout_preference: MemoryLayout::RowMajor,
746            sparsity_tolerance: 0.5,
747        };
748
749        let kernel = registry
750            .select_optimal_kernel("matmul", 512, &data_profile)
751            .expect("test: operation should succeed");
752
753        let a = vec![1.0; 64];
754        let b = vec![2.0; 64];
755        let mut c = vec![0.0; 64];
756
757        let params = KernelParams {
758            dimensions: vec![8, 8, 8],
759            strides: vec![8, 8, 8],
760            data_type: "f32".to_string(),
761            operation_params: HashMap::new(),
762            performance_hints: vec![],
763        };
764
765        let result = registry.execute_kernel(&kernel, &a, &b, &mut c, &params);
766        assert!(result.is_ok());
767
768        let execution_result = result.expect("test: operation should succeed");
769        assert!(execution_result.success);
770        assert!(execution_result.throughput > 0.0);
771    }
772
773    #[test]
774    fn test_performance_cache_update() {
775        let registry = AdvancedKernelRegistry::new(KernelOptimizationStrategy::Adaptive);
776
777        // Execute a kernel multiple times to populate cache
778        let data_profile = DataProfile {
779            size_range: (64, usize::MAX),
780            alignment_requirement: 32,
781            access_pattern: AccessPattern::Sequential,
782            layout_preference: MemoryLayout::RowMajor,
783            sparsity_tolerance: 0.5,
784        };
785
786        let kernel = registry
787            .select_optimal_kernel("elementwise", 256, &data_profile)
788            .expect("test: operation should succeed");
789
790        let a = vec![1.0; 16];
791        let b = vec![2.0; 16];
792        let mut c = vec![0.0; 16];
793
794        let mut params = KernelParams {
795            dimensions: vec![16],
796            strides: vec![1],
797            data_type: "f32".to_string(),
798            operation_params: HashMap::new(),
799            performance_hints: vec![],
800        };
801        params.operation_params.insert("operation".to_string(), 0.0); // Add operation
802
803        for _ in 0..5 {
804            let _ = registry.execute_kernel(&kernel, &a, &b, &mut c, &params);
805        }
806
807        let stats = registry
808            .get_registry_statistics()
809            .expect("test: get_registry_statistics should succeed");
810        assert!(stats.cached_performance_data > 0);
811    }
812}