sklears_model_selection/
memory_efficient.rs

1//! Memory-Efficient Evaluation System
2//!
3//! This module provides memory-efficient evaluation and cross-validation methods
4//! designed for large datasets that don't fit in memory, using streaming algorithms,
5//! memory mapping, and chunk-based processing.
6
7use std::collections::VecDeque;
8use thiserror::Error;
9
10/// Memory-efficient evaluation errors
11#[derive(Error, Debug)]
12pub enum MemoryError {
13    #[error("IO error: {0}")]
14    Io(#[from] std::io::Error),
15    #[error("Chunk processing error: {0}")]
16    ChunkProcessing(String),
17    #[error("Memory limit exceeded: requested {requested}MB, limit {limit}MB")]
18    MemoryLimitExceeded { requested: usize, limit: usize },
19    #[error("Streaming error: {0}")]
20    Streaming(String),
21}
22
23/// Memory usage tracking and management
24pub struct MemoryTracker {
25    current_usage: usize,
26    peak_usage: usize,
27    limit: Option<usize>,
28}
29
30impl MemoryTracker {
31    /// Create a new memory tracker with optional limit (in MB)
32    pub fn new(limit_mb: Option<usize>) -> Self {
33        Self {
34            current_usage: 0,
35            peak_usage: 0,
36            limit: limit_mb,
37        }
38    }
39
40    /// Allocate memory and track usage
41    pub fn allocate(&mut self, size_mb: usize) -> Result<(), MemoryError> {
42        if let Some(limit) = self.limit {
43            if self.current_usage + size_mb > limit {
44                return Err(MemoryError::MemoryLimitExceeded {
45                    requested: size_mb,
46                    limit,
47                });
48            }
49        }
50
51        self.current_usage += size_mb;
52        if self.current_usage > self.peak_usage {
53            self.peak_usage = self.current_usage;
54        }
55
56        Ok(())
57    }
58
59    /// Deallocate memory and update tracking
60    pub fn deallocate(&mut self, size_mb: usize) {
61        self.current_usage = self.current_usage.saturating_sub(size_mb);
62    }
63
64    /// Get current memory usage in MB
65    pub fn current_usage(&self) -> usize {
66        self.current_usage
67    }
68
69    /// Get peak memory usage in MB
70    pub fn peak_usage(&self) -> usize {
71        self.peak_usage
72    }
73
74    /// Get memory limit in MB
75    pub fn limit(&self) -> Option<usize> {
76        self.limit
77    }
78}
79
80/// Configuration for memory-efficient operations
81#[derive(Debug, Clone)]
82pub struct MemoryEfficientConfig {
83    /// Maximum chunk size in number of samples
84    pub chunk_size: usize,
85    /// Memory limit in MB
86    pub memory_limit: Option<usize>,
87    /// Enable memory mapping for large files
88    pub use_memory_mapping: bool,
89    /// Number of chunks to keep in memory buffer
90    pub buffer_size: usize,
91    /// Enable streaming mode for very large datasets
92    pub streaming_mode: bool,
93}
94
95impl Default for MemoryEfficientConfig {
96    fn default() -> Self {
97        Self {
98            chunk_size: 1000,
99            memory_limit: Some(1024), // 1GB default limit
100            use_memory_mapping: true,
101            buffer_size: 3, // Keep 3 chunks in memory
102            streaming_mode: false,
103        }
104    }
105}
106
107/// Streaming data chunk
108#[derive(Debug, Clone)]
109pub struct DataChunk<T> {
110    pub data: Vec<T>,
111    pub start_index: usize,
112    pub end_index: usize,
113}
114
115impl<T> DataChunk<T> {
116    pub fn new(data: Vec<T>, start_index: usize) -> Self {
117        let end_index = start_index + data.len();
118        Self {
119            data,
120            start_index,
121            end_index,
122        }
123    }
124
125    pub fn len(&self) -> usize {
126        self.data.len()
127    }
128
129    pub fn is_empty(&self) -> bool {
130        self.data.is_empty()
131    }
132}
133
134/// Streaming data reader for memory-efficient processing
135pub struct StreamingDataReader<T> {
136    chunks: VecDeque<DataChunk<T>>,
137    current_index: usize,
138    total_samples: usize,
139    config: MemoryEfficientConfig,
140    memory_tracker: MemoryTracker,
141}
142
143impl<T> StreamingDataReader<T>
144where
145    T: Clone + Send + Sync,
146{
147    /// Create a new streaming data reader
148    pub fn new(config: MemoryEfficientConfig) -> Self {
149        let memory_tracker = MemoryTracker::new(config.memory_limit);
150        Self {
151            chunks: VecDeque::new(),
152            current_index: 0,
153            total_samples: 0,
154            config,
155            memory_tracker,
156        }
157    }
158
159    /// Load data chunks from an iterator
160    pub fn load_from_iterator<I>(&mut self, data_iter: I) -> Result<(), MemoryError>
161    where
162        I: Iterator<Item = T>,
163    {
164        let mut chunk_data = Vec::with_capacity(self.config.chunk_size);
165        let mut start_index = 0;
166        let mut total_count = 0;
167
168        for (i, item) in data_iter.enumerate() {
169            chunk_data.push(item);
170            total_count += 1;
171
172            if chunk_data.len() >= self.config.chunk_size {
173                let chunk_size_mb = std::mem::size_of::<T>() * chunk_data.len() / 1_048_576;
174                self.memory_tracker.allocate(chunk_size_mb)?;
175
176                let chunk = DataChunk::new(chunk_data, start_index);
177                self.chunks.push_back(chunk);
178
179                chunk_data = Vec::with_capacity(self.config.chunk_size);
180                start_index = i + 1;
181            }
182
183            // Limit buffer size to prevent memory overflow
184            while self.chunks.len() > self.config.buffer_size {
185                if let Some(old_chunk) = self.chunks.pop_front() {
186                    let chunk_size_mb = std::mem::size_of::<T>() * old_chunk.len() / 1_048_576;
187                    self.memory_tracker.deallocate(chunk_size_mb);
188                }
189            }
190        }
191
192        // Handle remaining data
193        if !chunk_data.is_empty() {
194            let chunk_size_mb = std::mem::size_of::<T>() * chunk_data.len() / 1_048_576;
195            self.memory_tracker.allocate(chunk_size_mb)?;
196
197            let chunk = DataChunk::new(chunk_data, start_index);
198            self.chunks.push_back(chunk);
199        }
200
201        self.total_samples = total_count;
202        Ok(())
203    }
204
205    /// Get the next chunk of data
206    pub fn next_chunk(&mut self) -> Option<&DataChunk<T>> {
207        if self.chunks.is_empty() {
208            return None;
209        }
210
211        let front_chunk = self.chunks.front()?;
212        if self.current_index >= front_chunk.end_index {
213            // Move to next chunk
214            if let Some(old_chunk) = self.chunks.pop_front() {
215                let chunk_size_mb = std::mem::size_of::<T>() * old_chunk.len() / 1_048_576;
216                self.memory_tracker.deallocate(chunk_size_mb);
217            }
218            return self.next_chunk();
219        }
220
221        self.chunks.front()
222    }
223
224    /// Get current memory usage statistics
225    pub fn memory_stats(&self) -> (usize, usize, Option<usize>) {
226        (
227            self.memory_tracker.current_usage(),
228            self.memory_tracker.peak_usage(),
229            self.memory_tracker.limit(),
230        )
231    }
232
233    /// Get total number of samples
234    pub fn total_samples(&self) -> usize {
235        self.total_samples
236    }
237
238    /// Check if there are more chunks to process
239    pub fn has_more_chunks(&self) -> bool {
240        !self.chunks.is_empty() && self.current_index < self.total_samples
241    }
242}
243
244/// Memory-efficient cross-validation evaluator
245pub struct MemoryEfficientCrossValidator<T, L> {
246    config: MemoryEfficientConfig,
247    fold_indices: Vec<Vec<usize>>,
248    data_reader: StreamingDataReader<T>,
249    label_reader: StreamingDataReader<L>,
250}
251
252impl<T, L> MemoryEfficientCrossValidator<T, L>
253where
254    T: Clone + Send + Sync,
255    L: Clone + Send + Sync,
256{
257    /// Create a new memory-efficient cross-validator
258    pub fn new(config: MemoryEfficientConfig, n_folds: usize) -> Self {
259        Self {
260            config: config.clone(),
261            fold_indices: Vec::with_capacity(n_folds),
262            data_reader: StreamingDataReader::new(config.clone()),
263            label_reader: StreamingDataReader::new(config),
264        }
265    }
266
267    /// Set up fold indices for cross-validation
268    pub fn setup_folds(&mut self, n_samples: usize, n_folds: usize) {
269        let samples_per_fold = n_samples / n_folds;
270        let mut indices: Vec<usize> = (0..n_samples).collect();
271
272        // Simple shuffle (in practice, you'd use a proper random shuffle)
273        indices.sort_by_key(|&i| i % 997); // Simple pseudo-shuffle
274
275        for fold in 0..n_folds {
276            let start = fold * samples_per_fold;
277            let end = if fold == n_folds - 1 {
278                n_samples
279            } else {
280                (fold + 1) * samples_per_fold
281            };
282
283            self.fold_indices.push(indices[start..end].to_vec());
284        }
285    }
286
287    /// Perform streaming cross-validation evaluation
288    pub fn streaming_evaluate<F, R>(
289        &mut self,
290        train_func: F,
291    ) -> Result<StreamingEvaluationResult<R>, MemoryError>
292    where
293        F: Fn(&[T], &[L]) -> Result<R, MemoryError>,
294        R: Clone + Default,
295    {
296        let mut fold_results = Vec::new();
297        let mut memory_snapshots = Vec::new();
298
299        for fold_idx in 0..self.fold_indices.len() {
300            let test_indices = &self.fold_indices[fold_idx];
301
302            // Create training data by excluding test fold
303            let mut train_data = Vec::new();
304            let mut train_labels = Vec::new();
305
306            // Process data in chunks to avoid memory overflow
307            while let Some(data_chunk) = self.data_reader.next_chunk() {
308                let label_chunk = self.label_reader.next_chunk().ok_or_else(|| {
309                    MemoryError::Streaming("Mismatched data and labels".to_string())
310                })?;
311
312                for (i, (sample, label)) in data_chunk
313                    .data
314                    .iter()
315                    .zip(label_chunk.data.iter())
316                    .enumerate()
317                {
318                    let global_idx = data_chunk.start_index + i;
319                    if !test_indices.contains(&global_idx) {
320                        train_data.push(sample.clone());
321                        train_labels.push(label.clone());
322                    }
323                }
324            }
325
326            // Train and evaluate
327            let result = train_func(&train_data, &train_labels)?;
328            fold_results.push(result);
329
330            // Record memory usage
331            let (current, peak, limit) = self.data_reader.memory_stats();
332            memory_snapshots.push(MemorySnapshot {
333                fold: fold_idx,
334                current_usage: current,
335                peak_usage: peak,
336                limit,
337            });
338        }
339
340        Ok(StreamingEvaluationResult {
341            fold_results,
342            memory_snapshots,
343            total_folds: self.fold_indices.len(),
344        })
345    }
346}
347
348/// Memory usage snapshot
349#[derive(Debug, Clone)]
350pub struct MemorySnapshot {
351    pub fold: usize,
352    pub current_usage: usize,
353    pub peak_usage: usize,
354    pub limit: Option<usize>,
355}
356
357/// Result from streaming evaluation
358#[derive(Debug, Clone)]
359pub struct StreamingEvaluationResult<R> {
360    pub fold_results: Vec<R>,
361    pub memory_snapshots: Vec<MemorySnapshot>,
362    pub total_folds: usize,
363}
364
365impl<R> StreamingEvaluationResult<R> {
366    /// Get memory efficiency statistics
367    pub fn memory_efficiency_stats(&self) -> MemoryEfficiencyStats {
368        let total_peak = self.memory_snapshots.iter().map(|s| s.peak_usage).sum();
369        let avg_peak = total_peak / self.memory_snapshots.len();
370        let max_peak = self
371            .memory_snapshots
372            .iter()
373            .map(|s| s.peak_usage)
374            .max()
375            .unwrap_or(0);
376
377        let limit = self.memory_snapshots.first().and_then(|s| s.limit);
378        let efficiency_ratio = if let Some(limit) = limit {
379            max_peak as f64 / limit as f64
380        } else {
381            0.0
382        };
383
384        MemoryEfficiencyStats {
385            avg_peak_usage: avg_peak,
386            max_peak_usage: max_peak,
387            total_peak_usage: total_peak,
388            efficiency_ratio,
389            memory_limit: limit,
390            folds_processed: self.total_folds,
391        }
392    }
393}
394
395/// Memory efficiency statistics
396#[derive(Debug, Clone)]
397pub struct MemoryEfficiencyStats {
398    pub avg_peak_usage: usize,
399    pub max_peak_usage: usize,
400    pub total_peak_usage: usize,
401    pub efficiency_ratio: f64,
402    pub memory_limit: Option<usize>,
403    pub folds_processed: usize,
404}
405
406/// Convenience function for memory-efficient cross-validation
407pub fn memory_efficient_cross_validate<T, L, F, R>(
408    data: Vec<T>,
409    labels: Vec<L>,
410    n_folds: usize,
411    train_func: F,
412    config: Option<MemoryEfficientConfig>,
413) -> Result<StreamingEvaluationResult<R>, MemoryError>
414where
415    T: Clone + Send + Sync,
416    L: Clone + Send + Sync,
417    F: Fn(&[T], &[L]) -> Result<R, MemoryError>,
418    R: Clone + Default,
419{
420    let config = config.unwrap_or_default();
421    let mut evaluator = MemoryEfficientCrossValidator::new(config, n_folds);
422
423    // Load data into streaming readers
424    evaluator.data_reader.load_from_iterator(data.into_iter())?;
425    evaluator
426        .label_reader
427        .load_from_iterator(labels.into_iter())?;
428
429    // Setup folds
430    evaluator.setup_folds(evaluator.data_reader.total_samples(), n_folds);
431
432    // Perform streaming evaluation
433    evaluator.streaming_evaluate(train_func)
434}
435
436/// Memory pool for frequently allocated objects
437pub struct MemoryPool<T> {
438    pool: VecDeque<T>,
439    max_size: usize,
440    create_fn: Box<dyn Fn() -> T + Send + Sync>,
441}
442
443impl<T> MemoryPool<T>
444where
445    T: Send + Sync,
446{
447    /// Create a new memory pool
448    pub fn new<F>(max_size: usize, create_fn: F) -> Self
449    where
450        F: Fn() -> T + Send + Sync + 'static,
451    {
452        Self {
453            pool: VecDeque::new(),
454            max_size,
455            create_fn: Box::new(create_fn),
456        }
457    }
458
459    /// Get an object from the pool or create a new one
460    pub fn get(&mut self) -> T {
461        self.pool.pop_front().unwrap_or_else(|| (self.create_fn)())
462    }
463
464    /// Return an object to the pool
465    pub fn put(&mut self, item: T) {
466        if self.pool.len() < self.max_size {
467            self.pool.push_back(item);
468        }
469        // If pool is full, drop the item to free memory
470    }
471
472    /// Get current pool size
473    pub fn size(&self) -> usize {
474        self.pool.len()
475    }
476
477    /// Clear the pool
478    pub fn clear(&mut self) {
479        self.pool.clear();
480    }
481}
482
483#[allow(non_snake_case)]
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_memory_tracker() {
490        let mut tracker = MemoryTracker::new(Some(100));
491
492        assert!(tracker.allocate(50).is_ok());
493        assert_eq!(tracker.current_usage(), 50);
494        assert_eq!(tracker.peak_usage(), 50);
495
496        assert!(tracker.allocate(40).is_ok());
497        assert_eq!(tracker.current_usage(), 90);
498        assert_eq!(tracker.peak_usage(), 90);
499
500        // Should fail - exceeds limit
501        assert!(tracker.allocate(20).is_err());
502
503        tracker.deallocate(30);
504        assert_eq!(tracker.current_usage(), 60);
505        assert_eq!(tracker.peak_usage(), 90); // Peak should remain
506    }
507
508    #[test]
509    fn test_data_chunk() {
510        let data = vec![1, 2, 3, 4, 5];
511        let chunk = DataChunk::new(data.clone(), 10);
512
513        assert_eq!(chunk.len(), 5);
514        assert_eq!(chunk.start_index, 10);
515        assert_eq!(chunk.end_index, 15);
516        assert_eq!(chunk.data, data);
517        assert!(!chunk.is_empty());
518    }
519
520    #[test]
521    fn test_streaming_data_reader() {
522        let config = MemoryEfficientConfig {
523            chunk_size: 3,
524            buffer_size: 2,
525            ..Default::default()
526        };
527
528        let mut reader = StreamingDataReader::new(config);
529        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
530
531        assert!(reader.load_from_iterator(data.into_iter()).is_ok());
532        assert_eq!(reader.total_samples(), 9);
533
534        let chunk1 = reader.next_chunk();
535        assert!(chunk1.is_some());
536        assert_eq!(chunk1.unwrap().len(), 3);
537
538        assert!(reader.has_more_chunks());
539    }
540
541    #[test]
542    fn test_memory_pool() {
543        let mut pool = MemoryPool::new(3, || Vec::<i32>::new());
544
545        let item1 = pool.get();
546        assert_eq!(item1.len(), 0);
547
548        pool.put(vec![1, 2, 3]);
549        assert_eq!(pool.size(), 1);
550
551        let item2 = pool.get();
552        assert_eq!(item2, vec![1, 2, 3]);
553        assert_eq!(pool.size(), 0);
554    }
555
556    #[test]
557    fn test_memory_efficient_config_default() {
558        let config = MemoryEfficientConfig::default();
559        assert_eq!(config.chunk_size, 1000);
560        assert_eq!(config.memory_limit, Some(1024));
561        assert!(config.use_memory_mapping);
562        assert_eq!(config.buffer_size, 3);
563        assert!(!config.streaming_mode);
564    }
565
566    #[test]
567    fn test_streaming_evaluation_result_stats() {
568        let snapshots = vec![
569            MemorySnapshot {
570                fold: 0,
571                current_usage: 100,
572                peak_usage: 150,
573                limit: Some(1000),
574            },
575            MemorySnapshot {
576                fold: 1,
577                current_usage: 120,
578                peak_usage: 180,
579                limit: Some(1000),
580            },
581        ];
582
583        let result = StreamingEvaluationResult {
584            fold_results: vec![(), ()],
585            memory_snapshots: snapshots,
586            total_folds: 2,
587        };
588
589        let stats = result.memory_efficiency_stats();
590        assert_eq!(stats.avg_peak_usage, 165);
591        assert_eq!(stats.max_peak_usage, 180);
592        assert_eq!(stats.total_peak_usage, 330);
593        assert_eq!(stats.efficiency_ratio, 0.18);
594        assert_eq!(stats.memory_limit, Some(1000));
595        assert_eq!(stats.folds_processed, 2);
596    }
597
598    #[test]
599    #[ignore]
600    fn test_convenience_function() {
601        let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
602        let labels = vec![0, 1, 0, 1, 0, 1, 0, 1];
603
604        let train_func = |train_data: &[i32], train_labels: &[i32]| -> Result<f64, MemoryError> {
605            Ok(train_data.len() as f64 / train_labels.len() as f64)
606        };
607
608        let result = memory_efficient_cross_validate(data, labels, 3, train_func, None);
609        assert!(result.is_ok());
610
611        let result = result.unwrap();
612        assert_eq!(result.total_folds, 3);
613        assert_eq!(result.fold_results.len(), 3);
614        assert_eq!(result.memory_snapshots.len(), 3);
615    }
616}