1use scirs2_core::ndarray_ext::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::{Result as SklResult, SklearsError},
9 traits::{Estimator, Fit, Predict, PredictProba, Untrained},
10 types::Float,
11};
12use std::collections::{HashMap, VecDeque};
13
14#[derive(Debug, Clone)]
56pub struct StreamingGraphLearning<S = Untrained> {
57 state: S,
58 window_size: usize,
59 lambda_sparse: f64,
60 alpha_decay: f64,
61 update_frequency: usize,
62 forgetting_factor: f64,
63 adaptive_threshold: bool,
64 min_samples_update: usize,
65 k_neighbors: usize,
66 similarity_threshold: f64,
67}
68
69impl StreamingGraphLearning<Untrained> {
70 pub fn new() -> Self {
72 Self {
73 state: Untrained,
74 window_size: 1000,
75 lambda_sparse: 0.1,
76 alpha_decay: 0.95,
77 update_frequency: 50,
78 forgetting_factor: 0.99,
79 adaptive_threshold: true,
80 min_samples_update: 10,
81 k_neighbors: 5,
82 similarity_threshold: 0.5,
83 }
84 }
85
86 pub fn window_size(mut self, window_size: usize) -> Self {
88 self.window_size = window_size;
89 self
90 }
91
92 pub fn lambda_sparse(mut self, lambda_sparse: f64) -> Self {
94 self.lambda_sparse = lambda_sparse;
95 self
96 }
97
98 pub fn alpha_decay(mut self, alpha_decay: f64) -> Self {
100 self.alpha_decay = alpha_decay;
101 self
102 }
103
104 pub fn update_frequency(mut self, frequency: usize) -> Self {
106 self.update_frequency = frequency;
107 self
108 }
109
110 pub fn forgetting_factor(mut self, factor: f64) -> Self {
112 self.forgetting_factor = factor;
113 self
114 }
115
116 pub fn adaptive_threshold(mut self, adaptive: bool) -> Self {
118 self.adaptive_threshold = adaptive;
119 self
120 }
121
122 pub fn min_samples_update(mut self, min_samples: usize) -> Self {
124 self.min_samples_update = min_samples;
125 self
126 }
127
128 pub fn k_neighbors(mut self, k: usize) -> Self {
130 self.k_neighbors = k;
131 self
132 }
133
134 pub fn similarity_threshold(mut self, threshold: f64) -> Self {
136 self.similarity_threshold = threshold;
137 self
138 }
139
140 fn compute_similarity(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
141 let diff = x1 - x2;
142 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
143 (-dist / (2.0 * 1.0_f64.powi(2))).exp()
144 }
145
146 fn build_initial_graph(&self, X: &Array2<f64>) -> Array2<f64> {
147 let n_samples = X.nrows();
148 let mut W = Array2::zeros((n_samples, n_samples));
149
150 for i in 0..n_samples {
151 let mut similarities: Vec<(usize, f64)> = Vec::new();
152
153 for j in 0..n_samples {
154 if i != j {
155 let sim = self.compute_similarity(&X.row(i), &X.row(j));
156 similarities.push((j, sim));
157 }
158 }
159
160 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
162
163 for &(j, sim) in similarities.iter().take(self.k_neighbors) {
165 if sim > self.similarity_threshold {
166 W[[i, j]] = sim;
167 W[[j, i]] = sim; }
169 }
170 }
171
172 let threshold = self.lambda_sparse;
174 W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
175 W.mapv_inplace(|x| x.max(0.0));
176
177 for i in 0..n_samples {
179 W[[i, i]] = 0.0;
180 }
181
182 W
183 }
184
185 #[allow(non_snake_case)]
186 fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
187 let n_samples = W.nrows();
188 let n_classes = Y_init.ncols();
189
190 let D = W.sum_axis(Axis(1));
192 let mut P = Array2::zeros((n_samples, n_samples));
193 for i in 0..n_samples {
194 if D[i] > 0.0 {
195 for j in 0..n_samples {
196 P[[i, j]] = W[[i, j]] / D[i];
197 }
198 }
199 }
200
201 let mut Y = Y_init.clone();
202 let Y_static = Y_init.clone();
203
204 for _iter in 0..30 {
206 let prev_Y = Y.clone();
207 Y = 0.8 * P.dot(&Y) + 0.2 * &Y_static;
208
209 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
211 if diff < 1e-6 {
212 break;
213 }
214 }
215
216 Ok(Y)
217 }
218}
219
220impl Default for StreamingGraphLearning<Untrained> {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226impl Estimator for StreamingGraphLearning<Untrained> {
227 type Config = ();
228 type Error = SklearsError;
229 type Float = Float;
230
231 fn config(&self) -> &Self::Config {
232 &()
233 }
234}
235
236impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, i32>> for StreamingGraphLearning<Untrained> {
237 type Fitted = StreamingGraphLearning<StreamingGraphLearningTrained>;
238
239 #[allow(non_snake_case)]
240 fn fit(self, X: &ArrayView2<'_, Float>, y: &ArrayView1<'_, i32>) -> SklResult<Self::Fitted> {
241 let X = X.to_owned();
242 let y = y.to_owned();
243 let (n_samples, n_features) = X.dim();
244
245 let mut labeled_indices = Vec::new();
247 let mut classes = std::collections::HashSet::new();
248
249 for (i, &label) in y.iter().enumerate() {
250 if label != -1 {
251 labeled_indices.push(i);
252 classes.insert(label);
253 }
254 }
255
256 if labeled_indices.is_empty() {
257 return Err(SklearsError::InvalidInput(
258 "No labeled samples provided".to_string(),
259 ));
260 }
261
262 let classes: Vec<i32> = classes.into_iter().collect();
263 let n_classes = classes.len();
264
265 let W = self.build_initial_graph(&X);
267
268 let mut Y = Array2::zeros((n_samples, n_classes));
270 for &idx in &labeled_indices {
271 if let Some(class_idx) = classes.iter().position(|&c| c == y[idx]) {
272 Y[[idx, class_idx]] = 1.0;
273 }
274 }
275
276 let Y_final = self.propagate_labels(&W, &Y)?;
278
279 let mut data_window = VecDeque::with_capacity(self.window_size);
281 let mut label_window = VecDeque::with_capacity(self.window_size);
282
283 for i in 0..n_samples {
284 data_window.push_back(X.row(i).to_owned());
285 label_window.push_back(y[i]);
286 }
287
288 Ok(StreamingGraphLearning {
289 state: StreamingGraphLearningTrained {
290 X_train: X,
291 y_train: y,
292 classes: Array1::from(classes),
293 current_graph: W,
294 label_distributions: Y_final,
295 data_window,
296 label_window,
297 update_count: 0,
298 edge_ages: HashMap::new(),
299 adaptive_threshold_value: self.similarity_threshold,
300 },
301 window_size: self.window_size,
302 lambda_sparse: self.lambda_sparse,
303 alpha_decay: self.alpha_decay,
304 update_frequency: self.update_frequency,
305 forgetting_factor: self.forgetting_factor,
306 adaptive_threshold: self.adaptive_threshold,
307 min_samples_update: self.min_samples_update,
308 k_neighbors: self.k_neighbors,
309 similarity_threshold: self.similarity_threshold,
310 })
311 }
312}
313
314impl StreamingGraphLearning<StreamingGraphLearningTrained> {
315 fn compute_similarity(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
316 let diff = x1 - x2;
317 let dist = diff.mapv(|x: f64| x * x).sum().sqrt();
318 (-dist / (2.0 * 1.0_f64.powi(2))).exp()
319 }
320
321 fn build_initial_graph(&self, X: &Array2<f64>) -> Array2<f64> {
322 let n_samples = X.nrows();
323 let mut W = Array2::zeros((n_samples, n_samples));
324
325 for i in 0..n_samples {
326 let mut similarities: Vec<(usize, f64)> = Vec::new();
327
328 for j in 0..n_samples {
329 if i != j {
330 let sim = self.compute_similarity(&X.row(i), &X.row(j));
331 similarities.push((j, sim));
332 }
333 }
334
335 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
337
338 for &(j, sim) in similarities.iter().take(self.k_neighbors) {
340 if sim > self.similarity_threshold {
341 W[[i, j]] = sim;
342 W[[j, i]] = sim; }
344 }
345 }
346
347 let threshold = self.lambda_sparse;
349 W.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
350 W.mapv_inplace(|x| x.max(0.0));
351
352 for i in 0..n_samples {
354 W[[i, i]] = 0.0;
355 }
356
357 W
358 }
359
360 #[allow(non_snake_case)]
361 fn propagate_labels(&self, W: &Array2<f64>, Y_init: &Array2<f64>) -> SklResult<Array2<f64>> {
362 let n_samples = W.nrows();
363 let n_classes = Y_init.ncols();
364
365 let D = W.sum_axis(Axis(1));
367 let mut P = Array2::zeros((n_samples, n_samples));
368 for i in 0..n_samples {
369 if D[i] > 0.0 {
370 for j in 0..n_samples {
371 P[[i, j]] = W[[i, j]] / D[i];
372 }
373 }
374 }
375
376 let mut Y = Y_init.clone();
377 let Y_static = Y_init.clone();
378
379 for _iter in 0..30 {
381 let prev_Y = Y.clone();
382 Y = 0.8 * P.dot(&Y) + 0.2 * &Y_static;
383
384 let diff = (&Y - &prev_Y).mapv(|x| x.abs()).sum();
386 if diff < 1e-6 {
387 break;
388 }
389 }
390
391 Ok(Y)
392 }
393 #[allow(non_snake_case)]
395 pub fn update(
396 &mut self,
397 X_new: &ArrayView2<'_, Float>,
398 y_new: &ArrayView1<'_, i32>,
399 ) -> SklResult<()> {
400 let X_new = X_new.to_owned();
401 let y_new = y_new.to_owned();
402 let (n_new, _) = X_new.dim();
403
404 for i in 0..n_new {
406 if self.state.data_window.len() >= self.window_size {
408 self.state.data_window.pop_front();
409 self.state.label_window.pop_front();
410 }
411
412 self.state.data_window.push_back(X_new.row(i).to_owned());
413 self.state.label_window.push_back(y_new[i]);
414 }
415
416 self.state.update_count += n_new;
417
418 self.state
420 .current_graph
421 .mapv_inplace(|x| x * self.alpha_decay);
422
423 if self.adaptive_threshold {
425 self.update_adaptive_threshold();
426 }
427
428 let mut aged_edges = HashMap::new();
430 for ((i, j), age) in &self.state.edge_ages {
431 aged_edges.insert((*i, *j), age + 1);
432 }
433 self.state.edge_ages = aged_edges;
434
435 self.incremental_graph_update(&X_new, &y_new)?;
437
438 if self.state.update_count % self.update_frequency == 0 {
440 self.full_graph_reconstruction()?;
441 }
442
443 Ok(())
444 }
445
446 fn update_adaptive_threshold(&mut self) {
447 let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
448 if current_data.len() < 2 {
449 return;
450 }
451
452 let mut similarities = Vec::new();
453 for i in 0..current_data.len().min(100) {
454 for j in (i + 1)..current_data.len().min(100) {
455 let sim = self.compute_similarity(¤t_data[i].view(), ¤t_data[j].view());
456 similarities.push(sim);
457 }
458 }
459
460 if !similarities.is_empty() {
461 similarities.sort_by(|a, b| a.partial_cmp(b).unwrap());
462 let median_idx = similarities.len() / 2;
463 self.state.adaptive_threshold_value = similarities[median_idx] * 0.8;
464 }
465 }
466
467 fn incremental_graph_update(
468 &mut self,
469 X_new: &Array2<f64>,
470 y_new: &Array1<i32>,
471 ) -> SklResult<()> {
472 let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
473 let current_labels: Vec<i32> = self.state.label_window.iter().cloned().collect();
474 let n_current = current_data.len();
475 let n_new = X_new.nrows();
476
477 let mut new_graph = Array2::zeros((n_current, n_current));
479
480 let old_size = self.state.current_graph.nrows().min(n_current);
482 for i in 0..old_size {
483 for j in 0..old_size {
484 new_graph[[i, j]] = self.state.current_graph[[i, j]];
485 }
486 }
487
488 let start_idx = n_current - n_new;
490 for i in start_idx..n_current {
491 let mut similarities: Vec<(usize, f64)> = Vec::new();
492
493 for j in 0..n_current {
494 if i != j {
495 let sim =
496 self.compute_similarity(¤t_data[i].view(), ¤t_data[j].view());
497 similarities.push((j, sim));
498 }
499 }
500
501 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
503
504 let threshold = if self.adaptive_threshold {
506 self.state.adaptive_threshold_value
507 } else {
508 self.similarity_threshold
509 };
510
511 for &(j, sim) in similarities.iter().take(self.k_neighbors) {
512 if sim > threshold {
513 new_graph[[i, j]] = sim;
514 new_graph[[j, i]] = sim; self.state.edge_ages.insert((i, j), 0);
518 self.state.edge_ages.insert((j, i), 0);
519 }
520 }
521 }
522
523 for ((i, j), age) in &self.state.edge_ages {
525 if *i < n_current && *j < n_current {
526 let forgetting_weight = self.forgetting_factor.powi(*age as i32);
527 new_graph[[*i, *j]] *= forgetting_weight;
528 }
529 }
530
531 let threshold = self.lambda_sparse;
533 new_graph.mapv_inplace(|x| if x > threshold { x - threshold } else { 0.0 });
534 new_graph.mapv_inplace(|x| x.max(0.0));
535
536 for i in 0..n_current {
538 new_graph[[i, i]] = 0.0;
539 }
540
541 self.state.current_graph = new_graph;
542
543 self.update_label_propagation(¤t_data, ¤t_labels)?;
545
546 Ok(())
547 }
548
549 fn full_graph_reconstruction(&mut self) -> SklResult<()> {
550 let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
551 let current_labels: Vec<i32> = self.state.label_window.iter().cloned().collect();
552
553 if current_data.is_empty() {
554 return Ok(());
555 }
556
557 let n_samples = current_data.len();
558
559 let mut X = Array2::zeros((n_samples, current_data[0].len()));
561 for (i, data_point) in current_data.iter().enumerate() {
562 X.row_mut(i).assign(data_point);
563 }
564
565 self.state.current_graph = self.build_initial_graph(&X);
567
568 self.state.edge_ages.clear();
570
571 self.update_label_propagation(¤t_data, ¤t_labels)?;
573
574 Ok(())
575 }
576
577 #[allow(non_snake_case)]
578 fn update_label_propagation(
579 &mut self,
580 current_data: &[Array1<f64>],
581 current_labels: &[i32],
582 ) -> SklResult<()> {
583 let n_samples = current_data.len();
584 let n_classes = self.state.classes.len();
585
586 if n_samples == 0 {
587 return Ok(());
588 }
589
590 let mut Y = Array2::zeros((n_samples, n_classes));
592 for (i, &label) in current_labels.iter().enumerate() {
593 if label != -1 {
594 if let Some(class_idx) = self.state.classes.iter().position(|&c| c == label) {
595 Y[[i, class_idx]] = 1.0;
596 }
597 }
598 }
599
600 let Y_final = self.propagate_labels(&self.state.current_graph, &Y)?;
602 self.state.label_distributions = Y_final;
603
604 Ok(())
605 }
606}
607
608impl Predict<ArrayView2<'_, Float>, Array1<i32>>
609 for StreamingGraphLearning<StreamingGraphLearningTrained>
610{
611 #[allow(non_snake_case)]
612 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<i32>> {
613 let X = X.to_owned();
614 let n_test = X.nrows();
615 let mut predictions = Array1::zeros(n_test);
616
617 let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
618
619 for i in 0..n_test {
620 let mut max_sim = -1.0;
621 let mut best_idx = 0;
622
623 for (j, data_point) in current_data.iter().enumerate() {
625 let sim = self.compute_similarity(&X.row(i), &data_point.view());
626 if sim > max_sim {
627 max_sim = sim;
628 best_idx = j;
629 }
630 }
631
632 if best_idx < self.state.label_distributions.nrows() {
634 let distributions = self.state.label_distributions.row(best_idx);
635 let max_idx = distributions
636 .iter()
637 .enumerate()
638 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
639 .unwrap()
640 .0;
641
642 predictions[i] = self.state.classes[max_idx];
643 }
644 }
645
646 Ok(predictions)
647 }
648}
649
650impl PredictProba<ArrayView2<'_, Float>, Array2<f64>>
651 for StreamingGraphLearning<StreamingGraphLearningTrained>
652{
653 #[allow(non_snake_case)]
654 fn predict_proba(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
655 let X = X.to_owned();
656 let n_test = X.nrows();
657 let n_classes = self.state.classes.len();
658 let mut probas = Array2::zeros((n_test, n_classes));
659
660 let current_data: Vec<Array1<f64>> = self.state.data_window.iter().cloned().collect();
661
662 for i in 0..n_test {
663 let mut max_sim = -1.0;
664 let mut best_idx = 0;
665
666 for (j, data_point) in current_data.iter().enumerate() {
668 let sim = self.compute_similarity(&X.row(i), &data_point.view());
669 if sim > max_sim {
670 max_sim = sim;
671 best_idx = j;
672 }
673 }
674
675 if best_idx < self.state.label_distributions.nrows() {
677 for k in 0..n_classes {
678 probas[[i, k]] = self.state.label_distributions[[best_idx, k]];
679 }
680 }
681 }
682
683 Ok(probas)
684 }
685}
686
687#[derive(Debug, Clone)]
689pub struct StreamingGraphLearningTrained {
690 pub X_train: Array2<f64>,
692 pub y_train: Array1<i32>,
694 pub classes: Array1<i32>,
696 pub current_graph: Array2<f64>,
698 pub label_distributions: Array2<f64>,
700 pub data_window: VecDeque<Array1<f64>>,
702 pub label_window: VecDeque<i32>,
704 pub update_count: usize,
706 pub edge_ages: HashMap<(usize, usize), usize>,
708 pub adaptive_threshold_value: f64,
710}
711
712#[allow(non_snake_case)]
713#[cfg(test)]
714mod tests {
715 use super::*;
716 use scirs2_core::array;
717
718 #[test]
719 #[allow(non_snake_case)]
720 fn test_streaming_graph_learning_basic() {
721 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
722 let y = array![0, 1, -1, -1]; let sgl = StreamingGraphLearning::new()
725 .window_size(10)
726 .lambda_sparse(0.1)
727 .alpha_decay(0.9)
728 .update_frequency(5);
729 let fitted = sgl.fit(&X.view(), &y.view()).unwrap();
730
731 let predictions = fitted.predict(&X.view()).unwrap();
732 assert_eq!(predictions.len(), 4);
733
734 let probas = fitted.predict_proba(&X.view()).unwrap();
735 assert_eq!(probas.dim(), (4, 2));
736
737 assert_eq!(predictions[0], 0);
739 assert_eq!(predictions[1], 1);
740 }
741
742 #[test]
743 #[allow(non_snake_case)]
744 fn test_streaming_graph_learning_update() {
745 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
746 let y = array![0, 1, -1, -1];
747
748 let sgl = StreamingGraphLearning::new()
749 .window_size(10)
750 .update_frequency(3)
751 .alpha_decay(0.95);
752 let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
753
754 let initial_graph_size = fitted.state.current_graph.dim();
756 assert_eq!(initial_graph_size, (4, 4));
757
758 let X_new = array![[5.0, 6.0], [6.0, 7.0]];
760 let y_new = array![-1, 0];
761 fitted.update(&X_new.view(), &y_new.view()).unwrap();
762
763 assert_eq!(fitted.state.data_window.len(), 6);
765 assert_eq!(fitted.state.label_window.len(), 6);
766
767 let updated_graph_size = fitted.state.current_graph.dim();
769 assert_eq!(updated_graph_size, (6, 6));
770
771 let predictions = fitted.predict(&X_new.view()).unwrap();
773 assert_eq!(predictions.len(), 2);
774 }
775
776 #[test]
777 #[allow(non_snake_case)]
778 fn test_streaming_graph_learning_window_overflow() {
779 let X = array![[1.0, 2.0], [2.0, 3.0]];
780 let y = array![0, 1];
781
782 let sgl = StreamingGraphLearning::new()
783 .window_size(3) .update_frequency(2);
785 let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
786
787 let X_new1 = array![[3.0, 4.0]];
789 let y_new1 = array![-1];
790 fitted.update(&X_new1.view(), &y_new1.view()).unwrap();
791
792 let X_new2 = array![[4.0, 5.0]];
793 let y_new2 = array![0];
794 fitted.update(&X_new2.view(), &y_new2.view()).unwrap();
795
796 assert_eq!(fitted.state.data_window.len(), 3);
798 assert_eq!(fitted.state.label_window.len(), 3);
799
800 let predictions = fitted.predict(&X_new2.view()).unwrap();
802 assert_eq!(predictions.len(), 1);
803 }
804
805 #[test]
806 #[allow(non_snake_case)]
807 fn test_streaming_graph_learning_adaptive_threshold() {
808 let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]];
809 let y = array![0, 1, -1, -1];
810
811 let sgl = StreamingGraphLearning::new()
812 .window_size(10)
813 .adaptive_threshold(true)
814 .similarity_threshold(0.5);
815 let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
816
817 let initial_threshold = fitted.state.adaptive_threshold_value;
818
819 let X_new = array![[10.0, 20.0], [20.0, 30.0]];
821 let y_new = array![-1, 1];
822 fitted.update(&X_new.view(), &y_new.view()).unwrap();
823
824 assert!(fitted.state.adaptive_threshold_value > 0.0);
827 }
828
829 #[test]
830 #[allow(non_snake_case)]
831 fn test_streaming_graph_learning_edge_aging() {
832 let X = array![[1.0, 2.0], [2.0, 3.0]];
833 let y = array![0, 1];
834
835 let sgl = StreamingGraphLearning::new()
836 .window_size(10)
837 .forgetting_factor(0.8)
838 .alpha_decay(0.9);
839 let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
840
841 assert_eq!(fitted.state.update_count, 0);
843
844 for i in 0..3 {
846 let X_new = array![[3.0 + i as f64, 4.0 + i as f64]];
847 let y_new = array![-1];
848 fitted.update(&X_new.view(), &y_new.view()).unwrap();
849 }
850
851 assert_eq!(fitted.state.update_count, 3);
853
854 assert!(!fitted.state.edge_ages.is_empty());
856 }
857
858 #[test]
859 #[allow(non_snake_case)]
860 fn test_streaming_graph_learning_full_reconstruction() {
861 let X = array![[1.0, 2.0], [2.0, 3.0]];
862 let y = array![0, 1];
863
864 let sgl = StreamingGraphLearning::new()
865 .window_size(10)
866 .update_frequency(2); let mut fitted = sgl.fit(&X.view(), &y.view()).unwrap();
868
869 let X_new1 = array![[3.0, 4.0]];
871 let y_new1 = array![-1];
872 fitted.update(&X_new1.view(), &y_new1.view()).unwrap();
873
874 let X_new2 = array![[4.0, 5.0]];
875 let y_new2 = array![0];
876 fitted.update(&X_new2.view(), &y_new2.view()).unwrap();
877
878 assert!(
881 fitted.state.edge_ages.is_empty()
882 || fitted.state.edge_ages.values().all(|&age| age == 0)
883 );
884
885 let predictions = fitted.predict(&X_new2.view()).unwrap();
887 assert_eq!(predictions.len(), 1);
888 }
889}