Skip to main content

sklears_compose/
memory_optimization.rs

1//! Memory optimization and monitoring for pipeline execution
2//!
3//! This module provides memory-efficient pipeline execution, memory monitoring,
4//! garbage collection optimization, and memory pool management.
5
6use scirs2_core::ndarray::{s, Array2, ArrayView2};
7use sklears_core::error::{Result as SklResult, SklearsError};
8use std::alloc::{self, Layout};
9use std::collections::{BTreeMap, HashMap, VecDeque};
10use std::mem;
11use std::sync::{Arc, Mutex, RwLock};
12use std::thread::{self, JoinHandle};
13use std::time::{Duration, SystemTime};
14
15/// Memory usage tracking
16#[derive(Debug, Clone)]
17pub struct MemoryUsage {
18    /// Current allocated memory in bytes
19    pub allocated: u64,
20    /// Peak memory usage in bytes
21    pub peak: u64,
22    /// Number of allocations
23    pub allocations: u64,
24    /// Number of deallocations
25    pub deallocations: u64,
26    /// Memory fragmentation ratio
27    pub fragmentation: f64,
28    /// Last update timestamp
29    pub updated_at: SystemTime,
30}
31
32impl Default for MemoryUsage {
33    fn default() -> Self {
34        Self {
35            allocated: 0,
36            peak: 0,
37            allocations: 0,
38            deallocations: 0,
39            fragmentation: 0.0,
40            updated_at: SystemTime::now(),
41        }
42    }
43}
44
45impl MemoryUsage {
46    /// Update memory statistics
47    pub fn update(&mut self, allocated: u64, allocations: u64, deallocations: u64) {
48        self.allocated = allocated;
49        self.allocations = allocations;
50        self.deallocations = deallocations;
51
52        if allocated > self.peak {
53            self.peak = allocated;
54        }
55
56        // Simple fragmentation calculation
57        if allocations > 0 {
58            self.fragmentation = (allocations - deallocations) as f64 / allocations as f64;
59        }
60
61        self.updated_at = SystemTime::now();
62    }
63
64    /// Get current utilization ratio (0.0 - 1.0)
65    #[must_use]
66    pub fn utilization(&self, total_available: u64) -> f64 {
67        if total_available == 0 {
68            0.0
69        } else {
70            self.allocated as f64 / total_available as f64
71        }
72    }
73
74    /// Check if memory usage is critical
75    #[must_use]
76    pub fn is_critical(&self, threshold: f64, total_available: u64) -> bool {
77        self.utilization(total_available) > threshold
78    }
79}
80
81/// Memory monitor for tracking system memory usage
82pub struct MemoryMonitor {
83    /// Current memory usage
84    usage: Arc<RwLock<MemoryUsage>>,
85    /// Monitoring configuration
86    config: MemoryMonitorConfig,
87    /// Usage history
88    history: Arc<RwLock<VecDeque<MemoryUsage>>>,
89    /// Monitoring thread handle
90    monitor_thread: Option<JoinHandle<()>>,
91    /// Running flag
92    is_running: Arc<Mutex<bool>>,
93    /// Callbacks for memory events
94    callbacks: Arc<RwLock<Vec<Box<dyn Fn(&MemoryUsage) + Send + Sync>>>>,
95}
96
97impl std::fmt::Debug for MemoryMonitor {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        f.debug_struct("MemoryMonitor")
100            .field("usage", &self.usage)
101            .field("config", &self.config)
102            .field("history", &self.history)
103            .field("monitor_thread", &self.monitor_thread.is_some())
104            .field("is_running", &self.is_running)
105            .field(
106                "callbacks",
107                &format!(
108                    "{} callbacks",
109                    self.callbacks
110                        .read()
111                        .unwrap_or_else(|e| e.into_inner())
112                        .len()
113                ),
114            )
115            .finish()
116    }
117}
118
119/// Memory monitoring configuration
120#[derive(Debug, Clone)]
121pub struct MemoryMonitorConfig {
122    /// Monitoring interval
123    pub interval: Duration,
124    /// Warning threshold (0.0 - 1.0)
125    pub warning_threshold: f64,
126    /// Critical threshold (0.0 - 1.0)
127    pub critical_threshold: f64,
128    /// Maximum history entries
129    pub max_history: usize,
130    /// Enable automatic garbage collection
131    pub auto_gc: bool,
132    /// GC trigger threshold
133    pub gc_threshold: f64,
134}
135
136impl Default for MemoryMonitorConfig {
137    fn default() -> Self {
138        Self {
139            interval: Duration::from_secs(1),
140            warning_threshold: 0.7,
141            critical_threshold: 0.9,
142            max_history: 3600, // 1 hour at 1-second intervals
143            auto_gc: true,
144            gc_threshold: 0.8,
145        }
146    }
147}
148
149impl MemoryMonitor {
150    /// Create a new memory monitor
151    #[must_use]
152    pub fn new(config: MemoryMonitorConfig) -> Self {
153        Self {
154            usage: Arc::new(RwLock::new(MemoryUsage::default())),
155            config,
156            history: Arc::new(RwLock::new(VecDeque::new())),
157            monitor_thread: None,
158            is_running: Arc::new(Mutex::new(false)),
159            callbacks: Arc::new(RwLock::new(Vec::new())),
160        }
161    }
162
163    /// Start monitoring
164    pub fn start(&mut self) -> SklResult<()> {
165        {
166            let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
167            if *running {
168                return Ok(());
169            }
170            *running = true;
171        }
172
173        let usage = Arc::clone(&self.usage);
174        let history = Arc::clone(&self.history);
175        let callbacks = Arc::clone(&self.callbacks);
176        let is_running = Arc::clone(&self.is_running);
177        let config = self.config.clone();
178
179        let handle = thread::spawn(move || {
180            Self::monitor_loop(usage, history, callbacks, is_running, config);
181        });
182
183        self.monitor_thread = Some(handle);
184        Ok(())
185    }
186
187    /// Stop monitoring
188    pub fn stop(&mut self) -> SklResult<()> {
189        {
190            let mut running = self.is_running.lock().unwrap_or_else(|e| e.into_inner());
191            *running = false;
192        }
193
194        if let Some(handle) = self.monitor_thread.take() {
195            handle.join().map_err(|_| SklearsError::InvalidData {
196                reason: "Failed to join monitor thread".to_string(),
197            })?;
198        }
199
200        Ok(())
201    }
202
203    /// Main monitoring loop
204    fn monitor_loop(
205        usage: Arc<RwLock<MemoryUsage>>,
206        history: Arc<RwLock<VecDeque<MemoryUsage>>>,
207        callbacks: Arc<RwLock<Vec<Box<dyn Fn(&MemoryUsage) + Send + Sync>>>>,
208        is_running: Arc<Mutex<bool>>,
209        config: MemoryMonitorConfig,
210    ) {
211        while *is_running.lock().unwrap_or_else(|e| e.into_inner()) {
212            // Get current system memory usage
213            let (allocated, allocations, deallocations) = Self::get_system_memory_info();
214
215            // Update usage statistics
216            {
217                let mut current_usage = usage.write().unwrap_or_else(|e| e.into_inner());
218                current_usage.update(allocated, allocations, deallocations);
219
220                // Add to history
221                {
222                    let mut hist = history.write().unwrap_or_else(|e| e.into_inner());
223                    hist.push_back(current_usage.clone());
224
225                    // Limit history size
226                    while hist.len() > config.max_history {
227                        hist.pop_front();
228                    }
229                }
230
231                // Check thresholds and trigger callbacks
232                let total_memory = Self::get_total_system_memory();
233                let utilization = current_usage.utilization(total_memory);
234
235                if config.auto_gc && utilization > config.gc_threshold {
236                    Self::trigger_garbage_collection();
237                }
238
239                // Notify callbacks
240                let cb_list = callbacks.read().unwrap_or_else(|e| e.into_inner());
241                for callback in cb_list.iter() {
242                    callback(&current_usage);
243                }
244            }
245
246            thread::sleep(config.interval);
247        }
248    }
249
250    /// Get system memory information (simplified implementation)
251    fn get_system_memory_info() -> (u64, u64, u64) {
252        // In a real implementation, this would use platform-specific APIs
253        // For now, return dummy values
254        (1024 * 1024 * 100, 1000, 900) // 100MB allocated, 1000 allocs, 900 deallocs
255    }
256
257    /// Get total system memory (simplified implementation)
258    fn get_total_system_memory() -> u64 {
259        // In a real implementation, this would query system memory
260        1024 * 1024 * 1024 * 8 // 8GB
261    }
262
263    /// Trigger garbage collection
264    fn trigger_garbage_collection() {
265        // Force garbage collection (simplified)
266        // In Rust, this might involve dropping unused data structures
267        // or calling custom cleanup functions
268    }
269
270    /// Get current memory usage
271    #[must_use]
272    pub fn current_usage(&self) -> MemoryUsage {
273        let usage = self.usage.read().unwrap_or_else(|e| e.into_inner());
274        usage.clone()
275    }
276
277    /// Get memory usage history
278    #[must_use]
279    pub fn usage_history(&self) -> Vec<MemoryUsage> {
280        let history = self.history.read().unwrap_or_else(|e| e.into_inner());
281        history.iter().cloned().collect()
282    }
283
284    /// Add memory event callback
285    pub fn add_callback(&self, callback: Box<dyn Fn(&MemoryUsage) + Send + Sync>) {
286        let mut callbacks = self.callbacks.write().unwrap_or_else(|e| e.into_inner());
287        callbacks.push(callback);
288    }
289
290    /// Check if memory usage is above threshold
291    #[must_use]
292    pub fn is_above_threshold(&self, threshold: f64) -> bool {
293        let usage = self.usage.read().unwrap_or_else(|e| e.into_inner());
294        let total = Self::get_total_system_memory();
295        usage.utilization(total) > threshold
296    }
297
298    /// Get memory statistics summary
299    #[must_use]
300    pub fn get_statistics(&self) -> MemoryStatistics {
301        let usage = self.usage.read().unwrap_or_else(|e| e.into_inner());
302        let history = self.history.read().unwrap_or_else(|e| e.into_inner());
303
304        let avg_allocated = if history.is_empty() {
305            usage.allocated
306        } else {
307            history.iter().map(|u| u.allocated).sum::<u64>() / history.len() as u64
308        };
309
310        let max_allocated = history
311            .iter()
312            .map(|u| u.allocated)
313            .max()
314            .unwrap_or(usage.allocated);
315        let min_allocated = history
316            .iter()
317            .map(|u| u.allocated)
318            .min()
319            .unwrap_or(usage.allocated);
320
321        MemoryStatistics {
322            current: usage.clone(),
323            average_allocated: avg_allocated,
324            max_allocated,
325            min_allocated,
326            total_system_memory: Self::get_total_system_memory(),
327            samples_count: history.len(),
328        }
329    }
330}
331
332/// Memory statistics summary
333#[derive(Debug, Clone)]
334pub struct MemoryStatistics {
335    /// Current memory usage
336    pub current: MemoryUsage,
337    /// Average allocated memory
338    pub average_allocated: u64,
339    /// Maximum allocated memory
340    pub max_allocated: u64,
341    /// Minimum allocated memory
342    pub min_allocated: u64,
343    /// Total system memory
344    pub total_system_memory: u64,
345    /// Number of samples
346    pub samples_count: usize,
347}
348
349/// Memory pool for efficient allocation and reuse
350#[derive(Debug)]
351pub struct MemoryPool {
352    config: MemoryPoolConfig,
353    available_blocks: Arc<RwLock<BTreeMap<usize, Vec<MemoryBlock>>>>,
354    allocated_blocks: Arc<RwLock<HashMap<*mut u8, MemoryBlock>>>,
355    statistics: Arc<RwLock<PoolStatistics>>,
356    monitor: Option<MemoryMonitor>,
357}
358
359/// Memory pool configuration
360#[derive(Debug, Clone)]
361pub struct MemoryPoolConfig {
362    /// Initial pool size in bytes
363    pub initial_size: usize,
364    /// Maximum pool size in bytes
365    pub max_size: usize,
366    /// Block size classes
367    pub size_classes: Vec<usize>,
368    /// Enable automatic expansion
369    pub auto_expand: bool,
370    /// Expansion factor
371    pub expansion_factor: f64,
372    /// Enable memory compaction
373    pub compaction_enabled: bool,
374    /// Compaction threshold
375    pub compaction_threshold: f64,
376}
377
378impl Default for MemoryPoolConfig {
379    fn default() -> Self {
380        Self {
381            initial_size: 1024 * 1024 * 10, // 10MB
382            max_size: 1024 * 1024 * 100,    // 100MB
383            size_classes: vec![16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
384            auto_expand: true,
385            expansion_factor: 1.5,
386            compaction_enabled: true,
387            compaction_threshold: 0.7,
388        }
389    }
390}
391
392/// Memory block in the pool
393#[derive(Debug, Clone)]
394pub struct MemoryBlock {
395    /// Block pointer
396    pub ptr: *mut u8,
397    /// Block size
398    pub size: usize,
399    /// Allocation timestamp
400    pub allocated_at: SystemTime,
401    /// Last access timestamp
402    pub last_accessed: SystemTime,
403    /// Reference count
404    pub ref_count: usize,
405}
406
407/// Pool statistics
408#[derive(Debug, Clone)]
409pub struct PoolStatistics {
410    /// Total allocated bytes
411    pub total_allocated: usize,
412    /// Total available bytes
413    pub total_available: usize,
414    /// Number of allocations
415    pub allocations: u64,
416    /// Number of deallocations
417    pub deallocations: u64,
418    /// Cache hit rate
419    pub hit_rate: f64,
420    /// Fragmentation ratio
421    pub fragmentation: f64,
422    /// Pool utilization
423    pub utilization: f64,
424}
425
426impl Default for PoolStatistics {
427    fn default() -> Self {
428        Self {
429            total_allocated: 0,
430            total_available: 0,
431            allocations: 0,
432            deallocations: 0,
433            hit_rate: 0.0,
434            fragmentation: 0.0,
435            utilization: 0.0,
436        }
437    }
438}
439
440impl MemoryPool {
441    /// Create a new memory pool
442    pub fn new(config: MemoryPoolConfig) -> SklResult<Self> {
443        let mut pool = Self {
444            config,
445            available_blocks: Arc::new(RwLock::new(BTreeMap::new())),
446            allocated_blocks: Arc::new(RwLock::new(HashMap::new())),
447            statistics: Arc::new(RwLock::new(PoolStatistics::default())),
448            monitor: None,
449        };
450
451        // Initialize pool with initial blocks
452        pool.initialize_pool()?;
453
454        Ok(pool)
455    }
456
457    /// Initialize the memory pool
458    fn initialize_pool(&mut self) -> SklResult<()> {
459        let mut available = self
460            .available_blocks
461            .write()
462            .unwrap_or_else(|e| e.into_inner());
463
464        for &size_class in &self.config.size_classes {
465            let blocks_per_class =
466                self.config.initial_size / (size_class * self.config.size_classes.len());
467            let mut blocks = Vec::with_capacity(blocks_per_class);
468
469            for _ in 0..blocks_per_class {
470                let layout = Layout::from_size_align(size_class, std::mem::align_of::<u8>())
471                    .map_err(|_| SklearsError::InvalidData {
472                        reason: "Invalid memory layout".to_string(),
473                    })?;
474
475                unsafe {
476                    let ptr = alloc::alloc(layout);
477                    if ptr.is_null() {
478                        return Err(SklearsError::InvalidData {
479                            reason: "Memory allocation failed".to_string(),
480                        });
481                    }
482
483                    blocks.push(MemoryBlock {
484                        ptr,
485                        size: size_class,
486                        allocated_at: SystemTime::now(),
487                        last_accessed: SystemTime::now(),
488                        ref_count: 0,
489                    });
490                }
491            }
492
493            available.insert(size_class, blocks);
494        }
495
496        Ok(())
497    }
498
499    /// Allocate memory from the pool
500    pub fn allocate(&self, size: usize) -> SklResult<*mut u8> {
501        let size_class = self.find_size_class(size);
502        let mut available = self
503            .available_blocks
504            .write()
505            .unwrap_or_else(|e| e.into_inner());
506        let mut allocated = self
507            .allocated_blocks
508            .write()
509            .unwrap_or_else(|e| e.into_inner());
510        let mut stats = self.statistics.write().unwrap_or_else(|e| e.into_inner());
511
512        if let Some(blocks) = available.get_mut(&size_class) {
513            if let Some(mut block) = blocks.pop() {
514                // Found available block
515                block.allocated_at = SystemTime::now();
516                block.last_accessed = SystemTime::now();
517                block.ref_count = 1;
518
519                let ptr = block.ptr;
520                allocated.insert(ptr, block);
521
522                stats.allocations += 1;
523                stats.total_allocated += size_class;
524                stats.hit_rate = stats.allocations as f64 / (stats.allocations + 1) as f64;
525
526                return Ok(ptr);
527            }
528        }
529
530        // No available block, allocate new one if auto-expand is enabled
531        if self.config.auto_expand {
532            let layout =
533                Layout::from_size_align(size_class, std::mem::align_of::<u8>()).map_err(|_| {
534                    SklearsError::InvalidData {
535                        reason: "Invalid memory layout".to_string(),
536                    }
537                })?;
538
539            unsafe {
540                let ptr = alloc::alloc(layout);
541                if ptr.is_null() {
542                    return Err(SklearsError::InvalidData {
543                        reason: "Memory allocation failed".to_string(),
544                    });
545                }
546
547                let block = MemoryBlock {
548                    ptr,
549                    size: size_class,
550                    allocated_at: SystemTime::now(),
551                    last_accessed: SystemTime::now(),
552                    ref_count: 1,
553                };
554
555                allocated.insert(ptr, block);
556                stats.allocations += 1;
557                stats.total_allocated += size_class;
558
559                Ok(ptr)
560            }
561        } else {
562            Err(SklearsError::InvalidData {
563                reason: "Memory pool exhausted".to_string(),
564            })
565        }
566    }
567
568    /// Deallocate memory back to the pool
569    pub fn deallocate(&self, ptr: *mut u8) -> SklResult<()> {
570        let mut available = self
571            .available_blocks
572            .write()
573            .unwrap_or_else(|e| e.into_inner());
574        let mut allocated = self
575            .allocated_blocks
576            .write()
577            .unwrap_or_else(|e| e.into_inner());
578        let mut stats = self.statistics.write().unwrap_or_else(|e| e.into_inner());
579
580        if let Some(mut block) = allocated.remove(&ptr) {
581            block.ref_count = 0;
582            block.last_accessed = SystemTime::now();
583
584            let size_class = block.size;
585            available.entry(size_class).or_default().push(block);
586
587            stats.deallocations += 1;
588            stats.total_allocated = stats.total_allocated.saturating_sub(size_class);
589
590            Ok(())
591        } else {
592            Err(SklearsError::InvalidData {
593                reason: "Invalid pointer for deallocation".to_string(),
594            })
595        }
596    }
597
598    /// Find appropriate size class for allocation
599    fn find_size_class(&self, size: usize) -> usize {
600        self.config
601            .size_classes
602            .iter()
603            .find(|&&class_size| class_size >= size)
604            .copied()
605            .unwrap_or_else(|| {
606                // Round up to next power of 2
607                let mut class_size = 1;
608                while class_size < size {
609                    class_size <<= 1;
610                }
611                class_size
612            })
613    }
614
615    /// Compact the memory pool
616    pub fn compact(&self) -> SklResult<()> {
617        let available = self
618            .available_blocks
619            .write()
620            .unwrap_or_else(|e| e.into_inner());
621        let mut stats = self.statistics.write().unwrap_or_else(|e| e.into_inner());
622
623        let total_blocks: usize = available.values().map(std::vec::Vec::len).sum();
624        let fragmentation = if total_blocks > 0 {
625            1.0 - (available.len() as f64 / total_blocks as f64)
626        } else {
627            0.0
628        };
629
630        if fragmentation > self.config.compaction_threshold {
631            // Perform compaction (simplified)
632            // In a real implementation, this would reorganize memory blocks
633            stats.fragmentation = fragmentation;
634        }
635
636        Ok(())
637    }
638
639    /// Get pool statistics
640    #[must_use]
641    pub fn statistics(&self) -> PoolStatistics {
642        let stats = self.statistics.read().unwrap_or_else(|e| e.into_inner());
643        stats.clone()
644    }
645
646    /// Enable memory monitoring
647    pub fn enable_monitoring(&mut self, config: MemoryMonitorConfig) -> SklResult<()> {
648        let mut monitor = MemoryMonitor::new(config);
649        monitor.start()?;
650        self.monitor = Some(monitor);
651        Ok(())
652    }
653
654    /// Clear unused blocks (garbage collection)
655    pub fn garbage_collect(&self) -> SklResult<usize> {
656        let mut available = self
657            .available_blocks
658            .write()
659            .unwrap_or_else(|e| e.into_inner());
660        let mut freed_blocks = 0;
661
662        for (_, blocks) in available.iter_mut() {
663            let old_len = blocks.len();
664
665            // Keep only recently accessed blocks
666            let cutoff = SystemTime::now() - Duration::from_secs(300); // 5 minutes
667            blocks.retain(|block| block.last_accessed > cutoff);
668
669            freed_blocks += old_len - blocks.len();
670        }
671
672        Ok(freed_blocks)
673    }
674}
675
676/// Memory-efficient data structure for streaming data
677#[derive(Debug)]
678pub struct StreamingBuffer<T> {
679    /// Ring buffer for data
680    buffer: Vec<Option<T>>,
681    /// Buffer capacity
682    capacity: usize,
683    /// Current write position
684    write_pos: usize,
685    /// Current read position
686    read_pos: usize,
687    /// Number of elements in buffer
688    count: usize,
689    /// Memory pool for allocations
690    memory_pool: Option<Arc<MemoryPool>>,
691}
692
693impl<T> StreamingBuffer<T> {
694    /// Create a new streaming buffer
695    pub fn new(capacity: usize) -> Self {
696        let mut buffer = Vec::with_capacity(capacity);
697        for _ in 0..capacity {
698            buffer.push(None);
699        }
700        Self {
701            buffer,
702            capacity,
703            write_pos: 0,
704            read_pos: 0,
705            count: 0,
706            memory_pool: None,
707        }
708    }
709
710    /// Create streaming buffer with memory pool
711    #[must_use]
712    pub fn with_memory_pool(capacity: usize, memory_pool: Arc<MemoryPool>) -> Self {
713        let mut buffer = Self::new(capacity);
714        buffer.memory_pool = Some(memory_pool);
715        buffer
716    }
717
718    /// Push an element to the buffer
719    pub fn push(&mut self, item: T) -> Option<T> {
720        let old_item = self.buffer[self.write_pos].take();
721        self.buffer[self.write_pos] = Some(item);
722
723        self.write_pos = (self.write_pos + 1) % self.capacity;
724
725        if self.count < self.capacity {
726            self.count += 1;
727        } else {
728            self.read_pos = (self.read_pos + 1) % self.capacity;
729        }
730
731        old_item
732    }
733
734    /// Pop an element from the buffer
735    pub fn pop(&mut self) -> Option<T> {
736        if self.count == 0 {
737            return None;
738        }
739
740        let item = self.buffer[self.read_pos].take();
741        self.read_pos = (self.read_pos + 1) % self.capacity;
742        self.count -= 1;
743
744        item
745    }
746
747    /// Get current buffer size
748    #[must_use]
749    pub fn len(&self) -> usize {
750        self.count
751    }
752
753    /// Check if buffer is empty
754    #[must_use]
755    pub fn is_empty(&self) -> bool {
756        self.count == 0
757    }
758
759    /// Check if buffer is full
760    #[must_use]
761    pub fn is_full(&self) -> bool {
762        self.count == self.capacity
763    }
764
765    /// Clear the buffer
766    pub fn clear(&mut self) {
767        for slot in &mut self.buffer {
768            *slot = None;
769        }
770        self.write_pos = 0;
771        self.read_pos = 0;
772        self.count = 0;
773    }
774
775    /// Get memory usage of the buffer
776    #[must_use]
777    pub fn memory_usage(&self) -> usize {
778        self.capacity * mem::size_of::<Option<T>>()
779    }
780}
781
782/// Memory-efficient array operations
783pub struct MemoryEfficientOps;
784
785impl MemoryEfficientOps {
786    /// In-place array transformation to reduce memory allocations
787    pub fn transform_inplace<F>(array: &mut Array2<f64>, transform_fn: F)
788    where
789        F: Fn(f64) -> f64,
790    {
791        array.mapv_inplace(transform_fn);
792    }
793
794    /// Batch processing with controlled memory usage
795    pub fn batch_process<F, R>(
796        data: &Array2<f64>,
797        batch_size: usize,
798        process_fn: F,
799    ) -> SklResult<Vec<R>>
800    where
801        F: Fn(ArrayView2<f64>) -> SklResult<R>,
802    {
803        let mut results = Vec::new();
804        let n_rows = data.nrows();
805
806        for chunk_start in (0..n_rows).step_by(batch_size) {
807            let chunk_end = std::cmp::min(chunk_start + batch_size, n_rows);
808            let batch = data.slice(s![chunk_start..chunk_end, ..]);
809
810            let result = process_fn(batch)?;
811            results.push(result);
812        }
813
814        Ok(results)
815    }
816
817    /// Memory-efficient matrix multiplication using chunking
818    pub fn chunked_matmul(
819        a: &Array2<f64>,
820        b: &Array2<f64>,
821        chunk_size: usize,
822    ) -> SklResult<Array2<f64>> {
823        if a.ncols() != b.nrows() {
824            return Err(SklearsError::InvalidData {
825                reason: "Matrix dimensions don't match for multiplication".to_string(),
826            });
827        }
828
829        let mut result = Array2::zeros((a.nrows(), b.ncols()));
830
831        for i_chunk in (0..a.nrows()).step_by(chunk_size) {
832            let i_end = std::cmp::min(i_chunk + chunk_size, a.nrows());
833
834            for j_chunk in (0..b.ncols()).step_by(chunk_size) {
835                let j_end = std::cmp::min(j_chunk + chunk_size, b.ncols());
836
837                for k_chunk in (0..a.ncols()).step_by(chunk_size) {
838                    let k_end = std::cmp::min(k_chunk + chunk_size, a.ncols());
839
840                    let a_chunk = a.slice(s![i_chunk..i_end, k_chunk..k_end]);
841                    let b_chunk = b.slice(s![k_chunk..k_end, j_chunk..j_end]);
842
843                    let mut result_chunk = result.slice_mut(s![i_chunk..i_end, j_chunk..j_end]);
844
845                    // Perform chunk multiplication
846                    for (i, a_row) in a_chunk.rows().into_iter().enumerate() {
847                        for (j, b_col) in b_chunk.columns().into_iter().enumerate() {
848                            result_chunk[[i, j]] += a_row.dot(&b_col);
849                        }
850                    }
851                }
852            }
853        }
854
855        Ok(result)
856    }
857
858    /// Reduce memory footprint by using lower precision when possible
859    #[must_use]
860    pub fn optimize_precision(array: &Array2<f64>, tolerance: f64) -> Array2<f32> {
861        array.mapv(|x| {
862            if x.abs() < tolerance {
863                0.0f32
864            } else {
865                x as f32
866            }
867        })
868    }
869}
870
871#[allow(non_snake_case)]
872#[cfg(test)]
873mod tests {
874    use super::*;
875
876    #[test]
877    fn test_memory_usage() {
878        let mut usage = MemoryUsage::default();
879        usage.update(1024, 10, 5);
880
881        assert_eq!(usage.allocated, 1024);
882        assert_eq!(usage.allocations, 10);
883        assert_eq!(usage.deallocations, 5);
884        assert_eq!(usage.peak, 1024);
885    }
886
887    #[test]
888    fn test_memory_monitor_creation() {
889        let config = MemoryMonitorConfig::default();
890        let monitor = MemoryMonitor::new(config);
891
892        let usage = monitor.current_usage();
893        assert_eq!(usage.allocated, 0);
894    }
895
896    #[test]
897    fn test_memory_pool_creation() {
898        let config = MemoryPoolConfig::default();
899        let pool = MemoryPool::new(config).expect("operation should succeed");
900
901        let stats = pool.statistics();
902        assert_eq!(stats.allocations, 0);
903        assert_eq!(stats.deallocations, 0);
904    }
905
906    #[test]
907    fn test_streaming_buffer() {
908        let mut buffer = StreamingBuffer::new(3);
909
910        assert!(buffer.is_empty());
911        assert_eq!(buffer.len(), 0);
912
913        buffer.push(1);
914        buffer.push(2);
915        buffer.push(3);
916
917        assert!(buffer.is_full());
918        assert_eq!(buffer.len(), 3);
919
920        let old_item = buffer.push(4); // Should evict 1
921        assert_eq!(old_item, Some(1));
922
923        let popped = buffer.pop();
924        assert_eq!(popped, Some(2));
925    }
926
927    #[test]
928    fn test_memory_efficient_ops() {
929        let mut array =
930            Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap_or_default();
931
932        // Test in-place transformation
933        MemoryEfficientOps::transform_inplace(&mut array, |x| x * 2.0);
934        assert_eq!(array[[0, 0]], 2.0);
935        assert_eq!(array[[1, 1]], 8.0);
936
937        // Test precision optimization
938        let array_f64 =
939            Array2::from_shape_vec((2, 2), vec![1.0, 0.000001, 3.0, 0.000002]).unwrap_or_default();
940        let array_f32 = MemoryEfficientOps::optimize_precision(&array_f64, 0.00001);
941        assert_eq!(array_f32[[0, 1]], 0.0f32); // Small value should be zeroed
942        assert_eq!(array_f32[[1, 0]], 3.0f32); // Large value should be preserved
943    }
944
945    #[test]
946    fn test_batch_processing() {
947        let data = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
948            .unwrap_or_default();
949
950        let results = MemoryEfficientOps::batch_process(&data, 2, |batch| Ok(batch.sum()))
951            .unwrap_or_default();
952
953        assert_eq!(results.len(), 2); // 4 rows / 2 batch_size = 2 batches
954        assert_eq!(results[0], 10.0); // Sum of first batch: 1+2+3+4
955        assert_eq!(results[1], 26.0); // Sum of second batch: 5+6+7+8
956    }
957
958    #[test]
959    fn test_chunked_matmul() {
960        let a =
961            Array2::from_shape_vec((2, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap_or_default();
962        let b =
963            Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap_or_default();
964
965        let result = MemoryEfficientOps::chunked_matmul(&a, &b, 2).unwrap_or_default();
966
967        // Expected result of matrix multiplication
968        assert_eq!(result.shape(), &[2, 2]);
969        assert_eq!(result[[0, 0]], 22.0); // 1*1 + 2*3 + 3*5
970        assert_eq!(result[[0, 1]], 28.0); // 1*2 + 2*4 + 3*6
971        assert_eq!(result[[1, 0]], 49.0); // 4*1 + 5*3 + 6*5
972        assert_eq!(result[[1, 1]], 64.0); // 4*2 + 5*4 + 6*6
973    }
974}