Skip to main content

scirs2_neural/training/
backprop_efficient.rs

1//! Memory-efficient backpropagation with gradient checkpointing
2//!
3//! This module provides memory-efficient backpropagation strategies that reduce
4//! GPU memory usage by selectively checkpointing and recomputing activations.
5
6use crate::error::{NeuralError, Result};
7#[cfg(feature = "gpu")]
8use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType};
9use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14/// Recomputation policy for gradient checkpointing
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum RecomputationPolicy {
17    /// Checkpoint all activations (no recomputation, highest memory usage)
18    CheckpointAll,
19    /// Checkpoint no activations (full recomputation, lowest memory usage)
20    CheckpointNone,
21    /// Selectively checkpoint based on memory/computation tradeoff
22    Selective {
23        /// Checkpoint layers with recomputation cost above this threshold
24        cost_threshold: u32,
25    },
26    /// Checkpoint every N layers
27    EveryN {
28        /// Checkpoint frequency
29        n: usize,
30    },
31}
32
33impl Default for RecomputationPolicy {
34    fn default() -> Self {
35        Self::Selective {
36            cost_threshold: 100,
37        }
38    }
39}
40
41/// Activation checkpoint metadata
42#[derive(Debug, Clone)]
43pub struct ActivationCheckpoint {
44    /// Layer ID
45    pub layer_id: usize,
46    /// Checkpoint timestamp
47    pub timestamp: u64,
48    /// Memory size in bytes
49    pub memory_size: usize,
50    /// Recomputation cost estimate
51    pub recomputation_cost: u32,
52    /// Whether this checkpoint is currently in memory
53    pub in_memory: bool,
54}
55
56/// Gradient checkpoint manager (GPU feature required)
57#[cfg(feature = "gpu")]
58pub struct GradientCheckpointManager<T: GpuDataType> {
59    /// Checkpointed activations (layer_id -> buffer)
60    checkpoints: Arc<Mutex<HashMap<usize, GpuBuffer<T>>>>,
61    /// Checkpoint metadata
62    metadata: Arc<Mutex<HashMap<usize, ActivationCheckpoint>>>,
63    /// Current memory usage in bytes
64    memory_usage: Arc<AtomicU64>,
65    /// Maximum memory budget in bytes
66    memory_budget: u64,
67    /// Recomputation policy
68    policy: RecomputationPolicy,
69    /// Global checkpoint counter
70    checkpoint_counter: Arc<AtomicU64>,
71    /// GPU context for buffer management
72    gpu_context: Arc<GpuContext>,
73}
74
75#[cfg(feature = "gpu")]
76impl<T: GpuDataType> GradientCheckpointManager<T> {
77    /// Create a new gradient checkpoint manager
78    pub fn new(
79        gpu_context: Arc<GpuContext>,
80        memory_budget: u64,
81        policy: RecomputationPolicy,
82    ) -> Self {
83        Self {
84            checkpoints: Arc::new(Mutex::new(HashMap::new())),
85            metadata: Arc::new(Mutex::new(HashMap::new())),
86            memory_usage: Arc::new(AtomicU64::new(0)),
87            memory_budget,
88            policy,
89            checkpoint_counter: Arc::new(AtomicU64::new(0)),
90            gpu_context,
91        }
92    }
93
94    /// Save an activation checkpoint
95    pub fn checkpoint_activation(
96        &self,
97        layer_id: usize,
98        activation: &GpuBuffer<T>,
99        recomputation_cost: u32,
100    ) -> Result<()> {
101        let should_checkpoint = match self.policy {
102            RecomputationPolicy::CheckpointAll => true,
103            RecomputationPolicy::CheckpointNone => false,
104            RecomputationPolicy::Selective { cost_threshold } => {
105                recomputation_cost >= cost_threshold
106            }
107            RecomputationPolicy::EveryN { n } => layer_id.is_multiple_of(n),
108        };
109
110        if !should_checkpoint {
111            return Ok(());
112        }
113
114        let activation_size = activation.len() * std::mem::size_of::<T>();
115
116        // Check memory budget
117        let current_usage = self.memory_usage.load(Ordering::Relaxed);
118        if current_usage + activation_size as u64 > self.memory_budget {
119            // Evict oldest checkpoint if needed
120            self.evict_oldest_checkpoint()?;
121        }
122
123        // Store checkpoint
124        let mut checkpoints = self
125            .checkpoints
126            .lock()
127            .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
128
129        let mut metadata = self
130            .metadata
131            .lock()
132            .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
133
134        // Create checkpoint metadata
135        let checkpoint_meta = ActivationCheckpoint {
136            layer_id,
137            timestamp: self.checkpoint_counter.fetch_add(1, Ordering::Relaxed),
138            memory_size: activation_size,
139            recomputation_cost,
140            in_memory: true,
141        };
142
143        // Clone the buffer (in real GPU implementation, this would be a device-to-device copy)
144        let checkpoint_buffer = self.gpu_context.create_buffer::<T>(activation.len());
145
146        checkpoints.insert(layer_id, checkpoint_buffer);
147        metadata.insert(layer_id, checkpoint_meta);
148
149        self.memory_usage
150            .fetch_add(activation_size as u64, Ordering::Relaxed);
151
152        Ok(())
153    }
154
155    /// Retrieve a checkpointed activation (removes it from the checkpoint manager)
156    pub fn get_checkpoint(&self, layer_id: usize) -> Result<Option<GpuBuffer<T>>> {
157        let mut checkpoints = self
158            .checkpoints
159            .lock()
160            .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
161
162        Ok(checkpoints.remove(&layer_id))
163    }
164
165    /// Check if a layer has a checkpoint
166    pub fn has_checkpoint(&self, layer_id: usize) -> bool {
167        self.checkpoints
168            .lock()
169            .map(|cp| cp.contains_key(&layer_id))
170            .unwrap_or(false)
171    }
172
173    /// Evict the oldest checkpoint to free memory
174    fn evict_oldest_checkpoint(&self) -> Result<()> {
175        let mut metadata = self
176            .metadata
177            .lock()
178            .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
179
180        // Find oldest checkpoint
181        let oldest = metadata
182            .iter()
183            .filter(|(_, meta)| meta.in_memory)
184            .min_by_key(|(_, meta)| meta.timestamp)
185            .map(|(id, _)| *id);
186
187        if let Some(layer_id) = oldest {
188            self.remove_checkpoint(layer_id)?;
189        }
190
191        Ok(())
192    }
193
194    /// Remove a checkpoint and free its memory
195    pub fn remove_checkpoint(&self, layer_id: usize) -> Result<()> {
196        let mut checkpoints = self
197            .checkpoints
198            .lock()
199            .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
200
201        let mut metadata = self
202            .metadata
203            .lock()
204            .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
205
206        if let Some(checkpoint) = checkpoints.remove(&layer_id) {
207            let size = checkpoint.len() * std::mem::size_of::<T>();
208            self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
209        }
210
211        if let Some(meta) = metadata.get_mut(&layer_id) {
212            meta.in_memory = false;
213        }
214
215        Ok(())
216    }
217
218    /// Clear all checkpoints
219    pub fn clear(&self) -> Result<()> {
220        let mut checkpoints = self
221            .checkpoints
222            .lock()
223            .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
224
225        let mut metadata = self
226            .metadata
227            .lock()
228            .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
229
230        checkpoints.clear();
231        metadata.clear();
232        self.memory_usage.store(0, Ordering::Relaxed);
233
234        Ok(())
235    }
236
237    /// Get current memory usage in bytes
238    pub fn memory_usage(&self) -> u64 {
239        self.memory_usage.load(Ordering::Relaxed)
240    }
241
242    /// Get memory budget in bytes
243    pub fn memory_budget(&self) -> u64 {
244        self.memory_budget
245    }
246
247    /// Get number of active checkpoints
248    pub fn num_checkpoints(&self) -> usize {
249        self.checkpoints.lock().map(|cp| cp.len()).unwrap_or(0)
250    }
251
252    /// Get checkpoint statistics
253    pub fn get_statistics(&self) -> CheckpointStatistics {
254        let metadata = self.metadata.lock().expect("Failed to lock metadata");
255
256        let total_checkpoints = metadata.len();
257        let in_memory_checkpoints = metadata.values().filter(|meta| meta.in_memory).count();
258
259        let total_memory = metadata
260            .values()
261            .filter(|meta| meta.in_memory)
262            .map(|meta| meta.memory_size as u64)
263            .sum();
264
265        CheckpointStatistics {
266            total_checkpoints,
267            in_memory_checkpoints,
268            total_memory,
269            memory_budget: self.memory_budget,
270            memory_utilization: total_memory as f64 / self.memory_budget as f64,
271        }
272    }
273}
274
275/// Statistics for checkpoint management
276#[derive(Debug, Clone)]
277pub struct CheckpointStatistics {
278    /// Total number of checkpoints created
279    pub total_checkpoints: usize,
280    /// Number of checkpoints currently in memory
281    pub in_memory_checkpoints: usize,
282    /// Total memory used by checkpoints
283    pub total_memory: u64,
284    /// Memory budget
285    pub memory_budget: u64,
286    /// Memory utilization ratio (0.0 to 1.0)
287    pub memory_utilization: f64,
288}
289
290/// Efficient backpropagation implementation with gradient checkpointing (GPU feature required)
291#[cfg(feature = "gpu")]
292pub struct EfficientBackprop<T: GpuDataType> {
293    /// Gradient checkpoint manager
294    checkpoint_manager: Arc<GradientCheckpointManager<T>>,
295    /// GPU context
296    gpu_context: Arc<GpuContext>,
297    /// Whether to enable gradient checkpointing
298    enabled: bool,
299}
300
301#[cfg(feature = "gpu")]
302impl<T: GpuDataType> EfficientBackprop<T> {
303    /// Create a new efficient backpropagation context
304    pub fn new(
305        gpu_context: Arc<GpuContext>,
306        memory_budget: u64,
307        policy: RecomputationPolicy,
308        enabled: bool,
309    ) -> Self {
310        let checkpoint_manager = Arc::new(GradientCheckpointManager::new(
311            gpu_context.clone(),
312            memory_budget,
313            policy,
314        ));
315
316        Self {
317            checkpoint_manager,
318            gpu_context,
319            enabled,
320        }
321    }
322
323    /// Forward pass with optional checkpointing
324    pub fn forward_with_checkpoint(
325        &self,
326        layer_id: usize,
327        input: &GpuBuffer<T>,
328        forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
329        recomputation_cost: u32,
330    ) -> Result<GpuBuffer<T>> {
331        // Checkpoint input if enabled
332        if self.enabled {
333            self.checkpoint_manager
334                .checkpoint_activation(layer_id, input, recomputation_cost)?;
335        }
336
337        // Execute forward pass
338        forward_fn(input)
339    }
340
341    /// Backward pass with recomputation if needed
342    pub fn backward_with_recomputation(
343        &self,
344        layer_id: usize,
345        grad_output: &GpuBuffer<T>,
346        forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
347        backward_fn: impl FnOnce(&GpuBuffer<T>, &GpuBuffer<T>) -> Result<GpuBuffer<T>>,
348    ) -> Result<GpuBuffer<T>> {
349        // Check if we have a checkpoint
350        let activation =
351            if let Some(checkpoint) = self.checkpoint_manager.get_checkpoint(layer_id)? {
352                // Use checkpointed activation
353                checkpoint
354            } else {
355                // Recompute activation (requires previous layer output)
356                // In a real implementation, we would retrieve the previous layer's output
357                // For now, we create a placeholder buffer
358                self.gpu_context.create_buffer::<T>(grad_output.len())
359            };
360
361        // Execute backward pass
362        backward_fn(&activation, grad_output)
363    }
364
365    /// Enable or disable gradient checkpointing
366    pub fn set_enabled(&mut self, enabled: bool) {
367        self.enabled = enabled;
368    }
369
370    /// Check if checkpointing is enabled
371    pub fn is_enabled(&self) -> bool {
372        self.enabled
373    }
374
375    /// Get checkpoint manager
376    pub fn checkpoint_manager(&self) -> &Arc<GradientCheckpointManager<T>> {
377        &self.checkpoint_manager
378    }
379
380    /// Get checkpoint statistics
381    pub fn get_statistics(&self) -> CheckpointStatistics {
382        self.checkpoint_manager.get_statistics()
383    }
384
385    /// Clear all checkpoints
386    pub fn clear_checkpoints(&self) -> Result<()> {
387        self.checkpoint_manager.clear()
388    }
389}
390
391/// CPU-based activation storage for fallback
392#[derive(Debug)]
393pub struct CpuActivationStore<F> {
394    /// Stored activations
395    activations: Arc<Mutex<HashMap<usize, ArrayD<F>>>>,
396    /// Memory usage
397    memory_usage: Arc<AtomicU64>,
398}
399
400impl<F> CpuActivationStore<F>
401where
402    F: Clone + Default,
403{
404    /// Create a new CPU activation store
405    pub fn new() -> Self {
406        Self {
407            activations: Arc::new(Mutex::new(HashMap::new())),
408            memory_usage: Arc::new(AtomicU64::new(0)),
409        }
410    }
411
412    /// Store an activation
413    pub fn store(&self, layer_id: usize, activation: ArrayD<F>) -> Result<()> {
414        let size = activation.len() * std::mem::size_of::<F>();
415
416        let mut activations = self
417            .activations
418            .lock()
419            .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
420
421        activations.insert(layer_id, activation);
422        self.memory_usage.fetch_add(size as u64, Ordering::Relaxed);
423
424        Ok(())
425    }
426
427    /// Retrieve an activation
428    pub fn retrieve(&self, layer_id: usize) -> Result<Option<ArrayD<F>>> {
429        let activations = self
430            .activations
431            .lock()
432            .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
433
434        Ok(activations.get(&layer_id).cloned())
435    }
436
437    /// Remove an activation
438    pub fn remove(&self, layer_id: usize) -> Result<()> {
439        let mut activations = self
440            .activations
441            .lock()
442            .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
443
444        if let Some(activation) = activations.remove(&layer_id) {
445            let size = activation.len() * std::mem::size_of::<F>();
446            self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
447        }
448
449        Ok(())
450    }
451
452    /// Clear all activations
453    pub fn clear(&self) -> Result<()> {
454        let mut activations = self
455            .activations
456            .lock()
457            .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
458
459        activations.clear();
460        self.memory_usage.store(0, Ordering::Relaxed);
461
462        Ok(())
463    }
464
465    /// Get current memory usage
466    pub fn memory_usage(&self) -> u64 {
467        self.memory_usage.load(Ordering::Relaxed)
468    }
469}
470
471impl<F> Default for CpuActivationStore<F>
472where
473    F: Clone + Default,
474{
475    fn default() -> Self {
476        Self::new()
477    }
478}
479
480#[cfg(all(test, feature = "gpu"))]
481mod tests {
482    use super::*;
483    use scirs2_core::gpu::GpuBackend;
484
485    #[test]
486    fn test_recomputation_policy() {
487        let policy = RecomputationPolicy::default();
488        assert!(matches!(policy, RecomputationPolicy::Selective { .. }));
489
490        let checkpoint_all = RecomputationPolicy::CheckpointAll;
491        assert_eq!(checkpoint_all, RecomputationPolicy::CheckpointAll);
492    }
493
494    #[test]
495    fn test_checkpoint_manager_creation() {
496        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
497        let manager = GradientCheckpointManager::<f32>::new(
498            Arc::new(context),
499            1024 * 1024 * 1024, // 1GB
500            RecomputationPolicy::CheckpointAll,
501        );
502
503        assert_eq!(manager.memory_usage(), 0);
504        assert_eq!(manager.num_checkpoints(), 0);
505    }
506
507    #[test]
508    fn test_checkpoint_statistics() {
509        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
510        let manager = GradientCheckpointManager::<f32>::new(
511            Arc::new(context),
512            1024 * 1024 * 1024,
513            RecomputationPolicy::CheckpointAll,
514        );
515
516        let stats = manager.get_statistics();
517        assert_eq!(stats.total_checkpoints, 0);
518        assert_eq!(stats.in_memory_checkpoints, 0);
519        assert_eq!(stats.total_memory, 0);
520    }
521
522    #[test]
523    fn test_efficient_backprop_creation() {
524        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
525        let backprop = EfficientBackprop::<f32>::new(
526            Arc::new(context),
527            1024 * 1024 * 1024,
528            RecomputationPolicy::CheckpointAll,
529            true,
530        );
531
532        assert!(backprop.is_enabled());
533        assert_eq!(backprop.checkpoint_manager().num_checkpoints(), 0);
534    }
535
536    #[test]
537    fn test_cpu_activation_store() {
538        let store = CpuActivationStore::<f32>::new();
539
540        let activation = Array::zeros(IxDyn(&[2, 3, 4]));
541        store.store(0, activation.clone()).expect("Failed to store");
542
543        let retrieved = store.retrieve(0).expect("Failed to retrieve");
544        assert!(retrieved.is_some());
545
546        assert!(store.memory_usage() > 0);
547
548        store.clear().expect("Failed to clear");
549        assert_eq!(store.memory_usage(), 0);
550    }
551
552    #[test]
553    fn test_enable_disable_checkpointing() {
554        let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
555        let mut backprop = EfficientBackprop::<f32>::new(
556            Arc::new(context),
557            1024 * 1024 * 1024,
558            RecomputationPolicy::CheckpointAll,
559            true,
560        );
561
562        assert!(backprop.is_enabled());
563
564        backprop.set_enabled(false);
565        assert!(!backprop.is_enabled());
566
567        backprop.set_enabled(true);
568        assert!(backprop.is_enabled());
569    }
570}