train_station/gradtrack/no_grad_track.rs
1//! NoGradTrack for temporarily disabling gradient tracking
2//!
3//! This module provides functionality similar to PyTorch's `torch.no_grad()` context manager,
4//! allowing users to temporarily disable gradient computation for performance optimization
5//! during inference or validation phases.
6//!
7//! # Features
8//! - Thread-local gradient context management
9//! - RAII pattern for automatic restoration
10//! - Nested context support with proper stack management
11//! - Zero-cost abstraction when gradients are already disabled
12//! - Thread-safe design for concurrent usage
13
14use std::cell::RefCell;
15
16thread_local! {
17 /// Thread-local storage for gradient tracking state
18 ///
19 /// Uses a stack to support nested NoGradTrack contexts.
20 /// Each element represents whether gradients were enabled at that nesting level.
21 static GRAD_ENABLED_STACK: RefCell<Vec<bool>> = RefCell::new(vec![true]);
22}
23
24/// A RAII guard that temporarily disables gradient tracking
25///
26/// Similar to PyTorch's `torch.no_grad()`, this guard disables gradient computation
27/// within its scope and automatically restores the previous gradient tracking state
28/// when it goes out of scope.
29///
30/// # Performance Benefits
31/// - Prevents computation graph construction during inference
32/// - Reduces memory usage by not storing intermediate values for backpropagation
33/// - Improves computation speed by skipping gradient-related operations
34///
35/// # Examples
36///
37/// ```rust
38/// use train_station::{NoGradTrack, Tensor};
39///
40/// let x = Tensor::ones(vec![3, 3]).with_requires_grad();
41/// let y = Tensor::ones(vec![3, 3]).with_requires_grad();
42///
43/// // Normal computation with gradients
44/// let z1 = x.add_tensor(&y);
45/// assert!(z1.requires_grad());
46///
47/// // Computation without gradients
48/// {
49/// let _guard = NoGradTrack::new();
50/// let z2 = x.add_tensor(&y);
51/// assert!(!z2.requires_grad()); // Gradients disabled
52/// } // Guard drops here, gradients restored
53///
54/// // Gradients are automatically restored
55/// let z3 = x.add_tensor(&y);
56/// assert!(z3.requires_grad());
57/// ```
58///
59/// # Nested Contexts
60///
61/// ```rust
62/// use train_station::{NoGradTrack, is_grad_enabled, Tensor};
63///
64/// assert!(is_grad_enabled());
65///
66/// {
67/// let _guard1 = NoGradTrack::new();
68/// assert!(!is_grad_enabled());
69///
70/// {
71/// let _guard2 = NoGradTrack::new();
72/// assert!(!is_grad_enabled());
73/// } // guard2 drops
74///
75/// assert!(!is_grad_enabled()); // Still disabled
76/// } // guard1 drops
77///
78/// assert!(is_grad_enabled()); // Restored
79/// ```
80pub struct NoGradTrack {
81 // Marker to ensure the guard cannot be constructed outside this module
82 // without using the `new()` method
83 _private: (),
84}
85
86impl NoGradTrack {
87 /// Create a new NoGradTrack that disables gradient tracking
88 ///
89 /// This function pushes the current gradient state onto the stack and
90 /// disables gradient tracking. When the guard is dropped, the previous
91 /// state is automatically restored.
92 ///
93 /// # Returns
94 ///
95 /// A new `NoGradTrack` that will restore gradient state when dropped
96 pub fn new() -> Self {
97 GRAD_ENABLED_STACK.with(|stack| {
98 let mut stack = stack.borrow_mut();
99 stack.push(false); // Disable gradients
100 });
101
102 NoGradTrack { _private: () }
103 }
104}
105
106impl Default for NoGradTrack {
107 fn default() -> Self {
108 Self::new()
109 }
110}
111
112impl Drop for NoGradTrack {
113 /// Automatically restore the previous gradient tracking state
114 ///
115 /// This ensures that gradient tracking is properly restored even if
116 /// the guard goes out of scope due to early returns or panics.
117 fn drop(&mut self) {
118 GRAD_ENABLED_STACK.with(|stack| {
119 let mut stack = stack.borrow_mut();
120 if stack.len() > 1 {
121 stack.pop(); // Remove the current (disabled) state
122 // The previous state is now at the top of the stack
123 } else {
124 // This should not happen in normal usage, but we handle it gracefully
125 // by ensuring at least one state remains (the default true state)
126 *stack = vec![true];
127 }
128 });
129 }
130}
131
132/// Check if gradient computation is currently enabled
133///
134/// This function returns the current gradient tracking state for the current thread.
135/// It respects any active `NoGradTrack` contexts.
136///
137/// # Returns
138///
139/// `true` if gradient computation is enabled, `false` otherwise
140///
141/// # Examples
142///
143/// ```rust
144/// use train_station::{NoGradTrack, is_grad_enabled};
145///
146/// assert!(is_grad_enabled()); // Default state
147///
148/// {
149/// let _guard = NoGradTrack::new();
150/// assert!(!is_grad_enabled()); // Disabled by guard
151/// }
152///
153/// assert!(is_grad_enabled()); // Restored after guard drops
154/// ```
155pub fn is_grad_enabled() -> bool {
156 GRAD_ENABLED_STACK.with(|stack| {
157 let stack = stack.borrow();
158 *stack.last().unwrap_or(&true)
159 })
160}
161
162/// Manually set the gradient tracking state
163///
164/// This function allows manual control over gradient tracking state.
165/// It's primarily intended for internal use and testing. In most cases,
166/// using `NoGradTrack` is preferred as it provides automatic restoration.
167///
168/// # Arguments
169///
170/// * `enabled` - Whether to enable or disable gradient tracking
171///
172/// # Warning
173///
174/// This function modifies the current gradient state without automatic restoration.
175/// Use `NoGradTrack` for RAII-style management in most cases.
176///
177/// # Examples
178///
179/// ```rust
180/// use train_station::{set_grad_enabled, is_grad_enabled};
181///
182/// assert!(is_grad_enabled());
183/// set_grad_enabled(false);
184/// assert!(!is_grad_enabled());
185/// set_grad_enabled(true);
186/// assert!(is_grad_enabled());
187/// ```
188pub fn set_grad_enabled(enabled: bool) {
189 GRAD_ENABLED_STACK.with(|stack| {
190 let mut stack = stack.borrow_mut();
191 if let Some(last) = stack.last_mut() {
192 *last = enabled;
193 } else {
194 // Fallback: ensure stack has at least one element
195 *stack = vec![enabled];
196 }
197 });
198}
199
200/// Convenience function to execute a closure with gradients disabled
201///
202/// This function provides a convenient way to execute code with gradients
203/// disabled without explicitly managing a `NoGradTrack`.
204///
205/// # Arguments
206///
207/// * `f` - The closure to execute with gradients disabled
208///
209/// # Returns
210///
211/// The result of the closure
212///
213/// # Examples
214///
215/// ```rust
216/// use train_station::{Tensor, with_no_grad, is_grad_enabled};
217///
218/// let x = Tensor::ones(vec![2, 2]).with_requires_grad();
219/// let y = Tensor::ones(vec![2, 2]).with_requires_grad();
220///
221/// let result = with_no_grad(|| {
222/// assert!(!is_grad_enabled());
223/// x.add_tensor(&y)
224/// });
225///
226/// assert!(!result.requires_grad());
227/// assert!(is_grad_enabled()); // Restored after closure
228/// ```
229pub fn with_no_grad<F, R>(f: F) -> R
230where
231 F: FnOnce() -> R,
232{
233 let _guard = NoGradTrack::new();
234 f()
235}
236
237/// Reset the gradient tracking state to the default (enabled)
238///
239/// This function is primarily for testing and debugging purposes.
240/// It clears the entire gradient state stack and resets to the default state.
241///
242/// # Warning
243///
244/// This function will disrupt any active `NoGradTrack` contexts and should
245/// only be used in test cleanup or exceptional circumstances.
246#[cfg(test)]
247pub fn reset_grad_state() {
248 GRAD_ENABLED_STACK.with(|stack| {
249 *stack.borrow_mut() = vec![true];
250 });
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_default_grad_enabled() {
259 reset_grad_state();
260 assert!(is_grad_enabled());
261 }
262
263 #[test]
264 fn test_no_grad_guard_basic() {
265 reset_grad_state();
266 assert!(is_grad_enabled());
267
268 {
269 let _guard = NoGradTrack::new();
270 assert!(!is_grad_enabled());
271 }
272
273 assert!(is_grad_enabled());
274 }
275
276 #[test]
277 fn test_nested_no_grad_guards() {
278 reset_grad_state();
279 assert!(is_grad_enabled());
280
281 {
282 let _guard1 = NoGradTrack::new();
283 assert!(!is_grad_enabled());
284
285 {
286 let _guard2 = NoGradTrack::new();
287 assert!(!is_grad_enabled());
288
289 {
290 let _guard3 = NoGradTrack::new();
291 assert!(!is_grad_enabled());
292 }
293
294 assert!(!is_grad_enabled());
295 }
296
297 assert!(!is_grad_enabled());
298 }
299
300 assert!(is_grad_enabled());
301 }
302
303 #[test]
304 fn test_set_grad_enabled() {
305 reset_grad_state();
306 assert!(is_grad_enabled());
307
308 set_grad_enabled(false);
309 assert!(!is_grad_enabled());
310
311 set_grad_enabled(true);
312 assert!(is_grad_enabled());
313 }
314
315 #[test]
316 fn test_with_no_grad_function() {
317 reset_grad_state();
318 assert!(is_grad_enabled());
319
320 let result = with_no_grad(|| {
321 assert!(!is_grad_enabled());
322 42
323 });
324
325 assert_eq!(result, 42);
326 assert!(is_grad_enabled());
327 }
328
329 #[test]
330 fn test_with_no_grad_nested() {
331 reset_grad_state();
332 assert!(is_grad_enabled());
333
334 with_no_grad(|| {
335 assert!(!is_grad_enabled());
336
337 with_no_grad(|| {
338 assert!(!is_grad_enabled());
339 });
340
341 assert!(!is_grad_enabled());
342 });
343
344 assert!(is_grad_enabled());
345 }
346
347 #[test]
348 fn test_multiple_guards_same_scope() {
349 reset_grad_state();
350 assert!(is_grad_enabled());
351
352 let _guard1 = NoGradTrack::new();
353 assert!(!is_grad_enabled());
354
355 let _guard2 = NoGradTrack::new();
356 assert!(!is_grad_enabled());
357
358 drop(_guard1);
359 assert!(!is_grad_enabled()); // Still disabled due to guard2
360
361 drop(_guard2);
362 assert!(is_grad_enabled()); // Now restored
363 }
364
365 #[test]
366 fn test_early_return_with_guard() {
367 fn test_function() -> i32 {
368 reset_grad_state();
369 assert!(is_grad_enabled());
370
371 let _guard = NoGradTrack::new();
372 assert!(!is_grad_enabled());
373
374 if true {
375 return 42; // Early return should still restore state
376 }
377
378 unreachable!()
379 }
380
381 let result = test_function();
382 assert_eq!(result, 42);
383 assert!(is_grad_enabled()); // Should be restored even with early return
384 }
385
386 #[test]
387 fn test_thread_local_isolation() {
388 reset_grad_state();
389 assert!(is_grad_enabled());
390
391 let handle = std::thread::spawn(|| {
392 // Each thread should start with gradients enabled
393 assert!(is_grad_enabled());
394
395 let _guard = NoGradTrack::new();
396 assert!(!is_grad_enabled());
397
398 // Return the state that should be isolated to this thread
399 is_grad_enabled()
400 });
401
402 // Main thread should still have gradients enabled
403 assert!(is_grad_enabled());
404
405 let other_thread_state = handle.join().unwrap();
406 assert!(!other_thread_state); // Other thread had gradients disabled
407
408 // Main thread should still be unaffected
409 assert!(is_grad_enabled());
410 }
411
412 #[test]
413 fn test_panic_safety() {
414 reset_grad_state();
415 assert!(is_grad_enabled());
416
417 let result = std::panic::catch_unwind(|| {
418 let _guard = NoGradTrack::new();
419 assert!(!is_grad_enabled());
420 panic!("Test panic");
421 });
422
423 assert!(result.is_err());
424
425 // State should be restored even after panic
426 // Note: This might not work in all test runners due to panic handling
427 // but the RAII pattern should ensure cleanup in normal Rust programs
428 assert!(is_grad_enabled());
429 }
430
431 #[test]
432 fn test_grad_state_stack_integrity() {
433 reset_grad_state();
434
435 // Test that the stack maintains integrity through complex operations
436 assert!(is_grad_enabled());
437
438 {
439 let _g1 = NoGradTrack::new();
440 assert!(!is_grad_enabled());
441
442 set_grad_enabled(true); // Manual override
443 assert!(is_grad_enabled());
444
445 {
446 let _g2 = NoGradTrack::new();
447 assert!(!is_grad_enabled());
448 }
449
450 assert!(is_grad_enabled()); // Should restore to manually set state
451 }
452
453 assert!(is_grad_enabled()); // Should restore to original state
454 }
455}