1use ndarray::Array2;
23
24#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct CacheAwareConfig {
35 pub l1_cache_size: usize,
37 pub l2_cache_size: usize,
39 pub l3_cache_size: usize,
41 pub element_size: usize,
43}
44
45impl Default for CacheAwareConfig {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl CacheAwareConfig {
52 pub fn new() -> Self {
54 Self {
55 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, element_size: 8, }
60 }
61
62 pub fn detect() -> Self {
70 let defaults = Self::new();
71
72 #[cfg(target_os = "linux")]
73 {
74 if let Some(cfg) = detect_linux() {
75 return cfg;
76 }
77 }
78
79 #[cfg(target_os = "macos")]
80 {
81 if let Some(cfg) = detect_macos() {
82 return cfg;
83 }
84 }
85
86 defaults
87 }
88
89 pub fn tile_size_for_matmul(&self, n: usize) -> usize {
96 let max_elements = self.l2_cache_size / (3 * self.element_size.max(1));
99 let tile = (max_elements as f64).sqrt() as usize;
100 tile.clamp(4, n.max(4))
101 }
102
103 pub fn block_size_for_scan(&self) -> usize {
106 (self.l1_cache_size / self.element_size.max(1)).max(1)
107 }
108}
109
110fn parse_sysfs_size(s: &str) -> Option<usize> {
116 let s = s.trim();
117 if let Some(stripped) = s.strip_suffix('K') {
118 stripped.trim().parse::<usize>().ok().map(|v| v * 1024)
119 } else if let Some(stripped) = s.strip_suffix('M') {
120 stripped
121 .trim()
122 .parse::<usize>()
123 .ok()
124 .map(|v| v * 1024 * 1024)
125 } else {
126 s.parse::<usize>().ok()
127 }
128}
129
130#[cfg(target_os = "linux")]
131fn detect_linux() -> Option<CacheAwareConfig> {
132 use std::fs;
133
134 let base = "/sys/devices/system/cpu/cpu0/cache";
136 let mut l1: Option<usize> = None;
137 let mut l2: Option<usize> = None;
138 let mut l3: Option<usize> = None;
139
140 for idx in 0..8usize {
141 let level_path = format!("{base}/index{idx}/level");
142 let size_path = format!("{base}/index{idx}/size");
143 let type_path = format!("{base}/index{idx}/type");
144
145 let level_str = match fs::read_to_string(&level_path) {
146 Ok(s) => s,
147 Err(_) => break,
148 };
149 let level: usize = match level_str.trim().parse() {
150 Ok(v) => v,
151 Err(_) => continue,
152 };
153 let size_str = match fs::read_to_string(&size_path) {
154 Ok(s) => s,
155 Err(_) => continue,
156 };
157 let size = match parse_sysfs_size(&size_str) {
158 Some(s) => s,
159 None => continue,
160 };
161 let cache_type = fs::read_to_string(&type_path).unwrap_or_default();
163 let cache_type = cache_type.trim();
164 if level == 1 && cache_type == "Instruction" {
165 continue;
166 }
167
168 match level {
169 1 => l1 = Some(size),
170 2 => l2 = Some(size),
171 3 => l3 = Some(size),
172 _ => {}
173 }
174 }
175
176 if l1.is_none() && l2.is_none() && l3.is_none() {
177 return None;
178 }
179
180 let defaults = CacheAwareConfig::new();
181 Some(CacheAwareConfig {
182 l1_cache_size: l1.unwrap_or(defaults.l1_cache_size),
183 l2_cache_size: l2.unwrap_or(defaults.l2_cache_size),
184 l3_cache_size: l3.unwrap_or(defaults.l3_cache_size),
185 element_size: defaults.element_size,
186 })
187}
188
189#[cfg(target_os = "macos")]
190fn detect_macos() -> Option<CacheAwareConfig> {
191 fn sysctl_usize(name: &str) -> Option<usize> {
192 let out = std::process::Command::new("sysctl")
193 .arg("-n")
194 .arg(name)
195 .output()
196 .ok()?;
197 let s = std::str::from_utf8(&out.stdout).ok()?.trim();
198 s.parse::<usize>().ok()
199 }
200
201 let l1 = sysctl_usize("hw.l1dcachesize");
202 let l2 = sysctl_usize("hw.l2cachesize");
203 let l3 = sysctl_usize("hw.l3cachesize");
204
205 if l1.is_none() && l2.is_none() && l3.is_none() {
206 return None;
207 }
208
209 let defaults = CacheAwareConfig::new();
210 Some(CacheAwareConfig {
211 l1_cache_size: l1.unwrap_or(defaults.l1_cache_size),
212 l2_cache_size: l2.unwrap_or(defaults.l2_cache_size),
213 l3_cache_size: l3.unwrap_or(defaults.l3_cache_size),
214 element_size: defaults.element_size,
215 })
216}
217
218pub fn cache_oblivious_transpose(a: &mut Array2<f64>) {
229 let (rows, cols) = a.dim();
230 if rows != cols {
231 let transposed = a.t().to_owned();
233 *a = transposed;
234 return;
235 }
236 let n = rows;
237 let ptr = a.as_mut_ptr();
239 let slice = unsafe { std::slice::from_raw_parts_mut(ptr, n * n) };
241 recursive_transpose(slice, 0, n, 0, n, n);
242}
243
244fn recursive_transpose(
247 buf: &mut [f64],
248 row_start: usize,
249 row_end: usize,
250 col_start: usize,
251 col_end: usize,
252 stride: usize,
253) {
254 const BASE: usize = 32;
255 let rows = row_end - row_start;
256 let cols = col_end - col_start;
257
258 if rows <= BASE && cols <= BASE {
259 for i in row_start..row_end {
261 let j_min = if col_start > i { col_start } else { i + 1 };
263 for j in j_min..col_end {
264 buf.swap(i * stride + j, j * stride + i);
265 }
266 }
267 return;
268 }
269
270 if rows >= cols {
271 let mid = row_start + rows / 2;
272 recursive_transpose(buf, row_start, mid, col_start, col_end, stride);
273 recursive_transpose(buf, mid, row_end, col_start, col_end, stride);
274 } else {
275 let mid = col_start + cols / 2;
276 recursive_transpose(buf, row_start, row_end, col_start, mid, stride);
277 recursive_transpose(buf, row_start, row_end, mid, col_end, stride);
278 }
279}
280
281pub fn tiled_matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
294 let (m, k) = a.dim();
295 let (kb, n) = b.dim();
296 assert_eq!(
297 k, kb,
298 "tiled_matmul: inner dimensions must match ({k} vs {kb})"
299 );
300
301 let config = CacheAwareConfig::detect();
302 let tile = config.tile_size_for_matmul(m.max(n).max(k));
303
304 let mut c = Array2::<f64>::zeros((m, n));
305
306 let mut ii = 0;
308 while ii < m {
309 let i_end = (ii + tile).min(m);
310 let mut kk = 0;
311 while kk < k {
312 let k_end = (kk + tile).min(k);
313 let mut jj = 0;
314 while jj < n {
315 let j_end = (jj + tile).min(n);
316 for i in ii..i_end {
318 for kp in kk..k_end {
319 let a_ik = a[[i, kp]];
320 for j in jj..j_end {
321 c[[i, j]] += a_ik * b[[kp, j]];
322 }
323 }
324 }
325 jj += tile;
326 }
327 kk += tile;
328 }
329 ii += tile;
330 }
331
332 c
333}
334
335pub fn prefetch_matmul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
350 let (m, k) = a.dim();
351 let (kb, n) = b.dim();
352 assert_eq!(
353 k, kb,
354 "prefetch_matmul: inner dimensions must match ({k} vs {kb})"
355 );
356
357 let config = CacheAwareConfig::detect();
358 let tile = config.tile_size_for_matmul(m.max(n).max(k));
359
360 let mut c = Array2::<f64>::zeros((m, n));
361
362 let mut ii = 0;
363 while ii < m {
364 let i_end = (ii + tile).min(m);
365 let mut kk = 0;
366 while kk < k {
367 let k_end = (kk + tile).min(k);
368 let mut jj = 0;
369 while jj < n {
370 let j_end = (jj + tile).min(n);
371
372 let next_jj = jj + tile;
374 if next_jj < n {
375 let next_j_end = (next_jj + tile).min(n);
376 prefetch_b_tile(b, kk, k_end, next_jj, next_j_end);
377 }
378
379 for i in ii..i_end {
380 for kp in kk..k_end {
381 let a_ik = a[[i, kp]];
382 for j in jj..j_end {
383 c[[i, j]] += a_ik * b[[kp, j]];
384 }
385 }
386 }
387 jj += tile;
388 }
389 kk += tile;
390 }
391 ii += tile;
392 }
393
394 c
395}
396
397#[inline]
399fn prefetch_b_tile(b: &Array2<f64>, k_start: usize, k_end: usize, j_start: usize, j_end: usize) {
400 const STRIDE: usize = 8;
402
403 for kp in k_start..k_end {
404 let mut j = j_start;
405 while j < j_end {
406 let ptr: *const f64 = &b[[kp, j]];
408 #[cfg(target_arch = "x86_64")]
409 {
410 unsafe {
413 std::arch::x86_64::_mm_prefetch(
414 ptr as *const i8,
415 std::arch::x86_64::_MM_HINT_T1, );
417 }
418 }
419 #[cfg(not(target_arch = "x86_64"))]
420 {
421 let _ = std::hint::black_box(ptr);
423 }
424 j += STRIDE;
425 }
426 }
427}
428
429#[cfg(test)]
434mod tests {
435 use super::*;
436 use ndarray::Array2;
437
438 #[test]
441 fn test_config_defaults_are_reasonable() {
442 let cfg = CacheAwareConfig::new();
443 assert!(cfg.l1_cache_size >= 8 * 1024, "L1 should be at least 8 KiB");
444 assert!(cfg.l2_cache_size > cfg.l1_cache_size, "L2 > L1");
445 assert!(cfg.l3_cache_size > cfg.l2_cache_size, "L3 > L2");
446 assert_eq!(cfg.element_size, 8);
447 }
448
449 #[test]
450 fn test_config_detect_returns_nonzero_sizes() {
451 let cfg = CacheAwareConfig::detect();
452 assert!(cfg.l1_cache_size > 0);
453 assert!(cfg.l2_cache_size > 0);
454 assert!(cfg.l3_cache_size > 0);
455 assert!(cfg.element_size > 0);
456 }
457
458 #[test]
459 fn test_tile_size_within_bounds_small() {
460 let cfg = CacheAwareConfig::new();
461 let n = 16;
462 let tile = cfg.tile_size_for_matmul(n);
463 assert!(tile >= 4, "tile_size >= 4");
464 assert!(tile <= n, "tile_size <= n");
465 }
466
467 #[test]
468 fn test_tile_size_within_bounds_large() {
469 let cfg = CacheAwareConfig::new();
470 for n in [64, 128, 512, 1024] {
471 let tile = cfg.tile_size_for_matmul(n);
472 assert!(tile >= 4);
473 assert!(tile <= n);
474 }
475 }
476
477 #[test]
478 fn test_block_size_for_scan_is_positive() {
479 let cfg = CacheAwareConfig::new();
480 assert!(cfg.block_size_for_scan() > 0);
481 }
482
483 #[test]
484 fn test_block_size_for_scan_fits_in_l1() {
485 let cfg = CacheAwareConfig::new();
486 let block = cfg.block_size_for_scan();
487 assert!(block * cfg.element_size <= cfg.l1_cache_size);
489 }
490
491 #[test]
494 fn test_cache_oblivious_transpose_4x4() {
495 let mut a = Array2::<f64>::from_shape_vec((4, 4), (0..16).map(|x| x as f64).collect())
496 .expect("valid shape");
497 let expected = a.t().to_owned();
498 cache_oblivious_transpose(&mut a);
499 assert_eq!(a, expected);
500 }
501
502 #[test]
503 fn test_cache_oblivious_transpose_8x8() {
504 let data: Vec<f64> = (0..64).map(|x| x as f64).collect();
505 let mut a = Array2::<f64>::from_shape_vec((8, 8), data).expect("valid shape");
506 let expected = a.t().to_owned();
507 cache_oblivious_transpose(&mut a);
508 assert_eq!(a, expected);
509 }
510
511 #[test]
512 fn test_cache_oblivious_transpose_involutory() {
513 let data: Vec<f64> = (0..64).map(|x| x as f64 * 0.5).collect();
515 let mut a = Array2::<f64>::from_shape_vec((8, 8), data.clone()).expect("valid shape");
516 let original = a.clone();
517 cache_oblivious_transpose(&mut a);
518 cache_oblivious_transpose(&mut a);
519 assert_eq!(a, original);
520 }
521
522 #[test]
523 fn test_cache_oblivious_transpose_large() {
524 let n = 64;
525 let data: Vec<f64> = (0..(n * n)).map(|x| x as f64).collect();
526 let mut a = Array2::<f64>::from_shape_vec((n, n), data).expect("valid shape");
527 let expected = a.t().to_owned();
528 cache_oblivious_transpose(&mut a);
529 assert_eq!(a, expected);
530 }
531
532 #[test]
533 fn test_cache_oblivious_transpose_non_square_fallback() {
534 let mut a = Array2::<f64>::from_shape_vec((3, 5), (0..15).map(|x| x as f64).collect())
535 .expect("valid shape");
536 let expected = a.t().to_owned();
537 cache_oblivious_transpose(&mut a);
538 assert_eq!(a, expected);
539 }
540
541 #[test]
544 fn test_tiled_matmul_identity_4x4() {
545 let a = Array2::<f64>::eye(4);
546 let b = Array2::<f64>::eye(4);
547 let c = tiled_matmul(&a, &b);
548 assert_eq!(c, Array2::<f64>::eye(4));
549 }
550
551 #[test]
552 fn test_tiled_matmul_known_result_2x2() {
553 let a = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).expect("ok");
556 let b = Array2::from_shape_vec((2, 2), vec![5.0, 6.0, 7.0, 8.0]).expect("ok");
557 let c = tiled_matmul(&a, &b);
558 let expected = Array2::from_shape_vec((2, 2), vec![19.0, 22.0, 43.0, 50.0]).expect("ok");
559 for ((i, j), v) in c.indexed_iter() {
560 assert!(
561 (v - expected[[i, j]]).abs() < 1e-12,
562 "mismatch at [{i},{j}]: {v} != {}",
563 expected[[i, j]]
564 );
565 }
566 }
567
568 #[test]
569 fn test_tiled_matmul_matches_naive_16x16() {
570 use ndarray::Array2;
571 let n = 16;
572 let a = Array2::from_shape_fn((n, n), |(i, j)| (i * n + j) as f64 * 0.01);
573 let b = Array2::from_shape_fn((n, n), |(i, j)| (i + j) as f64 * 0.01);
574 let tiled = tiled_matmul(&a, &b);
575 let naive = a.dot(&b);
576 for ((i, j), v) in tiled.indexed_iter() {
577 assert!(
578 (v - naive[[i, j]]).abs() < 1e-9,
579 "tiled vs naive mismatch at [{i},{j}]"
580 );
581 }
582 }
583
584 #[test]
587 fn test_prefetch_matmul_matches_tiled_8x8() {
588 let n = 8;
589 let a = Array2::from_shape_fn((n, n), |(i, j)| (i * n + j) as f64);
590 let b = Array2::from_shape_fn((n, n), |(i, j)| (i + j + 1) as f64);
591 let tiled = tiled_matmul(&a, &b);
592 let prefetched = prefetch_matmul(&a, &b);
593 for ((i, j), v) in prefetched.indexed_iter() {
594 assert!(
595 (v - tiled[[i, j]]).abs() < 1e-9,
596 "prefetch vs tiled mismatch at [{i},{j}]"
597 );
598 }
599 }
600
601 #[test]
602 fn test_prefetch_matmul_correctness_64x64() {
603 let n = 64;
604 let a = Array2::from_shape_fn((n, n), |(i, j)| ((i + 1) * (j + 1)) as f64 * 0.001);
605 let b = Array2::from_shape_fn((n, n), |(i, j)| (i as f64 - j as f64).abs() * 0.001);
606 let reference = a.dot(&b);
607 let result = prefetch_matmul(&a, &b);
608 for ((i, j), v) in result.indexed_iter() {
609 assert!(
610 (v - reference[[i, j]]).abs() < 1e-8,
611 "prefetch_matmul wrong at [{i},{j}]"
612 );
613 }
614 }
615
616 #[test]
617 fn test_prefetch_matmul_identity_8x8() {
618 let eye = Array2::<f64>::eye(8);
619 let a = Array2::from_shape_fn((8, 8), |(i, j)| (i * j) as f64 + 1.0);
620 let result = prefetch_matmul(&a, &eye);
621 for ((i, j), v) in result.indexed_iter() {
622 assert!(
623 (v - a[[i, j]]).abs() < 1e-12,
624 "A×I should equal A at [{i},{j}]"
625 );
626 }
627 }
628
629 #[test]
630 fn test_tiled_matmul_rect_2x3_times_3x4() {
631 let a = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("ok");
633 let b = Array2::from_shape_vec(
634 (3, 4),
635 vec![
636 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
637 ],
638 )
639 .expect("ok");
640 let tiled = tiled_matmul(&a, &b);
641 let naive = a.dot(&b);
642 for ((i, j), v) in tiled.indexed_iter() {
643 assert!(
644 (v - naive[[i, j]]).abs() < 1e-9,
645 "rect mismatch at [{i},{j}]"
646 );
647 }
648 }
649}