pub struct GradientCheckpointing<F: Float + Debug> { /* private fields */ }
Expand description
Gradient checkpointing implementation for memory-efficient training
Implementations§
Source§impl<F: Float + Debug + Clone + 'static + ScalarOperand> GradientCheckpointing<F>
impl<F: Float + Debug + Clone + 'static + ScalarOperand> GradientCheckpointing<F>
Sourcepub fn new(memory_threshold_mb: f64) -> Self
pub fn new(memory_threshold_mb: f64) -> Self
Create a new gradient checkpointing manager
Examples found in repository?
examples/memory_efficient_example.rs (line 88)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Sourcepub fn add_checkpoint_layer(&mut self, layer_name: String)
pub fn add_checkpoint_layer(&mut self, layer_name: String)
Add a layer as a checkpoint point
Examples found in repository?
examples/memory_efficient_example.rs (line 91)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Sourcepub fn store_checkpoint(
&mut self,
layer_name: &str,
activation: ArrayD<F>,
) -> Result<()>
pub fn store_checkpoint( &mut self, layer_name: &str, activation: ArrayD<F>, ) -> Result<()>
Store activation at a checkpoint
Examples found in repository?
examples/memory_efficient_example.rs (line 102)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Sourcepub fn get_checkpoint(&self, layer_name: &str) -> Option<&ArrayD<F>>
pub fn get_checkpoint(&self, layer_name: &str) -> Option<&ArrayD<F>>
Retrieve activation from checkpoint
Examples found in repository?
examples/memory_efficient_example.rs (line 112)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Sourcepub fn clear_checkpoints(&mut self)
pub fn clear_checkpoints(&mut self)
Clear checkpoints to free memory
Examples found in repository?
examples/memory_efficient_example.rs (line 118)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Sourcepub fn get_memory_usage(&self) -> MemoryUsage
pub fn get_memory_usage(&self) -> MemoryUsage
Get current memory usage
Examples found in repository?
examples/memory_efficient_example.rs (line 106)
84fn demo_gradient_checkpointing() -> Result<()> {
85 println!("\n📊 Gradient Checkpointing Demo");
86 println!("------------------------------");
87
88 let mut checkpointing = GradientCheckpointing::<f64>::new(100.0); // 100MB threshold
89
90 // Set up checkpoint layers
91 checkpointing.add_checkpoint_layer("conv1".to_string());
92 checkpointing.add_checkpoint_layer("conv3".to_string());
93 checkpointing.add_checkpoint_layer("fc1".to_string());
94
95 println!("Storing activations at checkpoints...");
96
97 // Simulate storing activations during forward pass
98 let conv1_activation = Array3::from_elem((32, 64, 64), 0.5).into_dyn(); // Batch=32, 64x64 feature maps
99 let conv3_activation = Array3::from_elem((32, 128, 32), 0.3).into_dyn(); // Reduced spatial size
100 let fc1_activation = Array2::from_elem((32, 512), 0.2).into_dyn(); // Fully connected
101
102 checkpointing.store_checkpoint("conv1", conv1_activation)?;
103 checkpointing.store_checkpoint("conv3", conv3_activation)?;
104 checkpointing.store_checkpoint("fc1", fc1_activation)?;
105
106 let usage = checkpointing.get_memory_usage();
107 println!("Memory usage after checkpointing:");
108 print_memory_usage(&usage);
109
110 // Simulate retrieving checkpoints during backward pass
111 println!("Retrieving checkpoints for gradient computation...");
112 if let Some(checkpoint) = checkpointing.get_checkpoint("conv1") {
113 println!("Retrieved conv1 checkpoint: shape {:?}", checkpoint.shape());
114 }
115
116 // Clear checkpoints to free memory
117 println!("Clearing checkpoints...");
118 checkpointing.clear_checkpoints();
119
120 let usage = checkpointing.get_memory_usage();
121 println!("Memory usage after clearing:");
122 print_memory_usage(&usage);
123
124 Ok(())
125}
Auto Trait Implementations§
impl<F> Freeze for GradientCheckpointing<F>
impl<F> RefUnwindSafe for GradientCheckpointing<F>where
F: RefUnwindSafe,
impl<F> Send for GradientCheckpointing<F>where
F: Send,
impl<F> Sync for GradientCheckpointing<F>where
F: Sync,
impl<F> Unpin for GradientCheckpointing<F>
impl<F> UnwindSafe for GradientCheckpointing<F>where
F: RefUnwindSafe,
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left
is true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
Converts
self
into a Left
variant of Either<Self, Self>
if into_left(&self)
returns true
.
Converts self
into a Right
variant of Either<Self, Self>
otherwise. Read more