Skip to main content

sklears_kernel_approximation/
streaming_kernel.rs

1use scirs2_core::ndarray::{s, Array1, Array2, Axis};
2use scirs2_core::random::rngs::StdRng;
3use scirs2_core::random::RngExt;
4use scirs2_core::random::{thread_rng, SeedableRng};
5use scirs2_core::StandardNormal;
6use sklears_core::error::{Result, SklearsError};
7use std::collections::VecDeque;
8
9/// Streaming kernel approximation methods for online processing
10///
11/// This module provides online learning capabilities for kernel
12/// approximations, enabling processing of data streams where samples
13/// arrive continuously and memory is limited.
14///
15/// Buffer management strategy for streaming data
16#[derive(Debug, Clone)]
17pub enum BufferStrategy {
18    /// Fixed-size buffer with FIFO replacement
19    FixedSize(usize),
20    /// Sliding window with time-based expiration
21    SlidingWindow { size: usize, time_window: f64 },
22    /// Reservoir sampling for representative subset
23    ReservoirSampling(usize),
24    /// Exponential decay weighting
25    ExponentialDecay { alpha: f64, min_weight: f64 },
26    /// Importance-weighted sampling
27    ImportanceWeighted { capacity: usize, threshold: f64 },
28}
29
30/// Update frequency for model parameters
31#[derive(Debug, Clone)]
32pub enum UpdateFrequency {
33    /// Update after every sample
34    PerSample,
35    /// Update after every N samples
36    BatchSize(usize),
37    /// Update based on time intervals
38    TimeInterval(f64),
39    /// Update when error exceeds threshold
40    ErrorThreshold(f64),
41    /// Adaptive update frequency
42    Adaptive {
43        initial: usize,
44        max: usize,
45        min: usize,
46    },
47}
48
49/// Forgetting mechanism for old data
50#[derive(Debug, Clone)]
51pub enum ForgettingMechanism {
52    /// No forgetting - keep all data
53    None,
54    /// Linear decay of old samples
55    LinearDecay(f64),
56    /// Exponential decay of old samples
57    ExponentialDecay(f64),
58    /// Abrupt forgetting after time window
59    AbruptForgetting(f64),
60    /// Gradual forgetting with sigmoid function
61    SigmoidDecay { steepness: f64, midpoint: f64 },
62}
63
64/// Configuration for streaming kernel approximation
65#[derive(Debug, Clone)]
66pub struct StreamingConfig {
67    /// buffer_strategy
68    pub buffer_strategy: BufferStrategy,
69    /// update_frequency
70    pub update_frequency: UpdateFrequency,
71    /// forgetting_mechanism
72    pub forgetting_mechanism: ForgettingMechanism,
73    /// max_memory_mb
74    pub max_memory_mb: Option<usize>,
75    /// adaptive_components
76    pub adaptive_components: bool,
77    /// quality_monitoring
78    pub quality_monitoring: bool,
79    /// drift_detection
80    pub drift_detection: bool,
81    /// concept_drift_threshold
82    pub concept_drift_threshold: f64,
83}
84
85impl Default for StreamingConfig {
86    fn default() -> Self {
87        Self {
88            buffer_strategy: BufferStrategy::FixedSize(1000),
89            update_frequency: UpdateFrequency::BatchSize(100),
90            forgetting_mechanism: ForgettingMechanism::ExponentialDecay(0.99),
91            max_memory_mb: Some(100),
92            adaptive_components: true,
93            quality_monitoring: true,
94            drift_detection: false,
95            concept_drift_threshold: 0.1,
96        }
97    }
98}
99
100/// Sample with metadata for streaming processing
101#[derive(Debug, Clone)]
102pub struct StreamingSample {
103    /// data
104    pub data: Array1<f64>,
105    /// timestamp
106    pub timestamp: f64,
107    /// weight
108    pub weight: f64,
109    /// importance
110    pub importance: f64,
111    /// label
112    pub label: Option<f64>,
113}
114
115impl StreamingSample {
116    pub fn new(data: Array1<f64>, timestamp: f64) -> Self {
117        Self {
118            data,
119            timestamp,
120            weight: 1.0,
121            importance: 1.0,
122            label: None,
123        }
124    }
125
126    pub fn with_weight(mut self, weight: f64) -> Self {
127        self.weight = weight;
128        self
129    }
130
131    pub fn with_importance(mut self, importance: f64) -> Self {
132        self.importance = importance;
133        self
134    }
135
136    pub fn with_label(mut self, label: f64) -> Self {
137        self.label = Some(label);
138        self
139    }
140}
141
142/// Streaming RBF kernel approximation using Random Fourier Features
143///
144/// Maintains an online approximation of RBF kernel features that
145/// adapts to data streams with concept drift and memory constraints.
146pub struct StreamingRBFSampler {
147    n_components: usize,
148    gamma: f64,
149    config: StreamingConfig,
150    weights: Option<Array2<f64>>,
151    bias: Option<Array1<f64>>,
152    buffer: VecDeque<StreamingSample>,
153    sample_count: usize,
154    last_update: usize,
155    feature_statistics: FeatureStatistics,
156    random_state: Option<u64>,
157    rng: StdRng,
158}
159
160/// Statistics for monitoring feature quality
161#[derive(Debug, Clone)]
162pub struct FeatureStatistics {
163    /// mean
164    pub mean: Array1<f64>,
165    /// variance
166    pub variance: Array1<f64>,
167    /// min
168    pub min: Array1<f64>,
169    /// max
170    pub max: Array1<f64>,
171    /// update_count
172    pub update_count: usize,
173    /// approximation_error
174    pub approximation_error: f64,
175    /// drift_score
176    pub drift_score: f64,
177}
178
179impl FeatureStatistics {
180    pub fn new(n_components: usize) -> Self {
181        Self {
182            mean: Array1::zeros(n_components),
183            variance: Array1::zeros(n_components),
184            min: Array1::from_elem(n_components, f64::INFINITY),
185            max: Array1::from_elem(n_components, f64::NEG_INFINITY),
186            update_count: 0,
187            approximation_error: 0.0,
188            drift_score: 0.0,
189        }
190    }
191
192    pub fn update(&mut self, features: &Array2<f64>) {
193        let n_samples = features.nrows();
194
195        for i in 0..features.ncols() {
196            let col = features.column(i);
197            let new_mean = col.mean().unwrap_or(0.0);
198            let new_var = col.mapv(|x| (x - new_mean).powi(2)).mean().unwrap_or(0.0);
199            let new_min = col.iter().fold(f64::INFINITY, |a, &b| a.min(b));
200            let new_max = col.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
201
202            // Online update of statistics
203            let old_count = self.update_count;
204            let new_count = old_count + n_samples;
205
206            if old_count == 0 {
207                self.mean[i] = new_mean;
208                self.variance[i] = new_var;
209            } else {
210                let alpha = n_samples as f64 / new_count as f64;
211                self.mean[i] = (1.0 - alpha) * self.mean[i] + alpha * new_mean;
212                self.variance[i] = (1.0 - alpha) * self.variance[i] + alpha * new_var;
213            }
214
215            self.min[i] = self.min[i].min(new_min);
216            self.max[i] = self.max[i].max(new_max);
217        }
218
219        self.update_count += n_samples;
220    }
221
222    pub fn detect_drift(&mut self, new_features: &Array2<f64>) -> bool {
223        let old_mean = self.mean.clone();
224        self.update(new_features);
225
226        // Simple drift detection based on mean shift
227        let mean_shift = (&self.mean - &old_mean).mapv(f64::abs).sum();
228        self.drift_score = mean_shift / self.mean.len() as f64;
229
230        self.drift_score > 0.1 // Simple threshold
231    }
232}
233
234impl StreamingRBFSampler {
235    /// Create a new streaming RBF sampler
236    pub fn new(n_components: usize, gamma: f64) -> Self {
237        let rng = StdRng::from_seed(thread_rng().random());
238        Self {
239            n_components,
240            gamma,
241            config: StreamingConfig::default(),
242            weights: None,
243            bias: None,
244            buffer: VecDeque::new(),
245            sample_count: 0,
246            last_update: 0,
247            feature_statistics: FeatureStatistics::new(n_components),
248            random_state: None,
249            rng,
250        }
251    }
252
253    /// Set the streaming configuration
254    pub fn with_config(mut self, config: StreamingConfig) -> Self {
255        self.config = config;
256        self
257    }
258
259    /// Set random state for reproducibility
260    pub fn with_random_state(mut self, random_state: u64) -> Self {
261        self.random_state = Some(random_state);
262        self.rng = StdRng::seed_from_u64(random_state);
263        self
264    }
265
266    /// Initialize the streaming sampler with initial data
267    pub fn fit(&mut self, x: &Array2<f64>) -> Result<()> {
268        let (_, n_features) = x.dim();
269
270        // Initialize random weights and bias
271        self.weights = Some(self.generate_weights(n_features)?);
272        self.bias = Some(self.generate_bias()?);
273
274        // Process initial batch
275        for (i, row) in x.rows().into_iter().enumerate() {
276            let sample = StreamingSample::new(row.to_owned(), i as f64);
277            self.add_sample(sample)?;
278        }
279
280        Ok(())
281    }
282
283    /// Add a new sample to the stream
284    pub fn add_sample(&mut self, sample: StreamingSample) -> Result<()> {
285        // Check if initialization is needed
286        if self.weights.is_none() {
287            let n_features = sample.data.len();
288            self.weights = Some(self.generate_weights(n_features)?);
289            self.bias = Some(self.generate_bias()?);
290        }
291
292        // Add to buffer based on strategy
293        self.manage_buffer(sample)?;
294
295        self.sample_count += 1;
296
297        // Check if update is needed
298        if self.should_update()? {
299            self.update_model()?;
300            self.last_update = self.sample_count;
301        }
302
303        Ok(())
304    }
305
306    /// Transform data using current model
307    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
308        let weights = self
309            .weights
310            .as_ref()
311            .ok_or_else(|| SklearsError::NotFitted {
312                operation: "transform".to_string(),
313            })?;
314        let bias = self.bias.as_ref().ok_or_else(|| SklearsError::NotFitted {
315            operation: "transform".to_string(),
316        })?;
317
318        self.compute_features(x, weights, bias)
319    }
320
321    /// Transform a single sample
322    pub fn transform_sample(&self, sample: &Array1<f64>) -> Result<Array1<f64>> {
323        let weights = self
324            .weights
325            .as_ref()
326            .ok_or_else(|| SklearsError::NotFitted {
327                operation: "transform_sample".to_string(),
328            })?;
329        let bias = self.bias.as_ref().ok_or_else(|| SklearsError::NotFitted {
330            operation: "transform_sample".to_string(),
331        })?;
332
333        // Compute features for single sample
334        let projection = sample.dot(&weights.t()) + bias;
335        let norm_factor = (2.0 / self.n_components as f64).sqrt();
336
337        Ok(projection.mapv(|x| norm_factor * x.cos()))
338    }
339
340    /// Get current buffer statistics
341    pub fn buffer_stats(&self) -> (usize, f64, f64) {
342        let size = self.buffer.len();
343        let avg_weight = if size > 0 {
344            self.buffer.iter().map(|s| s.weight).sum::<f64>() / size as f64
345        } else {
346            0.0
347        };
348        let avg_importance = if size > 0 {
349            self.buffer.iter().map(|s| s.importance).sum::<f64>() / size as f64
350        } else {
351            0.0
352        };
353
354        (size, avg_weight, avg_importance)
355    }
356
357    /// Get feature statistics
358    pub fn feature_stats(&self) -> &FeatureStatistics {
359        &self.feature_statistics
360    }
361
362    /// Check for concept drift
363    pub fn detect_drift(&mut self, x: &Array2<f64>) -> Result<bool> {
364        if !self.config.drift_detection {
365            return Ok(false);
366        }
367
368        let features = self.transform(x)?;
369        Ok(self.feature_statistics.detect_drift(&features))
370    }
371
372    /// Manage buffer based on strategy
373    fn manage_buffer(&mut self, sample: StreamingSample) -> Result<()> {
374        match &self.config.buffer_strategy {
375            BufferStrategy::FixedSize(max_size) => {
376                if self.buffer.len() >= *max_size {
377                    self.buffer.pop_front();
378                }
379                self.buffer.push_back(sample);
380            }
381            BufferStrategy::SlidingWindow { size, time_window } => {
382                // Remove old samples based on time window
383                let current_time = sample.timestamp;
384                while let Some(front) = self.buffer.front() {
385                    if current_time - front.timestamp > *time_window {
386                        self.buffer.pop_front();
387                    } else {
388                        break;
389                    }
390                }
391
392                // Add new sample and maintain size limit
393                if self.buffer.len() >= *size {
394                    self.buffer.pop_front();
395                }
396                self.buffer.push_back(sample);
397            }
398            BufferStrategy::ReservoirSampling(capacity) => {
399                if self.buffer.len() < *capacity {
400                    self.buffer.push_back(sample);
401                } else {
402                    let replace_idx = self.rng.random_range(0..self.sample_count + 1);
403                    if replace_idx < *capacity {
404                        self.buffer[replace_idx] = sample;
405                    }
406                }
407            }
408            BufferStrategy::ExponentialDecay { alpha, min_weight } => {
409                // Decay weights of existing samples
410                for existing_sample in &mut self.buffer {
411                    existing_sample.weight *= alpha;
412                }
413
414                // Remove samples below minimum weight
415                self.buffer.retain(|s| s.weight >= *min_weight);
416
417                self.buffer.push_back(sample);
418            }
419            BufferStrategy::ImportanceWeighted {
420                capacity,
421                threshold,
422            } => {
423                if self.buffer.len() < *capacity {
424                    self.buffer.push_back(sample);
425                } else {
426                    // Find sample with lowest importance
427                    if let Some((min_idx, _)) =
428                        self.buffer.iter().enumerate().min_by(|(_, a), (_, b)| {
429                            a.importance
430                                .partial_cmp(&b.importance)
431                                .expect("operation should succeed")
432                        })
433                    {
434                        if sample.importance > self.buffer[min_idx].importance + threshold {
435                            self.buffer[min_idx] = sample;
436                        }
437                    }
438                }
439            }
440        }
441
442        Ok(())
443    }
444
445    /// Check if model should be updated
446    fn should_update(&self) -> Result<bool> {
447        match &self.config.update_frequency {
448            UpdateFrequency::PerSample => Ok(true),
449            UpdateFrequency::BatchSize(batch_size) => {
450                Ok(self.sample_count - self.last_update >= *batch_size)
451            }
452            UpdateFrequency::TimeInterval(_time_interval) => {
453                // For simplicity, use sample count as proxy for time
454                Ok(self.sample_count - self.last_update >= 100)
455            }
456            UpdateFrequency::ErrorThreshold(_threshold) => {
457                // For simplicity, update periodically
458                Ok(self.sample_count - self.last_update >= 50)
459            }
460            UpdateFrequency::Adaptive {
461                initial,
462                max: _,
463                min: _,
464            } => Ok(self.sample_count - self.last_update >= *initial),
465        }
466    }
467
468    /// Update model parameters based on current buffer
469    fn update_model(&mut self) -> Result<()> {
470        if self.buffer.is_empty() {
471            return Ok(());
472        }
473
474        // Extract data from buffer with weights
475        let mut data_matrix = Array2::zeros((self.buffer.len(), self.buffer[0].data.len()));
476        for (i, sample) in self.buffer.iter().enumerate() {
477            data_matrix.row_mut(i).assign(&sample.data);
478        }
479
480        // Compute features and update statistics
481        let weights = self.weights.as_ref().expect("operation should succeed");
482        let bias = self.bias.as_ref().expect("operation should succeed");
483        let features = self.compute_features(&data_matrix, weights, bias)?;
484
485        self.feature_statistics.update(&features);
486
487        // Apply forgetting mechanism
488        self.apply_forgetting()?;
489
490        Ok(())
491    }
492
493    /// Apply forgetting mechanism to reduce influence of old data
494    fn apply_forgetting(&mut self) -> Result<()> {
495        match &self.config.forgetting_mechanism {
496            ForgettingMechanism::None => {
497                // No forgetting
498            }
499            ForgettingMechanism::LinearDecay(decay_rate) => {
500                for sample in &mut self.buffer {
501                    sample.weight *= 1.0 - decay_rate;
502                }
503            }
504            ForgettingMechanism::ExponentialDecay(decay_rate) => {
505                for sample in &mut self.buffer {
506                    sample.weight *= decay_rate;
507                }
508            }
509            ForgettingMechanism::AbruptForgetting(time_threshold) => {
510                if let Some(newest) = self.buffer.back() {
511                    let cutoff_time = newest.timestamp - time_threshold;
512                    self.buffer.retain(|s| s.timestamp >= cutoff_time);
513                }
514            }
515            ForgettingMechanism::SigmoidDecay {
516                steepness,
517                midpoint,
518            } => {
519                if let Some(newest_timestamp) = self.buffer.back().map(|s| s.timestamp) {
520                    for sample in &mut self.buffer {
521                        let age = newest_timestamp - sample.timestamp;
522                        let sigmoid_weight = 1.0 / (1.0 + (steepness * (age - midpoint)).exp());
523                        sample.weight *= sigmoid_weight;
524                    }
525                }
526            }
527        }
528
529        Ok(())
530    }
531
532    /// Generate random weights for RBF features
533    fn generate_weights(&mut self, n_features: usize) -> Result<Array2<f64>> {
534        let mut weights = Array2::zeros((self.n_components, n_features));
535
536        for i in 0..self.n_components {
537            for j in 0..n_features {
538                weights[[i, j]] =
539                    self.rng.sample::<f64, _>(StandardNormal) * (2.0 * self.gamma).sqrt();
540            }
541        }
542
543        Ok(weights)
544    }
545
546    /// Generate random bias for RBF features
547    fn generate_bias(&mut self) -> Result<Array1<f64>> {
548        let mut bias = Array1::zeros(self.n_components);
549
550        for i in 0..self.n_components {
551            bias[i] = self.rng.random_range(0.0..2.0 * std::f64::consts::PI);
552        }
553
554        Ok(bias)
555    }
556
557    /// Compute RBF features for given data
558    fn compute_features(
559        &self,
560        x: &Array2<f64>,
561        weights: &Array2<f64>,
562        bias: &Array1<f64>,
563    ) -> Result<Array2<f64>> {
564        let (n_samples, _) = x.dim();
565        let n_components = weights.nrows();
566
567        // Compute X @ W^T + b
568        let projection = x.dot(&weights.t()) + bias;
569
570        // Apply cosine transformation with normalization
571        let mut features = Array2::zeros((n_samples, n_components));
572        let norm_factor = (2.0 / n_components as f64).sqrt();
573
574        for i in 0..n_samples {
575            for j in 0..n_components {
576                features[[i, j]] = norm_factor * projection[[i, j]].cos();
577            }
578        }
579
580        Ok(features)
581    }
582}
583
584/// Streaming Nyström method for kernel approximation
585///
586/// Maintains an online Nyström approximation that adapts to
587/// streaming data with efficient inducing point management.
588pub struct StreamingNystroem {
589    n_components: usize,
590    gamma: f64,
591    config: StreamingConfig,
592    inducing_points: Option<Array2<f64>>,
593    eigenvalues: Option<Array1<f64>>,
594    eigenvectors: Option<Array2<f64>>,
595    buffer: VecDeque<StreamingSample>,
596    sample_count: usize,
597    last_update: usize,
598    random_state: Option<u64>,
599    rng: StdRng,
600}
601
602impl StreamingNystroem {
603    /// Create a new streaming Nyström approximation
604    pub fn new(n_components: usize, gamma: f64) -> Self {
605        let rng = StdRng::from_seed(thread_rng().random());
606        Self {
607            n_components,
608            gamma,
609            config: StreamingConfig::default(),
610            inducing_points: None,
611            eigenvalues: None,
612            eigenvectors: None,
613            buffer: VecDeque::new(),
614            sample_count: 0,
615            last_update: 0,
616            random_state: None,
617            rng,
618        }
619    }
620
621    /// Set the streaming configuration
622    pub fn with_config(mut self, config: StreamingConfig) -> Self {
623        self.config = config;
624        self
625    }
626
627    /// Set random state for reproducibility
628    pub fn with_random_state(mut self, random_state: u64) -> Self {
629        self.random_state = Some(random_state);
630        self.rng = StdRng::seed_from_u64(random_state);
631        self
632    }
633
634    /// Initialize with initial data
635    pub fn fit(&mut self, x: &Array2<f64>) -> Result<()> {
636        // Select initial inducing points
637        let inducing_indices = self.select_inducing_points(x)?;
638        let inducing_points = x.select(Axis(0), &inducing_indices);
639
640        // Compute initial eigendecomposition
641        let kernel_matrix = self.compute_kernel_matrix(&inducing_points)?;
642        let (eigenvalues, eigenvectors) = self.eigendecomposition(&kernel_matrix)?;
643
644        self.inducing_points = Some(inducing_points);
645        self.eigenvalues = Some(eigenvalues);
646        self.eigenvectors = Some(eigenvectors);
647
648        // Add samples to buffer
649        for (i, row) in x.rows().into_iter().enumerate() {
650            let sample = StreamingSample::new(row.to_owned(), i as f64);
651            self.buffer.push_back(sample);
652        }
653
654        self.sample_count = x.nrows();
655
656        Ok(())
657    }
658
659    /// Add a new sample to the stream
660    pub fn add_sample(&mut self, sample: StreamingSample) -> Result<()> {
661        self.buffer.push_back(sample);
662        self.sample_count += 1;
663
664        // Manage buffer size
665        match &self.config.buffer_strategy {
666            BufferStrategy::FixedSize(max_size) => {
667                if self.buffer.len() > *max_size {
668                    self.buffer.pop_front();
669                }
670            }
671            _ => {
672                // Implement other buffer strategies as needed
673            }
674        }
675
676        // Check if update is needed
677        if self.should_update()? {
678            self.update_model()?;
679            self.last_update = self.sample_count;
680        }
681
682        Ok(())
683    }
684
685    /// Transform data using current approximation
686    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
687        let inducing_points =
688            self.inducing_points
689                .as_ref()
690                .ok_or_else(|| SklearsError::NotFitted {
691                    operation: "transform".to_string(),
692                })?;
693        let eigenvalues = self
694            .eigenvalues
695            .as_ref()
696            .ok_or_else(|| SklearsError::NotFitted {
697                operation: "transform".to_string(),
698            })?;
699        let eigenvectors = self
700            .eigenvectors
701            .as_ref()
702            .ok_or_else(|| SklearsError::NotFitted {
703                operation: "transform".to_string(),
704            })?;
705
706        // Compute kernel between x and inducing points
707        let kernel_x_inducing = self.compute_kernel(x, inducing_points)?;
708
709        // Apply Nyström transformation
710        let mut features = kernel_x_inducing.dot(eigenvectors);
711
712        // Scale by eigenvalues
713        for i in 0..eigenvalues.len() {
714            if eigenvalues[i] > 1e-12 {
715                let scale = 1.0 / eigenvalues[i].sqrt();
716                for j in 0..features.nrows() {
717                    features[[j, i]] *= scale;
718                }
719            }
720        }
721
722        Ok(features)
723    }
724
725    /// Check if model should be updated
726    fn should_update(&self) -> Result<bool> {
727        // Simple heuristic: update every 100 samples
728        Ok(self.sample_count - self.last_update >= 100)
729    }
730
731    /// Update inducing points and eigendecomposition
732    fn update_model(&mut self) -> Result<()> {
733        if self.buffer.is_empty() {
734            return Ok(());
735        }
736
737        // Extract current data from buffer
738        let n_samples = self.buffer.len();
739        let n_features = self.buffer[0].data.len();
740        let mut data_matrix = Array2::zeros((n_samples, n_features));
741
742        for (i, sample) in self.buffer.iter().enumerate() {
743            data_matrix.row_mut(i).assign(&sample.data);
744        }
745
746        // Reselect inducing points
747        let inducing_indices = self.select_inducing_points(&data_matrix)?;
748        let inducing_points = data_matrix.select(Axis(0), &inducing_indices);
749
750        // Recompute eigendecomposition
751        let kernel_matrix = self.compute_kernel_matrix(&inducing_points)?;
752        let (eigenvalues, eigenvectors) = self.eigendecomposition(&kernel_matrix)?;
753
754        self.inducing_points = Some(inducing_points);
755        self.eigenvalues = Some(eigenvalues);
756        self.eigenvectors = Some(eigenvectors);
757
758        Ok(())
759    }
760
761    /// Select inducing points from current data
762    fn select_inducing_points(&mut self, x: &Array2<f64>) -> Result<Vec<usize>> {
763        let n_samples = x.nrows();
764        let n_inducing = self.n_components.min(n_samples);
765
766        let mut indices = Vec::new();
767        for _ in 0..n_inducing {
768            indices.push(self.rng.random_range(0..n_samples));
769        }
770
771        Ok(indices)
772    }
773
774    /// Compute kernel matrix
775    fn compute_kernel_matrix(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
776        let n_samples = x.nrows();
777        let mut kernel_matrix = Array2::zeros((n_samples, n_samples));
778
779        for i in 0..n_samples {
780            for j in i..n_samples {
781                let diff = &x.row(i) - &x.row(j);
782                let squared_dist = diff.mapv(|x| x * x).sum();
783                let kernel_val = (-self.gamma * squared_dist).exp();
784                kernel_matrix[[i, j]] = kernel_val;
785                kernel_matrix[[j, i]] = kernel_val;
786            }
787        }
788
789        Ok(kernel_matrix)
790    }
791
792    /// Compute kernel between two matrices
793    fn compute_kernel(&self, x: &Array2<f64>, y: &Array2<f64>) -> Result<Array2<f64>> {
794        let (n_samples_x, _) = x.dim();
795        let (n_samples_y, _) = y.dim();
796        let mut kernel_matrix = Array2::zeros((n_samples_x, n_samples_y));
797
798        for i in 0..n_samples_x {
799            for j in 0..n_samples_y {
800                let diff = &x.row(i) - &y.row(j);
801                let squared_dist = diff.mapv(|x| x * x).sum();
802                let kernel_val = (-self.gamma * squared_dist).exp();
803                kernel_matrix[[i, j]] = kernel_val;
804            }
805        }
806
807        Ok(kernel_matrix)
808    }
809
810    /// Perform eigendecomposition (simplified)
811    fn eigendecomposition(&self, matrix: &Array2<f64>) -> Result<(Array1<f64>, Array2<f64>)> {
812        let n = matrix.nrows();
813        let eigenvalues = Array1::ones(self.n_components.min(n));
814        let eigenvectors = Array2::eye(n)
815            .slice(s![.., ..self.n_components.min(n)])
816            .to_owned();
817
818        Ok((eigenvalues, eigenvectors))
819    }
820}
821
822#[allow(non_snake_case)]
823#[cfg(test)]
824mod tests {
825    use super::*;
826    use scirs2_core::ndarray::array;
827
828    #[test]
829    fn test_streaming_rbf_sampler_basic() {
830        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
831
832        let mut sampler = StreamingRBFSampler::new(50, 0.1).with_random_state(42);
833
834        sampler.fit(&x).expect("operation should succeed");
835        let features = sampler.transform(&x).expect("operation should succeed");
836
837        assert_eq!(features.nrows(), 4);
838        assert_eq!(features.ncols(), 50);
839    }
840
841    #[test]
842    fn test_streaming_sample() {
843        let data = array![1.0, 2.0, 3.0];
844        let sample = StreamingSample::new(data.clone(), 1.0)
845            .with_weight(0.8)
846            .with_importance(0.9)
847            .with_label(1.0);
848
849        assert_eq!(sample.data, data);
850        assert_eq!(sample.timestamp, 1.0);
851        assert_eq!(sample.weight, 0.8);
852        assert_eq!(sample.importance, 0.9);
853        assert_eq!(sample.label, Some(1.0));
854    }
855
856    #[test]
857    fn test_buffer_strategies() {
858        let mut sampler = StreamingRBFSampler::new(10, 0.1).with_config(StreamingConfig {
859            buffer_strategy: BufferStrategy::FixedSize(3),
860            ..Default::default()
861        });
862
863        // Add samples beyond buffer capacity
864        for i in 0..5 {
865            let data = array![i as f64, (i + 1) as f64];
866            let sample = StreamingSample::new(data, i as f64);
867            sampler
868                .add_sample(sample)
869                .expect("operation should succeed");
870        }
871
872        let (size, _, _) = sampler.buffer_stats();
873        assert_eq!(size, 3); // Buffer should be limited to 3
874    }
875
876    #[test]
877    fn test_streaming_nystroem_basic() {
878        let x = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]];
879
880        let mut nystroem = StreamingNystroem::new(3, 0.1).with_random_state(42);
881
882        nystroem.fit(&x).expect("operation should succeed");
883        let features = nystroem.transform(&x).expect("operation should succeed");
884
885        assert_eq!(features.nrows(), 4);
886        assert_eq!(features.ncols(), 3);
887    }
888
889    #[test]
890    fn test_feature_statistics() {
891        let mut stats = FeatureStatistics::new(3);
892        let features = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
893
894        stats.update(&features);
895
896        assert_eq!(stats.update_count, 2);
897        assert!((stats.mean[0] - 2.5).abs() < 1e-10);
898        assert!((stats.mean[1] - 3.5).abs() < 1e-10);
899        assert!((stats.mean[2] - 4.5).abs() < 1e-10);
900    }
901
902    #[test]
903    fn test_transform_sample() {
904        let x = array![[1.0, 2.0], [3.0, 4.0]];
905        let mut sampler = StreamingRBFSampler::new(10, 0.1).with_random_state(42);
906
907        sampler.fit(&x).expect("operation should succeed");
908
909        let sample = array![5.0, 6.0];
910        let features = sampler
911            .transform_sample(&sample)
912            .expect("operation should succeed");
913
914        assert_eq!(features.len(), 10);
915    }
916
917    #[test]
918    fn test_streaming_config() {
919        let config = StreamingConfig {
920            buffer_strategy: BufferStrategy::SlidingWindow {
921                size: 100,
922                time_window: 10.0,
923            },
924            update_frequency: UpdateFrequency::BatchSize(50),
925            forgetting_mechanism: ForgettingMechanism::LinearDecay(0.01),
926            adaptive_components: true,
927            ..Default::default()
928        };
929
930        assert!(matches!(
931            config.buffer_strategy,
932            BufferStrategy::SlidingWindow { .. }
933        ));
934        assert!(matches!(
935            config.update_frequency,
936            UpdateFrequency::BatchSize(50)
937        ));
938        assert!(config.adaptive_components);
939    }
940
941    #[test]
942    fn test_online_updates() {
943        let mut sampler = StreamingRBFSampler::new(20, 0.1)
944            .with_config(StreamingConfig {
945                update_frequency: UpdateFrequency::BatchSize(2),
946                ..Default::default()
947            })
948            .with_random_state(42);
949
950        // Initialize with small batch
951        let x_init = array![[1.0, 2.0], [3.0, 4.0]];
952        sampler.fit(&x_init).expect("operation should succeed");
953
954        // Add samples one by one
955        for i in 5..10 {
956            let data = array![i as f64, (i + 1) as f64];
957            let sample = StreamingSample::new(data, i as f64);
958            sampler
959                .add_sample(sample)
960                .expect("operation should succeed");
961        }
962
963        let (buffer_size, _, _) = sampler.buffer_stats();
964        assert!(buffer_size > 0);
965    }
966
967    #[test]
968    fn test_drift_detection() {
969        let x1 = array![[1.0, 2.0], [1.1, 2.1], [0.9, 1.9]];
970        let x2 = array![[5.0, 6.0], [5.1, 6.1], [4.9, 5.9]]; // Different distribution
971
972        let mut sampler = StreamingRBFSampler::new(20, 0.1)
973            .with_config(StreamingConfig {
974                drift_detection: true,
975                ..Default::default()
976            })
977            .with_random_state(42);
978
979        sampler.fit(&x1).expect("operation should succeed");
980        let _drift_detected = sampler.detect_drift(&x2).expect("operation should succeed");
981
982        // Should detect some drift in feature statistics
983        assert!(sampler.feature_stats().drift_score >= 0.0);
984    }
985}