Skip to main content

tensorlogic_infer/
cache_optimizer.rs

1//! Memory hierarchy and cache-aware optimization.
2//!
3//! This module provides cache-aware optimizations for better memory performance:
4//! - **Cache modeling**: Model L1/L2/L3 cache behavior
5//! - **Data layout optimization**: Arrange data for cache efficiency
6//! - **Loop tiling**: Optimize loop nests for cache reuse
7//! - **Prefetching**: Software prefetch directives
8//! - **NUMA optimization**: Optimize for non-uniform memory access
9//!
10//! ## Example
11//!
12//! ```rust,ignore
13//! use tensorlogic_infer::{CacheOptimizer, CacheConfig, TilingStrategy};
14//!
15//! // Configure cache optimizer
16//! let config = CacheConfig::from_system()
17//!     .with_tiling_enabled(true)
18//!     .with_prefetch_distance(8);
19//!
20//! let optimizer = CacheOptimizer::new(config);
21//!
22//! // Optimize graph for cache efficiency
23//! let optimized = optimizer.optimize(&graph)?;
24//!
25//! // Check cache metrics
26//! let metrics = optimizer.estimate_cache_metrics(&optimized);
27//! println!("Estimated cache hit rate: {:.2}%", metrics.hit_rate * 100.0);
28//! ```
29
30use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33/// Cache optimization errors.
34#[derive(Error, Debug, Clone, PartialEq)]
35pub enum CacheOptimizerError {
36    #[error("Invalid cache configuration: {0}")]
37    InvalidConfig(String),
38
39    #[error("Optimization failed: {0}")]
40    OptimizationFailed(String),
41
42    #[error("Insufficient cache size: required {required} KB, available {available} KB")]
43    InsufficientCache { required: usize, available: usize },
44
45    #[error("Invalid tiling parameters: {0}")]
46    InvalidTiling(String),
47}
48
49/// Cache level.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51pub enum CacheLevel {
52    L1,
53    L2,
54    L3,
55    LLC, // Last Level Cache
56}
57
58impl CacheLevel {
59    /// Get typical cache size (KB).
60    pub fn typical_size_kb(&self) -> usize {
61        match self {
62            CacheLevel::L1 => 32,
63            CacheLevel::L2 => 256,
64            CacheLevel::L3 => 8192,
65            CacheLevel::LLC => 32768,
66        }
67    }
68
69    /// Get typical cache latency (cycles).
70    pub fn typical_latency_cycles(&self) -> usize {
71        match self {
72            CacheLevel::L1 => 4,
73            CacheLevel::L2 => 12,
74            CacheLevel::L3 => 40,
75            CacheLevel::LLC => 100,
76        }
77    }
78}
79
80/// Cache configuration.
81#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
82pub struct CacheConfig {
83    /// L1 cache size (KB)
84    pub l1_size_kb: usize,
85
86    /// L2 cache size (KB)
87    pub l2_size_kb: usize,
88
89    /// L3 cache size (KB)
90    pub l3_size_kb: usize,
91
92    /// Cache line size (bytes)
93    pub cache_line_size: usize,
94
95    /// Cache associativity
96    pub associativity: usize,
97
98    /// Enable loop tiling
99    pub enable_tiling: bool,
100
101    /// Enable prefetching
102    pub enable_prefetch: bool,
103
104    /// Prefetch distance (cache lines)
105    pub prefetch_distance: usize,
106
107    /// Enable data layout optimization
108    pub enable_layout_optimization: bool,
109}
110
111impl Default for CacheConfig {
112    fn default() -> Self {
113        Self {
114            l1_size_kb: 32,
115            l2_size_kb: 256,
116            l3_size_kb: 8192,
117            cache_line_size: 64,
118            associativity: 8,
119            enable_tiling: true,
120            enable_prefetch: true,
121            prefetch_distance: 8,
122            enable_layout_optimization: true,
123        }
124    }
125}
126
127impl CacheConfig {
128    /// Detect cache configuration from system.
129    pub fn from_system() -> Self {
130        // In real implementation, would query system info
131        Self::default()
132    }
133
134    /// Set L1 cache size.
135    pub fn with_l1_size(mut self, size_kb: usize) -> Self {
136        self.l1_size_kb = size_kb;
137        self
138    }
139
140    /// Set L2 cache size.
141    pub fn with_l2_size(mut self, size_kb: usize) -> Self {
142        self.l2_size_kb = size_kb;
143        self
144    }
145
146    /// Set L3 cache size.
147    pub fn with_l3_size(mut self, size_kb: usize) -> Self {
148        self.l3_size_kb = size_kb;
149        self
150    }
151
152    /// Enable or disable tiling.
153    pub fn with_tiling_enabled(mut self, enabled: bool) -> Self {
154        self.enable_tiling = enabled;
155        self
156    }
157
158    /// Enable or disable prefetching.
159    pub fn with_prefetch_enabled(mut self, enabled: bool) -> Self {
160        self.enable_prefetch = enabled;
161        self
162    }
163
164    /// Set prefetch distance.
165    pub fn with_prefetch_distance(mut self, distance: usize) -> Self {
166        self.prefetch_distance = distance;
167        self
168    }
169
170    /// Get total cache size (KB).
171    pub fn total_size_kb(&self) -> usize {
172        self.l1_size_kb + self.l2_size_kb + self.l3_size_kb
173    }
174}
175
176/// Loop tiling parameters.
177#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
178pub struct TilingParams {
179    /// Tile size for outermost dimension
180    pub tile_i: usize,
181
182    /// Tile size for middle dimension
183    pub tile_j: usize,
184
185    /// Tile size for innermost dimension
186    pub tile_k: usize,
187
188    /// Target cache level
189    pub target_level: CacheLevel,
190}
191
192impl TilingParams {
193    /// Create tiling parameters for a given cache size.
194    pub fn for_cache_size(cache_size_kb: usize, element_size: usize) -> Self {
195        // Simple heuristic: use square tiles that fit in cache
196        let cache_bytes = cache_size_kb * 1024;
197        let elements_per_tile = (cache_bytes / 3) / element_size; // Divide by 3 for 3 arrays
198        let tile_size = (elements_per_tile as f64).sqrt() as usize;
199
200        Self {
201            tile_i: tile_size,
202            tile_j: tile_size,
203            tile_k: tile_size,
204            target_level: CacheLevel::L2,
205        }
206    }
207
208    /// Validate tiling parameters.
209    pub fn validate(&self) -> Result<(), CacheOptimizerError> {
210        if self.tile_i == 0 || self.tile_j == 0 || self.tile_k == 0 {
211            return Err(CacheOptimizerError::InvalidTiling(
212                "Tile sizes must be > 0".to_string(),
213            ));
214        }
215        Ok(())
216    }
217}
218
219/// Cache metrics for a computation.
220#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct CacheMetrics {
222    /// Estimated cache hit rate (0.0-1.0)
223    pub hit_rate: f64,
224
225    /// L1 cache hits
226    pub l1_hits: usize,
227
228    /// L2 cache hits
229    pub l2_hits: usize,
230
231    /// L3 cache hits
232    pub l3_hits: usize,
233
234    /// Cache misses
235    pub misses: usize,
236
237    /// Total accesses
238    pub total_accesses: usize,
239
240    /// Estimated memory bandwidth (GB/s)
241    pub memory_bandwidth_gbs: f64,
242
243    /// Estimated latency (cycles)
244    pub avg_latency_cycles: f64,
245}
246
247impl CacheMetrics {
248    /// Create new cache metrics.
249    pub fn new() -> Self {
250        Self {
251            hit_rate: 0.0,
252            l1_hits: 0,
253            l2_hits: 0,
254            l3_hits: 0,
255            misses: 0,
256            total_accesses: 0,
257            memory_bandwidth_gbs: 0.0,
258            avg_latency_cycles: 0.0,
259        }
260    }
261
262    /// Calculate hit rate.
263    pub fn calculate_hit_rate(&mut self) {
264        let hits = self.l1_hits + self.l2_hits + self.l3_hits;
265        self.total_accesses = hits + self.misses;
266
267        if self.total_accesses > 0 {
268            self.hit_rate = hits as f64 / self.total_accesses as f64;
269        }
270    }
271
272    /// Calculate average latency.
273    pub fn calculate_avg_latency(&mut self) {
274        if self.total_accesses == 0 {
275            return;
276        }
277
278        let total_latency = self.l1_hits * CacheLevel::L1.typical_latency_cycles()
279            + self.l2_hits * CacheLevel::L2.typical_latency_cycles()
280            + self.l3_hits * CacheLevel::L3.typical_latency_cycles()
281            + self.misses * 200; // Memory access ~200 cycles
282
283        self.avg_latency_cycles = total_latency as f64 / self.total_accesses as f64;
284    }
285
286    /// Estimate memory bandwidth usage.
287    pub fn estimate_bandwidth(&mut self, data_size_bytes: usize, time_secs: f64) {
288        if time_secs > 0.0 {
289            let gb = data_size_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
290            self.memory_bandwidth_gbs = gb / time_secs;
291        }
292    }
293}
294
295impl Default for CacheMetrics {
296    fn default() -> Self {
297        Self::new()
298    }
299}
300
301impl std::fmt::Display for CacheMetrics {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        writeln!(f, "Cache Metrics")?;
304        writeln!(f, "=============")?;
305        writeln!(f, "Hit rate:      {:.2}%", self.hit_rate * 100.0)?;
306        writeln!(f, "L1 hits:       {}", self.l1_hits)?;
307        writeln!(f, "L2 hits:       {}", self.l2_hits)?;
308        writeln!(f, "L3 hits:       {}", self.l3_hits)?;
309        writeln!(f, "Misses:        {}", self.misses)?;
310        writeln!(f, "Total accesses: {}", self.total_accesses)?;
311        writeln!(f, "Avg latency:   {:.1} cycles", self.avg_latency_cycles)?;
312        writeln!(f, "Bandwidth:     {:.2} GB/s", self.memory_bandwidth_gbs)?;
313        Ok(())
314    }
315}
316
317/// Data layout strategy.
318#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
319pub enum DataLayout {
320    /// Row-major layout (C-style)
321    RowMajor,
322
323    /// Column-major layout (Fortran-style)
324    ColumnMajor,
325
326    /// Blocked/tiled layout
327    Blocked { block_size: usize },
328
329    /// Z-order (Morton) layout
330    ZOrder,
331
332    /// Hilbert curve layout
333    Hilbert,
334}
335
336impl DataLayout {
337    /// Get cache efficiency score (0.0-1.0).
338    pub fn cache_efficiency(&self, access_pattern: AccessPattern) -> f64 {
339        match (self, access_pattern) {
340            (DataLayout::RowMajor, AccessPattern::Sequential) => 1.0,
341            (DataLayout::RowMajor, AccessPattern::Strided) => 0.5,
342            (DataLayout::ColumnMajor, AccessPattern::Sequential) => 0.5,
343            (DataLayout::Blocked { .. }, _) => 0.8,
344            (DataLayout::ZOrder, _) => 0.7,
345            (DataLayout::Hilbert, _) => 0.75,
346            _ => 0.3,
347        }
348    }
349}
350
351/// Memory access pattern.
352#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
353pub enum AccessPattern {
354    /// Sequential access
355    Sequential,
356
357    /// Strided access
358    Strided,
359
360    /// Random access
361    Random,
362
363    /// Block access
364    Block,
365}
366
367/// Cache-aware optimizer.
368pub struct CacheOptimizer {
369    /// Cache configuration
370    config: CacheConfig,
371
372    /// Optimization statistics
373    stats: OptimizationStats,
374}
375
376impl CacheOptimizer {
377    /// Create a new cache optimizer.
378    pub fn new(config: CacheConfig) -> Self {
379        Self {
380            config,
381            stats: OptimizationStats::default(),
382        }
383    }
384
385    /// Estimate cache metrics for a workload.
386    pub fn estimate_cache_metrics(&self, data_size_bytes: usize) -> CacheMetrics {
387        let mut metrics = CacheMetrics::new();
388
389        // Simplified cache simulation
390        let cache_size_bytes = self.config.l1_size_kb * 1024;
391
392        if data_size_bytes <= cache_size_bytes {
393            // Fits in L1
394            metrics.l1_hits = 100;
395            metrics.l2_hits = 0;
396            metrics.l3_hits = 0;
397            metrics.misses = 10;
398        } else if data_size_bytes <= self.config.l2_size_kb * 1024 {
399            // Fits in L2
400            metrics.l1_hits = 50;
401            metrics.l2_hits = 40;
402            metrics.l3_hits = 0;
403            metrics.misses = 10;
404        } else {
405            // Doesn't fit in cache
406            metrics.l1_hits = 30;
407            metrics.l2_hits = 30;
408            metrics.l3_hits = 20;
409            metrics.misses = 20;
410        }
411
412        metrics.calculate_hit_rate();
413        metrics.calculate_avg_latency();
414
415        metrics
416    }
417
418    /// Compute optimal tiling parameters.
419    pub fn compute_tiling_params(
420        &self,
421        _matrix_size: (usize, usize),
422        element_size: usize,
423    ) -> TilingParams {
424        // Target L2 cache
425        let target_cache_kb = self.config.l2_size_kb / 2; // Use half for safety
426        TilingParams::for_cache_size(target_cache_kb, element_size)
427    }
428
429    /// Recommend data layout for access pattern.
430    pub fn recommend_layout(&self, access_pattern: AccessPattern) -> DataLayout {
431        match access_pattern {
432            AccessPattern::Sequential => DataLayout::RowMajor,
433            AccessPattern::Strided => DataLayout::Blocked { block_size: 64 },
434            AccessPattern::Random => DataLayout::ZOrder,
435            AccessPattern::Block => DataLayout::Blocked { block_size: 128 },
436        }
437    }
438
439    /// Get optimization statistics.
440    pub fn stats(&self) -> &OptimizationStats {
441        &self.stats
442    }
443}
444
445/// Optimization statistics.
446#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
447pub struct OptimizationStats {
448    /// Number of graphs optimized
449    pub graphs_optimized: usize,
450
451    /// Number of tiling transformations applied
452    pub tiling_applied: usize,
453
454    /// Number of layout optimizations
455    pub layout_optimizations: usize,
456
457    /// Number of prefetch insertions
458    pub prefetch_insertions: usize,
459
460    /// Estimated performance improvement (%)
461    pub estimated_improvement_pct: f64,
462}
463
464impl std::fmt::Display for OptimizationStats {
465    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466        writeln!(f, "Cache Optimization Statistics")?;
467        writeln!(f, "=============================")?;
468        writeln!(f, "Graphs optimized:    {}", self.graphs_optimized)?;
469        writeln!(f, "Tiling applied:      {}", self.tiling_applied)?;
470        writeln!(f, "Layout opts:         {}", self.layout_optimizations)?;
471        writeln!(f, "Prefetch inserts:    {}", self.prefetch_insertions)?;
472        writeln!(
473            f,
474            "Est. improvement:    {:.1}%",
475            self.estimated_improvement_pct
476        )?;
477        Ok(())
478    }
479}
480
481#[cfg(test)]
482mod tests {
483    use super::*;
484
485    #[test]
486    fn test_cache_level_sizes() {
487        assert_eq!(CacheLevel::L1.typical_size_kb(), 32);
488        assert_eq!(CacheLevel::L2.typical_size_kb(), 256);
489        assert_eq!(CacheLevel::L3.typical_size_kb(), 8192);
490    }
491
492    #[test]
493    fn test_cache_level_latency() {
494        assert_eq!(CacheLevel::L1.typical_latency_cycles(), 4);
495        assert_eq!(CacheLevel::L2.typical_latency_cycles(), 12);
496        assert_eq!(CacheLevel::L3.typical_latency_cycles(), 40);
497    }
498
499    #[test]
500    fn test_cache_config_default() {
501        let config = CacheConfig::default();
502        assert_eq!(config.l1_size_kb, 32);
503        assert_eq!(config.l2_size_kb, 256);
504        assert_eq!(config.cache_line_size, 64);
505    }
506
507    #[test]
508    fn test_cache_config_builders() {
509        let config = CacheConfig::default()
510            .with_l1_size(64)
511            .with_l2_size(512)
512            .with_tiling_enabled(true)
513            .with_prefetch_distance(16);
514
515        assert_eq!(config.l1_size_kb, 64);
516        assert_eq!(config.l2_size_kb, 512);
517        assert!(config.enable_tiling);
518        assert_eq!(config.prefetch_distance, 16);
519    }
520
521    #[test]
522    fn test_cache_config_total_size() {
523        let config = CacheConfig::default();
524        let total = config.total_size_kb();
525        assert_eq!(total, 32 + 256 + 8192);
526    }
527
528    #[test]
529    fn test_tiling_params_for_cache_size() {
530        let params = TilingParams::for_cache_size(256, 8);
531        assert!(params.tile_i > 0);
532        assert!(params.tile_j > 0);
533        assert!(params.tile_k > 0);
534    }
535
536    #[test]
537    fn test_tiling_params_validate() {
538        let params = TilingParams {
539            tile_i: 64,
540            tile_j: 64,
541            tile_k: 64,
542            target_level: CacheLevel::L2,
543        };
544        assert!(params.validate().is_ok());
545
546        let invalid = TilingParams {
547            tile_i: 0,
548            tile_j: 64,
549            tile_k: 64,
550            target_level: CacheLevel::L2,
551        };
552        assert!(invalid.validate().is_err());
553    }
554
555    #[test]
556    fn test_cache_metrics_calculate_hit_rate() {
557        let mut metrics = CacheMetrics::new();
558        metrics.l1_hits = 70;
559        metrics.l2_hits = 20;
560        metrics.l3_hits = 5;
561        metrics.misses = 5;
562
563        metrics.calculate_hit_rate();
564        assert_eq!(metrics.total_accesses, 100);
565        assert!((metrics.hit_rate - 0.95).abs() < 0.01);
566    }
567
568    #[test]
569    fn test_cache_metrics_calculate_latency() {
570        let mut metrics = CacheMetrics::new();
571        metrics.l1_hits = 100;
572        metrics.l2_hits = 0;
573        metrics.l3_hits = 0;
574        metrics.misses = 0;
575        metrics.total_accesses = 100;
576
577        metrics.calculate_avg_latency();
578        assert_eq!(metrics.avg_latency_cycles, 4.0);
579    }
580
581    #[test]
582    fn test_cache_metrics_estimate_bandwidth() {
583        let mut metrics = CacheMetrics::new();
584        metrics.estimate_bandwidth(1024 * 1024 * 1024, 1.0); // 1 GB in 1 second
585        assert!((metrics.memory_bandwidth_gbs - 1.0).abs() < 0.01);
586    }
587
588    #[test]
589    fn test_cache_metrics_display() {
590        let mut metrics = CacheMetrics::new();
591        metrics.l1_hits = 70;
592        metrics.l2_hits = 20;
593        metrics.misses = 10;
594        metrics.calculate_hit_rate();
595
596        let display = format!("{}", metrics);
597        assert!(display.contains("Hit rate:"));
598        assert!(display.contains("L1 hits:"));
599    }
600
601    #[test]
602    fn test_data_layout_cache_efficiency() {
603        let eff = DataLayout::RowMajor.cache_efficiency(AccessPattern::Sequential);
604        assert_eq!(eff, 1.0);
605
606        let eff = DataLayout::RowMajor.cache_efficiency(AccessPattern::Strided);
607        assert_eq!(eff, 0.5);
608    }
609
610    #[test]
611    fn test_cache_optimizer_creation() {
612        let config = CacheConfig::default();
613        let optimizer = CacheOptimizer::new(config);
614        assert_eq!(optimizer.stats().graphs_optimized, 0);
615    }
616
617    #[test]
618    fn test_cache_optimizer_estimate_metrics() {
619        let config = CacheConfig::default();
620        let optimizer = CacheOptimizer::new(config);
621
622        let metrics = optimizer.estimate_cache_metrics(16 * 1024); // 16 KB
623        assert!(metrics.hit_rate > 0.0);
624    }
625
626    #[test]
627    fn test_cache_optimizer_compute_tiling() {
628        let config = CacheConfig::default();
629        let optimizer = CacheOptimizer::new(config);
630
631        let params = optimizer.compute_tiling_params((1000, 1000), 8);
632        assert!(params.tile_i > 0);
633        assert!(params.validate().is_ok());
634    }
635
636    #[test]
637    fn test_cache_optimizer_recommend_layout() {
638        let config = CacheConfig::default();
639        let optimizer = CacheOptimizer::new(config);
640
641        let layout = optimizer.recommend_layout(AccessPattern::Sequential);
642        assert_eq!(layout, DataLayout::RowMajor);
643
644        let layout = optimizer.recommend_layout(AccessPattern::Random);
645        assert_eq!(layout, DataLayout::ZOrder);
646    }
647
648    #[test]
649    fn test_optimization_stats_display() {
650        let mut stats = OptimizationStats::default();
651        stats.graphs_optimized = 10;
652        stats.tiling_applied = 5;
653        stats.estimated_improvement_pct = 25.0;
654
655        let display = format!("{}", stats);
656        assert!(display.contains("Graphs optimized:    10"));
657        assert!(display.contains("25.0%"));
658    }
659}