set_grad_enabled

Function set_grad_enabled 

Source
pub fn set_grad_enabled(enabled: bool)
Expand description

Enable or disable gradient tracking for the current thread

This function allows manual control over gradient tracking state, enabling fine-grained control over when gradients are computed. It’s particularly useful for implementing custom training loops or inference optimizations. Manually set the gradient tracking state

This function allows manual control over gradient tracking state. It’s primarily intended for internal use and testing. In most cases, using NoGradTrack is preferred as it provides automatic restoration.

§Arguments

  • enabled - Whether to enable or disable gradient tracking

§Warning

This function modifies the current gradient state without automatic restoration. Use NoGradTrack for RAII-style management in most cases.

§Examples

use train_station::{set_grad_enabled, is_grad_enabled};

assert!(is_grad_enabled());
set_grad_enabled(false);
assert!(!is_grad_enabled());
set_grad_enabled(true);
assert!(is_grad_enabled());