Skip to main content

trustformers_optim/
cpu_offload.rs

1//! # CPU-Offloaded Optimizers
2//!
3//! This module provides CPU-offloaded versions of optimizers for memory efficiency.
4//! When training very large models, optimizer states can consume significant GPU memory.
5//! CPU offloading moves optimizer states to system RAM, reducing GPU memory usage
6//! at the cost of some performance overhead.
7//!
8//! ## Benefits
9//! - Reduces GPU memory usage by 50-75%
10//! - Enables training of larger models on limited GPU memory
11//! - Maintains numerical accuracy
12//!
13//! ## Trade-offs
14//! - Adds CPU-GPU transfer overhead (5-15% performance impact)
15//! - Requires sufficient system RAM
16//! - May create CPU bottlenecks with very fast GPUs
17
18use crate::StatefulOptimizer;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use trustformers_core::errors::Result;
22use trustformers_core::tensor::Tensor;
23use trustformers_core::traits::Optimizer;
24
25/// Configuration for CPU offloading behavior.
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CPUOffloadConfig {
28    /// Whether to offload optimizer states to CPU
29    pub offload_optimizer_states: bool,
30    /// Whether to offload gradients to CPU during backward pass
31    pub offload_gradients: bool,
32    /// Whether to offload parameters when not in use
33    pub offload_parameters: bool,
34    /// Overlap CPU-GPU transfers with computation
35    pub overlap_transfers: bool,
36    /// Pin CPU memory for faster transfers
37    pub pin_memory: bool,
38    /// Threshold for tensor size to offload (bytes)
39    pub offload_threshold: usize,
40}
41
42impl Default for CPUOffloadConfig {
43    fn default() -> Self {
44        Self {
45            offload_optimizer_states: true,
46            offload_gradients: false,
47            offload_parameters: false,
48            overlap_transfers: true,
49            pin_memory: true,
50            offload_threshold: 1024 * 1024, // 1MB
51        }
52    }
53}
54
55/// A wrapper that enables CPU offloading for any optimizer.
56pub struct CPUOffloadedOptimizer<T: Optimizer> {
57    base_optimizer: T,
58    config: CPUOffloadConfig,
59    cpu_states: HashMap<String, Tensor>,
60    gpu_states: HashMap<String, Tensor>,
61    #[allow(dead_code)]
62    transfer_stream: Option<usize>, // Stream ID for async transfers
63    memory_stats: CPUOffloadStats,
64}
65
66#[derive(Debug, Default)]
67pub struct CPUOffloadStats {
68    pub total_cpu_memory_bytes: usize,
69    pub total_gpu_memory_bytes: usize,
70    pub transfers_to_cpu: usize,
71    pub transfers_to_gpu: usize,
72    pub transfer_time_ms: f64,
73}
74
75impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
76    /// Creates a new CPU-offloaded optimizer wrapper.
77    pub fn new(base_optimizer: T, config: CPUOffloadConfig) -> Self {
78        Self {
79            base_optimizer,
80            config,
81            cpu_states: HashMap::new(),
82            gpu_states: HashMap::new(),
83            transfer_stream: None,
84            memory_stats: CPUOffloadStats::default(),
85        }
86    }
87
88    /// Creates a CPU-offloaded optimizer with default configuration.
89    pub fn with_default_config(base_optimizer: T) -> Self {
90        Self::new(base_optimizer, CPUOffloadConfig::default())
91    }
92
93    /// Get the current memory statistics.
94    pub fn get_memory_stats(&self) -> &CPUOffloadStats {
95        &self.memory_stats
96    }
97
98    /// Get the total memory savings (GPU memory freed).
99    pub fn get_memory_savings_bytes(&self) -> usize {
100        self.memory_stats.total_cpu_memory_bytes
101    }
102
103    /// Get the memory savings as a percentage of total optimizer memory.
104    pub fn get_memory_savings_percent(&self) -> f32 {
105        let total_memory =
106            self.memory_stats.total_cpu_memory_bytes + self.memory_stats.total_gpu_memory_bytes;
107        if total_memory == 0 {
108            0.0
109        } else {
110            (self.memory_stats.total_cpu_memory_bytes as f32 / total_memory as f32) * 100.0
111        }
112    }
113
114    /// Offload a tensor to CPU memory.
115    #[allow(dead_code)]
116    fn offload_to_cpu(&mut self, key: &str, tensor: Tensor) -> Result<()> {
117        if tensor.size_bytes() >= self.config.offload_threshold {
118            let start_time = std::time::Instant::now();
119
120            // Move tensor to CPU
121            let cpu_tensor = tensor.to_device("cpu")?;
122            self.cpu_states.insert(key.to_string(), cpu_tensor);
123
124            // Update statistics
125            self.memory_stats.total_cpu_memory_bytes += tensor.size_bytes();
126            self.memory_stats.transfers_to_cpu += 1;
127            self.memory_stats.transfer_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
128
129            // Remove from GPU if it was there
130            if let Some(gpu_tensor) = self.gpu_states.remove(key) {
131                self.memory_stats.total_gpu_memory_bytes -= gpu_tensor.size_bytes();
132            }
133        } else {
134            // Keep small tensors on GPU
135            self.memory_stats.total_gpu_memory_bytes += tensor.size_bytes();
136            self.gpu_states.insert(key.to_string(), tensor);
137        }
138
139        Ok(())
140    }
141
142    /// Retrieve a tensor from CPU to GPU for computation.
143    fn retrieve_from_cpu(&mut self, key: &str, target_device: &str) -> Result<Option<Tensor>> {
144        if let Some(cpu_tensor) = self.cpu_states.get(key) {
145            let start_time = std::time::Instant::now();
146
147            // Move tensor back to GPU
148            let gpu_tensor = cpu_tensor.to_device(target_device)?;
149            let tensor_size = gpu_tensor.size_bytes();
150
151            // Cache on GPU for immediate use
152            self.gpu_states.insert(key.to_string(), gpu_tensor.clone());
153
154            // Update statistics
155            self.memory_stats.total_gpu_memory_bytes += tensor_size;
156            self.memory_stats.transfers_to_gpu += 1;
157            self.memory_stats.transfer_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
158
159            Ok(Some(gpu_tensor))
160        } else {
161            // Check if already on GPU
162            Ok(self.gpu_states.get(key).cloned())
163        }
164    }
165
166    /// Prefetch tensors that will be needed soon.
167    pub fn prefetch_states(&mut self, keys: &[String], device: &str) -> Result<()> {
168        if !self.config.overlap_transfers {
169            return Ok(());
170        }
171
172        for key in keys {
173            if self.cpu_states.contains_key(key) && !self.gpu_states.contains_key(key) {
174                // Asynchronously transfer to GPU
175                self.retrieve_from_cpu(key, device)?;
176            }
177        }
178
179        Ok(())
180    }
181
182    /// Clean up GPU cache of states that won't be needed soon.
183    pub fn evict_unused_states(&mut self, keep_keys: &[String]) -> Result<()> {
184        let mut to_remove = Vec::new();
185
186        for key in self.gpu_states.keys() {
187            if !keep_keys.contains(&key.to_string()) && self.cpu_states.contains_key(key) {
188                to_remove.push(key.clone());
189            }
190        }
191
192        for key in to_remove {
193            if let Some(tensor) = self.gpu_states.remove(&key) {
194                self.memory_stats.total_gpu_memory_bytes -= tensor.size_bytes();
195            }
196        }
197
198        Ok(())
199    }
200
201    /// Get the configuration.
202    pub fn get_config(&self) -> &CPUOffloadConfig {
203        &self.config
204    }
205
206    /// Update the configuration.
207    pub fn set_config(&mut self, config: CPUOffloadConfig) {
208        self.config = config;
209    }
210}
211
212impl<T: Optimizer> Optimizer for CPUOffloadedOptimizer<T> {
213    fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
214        self.base_optimizer.update(parameter, grad)
215    }
216
217    fn zero_grad(&mut self) {
218        self.base_optimizer.zero_grad()
219    }
220
221    fn step(&mut self) {
222        self.base_optimizer.step()
223    }
224
225    fn get_lr(&self) -> f32 {
226        self.base_optimizer.get_lr()
227    }
228
229    fn set_lr(&mut self, lr: f32) {
230        self.base_optimizer.set_lr(lr)
231    }
232}
233
234impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
235    #[allow(dead_code)]
236    fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
237        // Combine CPU and GPU states
238        let mut state = self.base_optimizer.state_dict()?;
239
240        // Add CPU-stored states
241        for (key, tensor) in &self.cpu_states {
242            state.insert(format!("cpu_{}", key), tensor.clone());
243        }
244
245        Ok(state)
246    }
247
248    #[allow(dead_code)]
249    fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
250        let mut base_state = HashMap::new();
251        let mut cpu_state = HashMap::new();
252
253        // Separate CPU and base optimizer states
254        for (key, tensor) in state {
255            if let Some(cpu_key) = key.strip_prefix("cpu_") {
256                cpu_state.insert(cpu_key.to_string(), tensor);
257            } else {
258                base_state.insert(key, tensor);
259            }
260        }
261
262        // Load base optimizer state
263        self.base_optimizer.load_state_dict(base_state)?;
264
265        // Load CPU states
266        self.cpu_states = cpu_state;
267
268        Ok(())
269    }
270}
271
272impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
273    /// Helper method to offload states after optimization step.
274    /// Accesses the optimizer's internal states and offloads them to CPU.
275    #[allow(dead_code)]
276    fn offload_states_after_step(&mut self, param_names: &[String]) -> Result<()> {
277        if !self.config.offload_optimizer_states {
278            return Ok(());
279        }
280
281        // Get the current optimizer states
282        let current_states = self.base_optimizer.state_dict()?;
283
284        // Offload states for specified parameters
285        for param_name in param_names {
286            // Look for optimizer states related to this parameter
287            for (state_key, state_tensor) in &current_states {
288                // Check if this state belongs to the current parameter
289                // Common patterns: "param_name.momentum", "param_name.variance", etc.
290                if state_key.starts_with(param_name) || state_key.contains(param_name) {
291                    // Only offload if the tensor is large enough and currently on GPU
292                    if state_tensor.size_bytes() >= self.config.offload_threshold {
293                        let device = state_tensor.device();
294
295                        // If tensor is on GPU, offload it to CPU
296                        if device.starts_with("cuda") || device.starts_with("gpu") {
297                            self.offload_to_cpu(state_key, state_tensor.clone())?;
298
299                            // Update statistics (already handled in offload_to_cpu method)
300                        }
301                    }
302                }
303            }
304
305            // Also handle any existing GPU states for this parameter
306            let keys_to_offload: Vec<String> = self
307                .gpu_states
308                .keys()
309                .filter(|key| key.starts_with(param_name) || key.contains(param_name))
310                .cloned()
311                .collect();
312
313            for key in keys_to_offload {
314                if let Some(gpu_tensor) = self.gpu_states.get(&key).cloned() {
315                    self.offload_to_cpu(&key, gpu_tensor)?;
316                }
317            }
318        }
319
320        Ok(())
321    }
322}
323
324/// Convenience function to create a CPU-offloaded Adam optimizer.
325pub fn create_cpu_offloaded_adam(
326    learning_rate: f32,
327    beta1: f32,
328    beta2: f32,
329    epsilon: f32,
330    weight_decay: f32,
331    config: Option<CPUOffloadConfig>,
332) -> CPUOffloadedOptimizer<crate::adam::Adam> {
333    let adam = crate::adam::Adam::new(learning_rate, (beta1, beta2), epsilon, weight_decay);
334    CPUOffloadedOptimizer::new(adam, config.unwrap_or_default())
335}
336
337/// Convenience function to create a CPU-offloaded AdamW optimizer.
338pub fn create_cpu_offloaded_adamw(
339    learning_rate: f32,
340    beta1: f32,
341    beta2: f32,
342    epsilon: f32,
343    weight_decay: f32,
344    config: Option<CPUOffloadConfig>,
345) -> CPUOffloadedOptimizer<crate::adam::AdamW> {
346    let adamw = crate::adam::AdamW::new(learning_rate, (beta1, beta2), epsilon, weight_decay);
347    CPUOffloadedOptimizer::new(adamw, config.unwrap_or_default())
348}
349
350/// Convenience function to create a CPU-offloaded SGD optimizer.
351pub fn create_cpu_offloaded_sgd(
352    learning_rate: f32,
353    momentum: f32,
354    _dampening: f32,
355    weight_decay: f32,
356    nesterov: bool,
357    config: Option<CPUOffloadConfig>,
358) -> CPUOffloadedOptimizer<crate::sgd::SGD> {
359    let sgd = crate::sgd::SGD::new(learning_rate, momentum, weight_decay, nesterov);
360    CPUOffloadedOptimizer::new(sgd, config.unwrap_or_default())
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    #[test]
368    fn test_cpu_offload_config_default() {
369        let config = CPUOffloadConfig::default();
370        assert!(config.offload_optimizer_states);
371        assert!(!config.offload_gradients);
372        assert!(!config.offload_parameters);
373        assert!(config.overlap_transfers);
374        assert!(config.pin_memory);
375        assert_eq!(config.offload_threshold, 1024 * 1024);
376    }
377
378    #[test]
379    fn test_memory_stats() {
380        let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
381        let optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
382
383        let stats = optimizer.get_memory_stats();
384        assert_eq!(stats.total_cpu_memory_bytes, 0);
385        assert_eq!(stats.total_gpu_memory_bytes, 0);
386        assert_eq!(stats.transfers_to_cpu, 0);
387        assert_eq!(stats.transfers_to_gpu, 0);
388    }
389
390    #[test]
391    fn test_memory_savings_calculation() {
392        let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
393        let optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
394
395        // With no memory allocated, savings should be 0%
396        assert_eq!(optimizer.get_memory_savings_percent(), 0.0);
397        assert_eq!(optimizer.get_memory_savings_bytes(), 0);
398    }
399
400    #[test]
401    fn test_convenience_functions() {
402        let _adam_offload = create_cpu_offloaded_adam(1e-3, 0.9, 0.999, 1e-8, 0.01, None);
403        let _adamw_offload = create_cpu_offloaded_adamw(1e-3, 0.9, 0.999, 1e-8, 0.01, None);
404        let _sgd_offload = create_cpu_offloaded_sgd(1e-2, 0.9, 0.0, 1e-4, false, None);
405
406        // Test passes if no panics occur during construction
407    }
408
409    #[test]
410    fn test_config_update() {
411        let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
412        let mut optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
413
414        let mut new_config = CPUOffloadConfig::default();
415        new_config.offload_gradients = true;
416        new_config.offload_threshold = 2048;
417
418        optimizer.set_config(new_config.clone());
419
420        assert_eq!(optimizer.get_config().offload_gradients, true);
421        assert_eq!(optimizer.get_config().offload_threshold, 2048);
422    }
423}