1use crate::error::{KernelError, Result};
20use crate::types::Kernel;
21use std::collections::VecDeque;
22use std::sync::Arc;
23
24#[derive(Debug, Clone)]
26pub struct OnlineConfig {
27 pub initial_capacity: usize,
29 pub compute_full_matrix: bool,
31}
32
33impl Default for OnlineConfig {
34 fn default() -> Self {
35 Self {
36 initial_capacity: 64,
37 compute_full_matrix: true,
38 }
39 }
40}
41
42impl OnlineConfig {
43 pub fn with_capacity(capacity: usize) -> Self {
45 Self {
46 initial_capacity: capacity,
47 ..Default::default()
48 }
49 }
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct OnlineStats {
55 pub samples_added: usize,
57 pub samples_removed: usize,
59 pub kernel_computations: usize,
61 pub resizes: usize,
63}
64
65pub struct OnlineKernelMatrix {
88 kernel: Box<dyn Kernel>,
90 samples: Vec<Vec<f64>>,
92 matrix: Vec<Vec<f64>>,
94 config: OnlineConfig,
96 stats: OnlineStats,
98}
99
100impl OnlineKernelMatrix {
101 pub fn new(kernel: Box<dyn Kernel>) -> Self {
103 Self::with_config(kernel, OnlineConfig::default())
104 }
105
106 pub fn with_config(kernel: Box<dyn Kernel>, config: OnlineConfig) -> Self {
108 Self {
109 kernel,
110 samples: Vec::with_capacity(config.initial_capacity),
111 matrix: Vec::with_capacity(config.initial_capacity),
112 config,
113 stats: OnlineStats::default(),
114 }
115 }
116
117 pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
121 if let Some(first) = self.samples.first() {
123 if sample.len() != first.len() {
124 return Err(KernelError::DimensionMismatch {
125 expected: vec![first.len()],
126 got: vec![sample.len()],
127 context: "online kernel matrix".to_string(),
128 });
129 }
130 }
131
132 let n = self.samples.len();
133
134 let mut new_row = Vec::with_capacity(n + 1);
136 for existing in &self.samples {
137 let k = self.kernel.compute(&sample, existing)?;
138 new_row.push(k);
139 self.stats.kernel_computations += 1;
140 }
141
142 let k_self = self.kernel.compute(&sample, &sample)?;
144 new_row.push(k_self);
145 self.stats.kernel_computations += 1;
146
147 for (i, row) in self.matrix.iter_mut().enumerate() {
149 row.push(new_row[i]);
150 }
151
152 self.matrix.push(new_row);
154 self.samples.push(sample);
155 self.stats.samples_added += 1;
156
157 Ok(())
158 }
159
160 pub fn add_samples(&mut self, samples: Vec<Vec<f64>>) -> Result<()> {
164 for sample in samples {
165 self.add_sample(sample)?;
166 }
167 Ok(())
168 }
169
170 pub fn remove_sample(&mut self, index: usize) -> Result<Vec<f64>> {
174 if index >= self.samples.len() {
175 return Err(KernelError::ComputationError(format!(
176 "Index {} out of bounds for {} samples",
177 index,
178 self.samples.len()
179 )));
180 }
181
182 let removed = self.samples.remove(index);
184
185 self.matrix.remove(index);
187
188 for row in &mut self.matrix {
190 row.remove(index);
191 }
192
193 self.stats.samples_removed += 1;
194 Ok(removed)
195 }
196
197 pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
199 &self.matrix
200 }
201
202 pub fn get_samples(&self) -> &Vec<Vec<f64>> {
204 &self.samples
205 }
206
207 pub fn get(&self, i: usize, j: usize) -> Option<f64> {
209 self.matrix.get(i).and_then(|row| row.get(j).copied())
210 }
211
212 pub fn len(&self) -> usize {
214 self.samples.len()
215 }
216
217 pub fn is_empty(&self) -> bool {
219 self.samples.is_empty()
220 }
221
222 pub fn stats(&self) -> &OnlineStats {
224 &self.stats
225 }
226
227 pub fn clear(&mut self) {
229 self.samples.clear();
230 self.matrix.clear();
231 self.stats = OnlineStats::default();
232 }
233
234 pub fn kernel(&self) -> &dyn Kernel {
236 self.kernel.as_ref()
237 }
238
239 pub fn config(&self) -> &OnlineConfig {
241 &self.config
242 }
243
244 pub fn compute_with_sample(&self, query: &[f64], sample_idx: usize) -> Result<f64> {
246 let sample = self.samples.get(sample_idx).ok_or_else(|| {
247 KernelError::ComputationError(format!("Sample index {} not found", sample_idx))
248 })?;
249 self.kernel.compute(query, sample)
250 }
251
252 pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
254 let mut result = Vec::with_capacity(self.samples.len());
255 for sample in &self.samples {
256 let k = self.kernel.compute(query, sample)?;
257 result.push(k);
258 }
259 Ok(result)
260 }
261
262 pub fn to_matrix(&self) -> Vec<Vec<f64>> {
264 self.matrix.clone()
265 }
266}
267
268pub struct WindowedKernelMatrix {
290 kernel: Box<dyn Kernel>,
292 window_size: usize,
294 samples: VecDeque<Vec<f64>>,
296 matrix: Vec<Vec<f64>>,
298 stats: OnlineStats,
300}
301
302impl WindowedKernelMatrix {
303 pub fn new(kernel: Box<dyn Kernel>, window_size: usize) -> Self {
305 assert!(window_size > 0, "Window size must be positive");
306 Self {
307 kernel,
308 window_size,
309 samples: VecDeque::with_capacity(window_size),
310 matrix: Vec::with_capacity(window_size),
311 stats: OnlineStats::default(),
312 }
313 }
314
315 pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<Option<Vec<f64>>> {
317 if let Some(first) = self.samples.front() {
319 if sample.len() != first.len() {
320 return Err(KernelError::DimensionMismatch {
321 expected: vec![first.len()],
322 got: vec![sample.len()],
323 context: "windowed kernel matrix".to_string(),
324 });
325 }
326 }
327
328 let evicted = if self.samples.len() >= self.window_size {
329 let removed = self.samples.pop_front();
331
332 self.matrix.remove(0);
334 for row in &mut self.matrix {
335 row.remove(0);
336 }
337
338 self.stats.samples_removed += 1;
339 removed
340 } else {
341 None
342 };
343
344 let n = self.samples.len();
346 let mut new_row = Vec::with_capacity(n + 1);
347
348 for existing in &self.samples {
349 let k = self.kernel.compute(&sample, existing)?;
350 new_row.push(k);
351 self.stats.kernel_computations += 1;
352 }
353
354 let k_self = self.kernel.compute(&sample, &sample)?;
356 new_row.push(k_self);
357 self.stats.kernel_computations += 1;
358
359 for (i, row) in self.matrix.iter_mut().enumerate() {
361 row.push(new_row[i]);
362 }
363
364 self.matrix.push(new_row);
366 self.samples.push_back(sample);
367 self.stats.samples_added += 1;
368
369 Ok(evicted)
370 }
371
372 pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
374 &self.matrix
375 }
376
377 pub fn get_samples(&self) -> &VecDeque<Vec<f64>> {
379 &self.samples
380 }
381
382 pub fn window_size(&self) -> usize {
384 self.window_size
385 }
386
387 pub fn len(&self) -> usize {
389 self.samples.len()
390 }
391
392 pub fn is_empty(&self) -> bool {
394 self.samples.is_empty()
395 }
396
397 pub fn is_full(&self) -> bool {
399 self.samples.len() >= self.window_size
400 }
401
402 pub fn stats(&self) -> &OnlineStats {
404 &self.stats
405 }
406
407 pub fn clear(&mut self) {
409 self.samples.clear();
410 self.matrix.clear();
411 self.stats = OnlineStats::default();
412 }
413
414 pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
416 let mut result = Vec::with_capacity(self.samples.len());
417 for sample in &self.samples {
418 let k = self.kernel.compute(query, sample)?;
419 result.push(k);
420 }
421 Ok(result)
422 }
423}
424
425#[derive(Debug, Clone)]
427pub struct ForgetfulConfig {
428 pub lambda: f64,
432 pub removal_threshold: Option<f64>,
434 pub max_samples: Option<usize>,
436}
437
438impl Default for ForgetfulConfig {
439 fn default() -> Self {
440 Self {
441 lambda: 0.99,
442 removal_threshold: Some(0.01),
443 max_samples: None,
444 }
445 }
446}
447
448impl ForgetfulConfig {
449 pub fn with_lambda(lambda: f64) -> Result<Self> {
451 if lambda <= 0.0 || lambda > 1.0 {
452 return Err(KernelError::InvalidParameter {
453 parameter: "lambda".to_string(),
454 value: lambda.to_string(),
455 reason: "lambda must be in (0, 1]".to_string(),
456 });
457 }
458 Ok(Self {
459 lambda,
460 ..Default::default()
461 })
462 }
463
464 pub fn with_max_samples(mut self, max: usize) -> Self {
466 self.max_samples = Some(max);
467 self
468 }
469
470 pub fn with_threshold(mut self, threshold: f64) -> Self {
472 self.removal_threshold = Some(threshold);
473 self
474 }
475}
476
477pub struct ForgetfulKernelMatrix {
500 kernel: Box<dyn Kernel>,
502 config: ForgetfulConfig,
504 samples: Vec<Vec<f64>>,
506 weights: Vec<f64>,
508 matrix: Vec<Vec<f64>>,
510 stats: OnlineStats,
512}
513
514impl ForgetfulKernelMatrix {
515 pub fn new(kernel: Box<dyn Kernel>, config: ForgetfulConfig) -> Self {
517 Self {
518 kernel,
519 config,
520 samples: Vec::new(),
521 weights: Vec::new(),
522 matrix: Vec::new(),
523 stats: OnlineStats::default(),
524 }
525 }
526
527 pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
529 if let Some(first) = self.samples.first() {
531 if sample.len() != first.len() {
532 return Err(KernelError::DimensionMismatch {
533 expected: vec![first.len()],
534 got: vec![sample.len()],
535 context: "forgetful kernel matrix".to_string(),
536 });
537 }
538 }
539
540 for weight in &mut self.weights {
542 *weight *= self.config.lambda;
543 }
544
545 if let Some(threshold) = self.config.removal_threshold {
547 let mut i = 0;
548 while i < self.weights.len() {
549 if self.weights[i] < threshold {
550 self.remove_at(i);
551 } else {
552 i += 1;
553 }
554 }
555 }
556
557 if let Some(max) = self.config.max_samples {
559 while self.samples.len() >= max && !self.samples.is_empty() {
560 if let Some((min_idx, _)) = self
562 .weights
563 .iter()
564 .enumerate()
565 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
566 {
567 self.remove_at(min_idx);
568 }
569 }
570 }
571
572 let n = self.samples.len();
574 let mut new_row = Vec::with_capacity(n + 1);
575
576 for existing in &self.samples {
577 let k = self.kernel.compute(&sample, existing)?;
578 new_row.push(k);
579 self.stats.kernel_computations += 1;
580 }
581
582 let k_self = self.kernel.compute(&sample, &sample)?;
584 new_row.push(k_self);
585 self.stats.kernel_computations += 1;
586
587 for (i, row) in self.matrix.iter_mut().enumerate() {
589 row.push(new_row[i]);
590 }
591
592 self.matrix.push(new_row);
594 self.samples.push(sample);
595 self.weights.push(1.0); self.stats.samples_added += 1;
597
598 Ok(())
599 }
600
601 fn remove_at(&mut self, index: usize) {
603 self.samples.remove(index);
604 self.weights.remove(index);
605 self.matrix.remove(index);
606 for row in &mut self.matrix {
607 row.remove(index);
608 }
609 self.stats.samples_removed += 1;
610 }
611
612 pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
614 &self.matrix
615 }
616
617 pub fn get_weighted_matrix(&self) -> Vec<Vec<f64>> {
621 let n = self.matrix.len();
622 let mut weighted = vec![vec![0.0; n]; n];
623
624 for (i, (row, &weight_i)) in self.matrix.iter().zip(&self.weights).enumerate() {
625 let sqrt_wi = weight_i.sqrt();
626 for (j, (&k_val, &weight_j)) in row.iter().zip(&self.weights).enumerate() {
627 let sqrt_wj = weight_j.sqrt();
628 weighted[i][j] = k_val * sqrt_wi * sqrt_wj;
629 }
630 }
631
632 weighted
633 }
634
635 pub fn get_weights(&self) -> &Vec<f64> {
637 &self.weights
638 }
639
640 pub fn get_samples(&self) -> &Vec<Vec<f64>> {
642 &self.samples
643 }
644
645 pub fn len(&self) -> usize {
647 self.samples.len()
648 }
649
650 pub fn is_empty(&self) -> bool {
652 self.samples.is_empty()
653 }
654
655 pub fn stats(&self) -> &OnlineStats {
657 &self.stats
658 }
659
660 pub fn lambda(&self) -> f64 {
662 self.config.lambda
663 }
664
665 pub fn clear(&mut self) {
667 self.samples.clear();
668 self.weights.clear();
669 self.matrix.clear();
670 self.stats = OnlineStats::default();
671 }
672
673 pub fn compute_weighted(&self, query: &[f64]) -> Result<Vec<f64>> {
675 let mut result = Vec::with_capacity(self.samples.len());
676 for (sample, weight) in self.samples.iter().zip(&self.weights) {
677 let k = self.kernel.compute(query, sample)?;
678 result.push(k * weight.sqrt());
679 }
680 Ok(result)
681 }
682
683 pub fn effective_size(&self) -> f64 {
685 self.weights.iter().sum()
686 }
687}
688
689pub struct AdaptiveKernelMatrix {
693 kernel: Arc<dyn Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync>,
695 current_bandwidth: f64,
697 distance_sum: f64,
699 distance_count: usize,
701 inner: OnlineKernelMatrix,
703 adaptation_rate: f64,
705}
706
707impl AdaptiveKernelMatrix {
708 pub fn new<F>(kernel_factory: F, initial_bandwidth: f64, adaptation_rate: f64) -> Self
710 where
711 F: Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync + 'static,
712 {
713 let factory = Arc::new(kernel_factory);
714 let kernel = factory(initial_bandwidth);
715
716 Self {
717 kernel: factory,
718 current_bandwidth: initial_bandwidth,
719 distance_sum: 0.0,
720 distance_count: 0,
721 inner: OnlineKernelMatrix::new(kernel),
722 adaptation_rate,
723 }
724 }
725
726 pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
728 for existing in self.inner.get_samples() {
730 let dist_sq: f64 = sample
731 .iter()
732 .zip(existing.iter())
733 .map(|(a, b)| (a - b) * (a - b))
734 .sum();
735 let dist = dist_sq.sqrt();
736 self.distance_sum += dist;
737 self.distance_count += 1;
738 }
739
740 if self.distance_count > 0 {
742 let mean_dist = self.distance_sum / self.distance_count as f64;
743 let new_bandwidth = mean_dist / 2.0_f64.sqrt();
744
745 self.current_bandwidth = (1.0 - self.adaptation_rate) * self.current_bandwidth
747 + self.adaptation_rate * new_bandwidth;
748
749 let new_kernel = (self.kernel)(self.current_bandwidth);
751
752 let samples: Vec<Vec<f64>> = self.inner.get_samples().clone();
754 self.inner = OnlineKernelMatrix::new(new_kernel);
755 for s in samples {
756 self.inner.add_sample(s)?;
757 }
758 }
759
760 self.inner.add_sample(sample)
761 }
762
763 pub fn bandwidth(&self) -> f64 {
765 self.current_bandwidth
766 }
767
768 pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
770 self.inner.get_matrix()
771 }
772
773 pub fn len(&self) -> usize {
775 self.inner.len()
776 }
777
778 pub fn is_empty(&self) -> bool {
780 self.inner.is_empty()
781 }
782}
783
784#[cfg(test)]
785#[allow(clippy::needless_range_loop)]
786mod tests {
787 use super::*;
788 use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
789
790 #[test]
793 fn test_online_kernel_matrix_basic() {
794 let kernel = LinearKernel::new();
795 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
796
797 assert!(online.is_empty());
798
799 online.add_sample(vec![1.0, 2.0]).unwrap();
800 assert_eq!(online.len(), 1);
801
802 online.add_sample(vec![3.0, 4.0]).unwrap();
803 assert_eq!(online.len(), 2);
804
805 let matrix = online.get_matrix();
806 assert_eq!(matrix.len(), 2);
807 assert_eq!(matrix[0].len(), 2);
808 }
809
810 #[test]
811 fn test_online_kernel_matrix_values() {
812 let kernel = LinearKernel::new();
813 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
814
815 online.add_sample(vec![1.0, 0.0]).unwrap();
816 online.add_sample(vec![0.0, 1.0]).unwrap();
817
818 let matrix = online.get_matrix();
819
820 assert!((matrix[0][0] - 1.0).abs() < 1e-10);
822 assert!((matrix[1][1] - 1.0).abs() < 1e-10);
824 assert!((matrix[0][1]).abs() < 1e-10);
826 assert!((matrix[1][0]).abs() < 1e-10);
828 }
829
830 #[test]
831 fn test_online_kernel_matrix_symmetry() {
832 let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
833 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
834
835 online.add_sample(vec![1.0, 2.0, 3.0]).unwrap();
836 online.add_sample(vec![4.0, 5.0, 6.0]).unwrap();
837 online.add_sample(vec![7.0, 8.0, 9.0]).unwrap();
838
839 let matrix = online.get_matrix();
840
841 for i in 0..3 {
842 for j in 0..3 {
843 assert!(
844 (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
845 "Matrix not symmetric at ({}, {})",
846 i,
847 j
848 );
849 }
850 }
851 }
852
853 #[test]
854 fn test_online_kernel_matrix_remove() {
855 let kernel = LinearKernel::new();
856 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
857
858 online.add_sample(vec![1.0]).unwrap();
859 online.add_sample(vec![2.0]).unwrap();
860 online.add_sample(vec![3.0]).unwrap();
861
862 let removed = online.remove_sample(1).unwrap();
863 assert_eq!(removed, vec![2.0]);
864 assert_eq!(online.len(), 2);
865
866 let matrix = online.get_matrix();
867 assert_eq!(matrix.len(), 2);
868 assert_eq!(matrix[0].len(), 2);
869 }
870
871 #[test]
872 fn test_online_kernel_matrix_dimension_mismatch() {
873 let kernel = LinearKernel::new();
874 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
875
876 online.add_sample(vec![1.0, 2.0]).unwrap();
877 let result = online.add_sample(vec![1.0, 2.0, 3.0]);
878 assert!(result.is_err());
879 }
880
881 #[test]
882 fn test_online_kernel_matrix_compute_with_all() {
883 let kernel = LinearKernel::new();
884 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
885
886 online.add_sample(vec![1.0, 0.0]).unwrap();
887 online.add_sample(vec![0.0, 1.0]).unwrap();
888
889 let query = vec![1.0, 1.0];
890 let result = online.compute_with_all(&query).unwrap();
891
892 assert!((result[0] - 1.0).abs() < 1e-10);
894 assert!((result[1] - 1.0).abs() < 1e-10);
896 }
897
898 #[test]
899 fn test_online_kernel_matrix_stats() {
900 let kernel = LinearKernel::new();
901 let mut online = OnlineKernelMatrix::new(Box::new(kernel));
902
903 online.add_sample(vec![1.0]).unwrap();
904 online.add_sample(vec![2.0]).unwrap();
905 online.add_sample(vec![3.0]).unwrap();
906
907 let stats = online.stats();
908 assert_eq!(stats.samples_added, 3);
909 assert_eq!(stats.kernel_computations, 6);
911 }
912
913 #[test]
916 fn test_windowed_kernel_matrix_basic() {
917 let kernel = LinearKernel::new();
918 let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
919
920 assert_eq!(windowed.window_size(), 3);
921 assert!(!windowed.is_full());
922
923 windowed.add_sample(vec![1.0]).unwrap();
924 windowed.add_sample(vec![2.0]).unwrap();
925 windowed.add_sample(vec![3.0]).unwrap();
926
927 assert!(windowed.is_full());
928 assert_eq!(windowed.len(), 3);
929 }
930
931 #[test]
932 fn test_windowed_kernel_matrix_eviction() {
933 let kernel = LinearKernel::new();
934 let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
935
936 windowed.add_sample(vec![1.0]).unwrap();
937 windowed.add_sample(vec![2.0]).unwrap();
938
939 let evicted = windowed.add_sample(vec![3.0]).unwrap();
941 assert_eq!(evicted, Some(vec![1.0]));
942 assert_eq!(windowed.len(), 2);
943
944 let samples: Vec<_> = windowed.get_samples().iter().cloned().collect();
946 assert_eq!(samples, vec![vec![2.0], vec![3.0]]);
947 }
948
949 #[test]
950 fn test_windowed_kernel_matrix_values() {
951 let kernel = LinearKernel::new();
952 let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
953
954 windowed.add_sample(vec![1.0, 0.0]).unwrap();
955 windowed.add_sample(vec![0.0, 1.0]).unwrap();
956
957 let matrix = windowed.get_matrix();
958
959 assert!((matrix[0][0] - 1.0).abs() < 1e-10);
960 assert!((matrix[1][1] - 1.0).abs() < 1e-10);
961 assert!((matrix[0][1]).abs() < 1e-10);
962
963 windowed.add_sample(vec![1.0, 1.0]).unwrap();
965
966 let matrix = windowed.get_matrix();
967 assert!((matrix[0][0] - 1.0).abs() < 1e-10);
970 assert!((matrix[1][1] - 2.0).abs() < 1e-10);
971 assert!((matrix[0][1] - 1.0).abs() < 1e-10);
972 }
973
974 #[test]
975 fn test_windowed_kernel_matrix_dimension_mismatch() {
976 let kernel = LinearKernel::new();
977 let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
978
979 windowed.add_sample(vec![1.0, 2.0]).unwrap();
980 let result = windowed.add_sample(vec![1.0]);
981 assert!(result.is_err());
982 }
983
984 #[test]
987 fn test_forgetful_kernel_matrix_basic() {
988 let kernel = LinearKernel::new();
989 let config = ForgetfulConfig::with_lambda(0.9).unwrap();
990 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
991
992 forgetful.add_sample(vec![1.0]).unwrap();
993 forgetful.add_sample(vec![2.0]).unwrap();
994
995 assert_eq!(forgetful.len(), 2);
996 assert!((forgetful.lambda() - 0.9).abs() < 1e-10);
997 }
998
999 #[test]
1000 fn test_forgetful_kernel_matrix_weights() {
1001 let kernel = LinearKernel::new();
1002 let config = ForgetfulConfig {
1003 lambda: 0.8,
1004 removal_threshold: None,
1005 max_samples: None,
1006 };
1007 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1008
1009 forgetful.add_sample(vec![1.0]).unwrap();
1010 forgetful.add_sample(vec![2.0]).unwrap();
1011 forgetful.add_sample(vec![3.0]).unwrap();
1012
1013 let weights = forgetful.get_weights();
1014 assert!((weights[2] - 1.0).abs() < 1e-10);
1016 assert!((weights[1] - 0.8).abs() < 1e-10);
1018 assert!((weights[0] - 0.64).abs() < 1e-10);
1020 }
1021
1022 #[test]
1023 fn test_forgetful_kernel_matrix_weighted_matrix() {
1024 let kernel = LinearKernel::new();
1025 let config = ForgetfulConfig {
1026 lambda: 0.5,
1027 removal_threshold: None,
1028 max_samples: None,
1029 };
1030 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1031
1032 forgetful.add_sample(vec![1.0]).unwrap();
1033 forgetful.add_sample(vec![1.0]).unwrap();
1034
1035 let weighted = forgetful.get_weighted_matrix();
1036
1037 assert!((weighted[0][0] - 0.5).abs() < 1e-10);
1042 assert!((weighted[1][1] - 1.0).abs() < 1e-10);
1043 assert!((weighted[0][1] - 0.5_f64.sqrt()).abs() < 1e-10);
1044 }
1045
1046 #[test]
1047 fn test_forgetful_kernel_matrix_removal_threshold() {
1048 let kernel = LinearKernel::new();
1049 let config = ForgetfulConfig {
1050 lambda: 0.5,
1051 removal_threshold: Some(0.3),
1052 max_samples: None,
1053 };
1054 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1055
1056 forgetful.add_sample(vec![1.0]).unwrap();
1057 forgetful.add_sample(vec![2.0]).unwrap();
1058 forgetful.add_sample(vec![3.0]).unwrap();
1061 assert_eq!(forgetful.len(), 2);
1064 }
1065
1066 #[test]
1067 fn test_forgetful_kernel_matrix_max_samples() {
1068 let kernel = LinearKernel::new();
1069 let config = ForgetfulConfig {
1070 lambda: 1.0, removal_threshold: None,
1072 max_samples: Some(2),
1073 };
1074 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1075
1076 forgetful.add_sample(vec![1.0]).unwrap();
1077 forgetful.add_sample(vec![2.0]).unwrap();
1078 forgetful.add_sample(vec![3.0]).unwrap();
1079
1080 assert_eq!(forgetful.len(), 2);
1081 }
1082
1083 #[test]
1084 fn test_forgetful_kernel_matrix_effective_size() {
1085 let kernel = LinearKernel::new();
1086 let config = ForgetfulConfig {
1087 lambda: 0.9,
1088 removal_threshold: None,
1089 max_samples: None,
1090 };
1091 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1092
1093 forgetful.add_sample(vec![1.0]).unwrap();
1094 forgetful.add_sample(vec![2.0]).unwrap();
1095 forgetful.add_sample(vec![3.0]).unwrap();
1096
1097 let eff_size = forgetful.effective_size();
1099 assert!((eff_size - 2.71).abs() < 1e-10);
1100 }
1101
1102 #[test]
1103 fn test_forgetful_kernel_matrix_invalid_lambda() {
1104 let result = ForgetfulConfig::with_lambda(0.0);
1105 assert!(result.is_err());
1106
1107 let result = ForgetfulConfig::with_lambda(1.5);
1108 assert!(result.is_err());
1109 }
1110
1111 #[test]
1112 fn test_forgetful_kernel_matrix_dimension_mismatch() {
1113 let kernel = LinearKernel::new();
1114 let config = ForgetfulConfig::with_lambda(0.9).unwrap();
1115 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1116
1117 forgetful.add_sample(vec![1.0, 2.0]).unwrap();
1118 let result = forgetful.add_sample(vec![1.0]);
1119 assert!(result.is_err());
1120 }
1121
1122 #[test]
1125 fn test_adaptive_kernel_matrix_basic() {
1126 let mut adaptive = AdaptiveKernelMatrix::new(
1127 |gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).unwrap()),
1128 1.0,
1129 0.1,
1130 );
1131
1132 adaptive.add_sample(vec![1.0, 2.0]).unwrap();
1133 adaptive.add_sample(vec![3.0, 4.0]).unwrap();
1134 adaptive.add_sample(vec![5.0, 6.0]).unwrap();
1135
1136 assert_eq!(adaptive.len(), 3);
1137 assert!(adaptive.bandwidth() > 0.0);
1138 }
1139
1140 #[test]
1141 fn test_adaptive_kernel_matrix_bandwidth_update() {
1142 let mut adaptive = AdaptiveKernelMatrix::new(
1143 |gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).unwrap()),
1144 1.0,
1145 0.5, );
1147
1148 let initial = adaptive.bandwidth();
1149
1150 adaptive.add_sample(vec![0.0]).unwrap();
1151 adaptive.add_sample(vec![10.0]).unwrap();
1152
1153 let after = adaptive.bandwidth();
1155 assert_ne!(initial, after);
1156 }
1157
1158 #[test]
1161 fn test_online_empty_operations() {
1162 let kernel = LinearKernel::new();
1163 let online = OnlineKernelMatrix::new(Box::new(kernel));
1164
1165 assert!(online.is_empty());
1166 assert!(online.get_matrix().is_empty());
1167 assert!(online.get_samples().is_empty());
1168 }
1169
1170 #[test]
1171 fn test_windowed_clear() {
1172 let kernel = LinearKernel::new();
1173 let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
1174
1175 windowed.add_sample(vec![1.0]).unwrap();
1176 windowed.add_sample(vec![2.0]).unwrap();
1177 windowed.clear();
1178
1179 assert!(windowed.is_empty());
1180 assert_eq!(windowed.len(), 0);
1181 }
1182
1183 #[test]
1184 fn test_forgetful_clear() {
1185 let kernel = LinearKernel::new();
1186 let config = ForgetfulConfig::with_lambda(0.9).unwrap();
1187 let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1188
1189 forgetful.add_sample(vec![1.0]).unwrap();
1190 forgetful.add_sample(vec![2.0]).unwrap();
1191 forgetful.clear();
1192
1193 assert!(forgetful.is_empty());
1194 assert_eq!(forgetful.len(), 0);
1195 }
1196}