1use crate::error::{ClusterError, ClusterResult};
14use crate::traits::{ClusteringAlgorithm, ClusteringResult, Fit, FitPredict};
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
16use scirs2_core::random::{seeded_rng, CoreRandom};
17use scirs2_core::random::rngs::StdRng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use torsh_tensor::Tensor;
23
24pub trait IncrementalClustering {
26 type Result: ClusteringResult;
27
28 fn update_single(&mut self, point: &Tensor) -> ClusterResult<()>;
30
31 fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()>;
33
34 fn get_current_result(&self) -> ClusterResult<Self::Result>;
36
37 fn reset(&mut self);
39
40 fn detect_drift(&self) -> bool;
42
43 fn n_points_seen(&self) -> usize;
45}
46
47#[derive(Debug, Clone)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub struct OnlineKMeansConfig {
51 pub n_clusters: usize,
53 pub learning_rate: Option<f64>,
55 pub decay_factor: f64,
57 pub min_learning_rate: f64,
59 pub drift_threshold: f64,
61 pub random_state: Option<u64>,
63 pub drift_window_size: usize,
65}
66
67impl Default for OnlineKMeansConfig {
68 fn default() -> Self {
69 Self {
70 n_clusters: 8,
71 learning_rate: None, decay_factor: 0.9,
73 min_learning_rate: 1e-6,
74 drift_threshold: 0.1,
75 random_state: None,
76 drift_window_size: 1000,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct OnlineKMeansResult {
84 pub centroids: Tensor,
86 pub labels: Option<Tensor>,
88 pub cluster_counts: Vec<usize>,
90 pub n_points_seen: usize,
92 pub current_learning_rate: f64,
94 pub drift_detected: bool,
96 pub avg_intra_cluster_distance: f64,
98}
99
100impl ClusteringResult for OnlineKMeansResult {
101 fn labels(&self) -> &Tensor {
102 self.labels
103 .as_ref()
104 .unwrap_or_else(|| panic!("Labels not available for online clustering result"))
105 }
106
107 fn n_clusters(&self) -> usize {
108 self.centroids.shape().dims()[0]
109 }
110
111 fn centers(&self) -> Option<&Tensor> {
112 Some(&self.centroids)
113 }
114
115 fn converged(&self) -> bool {
116 self.n_points_seen > 100 }
118
119 fn n_iter(&self) -> Option<usize> {
120 Some(self.n_points_seen)
121 }
122
123 fn metadata(&self) -> Option<&HashMap<String, String>> {
124 None
125 }
126}
127
128#[derive(Debug)]
156pub struct OnlineKMeans {
157 config: OnlineKMeansConfig,
158 centroids: Option<Array2<f64>>,
159 cluster_counts: Vec<usize>,
160 n_points_seen: usize,
161 current_learning_rate: f64,
162 drift_history: VecDeque<f64>,
163 rng: CoreRandom<StdRng>,
164 n_features: Option<usize>,
165}
166
167impl OnlineKMeans {
168 pub fn new(n_clusters: usize) -> ClusterResult<Self> {
170 let config = OnlineKMeansConfig {
171 n_clusters,
172 ..Default::default()
173 };
174
175 let seed = config.random_state.unwrap_or_else(|| {
176 use std::time::{SystemTime, UNIX_EPOCH};
177 SystemTime::now()
178 .duration_since(UNIX_EPOCH)
179 .expect("system time should be after UNIX_EPOCH")
180 .as_secs()
181 });
182 let rng = seeded_rng(seed);
183
184 Ok(Self {
185 config,
186 centroids: None,
187 cluster_counts: vec![0; n_clusters],
188 n_points_seen: 0,
189 current_learning_rate: 1.0,
190 drift_history: VecDeque::with_capacity(1000),
191 rng,
192 n_features: None,
193 })
194 }
195
196 pub fn learning_rate(mut self, learning_rate: Option<f64>) -> Self {
198 self.config.learning_rate = learning_rate;
199 self
200 }
201
202 pub fn drift_threshold(mut self, threshold: f64) -> Self {
204 self.config.drift_threshold = threshold;
205 self
206 }
207
208 pub fn random_state(mut self, seed: u64) -> Self {
210 self.config.random_state = Some(seed);
211 self.rng = seeded_rng(seed);
212 self
213 }
214
215 fn initialize_centroids(&mut self, n_features: usize) -> ClusterResult<()> {
217 if self.centroids.is_none() {
218 self.n_features = Some(n_features);
219
220 let mut centroids = Array2::<f64>::zeros((self.config.n_clusters, n_features));
222 for i in 0..self.config.n_clusters {
223 for j in 0..n_features {
224 centroids[[i, j]] = self.rng.gen_range(-1.0..1.0);
225 }
226 }
227
228 self.centroids = Some(centroids);
229 }
230
231 Ok(())
232 }
233
234 fn find_closest_centroid(&self, point: &ArrayView1<f64>) -> ClusterResult<usize> {
236 let centroids = self
237 .centroids
238 .as_ref()
239 .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
240
241 let mut min_distance = f64::INFINITY;
242 let mut closest_centroid = 0;
243
244 for (i, centroid) in centroids.outer_iter().enumerate() {
245 let distance = self.compute_distance(point, ¢roid)?;
246 if distance < min_distance {
247 min_distance = distance;
248 closest_centroid = i;
249 }
250 }
251
252 Ok(closest_centroid)
253 }
254
255 fn compute_distance(
257 &self,
258 point1: &ArrayView1<f64>,
259 point2: &ArrayView1<f64>,
260 ) -> ClusterResult<f64> {
261 let diff = point1 - point2;
262 let distance = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
263 Ok(distance)
264 }
265
266 fn update_centroid(&mut self, cluster_id: usize, point: &ArrayView1<f64>) -> ClusterResult<()> {
268 let centroids = self
269 .centroids
270 .as_mut()
271 .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
272
273 self.cluster_counts[cluster_id] += 1;
274 let count = self.cluster_counts[cluster_id] as f64;
275
276 let lr = if let Some(fixed_lr) = self.config.learning_rate {
278 fixed_lr
279 } else {
280 (1.0 / count).max(self.config.min_learning_rate)
282 };
283
284 self.current_learning_rate = lr;
285
286 let mut centroid = centroids.row_mut(cluster_id);
288 for (i, &point_val) in point.iter().enumerate() {
289 let current_val = centroid[i];
290 centroid[i] = current_val + lr * (point_val - current_val);
291 }
292
293 Ok(())
294 }
295
296 fn update_drift_detection(
298 &mut self,
299 point: &ArrayView1<f64>,
300 cluster_id: usize,
301 ) -> ClusterResult<()> {
302 let centroids = self
303 .centroids
304 .as_ref()
305 .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
306
307 let centroid = centroids.row(cluster_id);
308 let distance = self.compute_distance(point, ¢roid)?;
309
310 self.drift_history.push_back(distance);
312 if self.drift_history.len() > self.config.drift_window_size {
313 self.drift_history.pop_front();
314 }
315
316 Ok(())
317 }
318
319 fn tensor_to_array1(&self, tensor: &Tensor) -> ClusterResult<Array1<f64>> {
321 let tensor_shape = tensor.shape();
322 let shape = tensor_shape.dims();
323 if shape.len() != 1 && (shape.len() != 2 || shape[0] != 1) {
324 return Err(ClusterError::InvalidInput(
325 "Expected 1D tensor or single-row 2D tensor".to_string(),
326 ));
327 }
328
329 let data_f32: Vec<f32> = tensor.to_vec().map_err(ClusterError::TensorError)?;
330 let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
331
332 let n_features = if shape.len() == 1 { shape[0] } else { shape[1] };
333 Array1::from_vec(data)
334 .to_shape(n_features)
335 .map(|array| array.into_owned())
336 .map_err(|_| ClusterError::InvalidInput("Failed to reshape tensor".to_string()))
337 }
338
339 fn array2_to_tensor(&self, array: &Array2<f64>) -> ClusterResult<Tensor> {
341 let (rows, cols) = array.dim();
342 let data_f64: Vec<f64> = array.iter().copied().collect();
343 let data: Vec<f32> = data_f64.into_iter().map(|x| x as f32).collect();
344 Tensor::from_vec(data, &[rows, cols]).map_err(ClusterError::TensorError)
345 }
346}
347
348impl IncrementalClustering for OnlineKMeans {
349 type Result = OnlineKMeansResult;
350
351 fn update_single(&mut self, point: &Tensor) -> ClusterResult<()> {
352 let point_array = self.tensor_to_array1(point)?;
353 let n_features = point_array.len();
354
355 self.initialize_centroids(n_features)?;
357
358 let closest_centroid = self.find_closest_centroid(&point_array.view())?;
360
361 self.update_centroid(closest_centroid, &point_array.view())?;
363
364 self.update_drift_detection(&point_array.view(), closest_centroid)?;
366
367 self.n_points_seen += 1;
368
369 Ok(())
370 }
371
372 fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()> {
373 let batch_shape = batch.shape();
374 let shape = batch_shape.dims();
375 if shape.len() != 2 {
376 return Err(ClusterError::InvalidInput(
377 "Expected 2D batch tensor".to_string(),
378 ));
379 }
380
381 let n_samples = shape[0];
382 let n_features = shape[1];
383
384 self.initialize_centroids(n_features)?;
386
387 let data_f32: Vec<f32> = batch.to_vec().map_err(ClusterError::TensorError)?;
388 let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
389 let data_array = Array2::from_shape_vec((n_samples, n_features), data)
390 .map_err(|_| ClusterError::InvalidInput("Failed to reshape batch data".to_string()))?;
391
392 for i in 0..n_samples {
394 let point = data_array.row(i);
395 let closest_centroid = self.find_closest_centroid(&point)?;
396 self.update_centroid(closest_centroid, &point)?;
397 self.update_drift_detection(&point, closest_centroid)?;
398 self.n_points_seen += 1;
399 }
400
401 Ok(())
402 }
403
404 fn get_current_result(&self) -> ClusterResult<Self::Result> {
405 let centroids = self
406 .centroids
407 .as_ref()
408 .ok_or_else(|| ClusterError::ConfigError("No data processed yet".to_string()))?;
409
410 let centroids_tensor = self.array2_to_tensor(centroids)?;
411
412 let avg_distance = if self.drift_history.is_empty() {
414 0.0
415 } else {
416 self.drift_history.iter().sum::<f64>() / self.drift_history.len() as f64
417 };
418
419 Ok(OnlineKMeansResult {
420 centroids: centroids_tensor,
421 labels: None, cluster_counts: self.cluster_counts.clone(),
423 n_points_seen: self.n_points_seen,
424 current_learning_rate: self.current_learning_rate,
425 drift_detected: self.detect_drift(),
426 avg_intra_cluster_distance: avg_distance,
427 })
428 }
429
430 fn reset(&mut self) {
431 self.centroids = None;
432 self.cluster_counts = vec![0; self.config.n_clusters];
433 self.n_points_seen = 0;
434 self.current_learning_rate = 1.0;
435 self.drift_history.clear();
436 self.n_features = None;
437 }
438
439 fn detect_drift(&self) -> bool {
440 if self.drift_history.len() < self.config.drift_window_size / 2 {
441 return false;
442 }
443
444 let recent_window = self.drift_history.len() / 2;
446 let recent_avg: f64 = self
447 .drift_history
448 .iter()
449 .rev()
450 .take(recent_window)
451 .sum::<f64>()
452 / recent_window as f64;
453 let historical_avg: f64 =
454 self.drift_history.iter().take(recent_window).sum::<f64>() / recent_window as f64;
455
456 recent_avg > historical_avg * (1.0 + self.config.drift_threshold)
458 }
459
460 fn n_points_seen(&self) -> usize {
461 self.n_points_seen
462 }
463}
464
465impl ClusteringAlgorithm for OnlineKMeans {
466 fn name(&self) -> &str {
467 "Online K-Means"
468 }
469
470 fn get_params(&self) -> HashMap<String, String> {
471 let mut params = HashMap::new();
472 params.insert("n_clusters".to_string(), self.config.n_clusters.to_string());
473 params.insert(
474 "drift_threshold".to_string(),
475 self.config.drift_threshold.to_string(),
476 );
477 params.insert(
478 "decay_factor".to_string(),
479 self.config.decay_factor.to_string(),
480 );
481 if let Some(lr) = self.config.learning_rate {
482 params.insert("learning_rate".to_string(), lr.to_string());
483 }
484 params
485 }
486
487 fn set_params(&mut self, params: HashMap<String, String>) -> ClusterResult<()> {
488 for (key, value) in params {
489 match key.as_str() {
490 "n_clusters" => {
491 let n_clusters = value.parse().map_err(|_| {
492 ClusterError::ConfigError(format!("Invalid n_clusters: {}", value))
493 })?;
494 self.config.n_clusters = n_clusters;
495 self.cluster_counts = vec![0; n_clusters];
496 }
497 "drift_threshold" => {
498 self.config.drift_threshold = value.parse().map_err(|_| {
499 ClusterError::ConfigError(format!("Invalid drift_threshold: {}", value))
500 })?;
501 }
502 "learning_rate" => {
503 if value == "adaptive" {
504 self.config.learning_rate = None;
505 } else {
506 self.config.learning_rate = Some(value.parse().map_err(|_| {
507 ClusterError::ConfigError(format!("Invalid learning_rate: {}", value))
508 })?);
509 }
510 }
511 _ => {
512 return Err(ClusterError::ConfigError(format!(
513 "Unknown parameter: {}",
514 key
515 )));
516 }
517 }
518 }
519 Ok(())
520 }
521
522 fn is_fitted(&self) -> bool {
523 self.centroids.is_some()
524 }
525}
526
527impl Fit for OnlineKMeans {
528 type Result = OnlineKMeansResult;
529
530 fn fit(&self, data: &Tensor) -> ClusterResult<Self::Result> {
531 let mut online_kmeans = self.clone();
532 online_kmeans.update_batch(data)?;
533 online_kmeans.get_current_result()
534 }
535}
536
537impl FitPredict for OnlineKMeans {
538 type Result = OnlineKMeansResult;
539
540 fn fit_predict(&self, data: &Tensor) -> ClusterResult<Self::Result> {
541 self.fit(data)
542 }
543}
544
545impl Clone for OnlineKMeans {
547 fn clone(&self) -> Self {
548 let rng = seeded_rng(self.config.random_state.unwrap_or(42));
549
550 Self {
551 config: self.config.clone(),
552 centroids: self.centroids.clone(),
553 cluster_counts: self.cluster_counts.clone(),
554 n_points_seen: self.n_points_seen,
555 current_learning_rate: self.current_learning_rate,
556 drift_history: self.drift_history.clone(),
557 rng,
558 n_features: self.n_features,
559 }
560 }
561}
562
563#[derive(Debug, Clone)]
565#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
566pub struct SlidingWindowConfig {
567 pub n_clusters: usize,
569 pub window_size: usize,
571 pub recompute_frequency: usize,
573 pub random_state: Option<u64>,
575 pub max_iters: usize,
577 pub tolerance: f64,
579}
580
581impl Default for SlidingWindowConfig {
582 fn default() -> Self {
583 Self {
584 n_clusters: 8,
585 window_size: 1000,
586 recompute_frequency: 100,
587 random_state: None,
588 max_iters: 10,
589 tolerance: 1e-4,
590 }
591 }
592}
593
594#[derive(Debug, Clone)]
596pub struct SlidingWindowResult {
597 pub centroids: Tensor,
599 pub labels: Tensor,
601 pub cluster_counts: Vec<usize>,
603 pub n_points_seen: usize,
605 pub window_fill: usize,
607 pub n_recomputations: usize,
609}
610
611impl ClusteringResult for SlidingWindowResult {
612 fn labels(&self) -> &Tensor {
613 &self.labels
614 }
615
616 fn n_clusters(&self) -> usize {
617 self.centroids.shape().dims()[0]
618 }
619
620 fn centers(&self) -> Option<&Tensor> {
621 Some(&self.centroids)
622 }
623
624 fn converged(&self) -> bool {
625 self.n_points_seen > 100 }
627
628 fn n_iter(&self) -> Option<usize> {
629 Some(self.n_recomputations)
630 }
631
632 fn metadata(&self) -> Option<&HashMap<String, String>> {
633 None
634 }
635}
636
637#[derive(Debug)]
705pub struct SlidingWindowKMeans {
706 config: SlidingWindowConfig,
707 window: VecDeque<Array1<f64>>,
709 centroids: Option<Array2<f64>>,
711 n_points_seen: usize,
713 n_recomputations: usize,
715 points_since_recompute: usize,
717 rng: CoreRandom<StdRng>,
719 n_features: Option<usize>,
721}
722
723impl SlidingWindowKMeans {
724 pub fn new(config: SlidingWindowConfig) -> ClusterResult<Self> {
726 let seed = config.random_state.unwrap_or_else(|| {
727 use std::time::{SystemTime, UNIX_EPOCH};
728 SystemTime::now()
729 .duration_since(UNIX_EPOCH)
730 .expect("system time should be after UNIX_EPOCH")
731 .as_secs()
732 });
733 let rng = seeded_rng(seed);
734
735 Ok(Self {
736 config,
737 window: VecDeque::with_capacity(1000),
738 centroids: None,
739 n_points_seen: 0,
740 n_recomputations: 0,
741 points_since_recompute: 0,
742 rng,
743 n_features: None,
744 })
745 }
746
747 pub fn with_params(n_clusters: usize, window_size: usize) -> ClusterResult<Self> {
749 let config = SlidingWindowConfig {
750 n_clusters,
751 window_size,
752 ..Default::default()
753 };
754 Self::new(config)
755 }
756
757 pub fn window_size(mut self, size: usize) -> Self {
759 self.config.window_size = size;
760 self
761 }
762
763 pub fn recompute_frequency(mut self, frequency: usize) -> Self {
765 self.config.recompute_frequency = frequency;
766 self
767 }
768
769 fn initialize_centroids(&mut self) -> ClusterResult<()> {
771 if self.window.is_empty() {
772 return Err(ClusterError::ConfigError(
773 "Cannot initialize centroids from empty window".to_string(),
774 ));
775 }
776
777 let n_features = self.window[0].len();
778 self.n_features = Some(n_features);
779
780 let n_points = self.window.len();
781 let k = self.config.n_clusters.min(n_points);
782
783 let mut window_array = Array2::<f64>::zeros((n_points, n_features));
785 for (i, point) in self.window.iter().enumerate() {
786 for (j, &val) in point.iter().enumerate() {
787 window_array[[i, j]] = val;
788 }
789 }
790
791 let mut centroids = Array2::<f64>::zeros((k, n_features));
793
794 let first_idx = self.rng.gen_range(0..n_points);
796 centroids.row_mut(0).assign(&window_array.row(first_idx));
797
798 for i in 1..k {
800 let mut distances = vec![f64::INFINITY; n_points];
802 for (point_idx, point) in window_array.outer_iter().enumerate() {
803 let mut min_dist = f64::INFINITY;
804 for centroid in centroids.outer_iter().take(i) {
805 let dist = self.euclidean_distance(&point, ¢roid);
806 min_dist = min_dist.min(dist);
807 }
808 distances[point_idx] = min_dist;
809 }
810
811 let sum_sq_dist: f64 = distances.iter().map(|d| d * d).sum();
813 let mut target = self.rng.gen_range(0.0..sum_sq_dist);
814
815 let mut chosen_idx = 0;
816 for (idx, &dist) in distances.iter().enumerate() {
817 target -= dist * dist;
818 if target <= 0.0 {
819 chosen_idx = idx;
820 break;
821 }
822 }
823
824 centroids.row_mut(i).assign(&window_array.row(chosen_idx));
825 }
826
827 self.centroids = Some(centroids);
828 Ok(())
829 }
830
831 fn recompute_centroids(&mut self) -> ClusterResult<()> {
833 if self.window.is_empty() {
834 return Ok(());
835 }
836
837 if self.centroids.is_none() {
839 self.initialize_centroids()?;
840 }
841
842 let n_points = self.window.len();
843 let n_features = self.window[0].len();
844 let k = self.config.n_clusters.min(n_points);
845
846 let mut window_array = Array2::<f64>::zeros((n_points, n_features));
848 for (i, point) in self.window.iter().enumerate() {
849 for (j, &val) in point.iter().enumerate() {
850 window_array[[i, j]] = val;
851 }
852 }
853
854 let mut centroids = self
855 .centroids
856 .clone()
857 .expect("centroids should be initialized before recomputation");
858
859 for _iter in 0..self.config.max_iters {
861 let old_centroids = centroids.clone();
862
863 let mut labels = vec![0usize; n_points];
865 for (i, point) in window_array.outer_iter().enumerate() {
866 let mut min_dist = f64::INFINITY;
867 let mut closest = 0;
868 for (j, centroid) in centroids.outer_iter().enumerate() {
869 let dist = self.euclidean_distance(&point, ¢roid);
870 if dist < min_dist {
871 min_dist = dist;
872 closest = j;
873 }
874 }
875 labels[i] = closest;
876 }
877
878 centroids.fill(0.0);
880 let mut counts = vec![0usize; k];
881
882 for (i, &label) in labels.iter().enumerate() {
883 for (j, &val) in window_array.row(i).iter().enumerate() {
884 centroids[[label, j]] += val;
885 }
886 counts[label] += 1;
887 }
888
889 for i in 0..k {
890 if counts[i] > 0 {
891 for j in 0..n_features {
892 centroids[[i, j]] /= counts[i] as f64;
893 }
894 }
895 }
896
897 let mut max_shift: f64 = 0.0;
899 for (old_row, new_row) in old_centroids.outer_iter().zip(centroids.outer_iter()) {
900 let shift = self.euclidean_distance(&old_row, &new_row);
901 max_shift = max_shift.max(shift);
902 }
903
904 if max_shift < self.config.tolerance {
905 break;
906 }
907 }
908
909 self.centroids = Some(centroids);
910 self.n_recomputations += 1;
911 self.points_since_recompute = 0;
912
913 Ok(())
914 }
915
916 fn euclidean_distance(&self, p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
918 let mut sum_sq = 0.0;
919 for (a, b) in p1.iter().zip(p2.iter()) {
920 let diff = a - b;
921 sum_sq += diff * diff;
922 }
923 sum_sq.sqrt()
924 }
925
926 fn tensor_to_array1(&self, tensor: &Tensor) -> ClusterResult<Array1<f64>> {
928 let tensor_shape = tensor.shape();
929 let shape = tensor_shape.dims();
930 if shape.len() != 1 && (shape.len() != 2 || shape[0] != 1) {
931 return Err(ClusterError::InvalidInput(
932 "Expected 1D tensor or single-row 2D tensor".to_string(),
933 ));
934 }
935
936 let data_f32: Vec<f32> = tensor.to_vec().map_err(ClusterError::TensorError)?;
937 let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
938
939 let n_features = if shape.len() == 1 { shape[0] } else { shape[1] };
940 Array1::from_vec(data)
941 .to_shape(n_features)
942 .map(|array| array.into_owned())
943 .map_err(|_| ClusterError::InvalidInput("Failed to reshape tensor".to_string()))
944 }
945
946 fn array2_to_tensor(&self, array: &Array2<f64>) -> ClusterResult<Tensor> {
948 let (rows, cols) = array.dim();
949 let data_f64: Vec<f64> = array.iter().copied().collect();
950 let data: Vec<f32> = data_f64.into_iter().map(|x| x as f32).collect();
951 Tensor::from_vec(data, &[rows, cols]).map_err(ClusterError::TensorError)
952 }
953
954 fn vec_to_tensor(&self, data: Vec<f64>, shape: &[usize]) -> ClusterResult<Tensor> {
956 let data_f32: Vec<f32> = data.into_iter().map(|x| x as f32).collect();
957 Tensor::from_vec(data_f32, shape).map_err(ClusterError::TensorError)
958 }
959}
960
961impl IncrementalClustering for SlidingWindowKMeans {
962 type Result = SlidingWindowResult;
963
964 fn update_single(&mut self, point: &Tensor) -> ClusterResult<()> {
965 let point_array = self.tensor_to_array1(point)?;
966
967 if self.n_features.is_none() {
969 self.n_features = Some(point_array.len());
970 }
971
972 self.window.push_back(point_array);
974
975 if self.window.len() > self.config.window_size {
977 self.window.pop_front();
978 }
979
980 self.n_points_seen += 1;
981 self.points_since_recompute += 1;
982
983 if self.points_since_recompute >= self.config.recompute_frequency
985 || self.centroids.is_none()
986 {
987 self.recompute_centroids()?;
988 }
989
990 Ok(())
991 }
992
993 fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()> {
994 let batch_shape = batch.shape();
995 let shape = batch_shape.dims();
996 if shape.len() != 2 {
997 return Err(ClusterError::InvalidInput(
998 "Expected 2D batch tensor".to_string(),
999 ));
1000 }
1001
1002 let n_samples = shape[0];
1003 let n_features = shape[1];
1004
1005 if self.n_features.is_none() {
1006 self.n_features = Some(n_features);
1007 }
1008
1009 let data_f32: Vec<f32> = batch.to_vec().map_err(ClusterError::TensorError)?;
1010 let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
1011 let data_array = Array2::from_shape_vec((n_samples, n_features), data)
1012 .map_err(|_| ClusterError::InvalidInput("Failed to reshape batch data".to_string()))?;
1013
1014 for row in data_array.outer_iter() {
1015 let point_array = row.to_owned();
1016 self.window.push_back(point_array);
1017
1018 if self.window.len() > self.config.window_size {
1019 self.window.pop_front();
1020 }
1021
1022 self.n_points_seen += 1;
1023 self.points_since_recompute += 1;
1024 }
1025
1026 if self.points_since_recompute >= self.config.recompute_frequency
1028 || self.centroids.is_none()
1029 {
1030 self.recompute_centroids()?;
1031 }
1032
1033 Ok(())
1034 }
1035
1036 fn get_current_result(&self) -> ClusterResult<Self::Result> {
1037 let centroids = self
1038 .centroids
1039 .as_ref()
1040 .ok_or_else(|| ClusterError::ConfigError("No data processed yet".to_string()))?;
1041
1042 let centroids_tensor = self.array2_to_tensor(centroids)?;
1043
1044 let mut labels = Vec::with_capacity(self.window.len());
1046 let mut cluster_counts = vec![0usize; self.config.n_clusters];
1047
1048 for point in &self.window {
1049 let mut min_dist = f64::INFINITY;
1050 let mut closest = 0;
1051 for (i, centroid) in centroids.outer_iter().enumerate() {
1052 let dist = self.euclidean_distance(&point.view(), ¢roid);
1053 if dist < min_dist {
1054 min_dist = dist;
1055 closest = i;
1056 }
1057 }
1058 labels.push(closest as f64);
1059 cluster_counts[closest] += 1;
1060 }
1061
1062 let labels_tensor = self.vec_to_tensor(labels, &[self.window.len()])?;
1063
1064 Ok(SlidingWindowResult {
1065 centroids: centroids_tensor,
1066 labels: labels_tensor,
1067 cluster_counts,
1068 n_points_seen: self.n_points_seen,
1069 window_fill: self.window.len(),
1070 n_recomputations: self.n_recomputations,
1071 })
1072 }
1073
1074 fn reset(&mut self) {
1075 self.window.clear();
1076 self.centroids = None;
1077 self.n_points_seen = 0;
1078 self.n_recomputations = 0;
1079 self.points_since_recompute = 0;
1080 self.n_features = None;
1081 }
1082
1083 fn detect_drift(&self) -> bool {
1084 self.n_recomputations > 10 && self.n_points_seen / self.n_recomputations.max(1) < 50
1087 }
1088
1089 fn n_points_seen(&self) -> usize {
1090 self.n_points_seen
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097
1098 #[test]
1099 fn test_online_kmeans_basic() -> ClusterResult<()> {
1100 let mut online_kmeans = OnlineKMeans::new(2)?;
1101
1102 for i in 0..10 {
1104 let point = if i < 5 {
1105 Tensor::from_vec(vec![0.0 + i as f32 * 0.1, 0.0], &[2])?
1106 } else {
1107 Tensor::from_vec(vec![5.0 + (i - 5) as f32 * 0.1, 5.0], &[2])?
1108 };
1109
1110 online_kmeans.update_single(&point)?;
1111 }
1112
1113 let result = online_kmeans.get_current_result()?;
1114 assert_eq!(result.n_clusters(), 2);
1115 assert_eq!(result.n_points_seen, 10);
1116 assert!(result.centroids.shape().dims() == &[2, 2]);
1117
1118 Ok(())
1119 }
1120
1121 #[test]
1122 fn test_online_kmeans_batch() -> ClusterResult<()> {
1123 let mut online_kmeans = OnlineKMeans::new(2)?;
1124
1125 let batch = Tensor::from_vec(vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1], &[4, 2])?;
1126
1127 online_kmeans.update_batch(&batch)?;
1128
1129 let result = online_kmeans.get_current_result()?;
1130 assert_eq!(result.n_clusters(), 2);
1131 assert_eq!(result.n_points_seen, 4);
1132
1133 Ok(())
1134 }
1135
1136 #[test]
1137 fn test_drift_detection() -> ClusterResult<()> {
1138 let mut online_kmeans = OnlineKMeans::new(2)?.drift_threshold(0.1);
1139
1140 for i in 0..100 {
1142 let point = Tensor::from_vec(vec![i as f32 * 0.01, 0.0], &[2])?;
1143 online_kmeans.update_single(&point)?;
1144 }
1145
1146 let _initial_drift = online_kmeans.detect_drift();
1147
1148 for i in 0..50 {
1150 let point = Tensor::from_vec(vec![100.0 + i as f32, 100.0], &[2])?;
1151 online_kmeans.update_single(&point)?;
1152 }
1153
1154 let final_result = online_kmeans.get_current_result()?;
1157 assert!(final_result.n_points_seen == 150);
1158
1159 Ok(())
1160 }
1161
1162 #[test]
1163 fn test_sliding_window_basic() -> ClusterResult<()> {
1164 let config = SlidingWindowConfig {
1165 n_clusters: 2,
1166 window_size: 50,
1167 recompute_frequency: 10,
1168 ..Default::default()
1169 };
1170
1171 let mut sliding = SlidingWindowKMeans::new(config)?;
1172
1173 for i in 0..100 {
1175 let point = if i % 2 == 0 {
1176 Tensor::from_vec(vec![0.0 + (i as f32) * 0.01, 0.0], &[2])?
1177 } else {
1178 Tensor::from_vec(vec![10.0 + (i as f32) * 0.01, 10.0], &[2])?
1179 };
1180
1181 sliding.update_single(&point)?;
1182 }
1183
1184 let result = sliding.get_current_result()?;
1185 assert!(result.n_clusters() >= 1);
1187 assert!(result.n_clusters() <= 2);
1188 assert_eq!(result.window_fill, 50); assert_eq!(result.n_points_seen, 100);
1190 assert!(result.n_recomputations > 0);
1191
1192 Ok(())
1193 }
1194
1195 #[test]
1196 fn test_sliding_window_batch() -> ClusterResult<()> {
1197 let config = SlidingWindowConfig {
1198 n_clusters: 2,
1199 window_size: 20,
1200 recompute_frequency: 10,
1201 ..Default::default()
1202 };
1203
1204 let mut sliding = SlidingWindowKMeans::new(config)?;
1205
1206 let batch = Tensor::from_vec(
1207 vec![
1208 0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2, 5.3, 5.3,
1209 ],
1210 &[8, 2],
1211 )?;
1212
1213 sliding.update_batch(&batch)?;
1214
1215 let result = sliding.get_current_result()?;
1216 assert_eq!(result.n_clusters(), 2);
1217 assert_eq!(result.window_fill, 8);
1218 assert_eq!(result.n_points_seen, 8);
1219
1220 Ok(())
1221 }
1222
1223 #[test]
1224 fn test_sliding_window_expiration() -> ClusterResult<()> {
1225 let config = SlidingWindowConfig {
1226 n_clusters: 2,
1227 window_size: 10,
1228 recompute_frequency: 5,
1229 ..Default::default()
1230 };
1231
1232 let mut sliding = SlidingWindowKMeans::new(config)?;
1233
1234 for i in 0..20 {
1236 let point = Tensor::from_vec(vec![i as f32, 0.0], &[2])?;
1237 sliding.update_single(&point)?;
1238 }
1239
1240 let result = sliding.get_current_result()?;
1241
1242 assert_eq!(result.window_fill, 10);
1244 assert_eq!(result.n_points_seen, 20);
1245
1246 assert_eq!(result.labels.shape().dims()[0], 10);
1248
1249 Ok(())
1250 }
1251
1252 #[test]
1253 fn test_sliding_window_recomputation() -> ClusterResult<()> {
1254 let config = SlidingWindowConfig {
1255 n_clusters: 2,
1256 window_size: 50,
1257 recompute_frequency: 10,
1258 ..Default::default()
1259 };
1260
1261 let mut sliding = SlidingWindowKMeans::new(config)?;
1262
1263 for i in 0..50 {
1265 let point = Tensor::from_vec(vec![i as f32 * 0.1, 0.0], &[2])?;
1266 sliding.update_single(&point)?;
1267 }
1268
1269 let result = sliding.get_current_result()?;
1270
1271 assert!(result.n_recomputations >= 4);
1274 assert!(result.n_recomputations <= 6);
1275
1276 Ok(())
1277 }
1278
1279 #[test]
1280 fn test_sliding_window_reset() -> ClusterResult<()> {
1281 let config = SlidingWindowConfig {
1282 n_clusters: 2,
1283 window_size: 20,
1284 recompute_frequency: 5,
1285 ..Default::default()
1286 };
1287
1288 let mut sliding = SlidingWindowKMeans::new(config)?;
1289
1290 for i in 0..10 {
1292 let point = Tensor::from_vec(vec![i as f32, 0.0], &[2])?;
1293 sliding.update_single(&point)?;
1294 }
1295
1296 sliding.reset();
1298
1299 assert_eq!(sliding.n_points_seen(), 0);
1301
1302 let point = Tensor::from_vec(vec![1.0, 1.0], &[2])?;
1304 sliding.update_single(&point)?;
1305
1306 assert_eq!(sliding.n_points_seen(), 1);
1307
1308 Ok(())
1309 }
1310
1311 #[test]
1312 fn test_sliding_window_drift_adaptation() -> ClusterResult<()> {
1313 let config = SlidingWindowConfig {
1314 n_clusters: 2,
1315 window_size: 30,
1316 recompute_frequency: 10,
1317 ..Default::default()
1318 };
1319
1320 let mut sliding = SlidingWindowKMeans::new(config)?;
1321
1322 for i in 0..20 {
1324 let point = if i < 10 {
1325 Tensor::from_vec(vec![i as f32 * 0.1, 0.0], &[2])?
1326 } else {
1327 Tensor::from_vec(vec![5.0 + (i - 10) as f32 * 0.1, 5.0], &[2])?
1328 };
1329 sliding.update_single(&point)?;
1330 }
1331
1332 let result1 = sliding.get_current_result()?;
1333 let centroids1 = result1
1334 .centroids
1335 .to_vec()
1336 .expect("centroids conversion should succeed");
1337
1338 for i in 0..30 {
1340 let point = if i < 15 {
1341 Tensor::from_vec(vec![10.0 + i as f32 * 0.1, 10.0], &[2])?
1342 } else {
1343 Tensor::from_vec(vec![15.0 + (i - 15) as f32 * 0.1, 15.0], &[2])?
1344 };
1345 sliding.update_single(&point)?;
1346 }
1347
1348 let result2 = sliding.get_current_result()?;
1349 let centroids2 = result2
1350 .centroids
1351 .to_vec()
1352 .expect("centroids conversion should succeed");
1353
1354 let mut changed = false;
1358 for i in 0..centroids1.len().min(centroids2.len()) {
1359 if (centroids1[i] - centroids2[i]).abs() > 1.0 {
1360 changed = true;
1361 break;
1362 }
1363 }
1364
1365 assert!(changed, "Centroids should adapt to distribution shift");
1366
1367 Ok(())
1368 }
1369}