Skip to main content

tensorlogic_train/
memory.rs

1//! Memory profiling and management utilities for training.
2//!
3//! This module provides tools to monitor and manage memory usage during training:
4//! - Memory profiling with allocation tracking
5//! - Gradient checkpointing for memory-efficient training
6//! - Memory estimation utilities
7
8use crate::{Callback, TrainResult, TrainingState};
9use std::collections::HashMap;
10use std::time::Instant;
11
12/// Memory statistics for a training session.
13#[derive(Debug, Clone, Default)]
14pub struct MemoryStats {
15    /// Current allocated memory in bytes (estimated).
16    pub current_allocated: usize,
17    /// Peak allocated memory in bytes.
18    pub peak_allocated: usize,
19    /// Number of allocations tracked.
20    pub allocation_count: usize,
21    /// Memory usage history (epoch -> bytes).
22    pub history: Vec<(usize, usize)>,
23}
24
25impl MemoryStats {
26    /// Create new memory stats.
27    pub fn new() -> Self {
28        Self::default()
29    }
30
31    /// Record a memory measurement.
32    pub fn record(&mut self, epoch: usize, bytes: usize) {
33        self.current_allocated = bytes;
34        if bytes > self.peak_allocated {
35            self.peak_allocated = bytes;
36        }
37        self.allocation_count += 1;
38        self.history.push((epoch, bytes));
39    }
40
41    /// Get memory usage as formatted string.
42    pub fn format_bytes(bytes: usize) -> String {
43        if bytes >= 1_073_741_824 {
44            format!("{:.2} GB", bytes as f64 / 1_073_741_824.0)
45        } else if bytes >= 1_048_576 {
46            format!("{:.2} MB", bytes as f64 / 1_048_576.0)
47        } else if bytes >= 1024 {
48            format!("{:.2} KB", bytes as f64 / 1024.0)
49        } else {
50            format!("{} bytes", bytes)
51        }
52    }
53
54    /// Get summary of memory usage.
55    pub fn summary(&self) -> String {
56        format!(
57            "Memory: current={}, peak={}, allocations={}",
58            Self::format_bytes(self.current_allocated),
59            Self::format_bytes(self.peak_allocated),
60            self.allocation_count
61        )
62    }
63}
64
65/// Gradient checkpointing configuration.
66///
67/// Gradient checkpointing reduces memory usage by recomputing activations
68/// during the backward pass instead of storing them. This trades compute
69/// for memory.
70#[derive(Debug, Clone)]
71pub struct GradientCheckpointConfig {
72    /// Whether gradient checkpointing is enabled.
73    pub enabled: bool,
74    /// Checkpointing strategy.
75    pub strategy: CheckpointStrategy,
76    /// Layers to checkpoint (by name pattern).
77    pub checkpoint_layers: Vec<String>,
78    /// Memory threshold to trigger checkpointing (bytes).
79    pub memory_threshold: Option<usize>,
80}
81
82impl Default for GradientCheckpointConfig {
83    fn default() -> Self {
84        Self {
85            enabled: false,
86            strategy: CheckpointStrategy::Uniform,
87            checkpoint_layers: Vec::new(),
88            memory_threshold: None,
89        }
90    }
91}
92
93impl GradientCheckpointConfig {
94    /// Create a new gradient checkpoint config.
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    /// Enable gradient checkpointing.
100    pub fn enabled(mut self) -> Self {
101        self.enabled = true;
102        self
103    }
104
105    /// Set checkpointing strategy.
106    pub fn with_strategy(mut self, strategy: CheckpointStrategy) -> Self {
107        self.strategy = strategy;
108        self
109    }
110
111    /// Set layers to checkpoint.
112    pub fn with_layers(mut self, layers: Vec<String>) -> Self {
113        self.checkpoint_layers = layers;
114        self
115    }
116
117    /// Set memory threshold for automatic checkpointing.
118    pub fn with_memory_threshold(mut self, threshold: usize) -> Self {
119        self.memory_threshold = Some(threshold);
120        self
121    }
122}
123
124/// Gradient checkpointing strategy.
125#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum CheckpointStrategy {
127    /// Checkpoint every N layers uniformly.
128    Uniform,
129    /// Checkpoint based on memory estimates.
130    MemoryBased,
131    /// Custom checkpointing (user-specified layers).
132    Custom,
133    /// Square root strategy (checkpoint sqrt(L) layers for L total).
134    SqrtStrategy,
135}
136
137/// Memory profiler callback for tracking memory usage during training.
138///
139/// # Example
140/// ```no_run
141/// use tensorlogic_train::{MemoryProfilerCallback, Callback};
142///
143/// let mut profiler = MemoryProfilerCallback::new()
144///     .with_epoch_tracking(true)
145///     .with_batch_tracking(false);
146///
147/// // Use in training loop
148/// ```
149#[derive(Debug, Clone)]
150pub struct MemoryProfilerCallback {
151    /// Memory statistics.
152    pub stats: MemoryStats,
153    /// Whether to track at epoch level.
154    track_epoch: bool,
155    /// Whether to track at batch level.
156    track_batch: bool,
157    /// Logging frequency (every N epochs/batches).
158    log_frequency: usize,
159    /// Start time for duration tracking.
160    start_time: Option<Instant>,
161    /// Memory usage per batch in current epoch.
162    batch_memory: Vec<usize>,
163}
164
165impl MemoryProfilerCallback {
166    /// Create a new memory profiler callback.
167    pub fn new() -> Self {
168        Self {
169            stats: MemoryStats::new(),
170            track_epoch: true,
171            track_batch: false,
172            log_frequency: 1,
173            start_time: None,
174            batch_memory: Vec::new(),
175        }
176    }
177
178    /// Enable epoch-level tracking.
179    pub fn with_epoch_tracking(mut self, enabled: bool) -> Self {
180        self.track_epoch = enabled;
181        self
182    }
183
184    /// Enable batch-level tracking.
185    pub fn with_batch_tracking(mut self, enabled: bool) -> Self {
186        self.track_batch = enabled;
187        self
188    }
189
190    /// Set logging frequency.
191    pub fn with_log_frequency(mut self, frequency: usize) -> Self {
192        self.log_frequency = frequency.max(1);
193        self
194    }
195
196    /// Get memory statistics.
197    pub fn get_stats(&self) -> &MemoryStats {
198        &self.stats
199    }
200
201    /// Estimate current memory usage based on tensors.
202    ///
203    /// This is a simplified estimation - actual memory usage depends
204    /// on allocator behavior, alignment, and fragmentation.
205    pub fn estimate_tensor_memory(tensors: &[&[f64]]) -> usize {
206        tensors.iter().map(|t| std::mem::size_of_val(*t)).sum()
207    }
208
209    /// Estimate memory for parameters.
210    pub fn estimate_parameter_memory(
211        parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::IxDyn>>,
212    ) -> usize {
213        parameters.values().map(|p| p.len() * 8).sum()
214    }
215
216    /// Get memory usage report.
217    pub fn report(&self) -> String {
218        let mut report = String::new();
219        report.push_str("=== Memory Profiling Report ===\n");
220        report.push_str(&format!(
221            "Current Memory: {}\n",
222            MemoryStats::format_bytes(self.stats.current_allocated)
223        ));
224        report.push_str(&format!(
225            "Peak Memory: {}\n",
226            MemoryStats::format_bytes(self.stats.peak_allocated)
227        ));
228        report.push_str(&format!(
229            "Total Allocations: {}\n",
230            self.stats.allocation_count
231        ));
232
233        if !self.stats.history.is_empty() {
234            report.push_str("\nMemory History:\n");
235            for (epoch, bytes) in &self.stats.history {
236                report.push_str(&format!(
237                    "  Epoch {}: {}\n",
238                    epoch,
239                    MemoryStats::format_bytes(*bytes)
240                ));
241            }
242        }
243
244        report
245    }
246}
247
248impl Default for MemoryProfilerCallback {
249    fn default() -> Self {
250        Self::new()
251    }
252}
253
254impl Callback for MemoryProfilerCallback {
255    fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
256        self.start_time = Some(Instant::now());
257        Ok(())
258    }
259
260    fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
261        self.batch_memory.clear();
262        Ok(())
263    }
264
265    fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
266        if !self.track_epoch {
267            return Ok(());
268        }
269
270        // Estimate memory from state
271        // In a real implementation, this would use system memory APIs
272        let estimated_memory = estimate_training_memory(state);
273        self.stats.record(epoch, estimated_memory);
274
275        if epoch.is_multiple_of(self.log_frequency) {
276            println!(
277                "Epoch {}: Memory usage ~ {}",
278                epoch,
279                MemoryStats::format_bytes(estimated_memory)
280            );
281        }
282
283        Ok(())
284    }
285
286    fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
287        if !self.track_batch {
288            return Ok(());
289        }
290
291        let estimated_memory = estimate_training_memory(state);
292        self.batch_memory.push(estimated_memory);
293
294        if batch.is_multiple_of(self.log_frequency) && self.log_frequency > 1 {
295            println!(
296                "  Batch {}: Memory ~ {}",
297                batch,
298                MemoryStats::format_bytes(estimated_memory)
299            );
300        }
301
302        Ok(())
303    }
304
305    fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
306        if let Some(start) = self.start_time {
307            let duration = start.elapsed();
308            println!("\n{}", self.report());
309            println!("Training duration: {:.2?}", duration);
310        }
311        Ok(())
312    }
313}
314
315/// Estimate training memory from state.
316///
317/// This is a rough estimate based on the training state.
318/// Actual memory usage may differ significantly.
319fn estimate_training_memory(state: &TrainingState) -> usize {
320    // Base overhead for training state
321    let base_overhead = 1024 * 1024; // 1 MB base
322
323    // Estimate based on metrics count
324    let metrics_memory = state.metrics.len() * 1024;
325
326    // Total estimate
327    base_overhead + metrics_memory
328}
329
330/// Memory-efficient training utilities.
331pub struct MemoryEfficientTraining;
332
333impl MemoryEfficientTraining {
334    /// Calculate optimal batch size for available memory.
335    ///
336    /// # Arguments
337    /// * `available_memory` - Available memory in bytes
338    /// * `sample_size` - Size of a single sample in bytes
339    /// * `model_memory` - Model memory footprint in bytes
340    /// * `overhead_factor` - Overhead multiplier (typically 2-4 for gradients)
341    pub fn optimal_batch_size(
342        available_memory: usize,
343        sample_size: usize,
344        model_memory: usize,
345        overhead_factor: f64,
346    ) -> usize {
347        let available_for_batch = available_memory.saturating_sub(model_memory);
348        let sample_total = (sample_size as f64 * overhead_factor) as usize;
349
350        if sample_total == 0 {
351            return 1;
352        }
353
354        (available_for_batch / sample_total).max(1)
355    }
356
357    /// Estimate memory for a model with given parameter count.
358    ///
359    /// # Arguments
360    /// * `num_parameters` - Number of model parameters
361    /// * `with_gradients` - Whether to include gradient storage
362    /// * `with_optimizer_state` - Whether to include optimizer state (e.g., Adam moments)
363    pub fn estimate_model_memory(
364        num_parameters: usize,
365        with_gradients: bool,
366        with_optimizer_state: bool,
367    ) -> usize {
368        let param_size = num_parameters * std::mem::size_of::<f64>();
369        let mut total = param_size;
370
371        if with_gradients {
372            total += param_size; // Gradients same size as params
373        }
374
375        if with_optimizer_state {
376            // Adam has 2 moment tensors per parameter
377            total += param_size * 2;
378        }
379
380        total
381    }
382
383    /// Calculate gradient accumulation steps for target batch size.
384    ///
385    /// # Arguments
386    /// * `target_batch_size` - Desired effective batch size
387    /// * `actual_batch_size` - Batch size that fits in memory
388    pub fn gradient_accumulation_steps(
389        target_batch_size: usize,
390        actual_batch_size: usize,
391    ) -> usize {
392        if actual_batch_size == 0 {
393            return 1;
394        }
395        target_batch_size.div_ceil(actual_batch_size).max(1)
396    }
397
398    /// Get recommended memory settings for a given GPU memory.
399    pub fn recommended_settings(gpu_memory_gb: f64) -> MemorySettings {
400        let memory_bytes = (gpu_memory_gb * 1024.0 * 1024.0 * 1024.0) as usize;
401
402        MemorySettings {
403            max_batch_size: (memory_bytes / (100 * 1024 * 1024)).max(1), // ~100MB per batch
404            use_gradient_checkpointing: gpu_memory_gb < 16.0,
405            use_mixed_precision: gpu_memory_gb < 24.0,
406            gradient_accumulation: if gpu_memory_gb < 8.0 { 4 } else { 1 },
407        }
408    }
409}
410
411/// Recommended memory settings.
412#[derive(Debug, Clone)]
413pub struct MemorySettings {
414    /// Maximum recommended batch size.
415    pub max_batch_size: usize,
416    /// Whether to use gradient checkpointing.
417    pub use_gradient_checkpointing: bool,
418    /// Whether to use mixed precision.
419    pub use_mixed_precision: bool,
420    /// Gradient accumulation steps.
421    pub gradient_accumulation: usize,
422}
423
424/// Memory budget manager for training.
425#[derive(Debug, Clone)]
426pub struct MemoryBudgetManager {
427    /// Total memory budget in bytes.
428    budget: usize,
429    /// Current allocated memory.
430    allocated: usize,
431    /// Allocation tracking.
432    allocations: HashMap<String, usize>,
433}
434
435impl MemoryBudgetManager {
436    /// Create a new memory budget manager.
437    ///
438    /// # Arguments
439    /// * `budget_bytes` - Total memory budget in bytes
440    pub fn new(budget_bytes: usize) -> Self {
441        Self {
442            budget: budget_bytes,
443            allocated: 0,
444            allocations: HashMap::new(),
445        }
446    }
447
448    /// Create from GB specification.
449    pub fn from_gb(gb: f64) -> Self {
450        let bytes = (gb * 1024.0 * 1024.0 * 1024.0) as usize;
451        Self::new(bytes)
452    }
453
454    /// Try to allocate memory.
455    ///
456    /// Returns true if allocation succeeded, false if would exceed budget.
457    pub fn try_allocate(&mut self, name: &str, bytes: usize) -> bool {
458        if self.allocated + bytes > self.budget {
459            return false;
460        }
461
462        self.allocated += bytes;
463        *self.allocations.entry(name.to_string()).or_default() += bytes;
464        true
465    }
466
467    /// Free allocated memory.
468    pub fn free(&mut self, name: &str) {
469        if let Some(bytes) = self.allocations.remove(name) {
470            self.allocated = self.allocated.saturating_sub(bytes);
471        }
472    }
473
474    /// Get available memory.
475    pub fn available(&self) -> usize {
476        self.budget.saturating_sub(self.allocated)
477    }
478
479    /// Get utilization percentage.
480    pub fn utilization(&self) -> f64 {
481        if self.budget == 0 {
482            return 0.0;
483        }
484        (self.allocated as f64 / self.budget as f64) * 100.0
485    }
486
487    /// Get allocation summary.
488    pub fn summary(&self) -> String {
489        format!(
490            "Memory Budget: {:.2}% used ({} / {})",
491            self.utilization(),
492            MemoryStats::format_bytes(self.allocated),
493            MemoryStats::format_bytes(self.budget)
494        )
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    #[test]
503    fn test_memory_stats() {
504        let mut stats = MemoryStats::new();
505
506        stats.record(0, 1024 * 1024);
507        stats.record(1, 2 * 1024 * 1024);
508        stats.record(2, 1024 * 1024);
509
510        assert_eq!(stats.current_allocated, 1024 * 1024);
511        assert_eq!(stats.peak_allocated, 2 * 1024 * 1024);
512        assert_eq!(stats.allocation_count, 3);
513        assert_eq!(stats.history.len(), 3);
514    }
515
516    #[test]
517    fn test_format_bytes() {
518        assert_eq!(MemoryStats::format_bytes(500), "500 bytes");
519        assert_eq!(MemoryStats::format_bytes(2048), "2.00 KB");
520        assert_eq!(MemoryStats::format_bytes(2 * 1024 * 1024), "2.00 MB");
521        assert_eq!(MemoryStats::format_bytes(3 * 1024 * 1024 * 1024), "3.00 GB");
522    }
523
524    #[test]
525    fn test_gradient_checkpoint_config() {
526        let config = GradientCheckpointConfig::new()
527            .enabled()
528            .with_strategy(CheckpointStrategy::SqrtStrategy)
529            .with_layers(vec!["layer1".to_string(), "layer2".to_string()]);
530
531        assert!(config.enabled);
532        assert_eq!(config.strategy, CheckpointStrategy::SqrtStrategy);
533        assert_eq!(config.checkpoint_layers.len(), 2);
534    }
535
536    #[test]
537    fn test_memory_profiler_callback() {
538        let profiler = MemoryProfilerCallback::new()
539            .with_epoch_tracking(true)
540            .with_batch_tracking(false)
541            .with_log_frequency(5);
542
543        assert!(profiler.track_epoch);
544        assert!(!profiler.track_batch);
545        assert_eq!(profiler.log_frequency, 5);
546    }
547
548    #[test]
549    fn test_optimal_batch_size() {
550        // 8 GB available, 1 MB per sample, 1 GB model, 3x overhead
551        let batch_size = MemoryEfficientTraining::optimal_batch_size(
552            8 * 1024 * 1024 * 1024, // 8 GB
553            1024 * 1024,            // 1 MB
554            1024 * 1024 * 1024,     // 1 GB model
555            3.0,                    // 3x overhead
556        );
557
558        // (8GB - 1GB) / (1MB * 3) = ~2333
559        assert!(batch_size > 2000);
560        assert!(batch_size < 2500);
561    }
562
563    #[test]
564    fn test_estimate_model_memory() {
565        let params = 1_000_000;
566
567        // Just parameters
568        let base = MemoryEfficientTraining::estimate_model_memory(params, false, false);
569        assert_eq!(base, params * 8);
570
571        // With gradients
572        let with_grads = MemoryEfficientTraining::estimate_model_memory(params, true, false);
573        assert_eq!(with_grads, params * 8 * 2);
574
575        // With optimizer state (Adam)
576        let with_adam = MemoryEfficientTraining::estimate_model_memory(params, true, true);
577        assert_eq!(with_adam, params * 8 * 4);
578    }
579
580    #[test]
581    fn test_gradient_accumulation_steps() {
582        assert_eq!(
583            MemoryEfficientTraining::gradient_accumulation_steps(64, 16),
584            4
585        );
586        assert_eq!(
587            MemoryEfficientTraining::gradient_accumulation_steps(100, 32),
588            4 // ceil(100/32) = 4
589        );
590        assert_eq!(
591            MemoryEfficientTraining::gradient_accumulation_steps(32, 32),
592            1
593        );
594    }
595
596    #[test]
597    fn test_recommended_settings() {
598        let small = MemoryEfficientTraining::recommended_settings(8.0);
599        assert!(small.use_gradient_checkpointing);
600        assert!(small.use_mixed_precision);
601
602        let large = MemoryEfficientTraining::recommended_settings(32.0);
603        assert!(!large.use_gradient_checkpointing);
604        assert!(!large.use_mixed_precision);
605    }
606
607    #[test]
608    fn test_memory_budget_manager() {
609        let mut manager = MemoryBudgetManager::new(100 * 1024 * 1024); // 100 MB
610
611        // Allocate 50 MB
612        assert!(manager.try_allocate("model", 50 * 1024 * 1024));
613        assert_eq!(manager.utilization(), 50.0);
614
615        // Allocate another 30 MB
616        assert!(manager.try_allocate("gradients", 30 * 1024 * 1024));
617        assert_eq!(manager.utilization(), 80.0);
618
619        // Try to allocate 30 MB more (should fail)
620        assert!(!manager.try_allocate("overflow", 30 * 1024 * 1024));
621
622        // Free gradients
623        manager.free("gradients");
624        assert_eq!(manager.utilization(), 50.0);
625
626        // Now can allocate
627        assert!(manager.try_allocate("new", 30 * 1024 * 1024));
628    }
629
630    #[test]
631    fn test_memory_budget_from_gb() {
632        let manager = MemoryBudgetManager::from_gb(4.0);
633        assert_eq!(manager.budget, 4 * 1024 * 1024 * 1024);
634    }
635}