1use serde::{Deserialize, Serialize};
31use thiserror::Error;
32
33#[derive(Error, Debug, Clone, PartialEq)]
35pub enum CacheOptimizerError {
36 #[error("Invalid cache configuration: {0}")]
37 InvalidConfig(String),
38
39 #[error("Optimization failed: {0}")]
40 OptimizationFailed(String),
41
42 #[error("Insufficient cache size: required {required} KB, available {available} KB")]
43 InsufficientCache { required: usize, available: usize },
44
45 #[error("Invalid tiling parameters: {0}")]
46 InvalidTiling(String),
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
51pub enum CacheLevel {
52 L1,
53 L2,
54 L3,
55 LLC, }
57
58impl CacheLevel {
59 pub fn typical_size_kb(&self) -> usize {
61 match self {
62 CacheLevel::L1 => 32,
63 CacheLevel::L2 => 256,
64 CacheLevel::L3 => 8192,
65 CacheLevel::LLC => 32768,
66 }
67 }
68
69 pub fn typical_latency_cycles(&self) -> usize {
71 match self {
72 CacheLevel::L1 => 4,
73 CacheLevel::L2 => 12,
74 CacheLevel::L3 => 40,
75 CacheLevel::LLC => 100,
76 }
77 }
78}
79
80#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
82pub struct CacheConfig {
83 pub l1_size_kb: usize,
85
86 pub l2_size_kb: usize,
88
89 pub l3_size_kb: usize,
91
92 pub cache_line_size: usize,
94
95 pub associativity: usize,
97
98 pub enable_tiling: bool,
100
101 pub enable_prefetch: bool,
103
104 pub prefetch_distance: usize,
106
107 pub enable_layout_optimization: bool,
109}
110
111impl Default for CacheConfig {
112 fn default() -> Self {
113 Self {
114 l1_size_kb: 32,
115 l2_size_kb: 256,
116 l3_size_kb: 8192,
117 cache_line_size: 64,
118 associativity: 8,
119 enable_tiling: true,
120 enable_prefetch: true,
121 prefetch_distance: 8,
122 enable_layout_optimization: true,
123 }
124 }
125}
126
127impl CacheConfig {
128 pub fn from_system() -> Self {
130 Self::default()
132 }
133
134 pub fn with_l1_size(mut self, size_kb: usize) -> Self {
136 self.l1_size_kb = size_kb;
137 self
138 }
139
140 pub fn with_l2_size(mut self, size_kb: usize) -> Self {
142 self.l2_size_kb = size_kb;
143 self
144 }
145
146 pub fn with_l3_size(mut self, size_kb: usize) -> Self {
148 self.l3_size_kb = size_kb;
149 self
150 }
151
152 pub fn with_tiling_enabled(mut self, enabled: bool) -> Self {
154 self.enable_tiling = enabled;
155 self
156 }
157
158 pub fn with_prefetch_enabled(mut self, enabled: bool) -> Self {
160 self.enable_prefetch = enabled;
161 self
162 }
163
164 pub fn with_prefetch_distance(mut self, distance: usize) -> Self {
166 self.prefetch_distance = distance;
167 self
168 }
169
170 pub fn total_size_kb(&self) -> usize {
172 self.l1_size_kb + self.l2_size_kb + self.l3_size_kb
173 }
174}
175
176#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
178pub struct TilingParams {
179 pub tile_i: usize,
181
182 pub tile_j: usize,
184
185 pub tile_k: usize,
187
188 pub target_level: CacheLevel,
190}
191
192impl TilingParams {
193 pub fn for_cache_size(cache_size_kb: usize, element_size: usize) -> Self {
195 let cache_bytes = cache_size_kb * 1024;
197 let elements_per_tile = (cache_bytes / 3) / element_size; let tile_size = (elements_per_tile as f64).sqrt() as usize;
199
200 Self {
201 tile_i: tile_size,
202 tile_j: tile_size,
203 tile_k: tile_size,
204 target_level: CacheLevel::L2,
205 }
206 }
207
208 pub fn validate(&self) -> Result<(), CacheOptimizerError> {
210 if self.tile_i == 0 || self.tile_j == 0 || self.tile_k == 0 {
211 return Err(CacheOptimizerError::InvalidTiling(
212 "Tile sizes must be > 0".to_string(),
213 ));
214 }
215 Ok(())
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
221pub struct CacheMetrics {
222 pub hit_rate: f64,
224
225 pub l1_hits: usize,
227
228 pub l2_hits: usize,
230
231 pub l3_hits: usize,
233
234 pub misses: usize,
236
237 pub total_accesses: usize,
239
240 pub memory_bandwidth_gbs: f64,
242
243 pub avg_latency_cycles: f64,
245}
246
247impl CacheMetrics {
248 pub fn new() -> Self {
250 Self {
251 hit_rate: 0.0,
252 l1_hits: 0,
253 l2_hits: 0,
254 l3_hits: 0,
255 misses: 0,
256 total_accesses: 0,
257 memory_bandwidth_gbs: 0.0,
258 avg_latency_cycles: 0.0,
259 }
260 }
261
262 pub fn calculate_hit_rate(&mut self) {
264 let hits = self.l1_hits + self.l2_hits + self.l3_hits;
265 self.total_accesses = hits + self.misses;
266
267 if self.total_accesses > 0 {
268 self.hit_rate = hits as f64 / self.total_accesses as f64;
269 }
270 }
271
272 pub fn calculate_avg_latency(&mut self) {
274 if self.total_accesses == 0 {
275 return;
276 }
277
278 let total_latency = self.l1_hits * CacheLevel::L1.typical_latency_cycles()
279 + self.l2_hits * CacheLevel::L2.typical_latency_cycles()
280 + self.l3_hits * CacheLevel::L3.typical_latency_cycles()
281 + self.misses * 200; self.avg_latency_cycles = total_latency as f64 / self.total_accesses as f64;
284 }
285
286 pub fn estimate_bandwidth(&mut self, data_size_bytes: usize, time_secs: f64) {
288 if time_secs > 0.0 {
289 let gb = data_size_bytes as f64 / (1024.0 * 1024.0 * 1024.0);
290 self.memory_bandwidth_gbs = gb / time_secs;
291 }
292 }
293}
294
295impl Default for CacheMetrics {
296 fn default() -> Self {
297 Self::new()
298 }
299}
300
301impl std::fmt::Display for CacheMetrics {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 writeln!(f, "Cache Metrics")?;
304 writeln!(f, "=============")?;
305 writeln!(f, "Hit rate: {:.2}%", self.hit_rate * 100.0)?;
306 writeln!(f, "L1 hits: {}", self.l1_hits)?;
307 writeln!(f, "L2 hits: {}", self.l2_hits)?;
308 writeln!(f, "L3 hits: {}", self.l3_hits)?;
309 writeln!(f, "Misses: {}", self.misses)?;
310 writeln!(f, "Total accesses: {}", self.total_accesses)?;
311 writeln!(f, "Avg latency: {:.1} cycles", self.avg_latency_cycles)?;
312 writeln!(f, "Bandwidth: {:.2} GB/s", self.memory_bandwidth_gbs)?;
313 Ok(())
314 }
315}
316
317#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
319pub enum DataLayout {
320 RowMajor,
322
323 ColumnMajor,
325
326 Blocked { block_size: usize },
328
329 ZOrder,
331
332 Hilbert,
334}
335
336impl DataLayout {
337 pub fn cache_efficiency(&self, access_pattern: AccessPattern) -> f64 {
339 match (self, access_pattern) {
340 (DataLayout::RowMajor, AccessPattern::Sequential) => 1.0,
341 (DataLayout::RowMajor, AccessPattern::Strided) => 0.5,
342 (DataLayout::ColumnMajor, AccessPattern::Sequential) => 0.5,
343 (DataLayout::Blocked { .. }, _) => 0.8,
344 (DataLayout::ZOrder, _) => 0.7,
345 (DataLayout::Hilbert, _) => 0.75,
346 _ => 0.3,
347 }
348 }
349}
350
351#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
353pub enum AccessPattern {
354 Sequential,
356
357 Strided,
359
360 Random,
362
363 Block,
365}
366
367pub struct CacheOptimizer {
369 config: CacheConfig,
371
372 stats: OptimizationStats,
374}
375
376impl CacheOptimizer {
377 pub fn new(config: CacheConfig) -> Self {
379 Self {
380 config,
381 stats: OptimizationStats::default(),
382 }
383 }
384
385 pub fn estimate_cache_metrics(&self, data_size_bytes: usize) -> CacheMetrics {
387 let mut metrics = CacheMetrics::new();
388
389 let cache_size_bytes = self.config.l1_size_kb * 1024;
391
392 if data_size_bytes <= cache_size_bytes {
393 metrics.l1_hits = 100;
395 metrics.l2_hits = 0;
396 metrics.l3_hits = 0;
397 metrics.misses = 10;
398 } else if data_size_bytes <= self.config.l2_size_kb * 1024 {
399 metrics.l1_hits = 50;
401 metrics.l2_hits = 40;
402 metrics.l3_hits = 0;
403 metrics.misses = 10;
404 } else {
405 metrics.l1_hits = 30;
407 metrics.l2_hits = 30;
408 metrics.l3_hits = 20;
409 metrics.misses = 20;
410 }
411
412 metrics.calculate_hit_rate();
413 metrics.calculate_avg_latency();
414
415 metrics
416 }
417
418 pub fn compute_tiling_params(
420 &self,
421 _matrix_size: (usize, usize),
422 element_size: usize,
423 ) -> TilingParams {
424 let target_cache_kb = self.config.l2_size_kb / 2; TilingParams::for_cache_size(target_cache_kb, element_size)
427 }
428
429 pub fn recommend_layout(&self, access_pattern: AccessPattern) -> DataLayout {
431 match access_pattern {
432 AccessPattern::Sequential => DataLayout::RowMajor,
433 AccessPattern::Strided => DataLayout::Blocked { block_size: 64 },
434 AccessPattern::Random => DataLayout::ZOrder,
435 AccessPattern::Block => DataLayout::Blocked { block_size: 128 },
436 }
437 }
438
439 pub fn stats(&self) -> &OptimizationStats {
441 &self.stats
442 }
443}
444
445#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
447pub struct OptimizationStats {
448 pub graphs_optimized: usize,
450
451 pub tiling_applied: usize,
453
454 pub layout_optimizations: usize,
456
457 pub prefetch_insertions: usize,
459
460 pub estimated_improvement_pct: f64,
462}
463
464impl std::fmt::Display for OptimizationStats {
465 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
466 writeln!(f, "Cache Optimization Statistics")?;
467 writeln!(f, "=============================")?;
468 writeln!(f, "Graphs optimized: {}", self.graphs_optimized)?;
469 writeln!(f, "Tiling applied: {}", self.tiling_applied)?;
470 writeln!(f, "Layout opts: {}", self.layout_optimizations)?;
471 writeln!(f, "Prefetch inserts: {}", self.prefetch_insertions)?;
472 writeln!(
473 f,
474 "Est. improvement: {:.1}%",
475 self.estimated_improvement_pct
476 )?;
477 Ok(())
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_cache_level_sizes() {
487 assert_eq!(CacheLevel::L1.typical_size_kb(), 32);
488 assert_eq!(CacheLevel::L2.typical_size_kb(), 256);
489 assert_eq!(CacheLevel::L3.typical_size_kb(), 8192);
490 }
491
492 #[test]
493 fn test_cache_level_latency() {
494 assert_eq!(CacheLevel::L1.typical_latency_cycles(), 4);
495 assert_eq!(CacheLevel::L2.typical_latency_cycles(), 12);
496 assert_eq!(CacheLevel::L3.typical_latency_cycles(), 40);
497 }
498
499 #[test]
500 fn test_cache_config_default() {
501 let config = CacheConfig::default();
502 assert_eq!(config.l1_size_kb, 32);
503 assert_eq!(config.l2_size_kb, 256);
504 assert_eq!(config.cache_line_size, 64);
505 }
506
507 #[test]
508 fn test_cache_config_builders() {
509 let config = CacheConfig::default()
510 .with_l1_size(64)
511 .with_l2_size(512)
512 .with_tiling_enabled(true)
513 .with_prefetch_distance(16);
514
515 assert_eq!(config.l1_size_kb, 64);
516 assert_eq!(config.l2_size_kb, 512);
517 assert!(config.enable_tiling);
518 assert_eq!(config.prefetch_distance, 16);
519 }
520
521 #[test]
522 fn test_cache_config_total_size() {
523 let config = CacheConfig::default();
524 let total = config.total_size_kb();
525 assert_eq!(total, 32 + 256 + 8192);
526 }
527
528 #[test]
529 fn test_tiling_params_for_cache_size() {
530 let params = TilingParams::for_cache_size(256, 8);
531 assert!(params.tile_i > 0);
532 assert!(params.tile_j > 0);
533 assert!(params.tile_k > 0);
534 }
535
536 #[test]
537 fn test_tiling_params_validate() {
538 let params = TilingParams {
539 tile_i: 64,
540 tile_j: 64,
541 tile_k: 64,
542 target_level: CacheLevel::L2,
543 };
544 assert!(params.validate().is_ok());
545
546 let invalid = TilingParams {
547 tile_i: 0,
548 tile_j: 64,
549 tile_k: 64,
550 target_level: CacheLevel::L2,
551 };
552 assert!(invalid.validate().is_err());
553 }
554
555 #[test]
556 fn test_cache_metrics_calculate_hit_rate() {
557 let mut metrics = CacheMetrics::new();
558 metrics.l1_hits = 70;
559 metrics.l2_hits = 20;
560 metrics.l3_hits = 5;
561 metrics.misses = 5;
562
563 metrics.calculate_hit_rate();
564 assert_eq!(metrics.total_accesses, 100);
565 assert!((metrics.hit_rate - 0.95).abs() < 0.01);
566 }
567
568 #[test]
569 fn test_cache_metrics_calculate_latency() {
570 let mut metrics = CacheMetrics::new();
571 metrics.l1_hits = 100;
572 metrics.l2_hits = 0;
573 metrics.l3_hits = 0;
574 metrics.misses = 0;
575 metrics.total_accesses = 100;
576
577 metrics.calculate_avg_latency();
578 assert_eq!(metrics.avg_latency_cycles, 4.0);
579 }
580
581 #[test]
582 fn test_cache_metrics_estimate_bandwidth() {
583 let mut metrics = CacheMetrics::new();
584 metrics.estimate_bandwidth(1024 * 1024 * 1024, 1.0); assert!((metrics.memory_bandwidth_gbs - 1.0).abs() < 0.01);
586 }
587
588 #[test]
589 fn test_cache_metrics_display() {
590 let mut metrics = CacheMetrics::new();
591 metrics.l1_hits = 70;
592 metrics.l2_hits = 20;
593 metrics.misses = 10;
594 metrics.calculate_hit_rate();
595
596 let display = format!("{}", metrics);
597 assert!(display.contains("Hit rate:"));
598 assert!(display.contains("L1 hits:"));
599 }
600
601 #[test]
602 fn test_data_layout_cache_efficiency() {
603 let eff = DataLayout::RowMajor.cache_efficiency(AccessPattern::Sequential);
604 assert_eq!(eff, 1.0);
605
606 let eff = DataLayout::RowMajor.cache_efficiency(AccessPattern::Strided);
607 assert_eq!(eff, 0.5);
608 }
609
610 #[test]
611 fn test_cache_optimizer_creation() {
612 let config = CacheConfig::default();
613 let optimizer = CacheOptimizer::new(config);
614 assert_eq!(optimizer.stats().graphs_optimized, 0);
615 }
616
617 #[test]
618 fn test_cache_optimizer_estimate_metrics() {
619 let config = CacheConfig::default();
620 let optimizer = CacheOptimizer::new(config);
621
622 let metrics = optimizer.estimate_cache_metrics(16 * 1024); assert!(metrics.hit_rate > 0.0);
624 }
625
626 #[test]
627 fn test_cache_optimizer_compute_tiling() {
628 let config = CacheConfig::default();
629 let optimizer = CacheOptimizer::new(config);
630
631 let params = optimizer.compute_tiling_params((1000, 1000), 8);
632 assert!(params.tile_i > 0);
633 assert!(params.validate().is_ok());
634 }
635
636 #[test]
637 fn test_cache_optimizer_recommend_layout() {
638 let config = CacheConfig::default();
639 let optimizer = CacheOptimizer::new(config);
640
641 let layout = optimizer.recommend_layout(AccessPattern::Sequential);
642 assert_eq!(layout, DataLayout::RowMajor);
643
644 let layout = optimizer.recommend_layout(AccessPattern::Random);
645 assert_eq!(layout, DataLayout::ZOrder);
646 }
647
648 #[test]
649 fn test_optimization_stats_display() {
650 let mut stats = OptimizationStats::default();
651 stats.graphs_optimized = 10;
652 stats.tiling_applied = 5;
653 stats.estimated_improvement_pct = 25.0;
654
655 let display = format!("{}", stats);
656 assert!(display.contains("Graphs optimized: 10"));
657 assert!(display.contains("25.0%"));
658 }
659}