1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
25use scirs2_core::random::Random;
26use sklears_core::{
27 error::{Result, SklearsError},
28 traits::{Estimator, Fit, Predict, Trained, Untrained},
29 types::Float,
30};
31use std::collections::VecDeque;
32use std::marker::PhantomData;
33
34#[derive(Debug, Clone)]
36pub struct StreamingConfig {
37 pub max_clusters: usize,
39 pub learning_rate: Float,
41 pub decay_factor: Float,
43 pub window_size: usize,
45 pub creation_threshold: Float,
47 pub merge_threshold: Float,
49 pub min_weight: Float,
51 pub random_state: Option<u64>,
53 pub update_frequency: usize,
55}
56
57impl Default for StreamingConfig {
58 fn default() -> Self {
59 Self {
60 max_clusters: 100,
61 learning_rate: 0.1,
62 decay_factor: 0.95,
63 window_size: 1000,
64 creation_threshold: 1.0,
65 merge_threshold: 0.5,
66 min_weight: 0.01,
67 random_state: None,
68 update_frequency: 10,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct MicroCluster {
76 pub centroid: Array1<Float>,
78 pub weight: Float,
80 pub sum_squared: Float,
82 pub creation_time: usize,
84 pub last_update: usize,
86 pub radius: Float,
88}
89
90impl MicroCluster {
91 pub fn new(point: &ArrayView1<Float>, timestamp: usize) -> Self {
93 Self {
94 centroid: point.to_owned(),
95 weight: 1.0,
96 sum_squared: 0.0,
97 creation_time: timestamp,
98 last_update: timestamp,
99 radius: 0.0,
100 }
101 }
102
103 pub fn update(&mut self, point: &ArrayView1<Float>, timestamp: usize, learning_rate: Float) {
105 let distance = self.distance_to_centroid(point);
106
107 let weight_factor = learning_rate / (self.weight + 1.0);
109 let diff = point - &self.centroid;
110 self.centroid = &self.centroid + &(&diff * weight_factor);
111
112 self.weight += 1.0;
114 self.sum_squared += distance * distance;
115 self.last_update = timestamp;
116
117 self.radius = (self.sum_squared / self.weight.max(1.0)).sqrt();
119 }
120
121 pub fn distance_to_centroid(&self, point: &ArrayView1<Float>) -> Float {
123 let diff = point - &self.centroid;
124 diff.dot(&diff).sqrt()
125 }
126
127 pub fn decay(&mut self, decay_factor: Float) {
129 self.weight *= decay_factor;
130 self.sum_squared *= decay_factor;
131 }
132
133 pub fn should_remove(&self, min_weight: Float) -> bool {
135 self.weight < min_weight
136 }
137
138 pub fn density(&self) -> Float {
140 if self.radius > 0.0 {
141 self.weight / (self.radius * self.radius)
142 } else {
143 self.weight
144 }
145 }
146}
147
148pub struct OnlineKMeans<State = Untrained> {
150 config: StreamingConfig,
151 state: PhantomData<State>,
152 centroids: Option<Array2<Float>>,
154 weights: Option<Array1<Float>>,
155 n_updates: usize,
156 timestamp: usize,
157}
158
159impl<State> OnlineKMeans<State> {
160 pub fn new() -> Self {
162 Self {
163 config: StreamingConfig::default(),
164 state: PhantomData,
165 centroids: None,
166 weights: None,
167 n_updates: 0,
168 timestamp: 0,
169 }
170 }
171
172 pub fn max_clusters(mut self, max_clusters: usize) -> Self {
174 self.config.max_clusters = max_clusters;
175 self
176 }
177
178 pub fn learning_rate(mut self, learning_rate: Float) -> Self {
180 self.config.learning_rate = learning_rate;
181 self
182 }
183
184 pub fn decay_factor(mut self, decay_factor: Float) -> Self {
186 self.config.decay_factor = decay_factor;
187 self
188 }
189
190 pub fn random_state(mut self, seed: u64) -> Self {
192 self.config.random_state = Some(seed);
193 self
194 }
195}
196
197impl OnlineKMeans<Trained> {
198 pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
200 if let (Some(ref mut centroids), Some(ref mut weights)) =
201 (&mut self.centroids, &mut self.weights)
202 {
203 let mut min_distance = Float::INFINITY;
205 let mut closest_idx = 0;
206
207 for (i, centroid) in centroids.outer_iter().enumerate() {
208 let diff = point - ¢roid;
209 let distance = diff.dot(&diff).sqrt();
210 if distance < min_distance {
211 min_distance = distance;
212 closest_idx = i;
213 }
214 }
215
216 let lr = self.config.learning_rate / (weights[closest_idx] + 1.0);
218 let diff = point - ¢roids.row(closest_idx);
219 let mut new_centroid = centroids.row(closest_idx).to_owned();
220 new_centroid = &new_centroid + &(&diff * lr);
221 centroids.row_mut(closest_idx).assign(&new_centroid);
222
223 weights[closest_idx] += 1.0;
225
226 for weight in weights.iter_mut() {
228 *weight *= self.config.decay_factor;
229 }
230
231 self.n_updates += 1;
232 self.timestamp += 1;
233
234 Ok(())
235 } else {
236 Err(SklearsError::NotFitted {
237 operation: "partial_fit".to_string(),
238 })
239 }
240 }
241
242 pub fn centroids(&self) -> Result<&Array2<Float>> {
244 self.centroids
245 .as_ref()
246 .ok_or_else(|| SklearsError::NotFitted {
247 operation: "centroids".to_string(),
248 })
249 }
250
251 pub fn weights(&self) -> Result<&Array1<Float>> {
253 self.weights
254 .as_ref()
255 .ok_or_else(|| SklearsError::NotFitted {
256 operation: "weights".to_string(),
257 })
258 }
259}
260
261impl<State> Default for OnlineKMeans<State> {
262 fn default() -> Self {
263 Self::new()
264 }
265}
266
267impl<State> Estimator<State> for OnlineKMeans<State> {
268 type Config = StreamingConfig;
269 type Error = SklearsError;
270 type Float = Float;
271
272 fn config(&self) -> &Self::Config {
273 &self.config
274 }
275}
276
277impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for OnlineKMeans<Untrained> {
278 type Fitted = OnlineKMeans<Trained>;
279
280 fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
281 let (n_samples, n_features) = x.dim();
282 let k = self.config.max_clusters.min(n_samples);
283
284 if n_samples == 0 || n_features == 0 {
285 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
286 }
287
288 let mut rng = match self.config.random_state {
289 Some(seed) => Random::default(),
290 None => Random::default(),
291 };
292
293 let mut centroids = Array2::zeros((k, n_features));
295 let weights = Array1::ones(k);
296
297 for i in 0..k {
298 let idx = rng.gen_range(0..n_samples);
299 centroids.row_mut(i).assign(&x.row(idx));
300 }
301
302 Ok(OnlineKMeans {
303 config: self.config,
304 state: PhantomData,
305 centroids: Some(centroids),
306 weights: Some(weights),
307 n_updates: 0,
308 timestamp: 0,
309 })
310 }
311}
312
313impl Predict<ArrayView2<'_, Float>, Array1<usize>> for OnlineKMeans<Trained> {
314 fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
315 let centroids = self.centroids()?;
316 let mut labels = Array1::zeros(x.nrows());
317
318 for (i, sample) in x.outer_iter().enumerate() {
319 let mut min_distance = Float::INFINITY;
320 let mut best_cluster = 0;
321
322 for (j, centroid) in centroids.outer_iter().enumerate() {
323 let diff = &sample - ¢roid;
324 let distance = diff.dot(&diff).sqrt();
325 if distance < min_distance {
326 min_distance = distance;
327 best_cluster = j;
328 }
329 }
330
331 labels[i] = best_cluster;
332 }
333
334 Ok(labels)
335 }
336}
337
338pub struct CluStream<State = Untrained> {
340 config: StreamingConfig,
341 state: PhantomData<State>,
342 micro_clusters: Option<Vec<MicroCluster>>,
344 macro_clusters: Option<Array2<Float>>,
345 timestamp: usize,
346 update_counter: usize,
347}
348
349impl<State> CluStream<State> {
350 pub fn new() -> Self {
352 Self {
353 config: StreamingConfig::default(),
354 state: PhantomData,
355 micro_clusters: None,
356 macro_clusters: None,
357 timestamp: 0,
358 update_counter: 0,
359 }
360 }
361
362 pub fn max_clusters(mut self, max_clusters: usize) -> Self {
364 self.config.max_clusters = max_clusters;
365 self
366 }
367
368 pub fn creation_threshold(mut self, threshold: Float) -> Self {
370 self.config.creation_threshold = threshold;
371 self
372 }
373
374 pub fn merge_threshold(mut self, threshold: Float) -> Self {
376 self.config.merge_threshold = threshold;
377 self
378 }
379
380 pub fn decay_factor(mut self, decay_factor: Float) -> Self {
382 self.config.decay_factor = decay_factor;
383 self
384 }
385
386 pub fn update_frequency(mut self, frequency: usize) -> Self {
388 self.config.update_frequency = frequency;
389 self
390 }
391
392 pub fn random_state(mut self, seed: u64) -> Self {
394 self.config.random_state = Some(seed);
395 self
396 }
397}
398
399impl CluStream<Trained> {
400 pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
402 if let Some(ref mut micro_clusters) = &mut self.micro_clusters {
403 self.timestamp += 1;
404
405 let mut min_distance = Float::INFINITY;
407 let mut closest_idx = None;
408
409 for (i, cluster) in micro_clusters.iter().enumerate() {
410 let distance = cluster.distance_to_centroid(point);
411 if distance < cluster.radius + self.config.creation_threshold
412 && distance < min_distance
413 {
414 min_distance = distance;
415 closest_idx = Some(i);
416 }
417 }
418
419 if let Some(idx) = closest_idx {
420 micro_clusters[idx].update(point, self.timestamp, self.config.learning_rate);
422 } else {
423 if micro_clusters.len() < self.config.max_clusters {
425 micro_clusters.push(MicroCluster::new(point, self.timestamp));
427 } else {
428 let mut min_merge_distance = Float::INFINITY;
430 let mut merge_indices = (0, 1);
431
432 for i in 0..micro_clusters.len() {
433 for j in (i + 1)..micro_clusters.len() {
434 let dist = micro_clusters[i]
435 .distance_to_centroid(µ_clusters[j].centroid.view());
436 if dist < min_merge_distance {
437 min_merge_distance = dist;
438 merge_indices = (i, j);
439 }
440 }
441 }
442
443 if min_merge_distance < self.config.merge_threshold {
445 let (i, j) = merge_indices;
446 let merged_centroid = (µ_clusters[i].centroid
447 * micro_clusters[i].weight
448 + µ_clusters[j].centroid * micro_clusters[j].weight)
449 / (micro_clusters[i].weight + micro_clusters[j].weight);
450 let merged_weight = micro_clusters[i].weight + micro_clusters[j].weight;
451
452 micro_clusters[i].centroid = merged_centroid;
453 micro_clusters[i].weight = merged_weight;
454 micro_clusters[i].last_update = self.timestamp;
455
456 micro_clusters.remove(j);
457
458 micro_clusters.push(MicroCluster::new(point, self.timestamp));
460 } else {
461 let mut oldest_idx = 0;
463 let mut oldest_time = micro_clusters[0].last_update;
464
465 for (i, cluster) in micro_clusters.iter().enumerate() {
466 if cluster.last_update < oldest_time {
467 oldest_time = cluster.last_update;
468 oldest_idx = i;
469 }
470 }
471
472 micro_clusters[oldest_idx] = MicroCluster::new(point, self.timestamp);
473 }
474 }
475 }
476
477 for cluster in micro_clusters.iter_mut() {
479 cluster.decay(self.config.decay_factor);
480 }
481
482 micro_clusters.retain(|cluster| !cluster.should_remove(self.config.min_weight));
484
485 self.update_counter += 1;
487 if self.update_counter % self.config.update_frequency == 0 {
488 self.update_macro_clusters()?;
489 }
490
491 Ok(())
492 } else {
493 Err(SklearsError::NotFitted {
494 operation: "partial_fit".to_string(),
495 })
496 }
497 }
498
499 fn update_macro_clusters(&mut self) -> Result<()> {
501 if let Some(ref micro_clusters) = &self.micro_clusters {
502 if micro_clusters.is_empty() {
503 return Ok(());
504 }
505
506 let n_features = micro_clusters[0].centroid.len();
507 let mut macro_centroids = Array2::zeros((micro_clusters.len(), n_features));
508
509 for (i, cluster) in micro_clusters.iter().enumerate() {
510 macro_centroids.row_mut(i).assign(&cluster.centroid);
511 }
512
513 self.macro_clusters = Some(macro_centroids);
514 }
515
516 Ok(())
517 }
518
519 pub fn micro_clusters(&self) -> Result<&Vec<MicroCluster>> {
521 self.micro_clusters
522 .as_ref()
523 .ok_or_else(|| SklearsError::NotFitted {
524 operation: "micro_clusters".to_string(),
525 })
526 }
527
528 pub fn macro_clusters(&self) -> Result<&Array2<Float>> {
530 self.macro_clusters
531 .as_ref()
532 .ok_or_else(|| SklearsError::NotFitted {
533 operation: "macro_clusters".to_string(),
534 })
535 }
536}
537
538impl<State> Default for CluStream<State> {
539 fn default() -> Self {
540 Self::new()
541 }
542}
543
544impl<State> Estimator<State> for CluStream<State> {
545 type Config = StreamingConfig;
546 type Error = SklearsError;
547 type Float = Float;
548
549 fn config(&self) -> &Self::Config {
550 &self.config
551 }
552}
553
554impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for CluStream<Untrained> {
555 type Fitted = CluStream<Trained>;
556
557 fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
558 let (n_samples, n_features) = x.dim();
559
560 if n_samples == 0 || n_features == 0 {
561 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
562 }
563
564 let initial_clusters = self.config.max_clusters.min(n_samples);
566 let mut micro_clusters = Vec::with_capacity(initial_clusters);
567
568 for i in 0..initial_clusters {
569 micro_clusters.push(MicroCluster::new(&x.row(i), i));
570 }
571
572 Ok(CluStream {
573 config: self.config,
574 state: PhantomData,
575 micro_clusters: Some(micro_clusters),
576 macro_clusters: None,
577 timestamp: initial_clusters,
578 update_counter: 0,
579 })
580 }
581}
582
583impl Predict<ArrayView2<'_, Float>, Array1<usize>> for CluStream<Trained> {
584 fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
585 let micro_clusters = self.micro_clusters()?;
586 let mut labels = Array1::zeros(x.nrows());
587
588 for (i, sample) in x.outer_iter().enumerate() {
589 let mut min_distance = Float::INFINITY;
590 let mut best_cluster = 0;
591
592 for (j, cluster) in micro_clusters.iter().enumerate() {
593 let distance = cluster.distance_to_centroid(&sample);
594 if distance < min_distance {
595 min_distance = distance;
596 best_cluster = j;
597 }
598 }
599
600 labels[i] = best_cluster;
601 }
602
603 Ok(labels)
604 }
605}
606
607pub struct SlidingWindowKMeans<State = Untrained> {
609 config: StreamingConfig,
610 state: PhantomData<State>,
611 window_data: Option<VecDeque<Array1<Float>>>,
613 centroids: Option<Array2<Float>>,
614 timestamps: Option<VecDeque<usize>>,
615 current_time: usize,
616}
617
618impl<State> SlidingWindowKMeans<State> {
619 pub fn new() -> Self {
621 Self {
622 config: StreamingConfig::default(),
623 state: PhantomData,
624 window_data: None,
625 centroids: None,
626 timestamps: None,
627 current_time: 0,
628 }
629 }
630
631 pub fn window_size(mut self, window_size: usize) -> Self {
633 self.config.window_size = window_size;
634 self
635 }
636
637 pub fn max_clusters(mut self, max_clusters: usize) -> Self {
639 self.config.max_clusters = max_clusters;
640 self
641 }
642
643 pub fn random_state(mut self, seed: u64) -> Self {
645 self.config.random_state = Some(seed);
646 self
647 }
648}
649
650impl SlidingWindowKMeans<Trained> {
651 pub fn partial_fit(&mut self, point: &ArrayView1<Float>) -> Result<()> {
653 if let (Some(ref mut window_data), Some(ref mut timestamps)) =
654 (&mut self.window_data, &mut self.timestamps)
655 {
656 window_data.push_back(point.to_owned());
658 timestamps.push_back(self.current_time);
659
660 while window_data.len() > self.config.window_size {
662 window_data.pop_front();
663 timestamps.pop_front();
664 }
665
666 if window_data.len() >= self.config.max_clusters {
668 self.recompute_centroids()?;
669 }
670
671 self.current_time += 1;
672
673 Ok(())
674 } else {
675 Err(SklearsError::NotFitted {
676 operation: "partial_fit".to_string(),
677 })
678 }
679 }
680
681 fn recompute_centroids(&mut self) -> Result<()> {
683 if let Some(ref window_data) = &self.window_data {
684 if window_data.is_empty() {
685 return Ok(());
686 }
687
688 let n_features = window_data[0].len();
689 let k = self.config.max_clusters.min(window_data.len());
690
691 let mut centroids = Array2::zeros((k, n_features));
693 let mut counts = Array1::zeros(k);
694
695 for (i, point) in window_data.iter().take(k).enumerate() {
697 centroids.row_mut(i).assign(point);
698 }
699
700 for _ in 0..10 {
702 counts.fill(0.0);
704 let mut new_centroids = Array2::zeros((k, n_features));
705
706 for point in window_data.iter() {
707 let mut min_distance = Float::INFINITY;
709 let mut closest_idx = 0;
710
711 for (j, centroid) in centroids.outer_iter().enumerate() {
712 let diff = point - ¢roid;
713 let distance = diff.dot(&diff).sqrt();
714 if distance < min_distance {
715 min_distance = distance;
716 closest_idx = j;
717 }
718 }
719
720 let mut row = new_centroids.row_mut(closest_idx);
722 row += point;
723 counts[closest_idx] += 1.0;
724 }
725
726 for i in 0..k {
728 if counts[i] > 0.0 {
729 let mut row = new_centroids.row_mut(i);
730 row /= counts[i];
731 centroids.row_mut(i).assign(&row);
732 }
733 }
734 }
735
736 self.centroids = Some(centroids);
737 }
738
739 Ok(())
740 }
741
742 pub fn centroids(&self) -> Result<&Array2<Float>> {
744 self.centroids
745 .as_ref()
746 .ok_or_else(|| SklearsError::NotFitted {
747 operation: "centroids".to_string(),
748 })
749 }
750
751 pub fn current_window_size(&self) -> usize {
753 self.window_data.as_ref().map_or(0, |data| data.len())
754 }
755}
756
757impl<State> Default for SlidingWindowKMeans<State> {
758 fn default() -> Self {
759 Self::new()
760 }
761}
762
763impl<State> Estimator<State> for SlidingWindowKMeans<State> {
764 type Config = StreamingConfig;
765 type Error = SklearsError;
766 type Float = Float;
767
768 fn config(&self) -> &Self::Config {
769 &self.config
770 }
771}
772
773impl Fit<ArrayView2<'_, Float>, ArrayView1<'_, usize>> for SlidingWindowKMeans<Untrained> {
774 type Fitted = SlidingWindowKMeans<Trained>;
775
776 fn fit(self, x: &ArrayView2<Float>, _y: &ArrayView1<usize>) -> Result<Self::Fitted> {
777 let (n_samples, n_features) = x.dim();
778
779 if n_samples == 0 || n_features == 0 {
780 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
781 }
782
783 let window_size = self.config.window_size.min(n_samples);
785 let mut window_data = VecDeque::with_capacity(self.config.window_size);
786 let mut timestamps = VecDeque::with_capacity(self.config.window_size);
787
788 for (i, row) in x.outer_iter().take(window_size).enumerate() {
789 window_data.push_back(row.to_owned());
790 timestamps.push_back(i);
791 }
792
793 let k = self.config.max_clusters.min(window_size);
795 let mut centroids = Array2::zeros((k, n_features));
796 for i in 0..k {
797 centroids.row_mut(i).assign(&window_data[i]);
798 }
799
800 Ok(SlidingWindowKMeans {
801 config: self.config,
802 state: PhantomData,
803 window_data: Some(window_data),
804 centroids: Some(centroids),
805 timestamps: Some(timestamps),
806 current_time: window_size,
807 })
808 }
809}
810
811impl Predict<ArrayView2<'_, Float>, Array1<usize>> for SlidingWindowKMeans<Trained> {
812 fn predict(&self, x: &ArrayView2<Float>) -> Result<Array1<usize>> {
813 let centroids = self.centroids()?;
814 let mut labels = Array1::zeros(x.nrows());
815
816 for (i, sample) in x.outer_iter().enumerate() {
817 let mut min_distance = Float::INFINITY;
818 let mut best_cluster = 0;
819
820 for (j, centroid) in centroids.outer_iter().enumerate() {
821 let diff = &sample - ¢roid;
822 let distance = diff.dot(&diff).sqrt();
823 if distance < min_distance {
824 min_distance = distance;
825 best_cluster = j;
826 }
827 }
828
829 labels[i] = best_cluster;
830 }
831
832 Ok(labels)
833 }
834}
835
836#[allow(non_snake_case)]
837#[cfg(test)]
838mod tests {
839 use super::*;
840 use approx::assert_relative_eq;
841 use scirs2_core::ndarray::array;
842
843 #[test]
844 fn test_micro_cluster_creation() {
845 let point = array![1.0, 2.0];
846 let cluster = MicroCluster::new(&point.view(), 0);
847
848 assert_eq!(cluster.centroid, point);
849 assert_eq!(cluster.weight, 1.0);
850 assert_eq!(cluster.creation_time, 0);
851 assert_eq!(cluster.last_update, 0);
852 }
853
854 #[test]
855 fn test_micro_cluster_update() {
856 let point1 = array![1.0, 2.0];
857 let point2 = array![3.0, 4.0];
858 let mut cluster = MicroCluster::new(&point1.view(), 0);
859
860 cluster.update(&point2.view(), 1, 0.5);
861
862 assert!(cluster.weight > 1.0);
863 assert_eq!(cluster.last_update, 1);
864
865 assert!(cluster.centroid[0] > 1.0);
867 assert!(cluster.centroid[1] > 2.0);
868 }
869
870 #[test]
871 fn test_online_kmeans_fit() {
872 let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
873 let y = Array1::zeros(4);
874
875 let model = OnlineKMeans::new()
876 .max_clusters(2)
877 .learning_rate(0.1)
878 .random_state(42)
879 .fit(&x.view(), &y.view())
880 .unwrap();
881
882 assert!(model.centroids().is_ok());
883 assert!(model.weights().is_ok());
884
885 let centroids = model.centroids().unwrap();
886 assert_eq!(centroids.nrows(), 2);
887 assert_eq!(centroids.ncols(), 2);
888 }
889
890 #[test]
891 fn test_online_kmeans_partial_fit() {
892 let x = array![[0.0, 0.0], [1.0, 1.0],];
893 let y = Array1::zeros(2);
894
895 let mut model = OnlineKMeans::new()
896 .max_clusters(2)
897 .learning_rate(0.1)
898 .random_state(42)
899 .fit(&x.view(), &y.view())
900 .unwrap();
901
902 let new_point = array![0.5, 0.5];
904 model.partial_fit(&new_point.view()).unwrap();
905
906 let centroids = model.centroids().unwrap();
907 assert_eq!(centroids.nrows(), 2);
908 }
909
910 #[test]
911 fn test_online_kmeans_predict() {
912 let x = array![[0.0, 0.0], [1.0, 1.0],];
913 let y = Array1::zeros(2);
914
915 let model = OnlineKMeans::new()
916 .max_clusters(2)
917 .random_state(42)
918 .fit(&x.view(), &y.view())
919 .unwrap();
920
921 let test_data = array![[0.1, 0.1], [0.9, 0.9],];
922
923 let predictions = model.predict(&test_data.view()).unwrap();
924 assert_eq!(predictions.len(), 2);
925 }
926
927 #[test]
928 fn test_clustream_fit() {
929 let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
930 let y = Array1::zeros(4);
931
932 let model = CluStream::new()
933 .max_clusters(3)
934 .creation_threshold(0.5)
935 .merge_threshold(0.3)
936 .random_state(42)
937 .fit(&x.view(), &y.view())
938 .unwrap();
939
940 assert!(model.micro_clusters().is_ok());
941
942 let micro_clusters = model.micro_clusters().unwrap();
943 assert!(!micro_clusters.is_empty());
944 assert!(micro_clusters.len() <= 3);
945 }
946
947 #[test]
948 fn test_clustream_partial_fit() {
949 let x = array![[0.0, 0.0], [1.0, 1.0],];
950 let y = Array1::zeros(2);
951
952 let mut model = CluStream::new()
953 .max_clusters(3)
954 .creation_threshold(0.5)
955 .random_state(42)
956 .fit(&x.view(), &y.view())
957 .unwrap();
958
959 let new_point1 = array![0.2, 0.2];
961 let new_point2 = array![2.0, 2.0];
962
963 model.partial_fit(&new_point1.view()).unwrap();
964 model.partial_fit(&new_point2.view()).unwrap();
965
966 let micro_clusters = model.micro_clusters().unwrap();
967 assert!(!micro_clusters.is_empty());
968 }
969
970 #[test]
971 fn test_sliding_window_kmeans() {
972 let x = array![[0.0, 0.0], [0.1, 0.1], [1.0, 1.0], [1.1, 1.1],];
973 let y = Array1::zeros(4);
974
975 let mut model = SlidingWindowKMeans::new()
976 .window_size(3)
977 .max_clusters(2)
978 .random_state(42)
979 .fit(&x.view(), &y.view())
980 .unwrap();
981
982 assert!(model.centroids().is_ok());
983 assert_eq!(model.current_window_size(), 3);
984
985 let new_point = array![2.0, 2.0];
987 model.partial_fit(&new_point.view()).unwrap();
988
989 assert_eq!(model.current_window_size(), 3);
991 }
992
993 #[test]
994 fn test_micro_cluster_decay() {
995 let point = array![1.0, 2.0];
996 let mut cluster = MicroCluster::new(&point.view(), 0);
997
998 let initial_weight = cluster.weight;
999 cluster.decay(0.9);
1000
1001 assert!(cluster.weight < initial_weight);
1002 assert_eq!(cluster.weight, initial_weight * 0.9);
1003 }
1004
1005 #[test]
1006 fn test_micro_cluster_should_remove() {
1007 let point = array![1.0, 2.0];
1008 let mut cluster = MicroCluster::new(&point.view(), 0);
1009
1010 cluster.weight = 0.005;
1012
1013 assert!(cluster.should_remove(0.01));
1014 assert!(!cluster.should_remove(0.001));
1015 }
1016}