tensorlogic_sklears_kernels/
cache.rs1use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::sync::{Arc, Mutex};
8
9use crate::error::Result;
10use crate::types::Kernel;
11
12#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14struct CacheKey {
15 x_hash: u64,
17 y_hash: u64,
19}
20
21impl CacheKey {
22 fn from_inputs(x: &[f64], y: &[f64]) -> Self {
24 Self {
25 x_hash: Self::hash_vector(x),
26 y_hash: Self::hash_vector(y),
27 }
28 }
29
30 fn hash_vector(v: &[f64]) -> u64 {
32 let mut hasher = std::collections::hash_map::DefaultHasher::new();
33 for &val in v {
34 val.to_bits().hash(&mut hasher);
36 }
37 hasher.finish()
38 }
39}
40
41pub struct CachedKernel {
66 inner: Box<dyn Kernel>,
68 cache: Arc<Mutex<HashMap<CacheKey, f64>>>,
70 stats: Arc<Mutex<CacheStats>>,
72}
73
74#[derive(Clone, Debug, Default)]
76pub struct CacheStats {
77 pub hits: usize,
79 pub misses: usize,
81 pub size: usize,
83}
84
85impl CacheStats {
86 pub fn hit_rate(&self) -> f64 {
88 let total = self.hits + self.misses;
89 if total == 0 {
90 0.0
91 } else {
92 self.hits as f64 / total as f64
93 }
94 }
95}
96
97impl CachedKernel {
98 pub fn new(inner: Box<dyn Kernel>) -> Self {
100 Self {
101 inner,
102 cache: Arc::new(Mutex::new(HashMap::new())),
103 stats: Arc::new(Mutex::new(CacheStats::default())),
104 }
105 }
106
107 pub fn stats(&self) -> CacheStats {
109 self.stats.lock().unwrap().clone()
110 }
111
112 pub fn clear(&mut self) {
114 self.cache.lock().unwrap().clear();
115 let mut stats = self.stats.lock().unwrap();
116 stats.hits = 0;
117 stats.misses = 0;
118 stats.size = 0;
119 }
120
121 pub fn cache_size(&self) -> usize {
123 self.cache.lock().unwrap().len()
124 }
125}
126
127impl Kernel for CachedKernel {
128 fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
129 let key = CacheKey::from_inputs(x, y);
130
131 {
133 let cache = self.cache.lock().unwrap();
134 if let Some(&value) = cache.get(&key) {
135 let mut stats = self.stats.lock().unwrap();
136 stats.hits += 1;
137 return Ok(value);
138 }
139 }
140
141 let value = self.inner.compute(x, y)?;
143
144 {
146 let mut cache = self.cache.lock().unwrap();
147 cache.insert(key, value);
148
149 let mut stats = self.stats.lock().unwrap();
150 stats.misses += 1;
151 stats.size = cache.len();
152 }
153
154 Ok(value)
155 }
156
157 fn name(&self) -> &str {
158 self.inner.name()
159 }
160
161 fn is_psd(&self) -> bool {
162 self.inner.is_psd()
163 }
164}
165
166pub struct KernelMatrixCache {
193 cache: HashMap<u64, Vec<Vec<f64>>>,
195}
196
197impl KernelMatrixCache {
198 pub fn new() -> Self {
200 Self {
201 cache: HashMap::new(),
202 }
203 }
204
205 fn hash_data(data: &[Vec<f64>]) -> u64 {
207 let mut hasher = std::collections::hash_map::DefaultHasher::new();
208 for row in data {
209 for &val in row {
210 val.to_bits().hash(&mut hasher);
211 }
212 }
213 hasher.finish()
214 }
215
216 pub fn get_or_compute(
218 &mut self,
219 data: &[Vec<f64>],
220 kernel: &dyn Kernel,
221 ) -> Result<Vec<Vec<f64>>> {
222 let key = Self::hash_data(data);
223
224 if let Some(matrix) = self.cache.get(&key) {
225 return Ok(matrix.clone());
226 }
227
228 let matrix = kernel.compute_matrix(data)?;
230 self.cache.insert(key, matrix.clone());
231
232 Ok(matrix)
233 }
234
235 pub fn clear(&mut self) {
237 self.cache.clear();
238 }
239
240 pub fn size(&self) -> usize {
242 self.cache.len()
243 }
244}
245
246impl Default for KernelMatrixCache {
247 fn default() -> Self {
248 Self::new()
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255 use crate::tensor_kernels::LinearKernel;
256
257 #[test]
258 fn test_cached_kernel() {
259 let base = LinearKernel::new();
260 let cached = CachedKernel::new(Box::new(base));
261
262 let x = vec![1.0, 2.0, 3.0];
263 let y = vec![4.0, 5.0, 6.0];
264
265 let result1 = cached.compute(&x, &y).unwrap();
267 let stats1 = cached.stats();
268 assert_eq!(stats1.misses, 1);
269 assert_eq!(stats1.hits, 0);
270
271 let result2 = cached.compute(&x, &y).unwrap();
273 let stats2 = cached.stats();
274 assert_eq!(stats2.misses, 1);
275 assert_eq!(stats2.hits, 1);
276
277 assert_eq!(result1, result2);
278 }
279
280 #[test]
281 fn test_cached_kernel_clear() {
282 let base = LinearKernel::new();
283 let mut cached = CachedKernel::new(Box::new(base));
284
285 let x = vec![1.0, 2.0, 3.0];
286 let y = vec![4.0, 5.0, 6.0];
287
288 cached.compute(&x, &y).unwrap();
289 assert_eq!(cached.cache_size(), 1);
290
291 cached.clear();
292 assert_eq!(cached.cache_size(), 0);
293
294 let stats = cached.stats();
295 assert_eq!(stats.hits, 0);
296 assert_eq!(stats.misses, 0);
297 }
298
299 #[test]
300 fn test_cache_stats_hit_rate() {
301 let stats = CacheStats {
302 hits: 7,
303 misses: 3,
304 size: 10,
305 };
306
307 assert!((stats.hit_rate() - 0.7).abs() < 1e-10);
308 }
309
310 #[test]
311 fn test_cache_stats_empty() {
312 let stats = CacheStats::default();
313 assert_eq!(stats.hit_rate(), 0.0);
314 }
315
316 #[test]
317 fn test_kernel_matrix_cache() {
318 let kernel = LinearKernel::new();
319 let mut cache = KernelMatrixCache::new();
320
321 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
322
323 let matrix1 = cache.get_or_compute(&data, &kernel).unwrap();
325 assert_eq!(cache.size(), 1);
326
327 let matrix2 = cache.get_or_compute(&data, &kernel).unwrap();
329 assert_eq!(cache.size(), 1);
330
331 assert_eq!(matrix1.len(), matrix2.len());
332 for i in 0..matrix1.len() {
333 for j in 0..matrix1[i].len() {
334 assert_eq!(matrix1[i][j], matrix2[i][j]);
335 }
336 }
337 }
338
339 #[test]
340 fn test_kernel_matrix_cache_clear() {
341 let kernel = LinearKernel::new();
342 let mut cache = KernelMatrixCache::new();
343
344 let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
345
346 cache.get_or_compute(&data, &kernel).unwrap();
347 assert_eq!(cache.size(), 1);
348
349 cache.clear();
350 assert_eq!(cache.size(), 0);
351 }
352
353 #[test]
354 fn test_cached_kernel_name() {
355 let base = LinearKernel::new();
356 let cached = CachedKernel::new(Box::new(base));
357 assert_eq!(cached.name(), "Linear");
358 }
359
360 #[test]
361 fn test_cached_kernel_psd() {
362 let base = LinearKernel::new();
363 let cached = CachedKernel::new(Box::new(base));
364 assert!(cached.is_psd());
365 }
366}