1use std::collections::VecDeque;
8use thiserror::Error;
9
10#[derive(Error, Debug)]
12pub enum MemoryError {
13 #[error("IO error: {0}")]
14 Io(#[from] std::io::Error),
15 #[error("Chunk processing error: {0}")]
16 ChunkProcessing(String),
17 #[error("Memory limit exceeded: requested {requested}MB, limit {limit}MB")]
18 MemoryLimitExceeded { requested: usize, limit: usize },
19 #[error("Streaming error: {0}")]
20 Streaming(String),
21}
22
23pub struct MemoryTracker {
25 current_usage: usize,
26 peak_usage: usize,
27 limit: Option<usize>,
28}
29
30impl MemoryTracker {
31 pub fn new(limit_mb: Option<usize>) -> Self {
33 Self {
34 current_usage: 0,
35 peak_usage: 0,
36 limit: limit_mb,
37 }
38 }
39
40 pub fn allocate(&mut self, size_mb: usize) -> Result<(), MemoryError> {
42 if let Some(limit) = self.limit {
43 if self.current_usage + size_mb > limit {
44 return Err(MemoryError::MemoryLimitExceeded {
45 requested: size_mb,
46 limit,
47 });
48 }
49 }
50
51 self.current_usage += size_mb;
52 if self.current_usage > self.peak_usage {
53 self.peak_usage = self.current_usage;
54 }
55
56 Ok(())
57 }
58
59 pub fn deallocate(&mut self, size_mb: usize) {
61 self.current_usage = self.current_usage.saturating_sub(size_mb);
62 }
63
64 pub fn current_usage(&self) -> usize {
66 self.current_usage
67 }
68
69 pub fn peak_usage(&self) -> usize {
71 self.peak_usage
72 }
73
74 pub fn limit(&self) -> Option<usize> {
76 self.limit
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct MemoryEfficientConfig {
83 pub chunk_size: usize,
85 pub memory_limit: Option<usize>,
87 pub use_memory_mapping: bool,
89 pub buffer_size: usize,
91 pub streaming_mode: bool,
93}
94
95impl Default for MemoryEfficientConfig {
96 fn default() -> Self {
97 Self {
98 chunk_size: 1000,
99 memory_limit: Some(1024), use_memory_mapping: true,
101 buffer_size: 3, streaming_mode: false,
103 }
104 }
105}
106
107#[derive(Debug, Clone)]
109pub struct DataChunk<T> {
110 pub data: Vec<T>,
111 pub start_index: usize,
112 pub end_index: usize,
113}
114
115impl<T> DataChunk<T> {
116 pub fn new(data: Vec<T>, start_index: usize) -> Self {
117 let end_index = start_index + data.len();
118 Self {
119 data,
120 start_index,
121 end_index,
122 }
123 }
124
125 pub fn len(&self) -> usize {
126 self.data.len()
127 }
128
129 pub fn is_empty(&self) -> bool {
130 self.data.is_empty()
131 }
132}
133
134pub struct StreamingDataReader<T> {
136 chunks: VecDeque<DataChunk<T>>,
137 current_index: usize,
138 total_samples: usize,
139 config: MemoryEfficientConfig,
140 memory_tracker: MemoryTracker,
141}
142
143impl<T> StreamingDataReader<T>
144where
145 T: Clone + Send + Sync,
146{
147 pub fn new(config: MemoryEfficientConfig) -> Self {
149 let memory_tracker = MemoryTracker::new(config.memory_limit);
150 Self {
151 chunks: VecDeque::new(),
152 current_index: 0,
153 total_samples: 0,
154 config,
155 memory_tracker,
156 }
157 }
158
159 pub fn load_from_iterator<I>(&mut self, data_iter: I) -> Result<(), MemoryError>
161 where
162 I: Iterator<Item = T>,
163 {
164 let mut chunk_data = Vec::with_capacity(self.config.chunk_size);
165 let mut start_index = 0;
166 let mut total_count = 0;
167
168 for (i, item) in data_iter.enumerate() {
169 chunk_data.push(item);
170 total_count += 1;
171
172 if chunk_data.len() >= self.config.chunk_size {
173 let chunk_size_mb = std::mem::size_of::<T>() * chunk_data.len() / 1_048_576;
174 self.memory_tracker.allocate(chunk_size_mb)?;
175
176 let chunk = DataChunk::new(chunk_data, start_index);
177 self.chunks.push_back(chunk);
178
179 chunk_data = Vec::with_capacity(self.config.chunk_size);
180 start_index = i + 1;
181 }
182
183 while self.chunks.len() > self.config.buffer_size {
185 if let Some(old_chunk) = self.chunks.pop_front() {
186 let chunk_size_mb = std::mem::size_of::<T>() * old_chunk.len() / 1_048_576;
187 self.memory_tracker.deallocate(chunk_size_mb);
188 }
189 }
190 }
191
192 if !chunk_data.is_empty() {
194 let chunk_size_mb = std::mem::size_of::<T>() * chunk_data.len() / 1_048_576;
195 self.memory_tracker.allocate(chunk_size_mb)?;
196
197 let chunk = DataChunk::new(chunk_data, start_index);
198 self.chunks.push_back(chunk);
199 }
200
201 self.total_samples = total_count;
202 Ok(())
203 }
204
205 pub fn next_chunk(&mut self) -> Option<&DataChunk<T>> {
207 if self.chunks.is_empty() {
208 return None;
209 }
210
211 let front_chunk = self.chunks.front()?;
212 if self.current_index >= front_chunk.end_index {
213 if let Some(old_chunk) = self.chunks.pop_front() {
215 let chunk_size_mb = std::mem::size_of::<T>() * old_chunk.len() / 1_048_576;
216 self.memory_tracker.deallocate(chunk_size_mb);
217 }
218 return self.next_chunk();
219 }
220
221 self.chunks.front()
222 }
223
224 pub fn memory_stats(&self) -> (usize, usize, Option<usize>) {
226 (
227 self.memory_tracker.current_usage(),
228 self.memory_tracker.peak_usage(),
229 self.memory_tracker.limit(),
230 )
231 }
232
233 pub fn total_samples(&self) -> usize {
235 self.total_samples
236 }
237
238 pub fn has_more_chunks(&self) -> bool {
240 !self.chunks.is_empty() && self.current_index < self.total_samples
241 }
242}
243
244pub struct MemoryEfficientCrossValidator<T, L> {
246 config: MemoryEfficientConfig,
247 fold_indices: Vec<Vec<usize>>,
248 data_reader: StreamingDataReader<T>,
249 label_reader: StreamingDataReader<L>,
250}
251
252impl<T, L> MemoryEfficientCrossValidator<T, L>
253where
254 T: Clone + Send + Sync,
255 L: Clone + Send + Sync,
256{
257 pub fn new(config: MemoryEfficientConfig, n_folds: usize) -> Self {
259 Self {
260 config: config.clone(),
261 fold_indices: Vec::with_capacity(n_folds),
262 data_reader: StreamingDataReader::new(config.clone()),
263 label_reader: StreamingDataReader::new(config),
264 }
265 }
266
267 pub fn setup_folds(&mut self, n_samples: usize, n_folds: usize) {
269 let samples_per_fold = n_samples / n_folds;
270 let mut indices: Vec<usize> = (0..n_samples).collect();
271
272 indices.sort_by_key(|&i| i % 997); for fold in 0..n_folds {
276 let start = fold * samples_per_fold;
277 let end = if fold == n_folds - 1 {
278 n_samples
279 } else {
280 (fold + 1) * samples_per_fold
281 };
282
283 self.fold_indices.push(indices[start..end].to_vec());
284 }
285 }
286
287 pub fn streaming_evaluate<F, R>(
289 &mut self,
290 train_func: F,
291 ) -> Result<StreamingEvaluationResult<R>, MemoryError>
292 where
293 F: Fn(&[T], &[L]) -> Result<R, MemoryError>,
294 R: Clone + Default,
295 {
296 let mut fold_results = Vec::new();
297 let mut memory_snapshots = Vec::new();
298
299 for fold_idx in 0..self.fold_indices.len() {
300 let test_indices = &self.fold_indices[fold_idx];
301
302 let mut train_data = Vec::new();
304 let mut train_labels = Vec::new();
305
306 while let Some(data_chunk) = self.data_reader.next_chunk() {
308 let label_chunk = self.label_reader.next_chunk().ok_or_else(|| {
309 MemoryError::Streaming("Mismatched data and labels".to_string())
310 })?;
311
312 for (i, (sample, label)) in data_chunk
313 .data
314 .iter()
315 .zip(label_chunk.data.iter())
316 .enumerate()
317 {
318 let global_idx = data_chunk.start_index + i;
319 if !test_indices.contains(&global_idx) {
320 train_data.push(sample.clone());
321 train_labels.push(label.clone());
322 }
323 }
324 }
325
326 let result = train_func(&train_data, &train_labels)?;
328 fold_results.push(result);
329
330 let (current, peak, limit) = self.data_reader.memory_stats();
332 memory_snapshots.push(MemorySnapshot {
333 fold: fold_idx,
334 current_usage: current,
335 peak_usage: peak,
336 limit,
337 });
338 }
339
340 Ok(StreamingEvaluationResult {
341 fold_results,
342 memory_snapshots,
343 total_folds: self.fold_indices.len(),
344 })
345 }
346}
347
348#[derive(Debug, Clone)]
350pub struct MemorySnapshot {
351 pub fold: usize,
352 pub current_usage: usize,
353 pub peak_usage: usize,
354 pub limit: Option<usize>,
355}
356
357#[derive(Debug, Clone)]
359pub struct StreamingEvaluationResult<R> {
360 pub fold_results: Vec<R>,
361 pub memory_snapshots: Vec<MemorySnapshot>,
362 pub total_folds: usize,
363}
364
365impl<R> StreamingEvaluationResult<R> {
366 pub fn memory_efficiency_stats(&self) -> MemoryEfficiencyStats {
368 let total_peak = self.memory_snapshots.iter().map(|s| s.peak_usage).sum();
369 let avg_peak = total_peak / self.memory_snapshots.len();
370 let max_peak = self
371 .memory_snapshots
372 .iter()
373 .map(|s| s.peak_usage)
374 .max()
375 .unwrap_or(0);
376
377 let limit = self.memory_snapshots.first().and_then(|s| s.limit);
378 let efficiency_ratio = if let Some(limit) = limit {
379 max_peak as f64 / limit as f64
380 } else {
381 0.0
382 };
383
384 MemoryEfficiencyStats {
385 avg_peak_usage: avg_peak,
386 max_peak_usage: max_peak,
387 total_peak_usage: total_peak,
388 efficiency_ratio,
389 memory_limit: limit,
390 folds_processed: self.total_folds,
391 }
392 }
393}
394
395#[derive(Debug, Clone)]
397pub struct MemoryEfficiencyStats {
398 pub avg_peak_usage: usize,
399 pub max_peak_usage: usize,
400 pub total_peak_usage: usize,
401 pub efficiency_ratio: f64,
402 pub memory_limit: Option<usize>,
403 pub folds_processed: usize,
404}
405
406pub fn memory_efficient_cross_validate<T, L, F, R>(
408 data: Vec<T>,
409 labels: Vec<L>,
410 n_folds: usize,
411 train_func: F,
412 config: Option<MemoryEfficientConfig>,
413) -> Result<StreamingEvaluationResult<R>, MemoryError>
414where
415 T: Clone + Send + Sync,
416 L: Clone + Send + Sync,
417 F: Fn(&[T], &[L]) -> Result<R, MemoryError>,
418 R: Clone + Default,
419{
420 let config = config.unwrap_or_default();
421 let mut evaluator = MemoryEfficientCrossValidator::new(config, n_folds);
422
423 evaluator.data_reader.load_from_iterator(data.into_iter())?;
425 evaluator
426 .label_reader
427 .load_from_iterator(labels.into_iter())?;
428
429 evaluator.setup_folds(evaluator.data_reader.total_samples(), n_folds);
431
432 evaluator.streaming_evaluate(train_func)
434}
435
436pub struct MemoryPool<T> {
438 pool: VecDeque<T>,
439 max_size: usize,
440 create_fn: Box<dyn Fn() -> T + Send + Sync>,
441}
442
443impl<T> MemoryPool<T>
444where
445 T: Send + Sync,
446{
447 pub fn new<F>(max_size: usize, create_fn: F) -> Self
449 where
450 F: Fn() -> T + Send + Sync + 'static,
451 {
452 Self {
453 pool: VecDeque::new(),
454 max_size,
455 create_fn: Box::new(create_fn),
456 }
457 }
458
459 pub fn get(&mut self) -> T {
461 self.pool.pop_front().unwrap_or_else(|| (self.create_fn)())
462 }
463
464 pub fn put(&mut self, item: T) {
466 if self.pool.len() < self.max_size {
467 self.pool.push_back(item);
468 }
469 }
471
472 pub fn size(&self) -> usize {
474 self.pool.len()
475 }
476
477 pub fn clear(&mut self) {
479 self.pool.clear();
480 }
481}
482
483#[allow(non_snake_case)]
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_memory_tracker() {
490 let mut tracker = MemoryTracker::new(Some(100));
491
492 assert!(tracker.allocate(50).is_ok());
493 assert_eq!(tracker.current_usage(), 50);
494 assert_eq!(tracker.peak_usage(), 50);
495
496 assert!(tracker.allocate(40).is_ok());
497 assert_eq!(tracker.current_usage(), 90);
498 assert_eq!(tracker.peak_usage(), 90);
499
500 assert!(tracker.allocate(20).is_err());
502
503 tracker.deallocate(30);
504 assert_eq!(tracker.current_usage(), 60);
505 assert_eq!(tracker.peak_usage(), 90); }
507
508 #[test]
509 fn test_data_chunk() {
510 let data = vec![1, 2, 3, 4, 5];
511 let chunk = DataChunk::new(data.clone(), 10);
512
513 assert_eq!(chunk.len(), 5);
514 assert_eq!(chunk.start_index, 10);
515 assert_eq!(chunk.end_index, 15);
516 assert_eq!(chunk.data, data);
517 assert!(!chunk.is_empty());
518 }
519
520 #[test]
521 fn test_streaming_data_reader() {
522 let config = MemoryEfficientConfig {
523 chunk_size: 3,
524 buffer_size: 2,
525 ..Default::default()
526 };
527
528 let mut reader = StreamingDataReader::new(config);
529 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
530
531 assert!(reader.load_from_iterator(data.into_iter()).is_ok());
532 assert_eq!(reader.total_samples(), 9);
533
534 let chunk1 = reader.next_chunk();
535 assert!(chunk1.is_some());
536 assert_eq!(chunk1.unwrap().len(), 3);
537
538 assert!(reader.has_more_chunks());
539 }
540
541 #[test]
542 fn test_memory_pool() {
543 let mut pool = MemoryPool::new(3, || Vec::<i32>::new());
544
545 let item1 = pool.get();
546 assert_eq!(item1.len(), 0);
547
548 pool.put(vec![1, 2, 3]);
549 assert_eq!(pool.size(), 1);
550
551 let item2 = pool.get();
552 assert_eq!(item2, vec![1, 2, 3]);
553 assert_eq!(pool.size(), 0);
554 }
555
556 #[test]
557 fn test_memory_efficient_config_default() {
558 let config = MemoryEfficientConfig::default();
559 assert_eq!(config.chunk_size, 1000);
560 assert_eq!(config.memory_limit, Some(1024));
561 assert!(config.use_memory_mapping);
562 assert_eq!(config.buffer_size, 3);
563 assert!(!config.streaming_mode);
564 }
565
566 #[test]
567 fn test_streaming_evaluation_result_stats() {
568 let snapshots = vec![
569 MemorySnapshot {
570 fold: 0,
571 current_usage: 100,
572 peak_usage: 150,
573 limit: Some(1000),
574 },
575 MemorySnapshot {
576 fold: 1,
577 current_usage: 120,
578 peak_usage: 180,
579 limit: Some(1000),
580 },
581 ];
582
583 let result = StreamingEvaluationResult {
584 fold_results: vec![(), ()],
585 memory_snapshots: snapshots,
586 total_folds: 2,
587 };
588
589 let stats = result.memory_efficiency_stats();
590 assert_eq!(stats.avg_peak_usage, 165);
591 assert_eq!(stats.max_peak_usage, 180);
592 assert_eq!(stats.total_peak_usage, 330);
593 assert_eq!(stats.efficiency_ratio, 0.18);
594 assert_eq!(stats.memory_limit, Some(1000));
595 assert_eq!(stats.folds_processed, 2);
596 }
597
598 #[test]
599 #[ignore]
600 fn test_convenience_function() {
601 let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
602 let labels = vec![0, 1, 0, 1, 0, 1, 0, 1];
603
604 let train_func = |train_data: &[i32], train_labels: &[i32]| -> Result<f64, MemoryError> {
605 Ok(train_data.len() as f64 / train_labels.len() as f64)
606 };
607
608 let result = memory_efficient_cross_validate(data, labels, 3, train_func, None);
609 assert!(result.is_ok());
610
611 let result = result.unwrap();
612 assert_eq!(result.total_folds, 3);
613 assert_eq!(result.fold_results.len(), 3);
614 assert_eq!(result.memory_snapshots.len(), 3);
615 }
616}