train_station/gradtrack/
no_grad_track.rs

1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17    /// Thread-local storage for gradient tracking state
18    ///
19    /// Uses a stack to support nested NoGradTrack contexts.
20    /// Each element represents whether gradients were enabled at that nesting level.
21    static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::{NoGradTrack, Tensor};
39///
40/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
41/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
42///
43/// // Normal computation with gradients
44/// let z1 = x.add_tensor(&y);
45/// assert!(z1.requires_grad());
46///
47/// // Computation without gradients
48/// {
49///     let _guard = NoGradTrack::new();
50///     let z2 = x.add_tensor(&y);
51///     assert!(!z2.requires_grad()); // Gradients disabled
52/// } // Guard drops here, gradients restored
53///
54/// // Gradients are automatically restored
55/// let z3 = x.add_tensor(&y);
56/// assert!(z3.requires_grad());
57/// ```
58///
59/// # Nested Contexts
60///
61/// ```rust
62/// use train_station::{NoGradTrack, is_grad_enabled, Tensor};
63///
64/// assert!(is_grad_enabled());
65///
66/// {
67///     let _guard1 = NoGradTrack::new();
68///     assert!(!is_grad_enabled());
69///
70///     {
71///         let _guard2 = NoGradTrack::new();
72///         assert!(!is_grad_enabled());
73///     } // guard2 drops
74///
75///     assert!(!is_grad_enabled()); // Still disabled
76/// } // guard1 drops
77///
78/// assert!(is_grad_enabled()); // Restored
79/// ```
80pub struct NoGradTrack {
81    // Marker to ensure the guard cannot be constructed outside this module
82    // without using the `new()` method
83    _private: (),
84}
85
86impl NoGradTrack {
87    /// Create a new NoGradTrack that disables gradient tracking
88    ///
89    /// This function pushes the current gradient state onto the stack and
90    /// disables gradient tracking. When the guard is dropped, the previous
91    /// state is automatically restored.
92    ///
93    /// # Returns
94    ///
95    /// A new `NoGradTrack` that will restore gradient state when dropped
96    pub fn new() -> Self {
97        GRAD_ENABLED_STACK.with(|stack| {
98            let mut stack = stack.borrow_mut();
99            stack.push(false); // Disable gradients
100        });
101
102        NoGradTrack { _private: () }
103    }
104}
105
106impl Default for NoGradTrack {
107    fn default() -> Self {
108        Self::new()
109    }
110}
111
112impl Drop for NoGradTrack {
113    /// Automatically restore the previous gradient tracking state
114    ///
115    /// This ensures that gradient tracking is properly restored even if
116    /// the guard goes out of scope due to early returns or panics.
117    fn drop(&mut self) {
118        GRAD_ENABLED_STACK.with(|stack| {
119            let mut stack = stack.borrow_mut();
120            if stack.len() > 1 {
121                stack.pop(); // Remove the current (disabled) state
122                             // The previous state is now at the top of the stack
123            } else {
124                // This should not happen in normal usage, but we handle it gracefully
125                // by ensuring at least one state remains (the default true state)
126                *stack = vec![true];
127            }
128        });
129    }
130}
131
132/// Check if gradient computation is currently enabled
133///
134/// This function returns the current gradient tracking state for the current thread.
135/// It respects any active `NoGradTrack` contexts.
136///
137/// # Returns
138///
139/// `true` if gradient computation is enabled, `false` otherwise
140///
141/// # Examples
142///
143/// ```rust
144/// use train_station::{NoGradTrack, is_grad_enabled};
145///
146/// assert!(is_grad_enabled()); // Default state
147///
148/// {
149///     let _guard = NoGradTrack::new();
150///     assert!(!is_grad_enabled()); // Disabled by guard
151/// }
152///
153/// assert!(is_grad_enabled()); // Restored after guard drops
154/// ```
155pub fn is_grad_enabled() -> bool {
156    GRAD_ENABLED_STACK.with(|stack| {
157        let stack = stack.borrow();
158        *stack.last().unwrap_or(&true)
159    })
160}
161
162/// Manually set the gradient tracking state
163///
164/// This function allows manual control over gradient tracking state.
165/// It's primarily intended for internal use and testing. In most cases,
166/// using `NoGradTrack` is preferred as it provides automatic restoration.
167///
168/// # Arguments
169///
170/// * `enabled` - Whether to enable or disable gradient tracking
171///
172/// # Warning
173///
174/// This function modifies the current gradient state without automatic restoration.
175/// Use `NoGradTrack` for RAII-style management in most cases.
176///
177/// # Examples
178///
179/// ```rust
180/// use train_station::{set_grad_enabled, is_grad_enabled};
181///
182/// assert!(is_grad_enabled());
183/// set_grad_enabled(false);
184/// assert!(!is_grad_enabled());
185/// set_grad_enabled(true);
186/// assert!(is_grad_enabled());
187/// ```
188pub fn set_grad_enabled(enabled: bool) {
189    GRAD_ENABLED_STACK.with(|stack| {
190        let mut stack = stack.borrow_mut();
191        if let Some(last) = stack.last_mut() {
192            *last = enabled;
193        } else {
194            // Fallback: ensure stack has at least one element
195            *stack = vec![enabled];
196        }
197    });
198}
199
200/// Convenience function to execute a closure with gradients disabled
201///
202/// This function provides a convenient way to execute code with gradients
203/// disabled without explicitly managing a `NoGradTrack`.
204///
205/// # Arguments
206///
207/// * `f` - The closure to execute with gradients disabled
208///
209/// # Returns
210///
211/// The result of the closure
212///
213/// # Examples
214///
215/// ```rust
216/// use train_station::{Tensor, with_no_grad, is_grad_enabled};
217///
218/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
219/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
220///
221/// let result = with_no_grad(|| {
222///     assert!(!is_grad_enabled());
223///     x.add_tensor(&y)
224/// });
225///
226/// assert!(!result.requires_grad());
227/// assert!(is_grad_enabled()); // Restored after closure
228/// ```
229pub fn with_no_grad<F, R>(f: F) -> R
230where
231    F: FnOnce() -> R,
232{
233    let _guard = NoGradTrack::new();
234    f()
235}
236
237/// Reset the gradient tracking state to the default (enabled)
238///
239/// This function is primarily for testing and debugging purposes.
240/// It clears the entire gradient state stack and resets to the default state.
241///
242/// # Warning
243///
244/// This function will disrupt any active `NoGradTrack` contexts and should
245/// only be used in test cleanup or exceptional circumstances.
246#[cfg(test)]
247pub fn reset_grad_state() {
248    GRAD_ENABLED_STACK.with(|stack| {
249        *stack.borrow_mut() = vec![true];
250    });
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_default_grad_enabled() {
259        reset_grad_state();
260        assert!(is_grad_enabled());
261    }
262
263    #[test]
264    fn test_no_grad_guard_basic() {
265        reset_grad_state();
266        assert!(is_grad_enabled());
267
268        {
269            let _guard = NoGradTrack::new();
270            assert!(!is_grad_enabled());
271        }
272
273        assert!(is_grad_enabled());
274    }
275
276    #[test]
277    fn test_nested_no_grad_guards() {
278        reset_grad_state();
279        assert!(is_grad_enabled());
280
281        {
282            let _guard1 = NoGradTrack::new();
283            assert!(!is_grad_enabled());
284
285            {
286                let _guard2 = NoGradTrack::new();
287                assert!(!is_grad_enabled());
288
289                {
290                    let _guard3 = NoGradTrack::new();
291                    assert!(!is_grad_enabled());
292                }
293
294                assert!(!is_grad_enabled());
295            }
296
297            assert!(!is_grad_enabled());
298        }
299
300        assert!(is_grad_enabled());
301    }
302
303    #[test]
304    fn test_set_grad_enabled() {
305        reset_grad_state();
306        assert!(is_grad_enabled());
307
308        set_grad_enabled(false);
309        assert!(!is_grad_enabled());
310
311        set_grad_enabled(true);
312        assert!(is_grad_enabled());
313    }
314
315    #[test]
316    fn test_with_no_grad_function() {
317        reset_grad_state();
318        assert!(is_grad_enabled());
319
320        let result = with_no_grad(|| {
321            assert!(!is_grad_enabled());
322            42
323        });
324
325        assert_eq!(result, 42);
326        assert!(is_grad_enabled());
327    }
328
329    #[test]
330    fn test_with_no_grad_nested() {
331        reset_grad_state();
332        assert!(is_grad_enabled());
333
334        with_no_grad(|| {
335            assert!(!is_grad_enabled());
336
337            with_no_grad(|| {
338                assert!(!is_grad_enabled());
339            });
340
341            assert!(!is_grad_enabled());
342        });
343
344        assert!(is_grad_enabled());
345    }
346
347    #[test]
348    fn test_multiple_guards_same_scope() {
349        reset_grad_state();
350        assert!(is_grad_enabled());
351
352        let _guard1 = NoGradTrack::new();
353        assert!(!is_grad_enabled());
354
355        let _guard2 = NoGradTrack::new();
356        assert!(!is_grad_enabled());
357
358        drop(_guard1);
359        assert!(!is_grad_enabled()); // Still disabled due to guard2
360
361        drop(_guard2);
362        assert!(is_grad_enabled()); // Now restored
363    }
364
365    #[test]
366    fn test_early_return_with_guard() {
367        fn test_function() -> i32 {
368            reset_grad_state();
369            assert!(is_grad_enabled());
370
371            let _guard = NoGradTrack::new();
372            assert!(!is_grad_enabled());
373
374            if true {
375                return 42; // Early return should still restore state
376            }
377
378            unreachable!()
379        }
380
381        let result = test_function();
382        assert_eq!(result, 42);
383        assert!(is_grad_enabled()); // Should be restored even with early return
384    }
385
386    #[test]
387    fn test_thread_local_isolation() {
388        reset_grad_state();
389        assert!(is_grad_enabled());
390
391        let handle = std::thread::spawn(|| {
392            // Each thread should start with gradients enabled
393            assert!(is_grad_enabled());
394
395            let _guard = NoGradTrack::new();
396            assert!(!is_grad_enabled());
397
398            // Return the state that should be isolated to this thread
399            is_grad_enabled()
400        });
401
402        // Main thread should still have gradients enabled
403        assert!(is_grad_enabled());
404
405        let other_thread_state = handle.join().unwrap();
406        assert!(!other_thread_state); // Other thread had gradients disabled
407
408        // Main thread should still be unaffected
409        assert!(is_grad_enabled());
410    }
411
412    #[test]
413    fn test_panic_safety() {
414        reset_grad_state();
415        assert!(is_grad_enabled());
416
417        let result = std::panic::catch_unwind(|| {
418            let _guard = NoGradTrack::new();
419            assert!(!is_grad_enabled());
420            panic!("Test panic");
421        });
422
423        assert!(result.is_err());
424
425        // State should be restored even after panic
426        // Note: This might not work in all test runners due to panic handling
427        // but the RAII pattern should ensure cleanup in normal Rust programs
428        assert!(is_grad_enabled());
429    }
430
431    #[test]
432    fn test_grad_state_stack_integrity() {
433        reset_grad_state();
434
435        // Test that the stack maintains integrity through complex operations
436        assert!(is_grad_enabled());
437
438        {
439            let _g1 = NoGradTrack::new();
440            assert!(!is_grad_enabled());
441
442            set_grad_enabled(true); // Manual override
443            assert!(is_grad_enabled());
444
445            {
446                let _g2 = NoGradTrack::new();
447                assert!(!is_grad_enabled());
448            }
449
450            assert!(is_grad_enabled()); // Should restore to manually set state
451        }
452
453        assert!(is_grad_enabled()); // Should restore to original state
454    }
455}