sklears_semi_supervised/
streaming_graph_learning.rs

1//! Streaming Graph Learning for Dynamic Semi-Supervised Learning
2//!
3//! This module provides algorithms for learning and updating graph structures
4//! incrementally as new data arrives in streaming scenarios.
5
6use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8    error::{Result as SklResult, SklearsError},
9    traits::{Estimator, Fit, Predict, PredictProba, Untrained},
10    types::Float,
11};
12use std::collections::{HashMap, VecDeque};
13
14/// Streaming Graph Learning for Dynamic Semi-Supervised Learning
15///
16/// This method continuously updates graph structures as new data points arrive,
17/// making it suitable for dynamic environments where the data distribution
18/// may change over time. It maintains a sliding window of recent data points
19/// and efficiently updates the graph structure and label propagation.
20///
21/// # Parameters
22///
23/// * `window_size` - Size of the sliding window for maintaining recent data
24/// * `lambda_sparse` - Sparsity regularization parameter for graph learning
25/// * `alpha_decay` - Decay factor for edge weights over time
26/// * `update_frequency` - Frequency of full graph reconstruction
27/// * `forgetting_factor` - Factor for exponential forgetting of old connections
28/// * `adaptive_threshold` - Whether to use adaptive thresholds for edge addition
29/// * `min_samples_update` - Minimum samples required before updating the graph
30///
31/// # Examples
32///
33/// ```
34/// use scirs2_core::array;
35/// use sklears_semi_supervised::StreamingGraphLearning;
36/// use sklears_core::traits::{Predict, Fit};
37///
38///
39/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
40/// let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
41///
42/// let mut sgl = StreamingGraphLearning::new()
43///     .window_size(100)
44///     .lambda_sparse(0.1)
45///     .alpha_decay(0.95);
46///
47/// let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
48/// let predictions = fitted.predict(&X.view()).unwrap();
49///
50/// // Update with new data
51/// let X_new = array![[5.0, 6.0], [6.0, 7.0]];
52/// let y_new = array![-1, 0];
53/// let updated = fitted.update(&X_new.view(), &y_new.view()).unwrap();
54/// ```
55#[derive(Debug, Clone)]
56pub struct StreamingGraphLearning<S = Untrained> {
57    state: S,
58    window_size: usize,
59    lambda_sparse: f64,
60    alpha_decay: f64,
61    update_frequency: usize,
62    forgetting_factor: f64,
63    adaptive_threshold: bool,
64    min_samples_update: usize,
65    k_neighbors: usize,
66    similarity_threshold: f64,
67}
68
69impl StreamingGraphLearning<Untrained> {
70    /// Create a new StreamingGraphLearning instance
71    pub fn new() -> Self {
72        Self {
73            state: Untrained,
74            window_size: 1000,
75            lambda_sparse: 0.1,
76            alpha_decay: 0.95,
77            update_frequency: 50,
78            forgetting_factor: 0.99,
79            adaptive_threshold: true,
80            min_samples_update: 10,
81            k_neighbors: 5,
82            similarity_threshold: 0.5,
83        }
84    }
85
86    /// Set the sliding window size
87    pub fn window_size(mut self, window_size: usize) -> Self {
88        self.window_size = window_size;
89        self
90    }
91
92    /// Set the sparsity regularization parameter
93    pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
94        self.lambda_sparse = lambda_sparse;
95        self
96    }
97
98    /// Set the decay factor for edge weights
99    pub fn alpha_decay(mut self, alpha_decay: f64) -> Self {
100        self.alpha_decay = alpha_decay;
101        self
102    }
103
104    /// Set the frequency of full graph reconstruction
105    pub fn update_frequency(mut self, frequency: usize) -> Self {
106        self.update_frequency = frequency;
107        self
108    }
109
110    /// Set the forgetting factor for old connections
111    pub fn forgetting_factor(mut self, factor: f64) -> Self {
112        self.forgetting_factor = factor;
113        self
114    }
115
116    /// Enable/disable adaptive threshold for edge addition
117    pub fn adaptive_threshold(mut self, adaptive: bool) -> Self {
118        self.adaptive_threshold = adaptive;
119        self
120    }
121
122    /// Set minimum samples required before updating the graph
123    pub fn min_samples_update(mut self, min_samples: usize) -> Self {
124        self.min_samples_update = min_samples;
125        self
126    }
127
128    /// Set the number of nearest neighbors to consider
129    pub fn k_neighbors(mut self, k: usize) -> Self {
130        self.k_neighbors = k;
131        self
132    }
133
134    /// Set the similarity threshold for edge creation
135    pub fn similarity_threshold(mut self, threshold: f64) -> Self {
136        self.similarity_threshold = threshold;
137        self
138    }
139
140    fn compute_similarity(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
141        let diff = x1 - x2;
142        let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
143        (-dist / (2.0 * 1.0_f64.powi(2))).exp()
144    }
145
146    fn build_initial_graph(&self, X: &Array2<f64>) -> Array2<f64> {
147        let n_samples = X.nrows();
148        let mut W = Array2::zeros((n_samples, n_samples));
149
150        for i in 0..n_samples {
151            let mut similarities: Vec<(usize, f64)> = Vec::new();
152
153            for j in 0..n_samples {
154                if i != j {
155                    let sim = self.compute_similarity(&X.row(i), &X.row(j));
156                    similarities.push((j, sim));
157                }
158            }
159
160            // Sort by similarity (descending)
161            similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
162
163            // Connect to k nearest neighbors
164            for &(j, sim) in similarities.iter().take(self.k_neighbors) {
165                if sim > self.similarity_threshold {
166                    W[[i, j]] = sim;
167                    W[[j, i]] = sim; // Ensure symmetry
168                }
169            }
170        }
171
172        // Apply sparsity threshold
173        let threshold = self.lambda_sparse;
174        W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
175        W.mapv_inplace(|x| x.max(0.0));
176
177        // Zero diagonal
178        for i in 0..n_samples {
179            W[[i, i]] = 0.0;
180        }
181
182        W
183    }
184
185    #[allow(non_snake_case)]
186    fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
187        let n_samples = W.nrows();
188        let n_classes = Y_init.ncols();
189
190        // Compute transition matrix
191        let D = W.sum_axis(Axis(1));
192        let mut P = Array2::zeros((n_samples, n_samples));
193        for i in 0..n_samples {
194            if D[i] > 0.0 {
195                for j in 0..n_samples {
196                    P[[i, j]] = W[[i, j]] / D[i];
197                }
198            }
199        }
200
201        let mut Y = Y_init.clone();
202        let Y_static = Y_init.clone();
203
204        // Label propagation iterations
205        for _iter in 0..30 {
206            let prev_Y = Y.clone();
207            Y = 0.8 * P.dot(&Y) + 0.2 * &Y_static;
208
209            // Check convergence
210            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
211            if diff < 1e-6 {
212                break;
213            }
214        }
215
216        Ok(Y)
217    }
218}
219
220impl Default for StreamingGraphLearning<Untrained> {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226impl Estimator for StreamingGraphLearning<Untrained> {
227    type Config = ();
228    type Error = SklearsError;
229    type Float = Float;
230
231    fn config(&self) -> &Self::Config {
232        &()
233    }
234}
235
236impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for StreamingGraphLearning<Untrained> {
237    type Fitted = StreamingGraphLearning<StreamingGraphLearningTrained>;
238
239    #[allow(non_snake_case)]
240    fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
241        let X = X.to_owned();
242        let y = y.to_owned();
243        let (n_samples, n_features) = X.dim();
244
245        // Identify labeled samples and classes
246        let mut labeled_indices = Vec::new();
247        let mut classes = std::collections::HashSet::new();
248
249        for (i, &label) in y.iter().enumerate() {
250            if label != -1 {
251                labeled_indices.push(i);
252                classes.insert(label);
253            }
254        }
255
256        if labeled_indices.is_empty() {
257            return Err(SklearsError::InvalidInput(
258                "No labeled samples provided".to_string(),
259            ));
260        }
261
262        let classes: Vec<i32> = classes.into_iter().collect();
263        let n_classes = classes.len();
264
265        // Build initial graph
266        let W = self.build_initial_graph(&X);
267
268        // Initialize label matrix
269        let mut Y = Array2::zeros((n_samples, n_classes));
270        for &idx in &labeled_indices {
271            if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
272                Y[[idx, class_idx]] = 1.0;
273            }
274        }
275
276        // Perform initial label propagation
277        let Y_final = self.propagate_labels(&W, &Y)?;
278
279        // Initialize sliding window with current data
280        let mut data_window = VecDeque::with_capacity(self.window_size);
281        let mut label_window = VecDeque::with_capacity(self.window_size);
282
283        for i in 0..n_samples {
284            data_window.push_back(X.row(i).to_owned());
285            label_window.push_back(y[i]);
286        }
287
288        Ok(StreamingGraphLearning {
289            state: StreamingGraphLearningTrained {
290                X_train: X,
291                y_train: y,
292                classes: Array1::from(classes),
293                current_graph: W,
294                label_distributions: Y_final,
295                data_window,
296                label_window,
297                update_count: 0,
298                edge_ages: HashMap::new(),
299                adaptive_threshold_value: self.similarity_threshold,
300            },
301            window_size: self.window_size,
302            lambda_sparse: self.lambda_sparse,
303            alpha_decay: self.alpha_decay,
304            update_frequency: self.update_frequency,
305            forgetting_factor: self.forgetting_factor,
306            adaptive_threshold: self.adaptive_threshold,
307            min_samples_update: self.min_samples_update,
308            k_neighbors: self.k_neighbors,
309            similarity_threshold: self.similarity_threshold,
310        })
311    }
312}
313
314impl StreamingGraphLearning<StreamingGraphLearningTrained> {
315    fn compute_similarity(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
316        let diff = x1 - x2;
317        let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
318        (-dist / (2.0 * 1.0_f64.powi(2))).exp()
319    }
320
321    fn build_initial_graph(&self, X: &Array2<f64>) -> Array2<f64> {
322        let n_samples = X.nrows();
323        let mut W = Array2::zeros((n_samples, n_samples));
324
325        for i in 0..n_samples {
326            let mut similarities: Vec<(usize, f64)> = Vec::new();
327
328            for j in 0..n_samples {
329                if i != j {
330                    let sim = self.compute_similarity(&X.row(i), &X.row(j));
331                    similarities.push((j, sim));
332                }
333            }
334
335            // Sort by similarity (descending)
336            similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
337
338            // Connect to k nearest neighbors
339            for &(j, sim) in similarities.iter().take(self.k_neighbors) {
340                if sim > self.similarity_threshold {
341                    W[[i, j]] = sim;
342                    W[[j, i]] = sim; // Ensure symmetry
343                }
344            }
345        }
346
347        // Apply sparsity threshold
348        let threshold = self.lambda_sparse;
349        W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
350        W.mapv_inplace(|x| x.max(0.0));
351
352        // Zero diagonal
353        for i in 0..n_samples {
354            W[[i, i]] = 0.0;
355        }
356
357        W
358    }
359
360    #[allow(non_snake_case)]
361    fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
362        let n_samples = W.nrows();
363        let n_classes = Y_init.ncols();
364
365        // Compute transition matrix
366        let D = W.sum_axis(Axis(1));
367        let mut P = Array2::zeros((n_samples, n_samples));
368        for i in 0..n_samples {
369            if D[i] > 0.0 {
370                for j in 0..n_samples {
371                    P[[i, j]] = W[[i, j]] / D[i];
372                }
373            }
374        }
375
376        let mut Y = Y_init.clone();
377        let Y_static = Y_init.clone();
378
379        // Label propagation iterations
380        for _iter in 0..30 {
381            let prev_Y = Y.clone();
382            Y = 0.8 * P.dot(&Y) + 0.2 * &Y_static;
383
384            // Check convergence
385            let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
386            if diff < 1e-6 {
387                break;
388            }
389        }
390
391        Ok(Y)
392    }
393    /// Update the model with new streaming data
394    #[allow(non_snake_case)]
395    pub fn update(
396        &mut self,
397        X_new: &ArrayView2<'_, Float>,
398        y_new: &ArrayView1<'_, i32>,
399    ) -> SklResult<()> {
400        let X_new = X_new.to_owned();
401        let y_new = y_new.to_owned();
402        let (n_new, _) = X_new.dim();
403
404        // Add new data to sliding window
405        for i in 0..n_new {
406            // Remove oldest data if window is full
407            if self.state.data_window.len() >= self.window_size {
408                self.state.data_window.pop_front();
409                self.state.label_window.pop_front();
410            }
411
412            self.state.data_window.push_back(X_new.row(i).to_owned());
413            self.state.label_window.push_back(y_new[i]);
414        }
415
416        self.state.update_count += n_new;
417
418        // Decay existing edge weights
419        self.state
420            .current_graph
421            .mapv_inplace(|x| x * self.alpha_decay);
422
423        // Update adaptive threshold if enabled
424        if self.adaptive_threshold {
425            self.update_adaptive_threshold();
426        }
427
428        // Age all edges
429        let mut aged_edges = HashMap::new();
430        for ((i, j), age) in &self.state.edge_ages {
431            aged_edges.insert((*i, *j), age + 1);
432        }
433        self.state.edge_ages = aged_edges;
434
435        // Incremental graph update
436        self.incremental_graph_update(&X_new, &y_new)?;
437
438        // Full reconstruction if update frequency is reached
439        if self.state.update_count % self.update_frequency == 0 {
440            self.full_graph_reconstruction()?;
441        }
442
443        Ok(())
444    }
445
446    fn update_adaptive_threshold(&mut self) {
447        let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
448        if current_data.len() < 2 {
449            return;
450        }
451
452        let mut similarities = Vec::new();
453        for i in 0..current_data.len().min(100) {
454            for j in (i + 1)..current_data.len().min(100) {
455                let sim = self.compute_similarity(&current_data[i].view(), &current_data[j].view());
456                similarities.push(sim);
457            }
458        }
459
460        if !similarities.is_empty() {
461            similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
462            let median_idx = similarities.len() / 2;
463            self.state.adaptive_threshold_value = similarities[median_idx] * 0.8;
464        }
465    }
466
467    fn incremental_graph_update(
468        &mut self,
469        X_new: &Array2<f64>,
470        y_new: &Array1<i32>,
471    ) -> SklResult<()> {
472        let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
473        let current_labels: Vec<i32> = self.state.label_window.iter().cloned().collect();
474        let n_current = current_data.len();
475        let n_new = X_new.nrows();
476
477        // Extend current graph to accommodate new nodes
478        let mut new_graph = Array2::zeros((n_current, n_current));
479
480        // Copy existing graph (with aging applied)
481        let old_size = self.state.current_graph.nrows().min(n_current);
482        for i in 0..old_size {
483            for j in 0..old_size {
484                new_graph[[i, j]] = self.state.current_graph[[i, j]];
485            }
486        }
487
488        // Add connections for new nodes
489        let start_idx = n_current - n_new;
490        for i in start_idx..n_current {
491            let mut similarities: Vec<(usize, f64)> = Vec::new();
492
493            for j in 0..n_current {
494                if i != j {
495                    let sim =
496                        self.compute_similarity(&current_data[i].view(), &current_data[j].view());
497                    similarities.push((j, sim));
498                }
499            }
500
501            // Sort by similarity (descending)
502            similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
503
504            // Connect to k nearest neighbors
505            let threshold = if self.adaptive_threshold {
506                self.state.adaptive_threshold_value
507            } else {
508                self.similarity_threshold
509            };
510
511            for &(j, sim) in similarities.iter().take(self.k_neighbors) {
512                if sim > threshold {
513                    new_graph[[i, j]] = sim;
514                    new_graph[[j, i]] = sim; // Ensure symmetry
515
516                    // Track edge age
517                    self.state.edge_ages.insert((i, j), 0);
518                    self.state.edge_ages.insert((j, i), 0);
519                }
520            }
521        }
522
523        // Apply forgetting to old edges
524        for ((i, j), age) in &self.state.edge_ages {
525            if *i < n_current && *j < n_current {
526                let forgetting_weight = self.forgetting_factor.powi(*age as i32);
527                new_graph[[*i, *j]] *= forgetting_weight;
528            }
529        }
530
531        // Apply sparsity threshold
532        let threshold = self.lambda_sparse;
533        new_graph.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
534        new_graph.mapv_inplace(|x| x.max(0.0));
535
536        // Zero diagonal
537        for i in 0..n_current {
538            new_graph[[i, i]] = 0.0;
539        }
540
541        self.state.current_graph = new_graph;
542
543        // Update label propagation
544        self.update_label_propagation(&current_data, &current_labels)?;
545
546        Ok(())
547    }
548
549    fn full_graph_reconstruction(&mut self) -> SklResult<()> {
550        let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
551        let current_labels: Vec<i32> = self.state.label_window.iter().cloned().collect();
552
553        if current_data.is_empty() {
554            return Ok(());
555        }
556
557        let n_samples = current_data.len();
558
559        // Convert data to Array2
560        let mut X = Array2::zeros((n_samples, current_data[0].len()));
561        for (i, data_point) in current_data.iter().enumerate() {
562            X.row_mut(i).assign(data_point);
563        }
564
565        // Rebuild graph from scratch
566        self.state.current_graph = self.build_initial_graph(&X);
567
568        // Clear edge ages
569        self.state.edge_ages.clear();
570
571        // Update label propagation
572        self.update_label_propagation(&current_data, &current_labels)?;
573
574        Ok(())
575    }
576
577    #[allow(non_snake_case)]
578    fn update_label_propagation(
579        &mut self,
580        current_data: &[Array1<f64>],
581        current_labels: &[i32],
582    ) -> SklResult<()> {
583        let n_samples = current_data.len();
584        let n_classes = self.state.classes.len();
585
586        if n_samples == 0 {
587            return Ok(());
588        }
589
590        // Initialize label matrix
591        let mut Y = Array2::zeros((n_samples, n_classes));
592        for (i, &label) in current_labels.iter().enumerate() {
593            if label != -1 {
594                if let Some(class_idx) = self.state.classes.iter().position(|&c| c == label) {
595                    Y[[i, class_idx]] = 1.0;
596                }
597            }
598        }
599
600        // Perform label propagation
601        let Y_final = self.propagate_labels(&self.state.current_graph, &Y)?;
602        self.state.label_distributions = Y_final;
603
604        Ok(())
605    }
606}
607
608impl Predict<ArrayView2<'_, Float>, Array1<i32>>
609    for StreamingGraphLearning<StreamingGraphLearningTrained>
610{
611    #[allow(non_snake_case)]
612    fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
613        let X = X.to_owned();
614        let n_test = X.nrows();
615        let mut predictions = Array1::zeros(n_test);
616
617        let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
618
619        for i in 0..n_test {
620            let mut max_sim = -1.0;
621            let mut best_idx = 0;
622
623            // Find most similar sample in current window
624            for (j, data_point) in current_data.iter().enumerate() {
625                let sim = self.compute_similarity(&X.row(i), &data_point.view());
626                if sim > max_sim {
627                    max_sim = sim;
628                    best_idx = j;
629                }
630            }
631
632            // Use the label distribution of the most similar sample
633            if best_idx < self.state.label_distributions.nrows() {
634                let distributions = self.state.label_distributions.row(best_idx);
635                let max_idx = distributions
636                    .iter()
637                    .enumerate()
638                    .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
639                    .unwrap()
640                    .0;
641
642                predictions[i] = self.state.classes[max_idx];
643            }
644        }
645
646        Ok(predictions)
647    }
648}
649
650impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
651    for StreamingGraphLearning<StreamingGraphLearningTrained>
652{
653    #[allow(non_snake_case)]
654    fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
655        let X = X.to_owned();
656        let n_test = X.nrows();
657        let n_classes = self.state.classes.len();
658        let mut probas = Array2::zeros((n_test, n_classes));
659
660        let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
661
662        for i in 0..n_test {
663            let mut max_sim = -1.0;
664            let mut best_idx = 0;
665
666            // Find most similar sample in current window
667            for (j, data_point) in current_data.iter().enumerate() {
668                let sim = self.compute_similarity(&X.row(i), &data_point.view());
669                if sim > max_sim {
670                    max_sim = sim;
671                    best_idx = j;
672                }
673            }
674
675            // Copy the label distribution
676            if best_idx < self.state.label_distributions.nrows() {
677                for k in 0..n_classes {
678                    probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
679                }
680            }
681        }
682
683        Ok(probas)
684    }
685}
686
687/// Trained state for StreamingGraphLearning
688#[derive(Debug, Clone)]
689pub struct StreamingGraphLearningTrained {
690    /// X_train
691    pub X_train: Array2<f64>,
692    /// y_train
693    pub y_train: Array1<i32>,
694    /// classes
695    pub classes: Array1<i32>,
696    /// current_graph
697    pub current_graph: Array2<f64>,
698    /// label_distributions
699    pub label_distributions: Array2<f64>,
700    /// data_window
701    pub data_window: VecDeque<Array1<f64>>,
702    /// label_window
703    pub label_window: VecDeque<i32>,
704    /// update_count
705    pub update_count: usize,
706    /// edge_ages
707    pub edge_ages: HashMap<(usize, usize), usize>,
708    /// adaptive_threshold_value
709    pub adaptive_threshold_value: f64,
710}
711
712#[allow(non_snake_case)]
713#[cfg(test)]
714mod tests {
715    use super::*;
716    use scirs2_core::array;
717
718    #[test]
719    #[allow(non_snake_case)]
720    fn test_streaming_graph_learning_basic() {
721        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
722        let y = array![0, 1, -1, -1]; // -1 indicates unlabeled
723
724        let sgl = StreamingGraphLearning::new()
725            .window_size(10)
726            .lambda_sparse(0.1)
727            .alpha_decay(0.9)
728            .update_frequency(5);
729        let fitted = sgl.fit(&X.view(), &y.view()).unwrap();
730
731        let predictions = fitted.predict(&X.view()).unwrap();
732        assert_eq!(predictions.len(), 4);
733
734        let probas = fitted.predict_proba(&X.view()).unwrap();
735        assert_eq!(probas.dim(), (4, 2));
736
737        // Check that labeled samples maintain their labels
738        assert_eq!(predictions[0], 0);
739        assert_eq!(predictions[1], 1);
740    }
741
742    #[test]
743    #[allow(non_snake_case)]
744    fn test_streaming_graph_learning_update() {
745        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
746        let y = array![0, 1, -1, -1];
747
748        let sgl = StreamingGraphLearning::new()
749            .window_size(10)
750            .update_frequency(3)
751            .alpha_decay(0.95);
752        let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
753
754        // Initial graph size
755        let initial_graph_size = fitted.state.current_graph.dim();
756        assert_eq!(initial_graph_size, (4, 4));
757
758        // Add new streaming data
759        let X_new = array![[5.0, 6.0], [6.0, 7.0]];
760        let y_new = array![-1, 0];
761        fitted.update(&X_new.view(), &y_new.view()).unwrap();
762
763        // Check that data window is updated
764        assert_eq!(fitted.state.data_window.len(), 6);
765        assert_eq!(fitted.state.label_window.len(), 6);
766
767        // Graph should be updated to accommodate new data
768        let updated_graph_size = fitted.state.current_graph.dim();
769        assert_eq!(updated_graph_size, (6, 6));
770
771        // Test predictions with updated model
772        let predictions = fitted.predict(&X_new.view()).unwrap();
773        assert_eq!(predictions.len(), 2);
774    }
775
776    #[test]
777    #[allow(non_snake_case)]
778    fn test_streaming_graph_learning_window_overflow() {
779        let X = array![[1.0, 2.0], [2.0, 3.0]];
780        let y = array![0, 1];
781
782        let sgl = StreamingGraphLearning::new()
783            .window_size(3) // Small window size
784            .update_frequency(2);
785        let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
786
787        // Add more data than window size
788        let X_new1 = array![[3.0, 4.0]];
789        let y_new1 = array![-1];
790        fitted.update(&X_new1.view(), &y_new1.view()).unwrap();
791
792        let X_new2 = array![[4.0, 5.0]];
793        let y_new2 = array![0];
794        fitted.update(&X_new2.view(), &y_new2.view()).unwrap();
795
796        // Window should maintain size limit
797        assert_eq!(fitted.state.data_window.len(), 3);
798        assert_eq!(fitted.state.label_window.len(), 3);
799
800        // Should still be able to make predictions
801        let predictions = fitted.predict(&X_new2.view()).unwrap();
802        assert_eq!(predictions.len(), 1);
803    }
804
805    #[test]
806    #[allow(non_snake_case)]
807    fn test_streaming_graph_learning_adaptive_threshold() {
808        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
809        let y = array![0, 1, -1, -1];
810
811        let sgl = StreamingGraphLearning::new()
812            .window_size(10)
813            .adaptive_threshold(true)
814            .similarity_threshold(0.5);
815        let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
816
817        let initial_threshold = fitted.state.adaptive_threshold_value;
818
819        // Add new data with different characteristics
820        let X_new = array![[10.0, 20.0], [20.0, 30.0]];
821        let y_new = array![-1, 1];
822        fitted.update(&X_new.view(), &y_new.view()).unwrap();
823
824        // Adaptive threshold should potentially change
825        // (depends on the similarity distribution)
826        assert!(fitted.state.adaptive_threshold_value > 0.0);
827    }
828
829    #[test]
830    #[allow(non_snake_case)]
831    fn test_streaming_graph_learning_edge_aging() {
832        let X = array![[1.0, 2.0], [2.0, 3.0]];
833        let y = array![0, 1];
834
835        let sgl = StreamingGraphLearning::new()
836            .window_size(10)
837            .forgetting_factor(0.8)
838            .alpha_decay(0.9);
839        let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
840
841        // Check initial state
842        assert_eq!(fitted.state.update_count, 0);
843
844        // Add new data multiple times to age edges
845        for i in 0..3 {
846            let X_new = array![[3.0 + i as f64, 4.0 + i as f64]];
847            let y_new = array![-1];
848            fitted.update(&X_new.view(), &y_new.view()).unwrap();
849        }
850
851        // Update count should be incremented
852        assert_eq!(fitted.state.update_count, 3);
853
854        // Some edges should have aged
855        assert!(!fitted.state.edge_ages.is_empty());
856    }
857
858    #[test]
859    #[allow(non_snake_case)]
860    fn test_streaming_graph_learning_full_reconstruction() {
861        let X = array![[1.0, 2.0], [2.0, 3.0]];
862        let y = array![0, 1];
863
864        let sgl = StreamingGraphLearning::new()
865            .window_size(10)
866            .update_frequency(2); // Trigger full reconstruction frequently
867        let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
868
869        // Add data to trigger full reconstruction
870        let X_new1 = array![[3.0, 4.0]];
871        let y_new1 = array![-1];
872        fitted.update(&X_new1.view(), &y_new1.view()).unwrap();
873
874        let X_new2 = array![[4.0, 5.0]];
875        let y_new2 = array![0];
876        fitted.update(&X_new2.view(), &y_new2.view()).unwrap();
877
878        // Full reconstruction should have been triggered
879        // Edge ages should be cleared
880        assert!(
881            fitted.state.edge_ages.is_empty()
882                || fitted.state.edge_ages.values().all(|&age| age == 0)
883        );
884
885        // Should still be able to make predictions
886        let predictions = fitted.predict(&X_new2.view()).unwrap();
887        assert_eq!(predictions.len(), 1);
888    }
889}