NoGradTrack

Struct NoGradTrack 

Source
pub struct NoGradTrack { /* private fields */ }
Expand description

RAII guard for temporarily disabling gradient tracking

NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking

Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation within its scope and automatically restores the previous gradient tracking state when it goes out of scope.

§Performance Benefits

  • Prevents computation graph construction during inference
  • Reduces memory usage by not storing intermediate values for backpropagation
  • Improves computation speed by skipping gradient-related operations

§Examples

use train_station::{NoGradTrack, Tensor};

let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();

// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());

// Computation without gradients
{
    let _guard = NoGradTrack::new();
    let z2 = x.add_tensor(&y);
    assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored

// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());

§Nested Contexts

use train_station::{NoGradTrack, is_grad_enabled, Tensor};

assert!(is_grad_enabled());

{
    let _guard1 = NoGradTrack::new();
    assert!(!is_grad_enabled());

    {
        let _guard2 = NoGradTrack::new();
        assert!(!is_grad_enabled());
    } // guard2 drops

    assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops

assert!(is_grad_enabled()); // Restored

Implementations§

Source§

impl NoGradTrack

Source

pub fn new() -> Self

Create a new NoGradTrack that disables gradient tracking

This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.

§Returns

A new NoGradTrack that will restore gradient state when dropped

Examples found in repository?
examples/neural_networks/feedforward_network.rs (line 53)
52    pub fn forward_no_grad(input: &Tensor) -> Tensor {
53        let _guard = NoGradTrack::new();
54        Self::forward(input)
55    }
56}
57
58/// A basic linear layer implementation (reused from basic_linear_layer.rs)
59#[derive(Debug)]
60pub struct LinearLayer {
61    pub weight: Tensor,
62    pub bias: Tensor,
63    pub input_size: usize,
64    pub output_size: usize,
65}
66
67impl LinearLayer {
68    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
69        let scale = (1.0 / input_size as f32).sqrt();
70
71        let weight = Tensor::randn(vec![input_size, output_size], seed)
72            .mul_scalar(scale)
73            .with_requires_grad();
74        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
75
76        Self {
77            weight,
78            bias,
79            input_size,
80            output_size,
81        }
82    }
83
84    pub fn forward(&self, input: &Tensor) -> Tensor {
85        let output = input.matmul(&self.weight);
86        output.add_tensor(&self.bias)
87    }
88
89    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
90        let _guard = NoGradTrack::new();
91        self.forward(input)
92    }
93
94    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
95        vec![&mut self.weight, &mut self.bias]
96    }
97}
98
99/// Configuration for feed-forward network
100#[derive(Debug, Clone)]
101pub struct FeedForwardConfig {
102    pub input_size: usize,
103    pub hidden_sizes: Vec<usize>,
104    pub output_size: usize,
105    pub use_bias: bool,
106}
107
108impl Default for FeedForwardConfig {
109    fn default() -> Self {
110        Self {
111            input_size: 4,
112            hidden_sizes: vec![8, 4],
113            output_size: 2,
114            use_bias: true,
115        }
116    }
117}
118
119/// A configurable feed-forward neural network
120pub struct FeedForwardNetwork {
121    layers: Vec<LinearLayer>,
122    config: FeedForwardConfig,
123}
124
125impl FeedForwardNetwork {
126    /// Create a new feed-forward network with the given configuration
127    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
128        let mut layers = Vec::new();
129        let mut current_size = config.input_size;
130        let mut current_seed = seed;
131
132        // Create hidden layers
133        for &hidden_size in &config.hidden_sizes {
134            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
135            current_size = hidden_size;
136            current_seed = current_seed.map(|s| s + 1);
137        }
138
139        // Create output layer
140        layers.push(LinearLayer::new(
141            current_size,
142            config.output_size,
143            current_seed,
144        ));
145
146        Self { layers, config }
147    }
148
149    /// Forward pass through the entire network
150    pub fn forward(&self, input: &Tensor) -> Tensor {
151        let mut x = input.clone();
152
153        // Pass through all layers except the last one with ReLU activation
154        for layer in &self.layers[..self.layers.len() - 1] {
155            x = layer.forward(&x);
156            x = ReLU::forward(&x);
157        }
158
159        // Final layer without activation (raw logits)
160        if let Some(final_layer) = self.layers.last() {
161            x = final_layer.forward(&x);
162        }
163
164        x
165    }
166
167    /// Forward pass without gradients (for inference)
168    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
169        let _guard = NoGradTrack::new();
170        self.forward(input)
171    }
More examples
Hide additional examples
examples/neural_networks/basic_linear_layer.rs (line 79)
78    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
79        let _guard = NoGradTrack::new();
80        self.forward(input)
81    }

Trait Implementations§

Source§

impl Default for NoGradTrack

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Drop for NoGradTrack

Source§

fn drop(&mut self)

Automatically restore the previous gradient tracking state

This ensures that gradient tracking is properly restored even if the guard goes out of scope due to early returns or panics.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.