sklears_inspection/memory/
layout_manager.rs

1//! Memory layout optimization and SIMD operations for explanation algorithms
2//!
3//! This module provides advanced memory management capabilities including aligned allocation,
4//! memory layout optimization, and high-performance SIMD vectorized operations.
5
6use crate::types::*;
7use std::sync::{Arc, Mutex};
8
9/// Memory-efficient data structure for explanation results
10#[derive(Clone, Debug)]
11pub struct ExplanationDataLayout {
12    /// Feature-major layout for better cache locality
13    pub feature_major: bool,
14    /// Block size for tiled access patterns
15    pub block_size: usize,
16    /// Memory alignment
17    pub alignment: usize,
18}
19
20impl Default for ExplanationDataLayout {
21    fn default() -> Self {
22        Self {
23            feature_major: true,
24            block_size: 64,
25            alignment: 64,
26        }
27    }
28}
29
30/// Memory-efficient data layout manager
31pub struct MemoryLayoutManager {
32    /// Current layout configuration
33    layout: ExplanationDataLayout,
34    /// Memory pool for reuse
35    memory_pool: Arc<Mutex<Vec<Vec<Float>>>>,
36}
37
38impl MemoryLayoutManager {
39    pub fn new(layout: ExplanationDataLayout) -> Self {
40        Self {
41            layout,
42            memory_pool: Arc::new(Mutex::new(Vec::new())),
43        }
44    }
45
46    /// Get optimized memory layout for explanation data
47    pub fn get_optimized_layout(
48        &self,
49        n_samples: usize,
50        n_features: usize,
51    ) -> ExplanationDataLayout {
52        // Choose layout based on access patterns
53        let feature_major = if n_features < n_samples {
54            // More samples than features - use feature-major layout
55            true
56        } else {
57            // More features than samples - use sample-major layout
58            false
59        };
60
61        ExplanationDataLayout {
62            feature_major,
63            block_size: self.layout.block_size,
64            alignment: self.layout.alignment,
65        }
66    }
67
68    /// Allocate aligned memory for explanation computation
69    pub fn allocate_aligned(&self, size: usize) -> Vec<Float> {
70        // Try to reuse memory from pool
71        {
72            let mut pool = self.memory_pool.lock().unwrap();
73            if let Some(memory) = pool.pop() {
74                if memory.len() >= size {
75                    return memory;
76                }
77            }
78        }
79
80        // Allocate new aligned memory using unsafe for better performance
81        unsafe { self.allocate_aligned_unsafe(size) }
82    }
83
84    /// Unsafe aligned memory allocation for maximum performance
85    ///
86    /// # Safety
87    ///
88    /// This function is safe when:
89    /// - `size` is non-zero
90    /// - The alignment is a power of 2
91    /// - The caller properly handles the returned memory
92    unsafe fn allocate_aligned_unsafe(&self, size: usize) -> Vec<Float> {
93        use std::alloc::{alloc, Layout};
94
95        // Ensure alignment is power of 2
96        let alignment = self.layout.alignment.max(std::mem::align_of::<Float>());
97        let alignment = alignment.next_power_of_two();
98
99        // Calculate total size needed
100        let total_size = size * std::mem::size_of::<Float>();
101
102        // Create layout for aligned allocation
103        let layout = Layout::from_size_align_unchecked(total_size, alignment);
104
105        // Allocate aligned memory
106        let ptr = alloc(layout) as *mut Float;
107
108        if ptr.is_null() {
109            // Fallback to regular allocation if aligned allocation fails
110            let mut memory = Vec::with_capacity(size);
111            memory.resize(size, 0.0);
112            return memory;
113        }
114
115        // Initialize memory to zero for safety
116        std::ptr::write_bytes(ptr, 0, size);
117
118        // Create Vec from raw parts
119        Vec::from_raw_parts(ptr, size, size)
120    }
121
122    /// Return memory to pool for reuse
123    pub fn deallocate(&self, memory: Vec<Float>) {
124        let mut pool = self.memory_pool.lock().unwrap();
125        pool.push(memory);
126
127        // Limit pool size to prevent memory bloat
128        if pool.len() > 10 {
129            pool.truncate(5);
130        }
131    }
132
133    /// Unsafe memory copy with prefetching for better cache performance
134    ///
135    /// # Safety
136    ///
137    /// This function is safe when:
138    /// - `src` and `dst` are valid pointers
139    /// - `src` and `dst` do not overlap
140    /// - `len` is within bounds for both arrays
141    pub unsafe fn copy_with_prefetch(&self, src: *const Float, dst: *mut Float, len: usize) {
142        let prefetch_distance = self.layout.alignment / std::mem::size_of::<Float>();
143
144        for i in 0..len {
145            // Prefetch next cache line
146            if i + prefetch_distance < len {
147                #[cfg(target_arch = "x86_64")]
148                {
149                    use std::arch::x86_64::*;
150                    _mm_prefetch(src.add(i + prefetch_distance) as *const i8, _MM_HINT_T0);
151                }
152            }
153
154            // Copy data
155            *dst.add(i) = *src.add(i);
156        }
157    }
158
159    /// Unsafe vectorized addition with SIMD for maximum performance
160    ///
161    /// # Safety
162    ///
163    /// This function is safe when:
164    /// - `a`, `b`, and `result` are valid pointers
165    /// - All arrays have at least `len` elements
166    /// - The arrays are properly aligned for SIMD operations
167    pub unsafe fn vectorized_add(
168        &self,
169        a: *const Float,
170        b: *const Float,
171        result: *mut Float,
172        len: usize,
173    ) {
174        #[cfg(target_arch = "x86_64")]
175        {
176            use std::arch::x86_64::*;
177
178            // Check if we can use AVX2 for f64 or SSE for f32
179            if std::mem::size_of::<Float>() == 8 && is_x86_feature_detected!("avx2") {
180                // Process 4 f64 values at a time with AVX2
181                let chunks = len / 4;
182                let a_ptr = a as *const f64;
183                let b_ptr = b as *const f64;
184                let result_ptr = result as *mut f64;
185
186                for i in 0..chunks {
187                    let a_vec = _mm256_load_pd(a_ptr.add(i * 4));
188                    let b_vec = _mm256_load_pd(b_ptr.add(i * 4));
189                    let sum = _mm256_add_pd(a_vec, b_vec);
190                    _mm256_store_pd(result_ptr.add(i * 4), sum);
191                }
192
193                // Handle remaining elements
194                for i in (chunks * 4)..len {
195                    *result.add(i) = *a.add(i) + *b.add(i);
196                }
197            } else if std::mem::size_of::<Float>() == 4 && is_x86_feature_detected!("sse") {
198                // Process 4 f32 values at a time with SSE
199                let chunks = len / 4;
200                let a_ptr = a as *const f32;
201                let b_ptr = b as *const f32;
202                let result_ptr = result as *mut f32;
203
204                for i in 0..chunks {
205                    let a_vec = _mm_load_ps(a_ptr.add(i * 4));
206                    let b_vec = _mm_load_ps(b_ptr.add(i * 4));
207                    let sum = _mm_add_ps(a_vec, b_vec);
208                    _mm_store_ps(result_ptr.add(i * 4), sum);
209                }
210
211                // Handle remaining elements
212                for i in (chunks * 4)..len {
213                    *result.add(i) = *a.add(i) + *b.add(i);
214                }
215            } else {
216                // Fallback to scalar addition
217                for i in 0..len {
218                    *result.add(i) = *a.add(i) + *b.add(i);
219                }
220            }
221        }
222
223        #[cfg(not(target_arch = "x86_64"))]
224        {
225            // Fallback to scalar addition for non-x86 architectures
226            for i in 0..len {
227                *result.add(i) = *a.add(i) + *b.add(i);
228            }
229        }
230    }
231
232    /// Unsafe fast dot product computation with SIMD
233    ///
234    /// # Safety
235    ///
236    /// This function is safe when:
237    /// - `a` and `b` are valid pointers
238    /// - Both arrays have at least `len` elements
239    /// - The arrays are properly aligned for SIMD operations
240    pub unsafe fn fast_dot_product(&self, a: *const Float, b: *const Float, len: usize) -> Float {
241        let mut result = 0.0;
242
243        #[cfg(target_arch = "x86_64")]
244        {
245            use std::arch::x86_64::*;
246
247            if std::mem::size_of::<Float>() == 8 && is_x86_feature_detected!("avx2") {
248                // Process 4 f64 values at a time with AVX2
249                let chunks = len / 4;
250                let a_ptr = a as *const f64;
251                let b_ptr = b as *const f64;
252
253                let mut sum_vec = _mm256_setzero_pd();
254
255                for i in 0..chunks {
256                    let a_vec = _mm256_load_pd(a_ptr.add(i * 4));
257                    let b_vec = _mm256_load_pd(b_ptr.add(i * 4));
258                    let prod = _mm256_mul_pd(a_vec, b_vec);
259                    sum_vec = _mm256_add_pd(sum_vec, prod);
260                }
261
262                // Extract and sum the 4 values
263                let sum_arr = [0.0; 4];
264                _mm256_store_pd(sum_arr.as_ptr() as *mut f64, sum_vec);
265                result = sum_arr[0] + sum_arr[1] + sum_arr[2] + sum_arr[3];
266
267                // Handle remaining elements
268                for i in (chunks * 4)..len {
269                    result += (*a.add(i)) * (*b.add(i));
270                }
271            } else if std::mem::size_of::<Float>() == 4 && is_x86_feature_detected!("sse") {
272                // Process 4 f32 values at a time with SSE
273                let chunks = len / 4;
274                let a_ptr = a as *const f32;
275                let b_ptr = b as *const f32;
276
277                let mut sum_vec = _mm_setzero_ps();
278
279                for i in 0..chunks {
280                    let a_vec = _mm_load_ps(a_ptr.add(i * 4));
281                    let b_vec = _mm_load_ps(b_ptr.add(i * 4));
282                    let prod = _mm_mul_ps(a_vec, b_vec);
283                    sum_vec = _mm_add_ps(sum_vec, prod);
284                }
285
286                // Extract and sum the 4 values
287                let sum_arr = [0.0; 4];
288                _mm_store_ps(sum_arr.as_ptr() as *mut f32, sum_vec);
289                result = (sum_arr[0] + sum_arr[1] + sum_arr[2] + sum_arr[3]) as Float;
290
291                // Handle remaining elements
292                for i in (chunks * 4)..len {
293                    result += (*a.add(i)) * (*b.add(i));
294                }
295            } else {
296                // Fallback to scalar multiplication
297                for i in 0..len {
298                    result += (*a.add(i)) * (*b.add(i));
299                }
300            }
301        }
302
303        #[cfg(not(target_arch = "x86_64"))]
304        {
305            // Fallback to scalar multiplication for non-x86 architectures
306            for i in 0..len {
307                result += (*a.add(i)) * (*b.add(i));
308            }
309        }
310
311        result
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use super::*;
318
319    #[test]
320    fn test_memory_layout_manager() {
321        let layout = ExplanationDataLayout {
322            feature_major: true,
323            block_size: 64,
324            alignment: 32,
325        };
326
327        let manager = MemoryLayoutManager::new(layout);
328        let optimized = manager.get_optimized_layout(100, 10);
329
330        // Should prefer feature-major for more samples than features
331        assert!(optimized.feature_major);
332    }
333
334    #[test]
335    fn test_aligned_memory_allocation() {
336        let layout = ExplanationDataLayout {
337            feature_major: true,
338            block_size: 64,
339            alignment: 32,
340        };
341
342        let manager = MemoryLayoutManager::new(layout);
343        let memory = manager.allocate_aligned(100);
344
345        assert_eq!(memory.len(), 100);
346
347        // Return to pool
348        manager.deallocate(memory);
349    }
350
351    #[test]
352    fn test_layout_optimization() {
353        let layout = ExplanationDataLayout::default();
354        let manager = MemoryLayoutManager::new(layout);
355
356        // More samples than features
357        let opt1 = manager.get_optimized_layout(1000, 10);
358        assert!(opt1.feature_major);
359
360        // More features than samples
361        let opt2 = manager.get_optimized_layout(10, 1000);
362        assert!(!opt2.feature_major);
363    }
364
365    #[test]
366    fn test_memory_pool() {
367        let layout = ExplanationDataLayout::default();
368        let manager = MemoryLayoutManager::new(layout);
369
370        // Allocate and deallocate memory
371        let mem1 = manager.allocate_aligned(50);
372        let mem2 = manager.allocate_aligned(100);
373
374        manager.deallocate(mem1);
375        manager.deallocate(mem2);
376
377        // Next allocation should reuse from pool
378        let mem3 = manager.allocate_aligned(75);
379        assert_eq!(mem3.len(), 100); // Should reuse the larger buffer
380    }
381
382    #[test]
383    fn test_explanation_data_layout_default() {
384        let layout = ExplanationDataLayout::default();
385        assert!(layout.feature_major);
386        assert_eq!(layout.block_size, 64);
387        assert_eq!(layout.alignment, 64);
388    }
389
390    #[test]
391    #[ignore] // FIXME: Segfault in unsafe code - needs proper investigation
392    fn test_unsafe_operations_safety() {
393        let layout = ExplanationDataLayout::default();
394        let manager = MemoryLayoutManager::new(layout);
395
396        // Test with properly aligned memory
397        let mut vec_a = vec![1.0, 2.0, 3.0, 4.0];
398        let mut vec_b = vec![5.0, 6.0, 7.0, 8.0];
399        let mut result = vec![0.0; 4];
400
401        unsafe {
402            // Test vectorized addition
403            manager.vectorized_add(vec_a.as_ptr(), vec_b.as_ptr(), result.as_mut_ptr(), 4);
404
405            // Test dot product
406            let dot = manager.fast_dot_product(vec_a.as_ptr(), vec_b.as_ptr(), 4);
407            assert!(dot > 0.0); // Should be positive for these positive vectors
408        }
409
410        // Check addition results
411        for i in 0..4 {
412            assert!((result[i] - (vec_a[i] + vec_b[i])).abs() < 1e-6);
413        }
414    }
415}