train_station/gradtrack/no_grad_track.rs
1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17 /// Thread-local storage for gradient tracking state
18 ///
19 /// Uses a stack to support nested NoGradTrack contexts.
20 /// Each element represents whether gradients were enabled at that nesting level.
21 static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::gradtrack::{NoGradTrack};
39/// use train_station::Tensor;
40///
41/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
42/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
43///
44/// // Normal computation with gradients
45/// let z1 = x.add_tensor(&y);
46/// assert!(z1.requires_grad());
47///
48/// // Computation without gradients
49/// {
50/// let _guard = NoGradTrack::new();
51/// let z2 = x.add_tensor(&y);
52/// assert!(!z2.requires_grad()); // Gradients disabled
53/// } // Guard drops here, gradients restored
54///
55/// // Gradients are automatically restored
56/// let z3 = x.add_tensor(&y);
57/// assert!(z3.requires_grad());
58/// ```
59///
60/// # Nested Contexts
61///
62/// ```rust
63/// use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};
64///
65/// assert!(is_grad_enabled());
66///
67/// {
68/// let _guard1 = NoGradTrack::new();
69/// assert!(!is_grad_enabled());
70///
71/// {
72/// let _guard2 = NoGradTrack::new();
73/// assert!(!is_grad_enabled());
74/// } // guard2 drops
75///
76/// assert!(!is_grad_enabled()); // Still disabled
77/// } // guard1 drops
78///
79/// assert!(is_grad_enabled()); // Restored
80/// ```
81pub struct NoGradTrack {
82 // Marker to ensure the guard cannot be constructed outside this module
83 // without using the `new()` method
84 _private: (),
85}
86
87impl NoGradTrack {
88 /// Create a new NoGradTrack that disables gradient tracking
89 ///
90 /// This function pushes the current gradient state onto the stack and
91 /// disables gradient tracking. When the guard is dropped, the previous
92 /// state is automatically restored.
93 ///
94 /// # Returns
95 ///
96 /// A new `NoGradTrack` that will restore gradient state when dropped
97 #[track_caller]
98 pub fn new() -> Self {
99 GRAD_ENABLED_STACK.with(|stack| {
100 let mut stack = stack.borrow_mut();
101 stack.push(false); // Disable gradients
102 });
103
104 NoGradTrack { _private: () }
105 }
106}
107
108impl Default for NoGradTrack {
109 fn default() -> Self {
110 Self::new()
111 }
112}
113
114impl Drop for NoGradTrack {
115 /// Automatically restore the previous gradient tracking state
116 ///
117 /// This ensures that gradient tracking is properly restored even if
118 /// the guard goes out of scope due to early returns or panics.
119 fn drop(&mut self) {
120 GRAD_ENABLED_STACK.with(|stack| {
121 let mut stack = stack.borrow_mut();
122 if stack.len() > 1 {
123 stack.pop(); // Remove the current (disabled) state
124 // The previous state is now at the top of the stack
125 } else {
126 // This should not happen in normal usage, but we handle it gracefully
127 // by ensuring at least one state remains (the default true state)
128 *stack = vec![true];
129 }
130 });
131 }
132}
133
134/// Check if gradient computation is currently enabled
135///
136/// This function returns the current gradient tracking state for the current thread.
137/// It respects any active `NoGradTrack` contexts.
138///
139/// # Returns
140///
141/// `true` if gradient computation is enabled, `false` otherwise
142///
143/// # Examples
144///
145/// ```rust
146/// use train_station::gradtrack::{NoGradTrack, is_grad_enabled};
147///
148/// assert!(is_grad_enabled()); // Default state
149///
150/// {
151/// let _guard = NoGradTrack::new();
152/// assert!(!is_grad_enabled()); // Disabled by guard
153/// }
154///
155/// assert!(is_grad_enabled()); // Restored after guard drops
156/// ```
157#[track_caller]
158pub fn is_grad_enabled() -> bool {
159 GRAD_ENABLED_STACK.with(|stack| {
160 let stack = stack.borrow();
161 *stack.last().unwrap_or(&true)
162 })
163}
164
165/// Manually set the gradient tracking state
166///
167/// This function allows manual control over gradient tracking state.
168/// It's primarily intended for internal use and testing. In most cases,
169/// using `NoGradTrack` is preferred as it provides automatic restoration.
170///
171/// # Arguments
172///
173/// * `enabled` - Whether to enable or disable gradient tracking
174///
175/// # Warning
176///
177/// This function modifies the current gradient state without automatic restoration.
178/// Use `NoGradTrack` for RAII-style management in most cases.
179///
180/// # Examples
181///
182/// ```rust
183/// use train_station::gradtrack::{set_grad_enabled, is_grad_enabled};
184///
185/// assert!(is_grad_enabled());
186/// set_grad_enabled(false);
187/// assert!(!is_grad_enabled());
188/// set_grad_enabled(true);
189/// assert!(is_grad_enabled());
190/// ```
191#[track_caller]
192pub fn set_grad_enabled(enabled: bool) {
193 GRAD_ENABLED_STACK.with(|stack| {
194 let mut stack = stack.borrow_mut();
195 if let Some(last) = stack.last_mut() {
196 *last = enabled;
197 } else {
198 // Fallback: ensure stack has at least one element
199 *stack = vec![enabled];
200 }
201 });
202}
203
204/// Convenience function to execute a closure with gradients disabled
205///
206/// This function provides a convenient way to execute code with gradients
207/// disabled without explicitly managing a `NoGradTrack`.
208///
209/// # Arguments
210///
211/// * `f` - The closure to execute with gradients disabled
212///
213/// # Returns
214///
215/// The result of the closure
216///
217/// # Examples
218///
219/// ```rust
220/// use train_station::{Tensor, gradtrack::with_no_grad, gradtrack::is_grad_enabled};
221///
222/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
223/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
224///
225/// let result = with_no_grad(|| {
226/// assert!(!is_grad_enabled());
227/// x.add_tensor(&y)
228/// });
229///
230/// assert!(!result.requires_grad());
231/// assert!(is_grad_enabled()); // Restored after guard drops
232/// ```
233#[track_caller]
234pub fn with_no_grad<F, R>(f: F) -> R
235where
236 F: FnOnce() -> R,
237{
238 let _guard = NoGradTrack::new();
239 f()
240}
241
242/// Reset the gradient tracking state to the default (enabled)
243///
244/// This function is primarily for testing and debugging purposes.
245/// It clears the entire gradient state stack and resets to the default state.
246///
247/// # Warning
248///
249/// This function will disrupt any active `NoGradTrack` contexts and should
250/// only be used in test cleanup or exceptional circumstances.
251#[cfg(test)]
252#[track_caller]
253pub fn reset_grad_state() {
254 GRAD_ENABLED_STACK.with(|stack| {
255 *stack.borrow_mut() = vec![true];
256 });
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_default_grad_enabled() {
265 reset_grad_state();
266 assert!(is_grad_enabled());
267 }
268
269 #[test]
270 fn test_no_grad_guard_basic() {
271 reset_grad_state();
272 assert!(is_grad_enabled());
273
274 {
275 let _guard = NoGradTrack::new();
276 assert!(!is_grad_enabled());
277 }
278
279 assert!(is_grad_enabled());
280 }
281
282 #[test]
283 fn test_nested_no_grad_guards() {
284 reset_grad_state();
285 assert!(is_grad_enabled());
286
287 {
288 let _guard1 = NoGradTrack::new();
289 assert!(!is_grad_enabled());
290
291 {
292 let _guard2 = NoGradTrack::new();
293 assert!(!is_grad_enabled());
294
295 {
296 let _guard3 = NoGradTrack::new();
297 assert!(!is_grad_enabled());
298 }
299
300 assert!(!is_grad_enabled());
301 }
302
303 assert!(!is_grad_enabled());
304 }
305
306 assert!(is_grad_enabled());
307 }
308
309 #[test]
310 fn test_set_grad_enabled() {
311 reset_grad_state();
312 assert!(is_grad_enabled());
313
314 set_grad_enabled(false);
315 assert!(!is_grad_enabled());
316
317 set_grad_enabled(true);
318 assert!(is_grad_enabled());
319 }
320
321 #[test]
322 fn test_with_no_grad_function() {
323 reset_grad_state();
324 assert!(is_grad_enabled());
325
326 let result = with_no_grad(|| {
327 assert!(!is_grad_enabled());
328 42
329 });
330
331 assert_eq!(result, 42);
332 assert!(is_grad_enabled());
333 }
334
335 #[test]
336 fn test_with_no_grad_nested() {
337 reset_grad_state();
338 assert!(is_grad_enabled());
339
340 with_no_grad(|| {
341 assert!(!is_grad_enabled());
342
343 with_no_grad(|| {
344 assert!(!is_grad_enabled());
345 });
346
347 assert!(!is_grad_enabled());
348 });
349
350 assert!(is_grad_enabled());
351 }
352
353 #[test]
354 fn test_multiple_guards_same_scope() {
355 reset_grad_state();
356 assert!(is_grad_enabled());
357
358 let _guard1 = NoGradTrack::new();
359 assert!(!is_grad_enabled());
360
361 let _guard2 = NoGradTrack::new();
362 assert!(!is_grad_enabled());
363
364 drop(_guard1);
365 assert!(!is_grad_enabled()); // Still disabled due to guard2
366
367 drop(_guard2);
368 assert!(is_grad_enabled()); // Now restored
369 }
370
371 #[test]
372 fn test_early_return_with_guard() {
373 fn test_function() -> i32 {
374 reset_grad_state();
375 assert!(is_grad_enabled());
376
377 let _guard = NoGradTrack::new();
378 assert!(!is_grad_enabled());
379
380 if true {
381 return 42; // Early return should still restore state
382 }
383
384 unreachable!()
385 }
386
387 let result = test_function();
388 assert_eq!(result, 42);
389 assert!(is_grad_enabled()); // Should be restored even with early return
390 }
391
392 #[test]
393 fn test_thread_local_isolation() {
394 reset_grad_state();
395 assert!(is_grad_enabled());
396
397 let handle = std::thread::spawn(|| {
398 // Each thread should start with gradients enabled
399 assert!(is_grad_enabled());
400
401 let _guard = NoGradTrack::new();
402 assert!(!is_grad_enabled());
403
404 // Return the state that should be isolated to this thread
405 is_grad_enabled()
406 });
407
408 // Main thread should still have gradients enabled
409 assert!(is_grad_enabled());
410
411 let other_thread_state = handle.join().unwrap();
412 assert!(!other_thread_state); // Other thread had gradients disabled
413
414 // Main thread should still be unaffected
415 assert!(is_grad_enabled());
416 }
417
418 #[test]
419 fn test_panic_safety() {
420 reset_grad_state();
421 assert!(is_grad_enabled());
422
423 let result = std::panic::catch_unwind(|| {
424 let _guard = NoGradTrack::new();
425 assert!(!is_grad_enabled());
426 panic!("Test panic");
427 });
428
429 assert!(result.is_err());
430
431 // State should be restored even after panic
432 // Note: This might not work in all test runners due to panic handling
433 // but the RAII pattern should ensure cleanup in normal Rust programs
434 assert!(is_grad_enabled());
435 }
436
437 #[test]
438 fn test_grad_state_stack_integrity() {
439 reset_grad_state();
440
441 // Test that the stack maintains integrity through complex operations
442 assert!(is_grad_enabled());
443
444 {
445 let _g1 = NoGradTrack::new();
446 assert!(!is_grad_enabled());
447
448 set_grad_enabled(true); // Manual override
449 assert!(is_grad_enabled());
450
451 {
452 let _g2 = NoGradTrack::new();
453 assert!(!is_grad_enabled());
454 }
455
456 assert!(is_grad_enabled()); // Should restore to manually set state
457 }
458
459 assert!(is_grad_enabled()); // Should restore to original state
460 }
461}