1use crate::error::{NeuralError, Result};
7#[cfg(feature = "gpu")]
8use scirs2_core::gpu::{GpuBuffer, GpuContext, GpuDataType};
9use scirs2_core::ndarray::{Array, ArrayD, IxDyn};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::{Arc, Mutex};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum RecomputationPolicy {
17 CheckpointAll,
19 CheckpointNone,
21 Selective {
23 cost_threshold: u32,
25 },
26 EveryN {
28 n: usize,
30 },
31}
32
33impl Default for RecomputationPolicy {
34 fn default() -> Self {
35 Self::Selective {
36 cost_threshold: 100,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct ActivationCheckpoint {
44 pub layer_id: usize,
46 pub timestamp: u64,
48 pub memory_size: usize,
50 pub recomputation_cost: u32,
52 pub in_memory: bool,
54}
55
56#[cfg(feature = "gpu")]
58pub struct GradientCheckpointManager<T: GpuDataType> {
59 checkpoints: Arc<Mutex<HashMap<usize, GpuBuffer<T>>>>,
61 metadata: Arc<Mutex<HashMap<usize, ActivationCheckpoint>>>,
63 memory_usage: Arc<AtomicU64>,
65 memory_budget: u64,
67 policy: RecomputationPolicy,
69 checkpoint_counter: Arc<AtomicU64>,
71 gpu_context: Arc<GpuContext>,
73}
74
75#[cfg(feature = "gpu")]
76impl<T: GpuDataType> GradientCheckpointManager<T> {
77 pub fn new(
79 gpu_context: Arc<GpuContext>,
80 memory_budget: u64,
81 policy: RecomputationPolicy,
82 ) -> Self {
83 Self {
84 checkpoints: Arc::new(Mutex::new(HashMap::new())),
85 metadata: Arc::new(Mutex::new(HashMap::new())),
86 memory_usage: Arc::new(AtomicU64::new(0)),
87 memory_budget,
88 policy,
89 checkpoint_counter: Arc::new(AtomicU64::new(0)),
90 gpu_context,
91 }
92 }
93
94 pub fn checkpoint_activation(
96 &self,
97 layer_id: usize,
98 activation: &GpuBuffer<T>,
99 recomputation_cost: u32,
100 ) -> Result<()> {
101 let should_checkpoint = match self.policy {
102 RecomputationPolicy::CheckpointAll => true,
103 RecomputationPolicy::CheckpointNone => false,
104 RecomputationPolicy::Selective { cost_threshold } => {
105 recomputation_cost >= cost_threshold
106 }
107 RecomputationPolicy::EveryN { n } => layer_id.is_multiple_of(n),
108 };
109
110 if !should_checkpoint {
111 return Ok(());
112 }
113
114 let activation_size = activation.len() * std::mem::size_of::<T>();
115
116 let current_usage = self.memory_usage.load(Ordering::Relaxed);
118 if current_usage + activation_size as u64 > self.memory_budget {
119 self.evict_oldest_checkpoint()?;
121 }
122
123 let mut checkpoints = self
125 .checkpoints
126 .lock()
127 .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
128
129 let mut metadata = self
130 .metadata
131 .lock()
132 .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
133
134 let checkpoint_meta = ActivationCheckpoint {
136 layer_id,
137 timestamp: self.checkpoint_counter.fetch_add(1, Ordering::Relaxed),
138 memory_size: activation_size,
139 recomputation_cost,
140 in_memory: true,
141 };
142
143 let checkpoint_buffer = self.gpu_context.create_buffer::<T>(activation.len());
145
146 checkpoints.insert(layer_id, checkpoint_buffer);
147 metadata.insert(layer_id, checkpoint_meta);
148
149 self.memory_usage
150 .fetch_add(activation_size as u64, Ordering::Relaxed);
151
152 Ok(())
153 }
154
155 pub fn get_checkpoint(&self, layer_id: usize) -> Result<Option<GpuBuffer<T>>> {
157 let mut checkpoints = self
158 .checkpoints
159 .lock()
160 .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
161
162 Ok(checkpoints.remove(&layer_id))
163 }
164
165 pub fn has_checkpoint(&self, layer_id: usize) -> bool {
167 self.checkpoints
168 .lock()
169 .map(|cp| cp.contains_key(&layer_id))
170 .unwrap_or(false)
171 }
172
173 fn evict_oldest_checkpoint(&self) -> Result<()> {
175 let mut metadata = self
176 .metadata
177 .lock()
178 .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
179
180 let oldest = metadata
182 .iter()
183 .filter(|(_, meta)| meta.in_memory)
184 .min_by_key(|(_, meta)| meta.timestamp)
185 .map(|(id, _)| *id);
186
187 if let Some(layer_id) = oldest {
188 self.remove_checkpoint(layer_id)?;
189 }
190
191 Ok(())
192 }
193
194 pub fn remove_checkpoint(&self, layer_id: usize) -> Result<()> {
196 let mut checkpoints = self
197 .checkpoints
198 .lock()
199 .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
200
201 let mut metadata = self
202 .metadata
203 .lock()
204 .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
205
206 if let Some(checkpoint) = checkpoints.remove(&layer_id) {
207 let size = checkpoint.len() * std::mem::size_of::<T>();
208 self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
209 }
210
211 if let Some(meta) = metadata.get_mut(&layer_id) {
212 meta.in_memory = false;
213 }
214
215 Ok(())
216 }
217
218 pub fn clear(&self) -> Result<()> {
220 let mut checkpoints = self
221 .checkpoints
222 .lock()
223 .map_err(|_| NeuralError::TrainingError("Failed to lock checkpoints".to_string()))?;
224
225 let mut metadata = self
226 .metadata
227 .lock()
228 .map_err(|_| NeuralError::TrainingError("Failed to lock metadata".to_string()))?;
229
230 checkpoints.clear();
231 metadata.clear();
232 self.memory_usage.store(0, Ordering::Relaxed);
233
234 Ok(())
235 }
236
237 pub fn memory_usage(&self) -> u64 {
239 self.memory_usage.load(Ordering::Relaxed)
240 }
241
242 pub fn memory_budget(&self) -> u64 {
244 self.memory_budget
245 }
246
247 pub fn num_checkpoints(&self) -> usize {
249 self.checkpoints.lock().map(|cp| cp.len()).unwrap_or(0)
250 }
251
252 pub fn get_statistics(&self) -> CheckpointStatistics {
254 let metadata = self.metadata.lock().expect("Failed to lock metadata");
255
256 let total_checkpoints = metadata.len();
257 let in_memory_checkpoints = metadata.values().filter(|meta| meta.in_memory).count();
258
259 let total_memory = metadata
260 .values()
261 .filter(|meta| meta.in_memory)
262 .map(|meta| meta.memory_size as u64)
263 .sum();
264
265 CheckpointStatistics {
266 total_checkpoints,
267 in_memory_checkpoints,
268 total_memory,
269 memory_budget: self.memory_budget,
270 memory_utilization: total_memory as f64 / self.memory_budget as f64,
271 }
272 }
273}
274
275#[derive(Debug, Clone)]
277pub struct CheckpointStatistics {
278 pub total_checkpoints: usize,
280 pub in_memory_checkpoints: usize,
282 pub total_memory: u64,
284 pub memory_budget: u64,
286 pub memory_utilization: f64,
288}
289
290#[cfg(feature = "gpu")]
292pub struct EfficientBackprop<T: GpuDataType> {
293 checkpoint_manager: Arc<GradientCheckpointManager<T>>,
295 gpu_context: Arc<GpuContext>,
297 enabled: bool,
299}
300
301#[cfg(feature = "gpu")]
302impl<T: GpuDataType> EfficientBackprop<T> {
303 pub fn new(
305 gpu_context: Arc<GpuContext>,
306 memory_budget: u64,
307 policy: RecomputationPolicy,
308 enabled: bool,
309 ) -> Self {
310 let checkpoint_manager = Arc::new(GradientCheckpointManager::new(
311 gpu_context.clone(),
312 memory_budget,
313 policy,
314 ));
315
316 Self {
317 checkpoint_manager,
318 gpu_context,
319 enabled,
320 }
321 }
322
323 pub fn forward_with_checkpoint(
325 &self,
326 layer_id: usize,
327 input: &GpuBuffer<T>,
328 forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
329 recomputation_cost: u32,
330 ) -> Result<GpuBuffer<T>> {
331 if self.enabled {
333 self.checkpoint_manager
334 .checkpoint_activation(layer_id, input, recomputation_cost)?;
335 }
336
337 forward_fn(input)
339 }
340
341 pub fn backward_with_recomputation(
343 &self,
344 layer_id: usize,
345 grad_output: &GpuBuffer<T>,
346 forward_fn: impl FnOnce(&GpuBuffer<T>) -> Result<GpuBuffer<T>>,
347 backward_fn: impl FnOnce(&GpuBuffer<T>, &GpuBuffer<T>) -> Result<GpuBuffer<T>>,
348 ) -> Result<GpuBuffer<T>> {
349 let activation =
351 if let Some(checkpoint) = self.checkpoint_manager.get_checkpoint(layer_id)? {
352 checkpoint
354 } else {
355 self.gpu_context.create_buffer::<T>(grad_output.len())
359 };
360
361 backward_fn(&activation, grad_output)
363 }
364
365 pub fn set_enabled(&mut self, enabled: bool) {
367 self.enabled = enabled;
368 }
369
370 pub fn is_enabled(&self) -> bool {
372 self.enabled
373 }
374
375 pub fn checkpoint_manager(&self) -> &Arc<GradientCheckpointManager<T>> {
377 &self.checkpoint_manager
378 }
379
380 pub fn get_statistics(&self) -> CheckpointStatistics {
382 self.checkpoint_manager.get_statistics()
383 }
384
385 pub fn clear_checkpoints(&self) -> Result<()> {
387 self.checkpoint_manager.clear()
388 }
389}
390
391#[derive(Debug)]
393pub struct CpuActivationStore<F> {
394 activations: Arc<Mutex<HashMap<usize, ArrayD<F>>>>,
396 memory_usage: Arc<AtomicU64>,
398}
399
400impl<F> CpuActivationStore<F>
401where
402 F: Clone + Default,
403{
404 pub fn new() -> Self {
406 Self {
407 activations: Arc::new(Mutex::new(HashMap::new())),
408 memory_usage: Arc::new(AtomicU64::new(0)),
409 }
410 }
411
412 pub fn store(&self, layer_id: usize, activation: ArrayD<F>) -> Result<()> {
414 let size = activation.len() * std::mem::size_of::<F>();
415
416 let mut activations = self
417 .activations
418 .lock()
419 .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
420
421 activations.insert(layer_id, activation);
422 self.memory_usage.fetch_add(size as u64, Ordering::Relaxed);
423
424 Ok(())
425 }
426
427 pub fn retrieve(&self, layer_id: usize) -> Result<Option<ArrayD<F>>> {
429 let activations = self
430 .activations
431 .lock()
432 .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
433
434 Ok(activations.get(&layer_id).cloned())
435 }
436
437 pub fn remove(&self, layer_id: usize) -> Result<()> {
439 let mut activations = self
440 .activations
441 .lock()
442 .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
443
444 if let Some(activation) = activations.remove(&layer_id) {
445 let size = activation.len() * std::mem::size_of::<F>();
446 self.memory_usage.fetch_sub(size as u64, Ordering::Relaxed);
447 }
448
449 Ok(())
450 }
451
452 pub fn clear(&self) -> Result<()> {
454 let mut activations = self
455 .activations
456 .lock()
457 .map_err(|_| NeuralError::TrainingError("Failed to lock activations".to_string()))?;
458
459 activations.clear();
460 self.memory_usage.store(0, Ordering::Relaxed);
461
462 Ok(())
463 }
464
465 pub fn memory_usage(&self) -> u64 {
467 self.memory_usage.load(Ordering::Relaxed)
468 }
469}
470
471impl<F> Default for CpuActivationStore<F>
472where
473 F: Clone + Default,
474{
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480#[cfg(all(test, feature = "gpu"))]
481mod tests {
482 use super::*;
483 use scirs2_core::gpu::GpuBackend;
484
485 #[test]
486 fn test_recomputation_policy() {
487 let policy = RecomputationPolicy::default();
488 assert!(matches!(policy, RecomputationPolicy::Selective { .. }));
489
490 let checkpoint_all = RecomputationPolicy::CheckpointAll;
491 assert_eq!(checkpoint_all, RecomputationPolicy::CheckpointAll);
492 }
493
494 #[test]
495 fn test_checkpoint_manager_creation() {
496 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
497 let manager = GradientCheckpointManager::<f32>::new(
498 Arc::new(context),
499 1024 * 1024 * 1024, RecomputationPolicy::CheckpointAll,
501 );
502
503 assert_eq!(manager.memory_usage(), 0);
504 assert_eq!(manager.num_checkpoints(), 0);
505 }
506
507 #[test]
508 fn test_checkpoint_statistics() {
509 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
510 let manager = GradientCheckpointManager::<f32>::new(
511 Arc::new(context),
512 1024 * 1024 * 1024,
513 RecomputationPolicy::CheckpointAll,
514 );
515
516 let stats = manager.get_statistics();
517 assert_eq!(stats.total_checkpoints, 0);
518 assert_eq!(stats.in_memory_checkpoints, 0);
519 assert_eq!(stats.total_memory, 0);
520 }
521
522 #[test]
523 fn test_efficient_backprop_creation() {
524 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
525 let backprop = EfficientBackprop::<f32>::new(
526 Arc::new(context),
527 1024 * 1024 * 1024,
528 RecomputationPolicy::CheckpointAll,
529 true,
530 );
531
532 assert!(backprop.is_enabled());
533 assert_eq!(backprop.checkpoint_manager().num_checkpoints(), 0);
534 }
535
536 #[test]
537 fn test_cpu_activation_store() {
538 let store = CpuActivationStore::<f32>::new();
539
540 let activation = Array::zeros(IxDyn(&[2, 3, 4]));
541 store.store(0, activation.clone()).expect("Failed to store");
542
543 let retrieved = store.retrieve(0).expect("Failed to retrieve");
544 assert!(retrieved.is_some());
545
546 assert!(store.memory_usage() > 0);
547
548 store.clear().expect("Failed to clear");
549 assert_eq!(store.memory_usage(), 0);
550 }
551
552 #[test]
553 fn test_enable_disable_checkpointing() {
554 let context = GpuContext::new(GpuBackend::Cpu).expect("Failed to create context");
555 let mut backprop = EfficientBackprop::<f32>::new(
556 Arc::new(context),
557 1024 * 1024 * 1024,
558 RecomputationPolicy::CheckpointAll,
559 true,
560 );
561
562 assert!(backprop.is_enabled());
563
564 backprop.set_enabled(false);
565 assert!(!backprop.is_enabled());
566
567 backprop.set_enabled(true);
568 assert!(backprop.is_enabled());
569 }
570}