train_station/gradtrack/
no_grad_track.rs

1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17    /// Thread-local storage for gradient tracking state
18    ///
19    /// Uses a stack to support nested NoGradTrack contexts.
20    /// Each element represents whether gradients were enabled at that nesting level.
21    static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::gradtrack::{NoGradTrack};
39/// use train_station::Tensor;
40///
41/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
42/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
43///
44/// // Normal computation with gradients
45/// let z1 = x.add_tensor(&y);
46/// assert!(z1.requires_grad());
47///
48/// // Computation without gradients
49/// {
50///     let _guard = NoGradTrack::new();
51///     let z2 = x.add_tensor(&y);
52///     assert!(!z2.requires_grad()); // Gradients disabled
53/// } // Guard drops here, gradients restored
54///
55/// // Gradients are automatically restored
56/// let z3 = x.add_tensor(&y);
57/// assert!(z3.requires_grad());
58/// ```
59///
60/// # Nested Contexts
61///
62/// ```rust
63/// use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};
64///
65/// assert!(is_grad_enabled());
66///
67/// {
68///     let _guard1 = NoGradTrack::new();
69///     assert!(!is_grad_enabled());
70///
71///     {
72///         let _guard2 = NoGradTrack::new();
73///         assert!(!is_grad_enabled());
74///     } // guard2 drops
75///
76///     assert!(!is_grad_enabled()); // Still disabled
77/// } // guard1 drops
78///
79/// assert!(is_grad_enabled()); // Restored
80/// ```
81pub struct NoGradTrack {
82    // Marker to ensure the guard cannot be constructed outside this module
83    // without using the `new()` method
84    _private: (),
85}
86
87impl NoGradTrack {
88    /// Create a new NoGradTrack that disables gradient tracking
89    ///
90    /// This function pushes the current gradient state onto the stack and
91    /// disables gradient tracking. When the guard is dropped, the previous
92    /// state is automatically restored.
93    ///
94    /// # Returns
95    ///
96    /// A new `NoGradTrack` that will restore gradient state when dropped
97    #[track_caller]
98    pub fn new() -> Self {
99        GRAD_ENABLED_STACK.with(|stack| {
100            let mut stack = stack.borrow_mut();
101            stack.push(false); // Disable gradients
102        });
103
104        NoGradTrack { _private: () }
105    }
106}
107
108impl Default for NoGradTrack {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl Drop for NoGradTrack {
115    /// Automatically restore the previous gradient tracking state
116    ///
117    /// This ensures that gradient tracking is properly restored even if
118    /// the guard goes out of scope due to early returns or panics.
119    fn drop(&mut self) {
120        GRAD_ENABLED_STACK.with(|stack| {
121            let mut stack = stack.borrow_mut();
122            if stack.len() > 1 {
123                stack.pop(); // Remove the current (disabled) state
124                             // The previous state is now at the top of the stack
125            } else {
126                // This should not happen in normal usage, but we handle it gracefully
127                // by ensuring at least one state remains (the default true state)
128                *stack = vec![true];
129            }
130        });
131    }
132}
133
134/// Check if gradient computation is currently enabled
135///
136/// This function returns the current gradient tracking state for the current thread.
137/// It respects any active `NoGradTrack` contexts.
138///
139/// # Returns
140///
141/// `true` if gradient computation is enabled, `false` otherwise
142///
143/// # Examples
144///
145/// ```rust
146/// use train_station::gradtrack::{NoGradTrack, is_grad_enabled};
147///
148/// assert!(is_grad_enabled()); // Default state
149///
150/// {
151///     let _guard = NoGradTrack::new();
152///     assert!(!is_grad_enabled()); // Disabled by guard
153/// }
154///
155/// assert!(is_grad_enabled()); // Restored after guard drops
156/// ```
157#[track_caller]
158pub fn is_grad_enabled() -> bool {
159    GRAD_ENABLED_STACK.with(|stack| {
160        let stack = stack.borrow();
161        *stack.last().unwrap_or(&true)
162    })
163}
164
165/// Manually set the gradient tracking state
166///
167/// This function allows manual control over gradient tracking state.
168/// It's primarily intended for internal use and testing. In most cases,
169/// using `NoGradTrack` is preferred as it provides automatic restoration.
170///
171/// # Arguments
172///
173/// * `enabled` - Whether to enable or disable gradient tracking
174///
175/// # Warning
176///
177/// This function modifies the current gradient state without automatic restoration.
178/// Use `NoGradTrack` for RAII-style management in most cases.
179///
180/// # Examples
181///
182/// ```rust
183/// use train_station::gradtrack::{set_grad_enabled, is_grad_enabled};
184///
185/// assert!(is_grad_enabled());
186/// set_grad_enabled(false);
187/// assert!(!is_grad_enabled());
188/// set_grad_enabled(true);
189/// assert!(is_grad_enabled());
190/// ```
191#[track_caller]
192pub fn set_grad_enabled(enabled: bool) {
193    GRAD_ENABLED_STACK.with(|stack| {
194        let mut stack = stack.borrow_mut();
195        if let Some(last) = stack.last_mut() {
196            *last = enabled;
197        } else {
198            // Fallback: ensure stack has at least one element
199            *stack = vec![enabled];
200        }
201    });
202}
203
204/// Convenience function to execute a closure with gradients disabled
205///
206/// This function provides a convenient way to execute code with gradients
207/// disabled without explicitly managing a `NoGradTrack`.
208///
209/// # Arguments
210///
211/// * `f` - The closure to execute with gradients disabled
212///
213/// # Returns
214///
215/// The result of the closure
216///
217/// # Examples
218///
219/// ```rust
220/// use train_station::{Tensor, gradtrack::with_no_grad, gradtrack::is_grad_enabled};
221///
222/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
223/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
224///
225/// let result = with_no_grad(|| {
226///     assert!(!is_grad_enabled());
227///     x.add_tensor(&y)
228/// });
229///
230/// assert!(!result.requires_grad());
231/// assert!(is_grad_enabled()); // Restored after guard drops
232/// ```
233#[track_caller]
234pub fn with_no_grad<F, R>(f: F) -> R
235where
236    F: FnOnce() -> R,
237{
238    let _guard = NoGradTrack::new();
239    f()
240}
241
242/// Reset the gradient tracking state to the default (enabled)
243///
244/// This function is primarily for testing and debugging purposes.
245/// It clears the entire gradient state stack and resets to the default state.
246///
247/// # Warning
248///
249/// This function will disrupt any active `NoGradTrack` contexts and should
250/// only be used in test cleanup or exceptional circumstances.
251#[cfg(test)]
252#[track_caller]
253pub fn reset_grad_state() {
254    GRAD_ENABLED_STACK.with(|stack| {
255        *stack.borrow_mut() = vec![true];
256    });
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_default_grad_enabled() {
265        reset_grad_state();
266        assert!(is_grad_enabled());
267    }
268
269    #[test]
270    fn test_no_grad_guard_basic() {
271        reset_grad_state();
272        assert!(is_grad_enabled());
273
274        {
275            let _guard = NoGradTrack::new();
276            assert!(!is_grad_enabled());
277        }
278
279        assert!(is_grad_enabled());
280    }
281
282    #[test]
283    fn test_nested_no_grad_guards() {
284        reset_grad_state();
285        assert!(is_grad_enabled());
286
287        {
288            let _guard1 = NoGradTrack::new();
289            assert!(!is_grad_enabled());
290
291            {
292                let _guard2 = NoGradTrack::new();
293                assert!(!is_grad_enabled());
294
295                {
296                    let _guard3 = NoGradTrack::new();
297                    assert!(!is_grad_enabled());
298                }
299
300                assert!(!is_grad_enabled());
301            }
302
303            assert!(!is_grad_enabled());
304        }
305
306        assert!(is_grad_enabled());
307    }
308
309    #[test]
310    fn test_set_grad_enabled() {
311        reset_grad_state();
312        assert!(is_grad_enabled());
313
314        set_grad_enabled(false);
315        assert!(!is_grad_enabled());
316
317        set_grad_enabled(true);
318        assert!(is_grad_enabled());
319    }
320
321    #[test]
322    fn test_with_no_grad_function() {
323        reset_grad_state();
324        assert!(is_grad_enabled());
325
326        let result = with_no_grad(|| {
327            assert!(!is_grad_enabled());
328            42
329        });
330
331        assert_eq!(result, 42);
332        assert!(is_grad_enabled());
333    }
334
335    #[test]
336    fn test_with_no_grad_nested() {
337        reset_grad_state();
338        assert!(is_grad_enabled());
339
340        with_no_grad(|| {
341            assert!(!is_grad_enabled());
342
343            with_no_grad(|| {
344                assert!(!is_grad_enabled());
345            });
346
347            assert!(!is_grad_enabled());
348        });
349
350        assert!(is_grad_enabled());
351    }
352
353    #[test]
354    fn test_multiple_guards_same_scope() {
355        reset_grad_state();
356        assert!(is_grad_enabled());
357
358        let _guard1 = NoGradTrack::new();
359        assert!(!is_grad_enabled());
360
361        let _guard2 = NoGradTrack::new();
362        assert!(!is_grad_enabled());
363
364        drop(_guard1);
365        assert!(!is_grad_enabled()); // Still disabled due to guard2
366
367        drop(_guard2);
368        assert!(is_grad_enabled()); // Now restored
369    }
370
371    #[test]
372    fn test_early_return_with_guard() {
373        fn test_function() -> i32 {
374            reset_grad_state();
375            assert!(is_grad_enabled());
376
377            let _guard = NoGradTrack::new();
378            assert!(!is_grad_enabled());
379
380            if true {
381                return 42; // Early return should still restore state
382            }
383
384            unreachable!()
385        }
386
387        let result = test_function();
388        assert_eq!(result, 42);
389        assert!(is_grad_enabled()); // Should be restored even with early return
390    }
391
392    #[test]
393    fn test_thread_local_isolation() {
394        reset_grad_state();
395        assert!(is_grad_enabled());
396
397        let handle = std::thread::spawn(|| {
398            // Each thread should start with gradients enabled
399            assert!(is_grad_enabled());
400
401            let _guard = NoGradTrack::new();
402            assert!(!is_grad_enabled());
403
404            // Return the state that should be isolated to this thread
405            is_grad_enabled()
406        });
407
408        // Main thread should still have gradients enabled
409        assert!(is_grad_enabled());
410
411        let other_thread_state = handle.join().unwrap();
412        assert!(!other_thread_state); // Other thread had gradients disabled
413
414        // Main thread should still be unaffected
415        assert!(is_grad_enabled());
416    }
417
418    #[test]
419    fn test_panic_safety() {
420        reset_grad_state();
421        assert!(is_grad_enabled());
422
423        let result = std::panic::catch_unwind(|| {
424            let _guard = NoGradTrack::new();
425            assert!(!is_grad_enabled());
426            panic!("Test panic");
427        });
428
429        assert!(result.is_err());
430
431        // State should be restored even after panic
432        // Note: This might not work in all test runners due to panic handling
433        // but the RAII pattern should ensure cleanup in normal Rust programs
434        assert!(is_grad_enabled());
435    }
436
437    #[test]
438    fn test_grad_state_stack_integrity() {
439        reset_grad_state();
440
441        // Test that the stack maintains integrity through complex operations
442        assert!(is_grad_enabled());
443
444        {
445            let _g1 = NoGradTrack::new();
446            assert!(!is_grad_enabled());
447
448            set_grad_enabled(true); // Manual override
449            assert!(is_grad_enabled());
450
451            {
452                let _g2 = NoGradTrack::new();
453                assert!(!is_grad_enabled());
454            }
455
456            assert!(is_grad_enabled()); // Should restore to manually set state
457        }
458
459        assert!(is_grad_enabled()); // Should restore to original state
460    }
461}