scirs2_cluster/advanced/
online.rs

1//! Adaptive online clustering with concept drift detection
2//!
3//! This module provides implementations of online clustering algorithms that
4//! automatically adapt to changing data distributions, create new clusters when
5//! needed, merge similar clusters, and detect concept drift in streaming data.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11use std::fmt::Debug;
12
13use crate::error::{ClusteringError, Result};
14use crate::vq::euclidean_distance;
15
16/// Configuration for adaptive online clustering
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AdaptiveOnlineConfig {
19    /// Initial learning rate
20    pub initial_learning_rate: f64,
21    /// Minimum learning rate
22    pub min_learning_rate: f64,
23    /// Learning rate decay factor
24    pub learning_rate_decay: f64,
25    /// Forgetting factor for older data
26    pub forgetting_factor: f64,
27    /// Threshold for creating new clusters
28    pub cluster_creation_threshold: f64,
29    /// Maximum number of clusters allowed
30    pub max_clusters: usize,
31    /// Minimum cluster size before merging
32    pub min_cluster_size: usize,
33    /// Distance threshold for cluster merging
34    pub merge_threshold: f64,
35    /// Window size for concept drift detection
36    pub concept_drift_window: usize,
37    /// Threshold for detecting concept drift
38    pub drift_detection_threshold: f64,
39}
40
41impl Default for AdaptiveOnlineConfig {
42    fn default() -> Self {
43        Self {
44            initial_learning_rate: 0.1,
45            min_learning_rate: 0.001,
46            learning_rate_decay: 0.999,
47            forgetting_factor: 0.95,
48            cluster_creation_threshold: 2.0,
49            max_clusters: 50,
50            min_cluster_size: 10,
51            merge_threshold: 0.5,
52            concept_drift_window: 1000,
53            drift_detection_threshold: 0.3,
54        }
55    }
56}
57
58/// Adaptive online clustering with concept drift detection
59///
60/// This algorithm automatically adapts to changing data distributions,
61/// creates new clusters when needed, merges similar clusters, and detects
62/// concept drift in streaming data.
63pub struct AdaptiveOnlineClustering<F: Float> {
64    config: AdaptiveOnlineConfig,
65    clusters: Vec<OnlineCluster<F>>,
66    learning_rate: f64,
67    samples_processed: usize,
68    recent_distances: VecDeque<f64>,
69    drift_detector: ConceptDriftDetector,
70}
71
72/// Represents an online cluster with adaptive properties
73#[derive(Debug, Clone)]
74struct OnlineCluster<F: Float> {
75    /// Cluster centroid
76    centroid: Array1<F>,
77    /// Number of points assigned to this cluster
78    weight: f64,
79    /// Timestamp of last update
80    last_update: usize,
81    /// Variance estimate for this cluster
82    variance: f64,
83    /// Cluster age (for aging/forgetting)
84    age: usize,
85    /// Recent assignment history
86    recent_assignments: VecDeque<usize>,
87}
88
89/// Simple concept drift detector
90#[derive(Debug, Clone)]
91struct ConceptDriftDetector {
92    /// Recent prediction errors
93    recent_errors: VecDeque<f64>,
94    /// Baseline error rate
95    baseline_error: f64,
96    /// Window size for drift detection
97    window_size: usize,
98}
99
100impl<F: Float + FromPrimitive + Debug> AdaptiveOnlineClustering<F> {
101    /// Create a new adaptive online clustering instance
102    pub fn new(config: AdaptiveOnlineConfig) -> Self {
103        Self {
104            config: config.clone(),
105            clusters: Vec::new(),
106            learning_rate: config.initial_learning_rate,
107            samples_processed: 0,
108            recent_distances: VecDeque::with_capacity(config.concept_drift_window),
109            drift_detector: ConceptDriftDetector {
110                recent_errors: VecDeque::with_capacity(config.concept_drift_window),
111                baseline_error: 1.0,
112                window_size: config.concept_drift_window,
113            },
114        }
115    }
116
117    /// Process a single data point online
118    pub fn partial_fit(&mut self, point: ArrayView1<F>) -> Result<usize> {
119        self.samples_processed += 1;
120
121        // Find nearest cluster
122        let (nearest_cluster_idx, nearest_distance) = self.find_nearest_cluster(point);
123
124        let assigned_cluster = if let Some(cluster_idx) = nearest_cluster_idx {
125            let distance_threshold = F::from(self.config.cluster_creation_threshold).unwrap();
126
127            if nearest_distance <= distance_threshold {
128                // Update existing cluster
129                self.update_cluster(cluster_idx, point)?;
130                cluster_idx
131            } else if self.clusters.len() < self.config.max_clusters {
132                // Create new cluster
133                self.create_new_cluster(point)?
134            } else {
135                // Force assignment to nearest cluster and update threshold
136                self.update_cluster(cluster_idx, point)?;
137                cluster_idx
138            }
139        } else {
140            // No clusters exist, create first one
141            self.create_new_cluster(point)?
142        };
143
144        // Update learning rate
145        self.learning_rate = (self.learning_rate * self.config.learning_rate_decay)
146            .max(self.config.min_learning_rate);
147
148        // Track distance for concept drift detection
149        self.recent_distances
150            .push_back(nearest_distance.to_f64().unwrap_or(0.0));
151        if self.recent_distances.len() > self.config.concept_drift_window {
152            self.recent_distances.pop_front();
153        }
154
155        // Detect concept drift
156        if self.samples_processed.is_multiple_of(100) {
157            self.detect_concept_drift()?;
158        }
159
160        // Periodic maintenance
161        if self.samples_processed.is_multiple_of(1000) {
162            self.merge_similar_clusters()?;
163            self.remove_old_clusters()?;
164        }
165
166        Ok(assigned_cluster)
167    }
168
169    /// Find the nearest cluster to a point
170    fn find_nearest_cluster(&self, point: ArrayView1<F>) -> (Option<usize>, F) {
171        if self.clusters.is_empty() {
172            return (None, F::infinity());
173        }
174
175        let mut min_distance = F::infinity();
176        let mut nearest_idx = 0;
177
178        for (i, cluster) in self.clusters.iter().enumerate() {
179            let distance = euclidean_distance(point, cluster.centroid.view());
180            if distance < min_distance {
181                min_distance = distance;
182                nearest_idx = i;
183            }
184        }
185
186        (Some(nearest_idx), min_distance)
187    }
188
189    /// Update an existing cluster with a new point
190    fn update_cluster(&mut self, clusteridx: usize, point: ArrayView1<F>) -> Result<()> {
191        let cluster = &mut self.clusters[clusteridx];
192
193        // Update weight with forgetting factor
194        cluster.weight = cluster.weight * self.config.forgetting_factor + 1.0;
195
196        // Update centroid using online mean
197        let learning_rate = F::from(self.learning_rate / cluster.weight).unwrap();
198
199        Zip::from(&mut cluster.centroid)
200            .and(point)
201            .for_each(|centroid_val, &point_val| {
202                let diff = point_val - *centroid_val;
203                *centroid_val = *centroid_val + learning_rate * diff;
204            });
205
206        // Update variance estimate
207        let distance = euclidean_distance(point, cluster.centroid.view());
208        let distance_squared = distance * distance;
209        cluster.variance = cluster.variance * 0.9 + distance_squared.to_f64().unwrap_or(0.0) * 0.1;
210
211        // Update metadata
212        cluster.last_update = self.samples_processed;
213        cluster.age += 1;
214        cluster.recent_assignments.push_back(self.samples_processed);
215
216        if cluster.recent_assignments.len() > 100 {
217            cluster.recent_assignments.pop_front();
218        }
219
220        Ok(())
221    }
222
223    /// Create a new cluster
224    fn create_new_cluster(&mut self, point: ArrayView1<F>) -> Result<usize> {
225        let new_cluster = OnlineCluster {
226            centroid: point.to_owned(),
227            weight: 1.0,
228            last_update: self.samples_processed,
229            variance: 0.0,
230            age: 0,
231            recent_assignments: VecDeque::new(),
232        };
233
234        self.clusters.push(new_cluster);
235        Ok(self.clusters.len() - 1)
236    }
237
238    /// Detect concept drift in the data stream
239    fn detect_concept_drift(&mut self) -> Result<()> {
240        if self.recent_distances.len() < self.config.concept_drift_window / 2 {
241            return Ok(());
242        }
243
244        // Calculate recent mean distance
245        let recent_mean: f64 =
246            self.recent_distances.iter().sum::<f64>() / self.recent_distances.len() as f64;
247
248        // Update drift detector
249        self.drift_detector.recent_errors.push_back(recent_mean);
250        if self.drift_detector.recent_errors.len() > self.drift_detector.window_size {
251            self.drift_detector.recent_errors.pop_front();
252        }
253
254        // Calculate current error rate
255        let current_error: f64 = self.drift_detector.recent_errors.iter().sum::<f64>()
256            / self.drift_detector.recent_errors.len() as f64;
257
258        // Detect drift if current error is significantly higher than baseline
259        if current_error
260            > self.drift_detector.baseline_error * (1.0 + self.config.drift_detection_threshold)
261        {
262            // Concept drift detected - adapt by increasing learning rate temporarily
263            self.learning_rate = (self.learning_rate * 2.0).min(0.5);
264            self.drift_detector.baseline_error = current_error;
265        } else {
266            // Update baseline gradually
267            self.drift_detector.baseline_error =
268                self.drift_detector.baseline_error * 0.99 + current_error * 0.01;
269        }
270
271        Ok(())
272    }
273
274    /// Merge clusters that are too similar
275    fn merge_similar_clusters(&mut self) -> Result<()> {
276        let mut to_merge = Vec::new();
277        let merge_threshold = F::from(self.config.merge_threshold).unwrap();
278
279        // Find pairs of clusters to merge
280        for i in 0..self.clusters.len() {
281            for j in (i + 1)..self.clusters.len() {
282                let distance = euclidean_distance(
283                    self.clusters[i].centroid.view(),
284                    self.clusters[j].centroid.view(),
285                );
286
287                if distance <= merge_threshold {
288                    to_merge.push((i, j));
289                }
290            }
291        }
292
293        // Merge clusters (process in reverse order to maintain indices)
294        for (i, j) in to_merge.into_iter().rev() {
295            self.merge_clusters(i, j)?;
296        }
297
298        Ok(())
299    }
300
301    /// Merge two clusters
302    fn merge_clusters(&mut self, i: usize, j: usize) -> Result<()> {
303        if i >= self.clusters.len() || j >= self.clusters.len() || i == j {
304            return Ok(());
305        }
306
307        let (cluster_i, cluster_j) = if i < j {
308            let (left, right) = self.clusters.split_at_mut(j);
309            (&mut left[i], &mut right[0])
310        } else {
311            let (left, right) = self.clusters.split_at_mut(i);
312            (&mut right[0], &mut left[j])
313        };
314
315        // Weighted merge of centroids
316        let total_weight = cluster_i.weight + cluster_j.weight;
317        let weight_i = F::from(cluster_i.weight / total_weight).unwrap();
318        let weight_j = F::from(cluster_j.weight / total_weight).unwrap();
319
320        Zip::from(&mut cluster_i.centroid)
321            .and(&cluster_j.centroid)
322            .for_each(|cent_i, &cent_j| {
323                *cent_i = *cent_i * weight_i + cent_j * weight_j;
324            });
325
326        // Merge other properties
327        cluster_i.weight = total_weight;
328        cluster_i.variance = (cluster_i.variance + cluster_j.variance) / 2.0;
329        cluster_i.age = cluster_i.age.max(cluster_j.age);
330        cluster_i.last_update = cluster_i.last_update.max(cluster_j.last_update);
331
332        // Remove the merged cluster
333        let remove_idx = if i < j { j } else { i };
334        self.clusters.remove(remove_idx);
335
336        Ok(())
337    }
338
339    /// Remove old, inactive clusters
340    fn remove_old_clusters(&mut self) -> Result<()> {
341        let current_time = self.samples_processed;
342        let max_age = 10000; // Maximum age before considering removal
343
344        self.clusters.retain(|cluster| {
345            let age_ok = cluster.age < max_age;
346            let recent_activity = current_time - cluster.last_update < 5000;
347            let sufficient_size = cluster.weight >= self.config.min_cluster_size as f64;
348
349            age_ok && (recent_activity || sufficient_size)
350        });
351
352        Ok(())
353    }
354
355    /// Predict cluster assignment for new data
356    pub fn predict(&self, point: ArrayView1<F>) -> Result<usize> {
357        let (nearest_cluster_idx_, _distance) = self.find_nearest_cluster(point);
358
359        nearest_cluster_idx_.ok_or_else(|| {
360            ClusteringError::InvalidInput("No clusters available for prediction".to_string())
361        })
362    }
363
364    /// Get current cluster centroids
365    pub fn cluster_centers(&self) -> Array2<F> {
366        if self.clusters.is_empty() {
367            return Array2::zeros((0, 0));
368        }
369
370        let n_features = self.clusters[0].centroid.len();
371        let mut centers = Array2::zeros((self.clusters.len(), n_features));
372
373        for (i, cluster) in self.clusters.iter().enumerate() {
374            centers.row_mut(i).assign(&cluster.centroid);
375        }
376
377        centers
378    }
379
380    /// Get cluster information for analysis
381    pub fn cluster_info(&self) -> Vec<(f64, f64, usize)> {
382        self.clusters
383            .iter()
384            .map(|cluster| (cluster.weight, cluster.variance, cluster.age))
385            .collect()
386    }
387
388    /// Get number of active clusters
389    pub fn n_clusters(&self) -> usize {
390        self.clusters.len()
391    }
392}
393
394/// Convenience function for adaptive online clustering
395pub fn adaptive_online_clustering<F: Float + FromPrimitive + Debug>(
396    data: ArrayView2<F>,
397    config: Option<AdaptiveOnlineConfig>,
398) -> Result<(Array2<F>, Array1<usize>)> {
399    let config = config.unwrap_or_default();
400    let mut clusterer = AdaptiveOnlineClustering::new(config);
401
402    let n_samples = data.nrows();
403    let mut labels = Array1::zeros(n_samples);
404
405    // Process data points sequentially
406    for (i, point) in data.rows().into_iter().enumerate() {
407        labels[i] = clusterer.partial_fit(point)?;
408    }
409
410    let centers = clusterer.cluster_centers();
411    Ok((centers, labels))
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use scirs2_core::ndarray::Array2;
418
419    #[test]
420    fn test_adaptive_online_config_default() {
421        let config = AdaptiveOnlineConfig::default();
422        assert_eq!(config.initial_learning_rate, 0.1);
423        assert_eq!(config.max_clusters, 50);
424        assert_eq!(config.concept_drift_window, 1000);
425    }
426
427    #[test]
428    fn test_adaptive_online_clustering_creation() {
429        let config = AdaptiveOnlineConfig::default();
430        let clusterer = AdaptiveOnlineClustering::<f64>::new(config);
431        assert_eq!(clusterer.n_clusters(), 0);
432    }
433
434    #[test]
435    fn test_adaptive_online_clustering_simple() {
436        let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
437            .unwrap();
438
439        let config = AdaptiveOnlineConfig {
440            cluster_creation_threshold: 2.0,
441            max_clusters: 10,
442            ..Default::default()
443        };
444
445        let result = adaptive_online_clustering(data.view(), Some(config));
446        assert!(result.is_ok());
447
448        let (centers, labels) = result.unwrap();
449        assert_eq!(labels.len(), 4);
450        assert!(centers.nrows() <= 4); // Should create clusters as needed
451    }
452
453    #[test]
454    fn test_online_cluster_creation() {
455        let config = AdaptiveOnlineConfig::default();
456        let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
457
458        let point = Array1::from_vec(vec![1.0, 2.0]);
459        let cluster_id = clusterer.partial_fit(point.view()).unwrap();
460
461        assert_eq!(cluster_id, 0);
462        assert_eq!(clusterer.n_clusters(), 1);
463    }
464
465    #[test]
466    fn test_concept_drift_detection() {
467        let config = AdaptiveOnlineConfig {
468            concept_drift_window: 10,
469            drift_detection_threshold: 0.1,
470            ..Default::default()
471        };
472
473        let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
474
475        // Process some initial points
476        for i in 0..5 {
477            let point = Array1::from_vec(vec![i as f64, i as f64]);
478            clusterer.partial_fit(point.view()).unwrap();
479        }
480
481        // The drift detection should run without errors
482        assert!(clusterer.detect_concept_drift().is_ok());
483    }
484
485    #[test]
486    fn test_cluster_merging() {
487        let config = AdaptiveOnlineConfig {
488            merge_threshold: 1.0,
489            cluster_creation_threshold: 0.5,
490            ..Default::default()
491        };
492
493        let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
494
495        // Create two close clusters
496        let point1 = Array1::from_vec(vec![0.0, 0.0]);
497        let point2 = Array1::from_vec(vec![0.3, 0.3]);
498
499        clusterer.partial_fit(point1.view()).unwrap();
500        clusterer.partial_fit(point2.view()).unwrap();
501
502        // Initial clusters should exist
503        let initial_clusters = clusterer.n_clusters();
504
505        // Force merge check
506        clusterer.merge_similar_clusters().unwrap();
507
508        // Clusters might be merged if they're close enough
509        assert!(clusterer.n_clusters() <= initial_clusters);
510    }
511}