train_station/gradtrack/no_grad_track.rs
1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17 /// Thread-local storage for gradient tracking state
18 ///
19 /// Uses a stack to support nested NoGradTrack contexts.
20 /// Each element represents whether gradients were enabled at that nesting level.
21 static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::{NoGradTrack, Tensor};
39///
40/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
41/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
42///
43/// // Normal computation with gradients
44/// let z1 = x.add_tensor(&y);
45/// assert!(z1.requires_grad());
46///
47/// // Computation without gradients
48/// {
49/// let _guard = NoGradTrack::new();
50/// let z2 = x.add_tensor(&y);
51/// assert!(!z2.requires_grad()); // Gradients disabled
52/// } // Guard drops here, gradients restored
53///
54/// // Gradients are automatically restored
55/// let z3 = x.add_tensor(&y);
56/// assert!(z3.requires_grad());
57/// ```
58///
59/// # Nested Contexts
60///
61/// ```rust
62/// use train_station::{NoGradTrack, is_grad_enabled, Tensor};
63///
64/// assert!(is_grad_enabled());
65///
66/// {
67/// let _guard1 = NoGradTrack::new();
68/// assert!(!is_grad_enabled());
69///
70/// {
71/// let _guard2 = NoGradTrack::new();
72/// assert!(!is_grad_enabled());
73/// } // guard2 drops
74///
75/// assert!(!is_grad_enabled()); // Still disabled
76/// } // guard1 drops
77///
78/// assert!(is_grad_enabled()); // Restored
79/// ```
80pub struct NoGradTrack {
81 // Marker to ensure the guard cannot be constructed outside this module
82 // without using the `new()` method
83 _private: (),
84}
85
86impl NoGradTrack {
87 /// Create a new NoGradTrack that disables gradient tracking
88 ///
89 /// This function pushes the current gradient state onto the stack and
90 /// disables gradient tracking. When the guard is dropped, the previous
91 /// state is automatically restored.
92 ///
93 /// # Returns
94 ///
95 /// A new `NoGradTrack` that will restore gradient state when dropped
96 #[track_caller]
97 pub fn new() -> Self {
98 GRAD_ENABLED_STACK.with(|stack| {
99 let mut stack = stack.borrow_mut();
100 stack.push(false); // Disable gradients
101 });
102
103 NoGradTrack { _private: () }
104 }
105}
106
107impl Default for NoGradTrack {
108 fn default() -> Self {
109 Self::new()
110 }
111}
112
113impl Drop for NoGradTrack {
114 /// Automatically restore the previous gradient tracking state
115 ///
116 /// This ensures that gradient tracking is properly restored even if
117 /// the guard goes out of scope due to early returns or panics.
118 fn drop(&mut self) {
119 GRAD_ENABLED_STACK.with(|stack| {
120 let mut stack = stack.borrow_mut();
121 if stack.len() > 1 {
122 stack.pop(); // Remove the current (disabled) state
123 // The previous state is now at the top of the stack
124 } else {
125 // This should not happen in normal usage, but we handle it gracefully
126 // by ensuring at least one state remains (the default true state)
127 *stack = vec![true];
128 }
129 });
130 }
131}
132
133/// Check if gradient computation is currently enabled
134///
135/// This function returns the current gradient tracking state for the current thread.
136/// It respects any active `NoGradTrack` contexts.
137///
138/// # Returns
139///
140/// `true` if gradient computation is enabled, `false` otherwise
141///
142/// # Examples
143///
144/// ```rust
145/// use train_station::{NoGradTrack, is_grad_enabled};
146///
147/// assert!(is_grad_enabled()); // Default state
148///
149/// {
150/// let _guard = NoGradTrack::new();
151/// assert!(!is_grad_enabled()); // Disabled by guard
152/// }
153///
154/// assert!(is_grad_enabled()); // Restored after guard drops
155/// ```
156#[track_caller]
157pub fn is_grad_enabled() -> bool {
158 GRAD_ENABLED_STACK.with(|stack| {
159 let stack = stack.borrow();
160 *stack.last().unwrap_or(&true)
161 })
162}
163
164/// Manually set the gradient tracking state
165///
166/// This function allows manual control over gradient tracking state.
167/// It's primarily intended for internal use and testing. In most cases,
168/// using `NoGradTrack` is preferred as it provides automatic restoration.
169///
170/// # Arguments
171///
172/// * `enabled` - Whether to enable or disable gradient tracking
173///
174/// # Warning
175///
176/// This function modifies the current gradient state without automatic restoration.
177/// Use `NoGradTrack` for RAII-style management in most cases.
178///
179/// # Examples
180///
181/// ```rust
182/// use train_station::{set_grad_enabled, is_grad_enabled};
183///
184/// assert!(is_grad_enabled());
185/// set_grad_enabled(false);
186/// assert!(!is_grad_enabled());
187/// set_grad_enabled(true);
188/// assert!(is_grad_enabled());
189/// ```
190#[track_caller]
191pub fn set_grad_enabled(enabled: bool) {
192 GRAD_ENABLED_STACK.with(|stack| {
193 let mut stack = stack.borrow_mut();
194 if let Some(last) = stack.last_mut() {
195 *last = enabled;
196 } else {
197 // Fallback: ensure stack has at least one element
198 *stack = vec![enabled];
199 }
200 });
201}
202
203/// Convenience function to execute a closure with gradients disabled
204///
205/// This function provides a convenient way to execute code with gradients
206/// disabled without explicitly managing a `NoGradTrack`.
207///
208/// # Arguments
209///
210/// * `f` - The closure to execute with gradients disabled
211///
212/// # Returns
213///
214/// The result of the closure
215///
216/// # Examples
217///
218/// ```rust
219/// use train_station::{Tensor, with_no_grad, is_grad_enabled};
220///
221/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
222/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
223///
224/// let result = with_no_grad(|| {
225/// assert!(!is_grad_enabled());
226/// x.add_tensor(&y)
227/// });
228///
229/// assert!(!result.requires_grad());
230/// assert!(is_grad_enabled()); // Restored after guard drops
231/// ```
232#[track_caller]
233pub fn with_no_grad<F, R>(f: F) -> R
234where
235 F: FnOnce() -> R,
236{
237 let _guard = NoGradTrack::new();
238 f()
239}
240
241/// Reset the gradient tracking state to the default (enabled)
242///
243/// This function is primarily for testing and debugging purposes.
244/// It clears the entire gradient state stack and resets to the default state.
245///
246/// # Warning
247///
248/// This function will disrupt any active `NoGradTrack` contexts and should
249/// only be used in test cleanup or exceptional circumstances.
250#[cfg(test)]
251#[track_caller]
252pub fn reset_grad_state() {
253 GRAD_ENABLED_STACK.with(|stack| {
254 *stack.borrow_mut() = vec![true];
255 });
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_default_grad_enabled() {
264 reset_grad_state();
265 assert!(is_grad_enabled());
266 }
267
268 #[test]
269 fn test_no_grad_guard_basic() {
270 reset_grad_state();
271 assert!(is_grad_enabled());
272
273 {
274 let _guard = NoGradTrack::new();
275 assert!(!is_grad_enabled());
276 }
277
278 assert!(is_grad_enabled());
279 }
280
281 #[test]
282 fn test_nested_no_grad_guards() {
283 reset_grad_state();
284 assert!(is_grad_enabled());
285
286 {
287 let _guard1 = NoGradTrack::new();
288 assert!(!is_grad_enabled());
289
290 {
291 let _guard2 = NoGradTrack::new();
292 assert!(!is_grad_enabled());
293
294 {
295 let _guard3 = NoGradTrack::new();
296 assert!(!is_grad_enabled());
297 }
298
299 assert!(!is_grad_enabled());
300 }
301
302 assert!(!is_grad_enabled());
303 }
304
305 assert!(is_grad_enabled());
306 }
307
308 #[test]
309 fn test_set_grad_enabled() {
310 reset_grad_state();
311 assert!(is_grad_enabled());
312
313 set_grad_enabled(false);
314 assert!(!is_grad_enabled());
315
316 set_grad_enabled(true);
317 assert!(is_grad_enabled());
318 }
319
320 #[test]
321 fn test_with_no_grad_function() {
322 reset_grad_state();
323 assert!(is_grad_enabled());
324
325 let result = with_no_grad(|| {
326 assert!(!is_grad_enabled());
327 42
328 });
329
330 assert_eq!(result, 42);
331 assert!(is_grad_enabled());
332 }
333
334 #[test]
335 fn test_with_no_grad_nested() {
336 reset_grad_state();
337 assert!(is_grad_enabled());
338
339 with_no_grad(|| {
340 assert!(!is_grad_enabled());
341
342 with_no_grad(|| {
343 assert!(!is_grad_enabled());
344 });
345
346 assert!(!is_grad_enabled());
347 });
348
349 assert!(is_grad_enabled());
350 }
351
352 #[test]
353 fn test_multiple_guards_same_scope() {
354 reset_grad_state();
355 assert!(is_grad_enabled());
356
357 let _guard1 = NoGradTrack::new();
358 assert!(!is_grad_enabled());
359
360 let _guard2 = NoGradTrack::new();
361 assert!(!is_grad_enabled());
362
363 drop(_guard1);
364 assert!(!is_grad_enabled()); // Still disabled due to guard2
365
366 drop(_guard2);
367 assert!(is_grad_enabled()); // Now restored
368 }
369
370 #[test]
371 fn test_early_return_with_guard() {
372 fn test_function() -> i32 {
373 reset_grad_state();
374 assert!(is_grad_enabled());
375
376 let _guard = NoGradTrack::new();
377 assert!(!is_grad_enabled());
378
379 if true {
380 return 42; // Early return should still restore state
381 }
382
383 unreachable!()
384 }
385
386 let result = test_function();
387 assert_eq!(result, 42);
388 assert!(is_grad_enabled()); // Should be restored even with early return
389 }
390
391 #[test]
392 fn test_thread_local_isolation() {
393 reset_grad_state();
394 assert!(is_grad_enabled());
395
396 let handle = std::thread::spawn(|| {
397 // Each thread should start with gradients enabled
398 assert!(is_grad_enabled());
399
400 let _guard = NoGradTrack::new();
401 assert!(!is_grad_enabled());
402
403 // Return the state that should be isolated to this thread
404 is_grad_enabled()
405 });
406
407 // Main thread should still have gradients enabled
408 assert!(is_grad_enabled());
409
410 let other_thread_state = handle.join().unwrap();
411 assert!(!other_thread_state); // Other thread had gradients disabled
412
413 // Main thread should still be unaffected
414 assert!(is_grad_enabled());
415 }
416
417 #[test]
418 fn test_panic_safety() {
419 reset_grad_state();
420 assert!(is_grad_enabled());
421
422 let result = std::panic::catch_unwind(|| {
423 let _guard = NoGradTrack::new();
424 assert!(!is_grad_enabled());
425 panic!("Test panic");
426 });
427
428 assert!(result.is_err());
429
430 // State should be restored even after panic
431 // Note: This might not work in all test runners due to panic handling
432 // but the RAII pattern should ensure cleanup in normal Rust programs
433 assert!(is_grad_enabled());
434 }
435
436 #[test]
437 fn test_grad_state_stack_integrity() {
438 reset_grad_state();
439
440 // Test that the stack maintains integrity through complex operations
441 assert!(is_grad_enabled());
442
443 {
444 let _g1 = NoGradTrack::new();
445 assert!(!is_grad_enabled());
446
447 set_grad_enabled(true); // Manual override
448 assert!(is_grad_enabled());
449
450 {
451 let _g2 = NoGradTrack::new();
452 assert!(!is_grad_enabled());
453 }
454
455 assert!(is_grad_enabled()); // Should restore to manually set state
456 }
457
458 assert!(is_grad_enabled()); // Should restore to original state
459 }
460}