Skip to main content

tensorlogic_sklears_kernels/
cache.rs

1//! Kernel caching infrastructure for performance optimization.
2//!
3//! Provides caching mechanisms to avoid redundant kernel computations.
4
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::sync::{Arc, Mutex};
8
9use crate::error::Result;
10use crate::types::Kernel;
11
12/// Hash key for caching kernel computations
13#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14struct CacheKey {
15    /// Hash of first input
16    x_hash: u64,
17    /// Hash of second input
18    y_hash: u64,
19}
20
21impl CacheKey {
22    /// Create a cache key from two input vectors
23    fn from_inputs(x: &[f64], y: &[f64]) -> Self {
24        Self {
25            x_hash: Self::hash_vector(x),
26            y_hash: Self::hash_vector(y),
27        }
28    }
29
30    /// Hash a vector of f64 values
31    fn hash_vector(v: &[f64]) -> u64 {
32        let mut hasher = std::collections::hash_map::DefaultHasher::new();
33        for &val in v {
34            // Convert to bits for consistent hashing
35            val.to_bits().hash(&mut hasher);
36        }
37        hasher.finish()
38    }
39}
40
41/// Cached kernel wrapper that stores computed values
42///
43/// # Example
44///
45/// ```rust
46/// use tensorlogic_sklears_kernels::{LinearKernel, CachedKernel, Kernel};
47///
48/// let base_kernel = LinearKernel::new();
49/// let mut cached = CachedKernel::new(Box::new(base_kernel));
50///
51/// let x = vec![1.0, 2.0, 3.0];
52/// let y = vec![4.0, 5.0, 6.0];
53///
54/// // First call computes and caches
55/// let result1 = cached.compute(&x, &y).expect("unwrap");
56///
57/// // Second call retrieves from cache
58/// let result2 = cached.compute(&x, &y).expect("unwrap");
59/// assert_eq!(result1, result2);
60///
61/// // Check cache statistics
62/// let stats = cached.stats();
63/// assert!(stats.hits > 0);
64/// ```
65pub struct CachedKernel {
66    /// Underlying kernel
67    inner: Box<dyn Kernel>,
68    /// Cache storage
69    cache: Arc<Mutex<HashMap<CacheKey, f64>>>,
70    /// Cache statistics
71    stats: Arc<Mutex<CacheStats>>,
72}
73
74/// Cache statistics
75#[derive(Clone, Debug, Default)]
76pub struct CacheStats {
77    /// Number of cache hits
78    pub hits: usize,
79    /// Number of cache misses
80    pub misses: usize,
81    /// Number of entries in cache
82    pub size: usize,
83}
84
85impl CacheStats {
86    /// Get hit rate
87    pub fn hit_rate(&self) -> f64 {
88        let total = self.hits + self.misses;
89        if total == 0 {
90            0.0
91        } else {
92            self.hits as f64 / total as f64
93        }
94    }
95}
96
97impl CachedKernel {
98    /// Create a new cached kernel
99    pub fn new(inner: Box<dyn Kernel>) -> Self {
100        Self {
101            inner,
102            cache: Arc::new(Mutex::new(HashMap::new())),
103            stats: Arc::new(Mutex::new(CacheStats::default())),
104        }
105    }
106
107    /// Get cache statistics
108    pub fn stats(&self) -> CacheStats {
109        self.stats
110            .lock()
111            .expect("lock should not be poisoned")
112            .clone()
113    }
114
115    /// Clear the cache
116    pub fn clear(&mut self) {
117        self.cache
118            .lock()
119            .expect("lock should not be poisoned")
120            .clear();
121        let mut stats = self.stats.lock().expect("lock should not be poisoned");
122        stats.hits = 0;
123        stats.misses = 0;
124        stats.size = 0;
125    }
126
127    /// Get cache size
128    pub fn cache_size(&self) -> usize {
129        self.cache
130            .lock()
131            .expect("lock should not be poisoned")
132            .len()
133    }
134}
135
136impl Kernel for CachedKernel {
137    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
138        let key = CacheKey::from_inputs(x, y);
139
140        // Check cache first
141        {
142            let cache = self.cache.lock().expect("lock should not be poisoned");
143            if let Some(&value) = cache.get(&key) {
144                let mut stats = self.stats.lock().expect("lock should not be poisoned");
145                stats.hits += 1;
146                return Ok(value);
147            }
148        }
149
150        // Cache miss - compute value
151        let value = self.inner.compute(x, y)?;
152
153        // Store in cache
154        {
155            let mut cache = self.cache.lock().expect("lock should not be poisoned");
156            cache.insert(key, value);
157
158            let mut stats = self.stats.lock().expect("lock should not be poisoned");
159            stats.misses += 1;
160            stats.size = cache.len();
161        }
162
163        Ok(value)
164    }
165
166    fn name(&self) -> &str {
167        self.inner.name()
168    }
169
170    fn is_psd(&self) -> bool {
171        self.inner.is_psd()
172    }
173}
174
175/// Kernel matrix cache for efficient matrix operations
176///
177/// Stores entire kernel matrices to avoid recomputation.
178///
179/// # Example
180///
181/// ```rust
182/// use tensorlogic_sklears_kernels::{LinearKernel, KernelMatrixCache};
183///
184/// let kernel = LinearKernel::new();
185/// let mut cache = KernelMatrixCache::new();
186///
187/// let data = vec![
188///     vec![1.0, 2.0],
189///     vec![3.0, 4.0],
190///     vec![5.0, 6.0],
191/// ];
192///
193/// // Compute and cache
194/// let matrix1 = cache.get_or_compute(&data, &kernel).expect("unwrap");
195///
196/// // Retrieve from cache
197/// let matrix2 = cache.get_or_compute(&data, &kernel).expect("unwrap");
198///
199/// assert_eq!(matrix1.len(), matrix2.len());
200/// ```
201pub struct KernelMatrixCache {
202    /// Cache storage
203    cache: HashMap<u64, Vec<Vec<f64>>>,
204}
205
206impl KernelMatrixCache {
207    /// Create a new matrix cache
208    pub fn new() -> Self {
209        Self {
210            cache: HashMap::new(),
211        }
212    }
213
214    /// Hash input data
215    fn hash_data(data: &[Vec<f64>]) -> u64 {
216        let mut hasher = std::collections::hash_map::DefaultHasher::new();
217        for row in data {
218            for &val in row {
219                val.to_bits().hash(&mut hasher);
220            }
221        }
222        hasher.finish()
223    }
224
225    /// Get or compute kernel matrix
226    pub fn get_or_compute(
227        &mut self,
228        data: &[Vec<f64>],
229        kernel: &dyn Kernel,
230    ) -> Result<Vec<Vec<f64>>> {
231        let key = Self::hash_data(data);
232
233        if let Some(matrix) = self.cache.get(&key) {
234            return Ok(matrix.clone());
235        }
236
237        // Compute matrix
238        let matrix = kernel.compute_matrix(data)?;
239        self.cache.insert(key, matrix.clone());
240
241        Ok(matrix)
242    }
243
244    /// Clear the cache
245    pub fn clear(&mut self) {
246        self.cache.clear();
247    }
248
249    /// Get cache size
250    pub fn size(&self) -> usize {
251        self.cache.len()
252    }
253}
254
255impl Default for KernelMatrixCache {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::tensor_kernels::LinearKernel;
265
266    #[test]
267    fn test_cached_kernel() {
268        let base = LinearKernel::new();
269        let cached = CachedKernel::new(Box::new(base));
270
271        let x = vec![1.0, 2.0, 3.0];
272        let y = vec![4.0, 5.0, 6.0];
273
274        // First call - cache miss
275        let result1 = cached.compute(&x, &y).expect("unwrap");
276        let stats1 = cached.stats();
277        assert_eq!(stats1.misses, 1);
278        assert_eq!(stats1.hits, 0);
279
280        // Second call - cache hit
281        let result2 = cached.compute(&x, &y).expect("unwrap");
282        let stats2 = cached.stats();
283        assert_eq!(stats2.misses, 1);
284        assert_eq!(stats2.hits, 1);
285
286        assert_eq!(result1, result2);
287    }
288
289    #[test]
290    fn test_cached_kernel_clear() {
291        let base = LinearKernel::new();
292        let mut cached = CachedKernel::new(Box::new(base));
293
294        let x = vec![1.0, 2.0, 3.0];
295        let y = vec![4.0, 5.0, 6.0];
296
297        cached.compute(&x, &y).expect("unwrap");
298        assert_eq!(cached.cache_size(), 1);
299
300        cached.clear();
301        assert_eq!(cached.cache_size(), 0);
302
303        let stats = cached.stats();
304        assert_eq!(stats.hits, 0);
305        assert_eq!(stats.misses, 0);
306    }
307
308    #[test]
309    fn test_cache_stats_hit_rate() {
310        let stats = CacheStats {
311            hits: 7,
312            misses: 3,
313            size: 10,
314        };
315
316        assert!((stats.hit_rate() - 0.7).abs() < 1e-10);
317    }
318
319    #[test]
320    fn test_cache_stats_empty() {
321        let stats = CacheStats::default();
322        assert_eq!(stats.hit_rate(), 0.0);
323    }
324
325    #[test]
326    fn test_kernel_matrix_cache() {
327        let kernel = LinearKernel::new();
328        let mut cache = KernelMatrixCache::new();
329
330        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
331
332        // First call - compute
333        let matrix1 = cache.get_or_compute(&data, &kernel).expect("unwrap");
334        assert_eq!(cache.size(), 1);
335
336        // Second call - retrieve from cache
337        let matrix2 = cache.get_or_compute(&data, &kernel).expect("unwrap");
338        assert_eq!(cache.size(), 1);
339
340        assert_eq!(matrix1.len(), matrix2.len());
341        for i in 0..matrix1.len() {
342            for j in 0..matrix1[i].len() {
343                assert_eq!(matrix1[i][j], matrix2[i][j]);
344            }
345        }
346    }
347
348    #[test]
349    fn test_kernel_matrix_cache_clear() {
350        let kernel = LinearKernel::new();
351        let mut cache = KernelMatrixCache::new();
352
353        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
354
355        cache.get_or_compute(&data, &kernel).expect("unwrap");
356        assert_eq!(cache.size(), 1);
357
358        cache.clear();
359        assert_eq!(cache.size(), 0);
360    }
361
362    #[test]
363    fn test_cached_kernel_name() {
364        let base = LinearKernel::new();
365        let cached = CachedKernel::new(Box::new(base));
366        assert_eq!(cached.name(), "Linear");
367    }
368
369    #[test]
370    fn test_cached_kernel_psd() {
371        let base = LinearKernel::new();
372        let cached = CachedKernel::new(Box::new(base));
373        assert!(cached.is_psd());
374    }
375}