sklears_inspection/memory/
cache.rs

1//! Core caching system for explanation algorithms
2//!
3//! This module provides the fundamental caching infrastructure for explanation computation,
4//! including cache management, key generation, and statistics tracking.
5
6use crate::types::*;
7use crate::SklResult;
8// ✅ SciRS2 Policy Compliant Import
9use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, Mutex};
13
14/// Cache-friendly explanation computation with memory optimization
15pub struct ExplanationCache {
16    /// Feature importance cache
17    feature_importance_cache: Arc<Mutex<HashMap<CacheKey, Array1<Float>>>>,
18    /// Partial dependence cache
19    partial_dependence_cache: Arc<Mutex<HashMap<CacheKey, Array2<Float>>>>,
20    /// SHAP values cache
21    shap_cache: Arc<Mutex<HashMap<CacheKey, Array2<Float>>>>,
22    /// Model prediction cache
23    prediction_cache: Arc<Mutex<HashMap<CacheKey, Array1<Float>>>>,
24    /// Cache size limits
25    max_cache_size: usize,
26    /// Cache hit statistics
27    cache_hits: Arc<Mutex<CacheStatistics>>,
28}
29
30/// Cache key for identifying cached computations
31#[derive(Clone, Debug, PartialEq, Eq, Hash)]
32pub struct CacheKey {
33    /// Data hash
34    data_hash: u64,
35    /// Method identifier
36    method_id: String,
37    /// Configuration hash
38    config_hash: u64,
39}
40
41/// Cache statistics for monitoring performance
42#[derive(Clone, Debug, Default)]
43pub struct CacheStatistics {
44    /// Number of cache hits
45    pub hits: usize,
46    /// Number of cache misses
47    pub misses: usize,
48    /// Total cache size in bytes
49    pub total_size: usize,
50    /// Average access time
51    pub avg_access_time: f64,
52}
53
54/// Configuration for cache-friendly algorithms
55#[derive(Clone, Debug)]
56pub struct CacheConfig {
57    /// Maximum cache size in MB
58    pub max_cache_size_mb: usize,
59    /// Enable data locality optimization
60    pub enable_locality_optimization: bool,
61    /// Prefetch distance for sequential access
62    pub prefetch_distance: usize,
63    /// Memory alignment for SIMD operations
64    pub memory_alignment: usize,
65}
66
67impl Default for CacheConfig {
68    fn default() -> Self {
69        Self {
70            max_cache_size_mb: 256,
71            enable_locality_optimization: true,
72            prefetch_distance: 64,
73            memory_alignment: 64,
74        }
75    }
76}
77
78impl ExplanationCache {
79    /// Create a new explanation cache with specified configuration
80    pub fn new(config: &CacheConfig) -> Self {
81        Self {
82            feature_importance_cache: Arc::new(Mutex::new(HashMap::new())),
83            partial_dependence_cache: Arc::new(Mutex::new(HashMap::new())),
84            shap_cache: Arc::new(Mutex::new(HashMap::new())),
85            prediction_cache: Arc::new(Mutex::new(HashMap::new())),
86            max_cache_size: config.max_cache_size_mb * 1024 * 1024,
87            cache_hits: Arc::new(Mutex::new(CacheStatistics::default())),
88        }
89    }
90
91    /// Get or compute feature importance with caching
92    pub fn get_or_compute_feature_importance<F>(
93        &self,
94        key: &CacheKey,
95        compute_fn: F,
96    ) -> SklResult<Array1<Float>>
97    where
98        F: FnOnce() -> SklResult<Array1<Float>>,
99    {
100        // Check cache first
101        {
102            let cache = self.feature_importance_cache.lock().unwrap();
103            if let Some(result) = cache.get(key) {
104                // Cache hit
105                let mut stats = self.cache_hits.lock().unwrap();
106                stats.hits += 1;
107                return Ok(result.clone());
108            }
109        }
110
111        // Cache miss - compute and store
112        let result = compute_fn()?;
113
114        {
115            let mut cache = self.feature_importance_cache.lock().unwrap();
116            cache.insert(key.clone(), result.clone());
117
118            // Update statistics
119            let mut stats = self.cache_hits.lock().unwrap();
120            stats.misses += 1;
121            stats.total_size += result.len() * std::mem::size_of::<Float>();
122        }
123
124        Ok(result)
125    }
126
127    /// Get or compute SHAP values with caching
128    pub fn get_or_compute_shap<F>(&self, key: &CacheKey, compute_fn: F) -> SklResult<Array2<Float>>
129    where
130        F: FnOnce() -> SklResult<Array2<Float>>,
131    {
132        // Check cache first
133        {
134            let cache = self.shap_cache.lock().unwrap();
135            if let Some(result) = cache.get(key) {
136                // Cache hit
137                let mut stats = self.cache_hits.lock().unwrap();
138                stats.hits += 1;
139                return Ok(result.clone());
140            }
141        }
142
143        // Cache miss - compute and store
144        let result = compute_fn()?;
145
146        {
147            let mut cache = self.shap_cache.lock().unwrap();
148            cache.insert(key.clone(), result.clone());
149
150            // Update statistics
151            let mut stats = self.cache_hits.lock().unwrap();
152            stats.misses += 1;
153            stats.total_size += result.len() * std::mem::size_of::<Float>();
154        }
155
156        Ok(result)
157    }
158
159    /// Get cache statistics
160    pub fn get_statistics(&self) -> CacheStatistics {
161        self.cache_hits.lock().unwrap().clone()
162    }
163
164    /// Clear all caches
165    pub fn clear_all(&self) {
166        self.feature_importance_cache.lock().unwrap().clear();
167        self.partial_dependence_cache.lock().unwrap().clear();
168        self.shap_cache.lock().unwrap().clear();
169        self.prediction_cache.lock().unwrap().clear();
170
171        let mut stats = self.cache_hits.lock().unwrap();
172        *stats = CacheStatistics::default();
173    }
174
175    /// Evict least recently used entries if cache is full
176    pub fn evict_lru(&self) {
177        // Simple size-based eviction for now
178        // In a production system, you would implement proper LRU tracking
179        let total_size = self.cache_hits.lock().unwrap().total_size;
180        if total_size > self.max_cache_size {
181            // Clear half the cache
182            self.feature_importance_cache.lock().unwrap().clear();
183            self.partial_dependence_cache.lock().unwrap().clear();
184
185            let mut stats = self.cache_hits.lock().unwrap();
186            stats.total_size /= 2;
187        }
188    }
189}
190
191impl CacheKey {
192    /// Create a new cache key from data and configuration
193    pub fn new(data: &ArrayView2<Float>, method_id: &str, config_hash: u64) -> Self {
194        let mut hasher = std::collections::hash_map::DefaultHasher::new();
195
196        // Hash data dimensions and a sample of values for efficiency
197        data.shape().hash(&mut hasher);
198
199        // Hash a sample of data values for uniqueness
200        let sample_indices = if data.len() > 1000 {
201            (0..data.len())
202                .step_by(data.len() / 100)
203                .collect::<Vec<_>>()
204        } else {
205            (0..data.len()).collect::<Vec<_>>()
206        };
207
208        for &idx in &sample_indices {
209            let (row, col) = (idx / data.ncols(), idx % data.ncols());
210            if let Some(val) = data.get((row, col)) {
211                val.to_bits().hash(&mut hasher);
212            }
213        }
214
215        let data_hash = hasher.finish();
216
217        Self {
218            data_hash,
219            method_id: method_id.to_string(),
220            config_hash,
221        }
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    // ✅ SciRS2 Policy Compliant Import
229    use scirs2_core::ndarray::array;
230
231    #[test]
232    fn test_cache_key_creation() {
233        let x = array![[1.0, 2.0], [3.0, 4.0]];
234        let key1 = CacheKey::new(&x.view(), "test_method", 123);
235        let key2 = CacheKey::new(&x.view(), "test_method", 123);
236        let key3 = CacheKey::new(&x.view(), "different_method", 123);
237
238        assert_eq!(key1, key2);
239        assert_ne!(key1, key3);
240    }
241
242    #[test]
243    fn test_explanation_cache_creation() {
244        let config = CacheConfig::default();
245        let cache = ExplanationCache::new(&config);
246
247        let stats = cache.get_statistics();
248        assert_eq!(stats.hits, 0);
249        assert_eq!(stats.misses, 0);
250    }
251
252    #[test]
253    fn test_cache_hit_and_miss() {
254        let config = CacheConfig::default();
255        let cache = ExplanationCache::new(&config);
256
257        let x = array![[1.0, 2.0], [3.0, 4.0]];
258        let key = CacheKey::new(&x.view(), "test", 0);
259
260        // First access should be a miss
261        let result1 = cache
262            .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
263            .unwrap();
264
265        let stats = cache.get_statistics();
266        assert_eq!(stats.misses, 1);
267        assert_eq!(stats.hits, 0);
268
269        // Second access should be a hit
270        let result2 = cache
271            .get_or_compute_feature_importance(&key, || {
272                Ok(array![0.1, 0.9]) // Different values - should not be computed
273            })
274            .unwrap();
275
276        let stats = cache.get_statistics();
277        assert_eq!(stats.misses, 1);
278        assert_eq!(stats.hits, 1);
279
280        // Results should be the same (from cache)
281        assert_eq!(result1, result2);
282    }
283
284    #[test]
285    fn test_cache_statistics() {
286        let config = CacheConfig::default();
287        let cache = ExplanationCache::new(&config);
288
289        let x = array![[1.0, 2.0]];
290        let key = CacheKey::new(&x.view(), "test", 0);
291
292        // Perform some operations
293        cache
294            .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
295            .unwrap();
296        cache
297            .get_or_compute_feature_importance(&key, || Ok(array![0.1, 0.9]))
298            .unwrap();
299
300        let stats = cache.get_statistics();
301        assert_eq!(stats.hits, 1);
302        assert_eq!(stats.misses, 1);
303        assert!(stats.total_size > 0);
304    }
305
306    #[test]
307    fn test_cache_clear() {
308        let config = CacheConfig::default();
309        let cache = ExplanationCache::new(&config);
310
311        let x = array![[1.0, 2.0]];
312        let key = CacheKey::new(&x.view(), "test", 0);
313
314        // Add something to cache
315        cache
316            .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
317            .unwrap();
318
319        let stats_before = cache.get_statistics();
320        assert_eq!(stats_before.misses, 1);
321
322        // Clear cache
323        cache.clear_all();
324
325        let stats_after = cache.get_statistics();
326        assert_eq!(stats_after.hits, 0);
327        assert_eq!(stats_after.misses, 0);
328        assert_eq!(stats_after.total_size, 0);
329    }
330}