1use anyhow::Result;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex};
12use trustformers_core::tensor::Tensor;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct MemoryOptimizationConfig {
17 pub gradient_checkpointing: bool,
19 pub cpu_offloading: bool,
21 pub dynamic_memory: bool,
23 pub tensor_rematerialization: bool,
25 pub memory_threshold: usize,
27 pub max_memory_usage: usize,
29 pub checkpoint_interval: usize,
31 pub offload_threshold: usize,
33}
34
35impl Default for MemoryOptimizationConfig {
36 fn default() -> Self {
37 Self {
38 gradient_checkpointing: false,
39 cpu_offloading: false,
40 dynamic_memory: false,
41 tensor_rematerialization: false,
42 memory_threshold: 1_000_000_000, max_memory_usage: 8_000_000_000, checkpoint_interval: 4,
45 offload_threshold: 100_000_000, }
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct Checkpoint {
53 pub layer_index: usize,
54 pub activations: Vec<Tensor>,
55 pub timestamp: std::time::Instant,
56}
57
58#[allow(dead_code)]
60pub struct MemoryOptimizer {
61 config: MemoryOptimizationConfig,
62 checkpoints: VecDeque<Checkpoint>,
63 offloaded_tensors: HashMap<String, (Tensor, std::time::Instant)>,
64 memory_usage: Arc<Mutex<usize>>,
65 recompute_cache: HashMap<String, Vec<Tensor>>,
66}
67
68impl MemoryOptimizer {
69 pub fn new(config: MemoryOptimizationConfig) -> Self {
70 Self {
71 config,
72 checkpoints: VecDeque::new(),
73 offloaded_tensors: HashMap::new(),
74 memory_usage: Arc::new(Mutex::new(0)),
75 recompute_cache: HashMap::new(),
76 }
77 }
78
79 pub fn create_checkpoint(
81 &mut self,
82 layer_index: usize,
83 activations: Vec<Tensor>,
84 ) -> Result<()> {
85 if !self.config.gradient_checkpointing {
86 return Ok(());
87 }
88
89 let checkpoint = Checkpoint {
90 layer_index,
91 activations,
92 timestamp: std::time::Instant::now(),
93 };
94
95 self.checkpoints.push_back(checkpoint);
96
97 while self.checkpoints.len() > self.config.checkpoint_interval * 2 {
99 self.checkpoints.pop_front();
100 }
101
102 Ok(())
103 }
104
105 pub fn get_checkpoint_activations(&self, layer_index: usize) -> Option<Vec<Tensor>> {
107 if !self.config.gradient_checkpointing {
108 return None;
109 }
110
111 for checkpoint in self.checkpoints.iter().rev() {
113 if checkpoint.layer_index == layer_index {
114 return Some(checkpoint.activations.clone());
115 }
116 }
117
118 None
119 }
120
121 pub fn offload_to_cpu(&mut self, name: String, tensor: Tensor) -> Result<()> {
123 if !self.config.cpu_offloading {
124 return Ok(());
125 }
126
127 let tensor_size = self.estimate_tensor_size(&tensor)?;
128
129 if tensor_size >= self.config.offload_threshold {
130 self.offloaded_tensors.insert(name, (tensor, std::time::Instant::now()));
132 }
133
134 Ok(())
135 }
136
137 pub fn retrieve_from_cpu(&mut self, name: &str) -> Option<Tensor> {
139 if !self.config.cpu_offloading {
140 return None;
141 }
142
143 self.offloaded_tensors.remove(name).map(|(tensor, _)| tensor)
144 }
145
146 fn estimate_tensor_size(&self, tensor: &Tensor) -> Result<usize> {
148 let shape = tensor.shape();
150 let element_size = 4; let total_elements: usize = shape.iter().product();
152 Ok(total_elements * element_size)
153 }
154
155 pub fn update_memory_usage(&self, delta: isize) {
157 let mut usage = self.memory_usage.lock().expect("lock should not be poisoned");
158 if delta < 0 {
159 *usage = usage.saturating_sub((-delta) as usize);
160 } else {
161 *usage += delta as usize;
162 }
163 }
164
165 pub fn get_memory_usage(&self) -> usize {
167 *self.memory_usage.lock().expect("lock should not be poisoned")
168 }
169
170 pub fn should_cleanup(&self) -> bool {
172 let usage = self.get_memory_usage();
173 usage > self.config.memory_threshold
174 }
175
176 pub fn cleanup(&mut self) -> Result<usize> {
178 let mut freed_bytes = 0;
179
180 if self.config.dynamic_memory {
181 let now = std::time::Instant::now();
183 let old_checkpoints: Vec<_> = self
184 .checkpoints
185 .iter()
186 .enumerate()
187 .filter(|(_, checkpoint)| now.duration_since(checkpoint.timestamp).as_secs() > 30)
188 .map(|(i, _)| i)
189 .collect();
190
191 for i in old_checkpoints.into_iter().rev() {
192 if let Some(checkpoint) = self.checkpoints.remove(i) {
193 for tensor in &checkpoint.activations {
194 freed_bytes += self.estimate_tensor_size(tensor)?;
195 }
196 }
197 }
198
199 let old_tensors: Vec<_> = self
201 .offloaded_tensors
202 .iter()
203 .filter(|(_, (_, timestamp))| now.duration_since(*timestamp).as_secs() > 60)
204 .map(|(name, _)| name.clone())
205 .collect();
206
207 for name in old_tensors {
208 if let Some((tensor, _)) = self.offloaded_tensors.remove(&name) {
209 freed_bytes += self.estimate_tensor_size(&tensor)?;
210 }
211 }
212
213 if self.get_memory_usage() > self.config.max_memory_usage {
215 for tensors in self.recompute_cache.values() {
216 for tensor in tensors {
217 freed_bytes += self.estimate_tensor_size(tensor)?;
218 }
219 }
220 self.recompute_cache.clear();
221 }
222 }
223
224 self.update_memory_usage(-(freed_bytes as isize));
225 Ok(freed_bytes)
226 }
227
228 pub fn store_for_rematerialization(&mut self, key: String, tensors: Vec<Tensor>) -> Result<()> {
230 if !self.config.tensor_rematerialization {
231 return Ok(());
232 }
233
234 let mut total_size = 0;
235 for tensor in &tensors {
236 total_size += self.estimate_tensor_size(tensor)?;
237 }
238
239 if total_size < self.config.offload_threshold {
241 self.recompute_cache.insert(key, tensors);
242 }
243
244 Ok(())
245 }
246
247 pub fn retrieve_for_rematerialization(&mut self, key: &str) -> Option<Vec<Tensor>> {
249 if !self.config.tensor_rematerialization {
250 return None;
251 }
252
253 self.recompute_cache.remove(key)
254 }
255
256 pub fn get_stats(&self) -> MemoryOptimizationStats {
258 MemoryOptimizationStats {
259 current_memory_usage: self.get_memory_usage(),
260 checkpoints_count: self.checkpoints.len(),
261 offloaded_tensors_count: self.offloaded_tensors.len(),
262 recompute_cache_size: self.recompute_cache.len(),
263 memory_threshold: self.config.memory_threshold,
264 max_memory_usage: self.config.max_memory_usage,
265 }
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct MemoryOptimizationStats {
272 pub current_memory_usage: usize,
273 pub checkpoints_count: usize,
274 pub offloaded_tensors_count: usize,
275 pub recompute_cache_size: usize,
276 pub memory_threshold: usize,
277 pub max_memory_usage: usize,
278}
279
280pub struct GradientCheckpointWrapper {
282 optimizer: MemoryOptimizer,
283 layer_index: usize,
284}
285
286impl GradientCheckpointWrapper {
287 pub fn new(optimizer: MemoryOptimizer, layer_index: usize) -> Self {
288 Self {
289 optimizer,
290 layer_index,
291 }
292 }
293
294 pub fn forward_with_checkpoint(&mut self, inputs: Vec<Tensor>) -> Result<Vec<Tensor>> {
296 self.optimizer.create_checkpoint(self.layer_index, inputs.clone())?;
298
299 let outputs = inputs; Ok(outputs)
303 }
304
305 pub fn backward_with_checkpoint(&mut self, grad_outputs: Vec<Tensor>) -> Result<Vec<Tensor>> {
307 if let Some(_activations) = self.optimizer.get_checkpoint_activations(self.layer_index) {
309 Ok(grad_outputs) } else {
313 Err(anyhow::anyhow!(
314 "No checkpoint found for layer {}",
315 self.layer_index
316 ))
317 }
318 }
319}
320
321pub struct CPUOffloadManager {
323 optimizer: MemoryOptimizer,
324 offload_queue: VecDeque<String>,
325}
326
327impl CPUOffloadManager {
328 pub fn new(optimizer: MemoryOptimizer) -> Self {
329 Self {
330 optimizer,
331 offload_queue: VecDeque::new(),
332 }
333 }
334
335 pub fn schedule_offload(&mut self, name: String, tensor: Tensor) -> Result<()> {
337 self.optimizer.offload_to_cpu(name.clone(), tensor)?;
338 self.offload_queue.push_back(name);
339 Ok(())
340 }
341
342 pub fn process_offload_queue(&mut self) -> Result<()> {
344 let batch_size = 10;
346 for _ in 0..batch_size {
347 if let Some(_name) = self.offload_queue.pop_front() {
348 } else {
351 break;
352 }
353 }
354 Ok(())
355 }
356
357 pub fn retrieve_tensor(&mut self, name: &str) -> Option<Tensor> {
359 self.optimizer.retrieve_from_cpu(name)
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_memory_optimizer_creation() {
369 let config = MemoryOptimizationConfig::default();
370 let optimizer = MemoryOptimizer::new(config);
371
372 assert_eq!(optimizer.get_memory_usage(), 0);
373 assert_eq!(optimizer.checkpoints.len(), 0);
374 assert_eq!(optimizer.offloaded_tensors.len(), 0);
375 }
376
377 #[test]
378 fn test_checkpoint_creation() {
379 let config = MemoryOptimizationConfig {
380 gradient_checkpointing: true,
381 ..Default::default()
382 };
383 let mut optimizer = MemoryOptimizer::new(config);
384
385 let tensor = Tensor::zeros(&[2, 3]).expect("tensor operation failed");
387 let result = optimizer.create_checkpoint(0, vec![tensor]);
388
389 assert!(result.is_ok());
390 assert_eq!(optimizer.checkpoints.len(), 1);
391 }
392
393 #[test]
394 fn test_memory_cleanup() {
395 let config = MemoryOptimizationConfig {
396 dynamic_memory: true,
397 memory_threshold: 1000,
398 ..Default::default()
399 };
400 let mut optimizer = MemoryOptimizer::new(config);
401
402 optimizer.update_memory_usage(2000);
404 assert!(optimizer.should_cleanup());
405
406 let freed = optimizer.cleanup().expect("operation failed in test");
407 assert!(freed == 0); }
409
410 #[test]
411 fn test_cpu_offloading() {
412 let config = MemoryOptimizationConfig {
413 cpu_offloading: true,
414 offload_threshold: 100,
415 ..Default::default()
416 };
417 let mut optimizer = MemoryOptimizer::new(config);
418
419 let tensor = Tensor::zeros(&[1000, 1000]).expect("tensor operation failed"); let result = optimizer.offload_to_cpu("test_tensor".to_string(), tensor);
421
422 assert!(result.is_ok());
423 assert_eq!(optimizer.offloaded_tensors.len(), 1);
424
425 let retrieved = optimizer.retrieve_from_cpu("test_tensor");
426 assert!(retrieved.is_some());
427 assert_eq!(optimizer.offloaded_tensors.len(), 0);
428 }
429
430 #[test]
431 fn test_gradient_checkpoint_wrapper() {
432 let config = MemoryOptimizationConfig {
433 gradient_checkpointing: true,
434 ..Default::default()
435 };
436 let optimizer = MemoryOptimizer::new(config);
437 let mut wrapper = GradientCheckpointWrapper::new(optimizer, 0);
438
439 let tensor = Tensor::zeros(&[2, 3]).expect("tensor operation failed");
440 let result = wrapper.forward_with_checkpoint(vec![tensor]);
441
442 assert!(result.is_ok());
443 }
444
445 #[test]
446 fn test_memory_stats() {
447 let config = MemoryOptimizationConfig::default();
448 let optimizer = MemoryOptimizer::new(config);
449
450 let stats = optimizer.get_stats();
451 assert_eq!(stats.current_memory_usage, 0);
452 assert_eq!(stats.checkpoints_count, 0);
453 assert_eq!(stats.offloaded_tensors_count, 0);
454 }
455}