pub struct NoGradTrack { /* private fields */ }Expand description
RAII guard for temporarily disabling gradient tracking
NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking
Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation
within its scope and automatically restores the previous gradient tracking state
when it goes out of scope.
§Performance Benefits
- Prevents computation graph construction during inference
- Reduces memory usage by not storing intermediate values for backpropagation
- Improves computation speed by skipping gradient-related operations
§Examples
use train_station::gradtrack::{NoGradTrack};
use train_station::Tensor;
let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();
// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());
// Computation without gradients
{
let _guard = NoGradTrack::new();
let z2 = x.add_tensor(&y);
assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored
// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());§Nested Contexts
use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};
assert!(is_grad_enabled());
{
let _guard1 = NoGradTrack::new();
assert!(!is_grad_enabled());
{
let _guard2 = NoGradTrack::new();
assert!(!is_grad_enabled());
} // guard2 drops
assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops
assert!(is_grad_enabled()); // RestoredImplementations§
Source§impl NoGradTrack
impl NoGradTrack
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new NoGradTrack that disables gradient tracking
This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.
§Returns
A new NoGradTrack that will restore gradient state when dropped
Examples found in repository?
53 pub fn forward_no_grad(input: &Tensor) -> Tensor {
54 let _guard = NoGradTrack::new();
55 Self::forward(input)
56 }
57}
58
59/// A basic linear layer implementation (reused from basic_linear_layer.rs)
60#[derive(Debug)]
61pub struct LinearLayer {
62 pub weight: Tensor,
63 pub bias: Tensor,
64 pub input_size: usize,
65 pub output_size: usize,
66}
67
68impl LinearLayer {
69 pub fn new(input_size: usize, output_size: usize, seed: Option<u64>) -> Self {
70 let scale = (1.0 / input_size as f32).sqrt();
71
72 let weight = Tensor::randn(vec![input_size, output_size], seed)
73 .mul_scalar(scale)
74 .with_requires_grad();
75 let bias = Tensor::zeros(vec![output_size]).with_requires_grad();
76
77 Self {
78 weight,
79 bias,
80 input_size,
81 output_size,
82 }
83 }
84
85 pub fn forward(&self, input: &Tensor) -> Tensor {
86 let output = input.matmul(&self.weight);
87 output.add_tensor(&self.bias)
88 }
89
90 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
91 let _guard = NoGradTrack::new();
92 self.forward(input)
93 }
94
95 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
96 vec![&mut self.weight, &mut self.bias]
97 }
98}
99
100/// Configuration for feed-forward network
101#[derive(Debug, Clone)]
102pub struct FeedForwardConfig {
103 pub input_size: usize,
104 pub hidden_sizes: Vec<usize>,
105 pub output_size: usize,
106 pub use_bias: bool,
107}
108
109impl Default for FeedForwardConfig {
110 fn default() -> Self {
111 Self {
112 input_size: 4,
113 hidden_sizes: vec![8, 4],
114 output_size: 2,
115 use_bias: true,
116 }
117 }
118}
119
120/// A configurable feed-forward neural network
121pub struct FeedForwardNetwork {
122 layers: Vec<LinearLayer>,
123 config: FeedForwardConfig,
124}
125
126impl FeedForwardNetwork {
127 /// Create a new feed-forward network with the given configuration
128 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
129 let mut layers = Vec::new();
130 let mut current_size = config.input_size;
131 let mut current_seed = seed;
132
133 // Create hidden layers
134 for &hidden_size in &config.hidden_sizes {
135 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
136 current_size = hidden_size;
137 current_seed = current_seed.map(|s| s + 1);
138 }
139
140 // Create output layer
141 layers.push(LinearLayer::new(
142 current_size,
143 config.output_size,
144 current_seed,
145 ));
146
147 Self { layers, config }
148 }
149
150 /// Forward pass through the entire network
151 pub fn forward(&self, input: &Tensor) -> Tensor {
152 let mut x = input.clone();
153
154 // Pass through all layers except the last one with ReLU activation
155 for layer in &self.layers[..self.layers.len() - 1] {
156 x = layer.forward(&x);
157 x = ReLU::forward(&x);
158 }
159
160 // Final layer without activation (raw logits)
161 if let Some(final_layer) = self.layers.last() {
162 x = final_layer.forward(&x);
163 }
164
165 x
166 }
167
168 /// Forward pass without gradients (for inference)
169 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
170 let _guard = NoGradTrack::new();
171 self.forward(input)
172 }