Skip to main content

tensorlogic_sklears_kernels/
batch.rs

1//! Kernel matrix caching and batch computation.
2//!
3//! Provides efficient computation of kernel matrices for batches of inputs,
4//! with symmetric key normalization and LRU caching to avoid redundant evaluations.
5
6use std::collections::{HashMap, VecDeque};
7
8use scirs2_core::ndarray::Array2;
9
10use crate::error::{KernelError, Result};
11
12/// Cache for kernel evaluation results.
13///
14/// Uses symmetric key normalization: k(i,j) = k(j,i), so (i,j) and (j,i)
15/// share the same cache entry. Implements LRU eviction when capacity is reached.
16pub struct KernelCache {
17    entries: HashMap<(usize, usize), f64>,
18    lru_order: VecDeque<(usize, usize)>,
19    capacity: usize,
20    hits: u64,
21    misses: u64,
22}
23
24impl KernelCache {
25    /// Create a new kernel cache with the given capacity.
26    ///
27    /// The cache will evict the least-recently-used entry when full.
28    pub fn new(capacity: usize) -> Self {
29        Self {
30            entries: HashMap::with_capacity(capacity),
31            lru_order: VecDeque::with_capacity(capacity),
32            capacity,
33            hits: 0,
34            misses: 0,
35        }
36    }
37
38    /// Normalize cache key so that (i,j) and (j,i) map to the same entry.
39    fn normalize_key(i: usize, j: usize) -> (usize, usize) {
40        if i <= j {
41            (i, j)
42        } else {
43            (j, i)
44        }
45    }
46
47    /// Retrieve a cached value, updating LRU order on hit.
48    pub fn get(&mut self, i: usize, j: usize) -> Option<f64> {
49        let key = Self::normalize_key(i, j);
50        if let Some(&value) = self.entries.get(&key) {
51            self.hits += 1;
52            // Move to back (most recently used)
53            if let Some(pos) = self.lru_order.iter().position(|k| *k == key) {
54                self.lru_order.remove(pos);
55            }
56            self.lru_order.push_back(key);
57            Some(value)
58        } else {
59            self.misses += 1;
60            None
61        }
62    }
63
64    /// Insert a value into the cache, evicting the LRU entry if at capacity.
65    pub fn insert(&mut self, i: usize, j: usize, value: f64) {
66        let key = Self::normalize_key(i, j);
67
68        // If key already exists, update it and refresh LRU position
69        if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
70            e.insert(value);
71            if let Some(pos) = self.lru_order.iter().position(|k| *k == key) {
72                self.lru_order.remove(pos);
73            }
74            self.lru_order.push_back(key);
75            return;
76        }
77
78        // Evict if at capacity
79        if self.entries.len() >= self.capacity && self.capacity > 0 {
80            if let Some(evicted) = self.lru_order.pop_front() {
81                self.entries.remove(&evicted);
82            }
83        }
84
85        self.entries.insert(key, value);
86        self.lru_order.push_back(key);
87    }
88
89    /// Return the cache hit rate as a fraction in [0.0, 1.0].
90    pub fn hit_rate(&self) -> f64 {
91        let total = self.hits + self.misses;
92        if total == 0 {
93            0.0
94        } else {
95            self.hits as f64 / total as f64
96        }
97    }
98
99    /// Number of entries currently in the cache.
100    pub fn len(&self) -> usize {
101        self.entries.len()
102    }
103
104    /// Whether the cache is empty.
105    pub fn is_empty(&self) -> bool {
106        self.entries.is_empty()
107    }
108
109    /// Clear all entries and reset statistics.
110    pub fn clear(&mut self) {
111        self.entries.clear();
112        self.lru_order.clear();
113        self.hits = 0;
114        self.misses = 0;
115    }
116
117    /// Total number of cache hits.
118    pub fn hits(&self) -> u64 {
119        self.hits
120    }
121
122    /// Total number of cache misses.
123    pub fn misses(&self) -> u64 {
124        self.misses
125    }
126}
127
128/// A Gram matrix (symmetric kernel matrix) wrapper.
129///
130/// Wraps an `Array2<f64>` that is expected to be square and symmetric,
131/// providing convenient accessors for common operations.
132#[derive(Debug, Clone)]
133pub struct GramMatrix {
134    data: Array2<f64>,
135}
136
137impl GramMatrix {
138    /// Create a new Gram matrix, verifying that it is square.
139    pub fn new(data: Array2<f64>) -> Result<Self> {
140        if data.nrows() != data.ncols() {
141            return Err(KernelError::DimensionMismatch {
142                expected: vec![data.nrows(), data.nrows()],
143                got: vec![data.nrows(), data.ncols()],
144                context: "GramMatrix must be square".to_string(),
145            });
146        }
147        Ok(GramMatrix { data })
148    }
149
150    /// Get entry (i, j).
151    pub fn get(&self, i: usize, j: usize) -> f64 {
152        self.data[[i, j]]
153    }
154
155    /// Matrix dimension (n for an n x n matrix).
156    pub fn dim(&self) -> usize {
157        self.data.nrows()
158    }
159
160    /// Diagonal entries as a vector.
161    pub fn diagonal(&self) -> Vec<f64> {
162        (0..self.dim()).map(|i| self.data[[i, i]]).collect()
163    }
164
165    /// Matrix trace (sum of diagonal entries).
166    pub fn trace(&self) -> f64 {
167        self.diagonal().iter().sum()
168    }
169
170    /// Check if the matrix is approximately symmetric within a given tolerance.
171    pub fn is_symmetric(&self, tol: f64) -> bool {
172        let n = self.dim();
173        for i in 0..n {
174            for j in (i + 1)..n {
175                if (self.data[[i, j]] - self.data[[j, i]]).abs() > tol {
176                    return false;
177                }
178            }
179        }
180        true
181    }
182
183    /// Check if all diagonal entries are non-negative (necessary condition for PSD).
184    pub fn has_nonneg_diagonal(&self) -> bool {
185        self.diagonal().iter().all(|&d| d >= 0.0)
186    }
187
188    /// Frobenius norm: sqrt(sum of squares of all entries).
189    pub fn frobenius_norm(&self) -> f64 {
190        self.data.iter().map(|v| v * v).sum::<f64>().sqrt()
191    }
192
193    /// Access the underlying array.
194    pub fn as_array(&self) -> &Array2<f64> {
195        &self.data
196    }
197}
198
199/// Statistics about kernel matrix computation.
200#[derive(Debug, Clone, Default)]
201pub struct KernelMatrixStats {
202    /// Total number of kernel evaluations performed.
203    pub evaluations: u64,
204    /// Number of cache hits (if caching enabled).
205    pub cache_hits: u64,
206    /// Number of cache misses (if caching enabled).
207    pub cache_misses: u64,
208    /// Dimension of the computed matrix (n for n x n).
209    pub matrix_dim: usize,
210    /// Wall-clock time for computation in milliseconds.
211    pub computation_ms: f64,
212}
213
214impl KernelMatrixStats {
215    /// Cache hit rate as a fraction in [0.0, 1.0].
216    pub fn cache_hit_rate(&self) -> f64 {
217        let total = self.cache_hits + self.cache_misses;
218        if total == 0 {
219            0.0
220        } else {
221            self.cache_hits as f64 / total as f64
222        }
223    }
224}
225
226/// Batch kernel matrix computation engine.
227///
228/// Computes the full n x n kernel matrix for a batch of n input vectors,
229/// exploiting symmetry (only computing the upper triangle) and optionally
230/// caching results for repeated computations.
231pub struct BatchKernelComputer {
232    cache: Option<KernelCache>,
233}
234
235impl BatchKernelComputer {
236    /// Create a new batch computer without caching.
237    pub fn new() -> Self {
238        BatchKernelComputer { cache: None }
239    }
240
241    /// Create a new batch computer with an LRU cache of the given capacity.
242    pub fn with_cache(capacity: usize) -> Self {
243        BatchKernelComputer {
244            cache: Some(KernelCache::new(capacity)),
245        }
246    }
247
248    /// Compute the full kernel matrix for a batch of input vectors.
249    ///
250    /// The kernel function `kernel_fn` takes two vectors (as slices) and returns
251    /// the kernel value. The resulting matrix is symmetric: `K[i,j] = K[j,i]`.
252    ///
253    /// # Errors
254    ///
255    /// Returns `BatchKernelError::EmptyBatch` if `inputs` is empty.
256    pub fn compute<F>(
257        &mut self,
258        inputs: &[Vec<f64>],
259        kernel_fn: F,
260    ) -> Result<(GramMatrix, KernelMatrixStats)>
261    where
262        F: Fn(&[f64], &[f64]) -> f64,
263    {
264        if inputs.is_empty() {
265            return Err(KernelError::ComputationError(
266                "Empty input batch".to_string(),
267            ));
268        }
269
270        let n = inputs.len();
271        let dim = inputs[0].len();
272
273        // Validate consistent dimensions
274        for (idx, input) in inputs.iter().enumerate() {
275            if input.len() != dim {
276                return Err(KernelError::DimensionMismatch {
277                    expected: vec![dim],
278                    got: vec![input.len()],
279                    context: format!("Input vector at index {idx} has wrong dimension"),
280                });
281            }
282        }
283
284        let start = std::time::Instant::now();
285        let mut matrix = Array2::<f64>::zeros((n, n));
286        let mut stats = KernelMatrixStats {
287            matrix_dim: n,
288            ..Default::default()
289        };
290
291        for i in 0..n {
292            for j in i..n {
293                let value = if let Some(ref mut cache) = self.cache {
294                    if let Some(cached) = cache.get(i, j) {
295                        stats.cache_hits += 1;
296                        cached
297                    } else {
298                        stats.cache_misses += 1;
299                        let v = kernel_fn(&inputs[i], &inputs[j]);
300                        cache.insert(i, j, v);
301                        v
302                    }
303                } else {
304                    kernel_fn(&inputs[i], &inputs[j])
305                };
306                stats.evaluations += 1;
307                matrix[[i, j]] = value;
308                if i != j {
309                    matrix[[j, i]] = value;
310                }
311            }
312        }
313
314        stats.computation_ms = start.elapsed().as_secs_f64() * 1000.0;
315
316        let gram = GramMatrix { data: matrix };
317        Ok((gram, stats))
318    }
319
320    /// Clear the internal cache (no-op if caching is disabled).
321    pub fn clear_cache(&mut self) {
322        if let Some(ref mut cache) = self.cache {
323            cache.clear();
324        }
325    }
326
327    /// Return cache statistics, if caching is enabled.
328    pub fn cache_hit_rate(&self) -> Option<f64> {
329        self.cache.as_ref().map(|c| {
330            let total = c.hits + c.misses;
331            if total == 0 {
332                0.0
333            } else {
334                c.hits as f64 / total as f64
335            }
336        })
337    }
338}
339
340impl Default for BatchKernelComputer {
341    fn default() -> Self {
342        Self::new()
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    // ── KernelCache tests ──────────────────────────────────────────
351
352    #[test]
353    fn test_kernel_cache_insert_get() {
354        let mut cache = KernelCache::new(16);
355        cache.insert(0, 1, 7.53);
356        let val = cache.get(0, 1);
357        assert_eq!(val, Some(7.53));
358    }
359
360    #[test]
361    fn test_kernel_cache_symmetric() {
362        let mut cache = KernelCache::new(16);
363        cache.insert(1, 2, 42.0);
364        assert_eq!(cache.get(2, 1), Some(42.0));
365        assert_eq!(cache.get(1, 2), Some(42.0));
366    }
367
368    #[test]
369    fn test_kernel_cache_miss() {
370        let mut cache = KernelCache::new(16);
371        assert_eq!(cache.get(5, 7), None);
372    }
373
374    #[test]
375    fn test_kernel_cache_hit_rate() {
376        let mut cache = KernelCache::new(16);
377        cache.insert(0, 1, 1.0);
378        let _ = cache.get(0, 1); // hit
379        let _ = cache.get(2, 3); // miss
380        let rate = cache.hit_rate();
381        assert!((rate - 0.5).abs() < 1e-12);
382    }
383
384    #[test]
385    fn test_kernel_cache_eviction() {
386        let mut cache = KernelCache::new(2);
387        cache.insert(0, 1, 1.0);
388        cache.insert(2, 3, 2.0);
389        // Cache is full (capacity 2), inserting a third evicts the oldest
390        cache.insert(4, 5, 3.0);
391        assert_eq!(cache.len(), 2);
392        // (0,1) was the oldest and should be evicted
393        assert_eq!(cache.get(0, 1), None);
394        assert_eq!(cache.get(2, 3), Some(2.0));
395        assert_eq!(cache.get(4, 5), Some(3.0));
396    }
397
398    #[test]
399    fn test_kernel_cache_clear() {
400        let mut cache = KernelCache::new(16);
401        cache.insert(0, 1, 1.0);
402        cache.insert(2, 3, 2.0);
403        assert_eq!(cache.len(), 2);
404        cache.clear();
405        assert_eq!(cache.len(), 0);
406        assert!(cache.is_empty());
407        assert_eq!(cache.hits(), 0);
408        assert_eq!(cache.misses(), 0);
409    }
410
411    // ── GramMatrix tests ───────────────────────────────────────────
412
413    #[test]
414    fn test_gram_matrix_new_valid() {
415        let data = Array2::<f64>::zeros((3, 3));
416        let gram = GramMatrix::new(data);
417        assert!(gram.is_ok());
418        assert_eq!(gram.expect("valid gram matrix").dim(), 3);
419    }
420
421    #[test]
422    fn test_gram_matrix_not_square() {
423        let data = Array2::<f64>::zeros((3, 2));
424        let gram = GramMatrix::new(data);
425        assert!(gram.is_err());
426    }
427
428    #[test]
429    fn test_gram_matrix_diagonal() {
430        let mut data = Array2::<f64>::zeros((3, 3));
431        data[[0, 0]] = 1.0;
432        data[[1, 1]] = 2.0;
433        data[[2, 2]] = 3.0;
434        let gram = GramMatrix::new(data).expect("valid gram matrix");
435        assert_eq!(gram.diagonal(), vec![1.0, 2.0, 3.0]);
436    }
437
438    #[test]
439    fn test_gram_matrix_trace() {
440        let mut data = Array2::<f64>::zeros((3, 3));
441        data[[0, 0]] = 1.0;
442        data[[1, 1]] = 2.0;
443        data[[2, 2]] = 3.0;
444        let gram = GramMatrix::new(data).expect("valid gram matrix");
445        assert!((gram.trace() - 6.0).abs() < 1e-12);
446    }
447
448    #[test]
449    fn test_gram_matrix_symmetric() {
450        let mut data = Array2::<f64>::zeros((3, 3));
451        data[[0, 1]] = 1.5;
452        data[[1, 0]] = 1.5;
453        data[[0, 2]] = 2.5;
454        data[[2, 0]] = 2.5;
455        data[[1, 2]] = 3.5;
456        data[[2, 1]] = 3.5;
457        let gram = GramMatrix::new(data).expect("valid gram matrix");
458        assert!(gram.is_symmetric(1e-12));
459    }
460
461    #[test]
462    fn test_gram_matrix_frobenius() {
463        // Identity matrix of size n has Frobenius norm = sqrt(n)
464        let n = 4;
465        let mut data = Array2::<f64>::zeros((n, n));
466        for i in 0..n {
467            data[[i, i]] = 1.0;
468        }
469        let gram = GramMatrix::new(data).expect("valid gram matrix");
470        let expected = (n as f64).sqrt();
471        assert!((gram.frobenius_norm() - expected).abs() < 1e-12);
472    }
473
474    #[test]
475    fn test_gram_matrix_nonneg_diagonal() {
476        let mut data = Array2::<f64>::zeros((3, 3));
477        data[[0, 0]] = 1.0;
478        data[[1, 1]] = 0.0;
479        data[[2, 2]] = 5.0;
480        let gram = GramMatrix::new(data).expect("valid gram matrix");
481        assert!(gram.has_nonneg_diagonal());
482    }
483
484    // ── BatchKernelComputer tests ──────────────────────────────────
485
486    fn dot_product(x: &[f64], y: &[f64]) -> f64 {
487        x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
488    }
489
490    #[test]
491    fn test_batch_compute_basic() {
492        let mut computer = BatchKernelComputer::new();
493        let inputs = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
494        let (gram, stats) = computer.compute(&inputs, dot_product).expect("compute ok");
495        assert_eq!(gram.dim(), 3);
496        assert_eq!(stats.matrix_dim, 3);
497        // k([1,0],[0,1]) = 0
498        assert!((gram.get(0, 1)).abs() < 1e-12);
499        // k([1,0],[1,1]) = 1
500        assert!((gram.get(0, 2) - 1.0).abs() < 1e-12);
501        // k([1,1],[1,1]) = 2
502        assert!((gram.get(2, 2) - 2.0).abs() < 1e-12);
503    }
504
505    #[test]
506    fn test_batch_compute_symmetric_result() {
507        let mut computer = BatchKernelComputer::new();
508        let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
509        let (gram, _) = computer.compute(&inputs, dot_product).expect("compute ok");
510        assert!(gram.is_symmetric(1e-12));
511    }
512
513    #[test]
514    fn test_batch_compute_empty_batch() {
515        let mut computer = BatchKernelComputer::new();
516        let inputs: Vec<Vec<f64>> = vec![];
517        let result = computer.compute(&inputs, dot_product);
518        assert!(result.is_err());
519    }
520
521    #[test]
522    fn test_batch_compute_with_cache() {
523        let mut computer = BatchKernelComputer::with_cache(1024);
524        let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
525
526        // First computation: all misses
527        let (_, stats1) = computer.compute(&inputs, dot_product).expect("compute ok");
528        assert_eq!(stats1.cache_hits, 0);
529        assert!(stats1.cache_misses > 0);
530
531        // Second computation with same inputs: all hits
532        let (_, stats2) = computer.compute(&inputs, dot_product).expect("compute ok");
533        assert!(stats2.cache_hits > 0);
534        assert_eq!(stats2.cache_misses, 0);
535    }
536
537    #[test]
538    fn test_batch_stats() {
539        let mut computer = BatchKernelComputer::new();
540        let inputs = vec![vec![1.0], vec![2.0], vec![3.0]];
541        let (_, stats) = computer.compute(&inputs, dot_product).expect("compute ok");
542        assert_eq!(stats.matrix_dim, 3);
543        // Upper triangle including diagonal: n*(n+1)/2 = 6 evaluations
544        assert_eq!(stats.evaluations, 6);
545        assert!(stats.computation_ms >= 0.0);
546        // No cache, so hits/misses are 0
547        assert_eq!(stats.cache_hits, 0);
548        assert_eq!(stats.cache_misses, 0);
549    }
550}