train_station/gradtrack/
no_grad_track.rs

1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17    /// Thread-local storage for gradient tracking state
18    ///
19    /// Uses a stack to support nested NoGradTrack contexts.
20    /// Each element represents whether gradients were enabled at that nesting level.
21    static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::{NoGradTrack, Tensor};
39///
40/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
41/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
42///
43/// // Normal computation with gradients
44/// let z1 = x.add_tensor(&y);
45/// assert!(z1.requires_grad());
46///
47/// // Computation without gradients
48/// {
49///     let _guard = NoGradTrack::new();
50///     let z2 = x.add_tensor(&y);
51///     assert!(!z2.requires_grad()); // Gradients disabled
52/// } // Guard drops here, gradients restored
53///
54/// // Gradients are automatically restored
55/// let z3 = x.add_tensor(&y);
56/// assert!(z3.requires_grad());
57/// ```
58///
59/// # Nested Contexts
60///
61/// ```rust
62/// use train_station::{NoGradTrack, is_grad_enabled, Tensor};
63///
64/// assert!(is_grad_enabled());
65///
66/// {
67///     let _guard1 = NoGradTrack::new();
68///     assert!(!is_grad_enabled());
69///
70///     {
71///         let _guard2 = NoGradTrack::new();
72///         assert!(!is_grad_enabled());
73///     } // guard2 drops
74///
75///     assert!(!is_grad_enabled()); // Still disabled
76/// } // guard1 drops
77///
78/// assert!(is_grad_enabled()); // Restored
79/// ```
80pub struct NoGradTrack {
81    // Marker to ensure the guard cannot be constructed outside this module
82    // without using the `new()` method
83    _private: (),
84}
85
86impl NoGradTrack {
87    /// Create a new NoGradTrack that disables gradient tracking
88    ///
89    /// This function pushes the current gradient state onto the stack and
90    /// disables gradient tracking. When the guard is dropped, the previous
91    /// state is automatically restored.
92    ///
93    /// # Returns
94    ///
95    /// A new `NoGradTrack` that will restore gradient state when dropped
96    #[track_caller]
97    pub fn new() -> Self {
98        GRAD_ENABLED_STACK.with(|stack| {
99            let mut stack = stack.borrow_mut();
100            stack.push(false); // Disable gradients
101        });
102
103        NoGradTrack { _private: () }
104    }
105}
106
107impl Default for NoGradTrack {
108    fn default() -> Self {
109        Self::new()
110    }
111}
112
113impl Drop for NoGradTrack {
114    /// Automatically restore the previous gradient tracking state
115    ///
116    /// This ensures that gradient tracking is properly restored even if
117    /// the guard goes out of scope due to early returns or panics.
118    fn drop(&mut self) {
119        GRAD_ENABLED_STACK.with(|stack| {
120            let mut stack = stack.borrow_mut();
121            if stack.len() > 1 {
122                stack.pop(); // Remove the current (disabled) state
123                             // The previous state is now at the top of the stack
124            } else {
125                // This should not happen in normal usage, but we handle it gracefully
126                // by ensuring at least one state remains (the default true state)
127                *stack = vec![true];
128            }
129        });
130    }
131}
132
133/// Check if gradient computation is currently enabled
134///
135/// This function returns the current gradient tracking state for the current thread.
136/// It respects any active `NoGradTrack` contexts.
137///
138/// # Returns
139///
140/// `true` if gradient computation is enabled, `false` otherwise
141///
142/// # Examples
143///
144/// ```rust
145/// use train_station::{NoGradTrack, is_grad_enabled};
146///
147/// assert!(is_grad_enabled()); // Default state
148///
149/// {
150///     let _guard = NoGradTrack::new();
151///     assert!(!is_grad_enabled()); // Disabled by guard
152/// }
153///
154/// assert!(is_grad_enabled()); // Restored after guard drops
155/// ```
156#[track_caller]
157pub fn is_grad_enabled() -> bool {
158    GRAD_ENABLED_STACK.with(|stack| {
159        let stack = stack.borrow();
160        *stack.last().unwrap_or(&true)
161    })
162}
163
164/// Manually set the gradient tracking state
165///
166/// This function allows manual control over gradient tracking state.
167/// It's primarily intended for internal use and testing. In most cases,
168/// using `NoGradTrack` is preferred as it provides automatic restoration.
169///
170/// # Arguments
171///
172/// * `enabled` - Whether to enable or disable gradient tracking
173///
174/// # Warning
175///
176/// This function modifies the current gradient state without automatic restoration.
177/// Use `NoGradTrack` for RAII-style management in most cases.
178///
179/// # Examples
180///
181/// ```rust
182/// use train_station::{set_grad_enabled, is_grad_enabled};
183///
184/// assert!(is_grad_enabled());
185/// set_grad_enabled(false);
186/// assert!(!is_grad_enabled());
187/// set_grad_enabled(true);
188/// assert!(is_grad_enabled());
189/// ```
190#[track_caller]
191pub fn set_grad_enabled(enabled: bool) {
192    GRAD_ENABLED_STACK.with(|stack| {
193        let mut stack = stack.borrow_mut();
194        if let Some(last) = stack.last_mut() {
195            *last = enabled;
196        } else {
197            // Fallback: ensure stack has at least one element
198            *stack = vec![enabled];
199        }
200    });
201}
202
203/// Convenience function to execute a closure with gradients disabled
204///
205/// This function provides a convenient way to execute code with gradients
206/// disabled without explicitly managing a `NoGradTrack`.
207///
208/// # Arguments
209///
210/// * `f` - The closure to execute with gradients disabled
211///
212/// # Returns
213///
214/// The result of the closure
215///
216/// # Examples
217///
218/// ```rust
219/// use train_station::{Tensor, with_no_grad, is_grad_enabled};
220///
221/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
222/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
223///
224/// let result = with_no_grad(|| {
225///     assert!(!is_grad_enabled());
226///     x.add_tensor(&y)
227/// });
228///
229/// assert!(!result.requires_grad());
230/// assert!(is_grad_enabled()); // Restored after guard drops
231/// ```
232#[track_caller]
233pub fn with_no_grad<F, R>(f: F) -> R
234where
235    F: FnOnce() -> R,
236{
237    let _guard = NoGradTrack::new();
238    f()
239}
240
241/// Reset the gradient tracking state to the default (enabled)
242///
243/// This function is primarily for testing and debugging purposes.
244/// It clears the entire gradient state stack and resets to the default state.
245///
246/// # Warning
247///
248/// This function will disrupt any active `NoGradTrack` contexts and should
249/// only be used in test cleanup or exceptional circumstances.
250#[cfg(test)]
251#[track_caller]
252pub fn reset_grad_state() {
253    GRAD_ENABLED_STACK.with(|stack| {
254        *stack.borrow_mut() = vec![true];
255    });
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_default_grad_enabled() {
264        reset_grad_state();
265        assert!(is_grad_enabled());
266    }
267
268    #[test]
269    fn test_no_grad_guard_basic() {
270        reset_grad_state();
271        assert!(is_grad_enabled());
272
273        {
274            let _guard = NoGradTrack::new();
275            assert!(!is_grad_enabled());
276        }
277
278        assert!(is_grad_enabled());
279    }
280
281    #[test]
282    fn test_nested_no_grad_guards() {
283        reset_grad_state();
284        assert!(is_grad_enabled());
285
286        {
287            let _guard1 = NoGradTrack::new();
288            assert!(!is_grad_enabled());
289
290            {
291                let _guard2 = NoGradTrack::new();
292                assert!(!is_grad_enabled());
293
294                {
295                    let _guard3 = NoGradTrack::new();
296                    assert!(!is_grad_enabled());
297                }
298
299                assert!(!is_grad_enabled());
300            }
301
302            assert!(!is_grad_enabled());
303        }
304
305        assert!(is_grad_enabled());
306    }
307
308    #[test]
309    fn test_set_grad_enabled() {
310        reset_grad_state();
311        assert!(is_grad_enabled());
312
313        set_grad_enabled(false);
314        assert!(!is_grad_enabled());
315
316        set_grad_enabled(true);
317        assert!(is_grad_enabled());
318    }
319
320    #[test]
321    fn test_with_no_grad_function() {
322        reset_grad_state();
323        assert!(is_grad_enabled());
324
325        let result = with_no_grad(|| {
326            assert!(!is_grad_enabled());
327            42
328        });
329
330        assert_eq!(result, 42);
331        assert!(is_grad_enabled());
332    }
333
334    #[test]
335    fn test_with_no_grad_nested() {
336        reset_grad_state();
337        assert!(is_grad_enabled());
338
339        with_no_grad(|| {
340            assert!(!is_grad_enabled());
341
342            with_no_grad(|| {
343                assert!(!is_grad_enabled());
344            });
345
346            assert!(!is_grad_enabled());
347        });
348
349        assert!(is_grad_enabled());
350    }
351
352    #[test]
353    fn test_multiple_guards_same_scope() {
354        reset_grad_state();
355        assert!(is_grad_enabled());
356
357        let _guard1 = NoGradTrack::new();
358        assert!(!is_grad_enabled());
359
360        let _guard2 = NoGradTrack::new();
361        assert!(!is_grad_enabled());
362
363        drop(_guard1);
364        assert!(!is_grad_enabled()); // Still disabled due to guard2
365
366        drop(_guard2);
367        assert!(is_grad_enabled()); // Now restored
368    }
369
370    #[test]
371    fn test_early_return_with_guard() {
372        fn test_function() -> i32 {
373            reset_grad_state();
374            assert!(is_grad_enabled());
375
376            let _guard = NoGradTrack::new();
377            assert!(!is_grad_enabled());
378
379            if true {
380                return 42; // Early return should still restore state
381            }
382
383            unreachable!()
384        }
385
386        let result = test_function();
387        assert_eq!(result, 42);
388        assert!(is_grad_enabled()); // Should be restored even with early return
389    }
390
391    #[test]
392    fn test_thread_local_isolation() {
393        reset_grad_state();
394        assert!(is_grad_enabled());
395
396        let handle = std::thread::spawn(|| {
397            // Each thread should start with gradients enabled
398            assert!(is_grad_enabled());
399
400            let _guard = NoGradTrack::new();
401            assert!(!is_grad_enabled());
402
403            // Return the state that should be isolated to this thread
404            is_grad_enabled()
405        });
406
407        // Main thread should still have gradients enabled
408        assert!(is_grad_enabled());
409
410        let other_thread_state = handle.join().unwrap();
411        assert!(!other_thread_state); // Other thread had gradients disabled
412
413        // Main thread should still be unaffected
414        assert!(is_grad_enabled());
415    }
416
417    #[test]
418    fn test_panic_safety() {
419        reset_grad_state();
420        assert!(is_grad_enabled());
421
422        let result = std::panic::catch_unwind(|| {
423            let _guard = NoGradTrack::new();
424            assert!(!is_grad_enabled());
425            panic!("Test panic");
426        });
427
428        assert!(result.is_err());
429
430        // State should be restored even after panic
431        // Note: This might not work in all test runners due to panic handling
432        // but the RAII pattern should ensure cleanup in normal Rust programs
433        assert!(is_grad_enabled());
434    }
435
436    #[test]
437    fn test_grad_state_stack_integrity() {
438        reset_grad_state();
439
440        // Test that the stack maintains integrity through complex operations
441        assert!(is_grad_enabled());
442
443        {
444            let _g1 = NoGradTrack::new();
445            assert!(!is_grad_enabled());
446
447            set_grad_enabled(true); // Manual override
448            assert!(is_grad_enabled());
449
450            {
451                let _g2 = NoGradTrack::new();
452                assert!(!is_grad_enabled());
453            }
454
455            assert!(is_grad_enabled()); // Should restore to manually set state
456        }
457
458        assert!(is_grad_enabled()); // Should restore to original state
459    }
460}