tensorlogic_sklears_kernels/
cache.rs1use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::sync::{Arc, Mutex};
8
9use crate::error::Result;
10use crate::types::Kernel;
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14struct CacheKey {
15 x_hash: u64,
17 y_hash: u64,
19}
20
21impl CacheKey {
22 fn from_inputs(x: &[f64], y: &[f64]) -> Self {
24 Self {
25 x_hash: Self::hash_vector(x),
26 y_hash: Self::hash_vector(y),
27 }
28 }
29
30 fn hash_vector(v: &[f64]) -> u64 {
32 let mut hasher = std::collections::hash_map::DefaultHasher::new();
33 for &val in v {
34 val.to_bits().hash(&mut hasher);
36 }
37 hasher.finish()
38 }
39}
40
41pub struct CachedKernel {
66 inner: Box<dyn Kernel>,
68 cache: Arc<Mutex<HashMap<CacheKey, f64>>>,
70 stats: Arc<Mutex<CacheStats>>,
72}
73
74#[derive(Clone, Debug, Default)]
76pub struct CacheStats {
77 pub hits: usize,
79 pub misses: usize,
81 pub size: usize,
83}
84
85impl CacheStats {
86 pub fn hit_rate(&self) -> f64 {
88 let total = self.hits + self.misses;
89 if total == 0 {
90 0.0
91 } else {
92 self.hits as f64 / total as f64
93 }
94 }
95}
96
97impl CachedKernel {
98 pub fn new(inner: Box<dyn Kernel>) -> Self {
100 Self {
101 inner,
102 cache: Arc::new(Mutex::new(HashMap::new())),
103 stats: Arc::new(Mutex::new(CacheStats::default())),
104 }
105 }
106
107 pub fn stats(&self) -> CacheStats {
109 self.stats
110 .lock()
111 .expect("lock should not be poisoned")
112 .clone()
113 }
114
115 pub fn clear(&mut self) {
117 self.cache
118 .lock()
119 .expect("lock should not be poisoned")
120 .clear();
121 let mut stats = self.stats.lock().expect("lock should not be poisoned");
122 stats.hits = 0;
123 stats.misses = 0;
124 stats.size = 0;
125 }
126
127 pub fn cache_size(&self) -> usize {
129 self.cache
130 .lock()
131 .expect("lock should not be poisoned")
132 .len()
133 }
134}
135
136impl Kernel for CachedKernel {
137 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
138 let key = CacheKey::from_inputs(x, y);
139
140 {
142 let cache = self.cache.lock().expect("lock should not be poisoned");
143 if let Some(&value) = cache.get(&key) {
144 let mut stats = self.stats.lock().expect("lock should not be poisoned");
145 stats.hits += 1;
146 return Ok(value);
147 }
148 }
149
150 let value = self.inner.compute(x, y)?;
152
153 {
155 let mut cache = self.cache.lock().expect("lock should not be poisoned");
156 cache.insert(key, value);
157
158 let mut stats = self.stats.lock().expect("lock should not be poisoned");
159 stats.misses += 1;
160 stats.size = cache.len();
161 }
162
163 Ok(value)
164 }
165
166 fn name(&self) -> &str {
167 self.inner.name()
168 }
169
170 fn is_psd(&self) -> bool {
171 self.inner.is_psd()
172 }
173}
174
175pub struct KernelMatrixCache {
202 cache: HashMap<u64, Vec<Vec<f64>>>,
204}
205
206impl KernelMatrixCache {
207 pub fn new() -> Self {
209 Self {
210 cache: HashMap::new(),
211 }
212 }
213
214 fn hash_data(data: &[Vec<f64>]) -> u64 {
216 let mut hasher = std::collections::hash_map::DefaultHasher::new();
217 for row in data {
218 for &val in row {
219 val.to_bits().hash(&mut hasher);
220 }
221 }
222 hasher.finish()
223 }
224
225 pub fn get_or_compute(
227 &mut self,
228 data: &[Vec<f64>],
229 kernel: &dyn Kernel,
230 ) -> Result<Vec<Vec<f64>>> {
231 let key = Self::hash_data(data);
232
233 if let Some(matrix) = self.cache.get(&key) {
234 return Ok(matrix.clone());
235 }
236
237 let matrix = kernel.compute_matrix(data)?;
239 self.cache.insert(key, matrix.clone());
240
241 Ok(matrix)
242 }
243
244 pub fn clear(&mut self) {
246 self.cache.clear();
247 }
248
249 pub fn size(&self) -> usize {
251 self.cache.len()
252 }
253}
254
255impl Default for KernelMatrixCache {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::tensor_kernels::LinearKernel;
265
266 #[test]
267 fn test_cached_kernel() {
268 let base = LinearKernel::new();
269 let cached = CachedKernel::new(Box::new(base));
270
271 let x = vec![1.0, 2.0, 3.0];
272 let y = vec![4.0, 5.0, 6.0];
273
274 let result1 = cached.compute(&x, &y).expect("unwrap");
276 let stats1 = cached.stats();
277 assert_eq!(stats1.misses, 1);
278 assert_eq!(stats1.hits, 0);
279
280 let result2 = cached.compute(&x, &y).expect("unwrap");
282 let stats2 = cached.stats();
283 assert_eq!(stats2.misses, 1);
284 assert_eq!(stats2.hits, 1);
285
286 assert_eq!(result1, result2);
287 }
288
289 #[test]
290 fn test_cached_kernel_clear() {
291 let base = LinearKernel::new();
292 let mut cached = CachedKernel::new(Box::new(base));
293
294 let x = vec![1.0, 2.0, 3.0];
295 let y = vec![4.0, 5.0, 6.0];
296
297 cached.compute(&x, &y).expect("unwrap");
298 assert_eq!(cached.cache_size(), 1);
299
300 cached.clear();
301 assert_eq!(cached.cache_size(), 0);
302
303 let stats = cached.stats();
304 assert_eq!(stats.hits, 0);
305 assert_eq!(stats.misses, 0);
306 }
307
308 #[test]
309 fn test_cache_stats_hit_rate() {
310 let stats = CacheStats {
311 hits: 7,
312 misses: 3,
313 size: 10,
314 };
315
316 assert!((stats.hit_rate() - 0.7).abs() < 1e-10);
317 }
318
319 #[test]
320 fn test_cache_stats_empty() {
321 let stats = CacheStats::default();
322 assert_eq!(stats.hit_rate(), 0.0);
323 }
324
325 #[test]
326 fn test_kernel_matrix_cache() {
327 let kernel = LinearKernel::new();
328 let mut cache = KernelMatrixCache::new();
329
330 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
331
332 let matrix1 = cache.get_or_compute(&data, &kernel).expect("unwrap");
334 assert_eq!(cache.size(), 1);
335
336 let matrix2 = cache.get_or_compute(&data, &kernel).expect("unwrap");
338 assert_eq!(cache.size(), 1);
339
340 assert_eq!(matrix1.len(), matrix2.len());
341 for i in 0..matrix1.len() {
342 for j in 0..matrix1[i].len() {
343 assert_eq!(matrix1[i][j], matrix2[i][j]);
344 }
345 }
346 }
347
348 #[test]
349 fn test_kernel_matrix_cache_clear() {
350 let kernel = LinearKernel::new();
351 let mut cache = KernelMatrixCache::new();
352
353 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
354
355 cache.get_or_compute(&data, &kernel).expect("unwrap");
356 assert_eq!(cache.size(), 1);
357
358 cache.clear();
359 assert_eq!(cache.size(), 0);
360 }
361
362 #[test]
363 fn test_cached_kernel_name() {
364 let base = LinearKernel::new();
365 let cached = CachedKernel::new(Box::new(base));
366 assert_eq!(cached.name(), "Linear");
367 }
368
369 #[test]
370 fn test_cached_kernel_psd() {
371 let base = LinearKernel::new();
372 let cached = CachedKernel::new(Box::new(base));
373 assert!(cached.is_psd());
374 }
375}