Skip to main content

tensorlogic_trustformers/
checkpointing.rs

1//! Gradient checkpointing for memory-efficient training.
2//!
3//! Gradient checkpointing trades compute for memory by recomputing activations
4//! during the backward pass instead of storing them. This allows training much
5//! larger models or using larger batch sizes.
6//!
7//! ## How It Works
8//!
9//! Without checkpointing:
10//! ```text
11//! Forward: x → layer1 → layer2 → layer3 → loss
12//!          ↓     ↓       ↓       ↓
13//!        store  store   store   store (memory)
14//! ```
15//!
16//! With checkpointing:
17//! ```text
18//! Forward: x → layer1 → [checkpoint] → layer2 → [checkpoint] → layer3 → loss
19//!          ↓                             ↓                       ↓
20//!        store                         store                   store
21//!
22//! Backward: Recompute layer1 and layer2 activations as needed
23//! ```
24//!
25//! ## Usage
26//!
27//! ```rust
28//! use tensorlogic_trustformers::{CheckpointConfig, CheckpointStrategy};
29//!
30//! // Checkpoint every 2 layers
31//! let config = CheckpointConfig::uniform(2);
32//!
33//! // Checkpoint specific layers
34//! let config = CheckpointConfig::selective(vec![0, 3, 6, 9]);
35//!
36//! // Dynamic checkpointing (more frequent in deeper layers)
37//! let config = CheckpointConfig::dynamic(12, 0.3); // 30% memory target
38//! ```
39
40use serde::{Deserialize, Serialize};
41
42use crate::error::{Result, TrustformerError};
43
44/// Gradient checkpointing configuration
45#[derive(Clone, Debug, Serialize, Deserialize)]
46pub struct CheckpointConfig {
47    /// Checkpointing strategy
48    pub strategy: CheckpointStrategy,
49    /// Whether to checkpoint attention
50    pub checkpoint_attention: bool,
51    /// Whether to checkpoint feed-forward
52    pub checkpoint_ffn: bool,
53    /// Minimum layers between checkpoints
54    pub min_checkpoint_interval: usize,
55}
56
57/// Strategy for placing gradient checkpoints
58#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
59pub enum CheckpointStrategy {
60    /// No checkpointing
61    None,
62    /// Checkpoint every N layers uniformly
63    Uniform { interval: usize },
64    /// Checkpoint specific layer indices
65    Selective { layers: Vec<usize> },
66    /// Dynamic checkpointing based on memory budget
67    Dynamic {
68        /// Total number of layers
69        num_layers: usize,
70        /// Target memory fraction (0.0 - 1.0)
71        memory_fraction: f64,
72    },
73}
74
75impl CheckpointConfig {
76    /// Create a uniform checkpointing configuration
77    ///
78    /// # Arguments
79    /// * `interval` - Checkpoint every N layers (e.g., 2 means checkpoint layers 0, 2, 4, ...)
80    pub fn uniform(interval: usize) -> Self {
81        Self {
82            strategy: CheckpointStrategy::Uniform { interval },
83            checkpoint_attention: true,
84            checkpoint_ffn: true,
85            min_checkpoint_interval: 1,
86        }
87    }
88
89    /// Create a selective checkpointing configuration
90    ///
91    /// # Arguments
92    /// * `layers` - Specific layer indices to checkpoint
93    pub fn selective(layers: Vec<usize>) -> Self {
94        Self {
95            strategy: CheckpointStrategy::Selective { layers },
96            checkpoint_attention: true,
97            checkpoint_ffn: true,
98            min_checkpoint_interval: 1,
99        }
100    }
101
102    /// Create a dynamic checkpointing configuration
103    ///
104    /// # Arguments
105    /// * `num_layers` - Total number of layers in the model
106    /// * `memory_fraction` - Target memory usage as fraction of full storage (0.0 - 1.0)
107    pub fn dynamic(num_layers: usize, memory_fraction: f64) -> Result<Self> {
108        if num_layers == 0 {
109            return Err(TrustformerError::InvalidDimension {
110                expected: 1,
111                got: 0,
112                context: "num_layers must be > 0".to_string(),
113            });
114        }
115
116        if !(0.0..=1.0).contains(&memory_fraction) {
117            return Err(TrustformerError::InvalidDimension {
118                expected: 1,
119                got: 0,
120                context: format!(
121                    "memory_fraction must be in [0.0, 1.0], got {}",
122                    memory_fraction
123                ),
124            });
125        }
126
127        Ok(Self {
128            strategy: CheckpointStrategy::Dynamic {
129                num_layers,
130                memory_fraction,
131            },
132            checkpoint_attention: true,
133            checkpoint_ffn: true,
134            min_checkpoint_interval: 1,
135        })
136    }
137
138    /// Disable checkpointing
139    pub fn none() -> Self {
140        Self {
141            strategy: CheckpointStrategy::None,
142            checkpoint_attention: false,
143            checkpoint_ffn: false,
144            min_checkpoint_interval: 1,
145        }
146    }
147
148    /// Set whether to checkpoint attention sublayers
149    pub fn with_checkpoint_attention(mut self, checkpoint: bool) -> Self {
150        self.checkpoint_attention = checkpoint;
151        self
152    }
153
154    /// Set whether to checkpoint feed-forward sublayers
155    pub fn with_checkpoint_ffn(mut self, checkpoint: bool) -> Self {
156        self.checkpoint_ffn = checkpoint;
157        self
158    }
159
160    /// Set minimum interval between checkpoints
161    pub fn with_min_interval(mut self, interval: usize) -> Self {
162        self.min_checkpoint_interval = interval;
163        self
164    }
165
166    /// Check if a specific layer should be checkpointed
167    pub fn should_checkpoint(&self, layer_idx: usize) -> bool {
168        match &self.strategy {
169            CheckpointStrategy::None => false,
170            CheckpointStrategy::Uniform { interval } => {
171                *interval > 0 && layer_idx.is_multiple_of(*interval)
172            }
173            CheckpointStrategy::Selective { layers } => layers.contains(&layer_idx),
174            CheckpointStrategy::Dynamic {
175                num_layers,
176                memory_fraction,
177            } => {
178                // Calculate optimal checkpoint interval for target memory fraction
179                // Memory without checkpointing: O(n * d^2) for n layers
180                // Memory with checkpointing every k layers: O(k * d^2)
181                // Target: k * d^2 = memory_fraction * n * d^2
182                // Therefore: k = memory_fraction * n
183
184                if *num_layers == 0 {
185                    return false;
186                }
187
188                let target_interval = (*memory_fraction * *num_layers as f64).max(1.0) as usize;
189                let interval = target_interval.max(self.min_checkpoint_interval);
190                interval > 0 && layer_idx.is_multiple_of(interval)
191            }
192        }
193    }
194
195    /// Calculate expected memory savings
196    ///
197    /// Returns the fraction of activation memory saved (0.0 - 1.0)
198    pub fn memory_savings(&self, num_layers: usize) -> f64 {
199        if num_layers == 0 {
200            return 0.0;
201        }
202
203        match &self.strategy {
204            CheckpointStrategy::None => 0.0,
205            CheckpointStrategy::Uniform { interval } => {
206                let interval_val = *interval;
207                if interval_val == 0 || interval_val >= num_layers {
208                    return 0.0;
209                }
210                // We store activations at checkpoint boundaries only
211                let num_checkpoints = num_layers.div_ceil(interval_val);
212                1.0 - (num_checkpoints as f64 / num_layers as f64)
213            }
214            CheckpointStrategy::Selective { layers } => {
215                if layers.is_empty() {
216                    return 0.0;
217                }
218                1.0 - (layers.len() as f64 / num_layers as f64)
219            }
220            CheckpointStrategy::Dynamic {
221                memory_fraction, ..
222            } => {
223                // Dynamic strategy aims to use memory_fraction of full storage
224                1.0 - memory_fraction
225            }
226        }
227    }
228
229    /// Calculate expected compute overhead
230    ///
231    /// Returns the multiplicative factor for compute (1.0 = no overhead)
232    pub fn compute_overhead(&self, num_layers: usize) -> f64 {
233        if num_layers == 0 {
234            return 1.0;
235        }
236
237        match &self.strategy {
238            CheckpointStrategy::None => 1.0,
239            CheckpointStrategy::Uniform { interval } => {
240                if *interval == 0 || *interval >= num_layers {
241                    return 1.0;
242                }
243                // We recompute layers between checkpoints during backward pass
244                // Each layer is computed once in forward, and segments are recomputed in backward
245                // Overhead ≈ 1 + (average segment length / 2)
246                1.0 + (*interval as f64 / 2.0) / num_layers as f64
247            }
248            CheckpointStrategy::Selective { layers } => {
249                if layers.is_empty() {
250                    return 1.0;
251                }
252                // Average interval between checkpoints
253                let avg_interval = num_layers as f64 / layers.len() as f64;
254                1.0 + (avg_interval / 2.0) / num_layers as f64
255            }
256            CheckpointStrategy::Dynamic {
257                memory_fraction, ..
258            } => {
259                // Compute overhead scales with memory savings
260                1.0 + (1.0 - memory_fraction) * 0.3 // ~30% overhead for full checkpointing
261            }
262        }
263    }
264
265    /// Validate configuration
266    pub fn validate(&self) -> Result<()> {
267        match &self.strategy {
268            CheckpointStrategy::None => Ok(()),
269            CheckpointStrategy::Uniform { interval } => {
270                if *interval == 0 {
271                    return Err(TrustformerError::InvalidDimension {
272                        expected: 1,
273                        got: 0,
274                        context: "checkpoint interval must be > 0".to_string(),
275                    });
276                }
277                Ok(())
278            }
279            CheckpointStrategy::Selective { layers } => {
280                // Check for duplicates
281                let mut sorted = layers.clone();
282                sorted.sort_unstable();
283                sorted.dedup();
284                if sorted.len() != layers.len() {
285                    return Err(TrustformerError::InvalidDimension {
286                        expected: sorted.len(),
287                        got: layers.len(),
288                        context: "duplicate layer indices in selective checkpointing".to_string(),
289                    });
290                }
291                Ok(())
292            }
293            CheckpointStrategy::Dynamic {
294                num_layers,
295                memory_fraction,
296            } => {
297                if *num_layers == 0 {
298                    return Err(TrustformerError::InvalidDimension {
299                        expected: 1,
300                        got: 0,
301                        context: "num_layers must be > 0".to_string(),
302                    });
303                }
304                if !(0.0..=1.0).contains(memory_fraction) {
305                    return Err(TrustformerError::InvalidDimension {
306                        expected: 1,
307                        got: 0,
308                        context: format!(
309                            "memory_fraction must be in [0.0, 1.0], got {}",
310                            memory_fraction
311                        ),
312                    });
313                }
314                Ok(())
315            }
316        }
317    }
318
319    /// Get human-readable summary
320    pub fn summary(&self) -> String {
321        match &self.strategy {
322            CheckpointStrategy::None => "No checkpointing".to_string(),
323            CheckpointStrategy::Uniform { interval } => {
324                format!("Uniform checkpointing every {} layers", interval)
325            }
326            CheckpointStrategy::Selective { layers } => {
327                format!("Selective checkpointing at {} layers", layers.len())
328            }
329            CheckpointStrategy::Dynamic {
330                num_layers,
331                memory_fraction,
332            } => {
333                format!(
334                    "Dynamic checkpointing ({} layers, {:.1}% memory target)",
335                    num_layers,
336                    memory_fraction * 100.0
337                )
338            }
339        }
340    }
341}
342
343impl Default for CheckpointConfig {
344    fn default() -> Self {
345        Self::none()
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_uniform_checkpointing() {
355        let config = CheckpointConfig::uniform(2);
356        assert!(config.should_checkpoint(0));
357        assert!(!config.should_checkpoint(1));
358        assert!(config.should_checkpoint(2));
359        assert!(!config.should_checkpoint(3));
360        assert!(config.should_checkpoint(4));
361    }
362
363    #[test]
364    fn test_selective_checkpointing() {
365        let config = CheckpointConfig::selective(vec![0, 3, 7]);
366        assert!(config.should_checkpoint(0));
367        assert!(!config.should_checkpoint(1));
368        assert!(!config.should_checkpoint(2));
369        assert!(config.should_checkpoint(3));
370        assert!(!config.should_checkpoint(6));
371        assert!(config.should_checkpoint(7));
372    }
373
374    #[test]
375    fn test_dynamic_checkpointing() {
376        let config = CheckpointConfig::dynamic(12, 0.3).unwrap();
377        // With 12 layers and 30% memory target, checkpoint every ~4 layers
378        assert!(config.validate().is_ok());
379
380        // Check some layers
381        let checkpointed_count = (0..12).filter(|&i| config.should_checkpoint(i)).count();
382        assert!(checkpointed_count > 0);
383        assert!(checkpointed_count < 12);
384    }
385
386    #[test]
387    fn test_no_checkpointing() {
388        let config = CheckpointConfig::none();
389        assert!(!config.should_checkpoint(0));
390        assert!(!config.should_checkpoint(5));
391        assert!(!config.should_checkpoint(10));
392    }
393
394    #[test]
395    fn test_memory_savings_uniform() {
396        let config = CheckpointConfig::uniform(3);
397        let savings = config.memory_savings(12);
398        // With interval 3, we checkpoint layers 0, 3, 6, 9 (4 checkpoints out of 12)
399        assert!((savings - 2.0 / 3.0).abs() < 0.01);
400    }
401
402    #[test]
403    fn test_memory_savings_selective() {
404        let config = CheckpointConfig::selective(vec![0, 6]);
405        let savings = config.memory_savings(12);
406        // 2 checkpoints out of 12 layers
407        assert!((savings - 10.0 / 12.0).abs() < 0.01);
408    }
409
410    #[test]
411    fn test_compute_overhead() {
412        let config = CheckpointConfig::uniform(2);
413        let overhead = config.compute_overhead(12);
414        assert!(overhead >= 1.0);
415        assert!(overhead < 2.0); // Should be modest overhead
416    }
417
418    #[test]
419    fn test_invalid_dynamic_memory_fraction() {
420        let result = CheckpointConfig::dynamic(12, 1.5);
421        assert!(result.is_err());
422
423        let result = CheckpointConfig::dynamic(12, -0.1);
424        assert!(result.is_err());
425    }
426
427    #[test]
428    fn test_builder_pattern() {
429        let config = CheckpointConfig::uniform(2)
430            .with_checkpoint_attention(false)
431            .with_checkpoint_ffn(true)
432            .with_min_interval(2);
433
434        assert!(!config.checkpoint_attention);
435        assert!(config.checkpoint_ffn);
436        assert_eq!(config.min_checkpoint_interval, 2);
437    }
438
439    #[test]
440    fn test_validate_uniform() {
441        let config = CheckpointConfig::uniform(2);
442        assert!(config.validate().is_ok());
443
444        let config = CheckpointConfig::uniform(0);
445        assert!(config.validate().is_err());
446    }
447
448    #[test]
449    fn test_validate_selective_duplicates() {
450        let config = CheckpointConfig::selective(vec![0, 3, 3, 7]);
451        assert!(config.validate().is_err());
452    }
453
454    #[test]
455    fn test_summary() {
456        let config = CheckpointConfig::uniform(2);
457        assert!(config.summary().contains("every 2 layers"));
458
459        let config = CheckpointConfig::selective(vec![0, 3, 7]);
460        assert!(config.summary().contains("3 layers"));
461
462        let config = CheckpointConfig::dynamic(12, 0.3).unwrap();
463        assert!(config.summary().contains("30.0%"));
464    }
465
466    #[test]
467    fn test_default() {
468        let config = CheckpointConfig::default();
469        assert_eq!(config.strategy, CheckpointStrategy::None);
470        assert!(!config.should_checkpoint(0));
471    }
472
473    #[test]
474    fn test_zero_interval_uniform() {
475        let config = CheckpointConfig::uniform(0);
476        assert!(!config.should_checkpoint(0));
477        assert!(!config.should_checkpoint(1));
478    }
479
480    #[test]
481    fn test_dynamic_zero_layers() {
482        let result = CheckpointConfig::dynamic(0, 0.5);
483        assert!(result.is_err());
484    }
485
486    #[test]
487    fn test_memory_savings_edge_cases() {
488        // No layers
489        let config = CheckpointConfig::uniform(2);
490        assert_eq!(config.memory_savings(0), 0.0);
491
492        // Interval >= num_layers
493        let config = CheckpointConfig::uniform(20);
494        assert_eq!(config.memory_savings(10), 0.0);
495
496        // Empty selective
497        let config = CheckpointConfig::selective(vec![]);
498        assert_eq!(config.memory_savings(10), 0.0);
499    }
500}