1use crate::{Callback, TrainResult, TrainingState};
9use std::collections::HashMap;
10use std::time::Instant;
11
12#[derive(Debug, Clone, Default)]
14pub struct MemoryStats {
15 pub current_allocated: usize,
17 pub peak_allocated: usize,
19 pub allocation_count: usize,
21 pub history: Vec<(usize, usize)>,
23}
24
25impl MemoryStats {
26 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn record(&mut self, epoch: usize, bytes: usize) {
33 self.current_allocated = bytes;
34 if bytes > self.peak_allocated {
35 self.peak_allocated = bytes;
36 }
37 self.allocation_count += 1;
38 self.history.push((epoch, bytes));
39 }
40
41 pub fn format_bytes(bytes: usize) -> String {
43 if bytes >= 1_073_741_824 {
44 format!("{:.2} GB", bytes as f64 / 1_073_741_824.0)
45 } else if bytes >= 1_048_576 {
46 format!("{:.2} MB", bytes as f64 / 1_048_576.0)
47 } else if bytes >= 1024 {
48 format!("{:.2} KB", bytes as f64 / 1024.0)
49 } else {
50 format!("{} bytes", bytes)
51 }
52 }
53
54 pub fn summary(&self) -> String {
56 format!(
57 "Memory: current={}, peak={}, allocations={}",
58 Self::format_bytes(self.current_allocated),
59 Self::format_bytes(self.peak_allocated),
60 self.allocation_count
61 )
62 }
63}
64
65#[derive(Debug, Clone)]
71pub struct GradientCheckpointConfig {
72 pub enabled: bool,
74 pub strategy: CheckpointStrategy,
76 pub checkpoint_layers: Vec<String>,
78 pub memory_threshold: Option<usize>,
80}
81
82impl Default for GradientCheckpointConfig {
83 fn default() -> Self {
84 Self {
85 enabled: false,
86 strategy: CheckpointStrategy::Uniform,
87 checkpoint_layers: Vec::new(),
88 memory_threshold: None,
89 }
90 }
91}
92
93impl GradientCheckpointConfig {
94 pub fn new() -> Self {
96 Self::default()
97 }
98
99 pub fn enabled(mut self) -> Self {
101 self.enabled = true;
102 self
103 }
104
105 pub fn with_strategy(mut self, strategy: CheckpointStrategy) -> Self {
107 self.strategy = strategy;
108 self
109 }
110
111 pub fn with_layers(mut self, layers: Vec<String>) -> Self {
113 self.checkpoint_layers = layers;
114 self
115 }
116
117 pub fn with_memory_threshold(mut self, threshold: usize) -> Self {
119 self.memory_threshold = Some(threshold);
120 self
121 }
122}
123
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
126pub enum CheckpointStrategy {
127 Uniform,
129 MemoryBased,
131 Custom,
133 SqrtStrategy,
135}
136
137#[derive(Debug, Clone)]
150pub struct MemoryProfilerCallback {
151 pub stats: MemoryStats,
153 track_epoch: bool,
155 track_batch: bool,
157 log_frequency: usize,
159 start_time: Option<Instant>,
161 batch_memory: Vec<usize>,
163}
164
165impl MemoryProfilerCallback {
166 pub fn new() -> Self {
168 Self {
169 stats: MemoryStats::new(),
170 track_epoch: true,
171 track_batch: false,
172 log_frequency: 1,
173 start_time: None,
174 batch_memory: Vec::new(),
175 }
176 }
177
178 pub fn with_epoch_tracking(mut self, enabled: bool) -> Self {
180 self.track_epoch = enabled;
181 self
182 }
183
184 pub fn with_batch_tracking(mut self, enabled: bool) -> Self {
186 self.track_batch = enabled;
187 self
188 }
189
190 pub fn with_log_frequency(mut self, frequency: usize) -> Self {
192 self.log_frequency = frequency.max(1);
193 self
194 }
195
196 pub fn get_stats(&self) -> &MemoryStats {
198 &self.stats
199 }
200
201 pub fn estimate_tensor_memory(tensors: &[&[f64]]) -> usize {
206 tensors.iter().map(|t| std::mem::size_of_val(*t)).sum()
207 }
208
209 pub fn estimate_parameter_memory(
211 parameters: &HashMap<String, scirs2_core::ndarray::Array<f64, scirs2_core::ndarray::IxDyn>>,
212 ) -> usize {
213 parameters.values().map(|p| p.len() * 8).sum()
214 }
215
216 pub fn report(&self) -> String {
218 let mut report = String::new();
219 report.push_str("=== Memory Profiling Report ===\n");
220 report.push_str(&format!(
221 "Current Memory: {}\n",
222 MemoryStats::format_bytes(self.stats.current_allocated)
223 ));
224 report.push_str(&format!(
225 "Peak Memory: {}\n",
226 MemoryStats::format_bytes(self.stats.peak_allocated)
227 ));
228 report.push_str(&format!(
229 "Total Allocations: {}\n",
230 self.stats.allocation_count
231 ));
232
233 if !self.stats.history.is_empty() {
234 report.push_str("\nMemory History:\n");
235 for (epoch, bytes) in &self.stats.history {
236 report.push_str(&format!(
237 " Epoch {}: {}\n",
238 epoch,
239 MemoryStats::format_bytes(*bytes)
240 ));
241 }
242 }
243
244 report
245 }
246}
247
248impl Default for MemoryProfilerCallback {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl Callback for MemoryProfilerCallback {
255 fn on_train_begin(&mut self, _state: &TrainingState) -> TrainResult<()> {
256 self.start_time = Some(Instant::now());
257 Ok(())
258 }
259
260 fn on_epoch_begin(&mut self, _epoch: usize, _state: &TrainingState) -> TrainResult<()> {
261 self.batch_memory.clear();
262 Ok(())
263 }
264
265 fn on_epoch_end(&mut self, epoch: usize, state: &TrainingState) -> TrainResult<()> {
266 if !self.track_epoch {
267 return Ok(());
268 }
269
270 let estimated_memory = estimate_training_memory(state);
273 self.stats.record(epoch, estimated_memory);
274
275 if epoch.is_multiple_of(self.log_frequency) {
276 println!(
277 "Epoch {}: Memory usage ~ {}",
278 epoch,
279 MemoryStats::format_bytes(estimated_memory)
280 );
281 }
282
283 Ok(())
284 }
285
286 fn on_batch_end(&mut self, batch: usize, state: &TrainingState) -> TrainResult<()> {
287 if !self.track_batch {
288 return Ok(());
289 }
290
291 let estimated_memory = estimate_training_memory(state);
292 self.batch_memory.push(estimated_memory);
293
294 if batch.is_multiple_of(self.log_frequency) && self.log_frequency > 1 {
295 println!(
296 " Batch {}: Memory ~ {}",
297 batch,
298 MemoryStats::format_bytes(estimated_memory)
299 );
300 }
301
302 Ok(())
303 }
304
305 fn on_train_end(&mut self, _state: &TrainingState) -> TrainResult<()> {
306 if let Some(start) = self.start_time {
307 let duration = start.elapsed();
308 println!("\n{}", self.report());
309 println!("Training duration: {:.2?}", duration);
310 }
311 Ok(())
312 }
313}
314
315fn estimate_training_memory(state: &TrainingState) -> usize {
320 let base_overhead = 1024 * 1024; let metrics_memory = state.metrics.len() * 1024;
325
326 base_overhead + metrics_memory
328}
329
330pub struct MemoryEfficientTraining;
332
333impl MemoryEfficientTraining {
334 pub fn optimal_batch_size(
342 available_memory: usize,
343 sample_size: usize,
344 model_memory: usize,
345 overhead_factor: f64,
346 ) -> usize {
347 let available_for_batch = available_memory.saturating_sub(model_memory);
348 let sample_total = (sample_size as f64 * overhead_factor) as usize;
349
350 if sample_total == 0 {
351 return 1;
352 }
353
354 (available_for_batch / sample_total).max(1)
355 }
356
357 pub fn estimate_model_memory(
364 num_parameters: usize,
365 with_gradients: bool,
366 with_optimizer_state: bool,
367 ) -> usize {
368 let param_size = num_parameters * std::mem::size_of::<f64>();
369 let mut total = param_size;
370
371 if with_gradients {
372 total += param_size; }
374
375 if with_optimizer_state {
376 total += param_size * 2;
378 }
379
380 total
381 }
382
383 pub fn gradient_accumulation_steps(
389 target_batch_size: usize,
390 actual_batch_size: usize,
391 ) -> usize {
392 if actual_batch_size == 0 {
393 return 1;
394 }
395 target_batch_size.div_ceil(actual_batch_size).max(1)
396 }
397
398 pub fn recommended_settings(gpu_memory_gb: f64) -> MemorySettings {
400 let memory_bytes = (gpu_memory_gb * 1024.0 * 1024.0 * 1024.0) as usize;
401
402 MemorySettings {
403 max_batch_size: (memory_bytes / (100 * 1024 * 1024)).max(1), use_gradient_checkpointing: gpu_memory_gb < 16.0,
405 use_mixed_precision: gpu_memory_gb < 24.0,
406 gradient_accumulation: if gpu_memory_gb < 8.0 { 4 } else { 1 },
407 }
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct MemorySettings {
414 pub max_batch_size: usize,
416 pub use_gradient_checkpointing: bool,
418 pub use_mixed_precision: bool,
420 pub gradient_accumulation: usize,
422}
423
424#[derive(Debug, Clone)]
426pub struct MemoryBudgetManager {
427 budget: usize,
429 allocated: usize,
431 allocations: HashMap<String, usize>,
433}
434
435impl MemoryBudgetManager {
436 pub fn new(budget_bytes: usize) -> Self {
441 Self {
442 budget: budget_bytes,
443 allocated: 0,
444 allocations: HashMap::new(),
445 }
446 }
447
448 pub fn from_gb(gb: f64) -> Self {
450 let bytes = (gb * 1024.0 * 1024.0 * 1024.0) as usize;
451 Self::new(bytes)
452 }
453
454 pub fn try_allocate(&mut self, name: &str, bytes: usize) -> bool {
458 if self.allocated + bytes > self.budget {
459 return false;
460 }
461
462 self.allocated += bytes;
463 *self.allocations.entry(name.to_string()).or_default() += bytes;
464 true
465 }
466
467 pub fn free(&mut self, name: &str) {
469 if let Some(bytes) = self.allocations.remove(name) {
470 self.allocated = self.allocated.saturating_sub(bytes);
471 }
472 }
473
474 pub fn available(&self) -> usize {
476 self.budget.saturating_sub(self.allocated)
477 }
478
479 pub fn utilization(&self) -> f64 {
481 if self.budget == 0 {
482 return 0.0;
483 }
484 (self.allocated as f64 / self.budget as f64) * 100.0
485 }
486
487 pub fn summary(&self) -> String {
489 format!(
490 "Memory Budget: {:.2}% used ({} / {})",
491 self.utilization(),
492 MemoryStats::format_bytes(self.allocated),
493 MemoryStats::format_bytes(self.budget)
494 )
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_memory_stats() {
504 let mut stats = MemoryStats::new();
505
506 stats.record(0, 1024 * 1024);
507 stats.record(1, 2 * 1024 * 1024);
508 stats.record(2, 1024 * 1024);
509
510 assert_eq!(stats.current_allocated, 1024 * 1024);
511 assert_eq!(stats.peak_allocated, 2 * 1024 * 1024);
512 assert_eq!(stats.allocation_count, 3);
513 assert_eq!(stats.history.len(), 3);
514 }
515
516 #[test]
517 fn test_format_bytes() {
518 assert_eq!(MemoryStats::format_bytes(500), "500 bytes");
519 assert_eq!(MemoryStats::format_bytes(2048), "2.00 KB");
520 assert_eq!(MemoryStats::format_bytes(2 * 1024 * 1024), "2.00 MB");
521 assert_eq!(MemoryStats::format_bytes(3 * 1024 * 1024 * 1024), "3.00 GB");
522 }
523
524 #[test]
525 fn test_gradient_checkpoint_config() {
526 let config = GradientCheckpointConfig::new()
527 .enabled()
528 .with_strategy(CheckpointStrategy::SqrtStrategy)
529 .with_layers(vec!["layer1".to_string(), "layer2".to_string()]);
530
531 assert!(config.enabled);
532 assert_eq!(config.strategy, CheckpointStrategy::SqrtStrategy);
533 assert_eq!(config.checkpoint_layers.len(), 2);
534 }
535
536 #[test]
537 fn test_memory_profiler_callback() {
538 let profiler = MemoryProfilerCallback::new()
539 .with_epoch_tracking(true)
540 .with_batch_tracking(false)
541 .with_log_frequency(5);
542
543 assert!(profiler.track_epoch);
544 assert!(!profiler.track_batch);
545 assert_eq!(profiler.log_frequency, 5);
546 }
547
548 #[test]
549 fn test_optimal_batch_size() {
550 let batch_size = MemoryEfficientTraining::optimal_batch_size(
552 8 * 1024 * 1024 * 1024, 1024 * 1024, 1024 * 1024 * 1024, 3.0, );
557
558 assert!(batch_size > 2000);
560 assert!(batch_size < 2500);
561 }
562
563 #[test]
564 fn test_estimate_model_memory() {
565 let params = 1_000_000;
566
567 let base = MemoryEfficientTraining::estimate_model_memory(params, false, false);
569 assert_eq!(base, params * 8);
570
571 let with_grads = MemoryEfficientTraining::estimate_model_memory(params, true, false);
573 assert_eq!(with_grads, params * 8 * 2);
574
575 let with_adam = MemoryEfficientTraining::estimate_model_memory(params, true, true);
577 assert_eq!(with_adam, params * 8 * 4);
578 }
579
580 #[test]
581 fn test_gradient_accumulation_steps() {
582 assert_eq!(
583 MemoryEfficientTraining::gradient_accumulation_steps(64, 16),
584 4
585 );
586 assert_eq!(
587 MemoryEfficientTraining::gradient_accumulation_steps(100, 32),
588 4 );
590 assert_eq!(
591 MemoryEfficientTraining::gradient_accumulation_steps(32, 32),
592 1
593 );
594 }
595
596 #[test]
597 fn test_recommended_settings() {
598 let small = MemoryEfficientTraining::recommended_settings(8.0);
599 assert!(small.use_gradient_checkpointing);
600 assert!(small.use_mixed_precision);
601
602 let large = MemoryEfficientTraining::recommended_settings(32.0);
603 assert!(!large.use_gradient_checkpointing);
604 assert!(!large.use_mixed_precision);
605 }
606
607 #[test]
608 fn test_memory_budget_manager() {
609 let mut manager = MemoryBudgetManager::new(100 * 1024 * 1024); assert!(manager.try_allocate("model", 50 * 1024 * 1024));
613 assert_eq!(manager.utilization(), 50.0);
614
615 assert!(manager.try_allocate("gradients", 30 * 1024 * 1024));
617 assert_eq!(manager.utilization(), 80.0);
618
619 assert!(!manager.try_allocate("overflow", 30 * 1024 * 1024));
621
622 manager.free("gradients");
624 assert_eq!(manager.utilization(), 50.0);
625
626 assert!(manager.try_allocate("new", 30 * 1024 * 1024));
628 }
629
630 #[test]
631 fn test_memory_budget_from_gb() {
632 let manager = MemoryBudgetManager::from_gb(4.0);
633 assert_eq!(manager.budget, 4 * 1024 * 1024 * 1024);
634 }
635}