1use std::collections::{HashMap, VecDeque};
7
8use scirs2_core::ndarray::Array2;
9
10use crate::error::{KernelError, Result};
11
12pub struct KernelCache {
17 entries: HashMap<(usize, usize), f64>,
18 lru_order: VecDeque<(usize, usize)>,
19 capacity: usize,
20 hits: u64,
21 misses: u64,
22}
23
24impl KernelCache {
25 pub fn new(capacity: usize) -> Self {
29 Self {
30 entries: HashMap::with_capacity(capacity),
31 lru_order: VecDeque::with_capacity(capacity),
32 capacity,
33 hits: 0,
34 misses: 0,
35 }
36 }
37
38 fn normalize_key(i: usize, j: usize) -> (usize, usize) {
40 if i <= j {
41 (i, j)
42 } else {
43 (j, i)
44 }
45 }
46
47 pub fn get(&mut self, i: usize, j: usize) -> Option<f64> {
49 let key = Self::normalize_key(i, j);
50 if let Some(&value) = self.entries.get(&key) {
51 self.hits += 1;
52 if let Some(pos) = self.lru_order.iter().position(|k| *k == key) {
54 self.lru_order.remove(pos);
55 }
56 self.lru_order.push_back(key);
57 Some(value)
58 } else {
59 self.misses += 1;
60 None
61 }
62 }
63
64 pub fn insert(&mut self, i: usize, j: usize, value: f64) {
66 let key = Self::normalize_key(i, j);
67
68 if let std::collections::hash_map::Entry::Occupied(mut e) = self.entries.entry(key) {
70 e.insert(value);
71 if let Some(pos) = self.lru_order.iter().position(|k| *k == key) {
72 self.lru_order.remove(pos);
73 }
74 self.lru_order.push_back(key);
75 return;
76 }
77
78 if self.entries.len() >= self.capacity && self.capacity > 0 {
80 if let Some(evicted) = self.lru_order.pop_front() {
81 self.entries.remove(&evicted);
82 }
83 }
84
85 self.entries.insert(key, value);
86 self.lru_order.push_back(key);
87 }
88
89 pub fn hit_rate(&self) -> f64 {
91 let total = self.hits + self.misses;
92 if total == 0 {
93 0.0
94 } else {
95 self.hits as f64 / total as f64
96 }
97 }
98
99 pub fn len(&self) -> usize {
101 self.entries.len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.entries.is_empty()
107 }
108
109 pub fn clear(&mut self) {
111 self.entries.clear();
112 self.lru_order.clear();
113 self.hits = 0;
114 self.misses = 0;
115 }
116
117 pub fn hits(&self) -> u64 {
119 self.hits
120 }
121
122 pub fn misses(&self) -> u64 {
124 self.misses
125 }
126}
127
128#[derive(Debug, Clone)]
133pub struct GramMatrix {
134 data: Array2<f64>,
135}
136
137impl GramMatrix {
138 pub fn new(data: Array2<f64>) -> Result<Self> {
140 if data.nrows() != data.ncols() {
141 return Err(KernelError::DimensionMismatch {
142 expected: vec![data.nrows(), data.nrows()],
143 got: vec![data.nrows(), data.ncols()],
144 context: "GramMatrix must be square".to_string(),
145 });
146 }
147 Ok(GramMatrix { data })
148 }
149
150 pub fn get(&self, i: usize, j: usize) -> f64 {
152 self.data[[i, j]]
153 }
154
155 pub fn dim(&self) -> usize {
157 self.data.nrows()
158 }
159
160 pub fn diagonal(&self) -> Vec<f64> {
162 (0..self.dim()).map(|i| self.data[[i, i]]).collect()
163 }
164
165 pub fn trace(&self) -> f64 {
167 self.diagonal().iter().sum()
168 }
169
170 pub fn is_symmetric(&self, tol: f64) -> bool {
172 let n = self.dim();
173 for i in 0..n {
174 for j in (i + 1)..n {
175 if (self.data[[i, j]] - self.data[[j, i]]).abs() > tol {
176 return false;
177 }
178 }
179 }
180 true
181 }
182
183 pub fn has_nonneg_diagonal(&self) -> bool {
185 self.diagonal().iter().all(|&d| d >= 0.0)
186 }
187
188 pub fn frobenius_norm(&self) -> f64 {
190 self.data.iter().map(|v| v * v).sum::<f64>().sqrt()
191 }
192
193 pub fn as_array(&self) -> &Array2<f64> {
195 &self.data
196 }
197}
198
199#[derive(Debug, Clone, Default)]
201pub struct KernelMatrixStats {
202 pub evaluations: u64,
204 pub cache_hits: u64,
206 pub cache_misses: u64,
208 pub matrix_dim: usize,
210 pub computation_ms: f64,
212}
213
214impl KernelMatrixStats {
215 pub fn cache_hit_rate(&self) -> f64 {
217 let total = self.cache_hits + self.cache_misses;
218 if total == 0 {
219 0.0
220 } else {
221 self.cache_hits as f64 / total as f64
222 }
223 }
224}
225
226pub struct BatchKernelComputer {
232 cache: Option<KernelCache>,
233}
234
235impl BatchKernelComputer {
236 pub fn new() -> Self {
238 BatchKernelComputer { cache: None }
239 }
240
241 pub fn with_cache(capacity: usize) -> Self {
243 BatchKernelComputer {
244 cache: Some(KernelCache::new(capacity)),
245 }
246 }
247
248 pub fn compute<F>(
257 &mut self,
258 inputs: &[Vec<f64>],
259 kernel_fn: F,
260 ) -> Result<(GramMatrix, KernelMatrixStats)>
261 where
262 F: Fn(&[f64], &[f64]) -> f64,
263 {
264 if inputs.is_empty() {
265 return Err(KernelError::ComputationError(
266 "Empty input batch".to_string(),
267 ));
268 }
269
270 let n = inputs.len();
271 let dim = inputs[0].len();
272
273 for (idx, input) in inputs.iter().enumerate() {
275 if input.len() != dim {
276 return Err(KernelError::DimensionMismatch {
277 expected: vec![dim],
278 got: vec![input.len()],
279 context: format!("Input vector at index {idx} has wrong dimension"),
280 });
281 }
282 }
283
284 let start = std::time::Instant::now();
285 let mut matrix = Array2::<f64>::zeros((n, n));
286 let mut stats = KernelMatrixStats {
287 matrix_dim: n,
288 ..Default::default()
289 };
290
291 for i in 0..n {
292 for j in i..n {
293 let value = if let Some(ref mut cache) = self.cache {
294 if let Some(cached) = cache.get(i, j) {
295 stats.cache_hits += 1;
296 cached
297 } else {
298 stats.cache_misses += 1;
299 let v = kernel_fn(&inputs[i], &inputs[j]);
300 cache.insert(i, j, v);
301 v
302 }
303 } else {
304 kernel_fn(&inputs[i], &inputs[j])
305 };
306 stats.evaluations += 1;
307 matrix[[i, j]] = value;
308 if i != j {
309 matrix[[j, i]] = value;
310 }
311 }
312 }
313
314 stats.computation_ms = start.elapsed().as_secs_f64() * 1000.0;
315
316 let gram = GramMatrix { data: matrix };
317 Ok((gram, stats))
318 }
319
320 pub fn clear_cache(&mut self) {
322 if let Some(ref mut cache) = self.cache {
323 cache.clear();
324 }
325 }
326
327 pub fn cache_hit_rate(&self) -> Option<f64> {
329 self.cache.as_ref().map(|c| {
330 let total = c.hits + c.misses;
331 if total == 0 {
332 0.0
333 } else {
334 c.hits as f64 / total as f64
335 }
336 })
337 }
338}
339
340impl Default for BatchKernelComputer {
341 fn default() -> Self {
342 Self::new()
343 }
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
353 fn test_kernel_cache_insert_get() {
354 let mut cache = KernelCache::new(16);
355 cache.insert(0, 1, 7.53);
356 let val = cache.get(0, 1);
357 assert_eq!(val, Some(7.53));
358 }
359
360 #[test]
361 fn test_kernel_cache_symmetric() {
362 let mut cache = KernelCache::new(16);
363 cache.insert(1, 2, 42.0);
364 assert_eq!(cache.get(2, 1), Some(42.0));
365 assert_eq!(cache.get(1, 2), Some(42.0));
366 }
367
368 #[test]
369 fn test_kernel_cache_miss() {
370 let mut cache = KernelCache::new(16);
371 assert_eq!(cache.get(5, 7), None);
372 }
373
374 #[test]
375 fn test_kernel_cache_hit_rate() {
376 let mut cache = KernelCache::new(16);
377 cache.insert(0, 1, 1.0);
378 let _ = cache.get(0, 1); let _ = cache.get(2, 3); let rate = cache.hit_rate();
381 assert!((rate - 0.5).abs() < 1e-12);
382 }
383
384 #[test]
385 fn test_kernel_cache_eviction() {
386 let mut cache = KernelCache::new(2);
387 cache.insert(0, 1, 1.0);
388 cache.insert(2, 3, 2.0);
389 cache.insert(4, 5, 3.0);
391 assert_eq!(cache.len(), 2);
392 assert_eq!(cache.get(0, 1), None);
394 assert_eq!(cache.get(2, 3), Some(2.0));
395 assert_eq!(cache.get(4, 5), Some(3.0));
396 }
397
398 #[test]
399 fn test_kernel_cache_clear() {
400 let mut cache = KernelCache::new(16);
401 cache.insert(0, 1, 1.0);
402 cache.insert(2, 3, 2.0);
403 assert_eq!(cache.len(), 2);
404 cache.clear();
405 assert_eq!(cache.len(), 0);
406 assert!(cache.is_empty());
407 assert_eq!(cache.hits(), 0);
408 assert_eq!(cache.misses(), 0);
409 }
410
411 #[test]
414 fn test_gram_matrix_new_valid() {
415 let data = Array2::<f64>::zeros((3, 3));
416 let gram = GramMatrix::new(data);
417 assert!(gram.is_ok());
418 assert_eq!(gram.expect("valid gram matrix").dim(), 3);
419 }
420
421 #[test]
422 fn test_gram_matrix_not_square() {
423 let data = Array2::<f64>::zeros((3, 2));
424 let gram = GramMatrix::new(data);
425 assert!(gram.is_err());
426 }
427
428 #[test]
429 fn test_gram_matrix_diagonal() {
430 let mut data = Array2::<f64>::zeros((3, 3));
431 data[[0, 0]] = 1.0;
432 data[[1, 1]] = 2.0;
433 data[[2, 2]] = 3.0;
434 let gram = GramMatrix::new(data).expect("valid gram matrix");
435 assert_eq!(gram.diagonal(), vec![1.0, 2.0, 3.0]);
436 }
437
438 #[test]
439 fn test_gram_matrix_trace() {
440 let mut data = Array2::<f64>::zeros((3, 3));
441 data[[0, 0]] = 1.0;
442 data[[1, 1]] = 2.0;
443 data[[2, 2]] = 3.0;
444 let gram = GramMatrix::new(data).expect("valid gram matrix");
445 assert!((gram.trace() - 6.0).abs() < 1e-12);
446 }
447
448 #[test]
449 fn test_gram_matrix_symmetric() {
450 let mut data = Array2::<f64>::zeros((3, 3));
451 data[[0, 1]] = 1.5;
452 data[[1, 0]] = 1.5;
453 data[[0, 2]] = 2.5;
454 data[[2, 0]] = 2.5;
455 data[[1, 2]] = 3.5;
456 data[[2, 1]] = 3.5;
457 let gram = GramMatrix::new(data).expect("valid gram matrix");
458 assert!(gram.is_symmetric(1e-12));
459 }
460
461 #[test]
462 fn test_gram_matrix_frobenius() {
463 let n = 4;
465 let mut data = Array2::<f64>::zeros((n, n));
466 for i in 0..n {
467 data[[i, i]] = 1.0;
468 }
469 let gram = GramMatrix::new(data).expect("valid gram matrix");
470 let expected = (n as f64).sqrt();
471 assert!((gram.frobenius_norm() - expected).abs() < 1e-12);
472 }
473
474 #[test]
475 fn test_gram_matrix_nonneg_diagonal() {
476 let mut data = Array2::<f64>::zeros((3, 3));
477 data[[0, 0]] = 1.0;
478 data[[1, 1]] = 0.0;
479 data[[2, 2]] = 5.0;
480 let gram = GramMatrix::new(data).expect("valid gram matrix");
481 assert!(gram.has_nonneg_diagonal());
482 }
483
484 fn dot_product(x: &[f64], y: &[f64]) -> f64 {
487 x.iter().zip(y.iter()).map(|(a, b)| a * b).sum()
488 }
489
490 #[test]
491 fn test_batch_compute_basic() {
492 let mut computer = BatchKernelComputer::new();
493 let inputs = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
494 let (gram, stats) = computer.compute(&inputs, dot_product).expect("compute ok");
495 assert_eq!(gram.dim(), 3);
496 assert_eq!(stats.matrix_dim, 3);
497 assert!((gram.get(0, 1)).abs() < 1e-12);
499 assert!((gram.get(0, 2) - 1.0).abs() < 1e-12);
501 assert!((gram.get(2, 2) - 2.0).abs() < 1e-12);
503 }
504
505 #[test]
506 fn test_batch_compute_symmetric_result() {
507 let mut computer = BatchKernelComputer::new();
508 let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
509 let (gram, _) = computer.compute(&inputs, dot_product).expect("compute ok");
510 assert!(gram.is_symmetric(1e-12));
511 }
512
513 #[test]
514 fn test_batch_compute_empty_batch() {
515 let mut computer = BatchKernelComputer::new();
516 let inputs: Vec<Vec<f64>> = vec![];
517 let result = computer.compute(&inputs, dot_product);
518 assert!(result.is_err());
519 }
520
521 #[test]
522 fn test_batch_compute_with_cache() {
523 let mut computer = BatchKernelComputer::with_cache(1024);
524 let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
525
526 let (_, stats1) = computer.compute(&inputs, dot_product).expect("compute ok");
528 assert_eq!(stats1.cache_hits, 0);
529 assert!(stats1.cache_misses > 0);
530
531 let (_, stats2) = computer.compute(&inputs, dot_product).expect("compute ok");
533 assert!(stats2.cache_hits > 0);
534 assert_eq!(stats2.cache_misses, 0);
535 }
536
537 #[test]
538 fn test_batch_stats() {
539 let mut computer = BatchKernelComputer::new();
540 let inputs = vec![vec![1.0], vec![2.0], vec![3.0]];
541 let (_, stats) = computer.compute(&inputs, dot_product).expect("compute ok");
542 assert_eq!(stats.matrix_dim, 3);
543 assert_eq!(stats.evaluations, 6);
545 assert!(stats.computation_ms >= 0.0);
546 assert_eq!(stats.cache_hits, 0);
548 assert_eq!(stats.cache_misses, 0);
549 }
550}