1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::fs::{File, OpenOptions};
12use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
13use std::path::{Path, PathBuf};
14
15use crate::error::{ClusteringError, Result};
16use crate::vq::euclidean_distance;
17
18#[derive(Debug, Clone)]
20pub struct StreamingConfig {
21 pub max_memory_samples: usize,
23 pub batch_size: usize,
25 pub n_centers: usize,
27 pub tolerance: f64,
29 pub max_iterations: usize,
31}
32
33impl Default for StreamingConfig {
34 fn default() -> Self {
35 Self {
36 max_memory_samples: 10000,
37 batch_size: 1000,
38 n_centers: 10,
39 tolerance: 1e-4,
40 max_iterations: 100,
41 }
42 }
43}
44
45pub struct StreamingKMeans<F: Float> {
50 config: StreamingConfig,
51 centers: Option<Array2<F>>,
52 weights: Option<Array1<F>>,
53 n_samples_processed: usize,
54 initialized: bool,
55}
56
57impl<F: Float + FromPrimitive + Debug> StreamingKMeans<F> {
58 pub fn new(config: StreamingConfig) -> Self {
60 Self {
61 config,
62 centers: None,
63 weights: None,
64 n_samples_processed: 0,
65 initialized: false,
66 }
67 }
68
69 pub fn initialize(&mut self, data: ArrayView2<F>) -> Result<()> {
71 let n_samples = data.shape()[0];
72 let n_features = data.shape()[1];
73
74 if n_samples == 0 {
75 return Err(ClusteringError::InvalidInput(
76 "Cannot initialize with empty data".into(),
77 ));
78 }
79
80 let k = self.config.n_centers.min(n_samples);
81
82 let mut centers = Array2::zeros((k, n_features));
84 let weights = Array1::ones(k);
85
86 let first_center_idx = 0; centers.row_mut(0).assign(&data.row(first_center_idx));
89
90 for i in 1..k {
92 let mut distances = Array1::zeros(n_samples);
93 let mut total_distance = F::zero();
94
95 for j in 0..n_samples {
97 let mut min_dist = F::infinity();
98 for center_idx in 0..i {
99 let dist = euclidean_distance(data.row(j), centers.row(center_idx));
100 if dist < min_dist {
101 min_dist = dist;
102 }
103 }
104 distances[j] = min_dist * min_dist; total_distance = total_distance + distances[j];
106 }
107
108 let mut cumsum = F::zero();
110 let target =
111 total_distance * F::from(0.5).expect("Failed to convert constant to float"); for j in 0..n_samples {
114 cumsum = cumsum + distances[j];
115 if cumsum >= target {
116 centers.row_mut(i).assign(&data.row(j));
117 break;
118 }
119 }
120 }
121
122 self.centers = Some(centers);
123 self.weights = Some(weights);
124 self.n_samples_processed = n_samples;
125 self.initialized = true;
126
127 Ok(())
128 }
129
130 pub fn partial_fit(&mut self, data: ArrayView2<F>) -> Result<()> {
132 if !self.initialized {
133 return self.initialize(data);
134 }
135
136 let n_samples = data.shape()[0];
137 if n_samples == 0 {
138 return Ok(());
139 }
140
141 let centers = self.centers.as_mut().expect("Operation failed");
142 let weights = self.weights.as_mut().expect("Operation failed");
143
144 for i in 0..n_samples {
146 let point = data.row(i);
147
148 let mut min_dist = F::infinity();
150 let mut nearest_center = 0;
151
152 for j in 0..centers.shape()[0] {
153 let dist = euclidean_distance(point, centers.row(j));
154 if dist < min_dist {
155 min_dist = dist;
156 nearest_center = j;
157 }
158 }
159
160 let weight = weights[nearest_center];
162 let new_weight = weight + F::one();
163 let learning_rate = F::one() / new_weight;
164
165 let mut center_row = centers.row_mut(nearest_center);
167 for k in 0..center_row.len() {
168 let diff = point[k] - center_row[k];
169 center_row[k] = center_row[k] + learning_rate * diff;
170 }
171
172 weights[nearest_center] = new_weight;
173 }
174
175 self.n_samples_processed += n_samples;
176 Ok(())
177 }
178
179 pub fn cluster_centers(&self) -> Option<&Array2<F>> {
181 self.centers.as_ref()
182 }
183
184 pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
186 if !self.initialized {
187 return Err(ClusteringError::InvalidInput(
188 "Model must be initialized before prediction".into(),
189 ));
190 }
191
192 let centers = self.centers.as_ref().expect("Operation failed");
193 let n_samples = data.shape()[0];
194 let mut labels = Array1::zeros(n_samples);
195
196 for i in 0..n_samples {
197 let point = data.row(i);
198 let mut min_dist = F::infinity();
199 let mut nearest_center = 0;
200
201 for j in 0..centers.shape()[0] {
202 let dist = euclidean_distance(point, centers.row(j));
203 if dist < min_dist {
204 min_dist = dist;
205 nearest_center = j;
206 }
207 }
208
209 labels[i] = nearest_center;
210 }
211
212 Ok(labels)
213 }
214
215 pub fn n_samples_seen(&self) -> usize {
217 self.n_samples_processed
218 }
219}
220
221pub struct ProgressiveHierarchical<F: Float> {
226 #[allow(dead_code)]
227 config: StreamingConfig,
228 representative_points: VecDeque<Array1<F>>,
229 cluster_sizes: VecDeque<usize>,
230 max_representatives: usize,
231}
232
233impl<F: Float + FromPrimitive + Debug> ProgressiveHierarchical<F> {
234 pub fn new(config: StreamingConfig) -> Self {
236 let max_representatives = config.max_memory_samples / 10; Self {
239 config,
240 representative_points: VecDeque::new(),
241 cluster_sizes: VecDeque::new(),
242 max_representatives,
243 }
244 }
245
246 pub fn partial_fit(&mut self, data: ArrayView2<F>) -> Result<()> {
248 let n_samples = data.shape()[0];
249 if n_samples == 0 {
250 return Ok(());
251 }
252
253 if self.representative_points.is_empty() {
255 let step_size = (n_samples / self.max_representatives.min(n_samples)).max(1);
256
257 for i in (0..n_samples).step_by(step_size) {
258 self.representative_points.push_back(data.row(i).to_owned());
259 self.cluster_sizes.push_back(1);
260
261 if self.representative_points.len() >= self.max_representatives {
262 break;
263 }
264 }
265 return Ok(());
266 }
267
268 let mut new_representatives = Vec::new();
270 let mut new_sizes = Vec::new();
271
272 for i in 0..n_samples {
274 let point = data.row(i);
275
276 let mut min_dist = F::infinity();
278 let mut closest_idx = 0;
279
280 for (j, repr) in self.representative_points.iter().enumerate() {
281 let dist = euclidean_distance(point, repr.view());
282 if dist < min_dist {
283 min_dist = dist;
284 closest_idx = j;
285 }
286 }
287
288 let threshold = F::from(0.1).expect("Failed to convert constant to float"); if min_dist < threshold && closest_idx < self.representative_points.len() {
292 let old_size = self.cluster_sizes[closest_idx];
294 let new_size = old_size + 1;
295 let weight = F::from(old_size).expect("Failed to convert to float")
296 / F::from(new_size).expect("Failed to convert to float");
297
298 let mut repr = self.representative_points[closest_idx].clone();
300 for k in 0..repr.len() {
301 repr[k] = weight * repr[k] + (F::one() - weight) * point[k];
302 }
303
304 new_representatives.push(repr);
305 new_sizes.push(new_size);
306 } else {
307 new_representatives.push(point.to_owned());
309 new_sizes.push(1);
310 }
311 }
312
313 self.representative_points.clear();
315 self.cluster_sizes.clear();
316
317 for (repr, size) in new_representatives.into_iter().zip(new_sizes.into_iter()) {
318 self.representative_points.push_back(repr);
319 self.cluster_sizes.push_back(size);
320 }
321
322 if self.representative_points.len() > self.max_representatives {
324 self.compress_representatives()?;
325 }
326
327 Ok(())
328 }
329
330 fn compress_representatives(&mut self) -> Result<()> {
332 let _n_repr = self.representative_points.len();
333 let target_size = self.max_representatives * 3 / 4; while self.representative_points.len() > target_size {
336 let mut min_dist = F::infinity();
338 let mut merge_i = 0;
339 let mut merge_j = 1;
340
341 for i in 0..self.representative_points.len() {
342 for j in (i + 1)..self.representative_points.len() {
343 let dist = euclidean_distance(
344 self.representative_points[i].view(),
345 self.representative_points[j].view(),
346 );
347 if dist < min_dist {
348 min_dist = dist;
349 merge_i = i;
350 merge_j = j;
351 }
352 }
353 }
354
355 let size_i = self.cluster_sizes[merge_i];
357 let size_j = self.cluster_sizes[merge_j];
358 let total_size = size_i + size_j;
359
360 let weight_i = F::from(size_i).expect("Failed to convert to float")
361 / F::from(total_size).expect("Failed to convert to float");
362 let weight_j = F::from(size_j).expect("Failed to convert to float")
363 / F::from(total_size).expect("Failed to convert to float");
364
365 let repr_i = &self.representative_points[merge_i];
367 let repr_j = &self.representative_points[merge_j];
368 let mut merged_repr = Array1::zeros(repr_i.len());
369
370 for k in 0..merged_repr.len() {
371 merged_repr[k] = weight_i * repr_i[k] + weight_j * repr_j[k];
372 }
373
374 if merge_j > merge_i {
376 self.representative_points.remove(merge_j);
377 self.cluster_sizes.remove(merge_j);
378 self.representative_points.remove(merge_i);
379 self.cluster_sizes.remove(merge_i);
380 } else {
381 self.representative_points.remove(merge_i);
382 self.cluster_sizes.remove(merge_i);
383 self.representative_points.remove(merge_j);
384 self.cluster_sizes.remove(merge_j);
385 }
386
387 self.representative_points.push_back(merged_repr);
389 self.cluster_sizes.push_back(total_size);
390 }
391
392 Ok(())
393 }
394
395 pub fn get_representatives(&self) -> (Vec<Array1<F>>, Vec<usize>) {
397 (
398 self.representative_points.iter().cloned().collect(),
399 self.cluster_sizes.iter().cloned().collect(),
400 )
401 }
402
403 pub fn n_representatives(&self) -> usize {
405 self.representative_points.len()
406 }
407}
408
409pub struct ChunkedDistanceMatrix<F: Float> {
414 chunk_size: usize,
415 n_samples: usize,
416 _phantom: std::marker::PhantomData<F>,
417}
418
419impl<F: Float + FromPrimitive> ChunkedDistanceMatrix<F> {
420 pub fn new(n_samples: usize, max_memory_mb: usize) -> Self {
422 let memory_per_float = std::mem::size_of::<F>();
424 let max_elements = (max_memory_mb * 1024 * 1024) / memory_per_float;
425 let chunk_size = (max_elements / n_samples).max(1).min(n_samples);
426
427 Self {
428 chunk_size,
429 n_samples,
430 _phantom: std::marker::PhantomData,
431 }
432 }
433
434 pub fn process_chunks<Func>(&self, data: ArrayView2<F>, mut processor: Func) -> Result<()>
436 where
437 Func: FnMut(usize, usize, F) -> Result<()>,
438 {
439 for i in (0..self.n_samples).step_by(self.chunk_size) {
440 let end_i = (i + self.chunk_size).min(self.n_samples);
441
442 for j in (i..self.n_samples).step_by(self.chunk_size) {
443 let end_j = (j + self.chunk_size).min(self.n_samples);
444
445 for row in i..end_i {
447 for col in j.max(row + 1)..end_j {
448 let dist = euclidean_distance(data.row(row), data.row(col));
449 processor(row, col, dist)?;
450 }
451 }
452 }
453 }
454 Ok(())
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use scirs2_core::ndarray::Array2;
462
463 #[test]
464 fn test_streaming_kmeans() {
465 let config = StreamingConfig {
466 max_memory_samples: 100,
467 batch_size: 10,
468 n_centers: 2,
469 tolerance: 1e-4,
470 max_iterations: 10,
471 };
472
473 let mut streaming_kmeans = StreamingKMeans::new(config);
474
475 let batch1 = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 1.1, 1.1])
477 .expect("Operation failed");
478
479 streaming_kmeans
480 .partial_fit(batch1.view())
481 .expect("Operation failed");
482 assert!(streaming_kmeans.cluster_centers().is_some());
483
484 let batch2 = Array2::from_shape_vec((4, 2), vec![0.2, 0.2, 0.0, 0.1, 1.2, 1.0, 1.0, 1.2])
486 .expect("Operation failed");
487
488 streaming_kmeans
489 .partial_fit(batch2.view())
490 .expect("Operation failed");
491
492 let test_data =
494 Array2::from_shape_vec((2, 2), vec![0.05, 0.05, 1.05, 1.05]).expect("Operation failed");
495
496 let labels = streaming_kmeans
497 .predict(test_data.view())
498 .expect("Operation failed");
499 assert_eq!(labels.len(), 2);
500
501 assert_ne!(labels[0], labels[1]);
503 }
504
505 #[test]
506 fn test_progressive_hierarchical() {
507 let config = StreamingConfig::default();
508 let mut progressive = ProgressiveHierarchical::new(config);
509
510 let batch1 = Array2::from_shape_vec(
512 (6, 2),
513 vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2],
514 )
515 .expect("Test: operation failed");
516
517 progressive
518 .partial_fit(batch1.view())
519 .expect("Operation failed");
520 let (representatives, sizes) = progressive.get_representatives();
521
522 assert!(!representatives.is_empty());
523 assert_eq!(representatives.len(), sizes.len());
524 assert!(progressive.n_representatives() > 0);
525 }
526
527 #[test]
528 fn test_chunked_distance_matrix() {
529 let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
530 .expect("Operation failed");
531
532 let chunked_matrix = ChunkedDistanceMatrix::new(4, 1); let mut distance_count = 0;
534
535 chunked_matrix
536 .process_chunks(data.view(), |i, j, dist| {
537 assert!(i < j);
538 assert!(dist >= 0.0);
539 distance_count += 1;
540 Ok(())
541 })
542 .expect("Test: operation failed");
543
544 assert_eq!(distance_count, 6);
546 }
547
548 #[test]
549 fn test_streaming_config_default() {
550 let config = StreamingConfig::default();
551 assert_eq!(config.max_memory_samples, 10000);
552 assert_eq!(config.batch_size, 1000);
553 assert_eq!(config.n_centers, 10);
554 assert_eq!(config.tolerance, 1e-4);
555 assert_eq!(config.max_iterations, 100);
556 }
557}
558
559pub mod memory_management {
561 use super::*;
562
563 #[derive(Debug, Clone)]
565 pub struct AdaptiveMemoryManager {
566 current_usage: usize,
568 max_memory: usize,
570 pressure_threshold: f64,
572 enable_disk_storage: bool,
574 temp_dir: Option<PathBuf>,
576 }
577
578 impl AdaptiveMemoryManager {
579 pub fn new(max_memory_mb: usize) -> Self {
581 Self {
582 current_usage: 0,
583 max_memory: max_memory_mb * 1024 * 1024,
584 pressure_threshold: 0.8,
585 enable_disk_storage: true,
586 temp_dir: std::env::temp_dir().into(),
587 }
588 }
589
590 pub fn is_memory_pressure_high(&self) -> bool {
592 self.current_usage as f64 / self.max_memory as f64 > self.pressure_threshold
593 }
594
595 pub fn estimate_memory_usage<F: Float>(
597 &self,
598 n_samples: usize,
599 n_features: usize,
600 ) -> usize {
601 std::mem::size_of::<F>() * n_samples * n_features
602 }
603
604 pub fn allocate<F: Float>(&mut self, n_samples: usize, n_features: usize) -> Result<()> {
606 let required = self.estimate_memory_usage::<F>(n_samples, n_features);
607
608 if self.current_usage + required > self.max_memory {
609 if self.enable_disk_storage {
610 Ok(())
612 } else {
613 Err(ClusteringError::InvalidInput(
614 "Not enough memory and disk storage is disabled".to_string(),
615 ))
616 }
617 } else {
618 self.current_usage += required;
619 Ok(())
620 }
621 }
622
623 pub fn deallocate(&mut self, amount: usize) {
625 self.current_usage = self.current_usage.saturating_sub(amount);
626 }
627
628 pub fn available_memory(&self) -> usize {
630 self.max_memory.saturating_sub(self.current_usage)
631 }
632
633 pub fn optimal_batch_size<F: Float>(&self, n_features: usize) -> usize {
635 let available = self.available_memory();
636 let bytes_per_sample = std::mem::size_of::<F>() * n_features;
637
638 if bytes_per_sample == 0 {
639 1000 } else {
641 (available / bytes_per_sample).max(1).min(10000)
642 }
643 }
644 }
645
646 #[derive(Debug)]
648 pub struct DiskBackedStorage<F: Float + FromPrimitive> {
649 temp_files: Vec<PathBuf>,
650 temp_dir: PathBuf,
651 buffer_size: usize,
652 _phantom: std::marker::PhantomData<F>,
653 }
654
655 impl<F: Float + FromPrimitive> DiskBackedStorage<F> {
656 pub fn new(temp_dir: Option<PathBuf>, buffer_size: usize) -> Self {
658 let temp_dir = temp_dir.unwrap_or_else(std::env::temp_dir);
659
660 Self {
661 temp_files: Vec::new(),
662 temp_dir,
663 buffer_size,
664 _phantom: std::marker::PhantomData,
665 }
666 }
667
668 pub fn write_chunk(&mut self, data: ArrayView2<F>) -> Result<usize> {
670 let chunk_id = self.temp_files.len();
671 let file_path = self
672 .temp_dir
673 .join(format!("cluster_chunk_{}.bin", chunk_id));
674
675 let file = File::create(&file_path).map_err(|e| {
676 ClusteringError::InvalidInput(format!("Failed to create temp file: {}", e))
677 })?;
678 let mut writer = BufWriter::new(file);
679
680 let n_rows = data.shape()[0] as u64;
682 let n_cols = data.shape()[1] as u64;
683 writer.write_all(&n_rows.to_le_bytes()).map_err(|e| {
684 ClusteringError::InvalidInput(format!("Failed to write dimensions: {}", e))
685 })?;
686 writer.write_all(&n_cols.to_le_bytes()).map_err(|e| {
687 ClusteringError::InvalidInput(format!("Failed to write dimensions: {}", e))
688 })?;
689
690 for row in data.rows() {
692 for &value in row.iter() {
693 let bytes = value.to_f64().unwrap_or(0.0).to_le_bytes();
694 writer.write_all(&bytes).map_err(|e| {
695 ClusteringError::InvalidInput(format!("Failed to write data: {}", e))
696 })?;
697 }
698 }
699
700 writer.flush().map_err(|e| {
701 ClusteringError::InvalidInput(format!("Failed to flush data: {}", e))
702 })?;
703
704 self.temp_files.push(file_path);
705 Ok(chunk_id)
706 }
707
708 pub fn read_chunk(&self, chunk_id: usize) -> Result<Array2<F>> {
710 if chunk_id >= self.temp_files.len() {
711 return Err(ClusteringError::InvalidInput(
712 "Invalid chunk ID".to_string(),
713 ));
714 }
715
716 let file = File::open(&self.temp_files[chunk_id]).map_err(|e| {
717 ClusteringError::InvalidInput(format!("Failed to open temp file: {}", e))
718 })?;
719 let mut reader = BufReader::new(file);
720
721 let mut dim_bytes = [0u8; 8];
723 reader.read_exact(&mut dim_bytes).map_err(|e| {
724 ClusteringError::InvalidInput(format!("Failed to read dimensions: {}", e))
725 })?;
726 let n_rows = u64::from_le_bytes(dim_bytes) as usize;
727
728 reader.read_exact(&mut dim_bytes).map_err(|e| {
729 ClusteringError::InvalidInput(format!("Failed to read dimensions: {}", e))
730 })?;
731 let n_cols = u64::from_le_bytes(dim_bytes) as usize;
732
733 let mut data = Array2::zeros((n_rows, n_cols));
735 for mut row in data.rows_mut() {
736 for element in row.iter_mut() {
737 let mut value_bytes = [0u8; 8];
738 reader.read_exact(&mut value_bytes).map_err(|e| {
739 ClusteringError::InvalidInput(format!("Failed to read data: {}", e))
740 })?;
741 let value = f64::from_le_bytes(value_bytes);
742 *element = F::from(value).unwrap_or(F::zero());
743 }
744 }
745
746 Ok(data)
747 }
748
749 pub fn cleanup(&mut self) -> Result<()> {
751 for file_path in &self.temp_files {
752 if file_path.exists() {
753 std::fs::remove_file(file_path).map_err(|e| {
754 ClusteringError::InvalidInput(format!("Failed to remove temp file: {}", e))
755 })?;
756 }
757 }
758 self.temp_files.clear();
759 Ok(())
760 }
761
762 pub fn num_chunks(&self) -> usize {
764 self.temp_files.len()
765 }
766 }
767
768 impl<F: Float + FromPrimitive> Drop for DiskBackedStorage<F> {
769 fn drop(&mut self) {
770 let _ = self.cleanup(); }
772 }
773}
774
775pub mod advanced_streaming {
777 use super::*;
778
779 #[derive(Debug, Clone)]
782 pub struct CountMinSketch {
783 tables: Vec<Vec<u64>>,
785 width: usize,
787 depth: usize,
789 hash_params: Vec<(u64, u64)>,
791 }
792
793 impl CountMinSketch {
794 pub fn new(epsilon: f64, delta: f64) -> Self {
796 let width = (std::f64::consts::E / epsilon).ceil() as usize;
797 let depth = (1.0 / delta).ln().ceil() as usize;
798
799 let mut tables = Vec::new();
800 let mut hash_params = Vec::new();
801
802 for i in 0..depth {
803 tables.push(vec![0u64; width]);
804 hash_params.push((
806 1000000007 + i as u64 * 1000000009,
807 1000000021 + i as u64 * 1000000033,
808 ));
809 }
810
811 Self {
812 tables,
813 width,
814 depth,
815 hash_params,
816 }
817 }
818
819 pub fn add(&mut self, item: u64) {
821 for i in 0..self.depth {
822 let hash = self.hash(item, i);
823 let idx = (hash as usize) % self.width;
824 self.tables[i][idx] += 1;
825 }
826 }
827
828 pub fn estimate(&self, item: u64) -> u64 {
830 let mut min_count = u64::MAX;
831
832 for i in 0..self.depth {
833 let hash = self.hash(item, i);
834 let idx = (hash as usize) % self.width;
835 min_count = min_count.min(self.tables[i][idx]);
836 }
837
838 min_count
839 }
840
841 fn hash(&self, item: u64, table_idx: usize) -> u64 {
843 let (a, b) = self.hash_params[table_idx];
844 a.wrapping_mul(item).wrapping_add(b)
845 }
846
847 pub fn heavy_hitters(&self, threshold: u64) -> Vec<u64> {
849 Vec::new()
852 }
853 }
854
855 #[derive(Debug, Clone)]
857 pub struct ReservoirSampler<T> {
858 reservoir: Vec<T>,
859 capacity: usize,
860 seen_count: usize,
861 }
862
863 impl<T: Clone> ReservoirSampler<T> {
864 pub fn new(capacity: usize) -> Self {
866 Self {
867 reservoir: Vec::with_capacity(capacity),
868 capacity,
869 seen_count: 0,
870 }
871 }
872
873 pub fn add(&mut self, item: T) {
875 self.seen_count += 1;
876
877 if self.reservoir.len() < self.capacity {
878 self.reservoir.push(item);
879 } else {
880 let random_idx = (self.seen_count - 1) % self.capacity; if random_idx < self.capacity {
883 self.reservoir[random_idx] = item;
884 }
885 }
886 }
887
888 pub fn sample(&self) -> &[T] {
890 &self.reservoir
891 }
892
893 pub fn items_seen(&self) -> usize {
895 self.seen_count
896 }
897 }
898
899 #[derive(Debug)]
901 pub struct ProgressiveLearner<F: Float> {
902 model_state: HashMap<String, Vec<F>>,
904 learning_rate: F,
906 decay_factor: F,
908 update_count: usize,
910 gradient_memory: HashMap<String, Vec<F>>,
912 momentum: F,
914 }
915
916 impl<F: Float + FromPrimitive + std::fmt::Debug> ProgressiveLearner<F> {
917 pub fn new(initial_lr: F, decay: F, momentum: F) -> Self {
919 Self {
920 model_state: HashMap::new(),
921 learning_rate: initial_lr,
922 decay_factor: decay,
923 update_count: 0,
924 gradient_memory: HashMap::new(),
925 momentum,
926 }
927 }
928
929 pub fn update(&mut self, param_name: &str, gradient: &[F]) -> Result<()> {
931 self.update_count += 1;
932
933 if self.update_count.is_multiple_of(100) {
935 self.learning_rate = self.learning_rate * self.decay_factor;
936 }
937
938 if !self.model_state.contains_key(param_name) {
940 self.model_state
941 .insert(param_name.to_string(), vec![F::zero(); gradient.len()]);
942 self.gradient_memory
943 .insert(param_name.to_string(), vec![F::zero(); gradient.len()]);
944 }
945
946 let params = self
947 .model_state
948 .get_mut(param_name)
949 .expect("Operation failed");
950 let momentum_grad = self
951 .gradient_memory
952 .get_mut(param_name)
953 .expect("Operation failed");
954
955 for i in 0..params.len() {
957 momentum_grad[i] = self.momentum * momentum_grad[i] + gradient[i];
958 params[i] = params[i] - self.learning_rate * momentum_grad[i];
959 }
960
961 Ok(())
962 }
963
964 pub fn get_parameters(&self, param_name: &str) -> Option<&[F]> {
966 self.model_state.get(param_name).map(|v| v.as_slice())
967 }
968
969 pub fn current_learning_rate(&self) -> F {
971 self.learning_rate
972 }
973
974 pub fn update_count(&self) -> usize {
976 self.update_count
977 }
978 }
979}
980
981pub mod intelligent_loading {
983 use super::*;
984
985 #[derive(Debug)]
987 pub struct AdaptiveDataLoader {
988 current_batch_size: usize,
990 min_batch_size: usize,
992 max_batch_size: usize,
994 performance_history: VecDeque<f64>,
996 target_time: f64,
998 adjustment_factor: f64,
1000 }
1001
1002 impl AdaptiveDataLoader {
1003 pub fn new(initial_batch_size: usize, target_time_seconds: f64) -> Self {
1005 Self {
1006 current_batch_size: initial_batch_size,
1007 min_batch_size: initial_batch_size / 10,
1008 max_batch_size: initial_batch_size * 10,
1009 performance_history: VecDeque::with_capacity(10),
1010 target_time: target_time_seconds,
1011 adjustment_factor: 0.1,
1012 }
1013 }
1014
1015 pub fn report_batch_time(&mut self, processing_time: f64) {
1017 self.performance_history.push_back(processing_time);
1018 if self.performance_history.len() > 10 {
1019 self.performance_history.pop_front();
1020 }
1021
1022 let avg_time = self.performance_history.iter().sum::<f64>()
1024 / self.performance_history.len() as f64;
1025
1026 if avg_time > self.target_time * 1.2 {
1028 let new_size =
1030 (self.current_batch_size as f64 * (1.0 - self.adjustment_factor)) as usize;
1031 self.current_batch_size = new_size.max(self.min_batch_size);
1032 } else if avg_time < self.target_time * 0.8 {
1033 let new_size =
1035 (self.current_batch_size as f64 * (1.0 + self.adjustment_factor)) as usize;
1036 self.current_batch_size = new_size.min(self.max_batch_size);
1037 }
1038 }
1039
1040 pub fn current_batch_size(&self) -> usize {
1042 self.current_batch_size
1043 }
1044
1045 pub fn get_stats(&self) -> (f64, f64, usize) {
1047 let avg_time = if self.performance_history.is_empty() {
1048 0.0
1049 } else {
1050 self.performance_history.iter().sum::<f64>() / self.performance_history.len() as f64
1051 };
1052
1053 let efficiency = if avg_time > 0.0 {
1054 self.target_time / avg_time
1055 } else {
1056 1.0
1057 };
1058
1059 (avg_time, efficiency, self.current_batch_size)
1060 }
1061 }
1062
1063 #[derive(Debug, Clone)]
1065 pub struct StreamingPreprocessor<F: Float> {
1066 running_mean: Option<Array1<F>>,
1068 running_var: Option<Array1<F>>,
1069 sample_count: usize,
1070 normalize: bool,
1072 outlier_threshold: F,
1074 missing_value_strategy: MissingValueStrategy,
1076 }
1077
1078 #[derive(Debug, Clone)]
1079 pub enum MissingValueStrategy {
1080 Drop,
1081 FillMean,
1082 FillZero,
1083 Interpolate,
1084 }
1085
1086 impl<F: Float + FromPrimitive + std::fmt::Debug> StreamingPreprocessor<F> {
1087 pub fn new(normalize: bool, outlier_threshold: F) -> Self {
1089 Self {
1090 running_mean: None,
1091 running_var: None,
1092 sample_count: 0,
1093 normalize,
1094 outlier_threshold,
1095 missing_value_strategy: MissingValueStrategy::FillMean,
1096 }
1097 }
1098
1099 pub fn process_batch(&mut self, mut data: Array2<F>) -> Result<Array2<F>> {
1101 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1102
1103 if n_samples == 0 {
1104 return Ok(data);
1105 }
1106
1107 if self.running_mean.is_none() {
1109 self.running_mean = Some(Array1::zeros(n_features));
1110 self.running_var = Some(Array1::zeros(n_features));
1111 }
1112
1113 if self.normalize {
1115 self.update_statistics(&data)?;
1116 }
1117
1118 data = self.handle_missing_values(data)?;
1120
1121 if self.normalize {
1123 self.apply_normalization(&mut data)?;
1124 }
1125
1126 self.handle_outliers(&mut data)?;
1128
1129 Ok(data)
1130 }
1131
1132 fn update_statistics(&mut self, data: &Array2<F>) -> Result<()> {
1134 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1135 let mean = self.running_mean.as_mut().expect("Operation failed");
1136 let var = self.running_var.as_mut().expect("Operation failed");
1137
1138 for i in 0..n_samples {
1139 self.sample_count += 1;
1140 let sample = data.row(i);
1141
1142 for j in 0..n_features {
1143 if sample[j].is_finite() {
1144 let delta = sample[j] - mean[j];
1146 mean[j] = mean[j]
1147 + delta
1148 / F::from(self.sample_count).expect("Failed to convert to float");
1149 let delta2 = sample[j] - mean[j];
1150 var[j] = var[j] + delta * delta2;
1151 }
1152 }
1153 }
1154
1155 Ok(())
1156 }
1157
1158 fn handle_missing_values(&self, mut data: Array2<F>) -> Result<Array2<F>> {
1168 let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1169 match self.missing_value_strategy {
1170 MissingValueStrategy::FillZero => {
1171 for elem in data.iter_mut() {
1172 if !elem.is_finite() {
1173 *elem = F::zero();
1174 }
1175 }
1176 }
1177 MissingValueStrategy::FillMean => {
1178 if let Some(ref mean) = self.running_mean {
1179 for mut row in data.rows_mut().into_iter() {
1180 for (j, elem) in row.iter_mut().enumerate() {
1181 if !elem.is_finite() && j < mean.len() {
1182 *elem = mean[j];
1183 }
1184 }
1185 }
1186 }
1187 }
1188 MissingValueStrategy::Drop => {
1189 let valid_rows: Vec<usize> = (0..n_samples)
1191 .filter(|&i| data.row(i).iter().all(|v| v.is_finite()))
1192 .collect();
1193 if valid_rows.len() == n_samples {
1194 return Ok(data);
1196 }
1197 let mut kept = Array2::zeros((valid_rows.len(), n_features));
1198 for (new_idx, &old_idx) in valid_rows.iter().enumerate() {
1199 kept.row_mut(new_idx).assign(&data.row(old_idx));
1200 }
1201 return Ok(kept);
1202 }
1203 MissingValueStrategy::Interpolate => {
1204 let fallback = |j: usize| -> F {
1209 self.running_mean
1210 .as_ref()
1211 .and_then(|m| if j < m.len() { Some(m[j]) } else { None })
1212 .unwrap_or(F::zero())
1213 };
1214
1215 for col in 0..n_features {
1216 let vals: Vec<F> = (0..n_samples).map(|r| data[[r, col]]).collect();
1218
1219 for row in 0..n_samples {
1220 if !vals[row].is_finite() {
1221 let prev = (0..row)
1223 .rev()
1224 .find(|&r| vals[r].is_finite())
1225 .map(|r| (r, vals[r]));
1226 let next = ((row + 1)..n_samples)
1228 .find(|&r| vals[r].is_finite())
1229 .map(|r| (r, vals[r]));
1230
1231 data[[row, col]] = match (prev, next) {
1232 (Some((pr, pv)), Some((nr, nv))) => {
1233 let span = F::from(nr - pr).unwrap_or(F::one());
1235 let offset = F::from(row - pr).unwrap_or(F::zero());
1236 pv + (nv - pv) * offset / span
1237 }
1238 (Some((_, pv)), None) => pv,
1239 (None, Some((_, nv))) => nv,
1240 (None, None) => fallback(col),
1241 };
1242 }
1243 }
1244 }
1245 }
1246 }
1247 Ok(data)
1248 }
1249
1250 fn apply_normalization(&self, data: &mut Array2<F>) -> Result<()> {
1252 if let (Some(ref mean), Some(ref var)) = (&self.running_mean, &self.running_var) {
1253 if self.sample_count > 1 {
1254 for mut row in data.rows_mut().into_iter() {
1255 for (j, elem) in row.iter_mut().enumerate() {
1256 if j < mean.len() && var[j] > F::zero() {
1257 let std_dev = (var[j]
1258 / F::from(self.sample_count - 1)
1259 .expect("Failed to convert to float"))
1260 .sqrt();
1261 if std_dev > F::zero() {
1262 *elem = (*elem - mean[j]) / std_dev;
1263 }
1264 }
1265 }
1266 }
1267 }
1268 }
1269 Ok(())
1270 }
1271
1272 fn handle_outliers(&self, data: &mut Array2<F>) -> Result<()> {
1274 for elem in data.iter_mut() {
1276 if elem.abs() > self.outlier_threshold {
1277 *elem = if *elem > F::zero() {
1278 self.outlier_threshold
1279 } else {
1280 -self.outlier_threshold
1281 };
1282 }
1283 }
1284 Ok(())
1285 }
1286
1287 pub fn get_statistics(&self) -> Option<(Array1<F>, Array1<F>)> {
1289 if let (Some(ref mean), Some(ref var)) = (&self.running_mean, &self.running_var) {
1290 Some((mean.clone(), var.clone()))
1291 } else {
1292 None
1293 }
1294 }
1295 }
1296
1297 #[cfg(test)]
1298 mod tests_preprocessor {
1299 use super::*;
1300 use scirs2_core::ndarray::Array2;
1301
1302 fn make_nan_data() -> Array2<f64> {
1304 Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0])
1305 .expect("shape error")
1306 }
1307
1308 #[test]
1309 fn test_fill_zero_replaces_nan() {
1310 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1312 pp.missing_value_strategy = MissingValueStrategy::FillZero;
1313 let data = make_nan_data();
1314 let out = pp.process_batch(data).expect("process_batch failed");
1315 assert_eq!(out.shape(), &[3, 2]);
1316 assert_eq!(out[[1, 0]], 0.0, "NaN should be replaced with 0");
1317 assert!(out[[1, 0]].is_finite());
1318 }
1319
1320 #[test]
1321 fn test_fill_mean_uses_running_mean() {
1322 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1323 pp.missing_value_strategy = MissingValueStrategy::FillMean;
1324 pp.running_mean = Some(scirs2_core::ndarray::array![2.0, 3.0]);
1325 let data = make_nan_data();
1326 let out = pp.process_batch(data).expect("process_batch failed");
1327 assert!(
1329 (out[[1, 0]] - 2.0).abs() < 1e-9,
1330 "NaN should be replaced with running mean 2.0, got {}",
1331 out[[1, 0]]
1332 );
1333 }
1334
1335 #[test]
1336 fn test_drop_removes_nan_rows() {
1337 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1338 pp.missing_value_strategy = MissingValueStrategy::Drop;
1339 let data = make_nan_data();
1340 let out = pp.process_batch(data).expect("process_batch failed");
1341 assert_eq!(out.shape(), &[2, 2], "NaN row should have been dropped");
1343 for row in out.rows() {
1345 for &v in row.iter() {
1346 assert!(v.is_finite(), "all remaining values should be finite");
1347 }
1348 }
1349 }
1350
1351 #[test]
1352 fn test_drop_all_clean_data_unchanged() {
1353 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1354 pp.missing_value_strategy = MissingValueStrategy::Drop;
1355 let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1356 .expect("shape error");
1357 let out = pp.process_batch(data).expect("process_batch failed");
1358 assert_eq!(
1359 out.shape(),
1360 &[3, 2],
1361 "no rows should be dropped when data is clean"
1362 );
1363 }
1364
1365 #[test]
1366 fn test_interpolate_fills_nan_between_finite() {
1367 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1369 pp.missing_value_strategy = MissingValueStrategy::Interpolate;
1370 let data = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, f64::NAN, f64::NAN, 3.0, 8.0])
1373 .expect("shape error");
1374 let out = pp.process_batch(data).expect("process_batch failed");
1375 assert_eq!(out.shape(), &[3, 2]);
1376 assert!(
1377 (out[[1, 0]] - 2.0).abs() < 1e-9,
1378 "interpolated col 0 should be 2.0, got {}",
1379 out[[1, 0]]
1380 );
1381 assert!(
1382 (out[[1, 1]] - 6.0).abs() < 1e-9,
1383 "interpolated col 1 should be 6.0, got {}",
1384 out[[1, 1]]
1385 );
1386 }
1387
1388 #[test]
1389 fn test_interpolate_leading_nan_uses_next_finite() {
1390 let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1392 pp.missing_value_strategy = MissingValueStrategy::Interpolate;
1393 let data = Array2::from_shape_vec((3, 2), vec![f64::NAN, 0.0, 5.0, 0.0, 7.0, 0.0])
1395 .expect("shape error");
1396 let out = pp.process_batch(data).expect("process_batch failed");
1397 assert!(
1398 (out[[0, 0]] - 5.0).abs() < 1e-9,
1399 "leading NaN should be filled with 5.0, got {}",
1400 out[[0, 0]]
1401 );
1402 }
1403 }
1404}
1405
1406pub mod online_algorithms {
1408 use super::*;
1409 use std::collections::VecDeque;
1410
1411 #[derive(Debug, Clone)]
1413 pub struct AdaptiveOnlineKMeans<F: Float> {
1414 centers: Array2<F>,
1416 learning_rate_schedule: LearningRateSchedule,
1418 iteration: usize,
1420 adaptive_params: AdaptiveParams<F>,
1422 metrics: OnlineMetrics,
1424 }
1425
1426 #[derive(Debug, Clone)]
1428 pub enum LearningRateSchedule {
1429 Constant(f64),
1431 Decay { initial_lr: f64, decay: f64 },
1433 StepDecay {
1435 initial_lr: f64,
1436 factor: f64,
1437 step_size: usize,
1438 },
1439 Adaptive {
1441 min_lr: f64,
1442 max_lr: f64,
1443 stability_window: usize,
1444 },
1445 }
1446
1447 #[derive(Debug, Clone)]
1449 pub struct AdaptiveParams<F: Float> {
1450 pub momentum: F,
1452 pub stability_scores: Vec<F>,
1454 pub center_movements: VecDeque<F>,
1456 pub auto_k_adjustment: bool,
1458 pub split_threshold: F,
1460 pub merge_threshold: F,
1462 }
1463
1464 #[derive(Debug, Clone, Default)]
1466 pub struct OnlineMetrics {
1467 pub wcss: f64,
1469 pub samples_processed: usize,
1471 pub update_frequency: f64,
1473 pub cluster_distribution: Vec<usize>,
1475 pub batch_processing_times: VecDeque<f64>,
1477 }
1478
1479 impl<F: Float + FromPrimitive + Debug> AdaptiveOnlineKMeans<F> {
1480 pub fn new(
1482 initial_centers: Array2<F>,
1483 learning_rate_schedule: LearningRateSchedule,
1484 ) -> Self {
1485 let n_clusters = initial_centers.nrows();
1486 let adaptive_params = AdaptiveParams {
1487 momentum: F::from(0.9).expect("Failed to convert constant to float"),
1488 stability_scores: vec![F::zero(); n_clusters],
1489 center_movements: VecDeque::with_capacity(100),
1490 auto_k_adjustment: false,
1491 split_threshold: F::from(2.0).expect("Failed to convert constant to float"),
1492 merge_threshold: F::from(0.5).expect("Failed to convert constant to float"),
1493 };
1494
1495 Self {
1496 centers: initial_centers,
1497 learning_rate_schedule,
1498 iteration: 0,
1499 adaptive_params,
1500 metrics: OnlineMetrics::default(),
1501 }
1502 }
1503
1504 pub fn update(&mut self, sample: ArrayView1<F>) -> Result<usize> {
1506 let start_time = std::time::Instant::now();
1507
1508 let (nearest_cluster, min_distance) = self.find_nearest_cluster(sample)?;
1510
1511 let lr = self.get_current_learning_rate();
1513
1514 let old_center = self.centers.row(nearest_cluster).to_owned();
1516 self.update_center(nearest_cluster, sample, lr)?;
1517
1518 let movement = euclidean_distance(old_center.view(), self.centers.row(nearest_cluster));
1520 self.adaptive_params.center_movements.push_back(movement);
1521 if self.adaptive_params.center_movements.len() > 100 {
1522 self.adaptive_params.center_movements.pop_front();
1523 }
1524
1525 self.update_metrics(
1527 nearest_cluster,
1528 min_distance,
1529 start_time.elapsed().as_secs_f64(),
1530 );
1531
1532 if self.adaptive_params.auto_k_adjustment {
1534 self.maybe_adjust_clusters(sample, min_distance)?;
1535 }
1536
1537 self.iteration += 1;
1538 Ok(nearest_cluster)
1539 }
1540
1541 fn maybe_adjust_clusters(&mut self, sample: ArrayView1<F>, min_distance: F) -> Result<()> {
1543 if min_distance > self.adaptive_params.split_threshold {
1545 let n_clusters = self.centers.nrows();
1547 let n_features = self.centers.ncols();
1548
1549 let mut new_centers = Array2::<F>::zeros((n_clusters + 1, n_features));
1550
1551 for i in 0..n_clusters {
1553 for j in 0..n_features {
1554 new_centers[[i, j]] = self.centers[[i, j]];
1555 }
1556 }
1557
1558 for (j, &val) in sample.iter().enumerate() {
1560 if j < n_features {
1561 new_centers[[n_clusters, j]] = val;
1562 }
1563 }
1564
1565 self.centers = new_centers;
1566 self.adaptive_params.stability_scores.push(F::zero());
1567 }
1568
1569 let n_clusters = self.centers.nrows();
1571 if n_clusters > 1 {
1572 let mut clusters_to_merge: Option<(usize, usize)> = None;
1573 let mut min_inter_distance = F::infinity();
1574
1575 for i in 0..n_clusters {
1577 for j in (i + 1)..n_clusters {
1578 let dist = euclidean_distance(self.centers.row(i), self.centers.row(j));
1579
1580 if dist < min_inter_distance {
1581 min_inter_distance = dist;
1582 clusters_to_merge = Some((i, j));
1583 }
1584 }
1585 }
1586
1587 if let Some((i, j)) = clusters_to_merge {
1589 if min_inter_distance < self.adaptive_params.merge_threshold {
1590 let n_features = self.centers.ncols();
1591 let mut new_centers = Array2::<F>::zeros((n_clusters - 1, n_features));
1592
1593 let mut merged_center = Array1::<F>::zeros(n_features);
1595 for k in 0..n_features {
1596 merged_center[k] = (self.centers[[i, k]] + self.centers[[j, k]])
1597 / (F::one() + F::one());
1598 }
1599
1600 let mut new_idx = 0;
1602 for old_idx in 0..n_clusters {
1603 if old_idx == i {
1604 for k in 0..n_features {
1606 new_centers[[new_idx, k]] = merged_center[k];
1607 }
1608 new_idx += 1;
1609 } else if old_idx != j {
1610 for k in 0..n_features {
1612 new_centers[[new_idx, k]] = self.centers[[old_idx, k]];
1613 }
1614 new_idx += 1;
1615 }
1616 }
1617
1618 self.centers = new_centers;
1619
1620 if j < self.adaptive_params.stability_scores.len() {
1622 self.adaptive_params.stability_scores.remove(j);
1623 }
1624 }
1625 }
1626 }
1627
1628 Ok(())
1629 }
1630
1631 fn find_nearest_cluster(&self, sample: ArrayView1<F>) -> Result<(usize, F)> {
1633 let mut min_distance = F::infinity();
1634 let mut nearest_cluster = 0;
1635
1636 for (i, center) in self.centers.rows().into_iter().enumerate() {
1637 let distance = euclidean_distance(sample, center);
1638 if distance < min_distance {
1639 min_distance = distance;
1640 nearest_cluster = i;
1641 }
1642 }
1643
1644 Ok((nearest_cluster, min_distance))
1645 }
1646
1647 fn update_center(
1649 &mut self,
1650 cluster_idx: usize,
1651 sample: ArrayView1<F>,
1652 lr: f64,
1653 ) -> Result<()> {
1654 let learning_rate = F::from(lr).expect("Failed to convert to float");
1655 let momentum = self.adaptive_params.momentum;
1656
1657 let mut center = self.centers.row_mut(cluster_idx);
1658 for (i, &sample_val) in sample.iter().enumerate() {
1659 if i < center.len() {
1660 let old_val = center[i];
1661 let gradient = sample_val - old_val;
1662 let update = learning_rate * gradient;
1663 center[i] = momentum * old_val + (F::one() - momentum) * (old_val + update);
1664 }
1665 }
1666
1667 Ok(())
1668 }
1669
1670 fn get_current_learning_rate(&self) -> f64 {
1672 match &self.learning_rate_schedule {
1673 LearningRateSchedule::Constant(lr) => *lr,
1674 LearningRateSchedule::Decay { initial_lr, decay } => {
1675 initial_lr / (1.0 + decay * self.iteration as f64)
1676 }
1677 LearningRateSchedule::StepDecay {
1678 initial_lr,
1679 factor,
1680 step_size,
1681 } => {
1682 let steps = self.iteration / step_size;
1683 initial_lr * factor.powi(steps as i32)
1684 }
1685 LearningRateSchedule::Adaptive {
1686 min_lr,
1687 max_lr,
1688 stability_window,
1689 } => {
1690 let recent_movements: Vec<F> = self
1691 .adaptive_params
1692 .center_movements
1693 .iter()
1694 .rev()
1695 .take(*stability_window)
1696 .cloned()
1697 .collect();
1698
1699 if recent_movements.is_empty() {
1700 return *max_lr;
1701 }
1702
1703 let avg_movement = recent_movements.iter().fold(F::zero(), |acc, x| acc + *x)
1704 / F::from(recent_movements.len()).expect("Operation failed");
1705 let stability = F::one() / (F::one() + avg_movement);
1706
1707 let adaptive_lr = min_lr
1709 + (max_lr - min_lr)
1710 * (F::one() - stability).to_f64().expect("Operation failed");
1711 adaptive_lr.clamp(*min_lr, *max_lr)
1712 }
1713 }
1714 }
1715
1716 fn update_metrics(&mut self, cluster_idx: usize, distance: F, processing_time: f64) {
1718 self.metrics.samples_processed += 1;
1719
1720 let distance_sq = distance.to_f64().expect("Operation failed").powi(2);
1722 let n = self.metrics.samples_processed as f64;
1723 self.metrics.wcss = ((n - 1.0) * self.metrics.wcss + distance_sq) / n;
1724
1725 if cluster_idx >= self.metrics.cluster_distribution.len() {
1727 self.metrics.cluster_distribution.resize(cluster_idx + 1, 0);
1728 }
1729 self.metrics.cluster_distribution[cluster_idx] += 1;
1730
1731 self.metrics
1733 .batch_processing_times
1734 .push_back(processing_time);
1735 if self.metrics.batch_processing_times.len() > 1000 {
1736 self.metrics.batch_processing_times.pop_front();
1737 }
1738
1739 let total_updates = self.metrics.cluster_distribution.iter().sum::<usize>() as f64;
1741 self.metrics.update_frequency = total_updates / self.iteration.max(1) as f64;
1742 }
1743
1744 pub fn get_centers(&self) -> &Array2<F> {
1746 &self.centers
1747 }
1748
1749 pub fn get_metrics(&self) -> &OnlineMetrics {
1751 &self.metrics
1752 }
1753
1754 pub fn predict(&self, samples: ArrayView2<F>) -> Result<Array1<usize>> {
1756 let n_samples = samples.nrows();
1757 let mut predictions = Array1::zeros(n_samples);
1758
1759 for (i, sample) in samples.rows().into_iter().enumerate() {
1760 let (cluster_id, _distance) = self.find_nearest_cluster(sample)?;
1761 predictions[i] = cluster_id;
1762 }
1763
1764 Ok(predictions)
1765 }
1766 }
1767}