Skip to main content

tensorlogic_sklears_kernels/
cache.rs

1//! Kernel caching infrastructure for performance optimization.
2//!
3//! Provides caching mechanisms to avoid redundant kernel computations.
4
5use std::collections::HashMap;
6use std::hash::{Hash, Hasher};
7use std::sync::{Arc, Mutex};
8
9use crate::error::Result;
10use crate::types::Kernel;
11
12/// Hash key for caching kernel computations
13#[derive(Clone, Debug, PartialEq, Eq, Hash)]
14struct CacheKey {
15    /// Hash of first input
16    x_hash: u64,
17    /// Hash of second input
18    y_hash: u64,
19}
20
21impl CacheKey {
22    /// Create a cache key from two input vectors
23    fn from_inputs(x: &[f64], y: &[f64]) -> Self {
24        Self {
25            x_hash: Self::hash_vector(x),
26            y_hash: Self::hash_vector(y),
27        }
28    }
29
30    /// Hash a vector of f64 values
31    fn hash_vector(v: &[f64]) -> u64 {
32        let mut hasher = std::collections::hash_map::DefaultHasher::new();
33        for &val in v {
34            // Convert to bits for consistent hashing
35            val.to_bits().hash(&mut hasher);
36        }
37        hasher.finish()
38    }
39}
40
41/// Cached kernel wrapper that stores computed values
42///
43/// # Example
44///
45/// ```rust
46/// use tensorlogic_sklears_kernels::{LinearKernel, CachedKernel, Kernel};
47///
48/// let base_kernel = LinearKernel::new();
49/// let mut cached = CachedKernel::new(Box::new(base_kernel));
50///
51/// let x = vec![1.0, 2.0, 3.0];
52/// let y = vec![4.0, 5.0, 6.0];
53///
54/// // First call computes and caches
55/// let result1 = cached.compute(&x, &y).unwrap();
56///
57/// // Second call retrieves from cache
58/// let result2 = cached.compute(&x, &y).unwrap();
59/// assert_eq!(result1, result2);
60///
61/// // Check cache statistics
62/// let stats = cached.stats();
63/// assert!(stats.hits > 0);
64/// ```
65pub struct CachedKernel {
66    /// Underlying kernel
67    inner: Box<dyn Kernel>,
68    /// Cache storage
69    cache: Arc<Mutex<HashMap<CacheKey, f64>>>,
70    /// Cache statistics
71    stats: Arc<Mutex<CacheStats>>,
72}
73
74/// Cache statistics
75#[derive(Clone, Debug, Default)]
76pub struct CacheStats {
77    /// Number of cache hits
78    pub hits: usize,
79    /// Number of cache misses
80    pub misses: usize,
81    /// Number of entries in cache
82    pub size: usize,
83}
84
85impl CacheStats {
86    /// Get hit rate
87    pub fn hit_rate(&self) -> f64 {
88        let total = self.hits + self.misses;
89        if total == 0 {
90            0.0
91        } else {
92            self.hits as f64 / total as f64
93        }
94    }
95}
96
97impl CachedKernel {
98    /// Create a new cached kernel
99    pub fn new(inner: Box<dyn Kernel>) -> Self {
100        Self {
101            inner,
102            cache: Arc::new(Mutex::new(HashMap::new())),
103            stats: Arc::new(Mutex::new(CacheStats::default())),
104        }
105    }
106
107    /// Get cache statistics
108    pub fn stats(&self) -> CacheStats {
109        self.stats.lock().unwrap().clone()
110    }
111
112    /// Clear the cache
113    pub fn clear(&mut self) {
114        self.cache.lock().unwrap().clear();
115        let mut stats = self.stats.lock().unwrap();
116        stats.hits = 0;
117        stats.misses = 0;
118        stats.size = 0;
119    }
120
121    /// Get cache size
122    pub fn cache_size(&self) -> usize {
123        self.cache.lock().unwrap().len()
124    }
125}
126
127impl Kernel for CachedKernel {
128    fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
129        let key = CacheKey::from_inputs(x, y);
130
131        // Check cache first
132        {
133            let cache = self.cache.lock().unwrap();
134            if let Some(&value) = cache.get(&key) {
135                let mut stats = self.stats.lock().unwrap();
136                stats.hits += 1;
137                return Ok(value);
138            }
139        }
140
141        // Cache miss - compute value
142        let value = self.inner.compute(x, y)?;
143
144        // Store in cache
145        {
146            let mut cache = self.cache.lock().unwrap();
147            cache.insert(key, value);
148
149            let mut stats = self.stats.lock().unwrap();
150            stats.misses += 1;
151            stats.size = cache.len();
152        }
153
154        Ok(value)
155    }
156
157    fn name(&self) -> &str {
158        self.inner.name()
159    }
160
161    fn is_psd(&self) -> bool {
162        self.inner.is_psd()
163    }
164}
165
166/// Kernel matrix cache for efficient matrix operations
167///
168/// Stores entire kernel matrices to avoid recomputation.
169///
170/// # Example
171///
172/// ```rust
173/// use tensorlogic_sklears_kernels::{LinearKernel, KernelMatrixCache};
174///
175/// let kernel = LinearKernel::new();
176/// let mut cache = KernelMatrixCache::new();
177///
178/// let data = vec![
179///     vec![1.0, 2.0],
180///     vec![3.0, 4.0],
181///     vec![5.0, 6.0],
182/// ];
183///
184/// // Compute and cache
185/// let matrix1 = cache.get_or_compute(&data, &kernel).unwrap();
186///
187/// // Retrieve from cache
188/// let matrix2 = cache.get_or_compute(&data, &kernel).unwrap();
189///
190/// assert_eq!(matrix1.len(), matrix2.len());
191/// ```
192pub struct KernelMatrixCache {
193    /// Cache storage
194    cache: HashMap<u64, Vec<Vec<f64>>>,
195}
196
197impl KernelMatrixCache {
198    /// Create a new matrix cache
199    pub fn new() -> Self {
200        Self {
201            cache: HashMap::new(),
202        }
203    }
204
205    /// Hash input data
206    fn hash_data(data: &[Vec<f64>]) -> u64 {
207        let mut hasher = std::collections::hash_map::DefaultHasher::new();
208        for row in data {
209            for &val in row {
210                val.to_bits().hash(&mut hasher);
211            }
212        }
213        hasher.finish()
214    }
215
216    /// Get or compute kernel matrix
217    pub fn get_or_compute(
218        &mut self,
219        data: &[Vec<f64>],
220        kernel: &dyn Kernel,
221    ) -> Result<Vec<Vec<f64>>> {
222        let key = Self::hash_data(data);
223
224        if let Some(matrix) = self.cache.get(&key) {
225            return Ok(matrix.clone());
226        }
227
228        // Compute matrix
229        let matrix = kernel.compute_matrix(data)?;
230        self.cache.insert(key, matrix.clone());
231
232        Ok(matrix)
233    }
234
235    /// Clear the cache
236    pub fn clear(&mut self) {
237        self.cache.clear();
238    }
239
240    /// Get cache size
241    pub fn size(&self) -> usize {
242        self.cache.len()
243    }
244}
245
246impl Default for KernelMatrixCache {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use crate::tensor_kernels::LinearKernel;
256
257    #[test]
258    fn test_cached_kernel() {
259        let base = LinearKernel::new();
260        let cached = CachedKernel::new(Box::new(base));
261
262        let x = vec![1.0, 2.0, 3.0];
263        let y = vec![4.0, 5.0, 6.0];
264
265        // First call - cache miss
266        let result1 = cached.compute(&x, &y).unwrap();
267        let stats1 = cached.stats();
268        assert_eq!(stats1.misses, 1);
269        assert_eq!(stats1.hits, 0);
270
271        // Second call - cache hit
272        let result2 = cached.compute(&x, &y).unwrap();
273        let stats2 = cached.stats();
274        assert_eq!(stats2.misses, 1);
275        assert_eq!(stats2.hits, 1);
276
277        assert_eq!(result1, result2);
278    }
279
280    #[test]
281    fn test_cached_kernel_clear() {
282        let base = LinearKernel::new();
283        let mut cached = CachedKernel::new(Box::new(base));
284
285        let x = vec![1.0, 2.0, 3.0];
286        let y = vec![4.0, 5.0, 6.0];
287
288        cached.compute(&x, &y).unwrap();
289        assert_eq!(cached.cache_size(), 1);
290
291        cached.clear();
292        assert_eq!(cached.cache_size(), 0);
293
294        let stats = cached.stats();
295        assert_eq!(stats.hits, 0);
296        assert_eq!(stats.misses, 0);
297    }
298
299    #[test]
300    fn test_cache_stats_hit_rate() {
301        let stats = CacheStats {
302            hits: 7,
303            misses: 3,
304            size: 10,
305        };
306
307        assert!((stats.hit_rate() - 0.7).abs() < 1e-10);
308    }
309
310    #[test]
311    fn test_cache_stats_empty() {
312        let stats = CacheStats::default();
313        assert_eq!(stats.hit_rate(), 0.0);
314    }
315
316    #[test]
317    fn test_kernel_matrix_cache() {
318        let kernel = LinearKernel::new();
319        let mut cache = KernelMatrixCache::new();
320
321        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
322
323        // First call - compute
324        let matrix1 = cache.get_or_compute(&data, &kernel).unwrap();
325        assert_eq!(cache.size(), 1);
326
327        // Second call - retrieve from cache
328        let matrix2 = cache.get_or_compute(&data, &kernel).unwrap();
329        assert_eq!(cache.size(), 1);
330
331        assert_eq!(matrix1.len(), matrix2.len());
332        for i in 0..matrix1.len() {
333            for j in 0..matrix1[i].len() {
334                assert_eq!(matrix1[i][j], matrix2[i][j]);
335            }
336        }
337    }
338
339    #[test]
340    fn test_kernel_matrix_cache_clear() {
341        let kernel = LinearKernel::new();
342        let mut cache = KernelMatrixCache::new();
343
344        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
345
346        cache.get_or_compute(&data, &kernel).unwrap();
347        assert_eq!(cache.size(), 1);
348
349        cache.clear();
350        assert_eq!(cache.size(), 0);
351    }
352
353    #[test]
354    fn test_cached_kernel_name() {
355        let base = LinearKernel::new();
356        let cached = CachedKernel::new(Box::new(base));
357        assert_eq!(cached.name(), "Linear");
358    }
359
360    #[test]
361    fn test_cached_kernel_psd() {
362        let base = LinearKernel::new();
363        let cached = CachedKernel::new(Box::new(base));
364        assert!(cached.is_psd());
365    }
366}