Skip to main content

scirs2_cluster/
streaming.rs

1//! Streaming and memory-efficient clustering algorithms
2//!
3//! This module provides implementations of clustering algorithms that can handle
4//! large datasets that don't fit entirely in memory, using streaming and
5//! progressive processing techniques.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
8use scirs2_core::numeric::{Float, FromPrimitive};
9use std::collections::{HashMap, VecDeque};
10use std::fmt::Debug;
11use std::fs::{File, OpenOptions};
12use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
13use std::path::{Path, PathBuf};
14
15use crate::error::{ClusteringError, Result};
16use crate::vq::euclidean_distance;
17
18/// Configuration for streaming clustering algorithms
19#[derive(Debug, Clone)]
20pub struct StreamingConfig {
21    /// Maximum number of samples to keep in memory at once
22    pub max_memory_samples: usize,
23    /// Batch size for processing chunks of data
24    pub batch_size: usize,
25    /// Number of cluster centers to maintain
26    pub n_centers: usize,
27    /// Convergence threshold for iterative algorithms
28    pub tolerance: f64,
29    /// Maximum number of iterations
30    pub max_iterations: usize,
31}
32
33impl Default for StreamingConfig {
34    fn default() -> Self {
35        Self {
36            max_memory_samples: 10000,
37            batch_size: 1000,
38            n_centers: 10,
39            tolerance: 1e-4,
40            max_iterations: 100,
41        }
42    }
43}
44
45/// Streaming K-means clustering for large datasets
46///
47/// This implementation processes data in chunks and maintains a fixed number
48/// of cluster centers, updating them incrementally as new data arrives.
49pub struct StreamingKMeans<F: Float> {
50    config: StreamingConfig,
51    centers: Option<Array2<F>>,
52    weights: Option<Array1<F>>,
53    n_samples_processed: usize,
54    initialized: bool,
55}
56
57impl<F: Float + FromPrimitive + Debug> StreamingKMeans<F> {
58    /// Create a new streaming K-means instance
59    pub fn new(config: StreamingConfig) -> Self {
60        Self {
61            config,
62            centers: None,
63            weights: None,
64            n_samples_processed: 0,
65            initialized: false,
66        }
67    }
68
69    /// Initialize the clustering with the first batch of data
70    pub fn initialize(&mut self, data: ArrayView2<F>) -> Result<()> {
71        let n_samples = data.shape()[0];
72        let n_features = data.shape()[1];
73
74        if n_samples == 0 {
75            return Err(ClusteringError::InvalidInput(
76                "Cannot initialize with empty data".into(),
77            ));
78        }
79
80        let k = self.config.n_centers.min(n_samples);
81
82        // Initialize centers using K-means++ method
83        let mut centers = Array2::zeros((k, n_features));
84        let weights = Array1::ones(k);
85
86        // Choose first center randomly
87        let first_center_idx = 0; // For deterministic behavior, choose first point
88        centers.row_mut(0).assign(&data.row(first_center_idx));
89
90        // Choose remaining centers using K-means++ initialization
91        for i in 1..k {
92            let mut distances = Array1::zeros(n_samples);
93            let mut total_distance = F::zero();
94
95            // Calculate distances to nearest existing center
96            for j in 0..n_samples {
97                let mut min_dist = F::infinity();
98                for center_idx in 0..i {
99                    let dist = euclidean_distance(data.row(j), centers.row(center_idx));
100                    if dist < min_dist {
101                        min_dist = dist;
102                    }
103                }
104                distances[j] = min_dist * min_dist; // Squared distance for K-means++
105                total_distance = total_distance + distances[j];
106            }
107
108            // Choose next center with probability proportional to squared distance
109            let mut cumsum = F::zero();
110            let target =
111                total_distance * F::from(0.5).expect("Failed to convert constant to float"); // Simplified selection
112
113            for j in 0..n_samples {
114                cumsum = cumsum + distances[j];
115                if cumsum >= target {
116                    centers.row_mut(i).assign(&data.row(j));
117                    break;
118                }
119            }
120        }
121
122        self.centers = Some(centers);
123        self.weights = Some(weights);
124        self.n_samples_processed = n_samples;
125        self.initialized = true;
126
127        Ok(())
128    }
129
130    /// Process a new batch of data and update cluster centers
131    pub fn partial_fit(&mut self, data: ArrayView2<F>) -> Result<()> {
132        if !self.initialized {
133            return self.initialize(data);
134        }
135
136        let n_samples = data.shape()[0];
137        if n_samples == 0 {
138            return Ok(());
139        }
140
141        let centers = self.centers.as_mut().expect("Operation failed");
142        let weights = self.weights.as_mut().expect("Operation failed");
143
144        // Assign each point to the nearest center and update centers
145        for i in 0..n_samples {
146            let point = data.row(i);
147
148            // Find nearest center
149            let mut min_dist = F::infinity();
150            let mut nearest_center = 0;
151
152            for j in 0..centers.shape()[0] {
153                let dist = euclidean_distance(point, centers.row(j));
154                if dist < min_dist {
155                    min_dist = dist;
156                    nearest_center = j;
157                }
158            }
159
160            // Update center using online mean update
161            let weight = weights[nearest_center];
162            let new_weight = weight + F::one();
163            let learning_rate = F::one() / new_weight;
164
165            // Update center: new_center = old_center + lr * (point - old_center)
166            let mut center_row = centers.row_mut(nearest_center);
167            for k in 0..center_row.len() {
168                let diff = point[k] - center_row[k];
169                center_row[k] = center_row[k] + learning_rate * diff;
170            }
171
172            weights[nearest_center] = new_weight;
173        }
174
175        self.n_samples_processed += n_samples;
176        Ok(())
177    }
178
179    /// Get the current cluster centers
180    pub fn cluster_centers(&self) -> Option<&Array2<F>> {
181        self.centers.as_ref()
182    }
183
184    /// Predict cluster assignments for new data
185    pub fn predict(&self, data: ArrayView2<F>) -> Result<Array1<usize>> {
186        if !self.initialized {
187            return Err(ClusteringError::InvalidInput(
188                "Model must be initialized before prediction".into(),
189            ));
190        }
191
192        let centers = self.centers.as_ref().expect("Operation failed");
193        let n_samples = data.shape()[0];
194        let mut labels = Array1::zeros(n_samples);
195
196        for i in 0..n_samples {
197            let point = data.row(i);
198            let mut min_dist = F::infinity();
199            let mut nearest_center = 0;
200
201            for j in 0..centers.shape()[0] {
202                let dist = euclidean_distance(point, centers.row(j));
203                if dist < min_dist {
204                    min_dist = dist;
205                    nearest_center = j;
206                }
207            }
208
209            labels[i] = nearest_center;
210        }
211
212        Ok(labels)
213    }
214
215    /// Get the number of samples processed so far
216    pub fn n_samples_seen(&self) -> usize {
217        self.n_samples_processed
218    }
219}
220
221/// Progressive hierarchical clustering for large datasets
222///
223/// This implementation builds a hierarchy incrementally by processing data
224/// in chunks and maintaining a compressed representation of the clustering.
225pub struct ProgressiveHierarchical<F: Float> {
226    #[allow(dead_code)]
227    config: StreamingConfig,
228    representative_points: VecDeque<Array1<F>>,
229    cluster_sizes: VecDeque<usize>,
230    max_representatives: usize,
231}
232
233impl<F: Float + FromPrimitive + Debug> ProgressiveHierarchical<F> {
234    /// Create a new progressive hierarchical clustering instance
235    pub fn new(config: StreamingConfig) -> Self {
236        let max_representatives = config.max_memory_samples / 10; // Keep 10% as representatives
237
238        Self {
239            config,
240            representative_points: VecDeque::new(),
241            cluster_sizes: VecDeque::new(),
242            max_representatives,
243        }
244    }
245
246    /// Process a new batch of data
247    pub fn partial_fit(&mut self, data: ArrayView2<F>) -> Result<()> {
248        let n_samples = data.shape()[0];
249        if n_samples == 0 {
250            return Ok(());
251        }
252
253        // If this is the first batch, just add some representative points
254        if self.representative_points.is_empty() {
255            let step_size = (n_samples / self.max_representatives.min(n_samples)).max(1);
256
257            for i in (0..n_samples).step_by(step_size) {
258                self.representative_points.push_back(data.row(i).to_owned());
259                self.cluster_sizes.push_back(1);
260
261                if self.representative_points.len() >= self.max_representatives {
262                    break;
263                }
264            }
265            return Ok(());
266        }
267
268        // For subsequent batches, merge new points with existing representatives
269        let mut new_representatives = Vec::new();
270        let mut new_sizes = Vec::new();
271
272        // Process new data points
273        for i in 0..n_samples {
274            let point = data.row(i);
275
276            // Find closest representative
277            let mut min_dist = F::infinity();
278            let mut closest_idx = 0;
279
280            for (j, repr) in self.representative_points.iter().enumerate() {
281                let dist = euclidean_distance(point, repr.view());
282                if dist < min_dist {
283                    min_dist = dist;
284                    closest_idx = j;
285                }
286            }
287
288            // Merge with closest representative or create new one
289            let threshold = F::from(0.1).expect("Failed to convert constant to float"); // Distance threshold for merging
290
291            if min_dist < threshold && closest_idx < self.representative_points.len() {
292                // Merge with existing representative
293                let old_size = self.cluster_sizes[closest_idx];
294                let new_size = old_size + 1;
295                let weight = F::from(old_size).expect("Failed to convert to float")
296                    / F::from(new_size).expect("Failed to convert to float");
297
298                // Update representative as weighted average
299                let mut repr = self.representative_points[closest_idx].clone();
300                for k in 0..repr.len() {
301                    repr[k] = weight * repr[k] + (F::one() - weight) * point[k];
302                }
303
304                new_representatives.push(repr);
305                new_sizes.push(new_size);
306            } else {
307                // Create new representative
308                new_representatives.push(point.to_owned());
309                new_sizes.push(1);
310            }
311        }
312
313        // Replace old representatives with updated ones
314        self.representative_points.clear();
315        self.cluster_sizes.clear();
316
317        for (repr, size) in new_representatives.into_iter().zip(new_sizes.into_iter()) {
318            self.representative_points.push_back(repr);
319            self.cluster_sizes.push_back(size);
320        }
321
322        // If we have too many representatives, compress by merging similar ones
323        if self.representative_points.len() > self.max_representatives {
324            self.compress_representatives()?;
325        }
326
327        Ok(())
328    }
329
330    /// Compress the representation by merging similar representative points
331    fn compress_representatives(&mut self) -> Result<()> {
332        let _n_repr = self.representative_points.len();
333        let target_size = self.max_representatives * 3 / 4; // Reduce to 75% of max
334
335        while self.representative_points.len() > target_size {
336            // Find the two closest representatives to merge
337            let mut min_dist = F::infinity();
338            let mut merge_i = 0;
339            let mut merge_j = 1;
340
341            for i in 0..self.representative_points.len() {
342                for j in (i + 1)..self.representative_points.len() {
343                    let dist = euclidean_distance(
344                        self.representative_points[i].view(),
345                        self.representative_points[j].view(),
346                    );
347                    if dist < min_dist {
348                        min_dist = dist;
349                        merge_i = i;
350                        merge_j = j;
351                    }
352                }
353            }
354
355            // Merge the two closest representatives
356            let size_i = self.cluster_sizes[merge_i];
357            let size_j = self.cluster_sizes[merge_j];
358            let total_size = size_i + size_j;
359
360            let weight_i = F::from(size_i).expect("Failed to convert to float")
361                / F::from(total_size).expect("Failed to convert to float");
362            let weight_j = F::from(size_j).expect("Failed to convert to float")
363                / F::from(total_size).expect("Failed to convert to float");
364
365            // Create merged representative
366            let repr_i = &self.representative_points[merge_i];
367            let repr_j = &self.representative_points[merge_j];
368            let mut merged_repr = Array1::zeros(repr_i.len());
369
370            for k in 0..merged_repr.len() {
371                merged_repr[k] = weight_i * repr_i[k] + weight_j * repr_j[k];
372            }
373
374            // Remove the two old representatives (remove larger index first)
375            if merge_j > merge_i {
376                self.representative_points.remove(merge_j);
377                self.cluster_sizes.remove(merge_j);
378                self.representative_points.remove(merge_i);
379                self.cluster_sizes.remove(merge_i);
380            } else {
381                self.representative_points.remove(merge_i);
382                self.cluster_sizes.remove(merge_i);
383                self.representative_points.remove(merge_j);
384                self.cluster_sizes.remove(merge_j);
385            }
386
387            // Add the merged representative
388            self.representative_points.push_back(merged_repr);
389            self.cluster_sizes.push_back(total_size);
390        }
391
392        Ok(())
393    }
394
395    /// Get the current representative points
396    pub fn get_representatives(&self) -> (Vec<Array1<F>>, Vec<usize>) {
397        (
398            self.representative_points.iter().cloned().collect(),
399            self.cluster_sizes.iter().cloned().collect(),
400        )
401    }
402
403    /// Get the number of representative points
404    pub fn n_representatives(&self) -> usize {
405        self.representative_points.len()
406    }
407}
408
409/// Memory-efficient distance matrix computation
410///
411/// Computes distances between points in chunks to avoid storing the full
412/// distance matrix in memory.
413pub struct ChunkedDistanceMatrix<F: Float> {
414    chunk_size: usize,
415    n_samples: usize,
416    _phantom: std::marker::PhantomData<F>,
417}
418
419impl<F: Float + FromPrimitive> ChunkedDistanceMatrix<F> {
420    /// Create a new chunked distance matrix
421    pub fn new(n_samples: usize, max_memory_mb: usize) -> Self {
422        // Estimate chunk size based on memory limit
423        let memory_per_float = std::mem::size_of::<F>();
424        let max_elements = (max_memory_mb * 1024 * 1024) / memory_per_float;
425        let chunk_size = (max_elements / n_samples).max(1).min(n_samples);
426
427        Self {
428            chunk_size,
429            n_samples,
430            _phantom: std::marker::PhantomData,
431        }
432    }
433
434    /// Process distances in chunks and apply a function to each chunk
435    pub fn process_chunks<Func>(&self, data: ArrayView2<F>, mut processor: Func) -> Result<()>
436    where
437        Func: FnMut(usize, usize, F) -> Result<()>,
438    {
439        for i in (0..self.n_samples).step_by(self.chunk_size) {
440            let end_i = (i + self.chunk_size).min(self.n_samples);
441
442            for j in (i..self.n_samples).step_by(self.chunk_size) {
443                let end_j = (j + self.chunk_size).min(self.n_samples);
444
445                // Process this chunk of distances
446                for row in i..end_i {
447                    for col in j.max(row + 1)..end_j {
448                        let dist = euclidean_distance(data.row(row), data.row(col));
449                        processor(row, col, dist)?;
450                    }
451                }
452            }
453        }
454        Ok(())
455    }
456}
457
458#[cfg(test)]
459mod tests {
460    use super::*;
461    use scirs2_core::ndarray::Array2;
462
463    #[test]
464    fn test_streaming_kmeans() {
465        let config = StreamingConfig {
466            max_memory_samples: 100,
467            batch_size: 10,
468            n_centers: 2,
469            tolerance: 1e-4,
470            max_iterations: 10,
471        };
472
473        let mut streaming_kmeans = StreamingKMeans::new(config);
474
475        // First batch
476        let batch1 = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 0.1, 0.1, 1.0, 1.0, 1.1, 1.1])
477            .expect("Operation failed");
478
479        streaming_kmeans
480            .partial_fit(batch1.view())
481            .expect("Operation failed");
482        assert!(streaming_kmeans.cluster_centers().is_some());
483
484        // Second batch
485        let batch2 = Array2::from_shape_vec((4, 2), vec![0.2, 0.2, 0.0, 0.1, 1.2, 1.0, 1.0, 1.2])
486            .expect("Operation failed");
487
488        streaming_kmeans
489            .partial_fit(batch2.view())
490            .expect("Operation failed");
491
492        // Test prediction
493        let test_data =
494            Array2::from_shape_vec((2, 2), vec![0.05, 0.05, 1.05, 1.05]).expect("Operation failed");
495
496        let labels = streaming_kmeans
497            .predict(test_data.view())
498            .expect("Operation failed");
499        assert_eq!(labels.len(), 2);
500
501        // Points should be assigned to different clusters
502        assert_ne!(labels[0], labels[1]);
503    }
504
505    #[test]
506    fn test_progressive_hierarchical() {
507        let config = StreamingConfig::default();
508        let mut progressive = ProgressiveHierarchical::new(config);
509
510        // Process first batch
511        let batch1 = Array2::from_shape_vec(
512            (6, 2),
513            vec![0.0, 0.0, 0.1, 0.1, 0.2, 0.2, 5.0, 5.0, 5.1, 5.1, 5.2, 5.2],
514        )
515        .expect("Test: operation failed");
516
517        progressive
518            .partial_fit(batch1.view())
519            .expect("Operation failed");
520        let (representatives, sizes) = progressive.get_representatives();
521
522        assert!(!representatives.is_empty());
523        assert_eq!(representatives.len(), sizes.len());
524        assert!(progressive.n_representatives() > 0);
525    }
526
527    #[test]
528    fn test_chunked_distance_matrix() {
529        let data = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
530            .expect("Operation failed");
531
532        let chunked_matrix = ChunkedDistanceMatrix::new(4, 1); // 1 MB limit
533        let mut distance_count = 0;
534
535        chunked_matrix
536            .process_chunks(data.view(), |i, j, dist| {
537                assert!(i < j);
538                assert!(dist >= 0.0);
539                distance_count += 1;
540                Ok(())
541            })
542            .expect("Test: operation failed");
543
544        // Should process 6 distances for 4 points: (0,1), (0,2), (0,3), (1,2), (1,3), (2,3)
545        assert_eq!(distance_count, 6);
546    }
547
548    #[test]
549    fn test_streaming_config_default() {
550        let config = StreamingConfig::default();
551        assert_eq!(config.max_memory_samples, 10000);
552        assert_eq!(config.batch_size, 1000);
553        assert_eq!(config.n_centers, 10);
554        assert_eq!(config.tolerance, 1e-4);
555        assert_eq!(config.max_iterations, 100);
556    }
557}
558
559/// Enhanced memory management for out-of-core clustering
560pub mod memory_management {
561    use super::*;
562
563    /// Adaptive memory manager that monitors system resources
564    #[derive(Debug, Clone)]
565    pub struct AdaptiveMemoryManager {
566        /// Current memory usage estimate (bytes)
567        current_usage: usize,
568        /// Maximum allowed memory usage (bytes)
569        max_memory: usize,
570        /// Memory pressure threshold (0.0 to 1.0)
571        pressure_threshold: f64,
572        /// Enable disk-based storage when memory is full
573        enable_disk_storage: bool,
574        /// Temporary directory for disk storage
575        temp_dir: Option<PathBuf>,
576    }
577
578    impl AdaptiveMemoryManager {
579        /// Create a new adaptive memory manager
580        pub fn new(max_memory_mb: usize) -> Self {
581            Self {
582                current_usage: 0,
583                max_memory: max_memory_mb * 1024 * 1024,
584                pressure_threshold: 0.8,
585                enable_disk_storage: true,
586                temp_dir: std::env::temp_dir().into(),
587            }
588        }
589
590        /// Check if memory pressure is high
591        pub fn is_memory_pressure_high(&self) -> bool {
592            self.current_usage as f64 / self.max_memory as f64 > self.pressure_threshold
593        }
594
595        /// Estimate memory usage for storing data
596        pub fn estimate_memory_usage<F: Float>(
597            &self,
598            n_samples: usize,
599            n_features: usize,
600        ) -> usize {
601            std::mem::size_of::<F>() * n_samples * n_features
602        }
603
604        /// Allocate memory for data
605        pub fn allocate<F: Float>(&mut self, n_samples: usize, n_features: usize) -> Result<()> {
606            let required = self.estimate_memory_usage::<F>(n_samples, n_features);
607
608            if self.current_usage + required > self.max_memory {
609                if self.enable_disk_storage {
610                    // Allow allocation but mark that we need disk storage
611                    Ok(())
612                } else {
613                    Err(ClusteringError::InvalidInput(
614                        "Not enough memory and disk storage is disabled".to_string(),
615                    ))
616                }
617            } else {
618                self.current_usage += required;
619                Ok(())
620            }
621        }
622
623        /// Deallocate memory
624        pub fn deallocate(&mut self, amount: usize) {
625            self.current_usage = self.current_usage.saturating_sub(amount);
626        }
627
628        /// Get available memory
629        pub fn available_memory(&self) -> usize {
630            self.max_memory.saturating_sub(self.current_usage)
631        }
632
633        /// Get optimal batch size based on available memory
634        pub fn optimal_batch_size<F: Float>(&self, n_features: usize) -> usize {
635            let available = self.available_memory();
636            let bytes_per_sample = std::mem::size_of::<F>() * n_features;
637
638            if bytes_per_sample == 0 {
639                1000 // Default fallback
640            } else {
641                (available / bytes_per_sample).max(1).min(10000)
642            }
643        }
644    }
645
646    /// Disk-based storage for large intermediate results
647    #[derive(Debug)]
648    pub struct DiskBackedStorage<F: Float + FromPrimitive> {
649        temp_files: Vec<PathBuf>,
650        temp_dir: PathBuf,
651        buffer_size: usize,
652        _phantom: std::marker::PhantomData<F>,
653    }
654
655    impl<F: Float + FromPrimitive> DiskBackedStorage<F> {
656        /// Create a new disk-backed storage
657        pub fn new(temp_dir: Option<PathBuf>, buffer_size: usize) -> Self {
658            let temp_dir = temp_dir.unwrap_or_else(std::env::temp_dir);
659
660            Self {
661                temp_files: Vec::new(),
662                temp_dir,
663                buffer_size,
664                _phantom: std::marker::PhantomData,
665            }
666        }
667
668        /// Write data chunk to disk
669        pub fn write_chunk(&mut self, data: ArrayView2<F>) -> Result<usize> {
670            let chunk_id = self.temp_files.len();
671            let file_path = self
672                .temp_dir
673                .join(format!("cluster_chunk_{}.bin", chunk_id));
674
675            let file = File::create(&file_path).map_err(|e| {
676                ClusteringError::InvalidInput(format!("Failed to create temp file: {}", e))
677            })?;
678            let mut writer = BufWriter::new(file);
679
680            // Write dimensions
681            let n_rows = data.shape()[0] as u64;
682            let n_cols = data.shape()[1] as u64;
683            writer.write_all(&n_rows.to_le_bytes()).map_err(|e| {
684                ClusteringError::InvalidInput(format!("Failed to write dimensions: {}", e))
685            })?;
686            writer.write_all(&n_cols.to_le_bytes()).map_err(|e| {
687                ClusteringError::InvalidInput(format!("Failed to write dimensions: {}", e))
688            })?;
689
690            // Write data (simplified - in practice, you'd want more sophisticated serialization)
691            for row in data.rows() {
692                for &value in row.iter() {
693                    let bytes = value.to_f64().unwrap_or(0.0).to_le_bytes();
694                    writer.write_all(&bytes).map_err(|e| {
695                        ClusteringError::InvalidInput(format!("Failed to write data: {}", e))
696                    })?;
697                }
698            }
699
700            writer.flush().map_err(|e| {
701                ClusteringError::InvalidInput(format!("Failed to flush data: {}", e))
702            })?;
703
704            self.temp_files.push(file_path);
705            Ok(chunk_id)
706        }
707
708        /// Read data chunk from disk
709        pub fn read_chunk(&self, chunk_id: usize) -> Result<Array2<F>> {
710            if chunk_id >= self.temp_files.len() {
711                return Err(ClusteringError::InvalidInput(
712                    "Invalid chunk ID".to_string(),
713                ));
714            }
715
716            let file = File::open(&self.temp_files[chunk_id]).map_err(|e| {
717                ClusteringError::InvalidInput(format!("Failed to open temp file: {}", e))
718            })?;
719            let mut reader = BufReader::new(file);
720
721            // Read dimensions
722            let mut dim_bytes = [0u8; 8];
723            reader.read_exact(&mut dim_bytes).map_err(|e| {
724                ClusteringError::InvalidInput(format!("Failed to read dimensions: {}", e))
725            })?;
726            let n_rows = u64::from_le_bytes(dim_bytes) as usize;
727
728            reader.read_exact(&mut dim_bytes).map_err(|e| {
729                ClusteringError::InvalidInput(format!("Failed to read dimensions: {}", e))
730            })?;
731            let n_cols = u64::from_le_bytes(dim_bytes) as usize;
732
733            // Read data
734            let mut data = Array2::zeros((n_rows, n_cols));
735            for mut row in data.rows_mut() {
736                for element in row.iter_mut() {
737                    let mut value_bytes = [0u8; 8];
738                    reader.read_exact(&mut value_bytes).map_err(|e| {
739                        ClusteringError::InvalidInput(format!("Failed to read data: {}", e))
740                    })?;
741                    let value = f64::from_le_bytes(value_bytes);
742                    *element = F::from(value).unwrap_or(F::zero());
743                }
744            }
745
746            Ok(data)
747        }
748
749        /// Clean up temporary files
750        pub fn cleanup(&mut self) -> Result<()> {
751            for file_path in &self.temp_files {
752                if file_path.exists() {
753                    std::fs::remove_file(file_path).map_err(|e| {
754                        ClusteringError::InvalidInput(format!("Failed to remove temp file: {}", e))
755                    })?;
756                }
757            }
758            self.temp_files.clear();
759            Ok(())
760        }
761
762        /// Get number of chunks stored
763        pub fn num_chunks(&self) -> usize {
764            self.temp_files.len()
765        }
766    }
767
768    impl<F: Float + FromPrimitive> Drop for DiskBackedStorage<F> {
769        fn drop(&mut self) {
770            let _ = self.cleanup(); // Best effort cleanup
771        }
772    }
773}
774
775/// Advanced streaming algorithms for out-of-core processing
776pub mod advanced_streaming {
777    use super::*;
778
779    /// Count-Min Sketch for approximate frequency counting
780    /// Useful for identifying heavy hitters in streaming data
781    #[derive(Debug, Clone)]
782    pub struct CountMinSketch {
783        /// Hash tables (width x depth)
784        tables: Vec<Vec<u64>>,
785        /// Width of each hash table
786        width: usize,
787        /// Depth (number of hash tables)
788        depth: usize,
789        /// Hash functions (simple linear congruential generators)
790        hash_params: Vec<(u64, u64)>,
791    }
792
793    impl CountMinSketch {
794        /// Create a new Count-Min Sketch
795        pub fn new(epsilon: f64, delta: f64) -> Self {
796            let width = (std::f64::consts::E / epsilon).ceil() as usize;
797            let depth = (1.0 / delta).ln().ceil() as usize;
798
799            let mut tables = Vec::new();
800            let mut hash_params = Vec::new();
801
802            for i in 0..depth {
803                tables.push(vec![0u64; width]);
804                // Simple hash parameters (in practice, use better hash functions)
805                hash_params.push((
806                    1000000007 + i as u64 * 1000000009,
807                    1000000021 + i as u64 * 1000000033,
808                ));
809            }
810
811            Self {
812                tables,
813                width,
814                depth,
815                hash_params,
816            }
817        }
818
819        /// Add an item to the sketch
820        pub fn add(&mut self, item: u64) {
821            for i in 0..self.depth {
822                let hash = self.hash(item, i);
823                let idx = (hash as usize) % self.width;
824                self.tables[i][idx] += 1;
825            }
826        }
827
828        /// Estimate the frequency of an item
829        pub fn estimate(&self, item: u64) -> u64 {
830            let mut min_count = u64::MAX;
831
832            for i in 0..self.depth {
833                let hash = self.hash(item, i);
834                let idx = (hash as usize) % self.width;
835                min_count = min_count.min(self.tables[i][idx]);
836            }
837
838            min_count
839        }
840
841        /// Simple hash function
842        fn hash(&self, item: u64, table_idx: usize) -> u64 {
843            let (a, b) = self.hash_params[table_idx];
844            a.wrapping_mul(item).wrapping_add(b)
845        }
846
847        /// Get heavy hitters (items with frequency above threshold)
848        pub fn heavy_hitters(&self, threshold: u64) -> Vec<u64> {
849            // This is a simplified implementation
850            // In practice, you'd need to track candidates more carefully
851            Vec::new()
852        }
853    }
854
855    /// Reservoir sampling for maintaining a random sample from a stream
856    #[derive(Debug, Clone)]
857    pub struct ReservoirSampler<T> {
858        reservoir: Vec<T>,
859        capacity: usize,
860        seen_count: usize,
861    }
862
863    impl<T: Clone> ReservoirSampler<T> {
864        /// Create a new reservoir sampler
865        pub fn new(capacity: usize) -> Self {
866            Self {
867                reservoir: Vec::with_capacity(capacity),
868                capacity,
869                seen_count: 0,
870            }
871        }
872
873        /// Add an item to the reservoir
874        pub fn add(&mut self, item: T) {
875            self.seen_count += 1;
876
877            if self.reservoir.len() < self.capacity {
878                self.reservoir.push(item);
879            } else {
880                // Replace random item with probability k/n
881                let random_idx = (self.seen_count - 1) % self.capacity; // Simplified random selection
882                if random_idx < self.capacity {
883                    self.reservoir[random_idx] = item;
884                }
885            }
886        }
887
888        /// Get the current sample
889        pub fn sample(&self) -> &[T] {
890            &self.reservoir
891        }
892
893        /// Get the number of items seen
894        pub fn items_seen(&self) -> usize {
895            self.seen_count
896        }
897    }
898
899    /// Progressive learning framework for online clustering
900    #[derive(Debug)]
901    pub struct ProgressiveLearner<F: Float> {
902        /// Current model state
903        model_state: HashMap<String, Vec<F>>,
904        /// Learning rate schedule
905        learning_rate: F,
906        /// Decay factor for learning rate
907        decay_factor: F,
908        /// Number of updates performed
909        update_count: usize,
910        /// Memory for recent gradients (for momentum)
911        gradient_memory: HashMap<String, Vec<F>>,
912        /// Momentum coefficient
913        momentum: F,
914    }
915
916    impl<F: Float + FromPrimitive + std::fmt::Debug> ProgressiveLearner<F> {
917        /// Create a new progressive learner
918        pub fn new(initial_lr: F, decay: F, momentum: F) -> Self {
919            Self {
920                model_state: HashMap::new(),
921                learning_rate: initial_lr,
922                decay_factor: decay,
923                update_count: 0,
924                gradient_memory: HashMap::new(),
925                momentum,
926            }
927        }
928
929        /// Update model parameters with gradient
930        pub fn update(&mut self, param_name: &str, gradient: &[F]) -> Result<()> {
931            self.update_count += 1;
932
933            // Update learning rate with decay
934            if self.update_count.is_multiple_of(100) {
935                self.learning_rate = self.learning_rate * self.decay_factor;
936            }
937
938            // Initialize parameter if not exists
939            if !self.model_state.contains_key(param_name) {
940                self.model_state
941                    .insert(param_name.to_string(), vec![F::zero(); gradient.len()]);
942                self.gradient_memory
943                    .insert(param_name.to_string(), vec![F::zero(); gradient.len()]);
944            }
945
946            let params = self
947                .model_state
948                .get_mut(param_name)
949                .expect("Operation failed");
950            let momentum_grad = self
951                .gradient_memory
952                .get_mut(param_name)
953                .expect("Operation failed");
954
955            // Update with momentum
956            for i in 0..params.len() {
957                momentum_grad[i] = self.momentum * momentum_grad[i] + gradient[i];
958                params[i] = params[i] - self.learning_rate * momentum_grad[i];
959            }
960
961            Ok(())
962        }
963
964        /// Get current parameter values
965        pub fn get_parameters(&self, param_name: &str) -> Option<&[F]> {
966            self.model_state.get(param_name).map(|v| v.as_slice())
967        }
968
969        /// Get current learning rate
970        pub fn current_learning_rate(&self) -> F {
971            self.learning_rate
972        }
973
974        /// Get update count
975        pub fn update_count(&self) -> usize {
976            self.update_count
977        }
978    }
979}
980
981/// Intelligent data loading and preprocessing for streaming
982pub mod intelligent_loading {
983    use super::*;
984
985    /// Adaptive data loader that adjusts batch sizes based on system performance
986    #[derive(Debug)]
987    pub struct AdaptiveDataLoader {
988        /// Current batch size
989        current_batch_size: usize,
990        /// Minimum batch size
991        min_batch_size: usize,
992        /// Maximum batch size
993        max_batch_size: usize,
994        /// Performance history (processing times)
995        performance_history: VecDeque<f64>,
996        /// Target processing time per batch (seconds)
997        target_time: f64,
998        /// Adjustment factor for batch size changes
999        adjustment_factor: f64,
1000    }
1001
1002    impl AdaptiveDataLoader {
1003        /// Create a new adaptive data loader
1004        pub fn new(initial_batch_size: usize, target_time_seconds: f64) -> Self {
1005            Self {
1006                current_batch_size: initial_batch_size,
1007                min_batch_size: initial_batch_size / 10,
1008                max_batch_size: initial_batch_size * 10,
1009                performance_history: VecDeque::with_capacity(10),
1010                target_time: target_time_seconds,
1011                adjustment_factor: 0.1,
1012            }
1013        }
1014
1015        /// Report batch processing time and adjust batch size
1016        pub fn report_batch_time(&mut self, processing_time: f64) {
1017            self.performance_history.push_back(processing_time);
1018            if self.performance_history.len() > 10 {
1019                self.performance_history.pop_front();
1020            }
1021
1022            // Calculate moving average
1023            let avg_time = self.performance_history.iter().sum::<f64>()
1024                / self.performance_history.len() as f64;
1025
1026            // Adjust batch size based on performance
1027            if avg_time > self.target_time * 1.2 {
1028                // Too slow, reduce batch size
1029                let new_size =
1030                    (self.current_batch_size as f64 * (1.0 - self.adjustment_factor)) as usize;
1031                self.current_batch_size = new_size.max(self.min_batch_size);
1032            } else if avg_time < self.target_time * 0.8 {
1033                // Too fast, increase batch size
1034                let new_size =
1035                    (self.current_batch_size as f64 * (1.0 + self.adjustment_factor)) as usize;
1036                self.current_batch_size = new_size.min(self.max_batch_size);
1037            }
1038        }
1039
1040        /// Get current optimal batch size
1041        pub fn current_batch_size(&self) -> usize {
1042            self.current_batch_size
1043        }
1044
1045        /// Get performance statistics
1046        pub fn get_stats(&self) -> (f64, f64, usize) {
1047            let avg_time = if self.performance_history.is_empty() {
1048                0.0
1049            } else {
1050                self.performance_history.iter().sum::<f64>() / self.performance_history.len() as f64
1051            };
1052
1053            let efficiency = if avg_time > 0.0 {
1054                self.target_time / avg_time
1055            } else {
1056                1.0
1057            };
1058
1059            (avg_time, efficiency, self.current_batch_size)
1060        }
1061    }
1062
1063    /// Smart preprocessing pipeline for streaming data
1064    #[derive(Debug, Clone)]
1065    pub struct StreamingPreprocessor<F: Float> {
1066        /// Running statistics for normalization
1067        running_mean: Option<Array1<F>>,
1068        running_var: Option<Array1<F>>,
1069        sample_count: usize,
1070        /// Enable online normalization
1071        normalize: bool,
1072        /// Outlier detection threshold (standard deviations)
1073        outlier_threshold: F,
1074        /// Missing value strategy
1075        missing_value_strategy: MissingValueStrategy,
1076    }
1077
1078    #[derive(Debug, Clone)]
1079    pub enum MissingValueStrategy {
1080        Drop,
1081        FillMean,
1082        FillZero,
1083        Interpolate,
1084    }
1085
1086    impl<F: Float + FromPrimitive + std::fmt::Debug> StreamingPreprocessor<F> {
1087        /// Create a new streaming preprocessor
1088        pub fn new(normalize: bool, outlier_threshold: F) -> Self {
1089            Self {
1090                running_mean: None,
1091                running_var: None,
1092                sample_count: 0,
1093                normalize,
1094                outlier_threshold,
1095                missing_value_strategy: MissingValueStrategy::FillMean,
1096            }
1097        }
1098
1099        /// Process a batch of data
1100        pub fn process_batch(&mut self, mut data: Array2<F>) -> Result<Array2<F>> {
1101            let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1102
1103            if n_samples == 0 {
1104                return Ok(data);
1105            }
1106
1107            // Initialize statistics if first batch
1108            if self.running_mean.is_none() {
1109                self.running_mean = Some(Array1::zeros(n_features));
1110                self.running_var = Some(Array1::zeros(n_features));
1111            }
1112
1113            // Update running statistics
1114            if self.normalize {
1115                self.update_statistics(&data)?;
1116            }
1117
1118            // Handle missing values (Drop may reduce the number of rows).
1119            data = self.handle_missing_values(data)?;
1120
1121            // Apply normalization
1122            if self.normalize {
1123                self.apply_normalization(&mut data)?;
1124            }
1125
1126            // Detect and handle outliers
1127            self.handle_outliers(&mut data)?;
1128
1129            Ok(data)
1130        }
1131
1132        /// Update running mean and variance
1133        fn update_statistics(&mut self, data: &Array2<F>) -> Result<()> {
1134            let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1135            let mean = self.running_mean.as_mut().expect("Operation failed");
1136            let var = self.running_var.as_mut().expect("Operation failed");
1137
1138            for i in 0..n_samples {
1139                self.sample_count += 1;
1140                let sample = data.row(i);
1141
1142                for j in 0..n_features {
1143                    if sample[j].is_finite() {
1144                        // Online update of mean and variance (Welford's algorithm)
1145                        let delta = sample[j] - mean[j];
1146                        mean[j] = mean[j]
1147                            + delta
1148                                / F::from(self.sample_count).expect("Failed to convert to float");
1149                        let delta2 = sample[j] - mean[j];
1150                        var[j] = var[j] + delta * delta2;
1151                    }
1152                }
1153            }
1154
1155            Ok(())
1156        }
1157
1158        /// Handle missing values.
1159        ///
1160        /// - `FillZero`: replace every non-finite element with zero in place.
1161        /// - `FillMean`: replace non-finite elements with the running mean.
1162        /// - `Drop`: discard any row that contains at least one non-finite element.
1163        /// - `Interpolate`: replace non-finite elements with the running mean for
1164        ///   the first/last positions, and with the linear interpolation between the
1165        ///   nearest finite neighbours for interior positions.  Falls back to the
1166        ///   running-mean value when no finite neighbour exists on one side.
1167        fn handle_missing_values(&self, mut data: Array2<F>) -> Result<Array2<F>> {
1168            let (n_samples, n_features) = (data.shape()[0], data.shape()[1]);
1169            match self.missing_value_strategy {
1170                MissingValueStrategy::FillZero => {
1171                    for elem in data.iter_mut() {
1172                        if !elem.is_finite() {
1173                            *elem = F::zero();
1174                        }
1175                    }
1176                }
1177                MissingValueStrategy::FillMean => {
1178                    if let Some(ref mean) = self.running_mean {
1179                        for mut row in data.rows_mut().into_iter() {
1180                            for (j, elem) in row.iter_mut().enumerate() {
1181                                if !elem.is_finite() && j < mean.len() {
1182                                    *elem = mean[j];
1183                                }
1184                            }
1185                        }
1186                    }
1187                }
1188                MissingValueStrategy::Drop => {
1189                    // Keep only rows that are entirely finite.
1190                    let valid_rows: Vec<usize> = (0..n_samples)
1191                        .filter(|&i| data.row(i).iter().all(|v| v.is_finite()))
1192                        .collect();
1193                    if valid_rows.len() == n_samples {
1194                        // Nothing to drop — return unchanged.
1195                        return Ok(data);
1196                    }
1197                    let mut kept = Array2::zeros((valid_rows.len(), n_features));
1198                    for (new_idx, &old_idx) in valid_rows.iter().enumerate() {
1199                        kept.row_mut(new_idx).assign(&data.row(old_idx));
1200                    }
1201                    return Ok(kept);
1202                }
1203                MissingValueStrategy::Interpolate => {
1204                    // Per-column linear interpolation: for each non-finite element,
1205                    // scan left and right for the nearest finite neighbours and
1206                    // linearly interpolate between them.  Falls back to the running
1207                    // mean when no finite neighbour exists on one or both sides.
1208                    let fallback = |j: usize| -> F {
1209                        self.running_mean
1210                            .as_ref()
1211                            .and_then(|m| if j < m.len() { Some(m[j]) } else { None })
1212                            .unwrap_or(F::zero())
1213                    };
1214
1215                    for col in 0..n_features {
1216                        // Collect column values.
1217                        let vals: Vec<F> = (0..n_samples).map(|r| data[[r, col]]).collect();
1218
1219                        for row in 0..n_samples {
1220                            if !vals[row].is_finite() {
1221                                // Find previous finite value.
1222                                let prev = (0..row)
1223                                    .rev()
1224                                    .find(|&r| vals[r].is_finite())
1225                                    .map(|r| (r, vals[r]));
1226                                // Find next finite value.
1227                                let next = ((row + 1)..n_samples)
1228                                    .find(|&r| vals[r].is_finite())
1229                                    .map(|r| (r, vals[r]));
1230
1231                                data[[row, col]] = match (prev, next) {
1232                                    (Some((pr, pv)), Some((nr, nv))) => {
1233                                        // Linear interpolation.
1234                                        let span = F::from(nr - pr).unwrap_or(F::one());
1235                                        let offset = F::from(row - pr).unwrap_or(F::zero());
1236                                        pv + (nv - pv) * offset / span
1237                                    }
1238                                    (Some((_, pv)), None) => pv,
1239                                    (None, Some((_, nv))) => nv,
1240                                    (None, None) => fallback(col),
1241                                };
1242                            }
1243                        }
1244                    }
1245                }
1246            }
1247            Ok(data)
1248        }
1249
1250        /// Apply normalization
1251        fn apply_normalization(&self, data: &mut Array2<F>) -> Result<()> {
1252            if let (Some(ref mean), Some(ref var)) = (&self.running_mean, &self.running_var) {
1253                if self.sample_count > 1 {
1254                    for mut row in data.rows_mut().into_iter() {
1255                        for (j, elem) in row.iter_mut().enumerate() {
1256                            if j < mean.len() && var[j] > F::zero() {
1257                                let std_dev = (var[j]
1258                                    / F::from(self.sample_count - 1)
1259                                        .expect("Failed to convert to float"))
1260                                .sqrt();
1261                                if std_dev > F::zero() {
1262                                    *elem = (*elem - mean[j]) / std_dev;
1263                                }
1264                            }
1265                        }
1266                    }
1267                }
1268            }
1269            Ok(())
1270        }
1271
1272        /// Handle outliers
1273        fn handle_outliers(&self, data: &mut Array2<F>) -> Result<()> {
1274            // Simple outlier detection: clip values beyond threshold standard deviations
1275            for elem in data.iter_mut() {
1276                if elem.abs() > self.outlier_threshold {
1277                    *elem = if *elem > F::zero() {
1278                        self.outlier_threshold
1279                    } else {
1280                        -self.outlier_threshold
1281                    };
1282                }
1283            }
1284            Ok(())
1285        }
1286
1287        /// Get current statistics
1288        pub fn get_statistics(&self) -> Option<(Array1<F>, Array1<F>)> {
1289            if let (Some(ref mean), Some(ref var)) = (&self.running_mean, &self.running_var) {
1290                Some((mean.clone(), var.clone()))
1291            } else {
1292                None
1293            }
1294        }
1295    }
1296
1297    #[cfg(test)]
1298    mod tests_preprocessor {
1299        use super::*;
1300        use scirs2_core::ndarray::Array2;
1301
1302        /// Build a 3×2 array: rows 0 and 2 have finite values; row 1 has an NaN.
1303        fn make_nan_data() -> Array2<f64> {
1304            Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0])
1305                .expect("shape error")
1306        }
1307
1308        #[test]
1309        fn test_fill_zero_replaces_nan() {
1310            // Use large outlier_threshold so test values are not clipped.
1311            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1312            pp.missing_value_strategy = MissingValueStrategy::FillZero;
1313            let data = make_nan_data();
1314            let out = pp.process_batch(data).expect("process_batch failed");
1315            assert_eq!(out.shape(), &[3, 2]);
1316            assert_eq!(out[[1, 0]], 0.0, "NaN should be replaced with 0");
1317            assert!(out[[1, 0]].is_finite());
1318        }
1319
1320        #[test]
1321        fn test_fill_mean_uses_running_mean() {
1322            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1323            pp.missing_value_strategy = MissingValueStrategy::FillMean;
1324            pp.running_mean = Some(scirs2_core::ndarray::array![2.0, 3.0]);
1325            let data = make_nan_data();
1326            let out = pp.process_batch(data).expect("process_batch failed");
1327            // Row 1, col 0 was NaN → filled with running mean for col 0 (2.0).
1328            assert!(
1329                (out[[1, 0]] - 2.0).abs() < 1e-9,
1330                "NaN should be replaced with running mean 2.0, got {}",
1331                out[[1, 0]]
1332            );
1333        }
1334
1335        #[test]
1336        fn test_drop_removes_nan_rows() {
1337            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1338            pp.missing_value_strategy = MissingValueStrategy::Drop;
1339            let data = make_nan_data();
1340            let out = pp.process_batch(data).expect("process_batch failed");
1341            // Row 1 had NaN → dropped; expect 2 rows.
1342            assert_eq!(out.shape(), &[2, 2], "NaN row should have been dropped");
1343            // Remaining rows should be all-finite.
1344            for row in out.rows() {
1345                for &v in row.iter() {
1346                    assert!(v.is_finite(), "all remaining values should be finite");
1347                }
1348            }
1349        }
1350
1351        #[test]
1352        fn test_drop_all_clean_data_unchanged() {
1353            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1354            pp.missing_value_strategy = MissingValueStrategy::Drop;
1355            let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1356                .expect("shape error");
1357            let out = pp.process_batch(data).expect("process_batch failed");
1358            assert_eq!(
1359                out.shape(),
1360                &[3, 2],
1361                "no rows should be dropped when data is clean"
1362            );
1363        }
1364
1365        #[test]
1366        fn test_interpolate_fills_nan_between_finite() {
1367            // Use outlier_threshold = 1000.0 so small test values are not clipped.
1368            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1369            pp.missing_value_strategy = MissingValueStrategy::Interpolate;
1370            // col 0: 1.0, NaN, 3.0  → linearly interpolated → 2.0
1371            // col 1: 4.0, NaN, 8.0  → linearly interpolated → 6.0
1372            let data = Array2::from_shape_vec((3, 2), vec![1.0, 4.0, f64::NAN, f64::NAN, 3.0, 8.0])
1373                .expect("shape error");
1374            let out = pp.process_batch(data).expect("process_batch failed");
1375            assert_eq!(out.shape(), &[3, 2]);
1376            assert!(
1377                (out[[1, 0]] - 2.0).abs() < 1e-9,
1378                "interpolated col 0 should be 2.0, got {}",
1379                out[[1, 0]]
1380            );
1381            assert!(
1382                (out[[1, 1]] - 6.0).abs() < 1e-9,
1383                "interpolated col 1 should be 6.0, got {}",
1384                out[[1, 1]]
1385            );
1386        }
1387
1388        #[test]
1389        fn test_interpolate_leading_nan_uses_next_finite() {
1390            // Use outlier_threshold = 1000.0 so values are not clipped.
1391            let mut pp = StreamingPreprocessor::<f64>::new(false, 1000.0);
1392            pp.missing_value_strategy = MissingValueStrategy::Interpolate;
1393            // col 0: NaN, 5.0, 7.0  → leading NaN filled with next finite (5.0).
1394            let data = Array2::from_shape_vec((3, 2), vec![f64::NAN, 0.0, 5.0, 0.0, 7.0, 0.0])
1395                .expect("shape error");
1396            let out = pp.process_batch(data).expect("process_batch failed");
1397            assert!(
1398                (out[[0, 0]] - 5.0).abs() < 1e-9,
1399                "leading NaN should be filled with 5.0, got {}",
1400                out[[0, 0]]
1401            );
1402        }
1403    }
1404}
1405
1406/// Advanced online clustering algorithms for streaming data
1407pub mod online_algorithms {
1408    use super::*;
1409    use std::collections::VecDeque;
1410
1411    /// Online K-means with adaptive learning rate
1412    #[derive(Debug, Clone)]
1413    pub struct AdaptiveOnlineKMeans<F: Float> {
1414        /// Current cluster centers
1415        centers: Array2<F>,
1416        /// Learning rate schedule
1417        learning_rate_schedule: LearningRateSchedule,
1418        /// Current iteration count
1419        iteration: usize,
1420        /// Adaptive parameters
1421        adaptive_params: AdaptiveParams<F>,
1422        /// Performance metrics
1423        metrics: OnlineMetrics,
1424    }
1425
1426    /// Learning rate scheduling strategies
1427    #[derive(Debug, Clone)]
1428    pub enum LearningRateSchedule {
1429        /// Constant learning rate
1430        Constant(f64),
1431        /// Decreasing with iteration: lr / (1 + decay * iteration)
1432        Decay { initial_lr: f64, decay: f64 },
1433        /// Step decay: lr * factor every step_size iterations
1434        StepDecay {
1435            initial_lr: f64,
1436            factor: f64,
1437            step_size: usize,
1438        },
1439        /// Adaptive based on cluster stability
1440        Adaptive {
1441            min_lr: f64,
1442            max_lr: f64,
1443            stability_window: usize,
1444        },
1445    }
1446
1447    /// Adaptive parameters for online learning
1448    #[derive(Debug, Clone)]
1449    pub struct AdaptiveParams<F: Float> {
1450        /// Momentum for center updates
1451        pub momentum: F,
1452        /// Cluster stability tracking
1453        pub stability_scores: Vec<F>,
1454        /// Recent center movements
1455        pub center_movements: VecDeque<F>,
1456        /// Automatic cluster count adjustment
1457        pub auto_k_adjustment: bool,
1458        /// Split threshold for creating new clusters
1459        pub split_threshold: F,
1460        /// Merge threshold for combining clusters
1461        pub merge_threshold: F,
1462    }
1463
1464    /// Online performance metrics
1465    #[derive(Debug, Clone, Default)]
1466    pub struct OnlineMetrics {
1467        /// Running estimate of within-cluster sum of squares
1468        pub wcss: f64,
1469        /// Number of samples processed
1470        pub samples_processed: usize,
1471        /// Center update frequency
1472        pub update_frequency: f64,
1473        /// Cluster assignments distribution
1474        pub cluster_distribution: Vec<usize>,
1475        /// Processing time per batch
1476        pub batch_processing_times: VecDeque<f64>,
1477    }
1478
1479    impl<F: Float + FromPrimitive + Debug> AdaptiveOnlineKMeans<F> {
1480        /// Create a new adaptive online K-means instance
1481        pub fn new(
1482            initial_centers: Array2<F>,
1483            learning_rate_schedule: LearningRateSchedule,
1484        ) -> Self {
1485            let n_clusters = initial_centers.nrows();
1486            let adaptive_params = AdaptiveParams {
1487                momentum: F::from(0.9).expect("Failed to convert constant to float"),
1488                stability_scores: vec![F::zero(); n_clusters],
1489                center_movements: VecDeque::with_capacity(100),
1490                auto_k_adjustment: false,
1491                split_threshold: F::from(2.0).expect("Failed to convert constant to float"),
1492                merge_threshold: F::from(0.5).expect("Failed to convert constant to float"),
1493            };
1494
1495            Self {
1496                centers: initial_centers,
1497                learning_rate_schedule,
1498                iteration: 0,
1499                adaptive_params,
1500                metrics: OnlineMetrics::default(),
1501            }
1502        }
1503
1504        /// Process a new sample and update clusters
1505        pub fn update(&mut self, sample: ArrayView1<F>) -> Result<usize> {
1506            let start_time = std::time::Instant::now();
1507
1508            // Find nearest cluster
1509            let (nearest_cluster, min_distance) = self.find_nearest_cluster(sample)?;
1510
1511            // Get current learning rate
1512            let lr = self.get_current_learning_rate();
1513
1514            // Update cluster center
1515            let old_center = self.centers.row(nearest_cluster).to_owned();
1516            self.update_center(nearest_cluster, sample, lr)?;
1517
1518            // Track center movement for adaptive learning
1519            let movement = euclidean_distance(old_center.view(), self.centers.row(nearest_cluster));
1520            self.adaptive_params.center_movements.push_back(movement);
1521            if self.adaptive_params.center_movements.len() > 100 {
1522                self.adaptive_params.center_movements.pop_front();
1523            }
1524
1525            // Update metrics
1526            self.update_metrics(
1527                nearest_cluster,
1528                min_distance,
1529                start_time.elapsed().as_secs_f64(),
1530            );
1531
1532            // Check for adaptive cluster adjustments
1533            if self.adaptive_params.auto_k_adjustment {
1534                self.maybe_adjust_clusters(sample, min_distance)?;
1535            }
1536
1537            self.iteration += 1;
1538            Ok(nearest_cluster)
1539        }
1540
1541        /// Adaptively adjust cluster count based on distance thresholds
1542        fn maybe_adjust_clusters(&mut self, sample: ArrayView1<F>, min_distance: F) -> Result<()> {
1543            // Split: Create new cluster if sample is far from all existing clusters
1544            if min_distance > self.adaptive_params.split_threshold {
1545                // Add new cluster at sample location
1546                let n_clusters = self.centers.nrows();
1547                let n_features = self.centers.ncols();
1548
1549                let mut new_centers = Array2::<F>::zeros((n_clusters + 1, n_features));
1550
1551                // Copy existing centers
1552                for i in 0..n_clusters {
1553                    for j in 0..n_features {
1554                        new_centers[[i, j]] = self.centers[[i, j]];
1555                    }
1556                }
1557
1558                // Add new center at sample location
1559                for (j, &val) in sample.iter().enumerate() {
1560                    if j < n_features {
1561                        new_centers[[n_clusters, j]] = val;
1562                    }
1563                }
1564
1565                self.centers = new_centers;
1566                self.adaptive_params.stability_scores.push(F::zero());
1567            }
1568
1569            // Merge: Combine clusters that are too close
1570            let n_clusters = self.centers.nrows();
1571            if n_clusters > 1 {
1572                let mut clusters_to_merge: Option<(usize, usize)> = None;
1573                let mut min_inter_distance = F::infinity();
1574
1575                // Find closest pair of clusters
1576                for i in 0..n_clusters {
1577                    for j in (i + 1)..n_clusters {
1578                        let dist = euclidean_distance(self.centers.row(i), self.centers.row(j));
1579
1580                        if dist < min_inter_distance {
1581                            min_inter_distance = dist;
1582                            clusters_to_merge = Some((i, j));
1583                        }
1584                    }
1585                }
1586
1587                // Merge if below threshold
1588                if let Some((i, j)) = clusters_to_merge {
1589                    if min_inter_distance < self.adaptive_params.merge_threshold {
1590                        let n_features = self.centers.ncols();
1591                        let mut new_centers = Array2::<F>::zeros((n_clusters - 1, n_features));
1592
1593                        // Merge clusters i and j by averaging
1594                        let mut merged_center = Array1::<F>::zeros(n_features);
1595                        for k in 0..n_features {
1596                            merged_center[k] = (self.centers[[i, k]] + self.centers[[j, k]])
1597                                / (F::one() + F::one());
1598                        }
1599
1600                        // Build new center matrix
1601                        let mut new_idx = 0;
1602                        for old_idx in 0..n_clusters {
1603                            if old_idx == i {
1604                                // Add merged center
1605                                for k in 0..n_features {
1606                                    new_centers[[new_idx, k]] = merged_center[k];
1607                                }
1608                                new_idx += 1;
1609                            } else if old_idx != j {
1610                                // Copy unchanged center
1611                                for k in 0..n_features {
1612                                    new_centers[[new_idx, k]] = self.centers[[old_idx, k]];
1613                                }
1614                                new_idx += 1;
1615                            }
1616                        }
1617
1618                        self.centers = new_centers;
1619
1620                        // Update stability scores (remove merged cluster)
1621                        if j < self.adaptive_params.stability_scores.len() {
1622                            self.adaptive_params.stability_scores.remove(j);
1623                        }
1624                    }
1625                }
1626            }
1627
1628            Ok(())
1629        }
1630
1631        /// Find the nearest cluster to a sample
1632        fn find_nearest_cluster(&self, sample: ArrayView1<F>) -> Result<(usize, F)> {
1633            let mut min_distance = F::infinity();
1634            let mut nearest_cluster = 0;
1635
1636            for (i, center) in self.centers.rows().into_iter().enumerate() {
1637                let distance = euclidean_distance(sample, center);
1638                if distance < min_distance {
1639                    min_distance = distance;
1640                    nearest_cluster = i;
1641                }
1642            }
1643
1644            Ok((nearest_cluster, min_distance))
1645        }
1646
1647        /// Update a cluster center using momentum
1648        fn update_center(
1649            &mut self,
1650            cluster_idx: usize,
1651            sample: ArrayView1<F>,
1652            lr: f64,
1653        ) -> Result<()> {
1654            let learning_rate = F::from(lr).expect("Failed to convert to float");
1655            let momentum = self.adaptive_params.momentum;
1656
1657            let mut center = self.centers.row_mut(cluster_idx);
1658            for (i, &sample_val) in sample.iter().enumerate() {
1659                if i < center.len() {
1660                    let old_val = center[i];
1661                    let gradient = sample_val - old_val;
1662                    let update = learning_rate * gradient;
1663                    center[i] = momentum * old_val + (F::one() - momentum) * (old_val + update);
1664                }
1665            }
1666
1667            Ok(())
1668        }
1669
1670        /// Get current learning rate based on schedule
1671        fn get_current_learning_rate(&self) -> f64 {
1672            match &self.learning_rate_schedule {
1673                LearningRateSchedule::Constant(lr) => *lr,
1674                LearningRateSchedule::Decay { initial_lr, decay } => {
1675                    initial_lr / (1.0 + decay * self.iteration as f64)
1676                }
1677                LearningRateSchedule::StepDecay {
1678                    initial_lr,
1679                    factor,
1680                    step_size,
1681                } => {
1682                    let steps = self.iteration / step_size;
1683                    initial_lr * factor.powi(steps as i32)
1684                }
1685                LearningRateSchedule::Adaptive {
1686                    min_lr,
1687                    max_lr,
1688                    stability_window,
1689                } => {
1690                    let recent_movements: Vec<F> = self
1691                        .adaptive_params
1692                        .center_movements
1693                        .iter()
1694                        .rev()
1695                        .take(*stability_window)
1696                        .cloned()
1697                        .collect();
1698
1699                    if recent_movements.is_empty() {
1700                        return *max_lr;
1701                    }
1702
1703                    let avg_movement = recent_movements.iter().fold(F::zero(), |acc, x| acc + *x)
1704                        / F::from(recent_movements.len()).expect("Operation failed");
1705                    let stability = F::one() / (F::one() + avg_movement);
1706
1707                    // High stability = low learning rate, low stability = high learning rate
1708                    let adaptive_lr = min_lr
1709                        + (max_lr - min_lr)
1710                            * (F::one() - stability).to_f64().expect("Operation failed");
1711                    adaptive_lr.clamp(*min_lr, *max_lr)
1712                }
1713            }
1714        }
1715
1716        /// Update performance metrics
1717        fn update_metrics(&mut self, cluster_idx: usize, distance: F, processing_time: f64) {
1718            self.metrics.samples_processed += 1;
1719
1720            // Update WCSS estimate
1721            let distance_sq = distance.to_f64().expect("Operation failed").powi(2);
1722            let n = self.metrics.samples_processed as f64;
1723            self.metrics.wcss = ((n - 1.0) * self.metrics.wcss + distance_sq) / n;
1724
1725            // Update cluster distribution
1726            if cluster_idx >= self.metrics.cluster_distribution.len() {
1727                self.metrics.cluster_distribution.resize(cluster_idx + 1, 0);
1728            }
1729            self.metrics.cluster_distribution[cluster_idx] += 1;
1730
1731            // Track processing times
1732            self.metrics
1733                .batch_processing_times
1734                .push_back(processing_time);
1735            if self.metrics.batch_processing_times.len() > 1000 {
1736                self.metrics.batch_processing_times.pop_front();
1737            }
1738
1739            // Update frequency calculation
1740            let total_updates = self.metrics.cluster_distribution.iter().sum::<usize>() as f64;
1741            self.metrics.update_frequency = total_updates / self.iteration.max(1) as f64;
1742        }
1743
1744        /// Get current cluster centers
1745        pub fn get_centers(&self) -> &Array2<F> {
1746            &self.centers
1747        }
1748
1749        /// Get current metrics
1750        pub fn get_metrics(&self) -> &OnlineMetrics {
1751            &self.metrics
1752        }
1753
1754        /// Predict cluster for new samples
1755        pub fn predict(&self, samples: ArrayView2<F>) -> Result<Array1<usize>> {
1756            let n_samples = samples.nrows();
1757            let mut predictions = Array1::zeros(n_samples);
1758
1759            for (i, sample) in samples.rows().into_iter().enumerate() {
1760                let (cluster_id, _distance) = self.find_nearest_cluster(sample)?;
1761                predictions[i] = cluster_id;
1762            }
1763
1764            Ok(predictions)
1765        }
1766    }
1767}