1use crate::StatefulOptimizer;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use trustformers_core::errors::Result;
22use trustformers_core::tensor::Tensor;
23use trustformers_core::traits::Optimizer;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct CPUOffloadConfig {
28 pub offload_optimizer_states: bool,
30 pub offload_gradients: bool,
32 pub offload_parameters: bool,
34 pub overlap_transfers: bool,
36 pub pin_memory: bool,
38 pub offload_threshold: usize,
40}
41
42impl Default for CPUOffloadConfig {
43 fn default() -> Self {
44 Self {
45 offload_optimizer_states: true,
46 offload_gradients: false,
47 offload_parameters: false,
48 overlap_transfers: true,
49 pin_memory: true,
50 offload_threshold: 1024 * 1024, }
52 }
53}
54
55pub struct CPUOffloadedOptimizer<T: Optimizer> {
57 base_optimizer: T,
58 config: CPUOffloadConfig,
59 cpu_states: HashMap<String, Tensor>,
60 gpu_states: HashMap<String, Tensor>,
61 #[allow(dead_code)]
62 transfer_stream: Option<usize>, memory_stats: CPUOffloadStats,
64}
65
66#[derive(Debug, Default)]
67pub struct CPUOffloadStats {
68 pub total_cpu_memory_bytes: usize,
69 pub total_gpu_memory_bytes: usize,
70 pub transfers_to_cpu: usize,
71 pub transfers_to_gpu: usize,
72 pub transfer_time_ms: f64,
73}
74
75impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
76 pub fn new(base_optimizer: T, config: CPUOffloadConfig) -> Self {
78 Self {
79 base_optimizer,
80 config,
81 cpu_states: HashMap::new(),
82 gpu_states: HashMap::new(),
83 transfer_stream: None,
84 memory_stats: CPUOffloadStats::default(),
85 }
86 }
87
88 pub fn with_default_config(base_optimizer: T) -> Self {
90 Self::new(base_optimizer, CPUOffloadConfig::default())
91 }
92
93 pub fn get_memory_stats(&self) -> &CPUOffloadStats {
95 &self.memory_stats
96 }
97
98 pub fn get_memory_savings_bytes(&self) -> usize {
100 self.memory_stats.total_cpu_memory_bytes
101 }
102
103 pub fn get_memory_savings_percent(&self) -> f32 {
105 let total_memory =
106 self.memory_stats.total_cpu_memory_bytes + self.memory_stats.total_gpu_memory_bytes;
107 if total_memory == 0 {
108 0.0
109 } else {
110 (self.memory_stats.total_cpu_memory_bytes as f32 / total_memory as f32) * 100.0
111 }
112 }
113
114 #[allow(dead_code)]
116 fn offload_to_cpu(&mut self, key: &str, tensor: Tensor) -> Result<()> {
117 if tensor.size_bytes() >= self.config.offload_threshold {
118 let start_time = std::time::Instant::now();
119
120 let cpu_tensor = tensor.to_device("cpu")?;
122 self.cpu_states.insert(key.to_string(), cpu_tensor);
123
124 self.memory_stats.total_cpu_memory_bytes += tensor.size_bytes();
126 self.memory_stats.transfers_to_cpu += 1;
127 self.memory_stats.transfer_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
128
129 if let Some(gpu_tensor) = self.gpu_states.remove(key) {
131 self.memory_stats.total_gpu_memory_bytes -= gpu_tensor.size_bytes();
132 }
133 } else {
134 self.memory_stats.total_gpu_memory_bytes += tensor.size_bytes();
136 self.gpu_states.insert(key.to_string(), tensor);
137 }
138
139 Ok(())
140 }
141
142 fn retrieve_from_cpu(&mut self, key: &str, target_device: &str) -> Result<Option<Tensor>> {
144 if let Some(cpu_tensor) = self.cpu_states.get(key) {
145 let start_time = std::time::Instant::now();
146
147 let gpu_tensor = cpu_tensor.to_device(target_device)?;
149 let tensor_size = gpu_tensor.size_bytes();
150
151 self.gpu_states.insert(key.to_string(), gpu_tensor.clone());
153
154 self.memory_stats.total_gpu_memory_bytes += tensor_size;
156 self.memory_stats.transfers_to_gpu += 1;
157 self.memory_stats.transfer_time_ms += start_time.elapsed().as_secs_f64() * 1000.0;
158
159 Ok(Some(gpu_tensor))
160 } else {
161 Ok(self.gpu_states.get(key).cloned())
163 }
164 }
165
166 pub fn prefetch_states(&mut self, keys: &[String], device: &str) -> Result<()> {
168 if !self.config.overlap_transfers {
169 return Ok(());
170 }
171
172 for key in keys {
173 if self.cpu_states.contains_key(key) && !self.gpu_states.contains_key(key) {
174 self.retrieve_from_cpu(key, device)?;
176 }
177 }
178
179 Ok(())
180 }
181
182 pub fn evict_unused_states(&mut self, keep_keys: &[String]) -> Result<()> {
184 let mut to_remove = Vec::new();
185
186 for key in self.gpu_states.keys() {
187 if !keep_keys.contains(&key.to_string()) && self.cpu_states.contains_key(key) {
188 to_remove.push(key.clone());
189 }
190 }
191
192 for key in to_remove {
193 if let Some(tensor) = self.gpu_states.remove(&key) {
194 self.memory_stats.total_gpu_memory_bytes -= tensor.size_bytes();
195 }
196 }
197
198 Ok(())
199 }
200
201 pub fn get_config(&self) -> &CPUOffloadConfig {
203 &self.config
204 }
205
206 pub fn set_config(&mut self, config: CPUOffloadConfig) {
208 self.config = config;
209 }
210}
211
212impl<T: Optimizer> Optimizer for CPUOffloadedOptimizer<T> {
213 fn update(&mut self, parameter: &mut Tensor, grad: &Tensor) -> Result<()> {
214 self.base_optimizer.update(parameter, grad)
215 }
216
217 fn zero_grad(&mut self) {
218 self.base_optimizer.zero_grad()
219 }
220
221 fn step(&mut self) {
222 self.base_optimizer.step()
223 }
224
225 fn get_lr(&self) -> f32 {
226 self.base_optimizer.get_lr()
227 }
228
229 fn set_lr(&mut self, lr: f32) {
230 self.base_optimizer.set_lr(lr)
231 }
232}
233
234impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
235 #[allow(dead_code)]
236 fn state_dict(&self) -> Result<HashMap<String, Tensor>> {
237 let mut state = self.base_optimizer.state_dict()?;
239
240 for (key, tensor) in &self.cpu_states {
242 state.insert(format!("cpu_{}", key), tensor.clone());
243 }
244
245 Ok(state)
246 }
247
248 #[allow(dead_code)]
249 fn load_state_dict(&mut self, state: HashMap<String, Tensor>) -> Result<()> {
250 let mut base_state = HashMap::new();
251 let mut cpu_state = HashMap::new();
252
253 for (key, tensor) in state {
255 if let Some(cpu_key) = key.strip_prefix("cpu_") {
256 cpu_state.insert(cpu_key.to_string(), tensor);
257 } else {
258 base_state.insert(key, tensor);
259 }
260 }
261
262 self.base_optimizer.load_state_dict(base_state)?;
264
265 self.cpu_states = cpu_state;
267
268 Ok(())
269 }
270}
271
272impl<T: Optimizer + StatefulOptimizer> CPUOffloadedOptimizer<T> {
273 #[allow(dead_code)]
276 fn offload_states_after_step(&mut self, param_names: &[String]) -> Result<()> {
277 if !self.config.offload_optimizer_states {
278 return Ok(());
279 }
280
281 let current_states = self.base_optimizer.state_dict()?;
283
284 for param_name in param_names {
286 for (state_key, state_tensor) in ¤t_states {
288 if state_key.starts_with(param_name) || state_key.contains(param_name) {
291 if state_tensor.size_bytes() >= self.config.offload_threshold {
293 let device = state_tensor.device();
294
295 if device.starts_with("cuda") || device.starts_with("gpu") {
297 self.offload_to_cpu(state_key, state_tensor.clone())?;
298
299 }
301 }
302 }
303 }
304
305 let keys_to_offload: Vec<String> = self
307 .gpu_states
308 .keys()
309 .filter(|key| key.starts_with(param_name) || key.contains(param_name))
310 .cloned()
311 .collect();
312
313 for key in keys_to_offload {
314 if let Some(gpu_tensor) = self.gpu_states.get(&key).cloned() {
315 self.offload_to_cpu(&key, gpu_tensor)?;
316 }
317 }
318 }
319
320 Ok(())
321 }
322}
323
324pub fn create_cpu_offloaded_adam(
326 learning_rate: f32,
327 beta1: f32,
328 beta2: f32,
329 epsilon: f32,
330 weight_decay: f32,
331 config: Option<CPUOffloadConfig>,
332) -> CPUOffloadedOptimizer<crate::adam::Adam> {
333 let adam = crate::adam::Adam::new(learning_rate, (beta1, beta2), epsilon, weight_decay);
334 CPUOffloadedOptimizer::new(adam, config.unwrap_or_default())
335}
336
337pub fn create_cpu_offloaded_adamw(
339 learning_rate: f32,
340 beta1: f32,
341 beta2: f32,
342 epsilon: f32,
343 weight_decay: f32,
344 config: Option<CPUOffloadConfig>,
345) -> CPUOffloadedOptimizer<crate::adam::AdamW> {
346 let adamw = crate::adam::AdamW::new(learning_rate, (beta1, beta2), epsilon, weight_decay);
347 CPUOffloadedOptimizer::new(adamw, config.unwrap_or_default())
348}
349
350pub fn create_cpu_offloaded_sgd(
352 learning_rate: f32,
353 momentum: f32,
354 _dampening: f32,
355 weight_decay: f32,
356 nesterov: bool,
357 config: Option<CPUOffloadConfig>,
358) -> CPUOffloadedOptimizer<crate::sgd::SGD> {
359 let sgd = crate::sgd::SGD::new(learning_rate, momentum, weight_decay, nesterov);
360 CPUOffloadedOptimizer::new(sgd, config.unwrap_or_default())
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
368 fn test_cpu_offload_config_default() {
369 let config = CPUOffloadConfig::default();
370 assert!(config.offload_optimizer_states);
371 assert!(!config.offload_gradients);
372 assert!(!config.offload_parameters);
373 assert!(config.overlap_transfers);
374 assert!(config.pin_memory);
375 assert_eq!(config.offload_threshold, 1024 * 1024);
376 }
377
378 #[test]
379 fn test_memory_stats() {
380 let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
381 let optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
382
383 let stats = optimizer.get_memory_stats();
384 assert_eq!(stats.total_cpu_memory_bytes, 0);
385 assert_eq!(stats.total_gpu_memory_bytes, 0);
386 assert_eq!(stats.transfers_to_cpu, 0);
387 assert_eq!(stats.transfers_to_gpu, 0);
388 }
389
390 #[test]
391 fn test_memory_savings_calculation() {
392 let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
393 let optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
394
395 assert_eq!(optimizer.get_memory_savings_percent(), 0.0);
397 assert_eq!(optimizer.get_memory_savings_bytes(), 0);
398 }
399
400 #[test]
401 fn test_convenience_functions() {
402 let _adam_offload = create_cpu_offloaded_adam(1e-3, 0.9, 0.999, 1e-8, 0.01, None);
403 let _adamw_offload = create_cpu_offloaded_adamw(1e-3, 0.9, 0.999, 1e-8, 0.01, None);
404 let _sgd_offload = create_cpu_offloaded_sgd(1e-2, 0.9, 0.0, 1e-4, false, None);
405
406 }
408
409 #[test]
410 fn test_config_update() {
411 let adam = crate::adam::Adam::new(1e-3, (0.9, 0.999), 1e-8, 0.01);
412 let mut optimizer = CPUOffloadedOptimizer::new(adam, CPUOffloadConfig::default());
413
414 let mut new_config = CPUOffloadConfig::default();
415 new_config.offload_gradients = true;
416 new_config.offload_threshold = 2048;
417
418 optimizer.set_config(new_config.clone());
419
420 assert_eq!(optimizer.get_config().offload_gradients, true);
421 assert_eq!(optimizer.get_config().offload_threshold, 2048);
422 }
423}