Skip to main content

trustformers_training/
memory_optimization.rs

1/// Memory optimization techniques for training large models
2///
3/// This module provides advanced memory management strategies including:
4/// - Gradient checkpointing for reducing memory usage
5/// - CPU offloading for large tensors
6/// - Dynamic memory management with automatic cleanup
7/// - Tensor rematerialization strategies
8use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex};
12use trustformers_core::tensor::Tensor;
13
14/// Configuration for memory optimization
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MemoryOptimizationConfig {
17    /// Enable gradient checkpointing
18    pub gradient_checkpointing: bool,
19    /// Enable CPU offloading for large tensors
20    pub cpu_offloading: bool,
21    /// Enable dynamic memory management
22    pub dynamic_memory: bool,
23    /// Enable tensor rematerialization
24    pub tensor_rematerialization: bool,
25    /// Memory threshold for triggering optimizations (in bytes)
26    pub memory_threshold: usize,
27    /// Maximum memory usage before aggressive cleanup (in bytes)
28    pub max_memory_usage: usize,
29    /// Checkpoint interval (number of layers)
30    pub checkpoint_interval: usize,
31    /// CPU offloading threshold (tensor size in bytes)
32    pub offload_threshold: usize,
33}
34
35impl Default for MemoryOptimizationConfig {
36    fn default() -> Self {
37        Self {
38            gradient_checkpointing: false,
39            cpu_offloading: false,
40            dynamic_memory: false,
41            tensor_rematerialization: false,
42            memory_threshold: 1_000_000_000, // 1GB
43            max_memory_usage: 8_000_000_000, // 8GB
44            checkpoint_interval: 4,
45            offload_threshold: 100_000_000, // 100MB
46        }
47    }
48}
49
50/// Checkpoint information for gradient checkpointing
51#[derive(Debug, Clone)]
52pub struct Checkpoint {
53    pub layer_index: usize,
54    pub activations: Vec<Tensor>,
55    pub timestamp: std::time::Instant,
56}
57
58/// Memory optimization manager
59#[allow(dead_code)]
60pub struct MemoryOptimizer {
61    config: MemoryOptimizationConfig,
62    checkpoints: VecDeque<Checkpoint>,
63    offloaded_tensors: HashMap<String, (Tensor, std::time::Instant)>,
64    memory_usage: Arc<Mutex<usize>>,
65    recompute_cache: HashMap<String, Vec<Tensor>>,
66}
67
68impl MemoryOptimizer {
69    pub fn new(config: MemoryOptimizationConfig) -> Self {
70        Self {
71            config,
72            checkpoints: VecDeque::new(),
73            offloaded_tensors: HashMap::new(),
74            memory_usage: Arc::new(Mutex::new(0)),
75            recompute_cache: HashMap::new(),
76        }
77    }
78
79    /// Create a checkpoint for gradient checkpointing
80    pub fn create_checkpoint(
81        &mut self,
82        layer_index: usize,
83        activations: Vec<Tensor>,
84    ) -> Result<()> {
85        if !self.config.gradient_checkpointing {
86            return Ok(());
87        }
88
89        let checkpoint = Checkpoint {
90            layer_index,
91            activations,
92            timestamp: std::time::Instant::now(),
93        };
94
95        self.checkpoints.push_back(checkpoint);
96
97        // Limit checkpoint buffer size
98        while self.checkpoints.len() > self.config.checkpoint_interval * 2 {
99            self.checkpoints.pop_front();
100        }
101
102        Ok(())
103    }
104
105    /// Retrieve activations from checkpoint
106    pub fn get_checkpoint_activations(&self, layer_index: usize) -> Option<Vec<Tensor>> {
107        if !self.config.gradient_checkpointing {
108            return None;
109        }
110
111        // Find the most recent checkpoint for this layer
112        for checkpoint in self.checkpoints.iter().rev() {
113            if checkpoint.layer_index == layer_index {
114                return Some(checkpoint.activations.clone());
115            }
116        }
117
118        None
119    }
120
121    /// Offload tensor to CPU to free GPU memory
122    pub fn offload_to_cpu(&mut self, name: String, tensor: Tensor) -> Result<()> {
123        if !self.config.cpu_offloading {
124            return Ok(());
125        }
126
127        let tensor_size = self.estimate_tensor_size(&tensor)?;
128
129        if tensor_size >= self.config.offload_threshold {
130            // Move tensor to CPU (simplified - in real implementation would use actual CPU/GPU transfer)
131            self.offloaded_tensors.insert(name, (tensor, std::time::Instant::now()));
132        }
133
134        Ok(())
135    }
136
137    /// Retrieve tensor from CPU offloading
138    pub fn retrieve_from_cpu(&mut self, name: &str) -> Option<Tensor> {
139        if !self.config.cpu_offloading {
140            return None;
141        }
142
143        self.offloaded_tensors.remove(name).map(|(tensor, _)| tensor)
144    }
145
146    /// Estimate tensor memory usage
147    fn estimate_tensor_size(&self, tensor: &Tensor) -> Result<usize> {
148        // Simplified tensor size estimation
149        let shape = tensor.shape();
150        let element_size = 4; // Assume f32 elements
151        let total_elements: usize = shape.iter().product();
152        Ok(total_elements * element_size)
153    }
154
155    /// Update memory usage tracking
156    pub fn update_memory_usage(&self, delta: isize) {
157        let mut usage = self.memory_usage.lock().expect("lock should not be poisoned");
158        if delta < 0 {
159            *usage = usage.saturating_sub((-delta) as usize);
160        } else {
161            *usage += delta as usize;
162        }
163    }
164
165    /// Get current memory usage
166    pub fn get_memory_usage(&self) -> usize {
167        *self.memory_usage.lock().expect("lock should not be poisoned")
168    }
169
170    /// Check if memory cleanup is needed
171    pub fn should_cleanup(&self) -> bool {
172        let usage = self.get_memory_usage();
173        usage > self.config.memory_threshold
174    }
175
176    /// Perform memory cleanup
177    pub fn cleanup(&mut self) -> Result<usize> {
178        let mut freed_bytes = 0;
179
180        if self.config.dynamic_memory {
181            // Clean up old checkpoints
182            let now = std::time::Instant::now();
183            let old_checkpoints: Vec<_> = self
184                .checkpoints
185                .iter()
186                .enumerate()
187                .filter(|(_, checkpoint)| now.duration_since(checkpoint.timestamp).as_secs() > 30)
188                .map(|(i, _)| i)
189                .collect();
190
191            for i in old_checkpoints.into_iter().rev() {
192                if let Some(checkpoint) = self.checkpoints.remove(i) {
193                    for tensor in &checkpoint.activations {
194                        freed_bytes += self.estimate_tensor_size(tensor)?;
195                    }
196                }
197            }
198
199            // Clean up old offloaded tensors
200            let old_tensors: Vec<_> = self
201                .offloaded_tensors
202                .iter()
203                .filter(|(_, (_, timestamp))| now.duration_since(*timestamp).as_secs() > 60)
204                .map(|(name, _)| name.clone())
205                .collect();
206
207            for name in old_tensors {
208                if let Some((tensor, _)) = self.offloaded_tensors.remove(&name) {
209                    freed_bytes += self.estimate_tensor_size(&tensor)?;
210                }
211            }
212
213            // Clear recompute cache if memory pressure is high
214            if self.get_memory_usage() > self.config.max_memory_usage {
215                for tensors in self.recompute_cache.values() {
216                    for tensor in tensors {
217                        freed_bytes += self.estimate_tensor_size(tensor)?;
218                    }
219                }
220                self.recompute_cache.clear();
221            }
222        }
223
224        self.update_memory_usage(-(freed_bytes as isize));
225        Ok(freed_bytes)
226    }
227
228    /// Store tensor for rematerialization
229    pub fn store_for_rematerialization(&mut self, key: String, tensors: Vec<Tensor>) -> Result<()> {
230        if !self.config.tensor_rematerialization {
231            return Ok(());
232        }
233
234        let mut total_size = 0;
235        for tensor in &tensors {
236            total_size += self.estimate_tensor_size(tensor)?;
237        }
238
239        // Only store if under memory threshold
240        if total_size < self.config.offload_threshold {
241            self.recompute_cache.insert(key, tensors);
242        }
243
244        Ok(())
245    }
246
247    /// Retrieve tensor for rematerialization
248    pub fn retrieve_for_rematerialization(&mut self, key: &str) -> Option<Vec<Tensor>> {
249        if !self.config.tensor_rematerialization {
250            return None;
251        }
252
253        self.recompute_cache.remove(key)
254    }
255
256    /// Get memory optimization statistics
257    pub fn get_stats(&self) -> MemoryOptimizationStats {
258        MemoryOptimizationStats {
259            current_memory_usage: self.get_memory_usage(),
260            checkpoints_count: self.checkpoints.len(),
261            offloaded_tensors_count: self.offloaded_tensors.len(),
262            recompute_cache_size: self.recompute_cache.len(),
263            memory_threshold: self.config.memory_threshold,
264            max_memory_usage: self.config.max_memory_usage,
265        }
266    }
267}
268
269/// Statistics for memory optimization
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct MemoryOptimizationStats {
272    pub current_memory_usage: usize,
273    pub checkpoints_count: usize,
274    pub offloaded_tensors_count: usize,
275    pub recompute_cache_size: usize,
276    pub memory_threshold: usize,
277    pub max_memory_usage: usize,
278}
279
280/// Gradient checkpointing wrapper for model layers
281pub struct GradientCheckpointWrapper {
282    optimizer: MemoryOptimizer,
283    layer_index: usize,
284}
285
286impl GradientCheckpointWrapper {
287    pub fn new(optimizer: MemoryOptimizer, layer_index: usize) -> Self {
288        Self {
289            optimizer,
290            layer_index,
291        }
292    }
293
294    /// Forward pass with checkpointing
295    pub fn forward_with_checkpoint(&mut self, inputs: Vec<Tensor>) -> Result<Vec<Tensor>> {
296        // Create checkpoint before forward pass
297        self.optimizer.create_checkpoint(self.layer_index, inputs.clone())?;
298
299        // Perform forward pass (simplified)
300        let outputs = inputs; // In real implementation, this would be the actual layer forward pass
301
302        Ok(outputs)
303    }
304
305    /// Backward pass with checkpointing
306    pub fn backward_with_checkpoint(&mut self, grad_outputs: Vec<Tensor>) -> Result<Vec<Tensor>> {
307        // Retrieve activations from checkpoint
308        if let Some(_activations) = self.optimizer.get_checkpoint_activations(self.layer_index) {
309            // Recompute forward pass using checkpointed activations
310            // Then compute gradients (simplified)
311            Ok(grad_outputs) // In real implementation, this would be the actual gradient computation
312        } else {
313            Err(anyhow::anyhow!(
314                "No checkpoint found for layer {}",
315                self.layer_index
316            ))
317        }
318    }
319}
320
321/// CPU offloading manager for large tensors
322pub struct CPUOffloadManager {
323    optimizer: MemoryOptimizer,
324    offload_queue: VecDeque<String>,
325}
326
327impl CPUOffloadManager {
328    pub fn new(optimizer: MemoryOptimizer) -> Self {
329        Self {
330            optimizer,
331            offload_queue: VecDeque::new(),
332        }
333    }
334
335    /// Schedule tensor for offloading
336    pub fn schedule_offload(&mut self, name: String, tensor: Tensor) -> Result<()> {
337        self.optimizer.offload_to_cpu(name.clone(), tensor)?;
338        self.offload_queue.push_back(name);
339        Ok(())
340    }
341
342    /// Process offloading queue
343    pub fn process_offload_queue(&mut self) -> Result<()> {
344        // Process a batch of offloads to avoid blocking
345        let batch_size = 10;
346        for _ in 0..batch_size {
347            if let Some(_name) = self.offload_queue.pop_front() {
348                // Offloading is handled in schedule_offload
349                // This is where additional processing could be done
350            } else {
351                break;
352            }
353        }
354        Ok(())
355    }
356
357    /// Retrieve tensor from CPU
358    pub fn retrieve_tensor(&mut self, name: &str) -> Option<Tensor> {
359        self.optimizer.retrieve_from_cpu(name)
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_memory_optimizer_creation() {
369        let config = MemoryOptimizationConfig::default();
370        let optimizer = MemoryOptimizer::new(config);
371
372        assert_eq!(optimizer.get_memory_usage(), 0);
373        assert_eq!(optimizer.checkpoints.len(), 0);
374        assert_eq!(optimizer.offloaded_tensors.len(), 0);
375    }
376
377    #[test]
378    fn test_checkpoint_creation() {
379        let config = MemoryOptimizationConfig {
380            gradient_checkpointing: true,
381            ..Default::default()
382        };
383        let mut optimizer = MemoryOptimizer::new(config);
384
385        // Create mock tensor
386        let tensor = Tensor::zeros(&[2, 3]).expect("tensor operation failed");
387        let result = optimizer.create_checkpoint(0, vec![tensor]);
388
389        assert!(result.is_ok());
390        assert_eq!(optimizer.checkpoints.len(), 1);
391    }
392
393    #[test]
394    fn test_memory_cleanup() {
395        let config = MemoryOptimizationConfig {
396            dynamic_memory: true,
397            memory_threshold: 1000,
398            ..Default::default()
399        };
400        let mut optimizer = MemoryOptimizer::new(config);
401
402        // Simulate memory usage
403        optimizer.update_memory_usage(2000);
404        assert!(optimizer.should_cleanup());
405
406        let freed = optimizer.cleanup().expect("operation failed in test");
407        assert!(freed == 0); // No actual tensors to free in this test
408    }
409
410    #[test]
411    fn test_cpu_offloading() {
412        let config = MemoryOptimizationConfig {
413            cpu_offloading: true,
414            offload_threshold: 100,
415            ..Default::default()
416        };
417        let mut optimizer = MemoryOptimizer::new(config);
418
419        let tensor = Tensor::zeros(&[1000, 1000]).expect("tensor operation failed"); // Large tensor
420        let result = optimizer.offload_to_cpu("test_tensor".to_string(), tensor);
421
422        assert!(result.is_ok());
423        assert_eq!(optimizer.offloaded_tensors.len(), 1);
424
425        let retrieved = optimizer.retrieve_from_cpu("test_tensor");
426        assert!(retrieved.is_some());
427        assert_eq!(optimizer.offloaded_tensors.len(), 0);
428    }
429
430    #[test]
431    fn test_gradient_checkpoint_wrapper() {
432        let config = MemoryOptimizationConfig {
433            gradient_checkpointing: true,
434            ..Default::default()
435        };
436        let optimizer = MemoryOptimizer::new(config);
437        let mut wrapper = GradientCheckpointWrapper::new(optimizer, 0);
438
439        let tensor = Tensor::zeros(&[2, 3]).expect("tensor operation failed");
440        let result = wrapper.forward_with_checkpoint(vec![tensor]);
441
442        assert!(result.is_ok());
443    }
444
445    #[test]
446    fn test_memory_stats() {
447        let config = MemoryOptimizationConfig::default();
448        let optimizer = MemoryOptimizer::new(config);
449
450        let stats = optimizer.get_stats();
451        assert_eq!(stats.current_memory_usage, 0);
452        assert_eq!(stats.checkpoints_count, 0);
453        assert_eq!(stats.offloaded_tensors_count, 0);
454    }
455}