1use scirs2_core::ndarray::{Array1, Array2};
16use sklears_core::{
17 error::{Result, SklearsError},
18 types::Float,
19};
20use std::alloc::{alloc, dealloc, Layout};
21use std::mem::{align_of, size_of};
22use std::ptr::NonNull;
23
24#[derive(Debug, Clone)]
26pub struct CacheOptimizationConfig {
27 pub cache_line_size: usize,
29 pub l1_cache_size: usize,
31 pub l2_cache_size: usize,
33 pub l3_cache_size: usize,
35 pub tile_size: usize,
37 pub enable_prefetch: bool,
39 pub numa_aware: bool,
41 pub memory_alignment: usize,
43}
44
45impl Default for CacheOptimizationConfig {
46 fn default() -> Self {
47 Self {
48 cache_line_size: 64,
49 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, tile_size: 64,
53 enable_prefetch: true,
54 numa_aware: false,
55 memory_alignment: 64,
56 }
57 }
58}
59
60pub struct AlignedMatrix<T>
62where
63 T: Copy,
64{
65 data: NonNull<T>,
66 shape: (usize, usize),
67 capacity: usize,
68 alignment: usize,
69 layout: Layout,
70}
71
72impl<T> AlignedMatrix<T>
73where
74 T: Copy + Default,
75{
76 pub fn new(rows: usize, cols: usize, alignment: usize) -> Result<Self> {
78 let capacity = rows * cols;
79 let size = capacity * size_of::<T>();
80
81 let alignment = alignment.max(align_of::<T>()).next_power_of_two();
83
84 let layout = Layout::from_size_align(size, alignment)
85 .map_err(|_| SklearsError::InvalidInput("Invalid memory layout".to_string()))?;
86
87 let data = unsafe {
88 let ptr = alloc(layout);
89 if ptr.is_null() {
90 return Err(SklearsError::InvalidInput(
91 "Memory allocation failed".to_string(),
92 ));
93 }
94
95 let typed_ptr = ptr as *mut T;
97 for i in 0..capacity {
98 typed_ptr.add(i).write(T::default());
99 }
100
101 NonNull::new_unchecked(typed_ptr)
102 };
103
104 Ok(Self {
105 data,
106 shape: (rows, cols),
107 capacity,
108 alignment,
109 layout,
110 })
111 }
112
113 pub fn shape(&self) -> (usize, usize) {
115 self.shape
116 }
117
118 pub fn get(&self, row: usize, col: usize) -> Result<T> {
120 if row >= self.shape.0 || col >= self.shape.1 {
121 return Err(SklearsError::InvalidInput(
122 "Index out of bounds".to_string(),
123 ));
124 }
125
126 let index = row * self.shape.1 + col;
127 unsafe { Ok(*self.data.as_ptr().add(index)) }
128 }
129
130 pub fn set(&mut self, row: usize, col: usize, value: T) -> Result<()> {
132 if row >= self.shape.0 || col >= self.shape.1 {
133 return Err(SklearsError::InvalidInput(
134 "Index out of bounds".to_string(),
135 ));
136 }
137
138 let index = row * self.shape.1 + col;
139 unsafe {
140 *self.data.as_ptr().add(index) = value;
141 }
142 Ok(())
143 }
144
145 pub fn as_ptr(&self) -> *const T {
147 self.data.as_ptr()
148 }
149
150 pub fn as_mut_ptr(&mut self) -> *mut T {
152 self.data.as_ptr()
153 }
154
155 pub fn as_slice(&self) -> &[T] {
157 unsafe { std::slice::from_raw_parts(self.data.as_ptr(), self.capacity) }
158 }
159
160 pub fn as_mut_slice(&mut self) -> &mut [T] {
162 unsafe { std::slice::from_raw_parts_mut(self.data.as_ptr(), self.capacity) }
163 }
164
165 pub fn is_aligned(&self) -> bool {
167 (self.data.as_ptr() as usize) % self.alignment == 0
168 }
169
170 pub fn alignment(&self) -> usize {
172 self.alignment
173 }
174}
175
176impl<T> Drop for AlignedMatrix<T>
177where
178 T: Copy,
179{
180 fn drop(&mut self) {
181 unsafe {
182 dealloc(self.data.as_ptr() as *mut u8, self.layout);
183 }
184 }
185}
186
187unsafe impl<T: Send + Copy> Send for AlignedMatrix<T> {}
189
190unsafe impl<T: Sync + Copy> Sync for AlignedMatrix<T> {}
192
193pub struct TiledMatrixOps {
195 config: CacheOptimizationConfig,
196}
197
198impl TiledMatrixOps {
199 pub fn new() -> Self {
201 Self {
202 config: CacheOptimizationConfig::default(),
203 }
204 }
205
206 pub fn with_config(config: CacheOptimizationConfig) -> Self {
208 Self { config }
209 }
210
211 pub fn tiled_matrix_multiply(
213 &self,
214 a: &Array2<Float>,
215 b: &Array2<Float>,
216 ) -> Result<Array2<Float>> {
217 let (m, k1) = a.dim();
218 let (k2, n) = b.dim();
219
220 if k1 != k2 {
221 return Err(SklearsError::InvalidInput(
222 "Matrix dimensions incompatible for multiplication".to_string(),
223 ));
224 }
225
226 let k = k1;
227 let mut result = Array2::<Float>::zeros((m, n));
228 let tile_size = self.config.tile_size;
229
230 for ii in (0..m).step_by(tile_size) {
232 for jj in (0..n).step_by(tile_size) {
233 for kk in (0..k).step_by(tile_size) {
234 let i_end = (ii + tile_size).min(m);
235 let j_end = (jj + tile_size).min(n);
236 let k_end = (kk + tile_size).min(k);
237
238 self.multiply_tile(&mut result, a, b, ii, i_end, jj, j_end, kk, k_end);
240 }
241 }
242 }
243
244 Ok(result)
245 }
246
247 fn multiply_tile(
249 &self,
250 result: &mut Array2<Float>,
251 a: &Array2<Float>,
252 b: &Array2<Float>,
253 i_start: usize,
254 i_end: usize,
255 j_start: usize,
256 j_end: usize,
257 k_start: usize,
258 k_end: usize,
259 ) {
260 for i in i_start..i_end {
261 for j in j_start..j_end {
262 let mut sum = 0.0;
263
264 #[cfg(target_arch = "x86_64")]
266 if self.config.enable_prefetch && k_start + 8 < k_end {
267 unsafe {
268 let a_ptr = a.as_ptr().add(i * a.ncols() + k_start + 8);
269 let b_ptr = b.as_ptr().add((k_start + 8) * b.ncols() + j);
270 std::arch::x86_64::_mm_prefetch(
271 a_ptr as *const i8,
272 std::arch::x86_64::_MM_HINT_T0,
273 );
274 std::arch::x86_64::_mm_prefetch(
275 b_ptr as *const i8,
276 std::arch::x86_64::_MM_HINT_T0,
277 );
278 }
279 }
280
281 for k in k_start..k_end {
283 sum += a[[i, k]] * b[[k, j]];
284 }
285
286 result[[i, j]] += sum;
287 }
288 }
289 }
290
291 pub fn cache_friendly_transpose(&self, input: &Array2<Float>) -> Result<Array2<Float>> {
293 let (rows, cols) = input.dim();
294 let mut output = Array2::<Float>::zeros((cols, rows));
295 let tile_size = self.config.tile_size;
296
297 for i in (0..rows).step_by(tile_size) {
299 for j in (0..cols).step_by(tile_size) {
300 let i_end = (i + tile_size).min(rows);
301 let j_end = (j + tile_size).min(cols);
302
303 for ii in i..i_end {
305 for jj in j..j_end {
306 output[[jj, ii]] = input[[ii, jj]];
307 }
308 }
309 }
310 }
311
312 Ok(output)
313 }
314
315 pub fn cache_optimized_svd(
317 &self,
318 matrix: &Array2<Float>,
319 n_components: usize,
320 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
321 let (m, n) = matrix.dim();
322 let min_dim = m.min(n).min(n_components);
323
324 if m * n > self.config.l2_cache_size / size_of::<Float>() {
326 self.blocked_svd(matrix, min_dim)
327 } else {
328 self.standard_svd(matrix, min_dim)
330 }
331 }
332
333 fn blocked_svd(
335 &self,
336 matrix: &Array2<Float>,
337 n_components: usize,
338 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
339 let (m, n) = matrix.dim();
340
341 let block_size = ((self.config.l2_cache_size / size_of::<Float>()) as f64).sqrt() as usize;
343
344 let mut u_blocks = Vec::new();
346 let mut s_values = Vec::new();
347 let mut vt_blocks = Vec::new();
348
349 for i in (0..m).step_by(block_size) {
350 let i_end = (i + block_size).min(m);
351 let block = matrix.slice(scirs2_core::ndarray::s![i..i_end, ..]);
352
353 let block_owned = block.to_owned();
355 let (u_block, s_block, vt_block) = self.standard_svd(&block_owned, n_components)?;
356
357 u_blocks.push(u_block);
358 s_values.push(s_block);
359 vt_blocks.push(vt_block);
360 }
361
362 let u = if let Some(first_u) = u_blocks.first() {
364 first_u.clone()
365 } else {
366 Array2::eye(m)
367 };
368
369 let s = if let Some(first_s) = s_values.first() {
370 first_s.clone()
371 } else {
372 Array1::ones(n_components)
373 };
374
375 let vt = if let Some(first_vt) = vt_blocks.first() {
376 first_vt.clone()
377 } else {
378 Array2::eye(n)
379 };
380
381 Ok((u, s, vt))
382 }
383
384 fn standard_svd(
386 &self,
387 matrix: &Array2<Float>,
388 n_components: usize,
389 ) -> Result<(Array2<Float>, Array1<Float>, Array2<Float>)> {
390 let (m, n) = matrix.dim();
391
392 let u = Array2::eye(m);
394 let s = Array1::ones(n_components);
395 let vt = Array2::eye(n);
396
397 Ok((
398 u.slice(scirs2_core::ndarray::s![.., ..n_components])
399 .to_owned(),
400 s,
401 vt.slice(scirs2_core::ndarray::s![..n_components, ..])
402 .to_owned(),
403 ))
404 }
405
406 pub fn bandwidth_efficient_matvec(
408 &self,
409 matrix: &Array2<Float>,
410 vector: &Array1<Float>,
411 ) -> Result<Array1<Float>> {
412 let (m, n) = matrix.dim();
413 if n != vector.len() {
414 return Err(SklearsError::InvalidInput(
415 "Matrix columns must match vector length".to_string(),
416 ));
417 }
418
419 let mut result = Array1::<Float>::zeros(m);
420 let tile_size = self.config.tile_size;
421
422 for i in (0..m).step_by(tile_size) {
424 let i_end = (i + tile_size).min(m);
425
426 for ii in i..i_end {
427 let mut sum = 0.0;
428
429 for (_j, (&matrix_val, &vec_val)) in
431 matrix.row(ii).iter().zip(vector.iter()).enumerate()
432 {
433 #[cfg(target_arch = "x86_64")]
434 if self.config.enable_prefetch && _j + 8 < n {
435 unsafe {
436 let next_ptr = matrix.as_ptr().add(ii * n + _j + 8);
437 std::arch::x86_64::_mm_prefetch(
438 next_ptr as *const i8,
439 std::arch::x86_64::_MM_HINT_T0,
440 );
441 }
442 }
443
444 sum += matrix_val * vec_val;
445 }
446
447 result[ii] = sum;
448 }
449 }
450
451 Ok(result)
452 }
453}
454
455impl Default for TiledMatrixOps {
456 fn default() -> Self {
457 Self::new()
458 }
459}
460
461pub struct MatrixMemoryPool {
463 pools: Vec<Vec<AlignedMatrix<Float>>>,
464 sizes: Vec<(usize, usize)>,
465 alignment: usize,
466}
467
468impl MatrixMemoryPool {
469 pub fn new(alignment: usize) -> Self {
471 Self {
472 pools: Vec::new(),
473 sizes: Vec::new(),
474 alignment,
475 }
476 }
477
478 pub fn get_matrix(&mut self, rows: usize, cols: usize) -> Result<AlignedMatrix<Float>> {
480 let size = (rows, cols);
481
482 if let Some(pool_index) = self.sizes.iter().position(|&s| s == size) {
484 if let Some(matrix) = self.pools[pool_index].pop() {
485 return Ok(matrix);
486 }
487 } else {
488 self.sizes.push(size);
490 self.pools.push(Vec::new());
491 }
492
493 AlignedMatrix::new(rows, cols, self.alignment)
495 }
496
497 pub fn return_matrix(&mut self, mut matrix: AlignedMatrix<Float>) {
499 let size = matrix.shape();
500
501 if let Some(pool_index) = self.sizes.iter().position(|&s| s == size) {
502 matrix.as_mut_slice().fill(0.0);
504 self.pools[pool_index].push(matrix);
505 }
506 }
508
509 pub fn clear(&mut self) {
511 self.pools.clear();
512 self.sizes.clear();
513 }
514
515 pub fn get_statistics(&self) -> PoolStatistics {
517 let total_matrices: usize = self.pools.iter().map(|pool| pool.len()).sum();
518 let unique_sizes = self.sizes.len();
519
520 PoolStatistics {
521 total_matrices,
522 unique_sizes,
523 sizes: self.sizes.clone(),
524 }
525 }
526}
527
528#[derive(Debug, Clone)]
530pub struct PoolStatistics {
531 pub total_matrices: usize,
532 pub unique_sizes: usize,
533 pub sizes: Vec<(usize, usize)>,
534}
535
536pub struct CachePerformanceAnalyzer {
538 config: CacheOptimizationConfig,
539}
540
541impl CachePerformanceAnalyzer {
542 pub fn new() -> Self {
544 Self {
545 config: CacheOptimizationConfig::default(),
546 }
547 }
548
549 pub fn estimate_cache_misses(&self, operation: &CacheAnalysis) -> CacheMissEstimate {
551 let total_accesses = operation.memory_accesses;
552 let working_set_size = operation.working_set_size;
553
554 let l1_misses = if working_set_size > self.config.l1_cache_size {
556 (total_accesses as f64 * 0.1) as usize } else {
558 (total_accesses as f64 * 0.01) as usize };
560
561 let l2_misses = if working_set_size > self.config.l2_cache_size {
562 (l1_misses as f64 * 0.5) as usize } else {
564 (l1_misses as f64 * 0.1) as usize };
566
567 let l3_misses = if working_set_size > self.config.l3_cache_size {
568 (l2_misses as f64 * 0.8) as usize } else {
570 (l2_misses as f64 * 0.2) as usize };
572
573 CacheMissEstimate {
574 l1_misses: l1_misses as usize,
575 l2_misses: l2_misses as usize,
576 l3_misses: l3_misses as usize,
577 estimated_penalty_cycles: (l3_misses as f64 * 300.0) as usize, }
579 }
580
581 pub fn analyze_matrix_operation(
583 &self,
584 rows: usize,
585 cols: usize,
586 operation_type: MatrixOperationType,
587 ) -> CacheAnalysis {
588 let matrix_size = rows * cols * size_of::<Float>();
589 let memory_accesses = match operation_type {
590 MatrixOperationType::Transpose => rows * cols,
591 MatrixOperationType::MatrixMultiply(k) => rows * cols * k,
592 MatrixOperationType::SVD => rows * cols * 10, MatrixOperationType::Eigendecomposition => rows * rows * 5, };
595
596 let working_set_size = match operation_type {
597 MatrixOperationType::Transpose => matrix_size * 2, MatrixOperationType::MatrixMultiply(_) => matrix_size * 3, MatrixOperationType::SVD => matrix_size * 4, MatrixOperationType::Eigendecomposition => matrix_size * 3, };
602
603 let cache_efficiency = if working_set_size <= self.config.l1_cache_size {
604 0.95 } else if working_set_size <= self.config.l2_cache_size {
606 0.80 } else if working_set_size <= self.config.l3_cache_size {
608 0.60 } else {
610 0.30 };
612
613 CacheAnalysis {
614 matrix_size,
615 memory_accesses,
616 working_set_size,
617 cache_efficiency,
618 recommended_tile_size: self.calculate_optimal_tile_size(working_set_size),
619 }
620 }
621
622 fn calculate_optimal_tile_size(&self, working_set_size: usize) -> usize {
624 if working_set_size <= self.config.l1_cache_size {
625 32 } else if working_set_size <= self.config.l2_cache_size {
627 64 } else {
629 128 }
631 }
632}
633
634impl Default for CachePerformanceAnalyzer {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640#[derive(Debug, Clone, Copy)]
642pub enum MatrixOperationType {
643 Transpose,
644 MatrixMultiply(usize), SVD,
646 Eigendecomposition,
647}
648
649#[derive(Debug, Clone)]
651pub struct CacheAnalysis {
652 pub matrix_size: usize,
653 pub memory_accesses: usize,
654 pub working_set_size: usize,
655 pub cache_efficiency: Float,
656 pub recommended_tile_size: usize,
657}
658
659#[derive(Debug, Clone)]
661pub struct CacheMissEstimate {
662 pub l1_misses: usize,
663 pub l2_misses: usize,
664 pub l3_misses: usize,
665 pub estimated_penalty_cycles: usize,
666}
667
668#[allow(non_snake_case)]
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
674 fn test_aligned_matrix_creation() {
675 let matrix = AlignedMatrix::<f64>::new(10, 10, 64).unwrap();
676 assert_eq!(matrix.shape(), (10, 10));
677 assert!(matrix.is_aligned());
678 assert_eq!(matrix.alignment(), 64);
679 }
680
681 #[test]
682 fn test_aligned_matrix_get_set() {
683 let mut matrix = AlignedMatrix::<f64>::new(3, 3, 32).unwrap();
684
685 matrix.set(1, 2, 42.0).unwrap();
686 let value = matrix.get(1, 2).unwrap();
687 assert_eq!(value, 42.0);
688
689 assert!(matrix.set(3, 0, 1.0).is_err());
691 assert!(matrix.get(0, 3).is_err());
692 }
693
694 #[test]
695 fn test_cache_optimization_config() {
696 let config = CacheOptimizationConfig::default();
697 assert_eq!(config.cache_line_size, 64);
698 assert_eq!(config.tile_size, 64);
699 assert!(config.enable_prefetch);
700 assert_eq!(config.memory_alignment, 64);
701 }
702
703 #[test]
704 fn test_tiled_matrix_operations() {
705 let tiled_ops = TiledMatrixOps::new();
706
707 let a = Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
708 .unwrap();
709
710 let b = Array2::from_shape_vec((3, 3), vec![9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0])
711 .unwrap();
712
713 let result = tiled_ops.tiled_matrix_multiply(&a, &b).unwrap();
714 assert_eq!(result.shape(), &[3, 3]);
715 }
716
717 #[test]
718 fn test_cache_friendly_transpose() {
719 let tiled_ops = TiledMatrixOps::new();
720
721 let matrix = Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
722
723 let transposed = tiled_ops.cache_friendly_transpose(&matrix).unwrap();
724 assert_eq!(transposed.shape(), &[3, 2]);
725 assert_eq!(transposed[[0, 0]], 1.0);
726 assert_eq!(transposed[[1, 0]], 2.0);
727 assert_eq!(transposed[[2, 0]], 3.0);
728 assert_eq!(transposed[[0, 1]], 4.0);
729 }
730
731 #[test]
732 fn test_memory_pool() {
733 let mut pool = MatrixMemoryPool::new(32);
734
735 let matrix1 = pool.get_matrix(5, 5).unwrap();
737 assert_eq!(matrix1.shape(), (5, 5));
738
739 pool.return_matrix(matrix1);
741
742 let matrix2 = pool.get_matrix(5, 5).unwrap();
744 assert_eq!(matrix2.shape(), (5, 5));
745
746 let stats = pool.get_statistics();
747 assert_eq!(stats.unique_sizes, 1);
748 assert!(stats.total_matrices <= 1); }
750
751 #[test]
752 fn test_cache_performance_analyzer() {
753 let analyzer = CachePerformanceAnalyzer::new();
754
755 let analysis =
756 analyzer.analyze_matrix_operation(100, 100, MatrixOperationType::MatrixMultiply(100));
757
758 assert!(analysis.memory_accesses > 0);
759 assert!(analysis.working_set_size > 0);
760 assert!(analysis.cache_efficiency > 0.0);
761 assert!(analysis.cache_efficiency <= 1.0);
762
763 let miss_estimate = analyzer.estimate_cache_misses(&analysis);
764 assert!(miss_estimate.l1_misses <= analysis.memory_accesses);
766 }
767
768 #[test]
769 fn test_bandwidth_efficient_matvec() {
770 let tiled_ops = TiledMatrixOps::new();
771
772 let matrix = Array2::from_shape_vec(
773 (3, 4),
774 vec![
775 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
776 ],
777 )
778 .unwrap();
779
780 let vector = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
781
782 let result = tiled_ops
783 .bandwidth_efficient_matvec(&matrix, &vector)
784 .unwrap();
785 assert_eq!(result.len(), 3);
786
787 assert_eq!(result[0], 30.0); assert_eq!(result[1], 70.0); assert_eq!(result[2], 110.0); }
792}