pub struct NoGradTrack { /* private fields */ }Expand description
RAII guard for temporarily disabling gradient tracking
NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking
Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation
within its scope and automatically restores the previous gradient tracking state
when it goes out of scope.
§Performance Benefits
- Prevents computation graph construction during inference
- Reduces memory usage by not storing intermediate values for backpropagation
- Improves computation speed by skipping gradient-related operations
§Examples
use train_station::{NoGradTrack, Tensor};
let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();
// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());
// Computation without gradients
{
let _guard = NoGradTrack::new();
let z2 = x.add_tensor(&y);
assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored
// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());§Nested Contexts
use train_station::{NoGradTrack, is_grad_enabled, Tensor};
assert!(is_grad_enabled());
{
let _guard1 = NoGradTrack::new();
assert!(!is_grad_enabled());
{
let _guard2 = NoGradTrack::new();
assert!(!is_grad_enabled());
} // guard2 drops
assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops
assert!(is_grad_enabled()); // RestoredImplementations§
Source§impl NoGradTrack
impl NoGradTrack
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new NoGradTrack that disables gradient tracking
This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.
§Returns
A new NoGradTrack that will restore gradient state when dropped
Examples found in repository?
52 pub fn forward_no_grad(input: &Tensor) -> Tensor {
53 let _guard = NoGradTrack::new();
54 Self::forward(input)
55 }
56}
57
58/// A basic linear layer implementation (reused from basic_linear_layer.rs)
59#[derive(Debug)]
60pub struct LinearLayer {
61 pub weight: Tensor,
62 pub bias: Tensor,
63 pub input_size: usize,
64 pub output_size: usize,
65}
66
67impl LinearLayer {
68 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
69 let scale = (1.0 / input_size as f32).sqrt();
70
71 let weight = Tensor::randn(vec![input_size, output_size], seed)
72 .mul_scalar(scale)
73 .with_requires_grad();
74 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
75
76 Self {
77 weight,
78 bias,
79 input_size,
80 output_size,
81 }
82 }
83
84 pub fn forward(&self, input: &Tensor) -> Tensor {
85 let output = input.matmul(&self.weight);
86 output.add_tensor(&self.bias)
87 }
88
89 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
90 let _guard = NoGradTrack::new();
91 self.forward(input)
92 }
93
94 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
95 vec![&mut self.weight, &mut self.bias]
96 }
97}
98
99/// Configuration for feed-forward network
100#[derive(Debug, Clone)]
101pub struct FeedForwardConfig {
102 pub input_size: usize,
103 pub hidden_sizes: Vec<usize>,
104 pub output_size: usize,
105 pub use_bias: bool,
106}
107
108impl Default for FeedForwardConfig {
109 fn default() -> Self {
110 Self {
111 input_size: 4,
112 hidden_sizes: vec![8, 4],
113 output_size: 2,
114 use_bias: true,
115 }
116 }
117}
118
119/// A configurable feed-forward neural network
120pub struct FeedForwardNetwork {
121 layers: Vec<LinearLayer>,
122 config: FeedForwardConfig,
123}
124
125impl FeedForwardNetwork {
126 /// Create a new feed-forward network with the given configuration
127 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
128 let mut layers = Vec::new();
129 let mut current_size = config.input_size;
130 let mut current_seed = seed;
131
132 // Create hidden layers
133 for &hidden_size in &config.hidden_sizes {
134 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
135 current_size = hidden_size;
136 current_seed = current_seed.map(|s| s + 1);
137 }
138
139 // Create output layer
140 layers.push(LinearLayer::new(
141 current_size,
142 config.output_size,
143 current_seed,
144 ));
145
146 Self { layers, config }
147 }
148
149 /// Forward pass through the entire network
150 pub fn forward(&self, input: &Tensor) -> Tensor {
151 let mut x = input.clone();
152
153 // Pass through all layers except the last one with ReLU activation
154 for layer in &self.layers[..self.layers.len() - 1] {
155 x = layer.forward(&x);
156 x = ReLU::forward(&x);
157 }
158
159 // Final layer without activation (raw logits)
160 if let Some(final_layer) = self.layers.last() {
161 x = final_layer.forward(&x);
162 }
163
164 x
165 }
166
167 /// Forward pass without gradients (for inference)
168 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
169 let _guard = NoGradTrack::new();
170 self.forward(input)
171 }