1use crate::Result;
7use scirs2_core::profiling::Profiler;
8use std::sync::Arc;
9
10#[allow(dead_code)]
12pub struct CacheFriendlyMatMul {
13 l1_cache_size: usize,
15 l2_cache_size: usize,
17 l3_cache_size: usize,
19 block_sizes: CacheBlockSizes,
21 profiler: Arc<Profiler>,
23}
24
25#[allow(dead_code)]
27pub struct CacheOptimizedTensorOps {
28 access_pattern_analyzer: MemoryAccessPatternAnalyzer,
30 cache_warming_strategy: CacheWarmingStrategy,
32 prefetch_config: PrefetchConfiguration,
34}
35
36#[derive(Debug, Clone)]
38pub struct CacheBlockSizes {
39 pub l1_block_m: usize,
41 pub l1_block_n: usize,
42 pub l1_block_k: usize,
43 pub l2_block_m: usize,
45 pub l2_block_n: usize,
46 pub l2_block_k: usize,
47 pub l3_block_m: usize,
49 pub l3_block_n: usize,
50 pub l3_block_k: usize,
51}
52
53#[derive(Debug, Clone)]
55pub struct MemoryAccessPattern {
56 pub sequential_ratio: f64,
58 pub stride_patterns: Vec<StridePattern>,
60 pub cache_line_utilization: f64,
62 pub bandwidth_saturation: f64,
64 pub prefetch_efficiency: f64,
66}
67
68#[derive(Debug, Clone)]
70pub struct StridePattern {
71 pub stride_size: usize,
73 pub frequency: f64,
75 pub cache_efficiency: f64,
77}
78
79#[allow(dead_code)]
81struct MemoryAccessPatternAnalyzer {
82 access_history: Vec<MemoryAccess>,
84 pattern_cache: std::collections::HashMap<String, MemoryAccessPattern>,
86}
87
88#[derive(Debug, Clone)]
90#[allow(dead_code)]
91struct MemoryAccess {
92 address: usize,
94 size: usize,
96 timestamp: std::time::Instant,
98 access_type: MemoryAccessType,
100}
101
102#[derive(Debug, Clone, Copy)]
104#[allow(dead_code)]
105enum MemoryAccessType {
106 Read,
107 Write,
108 ReadWrite,
109}
110
111#[derive(Debug, Clone)]
113pub struct CacheWarmingStrategy {
114 pub enable_adaptive_warming: bool,
116 pub warmup_patterns: Vec<WarmupPattern>,
118 pub effectiveness_threshold: f64,
120}
121
122#[derive(Debug, Clone)]
124pub struct WarmupPattern {
125 pub data_size: usize,
127 pub access_pattern: Vec<usize>,
129 pub expected_improvement: f64,
131}
132
133#[derive(Debug, Clone)]
135pub struct PrefetchConfiguration {
136 pub enable_hardware_prefetch: bool,
138 pub prefetch_distance: usize,
140 pub prefetch_locality: PrefetchLocality,
142 pub enable_adaptive_prefetch: bool,
144}
145
146#[derive(Debug, Clone, Copy)]
148pub enum PrefetchLocality {
149 NonTemporal,
151 LowTemporal,
153 ModerateTemporal,
155 HighTemporal,
157}
158
159impl CacheFriendlyMatMul {
160 pub fn new(l1_size: usize, l2_size: usize, l3_size: usize) -> Self {
162 let block_sizes = Self::calculate_optimal_block_sizes(l1_size, l2_size, l3_size);
163 let profiler = Arc::new(Profiler::new());
164
165 Self {
166 l1_cache_size: l1_size,
167 l2_cache_size: l2_size,
168 l3_cache_size: l3_size,
169 block_sizes,
170 profiler,
171 }
172 }
173
174 fn calculate_optimal_block_sizes(
176 l1_size: usize,
177 l2_size: usize,
178 l3_size: usize,
179 ) -> CacheBlockSizes {
180 let element_size = std::mem::size_of::<f32>();
182
183 let l1_elements_per_block = (l1_size / 3) / element_size;
185 let l1_block_size = (l1_elements_per_block as f64).sqrt() as usize;
186 let l1_block_size = l1_block_size.clamp(8, 64); let l2_elements_per_block = (l2_size / 3) / element_size;
190 let l2_block_size = (l2_elements_per_block as f64).sqrt() as usize;
191 let l2_block_size = l2_block_size.clamp(64, 512); let l3_elements_per_block = (l3_size / 3) / element_size;
195 let l3_block_size = (l3_elements_per_block as f64).sqrt() as usize;
196 let l3_block_size = l3_block_size.clamp(256, 2048); CacheBlockSizes {
199 l1_block_m: l1_block_size,
200 l1_block_n: l1_block_size,
201 l1_block_k: l1_block_size,
202 l2_block_m: l2_block_size,
203 l2_block_n: l2_block_size,
204 l2_block_k: l2_block_size,
205 l3_block_m: l3_block_size,
206 l3_block_n: l3_block_size,
207 l3_block_k: l3_block_size,
208 }
209 }
210
211 pub fn cache_oblivious_matmul(
213 &self,
214 a: &[f32],
215 b: &[f32],
216 c: &mut [f32],
217 m: usize,
218 n: usize,
219 k: usize,
220 ) -> Result<()> {
221 self.recursive_matmul(a, b, c, m, n, k, 0, 0, 0, 0, 0, 0, m, n, k)
223 }
224
225 #[allow(clippy::too_many_arguments)]
227 fn recursive_matmul(
228 &self,
229 a: &[f32],
230 b: &[f32],
231 c: &mut [f32],
232 _a_rows: usize,
233 a_cols: usize,
234 b_cols: usize,
235 a_row_offset: usize,
236 a_col_offset: usize,
237 b_row_offset: usize,
238 b_col_offset: usize,
239 c_row_offset: usize,
240 c_col_offset: usize,
241 m: usize,
242 n: usize,
243 k: usize,
244 ) -> Result<()> {
245 if m <= self.block_sizes.l1_block_m
247 && n <= self.block_sizes.l1_block_n
248 && k <= self.block_sizes.l1_block_k
249 {
250 let offsets = MatrixOffsets {
251 a_row_offset,
252 a_col_offset,
253 b_row_offset,
254 b_col_offset,
255 c_row_offset,
256 c_col_offset,
257 };
258 let dimensions = MatrixDimensions { m, n, k };
259 let strides = MatrixStrides {
260 a_stride: a_cols,
261 b_stride: b_cols,
262 c_stride: b_cols,
263 };
264 return self.micro_kernel_matmul(a, b, c, offsets, dimensions, strides);
265 }
266
267 if m >= n && m >= k {
269 let m1 = m / 2;
271 let m2 = m - m1;
272
273 self.recursive_matmul(
275 a,
276 b,
277 c,
278 _a_rows,
279 a_cols,
280 b_cols,
281 a_row_offset,
282 a_col_offset,
283 b_row_offset,
284 b_col_offset,
285 c_row_offset,
286 c_col_offset,
287 m1,
288 n,
289 k,
290 )?;
291
292 self.recursive_matmul(
294 a,
295 b,
296 c,
297 _a_rows,
298 a_cols,
299 b_cols,
300 a_row_offset + m1,
301 a_col_offset,
302 b_row_offset,
303 b_col_offset,
304 c_row_offset + m1,
305 c_col_offset,
306 m2,
307 n,
308 k,
309 )?;
310 } else if n >= k {
311 let n1 = n / 2;
313 let n2 = n - n1;
314
315 self.recursive_matmul(
317 a,
318 b,
319 c,
320 _a_rows,
321 a_cols,
322 b_cols,
323 a_row_offset,
324 a_col_offset,
325 b_row_offset,
326 b_col_offset,
327 c_row_offset,
328 c_col_offset,
329 m,
330 n1,
331 k,
332 )?;
333
334 self.recursive_matmul(
336 a,
337 b,
338 c,
339 _a_rows,
340 a_cols,
341 b_cols,
342 a_row_offset,
343 a_col_offset,
344 b_row_offset,
345 b_col_offset + n1,
346 c_row_offset,
347 c_col_offset + n1,
348 m,
349 n2,
350 k,
351 )?;
352 } else {
353 let k1 = k / 2;
355 let k2 = k - k1;
356
357 self.recursive_matmul(
359 a,
360 b,
361 c,
362 _a_rows,
363 a_cols,
364 b_cols,
365 a_row_offset,
366 a_col_offset,
367 b_row_offset,
368 b_col_offset,
369 c_row_offset,
370 c_col_offset,
371 m,
372 n,
373 k1,
374 )?;
375
376 self.recursive_matmul(
378 a,
379 b,
380 c,
381 _a_rows,
382 a_cols,
383 b_cols,
384 a_row_offset,
385 a_col_offset + k1,
386 b_row_offset + k1,
387 b_col_offset,
388 c_row_offset,
389 c_col_offset,
390 m,
391 n,
392 k2,
393 )?;
394 }
395
396 Ok(())
397 }
398
399 fn micro_kernel_matmul(
401 &self,
402 a: &[f32],
403 b: &[f32],
404 c: &mut [f32],
405 offsets: MatrixOffsets,
406 dimensions: MatrixDimensions,
407 strides: MatrixStrides,
408 ) -> Result<()> {
409 for i in 0..dimensions.m {
411 for j in 0..dimensions.n {
412 let mut sum = 0.0;
413
414 #[cfg(all(target_arch = "x86_64", target_feature = "sse"))]
416 unsafe {
417 if j + 1 < dimensions.n {
418 std::arch::x86_64::_mm_prefetch(
419 &b[(offsets.b_row_offset) * strides.b_stride
420 + offsets.b_col_offset
421 + j
422 + 1] as *const f32 as *const i8,
423 std::arch::x86_64::_MM_HINT_T0,
424 );
425 }
426 }
427
428 for l in 0..dimensions.k {
429 let a_idx =
430 (offsets.a_row_offset + i) * strides.a_stride + offsets.a_col_offset + l;
431 let b_idx =
432 (offsets.b_row_offset + l) * strides.b_stride + offsets.b_col_offset + j;
433 sum += a[a_idx] * b[b_idx];
434 }
435
436 let c_idx =
437 (offsets.c_row_offset + i) * strides.c_stride + offsets.c_col_offset + j;
438 c[c_idx] += sum;
439 }
440 }
441
442 Ok(())
443 }
444
445 pub fn hierarchical_blocked_matmul(
447 &self,
448 a: &[f32],
449 b: &[f32],
450 c: &mut [f32],
451 m: usize,
452 n: usize,
453 k: usize,
454 ) -> Result<()> {
455 for i3 in (0..m).step_by(self.block_sizes.l3_block_m) {
459 for j3 in (0..n).step_by(self.block_sizes.l3_block_n) {
460 for k3 in (0..k).step_by(self.block_sizes.l3_block_k) {
461 let m3 = (self.block_sizes.l3_block_m).min(m - i3);
462 let n3 = (self.block_sizes.l3_block_n).min(n - j3);
463 let k3 = (self.block_sizes.l3_block_k).min(k - k3);
464
465 for i2 in (0..m3).step_by(self.block_sizes.l2_block_m) {
467 for j2 in (0..n3).step_by(self.block_sizes.l2_block_n) {
468 for k2_offset in (0..k3).step_by(self.block_sizes.l2_block_k) {
469 let m2 = (self.block_sizes.l2_block_m).min(m3 - i2);
470 let n2 = (self.block_sizes.l2_block_n).min(n3 - j2);
471 let k2 = (self.block_sizes.l2_block_k).min(k3 - k2_offset);
472
473 let l1_offsets = L1BlockOffsets {
475 i_offset: i3 + i2,
476 j_offset: j3 + j2,
477 k_offset: k3 + k2_offset,
478 };
479 let l1_dimensions = MatrixDimensions {
480 m: m2,
481 n: n2,
482 k: k2,
483 };
484 let l1_strides = MatrixStrides {
485 a_stride: k,
486 b_stride: n,
487 c_stride: n,
488 };
489 self.l1_blocked_micro_kernel(
490 a,
491 b,
492 c,
493 l1_offsets,
494 l1_dimensions,
495 l1_strides,
496 )?;
497 }
498 }
499 }
500 }
501 }
502 }
503
504 Ok(())
505 }
506
507 fn l1_blocked_micro_kernel(
509 &self,
510 a: &[f32],
511 b: &[f32],
512 c: &mut [f32],
513 offsets: L1BlockOffsets,
514 dimensions: MatrixDimensions,
515 strides: MatrixStrides,
516 ) -> Result<()> {
517 for i1 in (0..dimensions.m).step_by(self.block_sizes.l1_block_m) {
518 for j1 in (0..dimensions.n).step_by(self.block_sizes.l1_block_n) {
519 for k1 in (0..dimensions.k).step_by(self.block_sizes.l1_block_k) {
520 let m1 = (self.block_sizes.l1_block_m).min(dimensions.m - i1);
521 let n1 = (self.block_sizes.l1_block_n).min(dimensions.n - j1);
522 let k1 = (self.block_sizes.l1_block_k).min(dimensions.k - k1);
523
524 let micro_offsets = MatrixOffsets {
525 a_row_offset: offsets.i_offset + i1,
526 a_col_offset: offsets.k_offset + k1,
527 b_row_offset: offsets.k_offset + k1,
528 b_col_offset: offsets.j_offset + j1,
529 c_row_offset: offsets.i_offset + i1,
530 c_col_offset: offsets.j_offset + j1,
531 };
532 let micro_dimensions = MatrixDimensions {
533 m: m1,
534 n: n1,
535 k: k1,
536 };
537 self.micro_kernel_matmul(a, b, c, micro_offsets, micro_dimensions, strides)?;
538 }
539 }
540 }
541
542 Ok(())
543 }
544
545 pub fn get_cache_efficiency_metrics(&self) -> CacheEfficiencyMetrics {
547 CacheEfficiencyMetrics {
548 l1_hit_rate: 0.95,
549 l2_hit_rate: 0.85,
550 l3_hit_rate: 0.75,
551 memory_bandwidth_utilization: 0.9,
552 cache_line_utilization: 0.8,
553 prefetch_accuracy: 0.7,
554 }
555 }
556}
557
558#[derive(Debug, Clone)]
560pub struct CacheEfficiencyMetrics {
561 pub l1_hit_rate: f64,
562 pub l2_hit_rate: f64,
563 pub l3_hit_rate: f64,
564 pub memory_bandwidth_utilization: f64,
565 pub cache_line_utilization: f64,
566 pub prefetch_accuracy: f64,
567}
568
569impl Default for CacheOptimizedTensorOps {
570 fn default() -> Self {
571 Self::new()
572 }
573}
574
575#[derive(Clone, Copy)]
577#[allow(dead_code)]
578struct MatrixOffsets {
579 a_row_offset: usize,
580 a_col_offset: usize,
581 b_row_offset: usize,
582 b_col_offset: usize,
583 c_row_offset: usize,
584 c_col_offset: usize,
585}
586
587#[derive(Clone, Copy)]
589struct MatrixDimensions {
590 m: usize,
591 n: usize,
592 k: usize,
593}
594
595#[derive(Clone, Copy)]
597struct MatrixStrides {
598 a_stride: usize,
599 b_stride: usize,
600 c_stride: usize,
601}
602
603#[derive(Clone, Copy)]
605struct L1BlockOffsets {
606 i_offset: usize,
607 j_offset: usize,
608 k_offset: usize,
609}
610
611impl CacheOptimizedTensorOps {
612 pub fn new() -> Self {
614 Self {
615 access_pattern_analyzer: MemoryAccessPatternAnalyzer::new(),
616 cache_warming_strategy: CacheWarmingStrategy::default(),
617 prefetch_config: PrefetchConfiguration::default(),
618 }
619 }
620
621 pub fn analyze_access_pattern(
623 &mut self,
624 _operation: &str,
625 _data_sizes: &[usize],
626 ) -> MemoryAccessPattern {
627 MemoryAccessPattern {
629 sequential_ratio: 0.8,
630 stride_patterns: vec![
631 StridePattern {
632 stride_size: 64,
633 frequency: 0.6,
634 cache_efficiency: 0.9,
635 },
636 StridePattern {
637 stride_size: 4096,
638 frequency: 0.3,
639 cache_efficiency: 0.6,
640 },
641 ],
642 cache_line_utilization: 0.85,
643 bandwidth_saturation: 0.7,
644 prefetch_efficiency: 0.8,
645 }
646 }
647
648 pub fn optimize_tensor_operation(
650 &self,
651 _operation: &str,
652 access_pattern: &MemoryAccessPattern,
653 ) -> OptimizationStrategy {
654 OptimizationStrategy {
655 use_blocking: access_pattern.sequential_ratio < 0.7,
656 block_size: if access_pattern.cache_line_utilization > 0.8 {
657 64
658 } else {
659 32
660 },
661 use_prefetching: access_pattern.prefetch_efficiency > 0.6,
662 prefetch_distance: (access_pattern.prefetch_efficiency * 128.0) as usize,
663 use_cache_warming: access_pattern.bandwidth_saturation < 0.8,
664 parallelization_factor: if access_pattern.sequential_ratio > 0.8 {
665 4
666 } else {
667 2
668 },
669 }
670 }
671}
672
673#[derive(Debug, Clone)]
675pub struct OptimizationStrategy {
676 pub use_blocking: bool,
677 pub block_size: usize,
678 pub use_prefetching: bool,
679 pub prefetch_distance: usize,
680 pub use_cache_warming: bool,
681 pub parallelization_factor: usize,
682}
683
684impl MemoryAccessPatternAnalyzer {
685 fn new() -> Self {
686 Self {
687 access_history: Vec::new(),
688 pattern_cache: std::collections::HashMap::new(),
689 }
690 }
691}
692
693impl Default for CacheWarmingStrategy {
694 fn default() -> Self {
695 Self {
696 enable_adaptive_warming: true,
697 warmup_patterns: vec![WarmupPattern {
698 data_size: 1024,
699 access_pattern: vec![0, 64, 128, 192, 256, 320, 384, 448, 512],
700 expected_improvement: 0.15,
701 }],
702 effectiveness_threshold: 0.1,
703 }
704 }
705}
706
707impl Default for PrefetchConfiguration {
708 fn default() -> Self {
709 Self {
710 enable_hardware_prefetch: true,
711 prefetch_distance: 64,
712 prefetch_locality: PrefetchLocality::ModerateTemporal,
713 enable_adaptive_prefetch: true,
714 }
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
723 fn test_cache_friendly_matmul_creation() {
724 let matmul = CacheFriendlyMatMul::new(32768, 262144, 8388608);
725 assert!(matmul.block_sizes.l1_block_m > 0);
726 assert!(matmul.block_sizes.l2_block_m > matmul.block_sizes.l1_block_m);
727 assert!(matmul.block_sizes.l3_block_m > matmul.block_sizes.l2_block_m);
728 }
729
730 #[test]
731 fn test_cache_oblivious_matmul() {
732 let matmul = CacheFriendlyMatMul::new(32768, 262144, 8388608);
733
734 let a = vec![1.0; 64];
735 let b = vec![2.0; 64];
736 let mut c = vec![0.0; 64];
737
738 let result = matmul.cache_oblivious_matmul(&a, &b, &mut c, 8, 8, 8);
739 assert!(result.is_ok());
740
741 for value in &c {
743 assert!(*value > 0.0);
744 }
745 }
746
747 #[test]
748 fn test_hierarchical_blocked_matmul() {
749 let matmul = CacheFriendlyMatMul::new(32768, 262144, 8388608);
750
751 let a = vec![1.0; 16];
753 let b = vec![2.0; 16];
754 let mut c = vec![0.0; 16];
755
756 let result = matmul.cache_oblivious_matmul(&a, &b, &mut c, 4, 4, 4);
758 assert!(result.is_ok());
759
760 for value in &c {
762 assert!(*value > 0.0);
763 }
764 }
765
766 #[test]
767 fn test_cache_optimized_tensor_ops() {
768 let mut tensor_ops = CacheOptimizedTensorOps::new();
769
770 let data_sizes = vec![1024, 2048, 4096];
771 let pattern = tensor_ops.analyze_access_pattern("matmul", &data_sizes);
772
773 assert!(pattern.sequential_ratio > 0.0);
774 assert!(pattern.cache_line_utilization > 0.0);
775 assert!(!pattern.stride_patterns.is_empty());
776
777 let strategy = tensor_ops.optimize_tensor_operation("matmul", &pattern);
778 assert!(strategy.block_size > 0);
779 assert!(strategy.parallelization_factor > 0);
780 }
781
782 #[test]
783 fn test_cache_efficiency_metrics() {
784 let matmul = CacheFriendlyMatMul::new(32768, 262144, 8388608);
785 let metrics = matmul.get_cache_efficiency_metrics();
786
787 assert!(metrics.l1_hit_rate > 0.0);
788 assert!(metrics.l2_hit_rate > 0.0);
789 assert!(metrics.l3_hit_rate > 0.0);
790 assert!(metrics.memory_bandwidth_utilization > 0.0);
791 }
792}