NoGradTrack

Struct NoGradTrack 

Source
pub struct NoGradTrack { /* private fields */ }
Expand description

RAII guard for temporarily disabling gradient tracking

NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking

Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation within its scope and automatically restores the previous gradient tracking state when it goes out of scope.

§Performance Benefits

  • Prevents computation graph construction during inference
  • Reduces memory usage by not storing intermediate values for backpropagation
  • Improves computation speed by skipping gradient-related operations

§Examples

use train_station::gradtrack::{NoGradTrack};
use train_station::Tensor;

let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();

// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());

// Computation without gradients
{
    let _guard = NoGradTrack::new();
    let z2 = x.add_tensor(&y);
    assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored

// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());

§Nested Contexts

use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};

assert!(is_grad_enabled());

{
    let _guard1 = NoGradTrack::new();
    assert!(!is_grad_enabled());

    {
        let _guard2 = NoGradTrack::new();
        assert!(!is_grad_enabled());
    } // guard2 drops

    assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops

assert!(is_grad_enabled()); // Restored

Implementations§

Source§

impl NoGradTrack

Source

pub fn new() -> Self

Create a new NoGradTrack that disables gradient tracking

This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.

§Returns

A new NoGradTrack that will restore gradient state when dropped

Examples found in repository?
examples/neural_networks/feedforward_network.rs (line 54)
53    pub fn forward_no_grad(input: &Tensor) -> Tensor {
54        let _guard = NoGradTrack::new();
55        Self::forward(input)
56    }
57}
58
59/// A basic linear layer implementation (reused from basic_linear_layer.rs)
60#[derive(Debug)]
61pub struct LinearLayer {
62    pub weight: Tensor,
63    pub bias: Tensor,
64    pub input_size: usize,
65    pub output_size: usize,
66}
67
68impl LinearLayer {
69    pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
70        let scale = (1.0 / input_size as f32).sqrt();
71
72        let weight = Tensor::randn(vec![input_size, output_size], seed)
73            .mul_scalar(scale)
74            .with_requires_grad();
75        let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
76
77        Self {
78            weight,
79            bias,
80            input_size,
81            output_size,
82        }
83    }
84
85    pub fn forward(&self, input: &Tensor) -> Tensor {
86        let output = input.matmul(&self.weight);
87        output.add_tensor(&self.bias)
88    }
89
90    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
91        let _guard = NoGradTrack::new();
92        self.forward(input)
93    }
94
95    pub fn parameters(&mut self) -> Vec<&mut Tensor> {
96        vec![&mut self.weight, &mut self.bias]
97    }
98}
99
100/// Configuration for feed-forward network
101#[derive(Debug, Clone)]
102pub struct FeedForwardConfig {
103    pub input_size: usize,
104    pub hidden_sizes: Vec<usize>,
105    pub output_size: usize,
106    pub use_bias: bool,
107}
108
109impl Default for FeedForwardConfig {
110    fn default() -> Self {
111        Self {
112            input_size: 4,
113            hidden_sizes: vec![8, 4],
114            output_size: 2,
115            use_bias: true,
116        }
117    }
118}
119
120/// A configurable feed-forward neural network
121pub struct FeedForwardNetwork {
122    layers: Vec<LinearLayer>,
123    config: FeedForwardConfig,
124}
125
126impl FeedForwardNetwork {
127    /// Create a new feed-forward network with the given configuration
128    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
129        let mut layers = Vec::new();
130        let mut current_size = config.input_size;
131        let mut current_seed = seed;
132
133        // Create hidden layers
134        for &hidden_size in &config.hidden_sizes {
135            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
136            current_size = hidden_size;
137            current_seed = current_seed.map(|s| s + 1);
138        }
139
140        // Create output layer
141        layers.push(LinearLayer::new(
142            current_size,
143            config.output_size,
144            current_seed,
145        ));
146
147        Self { layers, config }
148    }
149
150    /// Forward pass through the entire network
151    pub fn forward(&self, input: &Tensor) -> Tensor {
152        let mut x = input.clone();
153
154        // Pass through all layers except the last one with ReLU activation
155        for layer in &self.layers[..self.layers.len() - 1] {
156            x = layer.forward(&x);
157            x = ReLU::forward(&x);
158        }
159
160        // Final layer without activation (raw logits)
161        if let Some(final_layer) = self.layers.last() {
162            x = final_layer.forward(&x);
163        }
164
165        x
166    }
167
168    /// Forward pass without gradients (for inference)
169    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
170        let _guard = NoGradTrack::new();
171        self.forward(input)
172    }
More examples
Hide additional examples
examples/neural_networks/basic_linear_layer.rs (line 80)
79    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
80        let _guard = NoGradTrack::new();
81        self.forward(input)
82    }

Trait Implementations§

Source§

impl Default for NoGradTrack

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Drop for NoGradTrack

Source§

fn drop(&mut self)

Automatically restore the previous gradient tracking state

This ensures that gradient tracking is properly restored even if the guard goes out of scope due to early returns or panics.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.