Skip to main content

torsh_cluster/algorithms/
incremental.rs

1//! Incremental and Online Clustering Algorithms
2//!
3//! This module provides streaming and incremental clustering algorithms that can
4//! process data points one at a time or in mini-batches, maintaining cluster
5//! models that adapt to concept drift and evolving data distributions.
6//!
7//! # Algorithms Included
8//!
9//! - **Online K-Means**: Incremental update of centroids with adaptive learning
10//! - **Sliding Window Clustering**: Maintains clusters over a temporal window
11//! - **Concept Drift Detection**: Detects changes in data distribution
12
13use crate::error::{ClusterError, ClusterResult};
14use crate::traits::{ClusteringAlgorithm, ClusteringResult, Fit, FitPredict};
15use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
16use scirs2_core::random::{seeded_rng, CoreRandom};
17// Using SciRS2 re-exported StdRng to avoid direct rand dependency (SciRS2 POLICY)
18use scirs2_core::random::rngs::StdRng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21use std::collections::{HashMap, VecDeque};
22use torsh_tensor::Tensor;
23
24/// Trait for incremental clustering algorithms that can process streaming data
25pub trait IncrementalClustering {
26    type Result: ClusteringResult;
27
28    /// Process a single data point and update the model
29    fn update_single(&mut self, point: &Tensor) -> ClusterResult<()>;
30
31    /// Process a batch of data points
32    fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()>;
33
34    /// Get current clustering state
35    fn get_current_result(&self) -> ClusterResult<Self::Result>;
36
37    /// Reset the clustering model
38    fn reset(&mut self);
39
40    /// Check if concept drift is detected
41    fn detect_drift(&self) -> bool;
42
43    /// Get the number of points processed so far
44    fn n_points_seen(&self) -> usize;
45}
46
47/// Online K-Means clustering configuration
48#[derive(Debug, Clone)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub struct OnlineKMeansConfig {
51    /// Number of clusters
52    pub n_clusters: usize,
53    /// Learning rate for centroid updates (adaptive if None)
54    pub learning_rate: Option<f64>,
55    /// Decay factor for learning rate adaptation
56    pub decay_factor: f64,
57    /// Minimum learning rate
58    pub min_learning_rate: f64,
59    /// Concept drift detection threshold
60    pub drift_threshold: f64,
61    /// Random seed for reproducibility
62    pub random_state: Option<u64>,
63    /// Window size for drift detection
64    pub drift_window_size: usize,
65}
66
67impl Default for OnlineKMeansConfig {
68    fn default() -> Self {
69        Self {
70            n_clusters: 8,
71            learning_rate: None, // Adaptive
72            decay_factor: 0.9,
73            min_learning_rate: 1e-6,
74            drift_threshold: 0.1,
75            random_state: None,
76            drift_window_size: 1000,
77        }
78    }
79}
80
81/// Online K-Means clustering result
82#[derive(Debug, Clone)]
83pub struct OnlineKMeansResult {
84    /// Current cluster centroids
85    pub centroids: Tensor,
86    /// Cluster assignments for recent points (if available)
87    pub labels: Option<Tensor>,
88    /// Number of points assigned to each cluster
89    pub cluster_counts: Vec<usize>,
90    /// Total points processed
91    pub n_points_seen: usize,
92    /// Current learning rate
93    pub current_learning_rate: f64,
94    /// Whether concept drift was detected
95    pub drift_detected: bool,
96    /// Average intra-cluster distance (for drift detection)
97    pub avg_intra_cluster_distance: f64,
98}
99
100impl ClusteringResult for OnlineKMeansResult {
101    fn labels(&self) -> &Tensor {
102        self.labels
103            .as_ref()
104            .unwrap_or_else(|| panic!("Labels not available for online clustering result"))
105    }
106
107    fn n_clusters(&self) -> usize {
108        self.centroids.shape().dims()[0]
109    }
110
111    fn centers(&self) -> Option<&Tensor> {
112        Some(&self.centroids)
113    }
114
115    fn converged(&self) -> bool {
116        self.n_points_seen > 100 // Consider "converged" after processing enough points
117    }
118
119    fn n_iter(&self) -> Option<usize> {
120        Some(self.n_points_seen)
121    }
122
123    fn metadata(&self) -> Option<&HashMap<String, String>> {
124        None
125    }
126}
127
128/// Online K-Means clustering algorithm for streaming data
129///
130/// This implementation can process data points incrementally and adapt to
131/// concept drift in the data distribution.
132///
133/// # Example
134///
135/// ```rust
136/// use torsh_cluster::algorithms::incremental::{OnlineKMeans, IncrementalClustering};
137/// use torsh_tensor::creation::randn;
138///
139/// let mut online_kmeans = OnlineKMeans::new(3)?;
140///
141/// // Process streaming data points
142/// for i in 0..1000 {
143///     let point = randn::<f32>(&[1, 2])?;
144///     online_kmeans.update_single(&point)?;
145///
146///     if online_kmeans.detect_drift() {
147///         println!("Concept drift detected at point {}", i);
148///     }
149/// }
150///
151/// let result = online_kmeans.get_current_result()?;
152/// println!("Final centroids: {:?}", result.centroids);
153/// # Ok::<(), Box<dyn std::error::Error>>(())
154/// ```
155#[derive(Debug)]
156pub struct OnlineKMeans {
157    config: OnlineKMeansConfig,
158    centroids: Option<Array2<f64>>,
159    cluster_counts: Vec<usize>,
160    n_points_seen: usize,
161    current_learning_rate: f64,
162    drift_history: VecDeque<f64>,
163    rng: CoreRandom<StdRng>,
164    n_features: Option<usize>,
165}
166
167impl OnlineKMeans {
168    /// Create a new Online K-Means algorithm
169    pub fn new(n_clusters: usize) -> ClusterResult<Self> {
170        let config = OnlineKMeansConfig {
171            n_clusters,
172            ..Default::default()
173        };
174
175        let seed = config.random_state.unwrap_or_else(|| {
176            use std::time::{SystemTime, UNIX_EPOCH};
177            SystemTime::now()
178                .duration_since(UNIX_EPOCH)
179                .expect("system time should be after UNIX_EPOCH")
180                .as_secs()
181        });
182        let rng = seeded_rng(seed);
183
184        Ok(Self {
185            config,
186            centroids: None,
187            cluster_counts: vec![0; n_clusters],
188            n_points_seen: 0,
189            current_learning_rate: 1.0,
190            drift_history: VecDeque::with_capacity(1000),
191            rng,
192            n_features: None,
193        })
194    }
195
196    /// Set learning rate (None for adaptive)
197    pub fn learning_rate(mut self, learning_rate: Option<f64>) -> Self {
198        self.config.learning_rate = learning_rate;
199        self
200    }
201
202    /// Set drift detection threshold
203    pub fn drift_threshold(mut self, threshold: f64) -> Self {
204        self.config.drift_threshold = threshold;
205        self
206    }
207
208    /// Set random seed
209    pub fn random_state(mut self, seed: u64) -> Self {
210        self.config.random_state = Some(seed);
211        self.rng = seeded_rng(seed);
212        self
213    }
214
215    /// Initialize centroids if not already done
216    fn initialize_centroids(&mut self, n_features: usize) -> ClusterResult<()> {
217        if self.centroids.is_none() {
218            self.n_features = Some(n_features);
219
220            // Initialize centroids randomly
221            let mut centroids = Array2::<f64>::zeros((self.config.n_clusters, n_features));
222            for i in 0..self.config.n_clusters {
223                for j in 0..n_features {
224                    centroids[[i, j]] = self.rng.gen_range(-1.0..1.0);
225                }
226            }
227
228            self.centroids = Some(centroids);
229        }
230
231        Ok(())
232    }
233
234    /// Find the closest centroid to a point
235    fn find_closest_centroid(&self, point: &ArrayView1<f64>) -> ClusterResult<usize> {
236        let centroids = self
237            .centroids
238            .as_ref()
239            .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
240
241        let mut min_distance = f64::INFINITY;
242        let mut closest_centroid = 0;
243
244        for (i, centroid) in centroids.outer_iter().enumerate() {
245            let distance = self.compute_distance(point, &centroid)?;
246            if distance < min_distance {
247                min_distance = distance;
248                closest_centroid = i;
249            }
250        }
251
252        Ok(closest_centroid)
253    }
254
255    /// Compute Euclidean distance between two points
256    fn compute_distance(
257        &self,
258        point1: &ArrayView1<f64>,
259        point2: &ArrayView1<f64>,
260    ) -> ClusterResult<f64> {
261        let diff = point1 - point2;
262        let distance = diff.iter().map(|x| x * x).sum::<f64>().sqrt();
263        Ok(distance)
264    }
265
266    /// Update centroid with new point using online learning
267    fn update_centroid(&mut self, cluster_id: usize, point: &ArrayView1<f64>) -> ClusterResult<()> {
268        let centroids = self
269            .centroids
270            .as_mut()
271            .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
272
273        self.cluster_counts[cluster_id] += 1;
274        let count = self.cluster_counts[cluster_id] as f64;
275
276        // Compute learning rate
277        let lr = if let Some(fixed_lr) = self.config.learning_rate {
278            fixed_lr
279        } else {
280            // Adaptive learning rate: 1/count
281            (1.0 / count).max(self.config.min_learning_rate)
282        };
283
284        self.current_learning_rate = lr;
285
286        // Update centroid: centroid = centroid + lr * (point - centroid)
287        let mut centroid = centroids.row_mut(cluster_id);
288        for (i, &point_val) in point.iter().enumerate() {
289            let current_val = centroid[i];
290            centroid[i] = current_val + lr * (point_val - current_val);
291        }
292
293        Ok(())
294    }
295
296    /// Detect concept drift based on recent clustering quality
297    fn update_drift_detection(
298        &mut self,
299        point: &ArrayView1<f64>,
300        cluster_id: usize,
301    ) -> ClusterResult<()> {
302        let centroids = self
303            .centroids
304            .as_ref()
305            .ok_or_else(|| ClusterError::ConfigError("Centroids not initialized".to_string()))?;
306
307        let centroid = centroids.row(cluster_id);
308        let distance = self.compute_distance(point, &centroid)?;
309
310        // Add to drift history
311        self.drift_history.push_back(distance);
312        if self.drift_history.len() > self.config.drift_window_size {
313            self.drift_history.pop_front();
314        }
315
316        Ok(())
317    }
318
319    /// Convert ndarray point to Array1
320    fn tensor_to_array1(&self, tensor: &Tensor) -> ClusterResult<Array1<f64>> {
321        let tensor_shape = tensor.shape();
322        let shape = tensor_shape.dims();
323        if shape.len() != 1 && (shape.len() != 2 || shape[0] != 1) {
324            return Err(ClusterError::InvalidInput(
325                "Expected 1D tensor or single-row 2D tensor".to_string(),
326            ));
327        }
328
329        let data_f32: Vec<f32> = tensor.to_vec().map_err(ClusterError::TensorError)?;
330        let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
331
332        let n_features = if shape.len() == 1 { shape[0] } else { shape[1] };
333        Array1::from_vec(data)
334            .to_shape(n_features)
335            .map(|array| array.into_owned())
336            .map_err(|_| ClusterError::InvalidInput("Failed to reshape tensor".to_string()))
337    }
338
339    /// Convert Array2 to Tensor
340    fn array2_to_tensor(&self, array: &Array2<f64>) -> ClusterResult<Tensor> {
341        let (rows, cols) = array.dim();
342        let data_f64: Vec<f64> = array.iter().copied().collect();
343        let data: Vec<f32> = data_f64.into_iter().map(|x| x as f32).collect();
344        Tensor::from_vec(data, &[rows, cols]).map_err(ClusterError::TensorError)
345    }
346}
347
348impl IncrementalClustering for OnlineKMeans {
349    type Result = OnlineKMeansResult;
350
351    fn update_single(&mut self, point: &Tensor) -> ClusterResult<()> {
352        let point_array = self.tensor_to_array1(point)?;
353        let n_features = point_array.len();
354
355        // Initialize centroids if this is the first point
356        self.initialize_centroids(n_features)?;
357
358        // Find closest centroid
359        let closest_centroid = self.find_closest_centroid(&point_array.view())?;
360
361        // Update centroid
362        self.update_centroid(closest_centroid, &point_array.view())?;
363
364        // Update drift detection
365        self.update_drift_detection(&point_array.view(), closest_centroid)?;
366
367        self.n_points_seen += 1;
368
369        Ok(())
370    }
371
372    fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()> {
373        let batch_shape = batch.shape();
374        let shape = batch_shape.dims();
375        if shape.len() != 2 {
376            return Err(ClusterError::InvalidInput(
377                "Expected 2D batch tensor".to_string(),
378            ));
379        }
380
381        let n_samples = shape[0];
382        let n_features = shape[1];
383
384        // Initialize centroids if this is the first batch
385        self.initialize_centroids(n_features)?;
386
387        let data_f32: Vec<f32> = batch.to_vec().map_err(ClusterError::TensorError)?;
388        let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
389        let data_array = Array2::from_shape_vec((n_samples, n_features), data)
390            .map_err(|_| ClusterError::InvalidInput("Failed to reshape batch data".to_string()))?;
391
392        // Process each point in the batch
393        for i in 0..n_samples {
394            let point = data_array.row(i);
395            let closest_centroid = self.find_closest_centroid(&point)?;
396            self.update_centroid(closest_centroid, &point)?;
397            self.update_drift_detection(&point, closest_centroid)?;
398            self.n_points_seen += 1;
399        }
400
401        Ok(())
402    }
403
404    fn get_current_result(&self) -> ClusterResult<Self::Result> {
405        let centroids = self
406            .centroids
407            .as_ref()
408            .ok_or_else(|| ClusterError::ConfigError("No data processed yet".to_string()))?;
409
410        let centroids_tensor = self.array2_to_tensor(centroids)?;
411
412        // Compute average intra-cluster distance for drift detection
413        let avg_distance = if self.drift_history.is_empty() {
414            0.0
415        } else {
416            self.drift_history.iter().sum::<f64>() / self.drift_history.len() as f64
417        };
418
419        Ok(OnlineKMeansResult {
420            centroids: centroids_tensor,
421            labels: None, // Not available for online clustering
422            cluster_counts: self.cluster_counts.clone(),
423            n_points_seen: self.n_points_seen,
424            current_learning_rate: self.current_learning_rate,
425            drift_detected: self.detect_drift(),
426            avg_intra_cluster_distance: avg_distance,
427        })
428    }
429
430    fn reset(&mut self) {
431        self.centroids = None;
432        self.cluster_counts = vec![0; self.config.n_clusters];
433        self.n_points_seen = 0;
434        self.current_learning_rate = 1.0;
435        self.drift_history.clear();
436        self.n_features = None;
437    }
438
439    fn detect_drift(&self) -> bool {
440        if self.drift_history.len() < self.config.drift_window_size / 2 {
441            return false;
442        }
443
444        // Simple drift detection: compare recent vs. historical performance
445        let recent_window = self.drift_history.len() / 2;
446        let recent_avg: f64 = self
447            .drift_history
448            .iter()
449            .rev()
450            .take(recent_window)
451            .sum::<f64>()
452            / recent_window as f64;
453        let historical_avg: f64 =
454            self.drift_history.iter().take(recent_window).sum::<f64>() / recent_window as f64;
455
456        // Drift detected if recent performance significantly worse
457        recent_avg > historical_avg * (1.0 + self.config.drift_threshold)
458    }
459
460    fn n_points_seen(&self) -> usize {
461        self.n_points_seen
462    }
463}
464
465impl ClusteringAlgorithm for OnlineKMeans {
466    fn name(&self) -> &str {
467        "Online K-Means"
468    }
469
470    fn get_params(&self) -> HashMap<String, String> {
471        let mut params = HashMap::new();
472        params.insert("n_clusters".to_string(), self.config.n_clusters.to_string());
473        params.insert(
474            "drift_threshold".to_string(),
475            self.config.drift_threshold.to_string(),
476        );
477        params.insert(
478            "decay_factor".to_string(),
479            self.config.decay_factor.to_string(),
480        );
481        if let Some(lr) = self.config.learning_rate {
482            params.insert("learning_rate".to_string(), lr.to_string());
483        }
484        params
485    }
486
487    fn set_params(&mut self, params: HashMap<String, String>) -> ClusterResult<()> {
488        for (key, value) in params {
489            match key.as_str() {
490                "n_clusters" => {
491                    let n_clusters = value.parse().map_err(|_| {
492                        ClusterError::ConfigError(format!("Invalid n_clusters: {}", value))
493                    })?;
494                    self.config.n_clusters = n_clusters;
495                    self.cluster_counts = vec![0; n_clusters];
496                }
497                "drift_threshold" => {
498                    self.config.drift_threshold = value.parse().map_err(|_| {
499                        ClusterError::ConfigError(format!("Invalid drift_threshold: {}", value))
500                    })?;
501                }
502                "learning_rate" => {
503                    if value == "adaptive" {
504                        self.config.learning_rate = None;
505                    } else {
506                        self.config.learning_rate = Some(value.parse().map_err(|_| {
507                            ClusterError::ConfigError(format!("Invalid learning_rate: {}", value))
508                        })?);
509                    }
510                }
511                _ => {
512                    return Err(ClusterError::ConfigError(format!(
513                        "Unknown parameter: {}",
514                        key
515                    )));
516                }
517            }
518        }
519        Ok(())
520    }
521
522    fn is_fitted(&self) -> bool {
523        self.centroids.is_some()
524    }
525}
526
527impl Fit for OnlineKMeans {
528    type Result = OnlineKMeansResult;
529
530    fn fit(&self, data: &Tensor) -> ClusterResult<Self::Result> {
531        let mut online_kmeans = self.clone();
532        online_kmeans.update_batch(data)?;
533        online_kmeans.get_current_result()
534    }
535}
536
537impl FitPredict for OnlineKMeans {
538    type Result = OnlineKMeansResult;
539
540    fn fit_predict(&self, data: &Tensor) -> ClusterResult<Self::Result> {
541        self.fit(data)
542    }
543}
544
545// Need to implement Clone for OnlineKMeans
546impl Clone for OnlineKMeans {
547    fn clone(&self) -> Self {
548        let rng = seeded_rng(self.config.random_state.unwrap_or(42));
549
550        Self {
551            config: self.config.clone(),
552            centroids: self.centroids.clone(),
553            cluster_counts: self.cluster_counts.clone(),
554            n_points_seen: self.n_points_seen,
555            current_learning_rate: self.current_learning_rate,
556            drift_history: self.drift_history.clone(),
557            rng,
558            n_features: self.n_features,
559        }
560    }
561}
562
563/// Sliding Window K-Means configuration
564#[derive(Debug, Clone)]
565#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
566pub struct SlidingWindowConfig {
567    /// Number of clusters
568    pub n_clusters: usize,
569    /// Window size (number of recent points to keep)
570    pub window_size: usize,
571    /// How often to recompute centroids (in number of points)
572    pub recompute_frequency: usize,
573    /// Random seed for reproducibility
574    pub random_state: Option<u64>,
575    /// Maximum iterations for centroid recomputation
576    pub max_iters: usize,
577    /// Convergence tolerance for centroid recomputation
578    pub tolerance: f64,
579}
580
581impl Default for SlidingWindowConfig {
582    fn default() -> Self {
583        Self {
584            n_clusters: 8,
585            window_size: 1000,
586            recompute_frequency: 100,
587            random_state: None,
588            max_iters: 10,
589            tolerance: 1e-4,
590        }
591    }
592}
593
594/// Sliding Window K-Means result
595#[derive(Debug, Clone)]
596pub struct SlidingWindowResult {
597    /// Current cluster centroids
598    pub centroids: Tensor,
599    /// Labels for points in the current window
600    pub labels: Tensor,
601    /// Number of points in each cluster (from current window)
602    pub cluster_counts: Vec<usize>,
603    /// Total points processed (including expired)
604    pub n_points_seen: usize,
605    /// Number of points currently in window
606    pub window_fill: usize,
607    /// Number of times centroids have been recomputed
608    pub n_recomputations: usize,
609}
610
611impl ClusteringResult for SlidingWindowResult {
612    fn labels(&self) -> &Tensor {
613        &self.labels
614    }
615
616    fn n_clusters(&self) -> usize {
617        self.centroids.shape().dims()[0]
618    }
619
620    fn centers(&self) -> Option<&Tensor> {
621        Some(&self.centroids)
622    }
623
624    fn converged(&self) -> bool {
625        self.n_points_seen > 100 // Consider converged after processing enough points
626    }
627
628    fn n_iter(&self) -> Option<usize> {
629        Some(self.n_recomputations)
630    }
631
632    fn metadata(&self) -> Option<&HashMap<String, String>> {
633        None
634    }
635}
636
637/// Sliding Window K-Means clustering for non-stationary streams
638///
639/// Maintains a fixed-size window of recent data points and periodically
640/// recomputes centroids from this window. Old points automatically expire
641/// when the window is full.
642///
643/// # Mathematical Foundation
644///
645/// Unlike Online K-Means which updates centroids incrementally, Sliding Window
646/// K-Means maintains explicit storage of recent points:
647///
648/// ```text
649/// Window W(t) = {x_{t-w+1}, x_{t-w+2}, ..., x_t}
650/// ```
651///
652/// When a new point x_{t+1} arrives and window is full:
653/// 1. Remove oldest point x_{t-w+1}
654/// 2. Add new point x_{t+1}
655/// 3. If recomputation triggered, run full K-Means on W(t+1)
656///
657/// # Advantages over Online K-Means
658///
659/// - **Adapts to drift**: Old data is discarded, preventing outdated patterns
660/// - **Full optimization**: Periodic recomputation finds better centroids
661/// - **Stable clusters**: Less sensitive to individual outliers
662///
663/// # Disadvantages
664///
665/// - **Memory usage**: O(window_size × n_features)
666/// - **Computation cost**: Periodic full K-Means on window
667/// - **Latency spikes**: Recomputation can cause delays
668///
669/// # Parameters
670///
671/// - **window_size**: Number of recent points to maintain (default: 1000)
672/// - **recompute_frequency**: Recompute centroids every N points (default: 100)
673/// - **n_clusters**: Number of clusters to find
674///
675/// # Example
676///
677/// ```rust
678/// use torsh_cluster::algorithms::incremental::{
679///     SlidingWindowKMeans, IncrementalClustering, SlidingWindowConfig
680/// };
681/// use torsh_tensor::Tensor;
682///
683/// let config = SlidingWindowConfig {
684///     n_clusters: 3,
685///     window_size: 500,
686///     recompute_frequency: 50,
687///     ..Default::default()
688/// };
689///
690/// let mut sliding_window = SlidingWindowKMeans::new(config)?;
691///
692/// // Process streaming data
693/// for i in 0..1000 {
694///     let point = Tensor::from_vec(vec![(i % 10) as f32, (i / 10) as f32], &[2])?;
695///     sliding_window.update_single(&point)?;
696///
697///     if i % 100 == 0 {
698///         let result = sliding_window.get_current_result()?;
699///         println!("Iteration {}: {} points in window", i, result.window_fill);
700///     }
701/// }
702/// # Ok::<(), Box<dyn std::error::Error>>(())
703/// ```
704#[derive(Debug)]
705pub struct SlidingWindowKMeans {
706    config: SlidingWindowConfig,
707    /// Sliding window of recent points
708    window: VecDeque<Array1<f64>>,
709    /// Current centroids
710    centroids: Option<Array2<f64>>,
711    /// Number of points processed
712    n_points_seen: usize,
713    /// Number of centroids recomputations performed
714    n_recomputations: usize,
715    /// Points since last recomputation
716    points_since_recompute: usize,
717    /// RNG for initialization
718    rng: CoreRandom<StdRng>,
719    /// Number of features
720    n_features: Option<usize>,
721}
722
723impl SlidingWindowKMeans {
724    /// Create a new Sliding Window K-Means algorithm
725    pub fn new(config: SlidingWindowConfig) -> ClusterResult<Self> {
726        let seed = config.random_state.unwrap_or_else(|| {
727            use std::time::{SystemTime, UNIX_EPOCH};
728            SystemTime::now()
729                .duration_since(UNIX_EPOCH)
730                .expect("system time should be after UNIX_EPOCH")
731                .as_secs()
732        });
733        let rng = seeded_rng(seed);
734
735        Ok(Self {
736            config,
737            window: VecDeque::with_capacity(1000),
738            centroids: None,
739            n_points_seen: 0,
740            n_recomputations: 0,
741            points_since_recompute: 0,
742            rng,
743            n_features: None,
744        })
745    }
746
747    /// Create with default config and specified parameters
748    pub fn with_params(n_clusters: usize, window_size: usize) -> ClusterResult<Self> {
749        let config = SlidingWindowConfig {
750            n_clusters,
751            window_size,
752            ..Default::default()
753        };
754        Self::new(config)
755    }
756
757    /// Set window size
758    pub fn window_size(mut self, size: usize) -> Self {
759        self.config.window_size = size;
760        self
761    }
762
763    /// Set recompute frequency
764    pub fn recompute_frequency(mut self, frequency: usize) -> Self {
765        self.config.recompute_frequency = frequency;
766        self
767    }
768
769    /// Initialize centroids using K-means++ on current window
770    fn initialize_centroids(&mut self) -> ClusterResult<()> {
771        if self.window.is_empty() {
772            return Err(ClusterError::ConfigError(
773                "Cannot initialize centroids from empty window".to_string(),
774            ));
775        }
776
777        let n_features = self.window[0].len();
778        self.n_features = Some(n_features);
779
780        let n_points = self.window.len();
781        let k = self.config.n_clusters.min(n_points);
782
783        // Convert window to Array2
784        let mut window_array = Array2::<f64>::zeros((n_points, n_features));
785        for (i, point) in self.window.iter().enumerate() {
786            for (j, &val) in point.iter().enumerate() {
787                window_array[[i, j]] = val;
788            }
789        }
790
791        // K-means++ initialization
792        let mut centroids = Array2::<f64>::zeros((k, n_features));
793
794        // Choose first centroid randomly
795        let first_idx = self.rng.gen_range(0..n_points);
796        centroids.row_mut(0).assign(&window_array.row(first_idx));
797
798        // Choose remaining centroids
799        for i in 1..k {
800            // Compute distances to nearest centroid
801            let mut distances = vec![f64::INFINITY; n_points];
802            for (point_idx, point) in window_array.outer_iter().enumerate() {
803                let mut min_dist = f64::INFINITY;
804                for centroid in centroids.outer_iter().take(i) {
805                    let dist = self.euclidean_distance(&point, &centroid);
806                    min_dist = min_dist.min(dist);
807                }
808                distances[point_idx] = min_dist;
809            }
810
811            // Choose next centroid with probability proportional to distance²
812            let sum_sq_dist: f64 = distances.iter().map(|d| d * d).sum();
813            let mut target = self.rng.gen_range(0.0..sum_sq_dist);
814
815            let mut chosen_idx = 0;
816            for (idx, &dist) in distances.iter().enumerate() {
817                target -= dist * dist;
818                if target <= 0.0 {
819                    chosen_idx = idx;
820                    break;
821                }
822            }
823
824            centroids.row_mut(i).assign(&window_array.row(chosen_idx));
825        }
826
827        self.centroids = Some(centroids);
828        Ok(())
829    }
830
831    /// Recompute centroids from current window using Lloyd's algorithm
832    fn recompute_centroids(&mut self) -> ClusterResult<()> {
833        if self.window.is_empty() {
834            return Ok(());
835        }
836
837        // Initialize if needed
838        if self.centroids.is_none() {
839            self.initialize_centroids()?;
840        }
841
842        let n_points = self.window.len();
843        let n_features = self.window[0].len();
844        let k = self.config.n_clusters.min(n_points);
845
846        // Convert window to Array2
847        let mut window_array = Array2::<f64>::zeros((n_points, n_features));
848        for (i, point) in self.window.iter().enumerate() {
849            for (j, &val) in point.iter().enumerate() {
850                window_array[[i, j]] = val;
851            }
852        }
853
854        let mut centroids = self
855            .centroids
856            .clone()
857            .expect("centroids should be initialized before recomputation");
858
859        // Lloyd's algorithm iterations
860        for _iter in 0..self.config.max_iters {
861            let old_centroids = centroids.clone();
862
863            // Assignment step
864            let mut labels = vec![0usize; n_points];
865            for (i, point) in window_array.outer_iter().enumerate() {
866                let mut min_dist = f64::INFINITY;
867                let mut closest = 0;
868                for (j, centroid) in centroids.outer_iter().enumerate() {
869                    let dist = self.euclidean_distance(&point, &centroid);
870                    if dist < min_dist {
871                        min_dist = dist;
872                        closest = j;
873                    }
874                }
875                labels[i] = closest;
876            }
877
878            // Update step
879            centroids.fill(0.0);
880            let mut counts = vec![0usize; k];
881
882            for (i, &label) in labels.iter().enumerate() {
883                for (j, &val) in window_array.row(i).iter().enumerate() {
884                    centroids[[label, j]] += val;
885                }
886                counts[label] += 1;
887            }
888
889            for i in 0..k {
890                if counts[i] > 0 {
891                    for j in 0..n_features {
892                        centroids[[i, j]] /= counts[i] as f64;
893                    }
894                }
895            }
896
897            // Check convergence
898            let mut max_shift: f64 = 0.0;
899            for (old_row, new_row) in old_centroids.outer_iter().zip(centroids.outer_iter()) {
900                let shift = self.euclidean_distance(&old_row, &new_row);
901                max_shift = max_shift.max(shift);
902            }
903
904            if max_shift < self.config.tolerance {
905                break;
906            }
907        }
908
909        self.centroids = Some(centroids);
910        self.n_recomputations += 1;
911        self.points_since_recompute = 0;
912
913        Ok(())
914    }
915
916    /// Compute Euclidean distance between two points
917    fn euclidean_distance(&self, p1: &ArrayView1<f64>, p2: &ArrayView1<f64>) -> f64 {
918        let mut sum_sq = 0.0;
919        for (a, b) in p1.iter().zip(p2.iter()) {
920            let diff = a - b;
921            sum_sq += diff * diff;
922        }
923        sum_sq.sqrt()
924    }
925
926    /// Convert tensor to Array1
927    fn tensor_to_array1(&self, tensor: &Tensor) -> ClusterResult<Array1<f64>> {
928        let tensor_shape = tensor.shape();
929        let shape = tensor_shape.dims();
930        if shape.len() != 1 && (shape.len() != 2 || shape[0] != 1) {
931            return Err(ClusterError::InvalidInput(
932                "Expected 1D tensor or single-row 2D tensor".to_string(),
933            ));
934        }
935
936        let data_f32: Vec<f32> = tensor.to_vec().map_err(ClusterError::TensorError)?;
937        let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
938
939        let n_features = if shape.len() == 1 { shape[0] } else { shape[1] };
940        Array1::from_vec(data)
941            .to_shape(n_features)
942            .map(|array| array.into_owned())
943            .map_err(|_| ClusterError::InvalidInput("Failed to reshape tensor".to_string()))
944    }
945
946    /// Convert Array2 to Tensor
947    fn array2_to_tensor(&self, array: &Array2<f64>) -> ClusterResult<Tensor> {
948        let (rows, cols) = array.dim();
949        let data_f64: Vec<f64> = array.iter().copied().collect();
950        let data: Vec<f32> = data_f64.into_iter().map(|x| x as f32).collect();
951        Tensor::from_vec(data, &[rows, cols]).map_err(ClusterError::TensorError)
952    }
953
954    /// Convert Vec to Tensor
955    fn vec_to_tensor(&self, data: Vec<f64>, shape: &[usize]) -> ClusterResult<Tensor> {
956        let data_f32: Vec<f32> = data.into_iter().map(|x| x as f32).collect();
957        Tensor::from_vec(data_f32, shape).map_err(ClusterError::TensorError)
958    }
959}
960
961impl IncrementalClustering for SlidingWindowKMeans {
962    type Result = SlidingWindowResult;
963
964    fn update_single(&mut self, point: &Tensor) -> ClusterResult<()> {
965        let point_array = self.tensor_to_array1(point)?;
966
967        // Initialize n_features if first point
968        if self.n_features.is_none() {
969            self.n_features = Some(point_array.len());
970        }
971
972        // Add point to window
973        self.window.push_back(point_array);
974
975        // Remove oldest point if window is full
976        if self.window.len() > self.config.window_size {
977            self.window.pop_front();
978        }
979
980        self.n_points_seen += 1;
981        self.points_since_recompute += 1;
982
983        // Recompute centroids if needed
984        if self.points_since_recompute >= self.config.recompute_frequency
985            || self.centroids.is_none()
986        {
987            self.recompute_centroids()?;
988        }
989
990        Ok(())
991    }
992
993    fn update_batch(&mut self, batch: &Tensor) -> ClusterResult<()> {
994        let batch_shape = batch.shape();
995        let shape = batch_shape.dims();
996        if shape.len() != 2 {
997            return Err(ClusterError::InvalidInput(
998                "Expected 2D batch tensor".to_string(),
999            ));
1000        }
1001
1002        let n_samples = shape[0];
1003        let n_features = shape[1];
1004
1005        if self.n_features.is_none() {
1006            self.n_features = Some(n_features);
1007        }
1008
1009        let data_f32: Vec<f32> = batch.to_vec().map_err(ClusterError::TensorError)?;
1010        let data: Vec<f64> = data_f32.into_iter().map(|x| x as f64).collect();
1011        let data_array = Array2::from_shape_vec((n_samples, n_features), data)
1012            .map_err(|_| ClusterError::InvalidInput("Failed to reshape batch data".to_string()))?;
1013
1014        for row in data_array.outer_iter() {
1015            let point_array = row.to_owned();
1016            self.window.push_back(point_array);
1017
1018            if self.window.len() > self.config.window_size {
1019                self.window.pop_front();
1020            }
1021
1022            self.n_points_seen += 1;
1023            self.points_since_recompute += 1;
1024        }
1025
1026        // Recompute after processing batch
1027        if self.points_since_recompute >= self.config.recompute_frequency
1028            || self.centroids.is_none()
1029        {
1030            self.recompute_centroids()?;
1031        }
1032
1033        Ok(())
1034    }
1035
1036    fn get_current_result(&self) -> ClusterResult<Self::Result> {
1037        let centroids = self
1038            .centroids
1039            .as_ref()
1040            .ok_or_else(|| ClusterError::ConfigError("No data processed yet".to_string()))?;
1041
1042        let centroids_tensor = self.array2_to_tensor(centroids)?;
1043
1044        // Compute labels for current window
1045        let mut labels = Vec::with_capacity(self.window.len());
1046        let mut cluster_counts = vec![0usize; self.config.n_clusters];
1047
1048        for point in &self.window {
1049            let mut min_dist = f64::INFINITY;
1050            let mut closest = 0;
1051            for (i, centroid) in centroids.outer_iter().enumerate() {
1052                let dist = self.euclidean_distance(&point.view(), &centroid);
1053                if dist < min_dist {
1054                    min_dist = dist;
1055                    closest = i;
1056                }
1057            }
1058            labels.push(closest as f64);
1059            cluster_counts[closest] += 1;
1060        }
1061
1062        let labels_tensor = self.vec_to_tensor(labels, &[self.window.len()])?;
1063
1064        Ok(SlidingWindowResult {
1065            centroids: centroids_tensor,
1066            labels: labels_tensor,
1067            cluster_counts,
1068            n_points_seen: self.n_points_seen,
1069            window_fill: self.window.len(),
1070            n_recomputations: self.n_recomputations,
1071        })
1072    }
1073
1074    fn reset(&mut self) {
1075        self.window.clear();
1076        self.centroids = None;
1077        self.n_points_seen = 0;
1078        self.n_recomputations = 0;
1079        self.points_since_recompute = 0;
1080        self.n_features = None;
1081    }
1082
1083    fn detect_drift(&self) -> bool {
1084        // Simplified drift detection: check if recomputations are happening frequently
1085        // In a stationary distribution, centroids would stabilize
1086        self.n_recomputations > 10 && self.n_points_seen / self.n_recomputations.max(1) < 50
1087    }
1088
1089    fn n_points_seen(&self) -> usize {
1090        self.n_points_seen
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097
1098    #[test]
1099    fn test_online_kmeans_basic() -> ClusterResult<()> {
1100        let mut online_kmeans = OnlineKMeans::new(2)?;
1101
1102        // Process some points
1103        for i in 0..10 {
1104            let point = if i < 5 {
1105                Tensor::from_vec(vec![0.0 + i as f32 * 0.1, 0.0], &[2])?
1106            } else {
1107                Tensor::from_vec(vec![5.0 + (i - 5) as f32 * 0.1, 5.0], &[2])?
1108            };
1109
1110            online_kmeans.update_single(&point)?;
1111        }
1112
1113        let result = online_kmeans.get_current_result()?;
1114        assert_eq!(result.n_clusters(), 2);
1115        assert_eq!(result.n_points_seen, 10);
1116        assert!(result.centroids.shape().dims() == &[2, 2]);
1117
1118        Ok(())
1119    }
1120
1121    #[test]
1122    fn test_online_kmeans_batch() -> ClusterResult<()> {
1123        let mut online_kmeans = OnlineKMeans::new(2)?;
1124
1125        let batch = Tensor::from_vec(vec![0.0, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1], &[4, 2])?;
1126
1127        online_kmeans.update_batch(&batch)?;
1128
1129        let result = online_kmeans.get_current_result()?;
1130        assert_eq!(result.n_clusters(), 2);
1131        assert_eq!(result.n_points_seen, 4);
1132
1133        Ok(())
1134    }
1135
1136    #[test]
1137    fn test_drift_detection() -> ClusterResult<()> {
1138        let mut online_kmeans = OnlineKMeans::new(2)?.drift_threshold(0.1);
1139
1140        // Process normal points
1141        for i in 0..100 {
1142            let point = Tensor::from_vec(vec![i as f32 * 0.01, 0.0], &[2])?;
1143            online_kmeans.update_single(&point)?;
1144        }
1145
1146        let _initial_drift = online_kmeans.detect_drift();
1147
1148        // Introduce outliers (potential drift)
1149        for i in 0..50 {
1150            let point = Tensor::from_vec(vec![100.0 + i as f32, 100.0], &[2])?;
1151            online_kmeans.update_single(&point)?;
1152        }
1153
1154        // Drift detection should eventually trigger
1155        // (Note: Simple test - in practice drift detection is complex)
1156        let final_result = online_kmeans.get_current_result()?;
1157        assert!(final_result.n_points_seen == 150);
1158
1159        Ok(())
1160    }
1161
1162    #[test]
1163    fn test_sliding_window_basic() -> ClusterResult<()> {
1164        let config = SlidingWindowConfig {
1165            n_clusters: 2,
1166            window_size: 50,
1167            recompute_frequency: 10,
1168            ..Default::default()
1169        };
1170
1171        let mut sliding = SlidingWindowKMeans::new(config)?;
1172
1173        // Process points from two clusters, alternating to keep both in window
1174        for i in 0..100 {
1175            let point = if i % 2 == 0 {
1176                Tensor::from_vec(vec![0.0 + (i as f32) * 0.01, 0.0], &[2])?
1177            } else {
1178                Tensor::from_vec(vec![10.0 + (i as f32) * 0.01, 10.0], &[2])?
1179            };
1180
1181            sliding.update_single(&point)?;
1182        }
1183
1184        let result = sliding.get_current_result()?;
1185        // May find 1 or 2 clusters depending on initialization
1186        assert!(result.n_clusters() >= 1);
1187        assert!(result.n_clusters() <= 2);
1188        assert_eq!(result.window_fill, 50); // Window size is 50
1189        assert_eq!(result.n_points_seen, 100);
1190        assert!(result.n_recomputations > 0);
1191
1192        Ok(())
1193    }
1194
1195    #[test]
1196    fn test_sliding_window_batch() -> ClusterResult<()> {
1197        let config = SlidingWindowConfig {
1198            n_clusters: 2,
1199            window_size: 20,
1200            recompute_frequency: 10,
1201            ..Default::default()
1202        };
1203
1204        let mut sliding = SlidingWindowKMeans::new(config)?;
1205
1206        let batch = Tensor::from_vec(
1207            vec![
1208                0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2, 5.3, 5.3,
1209            ],
1210            &[8, 2],
1211        )?;
1212
1213        sliding.update_batch(&batch)?;
1214
1215        let result = sliding.get_current_result()?;
1216        assert_eq!(result.n_clusters(), 2);
1217        assert_eq!(result.window_fill, 8);
1218        assert_eq!(result.n_points_seen, 8);
1219
1220        Ok(())
1221    }
1222
1223    #[test]
1224    fn test_sliding_window_expiration() -> ClusterResult<()> {
1225        let config = SlidingWindowConfig {
1226            n_clusters: 2,
1227            window_size: 10,
1228            recompute_frequency: 5,
1229            ..Default::default()
1230        };
1231
1232        let mut sliding = SlidingWindowKMeans::new(config)?;
1233
1234        // Add more points than window size
1235        for i in 0..20 {
1236            let point = Tensor::from_vec(vec![i as f32, 0.0], &[2])?;
1237            sliding.update_single(&point)?;
1238        }
1239
1240        let result = sliding.get_current_result()?;
1241
1242        // Window should only contain last 10 points
1243        assert_eq!(result.window_fill, 10);
1244        assert_eq!(result.n_points_seen, 20);
1245
1246        // Labels should match window size
1247        assert_eq!(result.labels.shape().dims()[0], 10);
1248
1249        Ok(())
1250    }
1251
1252    #[test]
1253    fn test_sliding_window_recomputation() -> ClusterResult<()> {
1254        let config = SlidingWindowConfig {
1255            n_clusters: 2,
1256            window_size: 50,
1257            recompute_frequency: 10,
1258            ..Default::default()
1259        };
1260
1261        let mut sliding = SlidingWindowKMeans::new(config)?;
1262
1263        // Process points
1264        for i in 0..50 {
1265            let point = Tensor::from_vec(vec![i as f32 * 0.1, 0.0], &[2])?;
1266            sliding.update_single(&point)?;
1267        }
1268
1269        let result = sliding.get_current_result()?;
1270
1271        // Should have recomputed centroids multiple times
1272        // (50 points / 10 recompute_frequency = 5 recomputations)
1273        assert!(result.n_recomputations >= 4);
1274        assert!(result.n_recomputations <= 6);
1275
1276        Ok(())
1277    }
1278
1279    #[test]
1280    fn test_sliding_window_reset() -> ClusterResult<()> {
1281        let config = SlidingWindowConfig {
1282            n_clusters: 2,
1283            window_size: 20,
1284            recompute_frequency: 5,
1285            ..Default::default()
1286        };
1287
1288        let mut sliding = SlidingWindowKMeans::new(config)?;
1289
1290        // Process some points
1291        for i in 0..10 {
1292            let point = Tensor::from_vec(vec![i as f32, 0.0], &[2])?;
1293            sliding.update_single(&point)?;
1294        }
1295
1296        // Reset
1297        sliding.reset();
1298
1299        // Check that everything is reset
1300        assert_eq!(sliding.n_points_seen(), 0);
1301
1302        // Processing after reset should work
1303        let point = Tensor::from_vec(vec![1.0, 1.0], &[2])?;
1304        sliding.update_single(&point)?;
1305
1306        assert_eq!(sliding.n_points_seen(), 1);
1307
1308        Ok(())
1309    }
1310
1311    #[test]
1312    fn test_sliding_window_drift_adaptation() -> ClusterResult<()> {
1313        let config = SlidingWindowConfig {
1314            n_clusters: 2,
1315            window_size: 30,
1316            recompute_frequency: 10,
1317            ..Default::default()
1318        };
1319
1320        let mut sliding = SlidingWindowKMeans::new(config)?;
1321
1322        // Phase 1: Cluster around (0, 0) and (5, 5)
1323        for i in 0..20 {
1324            let point = if i < 10 {
1325                Tensor::from_vec(vec![i as f32 * 0.1, 0.0], &[2])?
1326            } else {
1327                Tensor::from_vec(vec![5.0 + (i - 10) as f32 * 0.1, 5.0], &[2])?
1328            };
1329            sliding.update_single(&point)?;
1330        }
1331
1332        let result1 = sliding.get_current_result()?;
1333        let centroids1 = result1
1334            .centroids
1335            .to_vec()
1336            .expect("centroids conversion should succeed");
1337
1338        // Phase 2: Shift clusters to (10, 10) and (15, 15)
1339        for i in 0..30 {
1340            let point = if i < 15 {
1341                Tensor::from_vec(vec![10.0 + i as f32 * 0.1, 10.0], &[2])?
1342            } else {
1343                Tensor::from_vec(vec![15.0 + (i - 15) as f32 * 0.1, 15.0], &[2])?
1344            };
1345            sliding.update_single(&point)?;
1346        }
1347
1348        let result2 = sliding.get_current_result()?;
1349        let centroids2 = result2
1350            .centroids
1351            .to_vec()
1352            .expect("centroids conversion should succeed");
1353
1354        // Centroids should have adapted to new distribution
1355        // (Old points expired from window)
1356        // Check that centroids changed significantly
1357        let mut changed = false;
1358        for i in 0..centroids1.len().min(centroids2.len()) {
1359            if (centroids1[i] - centroids2[i]).abs() > 1.0 {
1360                changed = true;
1361                break;
1362            }
1363        }
1364
1365        assert!(changed, "Centroids should adapt to distribution shift");
1366
1367        Ok(())
1368    }
1369}