Skip to main content

scirs2_core/
cache_ops.rs

1//! Cache-aware matrix operations for high-performance computing.
2//!
3//! This module provides cache-optimized implementations of common matrix
4//! operations, including tiled matrix multiplication and cache-oblivious
5//! transpose. The [`CacheAwareConfig`] struct exposes cache topology information
6//! and derives optimal blocking parameters so that working sets fit in the
7//! appropriate cache level.
8//!
9//! # Examples
10//!
11//! ```rust
12//! use scirs2_core::cache_ops::{CacheAwareConfig, tiled_matmul};
13//! use ndarray::Array2;
14//!
15//! let config = CacheAwareConfig::detect();
16//! let a = Array2::<f64>::eye(4);
17//! let b = Array2::<f64>::eye(4);
18//! let c = tiled_matmul(&a, &b);
19//! assert_eq!(c, a);
20//! ```
21
22use ndarray::Array2;
23
24// ──────────────────────────────────────────────────────────────────────────────
25// CacheAwareConfig
26// ──────────────────────────────────────────────────────────────────────────────
27
28/// Cache topology description used to derive blocking parameters.
29///
30/// All sizes are in bytes.  The defaults (L1 = 32 KiB, L2 = 256 KiB,
31/// L3 = 8 MiB) are representative of modern x86-64 server CPUs and are
32/// used whenever hardware detection is not available.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct CacheAwareConfig {
35    /// L1 data-cache size in bytes (default 32 KiB)
36    pub l1_cache_size: usize,
37    /// L2 unified-cache size in bytes (default 256 KiB)
38    pub l2_cache_size: usize,
39    /// L3 shared-cache size in bytes (default 8 MiB)
40    pub l3_cache_size: usize,
41    /// Size of a single element in bytes (default 8 for `f64`)
42    pub element_size: usize,
43}
44
45impl Default for CacheAwareConfig {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl CacheAwareConfig {
52    /// Construct with well-known default cache sizes and `element_size = 8`.
53    pub fn new() -> Self {
54        Self {
55            l1_cache_size: 32 * 1024,       // 32 KiB
56            l2_cache_size: 256 * 1024,      // 256 KiB
57            l3_cache_size: 8 * 1024 * 1024, // 8 MiB
58            element_size: 8,                // f64
59        }
60    }
61
62    /// Attempt to detect L1/L2/L3 sizes from the host hardware.
63    ///
64    /// On Linux the kernel exposes per-cache information under
65    /// `/sys/devices/system/cpu/cpu0/cache/index*/size`.  On macOS
66    /// the same data is available through `sysctl`.  If detection
67    /// fails for any reason the function silently returns the same
68    /// defaults as [`CacheAwareConfig::new`].
69    pub fn detect() -> Self {
70        let defaults = Self::new();
71
72        #[cfg(target_os = "linux")]
73        {
74            if let Some(cfg) = detect_linux() {
75                return cfg;
76            }
77        }
78
79        #[cfg(target_os = "macos")]
80        {
81            if let Some(cfg) = detect_macos() {
82                return cfg;
83            }
84        }
85
86        defaults
87    }
88
89    /// Compute the optimal square tile edge length for a matrix-multiply
90    /// blocking scheme so that **three** tiles fit simultaneously in the L2
91    /// cache.
92    ///
93    /// `n` is the largest matrix dimension; the returned tile size is
94    /// clamped to `[4, n]`.
95    pub fn tile_size_for_matmul(&self, n: usize) -> usize {
96        // We want: 3 * tile^2 * element_size <= l2_cache_size
97        // => tile <= sqrt(l2_cache_size / (3 * element_size))
98        let max_elements = self.l2_cache_size / (3 * self.element_size.max(1));
99        let tile = (max_elements as f64).sqrt() as usize;
100        tile.clamp(4, n.max(4))
101    }
102
103    /// Compute the block size for a sequential scan so that one block fits
104    /// comfortably in the L1 data cache.
105    pub fn block_size_for_scan(&self) -> usize {
106        (self.l1_cache_size / self.element_size.max(1)).max(1)
107    }
108}
109
110// ──────────────────────────────────────────────────────────────────────────────
111// Hardware detection helpers (platform-specific)
112// ──────────────────────────────────────────────────────────────────────────────
113
114/// Parse a Linux sysfs cache size string like "32K" or "8192K" into bytes.
115fn parse_sysfs_size(s: &str) -> Option<usize> {
116    let s = s.trim();
117    if let Some(stripped) = s.strip_suffix('K') {
118        stripped.trim().parse::<usize>().ok().map(|v| v * 1024)
119    } else if let Some(stripped) = s.strip_suffix('M') {
120        stripped
121            .trim()
122            .parse::<usize>()
123            .ok()
124            .map(|v| v * 1024 * 1024)
125    } else {
126        s.parse::<usize>().ok()
127    }
128}
129
130#[cfg(target_os = "linux")]
131fn detect_linux() -> Option<CacheAwareConfig> {
132    use std::fs;
133
134    // Iterate over sysfs cache index directories.
135    let base = "/sys/devices/system/cpu/cpu0/cache";
136    let mut l1: Option<usize> = None;
137    let mut l2: Option<usize> = None;
138    let mut l3: Option<usize> = None;
139
140    for idx in 0..8usize {
141        let level_path = format!("{base}/index{idx}/level");
142        let size_path = format!("{base}/index{idx}/size");
143        let type_path = format!("{base}/index{idx}/type");
144
145        let level_str = match fs::read_to_string(&level_path) {
146            Ok(s) => s,
147            Err(_) => break,
148        };
149        let level: usize = match level_str.trim().parse() {
150            Ok(v) => v,
151            Err(_) => continue,
152        };
153        let size_str = match fs::read_to_string(&size_path) {
154            Ok(s) => s,
155            Err(_) => continue,
156        };
157        let size = match parse_sysfs_size(&size_str) {
158            Some(s) => s,
159            None => continue,
160        };
161        // Skip instruction caches for L1.
162        let cache_type = fs::read_to_string(&type_path).unwrap_or_default();
163        let cache_type = cache_type.trim();
164        if level == 1 && cache_type == "Instruction" {
165            continue;
166        }
167
168        match level {
169            1 => l1 = Some(size),
170            2 => l2 = Some(size),
171            3 => l3 = Some(size),
172            _ => {}
173        }
174    }
175
176    if l1.is_none() && l2.is_none() && l3.is_none() {
177        return None;
178    }
179
180    let defaults = CacheAwareConfig::new();
181    Some(CacheAwareConfig {
182        l1_cache_size: l1.unwrap_or(defaults.l1_cache_size),
183        l2_cache_size: l2.unwrap_or(defaults.l2_cache_size),
184        l3_cache_size: l3.unwrap_or(defaults.l3_cache_size),
185        element_size: defaults.element_size,
186    })
187}
188
189#[cfg(target_os = "macos")]
190fn detect_macos() -> Option<CacheAwareConfig> {
191    fn sysctl_usize(name: &str) -> Option<usize> {
192        let out = std::process::Command::new("sysctl")
193            .arg("-n")
194            .arg(name)
195            .output()
196            .ok()?;
197        let s = std::str::from_utf8(&out.stdout).ok()?.trim();
198        s.parse::<usize>().ok()
199    }
200
201    let l1 = sysctl_usize("hw.l1dcachesize");
202    let l2 = sysctl_usize("hw.l2cachesize");
203    let l3 = sysctl_usize("hw.l3cachesize");
204
205    if l1.is_none() && l2.is_none() && l3.is_none() {
206        return None;
207    }
208
209    let defaults = CacheAwareConfig::new();
210    Some(CacheAwareConfig {
211        l1_cache_size: l1.unwrap_or(defaults.l1_cache_size),
212        l2_cache_size: l2.unwrap_or(defaults.l2_cache_size),
213        l3_cache_size: l3.unwrap_or(defaults.l3_cache_size),
214        element_size: defaults.element_size,
215    })
216}
217
218// ──────────────────────────────────────────────────────────────────────────────
219// Cache-oblivious transpose
220// ──────────────────────────────────────────────────────────────────────────────
221
222/// In-place cache-oblivious transpose of a **square** `Array2<f64>`.
223///
224/// For non-square matrices the function falls back to `a.t().to_owned()`.
225///
226/// The recursive divide-and-conquer decomposition achieves optimal cache
227/// performance without knowing the actual cache size at compile time.
228pub fn cache_oblivious_transpose(a: &mut Array2<f64>) {
229    let (rows, cols) = a.dim();
230    if rows != cols {
231        // Non-square: replace with transpose clone.
232        let transposed = a.t().to_owned();
233        *a = transposed;
234        return;
235    }
236    let n = rows;
237    // Work on a raw slice; safe because we have exclusive access via &mut.
238    let ptr = a.as_mut_ptr();
239    // SAFETY: Array2 with standard layout gives a contiguous row-major buffer.
240    let slice = unsafe { std::slice::from_raw_parts_mut(ptr, n * n) };
241    recursive_transpose(slice, 0, n, 0, n, n);
242}
243
244/// Recursive helper: transpose the submatrix `[row_start..row_end) × [col_start..col_end)`
245/// within the flat row-major buffer `buf` of stride `stride` (= total columns = n).
246fn recursive_transpose(
247    buf: &mut [f64],
248    row_start: usize,
249    row_end: usize,
250    col_start: usize,
251    col_end: usize,
252    stride: usize,
253) {
254    const BASE: usize = 32;
255    let rows = row_end - row_start;
256    let cols = col_end - col_start;
257
258    if rows <= BASE && cols <= BASE {
259        // Base case: swap elements across the diagonal for the subblock.
260        for i in row_start..row_end {
261            // Only process the upper triangle relative to the diagonal.
262            let j_min = if col_start > i { col_start } else { i + 1 };
263            for j in j_min..col_end {
264                buf.swap(i * stride + j, j * stride + i);
265            }
266        }
267        return;
268    }
269
270    if rows >= cols {
271        let mid = row_start + rows / 2;
272        recursive_transpose(buf, row_start, mid, col_start, col_end, stride);
273        recursive_transpose(buf, mid, row_end, col_start, col_end, stride);
274    } else {
275        let mid = col_start + cols / 2;
276        recursive_transpose(buf, row_start, row_end, col_start, mid, stride);
277        recursive_transpose(buf, row_start, row_end, mid, col_end, stride);
278    }
279}
280
281// ──────────────────────────────────────────────────────────────────────────────
282// Tiled matrix multiply
283// ──────────────────────────────────────────────────────────────────────────────
284
285/// Cache-efficient tiled matrix multiplication `C = A × B`.
286///
287/// Tile size is derived from [`CacheAwareConfig::detect`] so that three
288/// tiles fit in the L2 cache simultaneously, maximising reuse.
289///
290/// # Panics
291///
292/// Panics if the inner dimensions do not match (`a.ncols() != b.nrows()`).
293pub fn tiled_matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
294    let (m, k) = a.dim();
295    let (kb, n) = b.dim();
296    assert_eq!(
297        k, kb,
298        "tiled_matmul: inner dimensions must match ({k} vs {kb})"
299    );
300
301    let config = CacheAwareConfig::detect();
302    let tile = config.tile_size_for_matmul(m.max(n).max(k));
303
304    let mut c = Array2::<f64>::zeros((m, n));
305
306    // Blocked i-k-j loop for cache reuse of B tiles.
307    let mut ii = 0;
308    while ii < m {
309        let i_end = (ii + tile).min(m);
310        let mut kk = 0;
311        while kk < k {
312            let k_end = (kk + tile).min(k);
313            let mut jj = 0;
314            while jj < n {
315                let j_end = (jj + tile).min(n);
316                // Micro-kernel: accumulate into the C tile.
317                for i in ii..i_end {
318                    for kp in kk..k_end {
319                        let a_ik = a[[i, kp]];
320                        for j in jj..j_end {
321                            c[[i, j]] += a_ik * b[[kp, j]];
322                        }
323                    }
324                }
325                jj += tile;
326            }
327            kk += tile;
328        }
329        ii += tile;
330    }
331
332    c
333}
334
335// ──────────────────────────────────────────────────────────────────────────────
336// Prefetch-hinted matrix multiply
337// ──────────────────────────────────────────────────────────────────────────────
338
339/// Matrix multiplication with software prefetch hints for pipelined execution.
340///
341/// On `x86_64` the implementation inserts `_mm_prefetch` intrinsics to pull
342/// the next tile of `B` into L2 cache before it is needed.  On other
343/// architectures this falls back to the same tiled algorithm as
344/// [`tiled_matmul`].
345///
346/// # Panics
347///
348/// Panics if the inner dimensions do not match.
349pub fn prefetch_matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
350    let (m, k) = a.dim();
351    let (kb, n) = b.dim();
352    assert_eq!(
353        k, kb,
354        "prefetch_matmul: inner dimensions must match ({k} vs {kb})"
355    );
356
357    let config = CacheAwareConfig::detect();
358    let tile = config.tile_size_for_matmul(m.max(n).max(k));
359
360    let mut c = Array2::<f64>::zeros((m, n));
361
362    let mut ii = 0;
363    while ii < m {
364        let i_end = (ii + tile).min(m);
365        let mut kk = 0;
366        while kk < k {
367            let k_end = (kk + tile).min(k);
368            let mut jj = 0;
369            while jj < n {
370                let j_end = (jj + tile).min(n);
371
372                // Issue prefetch for the *next* B tile.
373                let next_jj = jj + tile;
374                if next_jj < n {
375                    let next_j_end = (next_jj + tile).min(n);
376                    prefetch_b_tile(b, kk, k_end, next_jj, next_j_end);
377                }
378
379                for i in ii..i_end {
380                    for kp in kk..k_end {
381                        let a_ik = a[[i, kp]];
382                        for j in jj..j_end {
383                            c[[i, j]] += a_ik * b[[kp, j]];
384                        }
385                    }
386                }
387                jj += tile;
388            }
389            kk += tile;
390        }
391        ii += tile;
392    }
393
394    c
395}
396
397/// Issue cache prefetch hints for a tile of `b`.
398#[inline]
399fn prefetch_b_tile(b: &Array2<f64>, k_start: usize, k_end: usize, j_start: usize, j_end: usize) {
400    // Stride between contiguous prefetch hints (one cache line = 64 bytes = 8 f64s).
401    const STRIDE: usize = 8;
402
403    for kp in k_start..k_end {
404        let mut j = j_start;
405        while j < j_end {
406            // Obtain a raw pointer to b[[kp, j]] and issue the prefetch.
407            let ptr: *const f64 = &b[[kp, j]];
408            #[cfg(target_arch = "x86_64")]
409            {
410                // SAFETY: _mm_prefetch only reads the cache line; it never
411                // dereferences the pointer beyond a speculative load.
412                unsafe {
413                    std::arch::x86_64::_mm_prefetch(
414                        ptr as *const i8,
415                        std::arch::x86_64::_MM_HINT_T1, // L2 cache
416                    );
417                }
418            }
419            #[cfg(not(target_arch = "x86_64"))]
420            {
421                // On other architectures use a harmless identity hint.
422                let _ = std::hint::black_box(ptr);
423            }
424            j += STRIDE;
425        }
426    }
427}
428
429// ──────────────────────────────────────────────────────────────────────────────
430// Tests
431// ──────────────────────────────────────────────────────────────────────────────
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use ndarray::Array2;
437
438    // ── CacheAwareConfig ──────────────────────────────────────────────────────
439
440    #[test]
441    fn test_config_defaults_are_reasonable() {
442        let cfg = CacheAwareConfig::new();
443        assert!(cfg.l1_cache_size >= 8 * 1024, "L1 should be at least 8 KiB");
444        assert!(cfg.l2_cache_size > cfg.l1_cache_size, "L2 > L1");
445        assert!(cfg.l3_cache_size > cfg.l2_cache_size, "L3 > L2");
446        assert_eq!(cfg.element_size, 8);
447    }
448
449    #[test]
450    fn test_config_detect_returns_nonzero_sizes() {
451        let cfg = CacheAwareConfig::detect();
452        assert!(cfg.l1_cache_size > 0);
453        assert!(cfg.l2_cache_size > 0);
454        assert!(cfg.l3_cache_size > 0);
455        assert!(cfg.element_size > 0);
456    }
457
458    #[test]
459    fn test_tile_size_within_bounds_small() {
460        let cfg = CacheAwareConfig::new();
461        let n = 16;
462        let tile = cfg.tile_size_for_matmul(n);
463        assert!(tile >= 4, "tile_size >= 4");
464        assert!(tile <= n, "tile_size <= n");
465    }
466
467    #[test]
468    fn test_tile_size_within_bounds_large() {
469        let cfg = CacheAwareConfig::new();
470        for n in [64, 128, 512, 1024] {
471            let tile = cfg.tile_size_for_matmul(n);
472            assert!(tile >= 4);
473            assert!(tile <= n);
474        }
475    }
476
477    #[test]
478    fn test_block_size_for_scan_is_positive() {
479        let cfg = CacheAwareConfig::new();
480        assert!(cfg.block_size_for_scan() > 0);
481    }
482
483    #[test]
484    fn test_block_size_for_scan_fits_in_l1() {
485        let cfg = CacheAwareConfig::new();
486        let block = cfg.block_size_for_scan();
487        // block * element_size should be <= l1_cache_size
488        assert!(block * cfg.element_size <= cfg.l1_cache_size);
489    }
490
491    // ── cache_oblivious_transpose ─────────────────────────────────────────────
492
493    #[test]
494    fn test_cache_oblivious_transpose_4x4() {
495        let mut a = Array2::<f64>::from_shape_vec((4, 4), (0..16).map(|x| x as f64).collect())
496            .expect("valid shape");
497        let expected = a.t().to_owned();
498        cache_oblivious_transpose(&mut a);
499        assert_eq!(a, expected);
500    }
501
502    #[test]
503    fn test_cache_oblivious_transpose_8x8() {
504        let data: Vec<f64> = (0..64).map(|x| x as f64).collect();
505        let mut a = Array2::<f64>::from_shape_vec((8, 8), data).expect("valid shape");
506        let expected = a.t().to_owned();
507        cache_oblivious_transpose(&mut a);
508        assert_eq!(a, expected);
509    }
510
511    #[test]
512    fn test_cache_oblivious_transpose_involutory() {
513        // Applying transpose twice should return the original matrix.
514        let data: Vec<f64> = (0..64).map(|x| x as f64 * 0.5).collect();
515        let mut a = Array2::<f64>::from_shape_vec((8, 8), data.clone()).expect("valid shape");
516        let original = a.clone();
517        cache_oblivious_transpose(&mut a);
518        cache_oblivious_transpose(&mut a);
519        assert_eq!(a, original);
520    }
521
522    #[test]
523    fn test_cache_oblivious_transpose_large() {
524        let n = 64;
525        let data: Vec<f64> = (0..(n * n)).map(|x| x as f64).collect();
526        let mut a = Array2::<f64>::from_shape_vec((n, n), data).expect("valid shape");
527        let expected = a.t().to_owned();
528        cache_oblivious_transpose(&mut a);
529        assert_eq!(a, expected);
530    }
531
532    #[test]
533    fn test_cache_oblivious_transpose_non_square_fallback() {
534        let mut a = Array2::<f64>::from_shape_vec((3, 5), (0..15).map(|x| x as f64).collect())
535            .expect("valid shape");
536        let expected = a.t().to_owned();
537        cache_oblivious_transpose(&mut a);
538        assert_eq!(a, expected);
539    }
540
541    // ── tiled_matmul ──────────────────────────────────────────────────────────
542
543    #[test]
544    fn test_tiled_matmul_identity_4x4() {
545        let a = Array2::<f64>::eye(4);
546        let b = Array2::<f64>::eye(4);
547        let c = tiled_matmul(&a, &b);
548        assert_eq!(c, Array2::<f64>::eye(4));
549    }
550
551    #[test]
552    fn test_tiled_matmul_known_result_2x2() {
553        // [1 2] × [5 6]  =  [19 22]
554        // [3 4]   [7 8]     [43 50]
555        let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("ok");
556        let b = Array2::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).expect("ok");
557        let c = tiled_matmul(&a, &b);
558        let expected = Array2::from_shape_vec((2, 2), vec![19.0, 22.0, 43.0, 50.0]).expect("ok");
559        for ((i, j), v) in c.indexed_iter() {
560            assert!(
561                (v - expected[[i, j]]).abs() < 1e-12,
562                "mismatch at [{i},{j}]: {v} != {}",
563                expected[[i, j]]
564            );
565        }
566    }
567
568    #[test]
569    fn test_tiled_matmul_matches_naive_16x16() {
570        use ndarray::Array2;
571        let n = 16;
572        let a = Array2::from_shape_fn((n, n), |(i, j)| (i * n + j) as f64 * 0.01);
573        let b = Array2::from_shape_fn((n, n), |(i, j)| (i + j) as f64 * 0.01);
574        let tiled = tiled_matmul(&a, &b);
575        let naive = a.dot(&b);
576        for ((i, j), v) in tiled.indexed_iter() {
577            assert!(
578                (v - naive[[i, j]]).abs() < 1e-9,
579                "tiled vs naive mismatch at [{i},{j}]"
580            );
581        }
582    }
583
584    // ── prefetch_matmul ───────────────────────────────────────────────────────
585
586    #[test]
587    fn test_prefetch_matmul_matches_tiled_8x8() {
588        let n = 8;
589        let a = Array2::from_shape_fn((n, n), |(i, j)| (i * n + j) as f64);
590        let b = Array2::from_shape_fn((n, n), |(i, j)| (i + j + 1) as f64);
591        let tiled = tiled_matmul(&a, &b);
592        let prefetched = prefetch_matmul(&a, &b);
593        for ((i, j), v) in prefetched.indexed_iter() {
594            assert!(
595                (v - tiled[[i, j]]).abs() < 1e-9,
596                "prefetch vs tiled mismatch at [{i},{j}]"
597            );
598        }
599    }
600
601    #[test]
602    fn test_prefetch_matmul_correctness_64x64() {
603        let n = 64;
604        let a = Array2::from_shape_fn((n, n), |(i, j)| ((i + 1) * (j + 1)) as f64 * 0.001);
605        let b = Array2::from_shape_fn((n, n), |(i, j)| (i as f64 - j as f64).abs() * 0.001);
606        let reference = a.dot(&b);
607        let result = prefetch_matmul(&a, &b);
608        for ((i, j), v) in result.indexed_iter() {
609            assert!(
610                (v - reference[[i, j]]).abs() < 1e-8,
611                "prefetch_matmul wrong at [{i},{j}]"
612            );
613        }
614    }
615
616    #[test]
617    fn test_prefetch_matmul_identity_8x8() {
618        let eye = Array2::<f64>::eye(8);
619        let a = Array2::from_shape_fn((8, 8), |(i, j)| (i * j) as f64 + 1.0);
620        let result = prefetch_matmul(&a, &eye);
621        for ((i, j), v) in result.indexed_iter() {
622            assert!(
623                (v - a[[i, j]]).abs() < 1e-12,
624                "A×I should equal A at [{i},{j}]"
625            );
626        }
627    }
628
629    #[test]
630    fn test_tiled_matmul_rect_2x3_times_3x4() {
631        // Verify non-square multiplication shapes.
632        let a = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("ok");
633        let b = Array2::from_shape_vec(
634            (3, 4),
635            vec![
636                7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
637            ],
638        )
639        .expect("ok");
640        let tiled = tiled_matmul(&a, &b);
641        let naive = a.dot(&b);
642        for ((i, j), v) in tiled.indexed_iter() {
643            assert!(
644                (v - naive[[i, j]]).abs() < 1e-9,
645                "rect mismatch at [{i},{j}]"
646            );
647        }
648    }
649}