Skip to main content

tensorlogic_sklears_kernels/
online.rs

1//! Online kernel updates for streaming and incremental learning.
2//!
3//! This module provides kernel matrices that can be efficiently updated
4//! incrementally as new samples arrive, without recomputing the entire matrix.
5//!
6//! ## Features
7//!
8//! - **OnlineKernelMatrix** - Incrementally add samples with O(n) updates
9//! - **WindowedKernelMatrix** - Sliding window for bounded memory in time series
10//! - **ForgetfulKernelMatrix** - Exponential decay for concept drift adaptation
11//!
12//! ## Use Cases
13//!
14//! - Streaming data classification
15//! - Online learning with kernel methods
16//! - Time series with non-stationarity
17//! - Memory-constrained environments
18
19use crate::error::{KernelError, Result};
20use crate::types::Kernel;
21use std::collections::VecDeque;
22use std::sync::Arc;
23
24/// Configuration for online kernel matrix updates.
25#[derive(Debug, Clone)]
26pub struct OnlineConfig {
27    /// Initial capacity for samples
28    pub initial_capacity: usize,
29    /// Whether to compute full matrix or just needed entries
30    pub compute_full_matrix: bool,
31}
32
33impl Default for OnlineConfig {
34    fn default() -> Self {
35        Self {
36            initial_capacity: 64,
37            compute_full_matrix: true,
38        }
39    }
40}
41
42impl OnlineConfig {
43    /// Create a new configuration with specified initial capacity.
44    pub fn with_capacity(capacity: usize) -> Self {
45        Self {
46            initial_capacity: capacity,
47            ..Default::default()
48        }
49    }
50}
51
52/// Statistics for online kernel updates.
53#[derive(Debug, Clone, Default)]
54pub struct OnlineStats {
55    /// Number of samples added
56    pub samples_added: usize,
57    /// Number of samples removed (for windowed)
58    pub samples_removed: usize,
59    /// Total kernel computations performed
60    pub kernel_computations: usize,
61    /// Number of matrix resizes
62    pub resizes: usize,
63}
64
65/// Incrementally updatable kernel matrix.
66///
67/// Supports efficient O(n) updates when adding new samples,
68/// avoiding O(n²) recomputation of the entire matrix.
69///
70/// # Example
71///
72/// ```
73/// use tensorlogic_sklears_kernels::{OnlineKernelMatrix, RbfKernel, RbfKernelConfig, Kernel};
74///
75/// let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
76/// let mut online = OnlineKernelMatrix::new(Box::new(kernel));
77///
78/// // Add samples incrementally
79/// online.add_sample(vec![1.0, 2.0, 3.0]).unwrap();
80/// online.add_sample(vec![4.0, 5.0, 6.0]).unwrap();
81/// online.add_sample(vec![7.0, 8.0, 9.0]).unwrap();
82///
83/// // Get the kernel matrix
84/// let matrix = online.get_matrix();
85/// assert_eq!(matrix.len(), 3);
86/// ```
87pub struct OnlineKernelMatrix {
88    /// The underlying kernel function
89    kernel: Box<dyn Kernel>,
90    /// Stored samples
91    samples: Vec<Vec<f64>>,
92    /// Current kernel matrix (upper triangular stored as full for simplicity)
93    matrix: Vec<Vec<f64>>,
94    /// Configuration
95    config: OnlineConfig,
96    /// Statistics
97    stats: OnlineStats,
98}
99
100impl OnlineKernelMatrix {
101    /// Create a new online kernel matrix.
102    pub fn new(kernel: Box<dyn Kernel>) -> Self {
103        Self::with_config(kernel, OnlineConfig::default())
104    }
105
106    /// Create with custom configuration.
107    pub fn with_config(kernel: Box<dyn Kernel>, config: OnlineConfig) -> Self {
108        Self {
109            kernel,
110            samples: Vec::with_capacity(config.initial_capacity),
111            matrix: Vec::with_capacity(config.initial_capacity),
112            config,
113            stats: OnlineStats::default(),
114        }
115    }
116
117    /// Add a new sample and update the kernel matrix.
118    ///
119    /// Time complexity: O(n) where n is current number of samples.
120    pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
121        // Validate dimensions
122        if let Some(first) = self.samples.first() {
123            if sample.len() != first.len() {
124                return Err(KernelError::DimensionMismatch {
125                    expected: vec![first.len()],
126                    got: vec![sample.len()],
127                    context: "online kernel matrix".to_string(),
128                });
129            }
130        }
131
132        let n = self.samples.len();
133
134        // Compute kernel values between new sample and all existing samples
135        let mut new_row = Vec::with_capacity(n + 1);
136        for existing in &self.samples {
137            let k = self.kernel.compute(&sample, existing)?;
138            new_row.push(k);
139            self.stats.kernel_computations += 1;
140        }
141
142        // Self-similarity (usually 1.0 for normalized kernels)
143        let k_self = self.kernel.compute(&sample, &sample)?;
144        new_row.push(k_self);
145        self.stats.kernel_computations += 1;
146
147        // Update existing rows with new column
148        for (i, row) in self.matrix.iter_mut().enumerate() {
149            row.push(new_row[i]);
150        }
151
152        // Add new row
153        self.matrix.push(new_row);
154        self.samples.push(sample);
155        self.stats.samples_added += 1;
156
157        Ok(())
158    }
159
160    /// Add multiple samples at once (batch update).
161    ///
162    /// More efficient than adding one by one due to better cache utilization.
163    pub fn add_samples(&mut self, samples: Vec<Vec<f64>>) -> Result<()> {
164        for sample in samples {
165            self.add_sample(sample)?;
166        }
167        Ok(())
168    }
169
170    /// Remove a sample by index and update the matrix.
171    ///
172    /// Time complexity: O(n²) due to matrix restructuring.
173    pub fn remove_sample(&mut self, index: usize) -> Result<Vec<f64>> {
174        if index >= self.samples.len() {
175            return Err(KernelError::ComputationError(format!(
176                "Index {} out of bounds for {} samples",
177                index,
178                self.samples.len()
179            )));
180        }
181
182        // Remove from samples
183        let removed = self.samples.remove(index);
184
185        // Remove row
186        self.matrix.remove(index);
187
188        // Remove column from all remaining rows
189        for row in &mut self.matrix {
190            row.remove(index);
191        }
192
193        self.stats.samples_removed += 1;
194        Ok(removed)
195    }
196
197    /// Get the current kernel matrix.
198    pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
199        &self.matrix
200    }
201
202    /// Get the current samples.
203    pub fn get_samples(&self) -> &Vec<Vec<f64>> {
204        &self.samples
205    }
206
207    /// Get a specific kernel value.
208    pub fn get(&self, i: usize, j: usize) -> Option<f64> {
209        self.matrix.get(i).and_then(|row| row.get(j).copied())
210    }
211
212    /// Get the number of samples.
213    pub fn len(&self) -> usize {
214        self.samples.len()
215    }
216
217    /// Check if empty.
218    pub fn is_empty(&self) -> bool {
219        self.samples.is_empty()
220    }
221
222    /// Get statistics.
223    pub fn stats(&self) -> &OnlineStats {
224        &self.stats
225    }
226
227    /// Reset the matrix.
228    pub fn clear(&mut self) {
229        self.samples.clear();
230        self.matrix.clear();
231        self.stats = OnlineStats::default();
232    }
233
234    /// Get the underlying kernel.
235    pub fn kernel(&self) -> &dyn Kernel {
236        self.kernel.as_ref()
237    }
238
239    /// Get the configuration.
240    pub fn config(&self) -> &OnlineConfig {
241        &self.config
242    }
243
244    /// Compute kernel value between a query point and stored sample.
245    pub fn compute_with_sample(&self, query: &[f64], sample_idx: usize) -> Result<f64> {
246        let sample = self.samples.get(sample_idx).ok_or_else(|| {
247            KernelError::ComputationError(format!("Sample index {} not found", sample_idx))
248        })?;
249        self.kernel.compute(query, sample)
250    }
251
252    /// Compute kernel values between query and all stored samples.
253    pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
254        let mut result = Vec::with_capacity(self.samples.len());
255        for sample in &self.samples {
256            let k = self.kernel.compute(query, sample)?;
257            result.push(k);
258        }
259        Ok(result)
260    }
261
262    /// Clone the current matrix as a standalone 2D vector.
263    pub fn to_matrix(&self) -> Vec<Vec<f64>> {
264        self.matrix.clone()
265    }
266}
267
268/// Sliding window kernel matrix for bounded-memory streaming.
269///
270/// Maintains only the most recent `window_size` samples, automatically
271/// removing oldest samples when the window is full.
272///
273/// # Example
274///
275/// ```
276/// use tensorlogic_sklears_kernels::{WindowedKernelMatrix, LinearKernel, Kernel};
277///
278/// let kernel = LinearKernel::new();
279/// let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
280///
281/// // Add samples (window size = 3)
282/// windowed.add_sample(vec![1.0]).unwrap();
283/// windowed.add_sample(vec![2.0]).unwrap();
284/// windowed.add_sample(vec![3.0]).unwrap();
285/// windowed.add_sample(vec![4.0]).unwrap(); // First sample evicted
286///
287/// assert_eq!(windowed.len(), 3);
288/// ```
289pub struct WindowedKernelMatrix {
290    /// The underlying kernel function
291    kernel: Box<dyn Kernel>,
292    /// Window size
293    window_size: usize,
294    /// Stored samples in window (circular buffer)
295    samples: VecDeque<Vec<f64>>,
296    /// Current kernel matrix
297    matrix: Vec<Vec<f64>>,
298    /// Statistics
299    stats: OnlineStats,
300}
301
302impl WindowedKernelMatrix {
303    /// Create a new windowed kernel matrix.
304    pub fn new(kernel: Box<dyn Kernel>, window_size: usize) -> Self {
305        assert!(window_size > 0, "Window size must be positive");
306        Self {
307            kernel,
308            window_size,
309            samples: VecDeque::with_capacity(window_size),
310            matrix: Vec::with_capacity(window_size),
311            stats: OnlineStats::default(),
312        }
313    }
314
315    /// Add a sample, evicting oldest if window is full.
316    pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<Option<Vec<f64>>> {
317        // Validate dimensions
318        if let Some(first) = self.samples.front() {
319            if sample.len() != first.len() {
320                return Err(KernelError::DimensionMismatch {
321                    expected: vec![first.len()],
322                    got: vec![sample.len()],
323                    context: "windowed kernel matrix".to_string(),
324                });
325            }
326        }
327
328        let evicted = if self.samples.len() >= self.window_size {
329            // Remove oldest sample
330            let removed = self.samples.pop_front();
331
332            // Update matrix: remove first row and first column
333            self.matrix.remove(0);
334            for row in &mut self.matrix {
335                row.remove(0);
336            }
337
338            self.stats.samples_removed += 1;
339            removed
340        } else {
341            None
342        };
343
344        // Compute kernel values with existing samples
345        let n = self.samples.len();
346        let mut new_row = Vec::with_capacity(n + 1);
347
348        for existing in &self.samples {
349            let k = self.kernel.compute(&sample, existing)?;
350            new_row.push(k);
351            self.stats.kernel_computations += 1;
352        }
353
354        // Self-similarity
355        let k_self = self.kernel.compute(&sample, &sample)?;
356        new_row.push(k_self);
357        self.stats.kernel_computations += 1;
358
359        // Update existing rows
360        for (i, row) in self.matrix.iter_mut().enumerate() {
361            row.push(new_row[i]);
362        }
363
364        // Add new row
365        self.matrix.push(new_row);
366        self.samples.push_back(sample);
367        self.stats.samples_added += 1;
368
369        Ok(evicted)
370    }
371
372    /// Get the current kernel matrix.
373    pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
374        &self.matrix
375    }
376
377    /// Get samples in the window.
378    pub fn get_samples(&self) -> &VecDeque<Vec<f64>> {
379        &self.samples
380    }
381
382    /// Get the window size.
383    pub fn window_size(&self) -> usize {
384        self.window_size
385    }
386
387    /// Get current number of samples.
388    pub fn len(&self) -> usize {
389        self.samples.len()
390    }
391
392    /// Check if window is empty.
393    pub fn is_empty(&self) -> bool {
394        self.samples.is_empty()
395    }
396
397    /// Check if window is full.
398    pub fn is_full(&self) -> bool {
399        self.samples.len() >= self.window_size
400    }
401
402    /// Get statistics.
403    pub fn stats(&self) -> &OnlineStats {
404        &self.stats
405    }
406
407    /// Clear the window.
408    pub fn clear(&mut self) {
409        self.samples.clear();
410        self.matrix.clear();
411        self.stats = OnlineStats::default();
412    }
413
414    /// Compute kernel values between query and all samples in window.
415    pub fn compute_with_all(&self, query: &[f64]) -> Result<Vec<f64>> {
416        let mut result = Vec::with_capacity(self.samples.len());
417        for sample in &self.samples {
418            let k = self.kernel.compute(query, sample)?;
419            result.push(k);
420        }
421        Ok(result)
422    }
423}
424
425/// Configuration for forgetful kernel matrix.
426#[derive(Debug, Clone)]
427pub struct ForgetfulConfig {
428    /// Forgetting factor (0 < λ <= 1)
429    /// λ = 1: no forgetting (infinite memory)
430    /// λ < 1: older samples weighted less
431    pub lambda: f64,
432    /// Threshold below which samples are removed
433    pub removal_threshold: Option<f64>,
434    /// Maximum number of samples to keep
435    pub max_samples: Option<usize>,
436}
437
438impl Default for ForgetfulConfig {
439    fn default() -> Self {
440        Self {
441            lambda: 0.99,
442            removal_threshold: Some(0.01),
443            max_samples: None,
444        }
445    }
446}
447
448impl ForgetfulConfig {
449    /// Create configuration with specified forgetting factor.
450    pub fn with_lambda(lambda: f64) -> Result<Self> {
451        if lambda <= 0.0 || lambda > 1.0 {
452            return Err(KernelError::InvalidParameter {
453                parameter: "lambda".to_string(),
454                value: lambda.to_string(),
455                reason: "lambda must be in (0, 1]".to_string(),
456            });
457        }
458        Ok(Self {
459            lambda,
460            ..Default::default()
461        })
462    }
463
464    /// Set maximum samples limit.
465    pub fn with_max_samples(mut self, max: usize) -> Self {
466        self.max_samples = Some(max);
467        self
468    }
469
470    /// Set removal threshold.
471    pub fn with_threshold(mut self, threshold: f64) -> Self {
472        self.removal_threshold = Some(threshold);
473        self
474    }
475}
476
477/// Kernel matrix with exponential forgetting for concept drift adaptation.
478///
479/// Older samples are weighted by λ^age, allowing the model to adapt
480/// to changing data distributions.
481///
482/// # Example
483///
484/// ```
485/// use tensorlogic_sklears_kernels::{ForgetfulKernelMatrix, ForgetfulConfig, RbfKernel, RbfKernelConfig, Kernel};
486///
487/// let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
488/// let config = ForgetfulConfig::with_lambda(0.95).unwrap();
489/// let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
490///
491/// // Add samples (older ones get downweighted)
492/// forgetful.add_sample(vec![1.0, 2.0]).unwrap();
493/// forgetful.add_sample(vec![3.0, 4.0]).unwrap();
494/// forgetful.add_sample(vec![5.0, 6.0]).unwrap();
495///
496/// // Get weighted kernel matrix
497/// let weighted = forgetful.get_weighted_matrix();
498/// ```
499pub struct ForgetfulKernelMatrix {
500    /// The underlying kernel function
501    kernel: Box<dyn Kernel>,
502    /// Configuration
503    config: ForgetfulConfig,
504    /// Stored samples
505    samples: Vec<Vec<f64>>,
506    /// Sample weights (λ^age)
507    weights: Vec<f64>,
508    /// Raw kernel matrix (unweighted)
509    matrix: Vec<Vec<f64>>,
510    /// Statistics
511    stats: OnlineStats,
512}
513
514impl ForgetfulKernelMatrix {
515    /// Create a new forgetful kernel matrix.
516    pub fn new(kernel: Box<dyn Kernel>, config: ForgetfulConfig) -> Self {
517        Self {
518            kernel,
519            config,
520            samples: Vec::new(),
521            weights: Vec::new(),
522            matrix: Vec::new(),
523            stats: OnlineStats::default(),
524        }
525    }
526
527    /// Add a sample with forgetting.
528    pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
529        // Validate dimensions
530        if let Some(first) = self.samples.first() {
531            if sample.len() != first.len() {
532                return Err(KernelError::DimensionMismatch {
533                    expected: vec![first.len()],
534                    got: vec![sample.len()],
535                    context: "forgetful kernel matrix".to_string(),
536                });
537            }
538        }
539
540        // Age existing samples (multiply weights by lambda)
541        for weight in &mut self.weights {
542            *weight *= self.config.lambda;
543        }
544
545        // Remove samples below threshold
546        if let Some(threshold) = self.config.removal_threshold {
547            let mut i = 0;
548            while i < self.weights.len() {
549                if self.weights[i] < threshold {
550                    self.remove_at(i);
551                } else {
552                    i += 1;
553                }
554            }
555        }
556
557        // Enforce max samples
558        if let Some(max) = self.config.max_samples {
559            while self.samples.len() >= max && !self.samples.is_empty() {
560                // Remove the lowest weight sample
561                if let Some((min_idx, _)) = self
562                    .weights
563                    .iter()
564                    .enumerate()
565                    .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
566                {
567                    self.remove_at(min_idx);
568                }
569            }
570        }
571
572        // Compute kernel values with existing samples
573        let n = self.samples.len();
574        let mut new_row = Vec::with_capacity(n + 1);
575
576        for existing in &self.samples {
577            let k = self.kernel.compute(&sample, existing)?;
578            new_row.push(k);
579            self.stats.kernel_computations += 1;
580        }
581
582        // Self-similarity
583        let k_self = self.kernel.compute(&sample, &sample)?;
584        new_row.push(k_self);
585        self.stats.kernel_computations += 1;
586
587        // Update existing rows
588        for (i, row) in self.matrix.iter_mut().enumerate() {
589            row.push(new_row[i]);
590        }
591
592        // Add new row and sample
593        self.matrix.push(new_row);
594        self.samples.push(sample);
595        self.weights.push(1.0); // New sample has full weight
596        self.stats.samples_added += 1;
597
598        Ok(())
599    }
600
601    /// Remove sample at specific index.
602    fn remove_at(&mut self, index: usize) {
603        self.samples.remove(index);
604        self.weights.remove(index);
605        self.matrix.remove(index);
606        for row in &mut self.matrix {
607            row.remove(index);
608        }
609        self.stats.samples_removed += 1;
610    }
611
612    /// Get the raw (unweighted) kernel matrix.
613    pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
614        &self.matrix
615    }
616
617    /// Get the weighted kernel matrix.
618    ///
619    /// Each entry `K[i,j]` is multiplied by `sqrt(w_i * w_j)` to maintain PSD property.
620    pub fn get_weighted_matrix(&self) -> Vec<Vec<f64>> {
621        let n = self.matrix.len();
622        let mut weighted = vec![vec![0.0; n]; n];
623
624        for (i, (row, &weight_i)) in self.matrix.iter().zip(&self.weights).enumerate() {
625            let sqrt_wi = weight_i.sqrt();
626            for (j, (&k_val, &weight_j)) in row.iter().zip(&self.weights).enumerate() {
627                let sqrt_wj = weight_j.sqrt();
628                weighted[i][j] = k_val * sqrt_wi * sqrt_wj;
629            }
630        }
631
632        weighted
633    }
634
635    /// Get sample weights.
636    pub fn get_weights(&self) -> &Vec<f64> {
637        &self.weights
638    }
639
640    /// Get the current samples.
641    pub fn get_samples(&self) -> &Vec<Vec<f64>> {
642        &self.samples
643    }
644
645    /// Get the number of samples.
646    pub fn len(&self) -> usize {
647        self.samples.len()
648    }
649
650    /// Check if empty.
651    pub fn is_empty(&self) -> bool {
652        self.samples.is_empty()
653    }
654
655    /// Get statistics.
656    pub fn stats(&self) -> &OnlineStats {
657        &self.stats
658    }
659
660    /// Get the forgetting factor.
661    pub fn lambda(&self) -> f64 {
662        self.config.lambda
663    }
664
665    /// Clear all samples.
666    pub fn clear(&mut self) {
667        self.samples.clear();
668        self.weights.clear();
669        self.matrix.clear();
670        self.stats = OnlineStats::default();
671    }
672
673    /// Compute weighted kernel values between query and all stored samples.
674    pub fn compute_weighted(&self, query: &[f64]) -> Result<Vec<f64>> {
675        let mut result = Vec::with_capacity(self.samples.len());
676        for (sample, weight) in self.samples.iter().zip(&self.weights) {
677            let k = self.kernel.compute(query, sample)?;
678            result.push(k * weight.sqrt());
679        }
680        Ok(result)
681    }
682
683    /// Get effective sample size (sum of weights).
684    pub fn effective_size(&self) -> f64 {
685        self.weights.iter().sum()
686    }
687}
688
689/// Adaptive kernel with automatic bandwidth adjustment.
690///
691/// Updates kernel parameters based on incoming data statistics.
692pub struct AdaptiveKernelMatrix {
693    /// Base kernel (must be RBF-like with adjustable bandwidth)
694    kernel: Arc<dyn Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync>,
695    /// Current bandwidth parameter
696    current_bandwidth: f64,
697    /// Online mean of pairwise distances
698    distance_sum: f64,
699    /// Count of distance observations
700    distance_count: usize,
701    /// Inner online matrix
702    inner: OnlineKernelMatrix,
703    /// Adaptation rate
704    adaptation_rate: f64,
705}
706
707impl AdaptiveKernelMatrix {
708    /// Create adaptive kernel with bandwidth factory function.
709    pub fn new<F>(kernel_factory: F, initial_bandwidth: f64, adaptation_rate: f64) -> Self
710    where
711        F: Fn(f64) -> Box<dyn Kernel + Send + Sync> + Send + Sync + 'static,
712    {
713        let factory = Arc::new(kernel_factory);
714        let kernel = factory(initial_bandwidth);
715
716        Self {
717            kernel: factory,
718            current_bandwidth: initial_bandwidth,
719            distance_sum: 0.0,
720            distance_count: 0,
721            inner: OnlineKernelMatrix::new(kernel),
722            adaptation_rate,
723        }
724    }
725
726    /// Add sample with adaptive bandwidth update.
727    pub fn add_sample(&mut self, sample: Vec<f64>) -> Result<()> {
728        // Compute distances to existing samples for bandwidth adaptation
729        for existing in self.inner.get_samples() {
730            let dist_sq: f64 = sample
731                .iter()
732                .zip(existing.iter())
733                .map(|(a, b)| (a - b) * (a - b))
734                .sum();
735            let dist = dist_sq.sqrt();
736            self.distance_sum += dist;
737            self.distance_count += 1;
738        }
739
740        // Update bandwidth using median heuristic approximation
741        if self.distance_count > 0 {
742            let mean_dist = self.distance_sum / self.distance_count as f64;
743            let new_bandwidth = mean_dist / 2.0_f64.sqrt();
744
745            // Exponential moving average update
746            self.current_bandwidth = (1.0 - self.adaptation_rate) * self.current_bandwidth
747                + self.adaptation_rate * new_bandwidth;
748
749            // Rebuild kernel with new bandwidth
750            let new_kernel = (self.kernel)(self.current_bandwidth);
751
752            // Rebuild matrix with new kernel (expensive but necessary for adaptation)
753            let samples: Vec<Vec<f64>> = self.inner.get_samples().clone();
754            self.inner = OnlineKernelMatrix::new(new_kernel);
755            for s in samples {
756                self.inner.add_sample(s)?;
757            }
758        }
759
760        self.inner.add_sample(sample)
761    }
762
763    /// Get current bandwidth.
764    pub fn bandwidth(&self) -> f64 {
765        self.current_bandwidth
766    }
767
768    /// Get the underlying matrix.
769    pub fn get_matrix(&self) -> &Vec<Vec<f64>> {
770        self.inner.get_matrix()
771    }
772
773    /// Get number of samples.
774    pub fn len(&self) -> usize {
775        self.inner.len()
776    }
777
778    /// Check if empty.
779    pub fn is_empty(&self) -> bool {
780        self.inner.is_empty()
781    }
782}
783
784#[cfg(test)]
785#[allow(clippy::needless_range_loop)]
786mod tests {
787    use super::*;
788    use crate::{LinearKernel, RbfKernel, RbfKernelConfig};
789
790    // ===== OnlineKernelMatrix Tests =====
791
792    #[test]
793    fn test_online_kernel_matrix_basic() {
794        let kernel = LinearKernel::new();
795        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
796
797        assert!(online.is_empty());
798
799        online.add_sample(vec![1.0, 2.0]).unwrap();
800        assert_eq!(online.len(), 1);
801
802        online.add_sample(vec![3.0, 4.0]).unwrap();
803        assert_eq!(online.len(), 2);
804
805        let matrix = online.get_matrix();
806        assert_eq!(matrix.len(), 2);
807        assert_eq!(matrix[0].len(), 2);
808    }
809
810    #[test]
811    fn test_online_kernel_matrix_values() {
812        let kernel = LinearKernel::new();
813        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
814
815        online.add_sample(vec![1.0, 0.0]).unwrap();
816        online.add_sample(vec![0.0, 1.0]).unwrap();
817
818        let matrix = online.get_matrix();
819
820        // K[0,0] = [1,0]·[1,0] = 1
821        assert!((matrix[0][0] - 1.0).abs() < 1e-10);
822        // K[1,1] = [0,1]·[0,1] = 1
823        assert!((matrix[1][1] - 1.0).abs() < 1e-10);
824        // K[0,1] = [1,0]·[0,1] = 0
825        assert!((matrix[0][1]).abs() < 1e-10);
826        // K[1,0] = K[0,1]
827        assert!((matrix[1][0]).abs() < 1e-10);
828    }
829
830    #[test]
831    fn test_online_kernel_matrix_symmetry() {
832        let kernel = RbfKernel::new(RbfKernelConfig::new(0.5)).unwrap();
833        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
834
835        online.add_sample(vec![1.0, 2.0, 3.0]).unwrap();
836        online.add_sample(vec![4.0, 5.0, 6.0]).unwrap();
837        online.add_sample(vec![7.0, 8.0, 9.0]).unwrap();
838
839        let matrix = online.get_matrix();
840
841        for i in 0..3 {
842            for j in 0..3 {
843                assert!(
844                    (matrix[i][j] - matrix[j][i]).abs() < 1e-10,
845                    "Matrix not symmetric at ({}, {})",
846                    i,
847                    j
848                );
849            }
850        }
851    }
852
853    #[test]
854    fn test_online_kernel_matrix_remove() {
855        let kernel = LinearKernel::new();
856        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
857
858        online.add_sample(vec![1.0]).unwrap();
859        online.add_sample(vec![2.0]).unwrap();
860        online.add_sample(vec![3.0]).unwrap();
861
862        let removed = online.remove_sample(1).unwrap();
863        assert_eq!(removed, vec![2.0]);
864        assert_eq!(online.len(), 2);
865
866        let matrix = online.get_matrix();
867        assert_eq!(matrix.len(), 2);
868        assert_eq!(matrix[0].len(), 2);
869    }
870
871    #[test]
872    fn test_online_kernel_matrix_dimension_mismatch() {
873        let kernel = LinearKernel::new();
874        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
875
876        online.add_sample(vec![1.0, 2.0]).unwrap();
877        let result = online.add_sample(vec![1.0, 2.0, 3.0]);
878        assert!(result.is_err());
879    }
880
881    #[test]
882    fn test_online_kernel_matrix_compute_with_all() {
883        let kernel = LinearKernel::new();
884        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
885
886        online.add_sample(vec![1.0, 0.0]).unwrap();
887        online.add_sample(vec![0.0, 1.0]).unwrap();
888
889        let query = vec![1.0, 1.0];
890        let result = online.compute_with_all(&query).unwrap();
891
892        // [1,1]·[1,0] = 1
893        assert!((result[0] - 1.0).abs() < 1e-10);
894        // [1,1]·[0,1] = 1
895        assert!((result[1] - 1.0).abs() < 1e-10);
896    }
897
898    #[test]
899    fn test_online_kernel_matrix_stats() {
900        let kernel = LinearKernel::new();
901        let mut online = OnlineKernelMatrix::new(Box::new(kernel));
902
903        online.add_sample(vec![1.0]).unwrap();
904        online.add_sample(vec![2.0]).unwrap();
905        online.add_sample(vec![3.0]).unwrap();
906
907        let stats = online.stats();
908        assert_eq!(stats.samples_added, 3);
909        // 1 + 2 + 3 = 6 kernel computations (including self)
910        assert_eq!(stats.kernel_computations, 6);
911    }
912
913    // ===== WindowedKernelMatrix Tests =====
914
915    #[test]
916    fn test_windowed_kernel_matrix_basic() {
917        let kernel = LinearKernel::new();
918        let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
919
920        assert_eq!(windowed.window_size(), 3);
921        assert!(!windowed.is_full());
922
923        windowed.add_sample(vec![1.0]).unwrap();
924        windowed.add_sample(vec![2.0]).unwrap();
925        windowed.add_sample(vec![3.0]).unwrap();
926
927        assert!(windowed.is_full());
928        assert_eq!(windowed.len(), 3);
929    }
930
931    #[test]
932    fn test_windowed_kernel_matrix_eviction() {
933        let kernel = LinearKernel::new();
934        let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
935
936        windowed.add_sample(vec![1.0]).unwrap();
937        windowed.add_sample(vec![2.0]).unwrap();
938
939        // This should evict [1.0]
940        let evicted = windowed.add_sample(vec![3.0]).unwrap();
941        assert_eq!(evicted, Some(vec![1.0]));
942        assert_eq!(windowed.len(), 2);
943
944        // Check samples
945        let samples: Vec<_> = windowed.get_samples().iter().cloned().collect();
946        assert_eq!(samples, vec![vec![2.0], vec![3.0]]);
947    }
948
949    #[test]
950    fn test_windowed_kernel_matrix_values() {
951        let kernel = LinearKernel::new();
952        let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 2);
953
954        windowed.add_sample(vec![1.0, 0.0]).unwrap();
955        windowed.add_sample(vec![0.0, 1.0]).unwrap();
956
957        let matrix = windowed.get_matrix();
958
959        assert!((matrix[0][0] - 1.0).abs() < 1e-10);
960        assert!((matrix[1][1] - 1.0).abs() < 1e-10);
961        assert!((matrix[0][1]).abs() < 1e-10);
962
963        // Evict first and add new
964        windowed.add_sample(vec![1.0, 1.0]).unwrap();
965
966        let matrix = windowed.get_matrix();
967        // Now have [0,1] and [1,1]
968        // K[0,0] = 1, K[1,1] = 2, K[0,1] = 1
969        assert!((matrix[0][0] - 1.0).abs() < 1e-10);
970        assert!((matrix[1][1] - 2.0).abs() < 1e-10);
971        assert!((matrix[0][1] - 1.0).abs() < 1e-10);
972    }
973
974    #[test]
975    fn test_windowed_kernel_matrix_dimension_mismatch() {
976        let kernel = LinearKernel::new();
977        let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
978
979        windowed.add_sample(vec![1.0, 2.0]).unwrap();
980        let result = windowed.add_sample(vec![1.0]);
981        assert!(result.is_err());
982    }
983
984    // ===== ForgetfulKernelMatrix Tests =====
985
986    #[test]
987    fn test_forgetful_kernel_matrix_basic() {
988        let kernel = LinearKernel::new();
989        let config = ForgetfulConfig::with_lambda(0.9).unwrap();
990        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
991
992        forgetful.add_sample(vec![1.0]).unwrap();
993        forgetful.add_sample(vec![2.0]).unwrap();
994
995        assert_eq!(forgetful.len(), 2);
996        assert!((forgetful.lambda() - 0.9).abs() < 1e-10);
997    }
998
999    #[test]
1000    fn test_forgetful_kernel_matrix_weights() {
1001        let kernel = LinearKernel::new();
1002        let config = ForgetfulConfig {
1003            lambda: 0.8,
1004            removal_threshold: None,
1005            max_samples: None,
1006        };
1007        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1008
1009        forgetful.add_sample(vec![1.0]).unwrap();
1010        forgetful.add_sample(vec![2.0]).unwrap();
1011        forgetful.add_sample(vec![3.0]).unwrap();
1012
1013        let weights = forgetful.get_weights();
1014        // Newest has weight 1.0
1015        assert!((weights[2] - 1.0).abs() < 1e-10);
1016        // Middle has weight 0.8
1017        assert!((weights[1] - 0.8).abs() < 1e-10);
1018        // Oldest has weight 0.8^2 = 0.64
1019        assert!((weights[0] - 0.64).abs() < 1e-10);
1020    }
1021
1022    #[test]
1023    fn test_forgetful_kernel_matrix_weighted_matrix() {
1024        let kernel = LinearKernel::new();
1025        let config = ForgetfulConfig {
1026            lambda: 0.5,
1027            removal_threshold: None,
1028            max_samples: None,
1029        };
1030        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1031
1032        forgetful.add_sample(vec![1.0]).unwrap();
1033        forgetful.add_sample(vec![1.0]).unwrap();
1034
1035        let weighted = forgetful.get_weighted_matrix();
1036
1037        // w[0] = 0.5, w[1] = 1.0
1038        // K[0,0] = 1 * sqrt(0.5) * sqrt(0.5) = 0.5
1039        // K[1,1] = 1 * sqrt(1.0) * sqrt(1.0) = 1.0
1040        // K[0,1] = 1 * sqrt(0.5) * sqrt(1.0) = sqrt(0.5)
1041        assert!((weighted[0][0] - 0.5).abs() < 1e-10);
1042        assert!((weighted[1][1] - 1.0).abs() < 1e-10);
1043        assert!((weighted[0][1] - 0.5_f64.sqrt()).abs() < 1e-10);
1044    }
1045
1046    #[test]
1047    fn test_forgetful_kernel_matrix_removal_threshold() {
1048        let kernel = LinearKernel::new();
1049        let config = ForgetfulConfig {
1050            lambda: 0.5,
1051            removal_threshold: Some(0.3),
1052            max_samples: None,
1053        };
1054        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1055
1056        forgetful.add_sample(vec![1.0]).unwrap();
1057        forgetful.add_sample(vec![2.0]).unwrap();
1058        // First sample now has weight 0.5
1059
1060        forgetful.add_sample(vec![3.0]).unwrap();
1061        // First sample would have weight 0.25 < 0.3, should be removed
1062
1063        assert_eq!(forgetful.len(), 2);
1064    }
1065
1066    #[test]
1067    fn test_forgetful_kernel_matrix_max_samples() {
1068        let kernel = LinearKernel::new();
1069        let config = ForgetfulConfig {
1070            lambda: 1.0, // No decay
1071            removal_threshold: None,
1072            max_samples: Some(2),
1073        };
1074        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1075
1076        forgetful.add_sample(vec![1.0]).unwrap();
1077        forgetful.add_sample(vec![2.0]).unwrap();
1078        forgetful.add_sample(vec![3.0]).unwrap();
1079
1080        assert_eq!(forgetful.len(), 2);
1081    }
1082
1083    #[test]
1084    fn test_forgetful_kernel_matrix_effective_size() {
1085        let kernel = LinearKernel::new();
1086        let config = ForgetfulConfig {
1087            lambda: 0.9,
1088            removal_threshold: None,
1089            max_samples: None,
1090        };
1091        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1092
1093        forgetful.add_sample(vec![1.0]).unwrap();
1094        forgetful.add_sample(vec![2.0]).unwrap();
1095        forgetful.add_sample(vec![3.0]).unwrap();
1096
1097        // Weights: 0.81, 0.9, 1.0
1098        let eff_size = forgetful.effective_size();
1099        assert!((eff_size - 2.71).abs() < 1e-10);
1100    }
1101
1102    #[test]
1103    fn test_forgetful_kernel_matrix_invalid_lambda() {
1104        let result = ForgetfulConfig::with_lambda(0.0);
1105        assert!(result.is_err());
1106
1107        let result = ForgetfulConfig::with_lambda(1.5);
1108        assert!(result.is_err());
1109    }
1110
1111    #[test]
1112    fn test_forgetful_kernel_matrix_dimension_mismatch() {
1113        let kernel = LinearKernel::new();
1114        let config = ForgetfulConfig::with_lambda(0.9).unwrap();
1115        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1116
1117        forgetful.add_sample(vec![1.0, 2.0]).unwrap();
1118        let result = forgetful.add_sample(vec![1.0]);
1119        assert!(result.is_err());
1120    }
1121
1122    // ===== AdaptiveKernelMatrix Tests =====
1123
1124    #[test]
1125    fn test_adaptive_kernel_matrix_basic() {
1126        let mut adaptive = AdaptiveKernelMatrix::new(
1127            |gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).unwrap()),
1128            1.0,
1129            0.1,
1130        );
1131
1132        adaptive.add_sample(vec![1.0, 2.0]).unwrap();
1133        adaptive.add_sample(vec![3.0, 4.0]).unwrap();
1134        adaptive.add_sample(vec![5.0, 6.0]).unwrap();
1135
1136        assert_eq!(adaptive.len(), 3);
1137        assert!(adaptive.bandwidth() > 0.0);
1138    }
1139
1140    #[test]
1141    fn test_adaptive_kernel_matrix_bandwidth_update() {
1142        let mut adaptive = AdaptiveKernelMatrix::new(
1143            |gamma| Box::new(RbfKernel::new(RbfKernelConfig::new(gamma)).unwrap()),
1144            1.0,
1145            0.5, // High adaptation rate
1146        );
1147
1148        let initial = adaptive.bandwidth();
1149
1150        adaptive.add_sample(vec![0.0]).unwrap();
1151        adaptive.add_sample(vec![10.0]).unwrap();
1152
1153        // Bandwidth should have changed
1154        let after = adaptive.bandwidth();
1155        assert_ne!(initial, after);
1156    }
1157
1158    // ===== Edge case tests =====
1159
1160    #[test]
1161    fn test_online_empty_operations() {
1162        let kernel = LinearKernel::new();
1163        let online = OnlineKernelMatrix::new(Box::new(kernel));
1164
1165        assert!(online.is_empty());
1166        assert!(online.get_matrix().is_empty());
1167        assert!(online.get_samples().is_empty());
1168    }
1169
1170    #[test]
1171    fn test_windowed_clear() {
1172        let kernel = LinearKernel::new();
1173        let mut windowed = WindowedKernelMatrix::new(Box::new(kernel), 3);
1174
1175        windowed.add_sample(vec![1.0]).unwrap();
1176        windowed.add_sample(vec![2.0]).unwrap();
1177        windowed.clear();
1178
1179        assert!(windowed.is_empty());
1180        assert_eq!(windowed.len(), 0);
1181    }
1182
1183    #[test]
1184    fn test_forgetful_clear() {
1185        let kernel = LinearKernel::new();
1186        let config = ForgetfulConfig::with_lambda(0.9).unwrap();
1187        let mut forgetful = ForgetfulKernelMatrix::new(Box::new(kernel), config);
1188
1189        forgetful.add_sample(vec![1.0]).unwrap();
1190        forgetful.add_sample(vec![2.0]).unwrap();
1191        forgetful.clear();
1192
1193        assert!(forgetful.is_empty());
1194        assert_eq!(forgetful.len(), 0);
1195    }
1196}