1use scirs2_core::ndarray::{Array1, Array2};
15use serde::{Deserialize, Serialize};
16use sklears_core::error::SklearsError;
17use std::alloc::{alloc, dealloc, Layout};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
21pub enum MemoryLayout {
22 #[default]
24 RowMajor,
25 ColumnMajor,
27 Blocked { tile_size: usize },
29 StructureOfArrays,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CacheConfig {
36 pub layout: MemoryLayout,
38 pub l1_cache_size: usize,
40 pub l2_cache_size: usize,
42 pub l3_cache_size: usize,
44 pub cache_line_size: usize,
46 pub enable_prefetch: bool,
48 pub alignment: usize,
50}
51
52impl Default for CacheConfig {
53 fn default() -> Self {
54 Self {
55 layout: MemoryLayout::default(),
56 l1_cache_size: 32 * 1024, l2_cache_size: 256 * 1024, l3_cache_size: 8 * 1024 * 1024, cache_line_size: 64,
60 enable_prefetch: true,
61 alignment: 64, }
63 }
64}
65
66impl CacheConfig {
67 pub fn optimal_tile_size(&self) -> usize {
69 let bytes_per_element = std::mem::size_of::<f64>();
72 let available_cache = self.l2_cache_size / 3;
73 let elements_per_tile = available_cache / bytes_per_element;
74
75 let tile_dim = (elements_per_tile as f64).sqrt() as usize;
77
78 let elements_per_line = self.cache_line_size / bytes_per_element;
80 (tile_dim / elements_per_line) * elements_per_line
81 }
82
83 pub fn optimal_block_size(&self) -> usize {
85 let bytes_per_element = std::mem::size_of::<f64>();
87 let available_cache = self.l1_cache_size / 2; available_cache / bytes_per_element
89 }
90}
91
92pub struct CacheFriendlyMatrix {
94 data: Vec<f64>,
96 n_rows: usize,
98 n_cols: usize,
100 layout: MemoryLayout,
102 config: CacheConfig,
104}
105
106impl CacheFriendlyMatrix {
107 pub fn from_array(array: &Array2<f64>, config: CacheConfig) -> Result<Self, SklearsError> {
109 let (n_rows, n_cols) = array.dim();
110 let layout = config.layout;
111
112 let data = match layout {
113 MemoryLayout::RowMajor => array.iter().copied().collect(),
114 MemoryLayout::ColumnMajor => {
115 let mut data = Vec::with_capacity(n_rows * n_cols);
116 for col in 0..n_cols {
117 for row in 0..n_rows {
118 data.push(array[[row, col]]);
119 }
120 }
121 data
122 }
123 MemoryLayout::Blocked { tile_size } => Self::convert_to_blocked(array, tile_size)?,
124 MemoryLayout::StructureOfArrays => array.iter().copied().collect(),
125 };
126
127 Ok(Self {
128 data,
129 n_rows,
130 n_cols,
131 layout,
132 config,
133 })
134 }
135
136 fn convert_to_blocked(array: &Array2<f64>, tile_size: usize) -> Result<Vec<f64>, SklearsError> {
138 let (n_rows, n_cols) = array.dim();
139 let mut data = vec![0.0; n_rows * n_cols];
140
141 let n_row_blocks = (n_rows + tile_size - 1) / tile_size;
142 let n_col_blocks = (n_cols + tile_size - 1) / tile_size;
143
144 let mut offset = 0;
145 for block_row in 0..n_row_blocks {
146 for block_col in 0..n_col_blocks {
147 let row_start = block_row * tile_size;
148 let row_end = (row_start + tile_size).min(n_rows);
149 let col_start = block_col * tile_size;
150 let col_end = (col_start + tile_size).min(n_cols);
151
152 for row in row_start..row_end {
153 for col in col_start..col_end {
154 data[offset] = array[[row, col]];
155 offset += 1;
156 }
157 }
158 }
159 }
160
161 Ok(data)
162 }
163
164 pub fn get(&self, row: usize, col: usize) -> Result<f64, SklearsError> {
166 if row >= self.n_rows || col >= self.n_cols {
167 return Err(SklearsError::InvalidInput(
168 "Index out of bounds".to_string(),
169 ));
170 }
171
172 let idx = match self.layout {
173 MemoryLayout::RowMajor => row * self.n_cols + col,
174 MemoryLayout::ColumnMajor => col * self.n_rows + row,
175 MemoryLayout::Blocked { tile_size } => {
176 let block_row = row / tile_size;
177 let block_col = col / tile_size;
178 let in_block_row = row % tile_size;
179 let in_block_col = col % tile_size;
180
181 let n_col_blocks = (self.n_cols + tile_size - 1) / tile_size;
182 let block_idx = block_row * n_col_blocks + block_col;
183 let block_offset = block_idx * tile_size * tile_size;
184
185 block_offset + in_block_row * tile_size + in_block_col
186 }
187 MemoryLayout::StructureOfArrays => row * self.n_cols + col,
188 };
189
190 Ok(self.data[idx])
191 }
192
193 pub fn to_array(&self) -> Result<Array2<f64>, SklearsError> {
195 let mut array = Array2::zeros((self.n_rows, self.n_cols));
196
197 for row in 0..self.n_rows {
198 for col in 0..self.n_cols {
199 array[[row, col]] = self.get(row, col)?;
200 }
201 }
202
203 Ok(array)
204 }
205
206 pub fn dot_vector(&self, vector: &Array1<f64>) -> Result<Array1<f64>, SklearsError> {
208 if vector.len() != self.n_cols {
209 return Err(SklearsError::InvalidInput(
210 "Vector length must match number of columns".to_string(),
211 ));
212 }
213
214 let mut result = Array1::zeros(self.n_rows);
215
216 match self.layout {
217 MemoryLayout::RowMajor => {
218 for row in 0..self.n_rows {
220 let mut sum = 0.0;
221 for col in 0..self.n_cols {
222 sum += self.data[row * self.n_cols + col] * vector[col];
223 }
224 result[row] = sum;
225 }
226 }
227 MemoryLayout::ColumnMajor => {
228 for col in 0..self.n_cols {
230 let v_col = vector[col];
231 for row in 0..self.n_rows {
232 result[row] += self.data[col * self.n_rows + row] * v_col;
233 }
234 }
235 }
236 MemoryLayout::Blocked { tile_size } => {
237 for row_block in (0..self.n_rows).step_by(tile_size) {
239 let row_end = (row_block + tile_size).min(self.n_rows);
240
241 for col_block in (0..self.n_cols).step_by(tile_size) {
242 let col_end = (col_block + tile_size).min(self.n_cols);
243
244 for row in row_block..row_end {
245 let mut sum = 0.0;
246 for col in col_block..col_end {
247 sum += self.get(row, col)? * vector[col];
248 }
249 result[row] += sum;
250 }
251 }
252 }
253 }
254 MemoryLayout::StructureOfArrays => {
255 for row in 0..self.n_rows {
257 let mut sum = 0.0;
258 for col in 0..self.n_cols {
259 sum += self.data[row * self.n_cols + col] * vector[col];
260 }
261 result[row] = sum;
262 }
263 }
264 }
265
266 Ok(result)
267 }
268
269 pub fn dim(&self) -> (usize, usize) {
271 (self.n_rows, self.n_cols)
272 }
273}
274
275pub struct AlignedBuffer {
277 ptr: *mut f64,
278 len: usize,
279 alignment: usize,
280}
281
282impl AlignedBuffer {
283 pub fn new(len: usize, alignment: usize) -> Result<Self, SklearsError> {
285 if alignment == 0 || !alignment.is_power_of_two() {
286 return Err(SklearsError::InvalidInput(
287 "Alignment must be a power of 2".to_string(),
288 ));
289 }
290
291 let layout = Layout::from_size_align(len * std::mem::size_of::<f64>(), alignment)
292 .map_err(|e| SklearsError::InvalidInput(format!("Invalid layout: {}", e)))?;
293
294 let ptr = unsafe { alloc(layout) as *mut f64 };
295
296 if ptr.is_null() {
297 return Err(SklearsError::InvalidInput(
298 "Failed to allocate aligned memory".to_string(),
299 ));
300 }
301
302 unsafe {
304 std::ptr::write_bytes(ptr, 0, len);
305 }
306
307 Ok(Self {
308 ptr,
309 len,
310 alignment,
311 })
312 }
313
314 pub fn as_slice(&self) -> &[f64] {
316 unsafe { std::slice::from_raw_parts(self.ptr, self.len) }
317 }
318
319 pub fn as_mut_slice(&mut self) -> &mut [f64] {
321 unsafe { std::slice::from_raw_parts_mut(self.ptr, self.len) }
322 }
323
324 pub fn len(&self) -> usize {
326 self.len
327 }
328
329 pub fn is_empty(&self) -> bool {
331 self.len == 0
332 }
333
334 pub fn to_array(&self) -> Array1<f64> {
336 Array1::from_vec(self.as_slice().to_vec())
337 }
338}
339
340impl Drop for AlignedBuffer {
341 fn drop(&mut self) {
342 if !self.ptr.is_null() {
343 let layout =
344 Layout::from_size_align(self.len * std::mem::size_of::<f64>(), self.alignment)
345 .expect("Invalid layout in drop");
346 unsafe {
347 dealloc(self.ptr as *mut u8, layout);
348 }
349 }
350 }
351}
352
353unsafe impl Send for AlignedBuffer {}
354unsafe impl Sync for AlignedBuffer {}
355
356pub trait CacheAwareTransform {
358 fn transform_cached(
360 &self,
361 features: &Array2<f64>,
362 config: &CacheConfig,
363 ) -> Result<Array2<f64>, SklearsError>;
364}
365
366pub mod utils {
368 use super::*;
369
370 #[inline(always)]
372 pub fn prefetch_read(addr: *const f64) {
373 #[cfg(target_arch = "x86_64")]
374 {
375 #[cfg(target_feature = "sse")]
376 unsafe {
377 use std::arch::x86_64::_mm_prefetch;
378 use std::arch::x86_64::_MM_HINT_T0;
379 _mm_prefetch(addr as *const i8, _MM_HINT_T0);
380 }
381 }
382
383 let _ = addr;
385 }
386
387 pub fn transpose_blocked(
389 matrix: &Array2<f64>,
390 tile_size: usize,
391 ) -> Result<Array2<f64>, SklearsError> {
392 let (n_rows, n_cols) = matrix.dim();
393 let mut result = Array2::zeros((n_cols, n_rows));
394
395 for row_block in (0..n_rows).step_by(tile_size) {
396 for col_block in (0..n_cols).step_by(tile_size) {
397 let row_end = (row_block + tile_size).min(n_rows);
398 let col_end = (col_block + tile_size).min(n_cols);
399
400 for row in row_block..row_end {
401 for col in col_block..col_end {
402 result[[col, row]] = matrix[[row, col]];
403 }
404 }
405 }
406 }
407
408 Ok(result)
409 }
410
411 pub fn optimal_thread_count(
413 n_samples: usize,
414 n_features: usize,
415 config: &CacheConfig,
416 ) -> usize {
417 let data_size = n_samples * n_features * std::mem::size_of::<f64>();
418 let num_cpus = num_cpus::get();
419
420 if data_size <= config.l3_cache_size {
422 (num_cpus / 2).max(1)
423 } else {
424 num_cpus
425 }
426 }
427
428 pub fn align_to_cache_line(size: usize, cache_line_size: usize) -> usize {
430 ((size + cache_line_size - 1) / cache_line_size) * cache_line_size
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use scirs2_core::ndarray::array;
438
439 #[test]
440 fn test_cache_config_default() {
441 let config = CacheConfig::default();
442 assert_eq!(config.cache_line_size, 64);
443 assert!(config.enable_prefetch);
444 }
445
446 #[test]
447 fn test_optimal_tile_size() {
448 let config = CacheConfig::default();
449 let tile_size = config.optimal_tile_size();
450 assert!(tile_size > 0);
451 assert!(tile_size < 1024); }
453
454 #[test]
455 fn test_cache_friendly_matrix_row_major() {
456 let array = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
457 let config = CacheConfig {
458 layout: MemoryLayout::RowMajor,
459 ..Default::default()
460 };
461
462 let matrix = CacheFriendlyMatrix::from_array(&array, config).unwrap();
463 assert_eq!(matrix.dim(), (2, 3));
464 assert_eq!(matrix.get(0, 0).unwrap(), 1.0);
465 assert_eq!(matrix.get(1, 2).unwrap(), 6.0);
466 }
467
468 #[test]
469 fn test_cache_friendly_matrix_column_major() {
470 let array = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
471 let config = CacheConfig {
472 layout: MemoryLayout::ColumnMajor,
473 ..Default::default()
474 };
475
476 let matrix = CacheFriendlyMatrix::from_array(&array, config).unwrap();
477 assert_eq!(matrix.get(0, 0).unwrap(), 1.0);
478 assert_eq!(matrix.get(1, 2).unwrap(), 6.0);
479 }
480
481 #[test]
482 fn test_cache_friendly_matrix_blocked() {
483 let array = array![
484 [1.0, 2.0, 3.0, 4.0],
485 [5.0, 6.0, 7.0, 8.0],
486 [9.0, 10.0, 11.0, 12.0],
487 [13.0, 14.0, 15.0, 16.0]
488 ];
489 let config = CacheConfig {
490 layout: MemoryLayout::Blocked { tile_size: 2 },
491 ..Default::default()
492 };
493
494 let matrix = CacheFriendlyMatrix::from_array(&array, config).unwrap();
495 assert_eq!(matrix.get(0, 0).unwrap(), 1.0);
496 assert_eq!(matrix.get(3, 3).unwrap(), 16.0);
497 }
498
499 #[test]
500 fn test_matrix_vector_multiplication() {
501 let array = array![[1.0, 2.0], [3.0, 4.0]];
502 let vector = array![1.0, 2.0];
503 let config = CacheConfig::default();
504
505 let matrix = CacheFriendlyMatrix::from_array(&array, config).unwrap();
506 let result = matrix.dot_vector(&vector).unwrap();
507
508 assert_eq!(result[0], 5.0); assert_eq!(result[1], 11.0); }
511
512 #[test]
513 fn test_aligned_buffer() {
514 let buffer = AlignedBuffer::new(10, 64).unwrap();
515 assert_eq!(buffer.len(), 10);
516 assert!(!buffer.is_empty());
517
518 let slice = buffer.as_slice();
519 assert_eq!(slice.len(), 10);
520 }
521
522 #[test]
523 fn test_transpose_blocked() {
524 let matrix = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
525 let transposed = utils::transpose_blocked(&matrix, 2).unwrap();
526
527 assert_eq!(transposed.dim(), (3, 2));
528 assert_eq!(transposed[[0, 0]], 1.0);
529 assert_eq!(transposed[[2, 1]], 6.0);
530 }
531
532 #[test]
533 fn test_optimal_thread_count() {
534 let config = CacheConfig::default();
535 let threads = utils::optimal_thread_count(1000, 100, &config);
536 assert!(threads > 0);
537 assert!(threads <= num_cpus::get());
538 }
539
540 #[test]
541 fn test_align_to_cache_line() {
542 let aligned = utils::align_to_cache_line(100, 64);
543 assert_eq!(aligned, 128);
544 assert_eq!(aligned % 64, 0);
545 }
546}