Skip to main content

torsh_cluster/utils/
memory_efficient.rs

1//! Memory-efficient clustering operations for large-scale datasets
2//!
3//! This module provides utilities for clustering datasets that don't fit in memory
4//! by processing data in chunks and using lazy evaluation strategies.
5//!
6//! # Key Features
7//!
8//! - **Chunked Processing**: Process large datasets in manageable chunks
9//! - **Streaming K-Means**: Cluster streaming data without loading all into memory
10//! - **Memory Profiling**: Track and optimize memory usage during clustering
11//! - **Lazy Centroid Updates**: Delay expensive operations until necessary
12//!
13//! # SciRS2 POLICY Compliance
14//!
15//! All operations use `scirs2_core` abstractions:
16//! - Array operations via `scirs2_core::ndarray`
17//! - Parallel processing via `scirs2_core::parallel_ops`
18//! - Random number generation via `scirs2_core::random`
19
20use crate::error::{ClusterError, ClusterResult};
21use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
22use scirs2_core::parallel_ops::{IntoParallelIterator, ParallelIterator};
23use std::sync::Arc;
24use torsh_tensor::Tensor;
25
26/// Configuration for memory-efficient operations
27#[derive(Debug, Clone)]
28pub struct MemoryEfficientConfig {
29    /// Maximum chunk size for processing (number of samples)
30    pub chunk_size: usize,
31    /// Whether to use parallel processing for chunks
32    pub parallel: bool,
33    /// Memory limit in bytes (approximate)
34    pub memory_limit_mb: Option<usize>,
35}
36
37impl Default for MemoryEfficientConfig {
38    fn default() -> Self {
39        Self {
40            chunk_size: 1000,
41            parallel: true,
42            memory_limit_mb: None,
43        }
44    }
45}
46
47/// Memory-efficient chunked data processor
48///
49/// Processes large datasets in chunks to avoid memory overflow.
50///
51/// # Example
52///
53/// ```rust
54/// use torsh_cluster::utils::memory_efficient::ChunkedDataProcessor;
55/// use torsh_tensor::Tensor;
56///
57/// let large_data = Tensor::from_vec(
58///     (0..10000).map(|i| i as f32).collect(),
59///     &[1000, 10]
60/// )?;
61///
62/// let processor = ChunkedDataProcessor::new(100); // Process 100 samples at a time
63///
64/// let mut sum = 0.0;
65/// processor.process(&large_data, |chunk| {
66///     // Process each chunk
67///     let chunk_sum: f32 = chunk.iter().sum();
68///     sum += chunk_sum;
69///     Ok(())
70/// })?;
71/// # Ok::<(), Box<dyn std::error::Error>>(())
72/// ```
73pub struct ChunkedDataProcessor {
74    chunk_size: usize,
75    parallel: bool,
76}
77
78impl ChunkedDataProcessor {
79    /// Create a new chunked data processor
80    pub fn new(chunk_size: usize) -> Self {
81        Self {
82            chunk_size,
83            parallel: true,
84        }
85    }
86
87    /// Set whether to use parallel processing
88    pub fn parallel(mut self, parallel: bool) -> Self {
89        self.parallel = parallel;
90        self
91    }
92
93    /// Process data in chunks
94    ///
95    /// Applies the given function to each chunk of data sequentially.
96    pub fn process<F>(&self, data: &Tensor, mut f: F) -> ClusterResult<()>
97    where
98        F: FnMut(ArrayView2<f32>) -> ClusterResult<()>,
99    {
100        let shape = data.shape();
101        let n_samples = shape.dims()[0];
102        let n_features = shape.dims()[1];
103
104        // Convert tensor to ndarray for efficient slicing
105        let data_vec = data.to_vec()?;
106        let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
107            .map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
108
109        // Process in chunks
110        for start_idx in (0..n_samples).step_by(self.chunk_size) {
111            let end_idx = (start_idx + self.chunk_size).min(n_samples);
112            let chunk = data_array.slice(s![start_idx..end_idx, ..]);
113            f(chunk)?;
114        }
115
116        Ok(())
117    }
118
119    /// Process data in parallel chunks
120    ///
121    /// Applies the given function to each chunk in parallel.
122    /// Results from all chunks are collected and returned.
123    pub fn process_parallel<F, R>(&self, data: &Tensor, f: F) -> ClusterResult<Vec<R>>
124    where
125        F: Fn(ArrayView2<f32>) -> ClusterResult<R> + Send + Sync,
126        R: Send,
127    {
128        let shape = data.shape();
129        let n_samples = shape.dims()[0];
130        let n_features = shape.dims()[1];
131
132        // Convert tensor to ndarray
133        let data_vec = data.to_vec()?;
134        let data_array = Array2::from_shape_vec((n_samples, n_features), data_vec)
135            .map_err(|e| ClusterError::InvalidInput(format!("Shape error: {}", e)))?;
136
137        // Create Arc for thread-safe sharing
138        let data_arc = Arc::new(data_array);
139
140        // Collect chunk indices
141        let chunks: Vec<(usize, usize)> = (0..n_samples)
142            .step_by(self.chunk_size)
143            .map(|start| {
144                let end = (start + self.chunk_size).min(n_samples);
145                (start, end)
146            })
147            .collect();
148
149        if !self.parallel || chunks.len() <= 1 {
150            // Sequential processing
151            let results: Result<Vec<R>, ClusterError> = chunks
152                .iter()
153                .map(|(start, end)| {
154                    let chunk = data_arc.slice(s![*start..*end, ..]);
155                    f(chunk)
156                })
157                .collect();
158            return results;
159        }
160
161        // Parallel processing using scirs2_core parallel_ops
162        let results: Result<Vec<R>, ClusterError> = chunks
163            .into_par_iter()
164            .map(|(start, end)| {
165                let chunk = data_arc.slice(s![start..end, ..]);
166                f(chunk)
167            })
168            .collect();
169
170        results
171    }
172
173    /// Calculate optimal chunk size based on available memory and data dimensions
174    pub fn optimal_chunk_size(
175        n_samples: usize,
176        n_features: usize,
177        available_memory_mb: usize,
178    ) -> usize {
179        // Estimate memory per sample (assuming f32)
180        let bytes_per_sample = n_features * std::mem::size_of::<f32>();
181        let available_bytes = available_memory_mb * 1024 * 1024;
182
183        // Use 80% of available memory for safety
184        let safe_bytes = (available_bytes as f64 * 0.8) as usize;
185
186        // Calculate chunk size
187        let chunk_size = safe_bytes / bytes_per_sample;
188
189        // Ensure at least 10 samples per chunk, but not more than total samples
190        chunk_size.max(10).min(n_samples)
191    }
192}
193
194/// Memory-efficient incremental centroid updater
195///
196/// Updates centroids incrementally as new data arrives, minimizing memory overhead.
197///
198/// # Mathematical Formulation
199///
200/// For online centroid updates, we use Welford's algorithm:
201///
202/// ```text
203/// μ_{n+1} = μ_n + (x_{n+1} - μ_n) / (n + 1)
204/// ```
205///
206/// where:
207/// - `μ_n` is the mean after n samples
208/// - `x_{n+1}` is the new sample
209/// - `n` is the number of samples seen so far
210pub struct IncrementalCentroidUpdater {
211    /// Current centroids
212    centroids: Array2<f64>,
213    /// Number of samples assigned to each centroid
214    counts: Array1<usize>,
215    /// Total samples processed
216    n_samples: usize,
217}
218
219impl IncrementalCentroidUpdater {
220    /// Create a new incremental centroid updater
221    pub fn new(n_clusters: usize, n_features: usize) -> Self {
222        Self {
223            centroids: Array2::zeros((n_clusters, n_features)),
224            counts: Array1::zeros(n_clusters),
225            n_samples: 0,
226        }
227    }
228
229    /// Initialize centroids from initial samples
230    pub fn initialize(&mut self, initial_centroids: ArrayView2<f64>) -> ClusterResult<()> {
231        let (n_clusters, n_features) = initial_centroids.dim();
232
233        if (n_clusters, n_features) != self.centroids.dim() {
234            return Err(ClusterError::InvalidInput(format!(
235                "Expected {} clusters and {} features, got {} and {}",
236                self.centroids.nrows(),
237                self.centroids.ncols(),
238                n_clusters,
239                n_features
240            )));
241        }
242
243        self.centroids.assign(&initial_centroids);
244        self.counts.fill(1); // Assume each centroid initialized with one sample
245        self.n_samples = n_clusters;
246
247        Ok(())
248    }
249
250    /// Update centroids with new sample batch
251    ///
252    /// Uses incremental averaging to update centroids without storing all data.
253    pub fn update_batch(
254        &mut self,
255        samples: ArrayView2<f64>,
256        labels: &[usize],
257    ) -> ClusterResult<()> {
258        if samples.nrows() != labels.len() {
259            return Err(ClusterError::InvalidInput(format!(
260                "Sample count {} doesn't match label count {}",
261                samples.nrows(),
262                labels.len()
263            )));
264        }
265
266        // Update each centroid incrementally
267        for (sample, &label) in samples.outer_iter().zip(labels.iter()) {
268            if label >= self.centroids.nrows() {
269                return Err(ClusterError::InvalidInput(format!(
270                    "Label {} exceeds number of clusters {}",
271                    label,
272                    self.centroids.nrows()
273                )));
274            }
275
276            let count = self.counts[label];
277            let mut centroid = self.centroids.row_mut(label);
278
279            // Welford's algorithm for incremental mean update
280            for (i, &value) in sample.iter().enumerate() {
281                centroid[i] += (value - centroid[i]) / (count + 1) as f64;
282            }
283
284            self.counts[label] += 1;
285        }
286
287        self.n_samples += samples.nrows();
288
289        Ok(())
290    }
291
292    /// Get current centroids
293    pub fn centroids(&self) -> ArrayView2<'_, f64> {
294        self.centroids.view()
295    }
296
297    /// Get cluster counts
298    pub fn counts(&self) -> &Array1<usize> {
299        &self.counts
300    }
301
302    /// Get total samples processed
303    pub fn n_samples(&self) -> usize {
304        self.n_samples
305    }
306}
307
308/// Estimate memory usage for clustering operation
309///
310/// Returns estimated memory usage in megabytes.
311pub fn estimate_memory_usage(n_samples: usize, n_features: usize, n_clusters: usize) -> f64 {
312    // Data matrix: n_samples × n_features (f32)
313    let data_size = n_samples * n_features * std::mem::size_of::<f32>();
314
315    // Centroids: n_clusters × n_features (f64)
316    let centroids_size = n_clusters * n_features * std::mem::size_of::<f64>();
317
318    // Labels: n_samples (usize)
319    let labels_size = n_samples * std::mem::size_of::<usize>();
320
321    // Distance matrix (for some algorithms): n_samples × n_clusters (f32)
322    let distances_size = n_samples * n_clusters * std::mem::size_of::<f32>();
323
324    // Total in MB
325    let total_bytes = data_size + centroids_size + labels_size + distances_size;
326    total_bytes as f64 / (1024.0 * 1024.0)
327}
328
329/// Suggest optimal clustering strategy based on dataset size and available memory
330pub fn suggest_clustering_strategy(
331    n_samples: usize,
332    n_features: usize,
333    available_memory_mb: usize,
334) -> String {
335    let estimated_mb = estimate_memory_usage(n_samples, n_features, 10); // Assume 10 clusters
336
337    if estimated_mb < available_memory_mb as f64 * 0.5 {
338        format!(
339            "Standard clustering (estimated {:.2} MB, available {} MB)",
340            estimated_mb, available_memory_mb
341        )
342    } else if estimated_mb < available_memory_mb as f64 * 0.8 {
343        format!(
344            "Use parallel processing with caution (estimated {:.2} MB, available {} MB)",
345            estimated_mb, available_memory_mb
346        )
347    } else {
348        let chunk_size =
349            ChunkedDataProcessor::optimal_chunk_size(n_samples, n_features, available_memory_mb);
350        format!(
351            "Use chunked processing with chunk_size={} (estimated {:.2} MB exceeds available {} MB)",
352            chunk_size, estimated_mb, available_memory_mb
353        )
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use approx::assert_relative_eq;
361
362    #[test]
363    fn test_chunked_processor_basic() -> Result<(), Box<dyn std::error::Error>> {
364        let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
365
366        let processor = ChunkedDataProcessor::new(3);
367
368        let mut chunk_count = 0;
369        processor.process(&data, |chunk| {
370            chunk_count += 1;
371            assert!(chunk.nrows() <= 3);
372            Ok(())
373        })?;
374
375        assert_eq!(chunk_count, 4); // 10 samples / 3 per chunk = 4 chunks
376
377        Ok(())
378    }
379
380    #[test]
381    fn test_chunked_processor_parallel() -> Result<(), Box<dyn std::error::Error>> {
382        let data = Tensor::from_vec((0..100).map(|i| i as f32).collect(), &[10, 10])?;
383
384        let processor = ChunkedDataProcessor::new(3).parallel(true);
385
386        let results = processor.process_parallel(&data, |chunk| Ok(chunk.nrows()))?;
387
388        assert_eq!(results.len(), 4);
389        assert_eq!(results.iter().sum::<usize>(), 10);
390
391        Ok(())
392    }
393
394    #[test]
395    fn test_optimal_chunk_size() {
396        // 1000 samples, 100 features, 100 MB available
397        let chunk_size = ChunkedDataProcessor::optimal_chunk_size(1000, 100, 100);
398
399        // Each sample is 100 * 4 = 400 bytes
400        // 100 MB * 0.8 = 80 MB = 83,886,080 bytes
401        // 83,886,080 / 400 = 209,715 samples (but capped at 1000)
402        assert!(chunk_size > 0);
403        assert!(chunk_size <= 1000);
404    }
405
406    #[test]
407    fn test_incremental_centroid_updater() -> Result<(), Box<dyn std::error::Error>> {
408        let mut updater = IncrementalCentroidUpdater::new(2, 3);
409
410        // Initialize with some centroids
411        let initial = Array2::from_shape_vec((2, 3), vec![0.0, 0.0, 0.0, 5.0, 5.0, 5.0])?;
412        updater.initialize(initial.view())?;
413
414        // Add new samples
415        let samples = Array2::from_shape_vec((2, 3), vec![1.0, 1.0, 1.0, 6.0, 6.0, 6.0])?;
416        let labels = vec![0, 1];
417        updater.update_batch(samples.view(), &labels)?;
418
419        // Check updated centroids
420        let centroids = updater.centroids();
421        assert_relative_eq!(centroids[[0, 0]], 0.5, epsilon = 1e-6);
422        assert_relative_eq!(centroids[[1, 0]], 5.5, epsilon = 1e-6);
423
424        assert_eq!(updater.n_samples(), 4); // 2 initial + 2 new
425
426        Ok(())
427    }
428
429    #[test]
430    fn test_memory_estimation() {
431        let memory_mb = estimate_memory_usage(1000, 100, 10);
432
433        // Data: 1000 * 100 * 4 = 400,000 bytes
434        // Centroids: 10 * 100 * 8 = 8,000 bytes
435        // Labels: 1000 * 8 = 8,000 bytes
436        // Distances: 1000 * 10 * 4 = 40,000 bytes
437        // Total: ~456,000 bytes ≈ 0.435 MB
438
439        assert!(memory_mb > 0.4);
440        assert!(memory_mb < 0.5);
441    }
442
443    #[test]
444    fn test_suggest_clustering_strategy() {
445        // Small dataset that fits in memory
446        let strategy = suggest_clustering_strategy(100, 10, 100);
447        assert!(strategy.contains("Standard"));
448
449        // Large dataset that needs chunking
450        let strategy = suggest_clustering_strategy(1_000_000, 100, 10);
451        assert!(strategy.contains("chunked"));
452    }
453}