sklears_inspection/memory/
layout_manager.rs1use crate::types::*;
7use std::sync::{Arc, Mutex};
8
9#[derive(Clone, Debug)]
11pub struct ExplanationDataLayout {
12 pub feature_major: bool,
14 pub block_size: usize,
16 pub alignment: usize,
18}
19
20impl Default for ExplanationDataLayout {
21 fn default() -> Self {
22 Self {
23 feature_major: true,
24 block_size: 64,
25 alignment: 64,
26 }
27 }
28}
29
30pub struct MemoryLayoutManager {
32 layout: ExplanationDataLayout,
34 memory_pool: Arc<Mutex<Vec<Vec<Float>>>>,
36}
37
38impl MemoryLayoutManager {
39 pub fn new(layout: ExplanationDataLayout) -> Self {
40 Self {
41 layout,
42 memory_pool: Arc::new(Mutex::new(Vec::new())),
43 }
44 }
45
46 pub fn get_optimized_layout(
48 &self,
49 n_samples: usize,
50 n_features: usize,
51 ) -> ExplanationDataLayout {
52 let feature_major = if n_features < n_samples {
54 true
56 } else {
57 false
59 };
60
61 ExplanationDataLayout {
62 feature_major,
63 block_size: self.layout.block_size,
64 alignment: self.layout.alignment,
65 }
66 }
67
68 pub fn allocate_aligned(&self, size: usize) -> Vec<Float> {
70 {
72 let mut pool = self.memory_pool.lock().unwrap();
73 if let Some(memory) = pool.pop() {
74 if memory.len() >= size {
75 return memory;
76 }
77 }
78 }
79
80 unsafe { self.allocate_aligned_unsafe(size) }
82 }
83
84 unsafe fn allocate_aligned_unsafe(&self, size: usize) -> Vec<Float> {
93 use std::alloc::{alloc, Layout};
94
95 let alignment = self.layout.alignment.max(std::mem::align_of::<Float>());
97 let alignment = alignment.next_power_of_two();
98
99 let total_size = size * std::mem::size_of::<Float>();
101
102 let layout = Layout::from_size_align_unchecked(total_size, alignment);
104
105 let ptr = alloc(layout) as *mut Float;
107
108 if ptr.is_null() {
109 let mut memory = Vec::with_capacity(size);
111 memory.resize(size, 0.0);
112 return memory;
113 }
114
115 std::ptr::write_bytes(ptr, 0, size);
117
118 Vec::from_raw_parts(ptr, size, size)
120 }
121
122 pub fn deallocate(&self, memory: Vec<Float>) {
124 let mut pool = self.memory_pool.lock().unwrap();
125 pool.push(memory);
126
127 if pool.len() > 10 {
129 pool.truncate(5);
130 }
131 }
132
133 pub unsafe fn copy_with_prefetch(&self, src: *const Float, dst: *mut Float, len: usize) {
142 let prefetch_distance = self.layout.alignment / std::mem::size_of::<Float>();
143
144 for i in 0..len {
145 if i + prefetch_distance < len {
147 #[cfg(target_arch = "x86_64")]
148 {
149 use std::arch::x86_64::*;
150 _mm_prefetch(src.add(i + prefetch_distance) as *const i8, _MM_HINT_T0);
151 }
152 }
153
154 *dst.add(i) = *src.add(i);
156 }
157 }
158
159 pub unsafe fn vectorized_add(
168 &self,
169 a: *const Float,
170 b: *const Float,
171 result: *mut Float,
172 len: usize,
173 ) {
174 #[cfg(target_arch = "x86_64")]
175 {
176 use std::arch::x86_64::*;
177
178 if std::mem::size_of::<Float>() == 8 && is_x86_feature_detected!("avx2") {
180 let chunks = len / 4;
182 let a_ptr = a as *const f64;
183 let b_ptr = b as *const f64;
184 let result_ptr = result as *mut f64;
185
186 for i in 0..chunks {
187 let a_vec = _mm256_load_pd(a_ptr.add(i * 4));
188 let b_vec = _mm256_load_pd(b_ptr.add(i * 4));
189 let sum = _mm256_add_pd(a_vec, b_vec);
190 _mm256_store_pd(result_ptr.add(i * 4), sum);
191 }
192
193 for i in (chunks * 4)..len {
195 *result.add(i) = *a.add(i) + *b.add(i);
196 }
197 } else if std::mem::size_of::<Float>() == 4 && is_x86_feature_detected!("sse") {
198 let chunks = len / 4;
200 let a_ptr = a as *const f32;
201 let b_ptr = b as *const f32;
202 let result_ptr = result as *mut f32;
203
204 for i in 0..chunks {
205 let a_vec = _mm_load_ps(a_ptr.add(i * 4));
206 let b_vec = _mm_load_ps(b_ptr.add(i * 4));
207 let sum = _mm_add_ps(a_vec, b_vec);
208 _mm_store_ps(result_ptr.add(i * 4), sum);
209 }
210
211 for i in (chunks * 4)..len {
213 *result.add(i) = *a.add(i) + *b.add(i);
214 }
215 } else {
216 for i in 0..len {
218 *result.add(i) = *a.add(i) + *b.add(i);
219 }
220 }
221 }
222
223 #[cfg(not(target_arch = "x86_64"))]
224 {
225 for i in 0..len {
227 *result.add(i) = *a.add(i) + *b.add(i);
228 }
229 }
230 }
231
232 pub unsafe fn fast_dot_product(&self, a: *const Float, b: *const Float, len: usize) -> Float {
241 let mut result = 0.0;
242
243 #[cfg(target_arch = "x86_64")]
244 {
245 use std::arch::x86_64::*;
246
247 if std::mem::size_of::<Float>() == 8 && is_x86_feature_detected!("avx2") {
248 let chunks = len / 4;
250 let a_ptr = a as *const f64;
251 let b_ptr = b as *const f64;
252
253 let mut sum_vec = _mm256_setzero_pd();
254
255 for i in 0..chunks {
256 let a_vec = _mm256_load_pd(a_ptr.add(i * 4));
257 let b_vec = _mm256_load_pd(b_ptr.add(i * 4));
258 let prod = _mm256_mul_pd(a_vec, b_vec);
259 sum_vec = _mm256_add_pd(sum_vec, prod);
260 }
261
262 let sum_arr = [0.0; 4];
264 _mm256_store_pd(sum_arr.as_ptr() as *mut f64, sum_vec);
265 result = sum_arr[0] + sum_arr[1] + sum_arr[2] + sum_arr[3];
266
267 for i in (chunks * 4)..len {
269 result += (*a.add(i)) * (*b.add(i));
270 }
271 } else if std::mem::size_of::<Float>() == 4 && is_x86_feature_detected!("sse") {
272 let chunks = len / 4;
274 let a_ptr = a as *const f32;
275 let b_ptr = b as *const f32;
276
277 let mut sum_vec = _mm_setzero_ps();
278
279 for i in 0..chunks {
280 let a_vec = _mm_load_ps(a_ptr.add(i * 4));
281 let b_vec = _mm_load_ps(b_ptr.add(i * 4));
282 let prod = _mm_mul_ps(a_vec, b_vec);
283 sum_vec = _mm_add_ps(sum_vec, prod);
284 }
285
286 let sum_arr = [0.0; 4];
288 _mm_store_ps(sum_arr.as_ptr() as *mut f32, sum_vec);
289 result = (sum_arr[0] + sum_arr[1] + sum_arr[2] + sum_arr[3]) as Float;
290
291 for i in (chunks * 4)..len {
293 result += (*a.add(i)) * (*b.add(i));
294 }
295 } else {
296 for i in 0..len {
298 result += (*a.add(i)) * (*b.add(i));
299 }
300 }
301 }
302
303 #[cfg(not(target_arch = "x86_64"))]
304 {
305 for i in 0..len {
307 result += (*a.add(i)) * (*b.add(i));
308 }
309 }
310
311 result
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use super::*;
318
319 #[test]
320 fn test_memory_layout_manager() {
321 let layout = ExplanationDataLayout {
322 feature_major: true,
323 block_size: 64,
324 alignment: 32,
325 };
326
327 let manager = MemoryLayoutManager::new(layout);
328 let optimized = manager.get_optimized_layout(100, 10);
329
330 assert!(optimized.feature_major);
332 }
333
334 #[test]
335 fn test_aligned_memory_allocation() {
336 let layout = ExplanationDataLayout {
337 feature_major: true,
338 block_size: 64,
339 alignment: 32,
340 };
341
342 let manager = MemoryLayoutManager::new(layout);
343 let memory = manager.allocate_aligned(100);
344
345 assert_eq!(memory.len(), 100);
346
347 manager.deallocate(memory);
349 }
350
351 #[test]
352 fn test_layout_optimization() {
353 let layout = ExplanationDataLayout::default();
354 let manager = MemoryLayoutManager::new(layout);
355
356 let opt1 = manager.get_optimized_layout(1000, 10);
358 assert!(opt1.feature_major);
359
360 let opt2 = manager.get_optimized_layout(10, 1000);
362 assert!(!opt2.feature_major);
363 }
364
365 #[test]
366 fn test_memory_pool() {
367 let layout = ExplanationDataLayout::default();
368 let manager = MemoryLayoutManager::new(layout);
369
370 let mem1 = manager.allocate_aligned(50);
372 let mem2 = manager.allocate_aligned(100);
373
374 manager.deallocate(mem1);
375 manager.deallocate(mem2);
376
377 let mem3 = manager.allocate_aligned(75);
379 assert_eq!(mem3.len(), 100); }
381
382 #[test]
383 fn test_explanation_data_layout_default() {
384 let layout = ExplanationDataLayout::default();
385 assert!(layout.feature_major);
386 assert_eq!(layout.block_size, 64);
387 assert_eq!(layout.alignment, 64);
388 }
389
390 #[test]
391 #[ignore] fn test_unsafe_operations_safety() {
393 let layout = ExplanationDataLayout::default();
394 let manager = MemoryLayoutManager::new(layout);
395
396 let mut vec_a = vec![1.0, 2.0, 3.0, 4.0];
398 let mut vec_b = vec![5.0, 6.0, 7.0, 8.0];
399 let mut result = vec![0.0; 4];
400
401 unsafe {
402 manager.vectorized_add(vec_a.as_ptr(), vec_b.as_ptr(), result.as_mut_ptr(), 4);
404
405 let dot = manager.fast_dot_product(vec_a.as_ptr(), vec_b.as_ptr(), 4);
407 assert!(dot > 0.0); }
409
410 for i in 0..4 {
412 assert!((result[i] - (vec_a[i] + vec_b[i])).abs() < 1e-6);
413 }
414 }
415}