sklears_inspection/memory/
cache.rs1use crate::types::*;
7use crate::SklResult;
8use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::{Arc, Mutex};
13
14pub struct ExplanationCache {
16 feature_importance_cache: Arc<Mutex<HashMap<CacheKey, Array1<Float>>>>,
18 partial_dependence_cache: Arc<Mutex<HashMap<CacheKey, Array2<Float>>>>,
20 shap_cache: Arc<Mutex<HashMap<CacheKey, Array2<Float>>>>,
22 prediction_cache: Arc<Mutex<HashMap<CacheKey, Array1<Float>>>>,
24 max_cache_size: usize,
26 cache_hits: Arc<Mutex<CacheStatistics>>,
28}
29
30#[derive(Clone, Debug, PartialEq, Eq, Hash)]
32pub struct CacheKey {
33 data_hash: u64,
35 method_id: String,
37 config_hash: u64,
39}
40
41#[derive(Clone, Debug, Default)]
43pub struct CacheStatistics {
44 pub hits: usize,
46 pub misses: usize,
48 pub total_size: usize,
50 pub avg_access_time: f64,
52}
53
54#[derive(Clone, Debug)]
56pub struct CacheConfig {
57 pub max_cache_size_mb: usize,
59 pub enable_locality_optimization: bool,
61 pub prefetch_distance: usize,
63 pub memory_alignment: usize,
65}
66
67impl Default for CacheConfig {
68 fn default() -> Self {
69 Self {
70 max_cache_size_mb: 256,
71 enable_locality_optimization: true,
72 prefetch_distance: 64,
73 memory_alignment: 64,
74 }
75 }
76}
77
78impl ExplanationCache {
79 pub fn new(config: &CacheConfig) -> Self {
81 Self {
82 feature_importance_cache: Arc::new(Mutex::new(HashMap::new())),
83 partial_dependence_cache: Arc::new(Mutex::new(HashMap::new())),
84 shap_cache: Arc::new(Mutex::new(HashMap::new())),
85 prediction_cache: Arc::new(Mutex::new(HashMap::new())),
86 max_cache_size: config.max_cache_size_mb * 1024 * 1024,
87 cache_hits: Arc::new(Mutex::new(CacheStatistics::default())),
88 }
89 }
90
91 pub fn get_or_compute_feature_importance<F>(
93 &self,
94 key: &CacheKey,
95 compute_fn: F,
96 ) -> SklResult<Array1<Float>>
97 where
98 F: FnOnce() -> SklResult<Array1<Float>>,
99 {
100 {
102 let cache = self.feature_importance_cache.lock().unwrap();
103 if let Some(result) = cache.get(key) {
104 let mut stats = self.cache_hits.lock().unwrap();
106 stats.hits += 1;
107 return Ok(result.clone());
108 }
109 }
110
111 let result = compute_fn()?;
113
114 {
115 let mut cache = self.feature_importance_cache.lock().unwrap();
116 cache.insert(key.clone(), result.clone());
117
118 let mut stats = self.cache_hits.lock().unwrap();
120 stats.misses += 1;
121 stats.total_size += result.len() * std::mem::size_of::<Float>();
122 }
123
124 Ok(result)
125 }
126
127 pub fn get_or_compute_shap<F>(&self, key: &CacheKey, compute_fn: F) -> SklResult<Array2<Float>>
129 where
130 F: FnOnce() -> SklResult<Array2<Float>>,
131 {
132 {
134 let cache = self.shap_cache.lock().unwrap();
135 if let Some(result) = cache.get(key) {
136 let mut stats = self.cache_hits.lock().unwrap();
138 stats.hits += 1;
139 return Ok(result.clone());
140 }
141 }
142
143 let result = compute_fn()?;
145
146 {
147 let mut cache = self.shap_cache.lock().unwrap();
148 cache.insert(key.clone(), result.clone());
149
150 let mut stats = self.cache_hits.lock().unwrap();
152 stats.misses += 1;
153 stats.total_size += result.len() * std::mem::size_of::<Float>();
154 }
155
156 Ok(result)
157 }
158
159 pub fn get_statistics(&self) -> CacheStatistics {
161 self.cache_hits.lock().unwrap().clone()
162 }
163
164 pub fn clear_all(&self) {
166 self.feature_importance_cache.lock().unwrap().clear();
167 self.partial_dependence_cache.lock().unwrap().clear();
168 self.shap_cache.lock().unwrap().clear();
169 self.prediction_cache.lock().unwrap().clear();
170
171 let mut stats = self.cache_hits.lock().unwrap();
172 *stats = CacheStatistics::default();
173 }
174
175 pub fn evict_lru(&self) {
177 let total_size = self.cache_hits.lock().unwrap().total_size;
180 if total_size > self.max_cache_size {
181 self.feature_importance_cache.lock().unwrap().clear();
183 self.partial_dependence_cache.lock().unwrap().clear();
184
185 let mut stats = self.cache_hits.lock().unwrap();
186 stats.total_size /= 2;
187 }
188 }
189}
190
191impl CacheKey {
192 pub fn new(data: &ArrayView2<Float>, method_id: &str, config_hash: u64) -> Self {
194 let mut hasher = std::collections::hash_map::DefaultHasher::new();
195
196 data.shape().hash(&mut hasher);
198
199 let sample_indices = if data.len() > 1000 {
201 (0..data.len())
202 .step_by(data.len() / 100)
203 .collect::<Vec<_>>()
204 } else {
205 (0..data.len()).collect::<Vec<_>>()
206 };
207
208 for &idx in &sample_indices {
209 let (row, col) = (idx / data.ncols(), idx % data.ncols());
210 if let Some(val) = data.get((row, col)) {
211 val.to_bits().hash(&mut hasher);
212 }
213 }
214
215 let data_hash = hasher.finish();
216
217 Self {
218 data_hash,
219 method_id: method_id.to_string(),
220 config_hash,
221 }
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use scirs2_core::ndarray::array;
230
231 #[test]
232 fn test_cache_key_creation() {
233 let x = array![[1.0, 2.0], [3.0, 4.0]];
234 let key1 = CacheKey::new(&x.view(), "test_method", 123);
235 let key2 = CacheKey::new(&x.view(), "test_method", 123);
236 let key3 = CacheKey::new(&x.view(), "different_method", 123);
237
238 assert_eq!(key1, key2);
239 assert_ne!(key1, key3);
240 }
241
242 #[test]
243 fn test_explanation_cache_creation() {
244 let config = CacheConfig::default();
245 let cache = ExplanationCache::new(&config);
246
247 let stats = cache.get_statistics();
248 assert_eq!(stats.hits, 0);
249 assert_eq!(stats.misses, 0);
250 }
251
252 #[test]
253 fn test_cache_hit_and_miss() {
254 let config = CacheConfig::default();
255 let cache = ExplanationCache::new(&config);
256
257 let x = array![[1.0, 2.0], [3.0, 4.0]];
258 let key = CacheKey::new(&x.view(), "test", 0);
259
260 let result1 = cache
262 .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
263 .unwrap();
264
265 let stats = cache.get_statistics();
266 assert_eq!(stats.misses, 1);
267 assert_eq!(stats.hits, 0);
268
269 let result2 = cache
271 .get_or_compute_feature_importance(&key, || {
272 Ok(array![0.1, 0.9]) })
274 .unwrap();
275
276 let stats = cache.get_statistics();
277 assert_eq!(stats.misses, 1);
278 assert_eq!(stats.hits, 1);
279
280 assert_eq!(result1, result2);
282 }
283
284 #[test]
285 fn test_cache_statistics() {
286 let config = CacheConfig::default();
287 let cache = ExplanationCache::new(&config);
288
289 let x = array![[1.0, 2.0]];
290 let key = CacheKey::new(&x.view(), "test", 0);
291
292 cache
294 .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
295 .unwrap();
296 cache
297 .get_or_compute_feature_importance(&key, || Ok(array![0.1, 0.9]))
298 .unwrap();
299
300 let stats = cache.get_statistics();
301 assert_eq!(stats.hits, 1);
302 assert_eq!(stats.misses, 1);
303 assert!(stats.total_size > 0);
304 }
305
306 #[test]
307 fn test_cache_clear() {
308 let config = CacheConfig::default();
309 let cache = ExplanationCache::new(&config);
310
311 let x = array![[1.0, 2.0]];
312 let key = CacheKey::new(&x.view(), "test", 0);
313
314 cache
316 .get_or_compute_feature_importance(&key, || Ok(array![0.5, 0.3]))
317 .unwrap();
318
319 let stats_before = cache.get_statistics();
320 assert_eq!(stats_before.misses, 1);
321
322 cache.clear_all();
324
325 let stats_after = cache.get_statistics();
326 assert_eq!(stats_after.hits, 0);
327 assert_eq!(stats_after.misses, 0);
328 assert_eq!(stats_after.total_size, 0);
329 }
330}