tensorlogic_trustformers/
checkpointing.rs1use serde::{Deserialize, Serialize};
41
42use crate::error::{Result, TrustformerError};
43
44#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct CheckpointConfig {
47 pub strategy: CheckpointStrategy,
49 pub checkpoint_attention: bool,
51 pub checkpoint_ffn: bool,
53 pub min_checkpoint_interval: usize,
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
59pub enum CheckpointStrategy {
60 None,
62 Uniform { interval: usize },
64 Selective { layers: Vec<usize> },
66 Dynamic {
68 num_layers: usize,
70 memory_fraction: f64,
72 },
73}
74
75impl CheckpointConfig {
76 pub fn uniform(interval: usize) -> Self {
81 Self {
82 strategy: CheckpointStrategy::Uniform { interval },
83 checkpoint_attention: true,
84 checkpoint_ffn: true,
85 min_checkpoint_interval: 1,
86 }
87 }
88
89 pub fn selective(layers: Vec<usize>) -> Self {
94 Self {
95 strategy: CheckpointStrategy::Selective { layers },
96 checkpoint_attention: true,
97 checkpoint_ffn: true,
98 min_checkpoint_interval: 1,
99 }
100 }
101
102 pub fn dynamic(num_layers: usize, memory_fraction: f64) -> Result<Self> {
108 if num_layers == 0 {
109 return Err(TrustformerError::InvalidDimension {
110 expected: 1,
111 got: 0,
112 context: "num_layers must be > 0".to_string(),
113 });
114 }
115
116 if !(0.0..=1.0).contains(&memory_fraction) {
117 return Err(TrustformerError::InvalidDimension {
118 expected: 1,
119 got: 0,
120 context: format!(
121 "memory_fraction must be in [0.0, 1.0], got {}",
122 memory_fraction
123 ),
124 });
125 }
126
127 Ok(Self {
128 strategy: CheckpointStrategy::Dynamic {
129 num_layers,
130 memory_fraction,
131 },
132 checkpoint_attention: true,
133 checkpoint_ffn: true,
134 min_checkpoint_interval: 1,
135 })
136 }
137
138 pub fn none() -> Self {
140 Self {
141 strategy: CheckpointStrategy::None,
142 checkpoint_attention: false,
143 checkpoint_ffn: false,
144 min_checkpoint_interval: 1,
145 }
146 }
147
148 pub fn with_checkpoint_attention(mut self, checkpoint: bool) -> Self {
150 self.checkpoint_attention = checkpoint;
151 self
152 }
153
154 pub fn with_checkpoint_ffn(mut self, checkpoint: bool) -> Self {
156 self.checkpoint_ffn = checkpoint;
157 self
158 }
159
160 pub fn with_min_interval(mut self, interval: usize) -> Self {
162 self.min_checkpoint_interval = interval;
163 self
164 }
165
166 pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
168 match &self.strategy {
169 CheckpointStrategy::None => false,
170 CheckpointStrategy::Uniform { interval } => {
171 *interval > 0 && layer_idx.is_multiple_of(*interval)
172 }
173 CheckpointStrategy::Selective { layers } => layers.contains(&layer_idx),
174 CheckpointStrategy::Dynamic {
175 num_layers,
176 memory_fraction,
177 } => {
178 if *num_layers == 0 {
185 return false;
186 }
187
188 let target_interval = (*memory_fraction * *num_layers as f64).max(1.0) as usize;
189 let interval = target_interval.max(self.min_checkpoint_interval);
190 interval > 0 && layer_idx.is_multiple_of(interval)
191 }
192 }
193 }
194
195 pub fn memory_savings(&self, num_layers: usize) -> f64 {
199 if num_layers == 0 {
200 return 0.0;
201 }
202
203 match &self.strategy {
204 CheckpointStrategy::None => 0.0,
205 CheckpointStrategy::Uniform { interval } => {
206 let interval_val = *interval;
207 if interval_val == 0 || interval_val >= num_layers {
208 return 0.0;
209 }
210 let num_checkpoints = num_layers.div_ceil(interval_val);
212 1.0 - (num_checkpoints as f64 / num_layers as f64)
213 }
214 CheckpointStrategy::Selective { layers } => {
215 if layers.is_empty() {
216 return 0.0;
217 }
218 1.0 - (layers.len() as f64 / num_layers as f64)
219 }
220 CheckpointStrategy::Dynamic {
221 memory_fraction, ..
222 } => {
223 1.0 - memory_fraction
225 }
226 }
227 }
228
229 pub fn compute_overhead(&self, num_layers: usize) -> f64 {
233 if num_layers == 0 {
234 return 1.0;
235 }
236
237 match &self.strategy {
238 CheckpointStrategy::None => 1.0,
239 CheckpointStrategy::Uniform { interval } => {
240 if *interval == 0 || *interval >= num_layers {
241 return 1.0;
242 }
243 1.0 + (*interval as f64 / 2.0) / num_layers as f64
247 }
248 CheckpointStrategy::Selective { layers } => {
249 if layers.is_empty() {
250 return 1.0;
251 }
252 let avg_interval = num_layers as f64 / layers.len() as f64;
254 1.0 + (avg_interval / 2.0) / num_layers as f64
255 }
256 CheckpointStrategy::Dynamic {
257 memory_fraction, ..
258 } => {
259 1.0 + (1.0 - memory_fraction) * 0.3 }
262 }
263 }
264
265 pub fn validate(&self) -> Result<()> {
267 match &self.strategy {
268 CheckpointStrategy::None => Ok(()),
269 CheckpointStrategy::Uniform { interval } => {
270 if *interval == 0 {
271 return Err(TrustformerError::InvalidDimension {
272 expected: 1,
273 got: 0,
274 context: "checkpoint interval must be > 0".to_string(),
275 });
276 }
277 Ok(())
278 }
279 CheckpointStrategy::Selective { layers } => {
280 let mut sorted = layers.clone();
282 sorted.sort_unstable();
283 sorted.dedup();
284 if sorted.len() != layers.len() {
285 return Err(TrustformerError::InvalidDimension {
286 expected: sorted.len(),
287 got: layers.len(),
288 context: "duplicate layer indices in selective checkpointing".to_string(),
289 });
290 }
291 Ok(())
292 }
293 CheckpointStrategy::Dynamic {
294 num_layers,
295 memory_fraction,
296 } => {
297 if *num_layers == 0 {
298 return Err(TrustformerError::InvalidDimension {
299 expected: 1,
300 got: 0,
301 context: "num_layers must be > 0".to_string(),
302 });
303 }
304 if !(0.0..=1.0).contains(memory_fraction) {
305 return Err(TrustformerError::InvalidDimension {
306 expected: 1,
307 got: 0,
308 context: format!(
309 "memory_fraction must be in [0.0, 1.0], got {}",
310 memory_fraction
311 ),
312 });
313 }
314 Ok(())
315 }
316 }
317 }
318
319 pub fn summary(&self) -> String {
321 match &self.strategy {
322 CheckpointStrategy::None => "No checkpointing".to_string(),
323 CheckpointStrategy::Uniform { interval } => {
324 format!("Uniform checkpointing every {} layers", interval)
325 }
326 CheckpointStrategy::Selective { layers } => {
327 format!("Selective checkpointing at {} layers", layers.len())
328 }
329 CheckpointStrategy::Dynamic {
330 num_layers,
331 memory_fraction,
332 } => {
333 format!(
334 "Dynamic checkpointing ({} layers, {:.1}% memory target)",
335 num_layers,
336 memory_fraction * 100.0
337 )
338 }
339 }
340 }
341}
342
343impl Default for CheckpointConfig {
344 fn default() -> Self {
345 Self::none()
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_uniform_checkpointing() {
355 let config = CheckpointConfig::uniform(2);
356 assert!(config.should_checkpoint(0));
357 assert!(!config.should_checkpoint(1));
358 assert!(config.should_checkpoint(2));
359 assert!(!config.should_checkpoint(3));
360 assert!(config.should_checkpoint(4));
361 }
362
363 #[test]
364 fn test_selective_checkpointing() {
365 let config = CheckpointConfig::selective(vec![0, 3, 7]);
366 assert!(config.should_checkpoint(0));
367 assert!(!config.should_checkpoint(1));
368 assert!(!config.should_checkpoint(2));
369 assert!(config.should_checkpoint(3));
370 assert!(!config.should_checkpoint(6));
371 assert!(config.should_checkpoint(7));
372 }
373
374 #[test]
375 fn test_dynamic_checkpointing() {
376 let config = CheckpointConfig::dynamic(12, 0.3).unwrap();
377 assert!(config.validate().is_ok());
379
380 let checkpointed_count = (0..12).filter(|&i| config.should_checkpoint(i)).count();
382 assert!(checkpointed_count > 0);
383 assert!(checkpointed_count < 12);
384 }
385
386 #[test]
387 fn test_no_checkpointing() {
388 let config = CheckpointConfig::none();
389 assert!(!config.should_checkpoint(0));
390 assert!(!config.should_checkpoint(5));
391 assert!(!config.should_checkpoint(10));
392 }
393
394 #[test]
395 fn test_memory_savings_uniform() {
396 let config = CheckpointConfig::uniform(3);
397 let savings = config.memory_savings(12);
398 assert!((savings - 2.0 / 3.0).abs() < 0.01);
400 }
401
402 #[test]
403 fn test_memory_savings_selective() {
404 let config = CheckpointConfig::selective(vec![0, 6]);
405 let savings = config.memory_savings(12);
406 assert!((savings - 10.0 / 12.0).abs() < 0.01);
408 }
409
410 #[test]
411 fn test_compute_overhead() {
412 let config = CheckpointConfig::uniform(2);
413 let overhead = config.compute_overhead(12);
414 assert!(overhead >= 1.0);
415 assert!(overhead < 2.0); }
417
418 #[test]
419 fn test_invalid_dynamic_memory_fraction() {
420 let result = CheckpointConfig::dynamic(12, 1.5);
421 assert!(result.is_err());
422
423 let result = CheckpointConfig::dynamic(12, -0.1);
424 assert!(result.is_err());
425 }
426
427 #[test]
428 fn test_builder_pattern() {
429 let config = CheckpointConfig::uniform(2)
430 .with_checkpoint_attention(false)
431 .with_checkpoint_ffn(true)
432 .with_min_interval(2);
433
434 assert!(!config.checkpoint_attention);
435 assert!(config.checkpoint_ffn);
436 assert_eq!(config.min_checkpoint_interval, 2);
437 }
438
439 #[test]
440 fn test_validate_uniform() {
441 let config = CheckpointConfig::uniform(2);
442 assert!(config.validate().is_ok());
443
444 let config = CheckpointConfig::uniform(0);
445 assert!(config.validate().is_err());
446 }
447
448 #[test]
449 fn test_validate_selective_duplicates() {
450 let config = CheckpointConfig::selective(vec![0, 3, 3, 7]);
451 assert!(config.validate().is_err());
452 }
453
454 #[test]
455 fn test_summary() {
456 let config = CheckpointConfig::uniform(2);
457 assert!(config.summary().contains("every 2 layers"));
458
459 let config = CheckpointConfig::selective(vec![0, 3, 7]);
460 assert!(config.summary().contains("3 layers"));
461
462 let config = CheckpointConfig::dynamic(12, 0.3).unwrap();
463 assert!(config.summary().contains("30.0%"));
464 }
465
466 #[test]
467 fn test_default() {
468 let config = CheckpointConfig::default();
469 assert_eq!(config.strategy, CheckpointStrategy::None);
470 assert!(!config.should_checkpoint(0));
471 }
472
473 #[test]
474 fn test_zero_interval_uniform() {
475 let config = CheckpointConfig::uniform(0);
476 assert!(!config.should_checkpoint(0));
477 assert!(!config.should_checkpoint(1));
478 }
479
480 #[test]
481 fn test_dynamic_zero_layers() {
482 let result = CheckpointConfig::dynamic(0, 0.5);
483 assert!(result.is_err());
484 }
485
486 #[test]
487 fn test_memory_savings_edge_cases() {
488 let config = CheckpointConfig::uniform(2);
490 assert_eq!(config.memory_savings(0), 0.0);
491
492 let config = CheckpointConfig::uniform(20);
494 assert_eq!(config.memory_savings(10), 0.0);
495
496 let config = CheckpointConfig::selective(vec![]);
498 assert_eq!(config.memory_savings(10), 0.0);
499 }
500}