1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Zip};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11use std::fmt::Debug;
12
13use crate::error::{ClusteringError, Result};
14use crate::vq::euclidean_distance;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct AdaptiveOnlineConfig {
19 pub initial_learning_rate: f64,
21 pub min_learning_rate: f64,
23 pub learning_rate_decay: f64,
25 pub forgetting_factor: f64,
27 pub cluster_creation_threshold: f64,
29 pub max_clusters: usize,
31 pub min_cluster_size: usize,
33 pub merge_threshold: f64,
35 pub concept_drift_window: usize,
37 pub drift_detection_threshold: f64,
39}
40
41impl Default for AdaptiveOnlineConfig {
42 fn default() -> Self {
43 Self {
44 initial_learning_rate: 0.1,
45 min_learning_rate: 0.001,
46 learning_rate_decay: 0.999,
47 forgetting_factor: 0.95,
48 cluster_creation_threshold: 2.0,
49 max_clusters: 50,
50 min_cluster_size: 10,
51 merge_threshold: 0.5,
52 concept_drift_window: 1000,
53 drift_detection_threshold: 0.3,
54 }
55 }
56}
57
58pub struct AdaptiveOnlineClustering<F: Float> {
64 config: AdaptiveOnlineConfig,
65 clusters: Vec<OnlineCluster<F>>,
66 learning_rate: f64,
67 samples_processed: usize,
68 recent_distances: VecDeque<f64>,
69 drift_detector: ConceptDriftDetector,
70}
71
72#[derive(Debug, Clone)]
74struct OnlineCluster<F: Float> {
75 centroid: Array1<F>,
77 weight: f64,
79 last_update: usize,
81 variance: f64,
83 age: usize,
85 recent_assignments: VecDeque<usize>,
87}
88
89#[derive(Debug, Clone)]
91struct ConceptDriftDetector {
92 recent_errors: VecDeque<f64>,
94 baseline_error: f64,
96 window_size: usize,
98}
99
100impl<F: Float + FromPrimitive + Debug> AdaptiveOnlineClustering<F> {
101 pub fn new(config: AdaptiveOnlineConfig) -> Self {
103 Self {
104 config: config.clone(),
105 clusters: Vec::new(),
106 learning_rate: config.initial_learning_rate,
107 samples_processed: 0,
108 recent_distances: VecDeque::with_capacity(config.concept_drift_window),
109 drift_detector: ConceptDriftDetector {
110 recent_errors: VecDeque::with_capacity(config.concept_drift_window),
111 baseline_error: 1.0,
112 window_size: config.concept_drift_window,
113 },
114 }
115 }
116
117 pub fn partial_fit(&mut self, point: ArrayView1<F>) -> Result<usize> {
119 self.samples_processed += 1;
120
121 let (nearest_cluster_idx, nearest_distance) = self.find_nearest_cluster(point);
123
124 let assigned_cluster = if let Some(cluster_idx) = nearest_cluster_idx {
125 let distance_threshold = F::from(self.config.cluster_creation_threshold).unwrap();
126
127 if nearest_distance <= distance_threshold {
128 self.update_cluster(cluster_idx, point)?;
130 cluster_idx
131 } else if self.clusters.len() < self.config.max_clusters {
132 self.create_new_cluster(point)?
134 } else {
135 self.update_cluster(cluster_idx, point)?;
137 cluster_idx
138 }
139 } else {
140 self.create_new_cluster(point)?
142 };
143
144 self.learning_rate = (self.learning_rate * self.config.learning_rate_decay)
146 .max(self.config.min_learning_rate);
147
148 self.recent_distances
150 .push_back(nearest_distance.to_f64().unwrap_or(0.0));
151 if self.recent_distances.len() > self.config.concept_drift_window {
152 self.recent_distances.pop_front();
153 }
154
155 if self.samples_processed.is_multiple_of(100) {
157 self.detect_concept_drift()?;
158 }
159
160 if self.samples_processed.is_multiple_of(1000) {
162 self.merge_similar_clusters()?;
163 self.remove_old_clusters()?;
164 }
165
166 Ok(assigned_cluster)
167 }
168
169 fn find_nearest_cluster(&self, point: ArrayView1<F>) -> (Option<usize>, F) {
171 if self.clusters.is_empty() {
172 return (None, F::infinity());
173 }
174
175 let mut min_distance = F::infinity();
176 let mut nearest_idx = 0;
177
178 for (i, cluster) in self.clusters.iter().enumerate() {
179 let distance = euclidean_distance(point, cluster.centroid.view());
180 if distance < min_distance {
181 min_distance = distance;
182 nearest_idx = i;
183 }
184 }
185
186 (Some(nearest_idx), min_distance)
187 }
188
189 fn update_cluster(&mut self, clusteridx: usize, point: ArrayView1<F>) -> Result<()> {
191 let cluster = &mut self.clusters[clusteridx];
192
193 cluster.weight = cluster.weight * self.config.forgetting_factor + 1.0;
195
196 let learning_rate = F::from(self.learning_rate / cluster.weight).unwrap();
198
199 Zip::from(&mut cluster.centroid)
200 .and(point)
201 .for_each(|centroid_val, &point_val| {
202 let diff = point_val - *centroid_val;
203 *centroid_val = *centroid_val + learning_rate * diff;
204 });
205
206 let distance = euclidean_distance(point, cluster.centroid.view());
208 let distance_squared = distance * distance;
209 cluster.variance = cluster.variance * 0.9 + distance_squared.to_f64().unwrap_or(0.0) * 0.1;
210
211 cluster.last_update = self.samples_processed;
213 cluster.age += 1;
214 cluster.recent_assignments.push_back(self.samples_processed);
215
216 if cluster.recent_assignments.len() > 100 {
217 cluster.recent_assignments.pop_front();
218 }
219
220 Ok(())
221 }
222
223 fn create_new_cluster(&mut self, point: ArrayView1<F>) -> Result<usize> {
225 let new_cluster = OnlineCluster {
226 centroid: point.to_owned(),
227 weight: 1.0,
228 last_update: self.samples_processed,
229 variance: 0.0,
230 age: 0,
231 recent_assignments: VecDeque::new(),
232 };
233
234 self.clusters.push(new_cluster);
235 Ok(self.clusters.len() - 1)
236 }
237
238 fn detect_concept_drift(&mut self) -> Result<()> {
240 if self.recent_distances.len() < self.config.concept_drift_window / 2 {
241 return Ok(());
242 }
243
244 let recent_mean: f64 =
246 self.recent_distances.iter().sum::<f64>() / self.recent_distances.len() as f64;
247
248 self.drift_detector.recent_errors.push_back(recent_mean);
250 if self.drift_detector.recent_errors.len() > self.drift_detector.window_size {
251 self.drift_detector.recent_errors.pop_front();
252 }
253
254 let current_error: f64 = self.drift_detector.recent_errors.iter().sum::<f64>()
256 / self.drift_detector.recent_errors.len() as f64;
257
258 if current_error
260 > self.drift_detector.baseline_error * (1.0 + self.config.drift_detection_threshold)
261 {
262 self.learning_rate = (self.learning_rate * 2.0).min(0.5);
264 self.drift_detector.baseline_error = current_error;
265 } else {
266 self.drift_detector.baseline_error =
268 self.drift_detector.baseline_error * 0.99 + current_error * 0.01;
269 }
270
271 Ok(())
272 }
273
274 fn merge_similar_clusters(&mut self) -> Result<()> {
276 let mut to_merge = Vec::new();
277 let merge_threshold = F::from(self.config.merge_threshold).unwrap();
278
279 for i in 0..self.clusters.len() {
281 for j in (i + 1)..self.clusters.len() {
282 let distance = euclidean_distance(
283 self.clusters[i].centroid.view(),
284 self.clusters[j].centroid.view(),
285 );
286
287 if distance <= merge_threshold {
288 to_merge.push((i, j));
289 }
290 }
291 }
292
293 for (i, j) in to_merge.into_iter().rev() {
295 self.merge_clusters(i, j)?;
296 }
297
298 Ok(())
299 }
300
301 fn merge_clusters(&mut self, i: usize, j: usize) -> Result<()> {
303 if i >= self.clusters.len() || j >= self.clusters.len() || i == j {
304 return Ok(());
305 }
306
307 let (cluster_i, cluster_j) = if i < j {
308 let (left, right) = self.clusters.split_at_mut(j);
309 (&mut left[i], &mut right[0])
310 } else {
311 let (left, right) = self.clusters.split_at_mut(i);
312 (&mut right[0], &mut left[j])
313 };
314
315 let total_weight = cluster_i.weight + cluster_j.weight;
317 let weight_i = F::from(cluster_i.weight / total_weight).unwrap();
318 let weight_j = F::from(cluster_j.weight / total_weight).unwrap();
319
320 Zip::from(&mut cluster_i.centroid)
321 .and(&cluster_j.centroid)
322 .for_each(|cent_i, ¢_j| {
323 *cent_i = *cent_i * weight_i + cent_j * weight_j;
324 });
325
326 cluster_i.weight = total_weight;
328 cluster_i.variance = (cluster_i.variance + cluster_j.variance) / 2.0;
329 cluster_i.age = cluster_i.age.max(cluster_j.age);
330 cluster_i.last_update = cluster_i.last_update.max(cluster_j.last_update);
331
332 let remove_idx = if i < j { j } else { i };
334 self.clusters.remove(remove_idx);
335
336 Ok(())
337 }
338
339 fn remove_old_clusters(&mut self) -> Result<()> {
341 let current_time = self.samples_processed;
342 let max_age = 10000; self.clusters.retain(|cluster| {
345 let age_ok = cluster.age < max_age;
346 let recent_activity = current_time - cluster.last_update < 5000;
347 let sufficient_size = cluster.weight >= self.config.min_cluster_size as f64;
348
349 age_ok && (recent_activity || sufficient_size)
350 });
351
352 Ok(())
353 }
354
355 pub fn predict(&self, point: ArrayView1<F>) -> Result<usize> {
357 let (nearest_cluster_idx_, _distance) = self.find_nearest_cluster(point);
358
359 nearest_cluster_idx_.ok_or_else(|| {
360 ClusteringError::InvalidInput("No clusters available for prediction".to_string())
361 })
362 }
363
364 pub fn cluster_centers(&self) -> Array2<F> {
366 if self.clusters.is_empty() {
367 return Array2::zeros((0, 0));
368 }
369
370 let n_features = self.clusters[0].centroid.len();
371 let mut centers = Array2::zeros((self.clusters.len(), n_features));
372
373 for (i, cluster) in self.clusters.iter().enumerate() {
374 centers.row_mut(i).assign(&cluster.centroid);
375 }
376
377 centers
378 }
379
380 pub fn cluster_info(&self) -> Vec<(f64, f64, usize)> {
382 self.clusters
383 .iter()
384 .map(|cluster| (cluster.weight, cluster.variance, cluster.age))
385 .collect()
386 }
387
388 pub fn n_clusters(&self) -> usize {
390 self.clusters.len()
391 }
392}
393
394pub fn adaptive_online_clustering<F: Float + FromPrimitive + Debug>(
396 data: ArrayView2<F>,
397 config: Option<AdaptiveOnlineConfig>,
398) -> Result<(Array2<F>, Array1<usize>)> {
399 let config = config.unwrap_or_default();
400 let mut clusterer = AdaptiveOnlineClustering::new(config);
401
402 let n_samples = data.nrows();
403 let mut labels = Array1::zeros(n_samples);
404
405 for (i, point) in data.rows().into_iter().enumerate() {
407 labels[i] = clusterer.partial_fit(point)?;
408 }
409
410 let centers = clusterer.cluster_centers();
411 Ok((centers, labels))
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use scirs2_core::ndarray::Array2;
418
419 #[test]
420 fn test_adaptive_online_config_default() {
421 let config = AdaptiveOnlineConfig::default();
422 assert_eq!(config.initial_learning_rate, 0.1);
423 assert_eq!(config.max_clusters, 50);
424 assert_eq!(config.concept_drift_window, 1000);
425 }
426
427 #[test]
428 fn test_adaptive_online_clustering_creation() {
429 let config = AdaptiveOnlineConfig::default();
430 let clusterer = AdaptiveOnlineClustering::<f64>::new(config);
431 assert_eq!(clusterer.n_clusters(), 0);
432 }
433
434 #[test]
435 fn test_adaptive_online_clustering_simple() {
436 let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 1.0, 10.0, 10.0, 11.0, 11.0])
437 .unwrap();
438
439 let config = AdaptiveOnlineConfig {
440 cluster_creation_threshold: 2.0,
441 max_clusters: 10,
442 ..Default::default()
443 };
444
445 let result = adaptive_online_clustering(data.view(), Some(config));
446 assert!(result.is_ok());
447
448 let (centers, labels) = result.unwrap();
449 assert_eq!(labels.len(), 4);
450 assert!(centers.nrows() <= 4); }
452
453 #[test]
454 fn test_online_cluster_creation() {
455 let config = AdaptiveOnlineConfig::default();
456 let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
457
458 let point = Array1::from_vec(vec![1.0, 2.0]);
459 let cluster_id = clusterer.partial_fit(point.view()).unwrap();
460
461 assert_eq!(cluster_id, 0);
462 assert_eq!(clusterer.n_clusters(), 1);
463 }
464
465 #[test]
466 fn test_concept_drift_detection() {
467 let config = AdaptiveOnlineConfig {
468 concept_drift_window: 10,
469 drift_detection_threshold: 0.1,
470 ..Default::default()
471 };
472
473 let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
474
475 for i in 0..5 {
477 let point = Array1::from_vec(vec![i as f64, i as f64]);
478 clusterer.partial_fit(point.view()).unwrap();
479 }
480
481 assert!(clusterer.detect_concept_drift().is_ok());
483 }
484
485 #[test]
486 fn test_cluster_merging() {
487 let config = AdaptiveOnlineConfig {
488 merge_threshold: 1.0,
489 cluster_creation_threshold: 0.5,
490 ..Default::default()
491 };
492
493 let mut clusterer = AdaptiveOnlineClustering::<f64>::new(config);
494
495 let point1 = Array1::from_vec(vec![0.0, 0.0]);
497 let point2 = Array1::from_vec(vec![0.3, 0.3]);
498
499 clusterer.partial_fit(point1.view()).unwrap();
500 clusterer.partial_fit(point2.view()).unwrap();
501
502 let initial_clusters = clusterer.n_clusters();
504
505 clusterer.merge_similar_clusters().unwrap();
507
508 assert!(clusterer.n_clusters() <= initial_clusters);
510 }
511}