NoGradTrack

Struct NoGradTrack 

Source
pub struct NoGradTrack { /* private fields */ }
Expand description

RAII guard for temporarily disabling gradient tracking

NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking

Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation within its scope and automatically restores the previous gradient tracking state when it goes out of scope.

§Performance Benefits

  • Prevents computation graph construction during inference
  • Reduces memory usage by not storing intermediate values for backpropagation
  • Improves computation speed by skipping gradient-related operations

§Examples

use train_station::gradtrack::{NoGradTrack};
use train_station::Tensor;

let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();

// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());

// Computation without gradients
{
    let _guard = NoGradTrack::new();
    let z2 = x.add_tensor(&y);
    assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored

// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());

§Nested Contexts

use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};

assert!(is_grad_enabled());

{
    let _guard1 = NoGradTrack::new();
    assert!(!is_grad_enabled());

    {
        let _guard2 = NoGradTrack::new();
        assert!(!is_grad_enabled());
    } // guard2 drops

    assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops

assert!(is_grad_enabled()); // Restored

Implementations§

Source§

impl NoGradTrack

Source

pub fn new() -> Self

Create a new NoGradTrack that disables gradient tracking

This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.

§Returns

A new NoGradTrack that will restore gradient state when dropped

Examples found in repository?
examples/supervised_training/../neural_networks/feedforward_network.rs (line 61)
60    pub fn forward_no_grad(input: &Tensor) -> Tensor {
61        let _guard = NoGradTrack::new();
62        Self::forward(input)
63    }
64}
65
66/// Configuration for feed-forward network
67#[derive(Debug, Clone)]
68pub struct FeedForwardConfig {
69    pub input_size: usize,
70    pub hidden_sizes: Vec<usize>,
71    pub output_size: usize,
72    #[allow(unused)]
73    pub use_bias: bool,
74}
75
76impl Default for FeedForwardConfig {
77    fn default() -> Self {
78        Self {
79            input_size: 4,
80            hidden_sizes: vec![8, 4],
81            output_size: 2,
82            use_bias: true,
83        }
84    }
85}
86
87/// A configurable feed-forward neural network
88pub struct FeedForwardNetwork {
89    layers: Vec<LinearLayer>,
90    config: FeedForwardConfig,
91}
92
93impl FeedForwardNetwork {
94    /// Create a new feed-forward network with the given configuration
95    pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
96        let mut layers = Vec::new();
97        let mut current_size = config.input_size;
98        let mut current_seed = seed;
99
100        // Create hidden layers
101        for &hidden_size in &config.hidden_sizes {
102            layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
103            current_size = hidden_size;
104            current_seed = current_seed.map(|s| s + 1);
105        }
106
107        // Create output layer
108        layers.push(LinearLayer::new(
109            current_size,
110            config.output_size,
111            current_seed,
112        ));
113
114        Self { layers, config }
115    }
116
117    /// Forward pass through the entire network
118    pub fn forward(&self, input: &Tensor) -> Tensor {
119        let mut x = input.clone();
120
121        // Pass through all layers except the last one with ReLU activation
122        for layer in &self.layers[..self.layers.len() - 1] {
123            x = layer.forward(&x);
124            x = ReLU::forward(&x);
125        }
126
127        // Final layer without activation (raw logits)
128        if let Some(final_layer) = self.layers.last() {
129            x = final_layer.forward(&x);
130        }
131
132        x
133    }
134
135    /// Forward pass without gradients (for inference)
136    #[allow(unused)]
137    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
138        let _guard = NoGradTrack::new();
139        self.forward(input)
140    }
More examples
Hide additional examples
examples/RL_training/../neural_networks/basic_linear_layer.rs (line 81)
80    pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
81        let _guard = NoGradTrack::new();
82        self.forward(input)
83    }
examples/RL_training/dqn.rs (line 310)
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310    let _ng = NoGradTrack::new();
311    let mut total_sq = 0.0f32;
312    for p in parameters.iter_mut() {
313        for &v in p.data() {
314            total_sq += v * v;
315        }
316    }
317    total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}
examples/RL_training/td3.rs (line 132)
131    fn soft_update_from(&mut self, source: &Self, tau: f32) {
132        let _ng = NoGradTrack::new();
133        for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134            // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135            let new_w = t
136                .weight
137                .mul_scalar(1.0 - tau)
138                .add_tensor(&s.weight.mul_scalar(tau));
139            let new_b = t
140                .bias
141                .mul_scalar(1.0 - tau)
142                .add_tensor(&s.bias.mul_scalar(tau));
143            {
144                let src = new_w.data();
145                let dst = t.weight.data_mut();
146                dst.copy_from_slice(src);
147            }
148            {
149                let src = new_b.data();
150                let dst = t.bias.data_mut();
151                dst.copy_from_slice(src);
152            }
153            t.weight.set_requires_grad(false);
154            t.bias.set_requires_grad(false);
155        }
156    }
157}
158
159// -------------------------------
160// Actor and Critic
161// -------------------------------
162
163struct Actor {
164    net: Mlp,
165}
166
167impl Actor {
168    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169        // Smaller net for faster demo: sd -> 64 -> 64 -> ad, tanh output
170        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171        Self { net }
172    }
173    fn forward(&self, state: &Tensor) -> Tensor {
174        self.net.forward(state, Some(tanh_bounded))
175    }
176    fn parameters(&mut self) -> Vec<&mut Tensor> {
177        self.net.parameters()
178    }
179    fn set_requires_grad_all(&mut self, enable: bool) {
180        self.net.set_requires_grad_all(enable);
181    }
182}
183
184struct Critic {
185    net: Mlp,
186}
187
188impl Critic {
189    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190        let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191        Self { net }
192    }
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
201    fn parameters(&mut self) -> Vec<&mut Tensor> {
202        self.net.parameters()
203    }
204    fn set_requires_grad_all(&mut self, enable: bool) {
205        self.net.set_requires_grad_all(enable);
206    }
207}
208
209// -------------------------------
210// Simple continuous control environment: YardEnv
211// State: normalized features [pos/3, clamp(vel/1, -1..1), bias(=0)] ; Action: scalar in [-1, 1]
212// Dynamics: vel += 0.1*act - 0.01*pos; pos += vel
213// Reward: -(pos^2) - 0.1*act^2 ; Episode ends if |pos| > 3 or step >= max_steps
214// -------------------------------
215
216struct YardEnv {
217    pos: f32,
218    vel: f32,
219    steps: usize,
220    max_steps: usize,
221    rng: SmallRng,
222}
223
224impl YardEnv {
225    fn new(seed: u64) -> Self {
226        let mut env = Self {
227            pos: 0.0,
228            vel: 0.0,
229            steps: 0,
230            max_steps: 200,
231            rng: SmallRng::new(seed),
232        };
233        env.reset();
234        env
235    }
236
237    fn reset(&mut self) -> Tensor {
238        self.pos = self.rng.uniform(-0.5, 0.5);
239        self.vel = self.rng.uniform(-0.1, 0.1);
240        self.steps = 0;
241        self.state_tensor()
242    }
243
244    fn state_tensor(&self) -> Tensor {
245        // Normalize to keep critic inputs bounded:
246        // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247        // - Velocity scaled by 1.0 and clamped to [-1,1]
248        let pos_n = self.pos / 3.0;
249        let vel_n = self.vel.clamp(-1.0, 1.0);
250        Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251    }
252
253    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254        let a = action_value.clamp(-1.0, 1.0);
255        self.vel += 0.1 * a - 0.01 * self.pos;
256        self.pos += self.vel;
257        self.steps += 1;
258
259        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261        (self.state_tensor(), reward, done)
262    }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270    capacity: usize,
271    size: usize,
272    pos: usize,
273    state_dim: usize,
274    action_dim: usize,
275    states: Vec<f32>,
276    actions: Vec<f32>,
277    rewards: Vec<f32>,
278    dones: Vec<f32>,
279    next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283    fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284        Self {
285            capacity,
286            size: 0,
287            pos: 0,
288            state_dim,
289            action_dim,
290            states: vec![0.0; capacity * state_dim],
291            actions: vec![0.0; capacity * action_dim],
292            rewards: vec![0.0; capacity],
293            dones: vec![0.0; capacity],
294            next_states: vec![0.0; capacity * state_dim],
295        }
296    }
297
298    fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299        let i = self.pos;
300        let so = i * self.state_dim;
301        let ao = i * self.action_dim;
302        self.states[so..so + self.state_dim].copy_from_slice(s);
303        self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304        self.rewards[i] = r;
305        self.dones[i] = d;
306        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308        self.pos = (self.pos + 1) % self.capacity;
309        self.size = self.size.saturating_add(1).min(self.capacity);
310    }
311
312    fn can_sample(&self, batch_size: usize) -> bool {
313        self.size >= batch_size
314    }
315
316    fn sample(
317        &self,
318        batch_size: usize,
319        rng: &mut SmallRng,
320    ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322        let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323        let mut r_vec = Vec::with_capacity(batch_size);
324        let mut d_vec = Vec::with_capacity(batch_size);
325        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327        for _ in 0..batch_size {
328            let idx = rng.sample_index(self.size);
329            let so = idx * self.state_dim;
330            let ao = idx * self.action_dim;
331            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332            a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333            r_vec.push(self.rewards[idx]);
334            d_vec.push(self.dones[idx]);
335            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336        }
337
338        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339        let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343        (s, a, r, d, s2)
344    }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352    // Compute global L2 norm of all grads
353    let mut total_sq = 0.0f32;
354    for p in parameters.iter() {
355        if let Some(g) = p.grad_owned() {
356            for &v in g.data() {
357                total_sq += v * v;
358            }
359        }
360    }
361    let norm = total_sq.sqrt();
362    if norm > max_norm {
363        let scale = max_norm / (norm + eps);
364        for p in parameters.iter_mut() {
365            if let Some(g) = p.grad_owned() {
366                let scaled = g.mul_scalar(scale);
367                p.set_grad(scaled);
368            }
369        }
370    }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375    let mut total_sq = 0.0f32;
376    for p in parameters.iter_mut() {
377        if let Some(g) = p.grad_owned() {
378            for &v in g.data() {
379                total_sq += v * v;
380            }
381        }
382    }
383    total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388    let _ng = NoGradTrack::new();
389    let mut total_sq = 0.0f32;
390    for p in parameters.iter_mut() {
391        for &v in p.data() {
392            total_sq += v * v;
393        }
394    }
395    total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403    println!("=== TD3 Example (YardEnv) ===");
404
405    // Environment / problem dims
406    let state_dim = 3usize;
407    let action_dim = 1usize;
408
409    // Hyperparameters (small for demo)
410    let gamma = 0.99f32;
411    let tau = 0.005f32; // Polyak
412    let policy_noise = 0.2f32; // target smoothing noise stddev
413    let exploration_noise = 0.1f32; // behavior policy noise stddev
414    let policy_delay = 2usize;
415    let batch_size = 64usize;
416    let start_steps = 500usize; // random exploration steps
417    let total_steps = 1500usize;
418    let max_grad_norm = 1.0f32;
419
420    // Models
421    let mut actor = Actor::new(state_dim, action_dim, Some(11));
422    let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423    actor_targ.net.copy_from(&actor.net);
424    actor_targ.set_requires_grad_all(false);
425
426    let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427    let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428    let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429    let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430    critic1_targ.net.copy_from(&critic1.net);
431    critic2_targ.net.copy_from(&critic2.net);
432    critic1_targ.set_requires_grad_all(false);
433    critic2_targ.set_requires_grad_all(false);
434
435    // Optimizers
436    let mut actor_opt = Adam::with_learning_rate(1e-3);
437    for p in actor.parameters() {
438        actor_opt.add_parameter(p);
439    }
440
441    let mut critic_opt = Adam::with_learning_rate(1e-4);
442    for p in critic1.parameters() {
443        critic_opt.add_parameter(p);
444    }
445    for p in critic2.parameters() {
446        critic_opt.add_parameter(p);
447    }
448
449    // Replay buffer and env
450    let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451    let mut env = YardEnv::new(1234);
452    let mut rng = SmallRng::new(987654321);
453
454    // Reset & metric trackers
455    let mut state = env.reset(); // [1, state_dim]
456    let mut episode_return = 0.0f32;
457    let mut episode = 0usize;
458    let mut ema_return: Option<f32> = None;
459    let ema_alpha = 0.05f32; // smooth short-term
460    let mut best_return = f32::NEG_INFINITY;
461    let mut policy_updates: usize = 0;
462
463    for t in 0..total_steps {
464        // Select action
465        let action_tensor = if t < start_steps {
466            let a = rng.uniform(-1.0, 1.0);
467            Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468        } else {
469            // Behavior policy with exploration noise
470            let _ng = NoGradTrack::new();
471            let det = actor.forward(&state);
472            let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473            tanh_bounded(&det.add_tensor(&noise))
474        };
475        let action_value = action_tensor.data()[0];
476
477        // Environment step
478        let (next_state, reward, done) = env.step(action_value);
479        episode_return += reward;
480
481        // Store transition
482        let s_slice = state.data().to_vec();
483        let a_slice = action_tensor.data().to_vec();
484        let s2_slice = next_state.data().to_vec();
485        rb.push(
486            &s_slice,
487            &a_slice,
488            reward,
489            if done { 1.0 } else { 0.0 },
490            &s2_slice,
491        );
492
493        state = if done {
494            let st = env.reset();
495            // Metrics: update EMA and best
496            ema_return = Some(match ema_return {
497                None => episode_return,
498                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499            });
500            if episode_return > best_return {
501                best_return = episode_return;
502            }
503            println!(
504                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505                t,
506                episode,
507                episode_return,
508                ema_return.unwrap_or(episode_return),
509                best_return,
510                rb.size,
511                policy_updates
512            );
513            episode_return = 0.0;
514            episode += 1;
515            st
516        } else {
517            next_state
518        };
519
520        // Training
521        if rb.can_sample(batch_size) {
522            // Sample batch
523            let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525            // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526            let target_q = {
527                let _ng = NoGradTrack::new();
528                // Target actions with smoothing noise (tanh bounds)
529                let noise =
530                    Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531                let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532                let q1_t = critic1_targ.forward(&s2, &a_targ);
533                let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535                // Elementwise min via data() since this path is no-grad
536                let q1d = q1_t.data();
537                let q2d = q2_t.data();
538                let mut min_vec = Vec::with_capacity(batch_size);
539                for i in 0..batch_size {
540                    let v1 = q1d[i];
541                    let v2 = q2d[i];
542                    min_vec.push(v1.min(v2));
543                }
544                let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&min_q))
547            };
548
549            // Critic update (both critics)
550            // Zero grads in a short scope, then drop borrows before forward
551            {
552                let mut params = {
553                    let c_params = critic1.parameters();
554                    let c2_params = critic2.parameters();
555                    let mut tmp: Vec<&mut Tensor> = Vec::new();
556                    tmp.extend(c_params);
557                    tmp.extend(c2_params);
558                    tmp
559                };
560                critic_opt.zero_grad(&mut params);
561            }
562
563            // Forward current Q estimates
564            let q1 = critic1.forward(&s, &a);
565            let q2 = critic2.forward(&s, &a);
566            let diff1 = q1.sub_tensor(&target_q);
567            let diff2 = q2.sub_tensor(&target_q);
568            let mut critic_loss = diff1
569                .pow_scalar(2.0)
570                .mean()
571                .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573            // Backward
574            critic_loss.backward(None);
575
576            // Optional gradient clipping + step (only for params that received grads)
577            {
578                let params = {
579                    let c_params = critic1.parameters();
580                    let c2_params = critic2.parameters();
581                    let mut tmp: Vec<&mut Tensor> = Vec::new();
582                    tmp.extend(c_params);
583                    tmp.extend(c2_params);
584                    tmp
585                };
586                let mut with_grads: Vec<&mut Tensor> = Vec::new();
587                for p in params {
588                    if p.grad_owned().is_some() {
589                        with_grads.push(p);
590                    }
591                }
592                if !with_grads.is_empty() {
593                    // Pre-step metrics
594                    let grad_norm_before = grad_global_norm(&mut with_grads);
595                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596                    critic_opt.step(&mut with_grads);
597                    critic_opt.zero_grad(&mut with_grads);
598
599                    // Post-step metrics (param norm)
600                    let mut for_norm_params = {
601                        let c_params = critic1.parameters();
602                        let c2_params = critic2.parameters();
603                        let mut tmp: Vec<&mut Tensor> = Vec::new();
604                        tmp.extend(c_params);
605                        tmp.extend(c2_params);
606                        tmp
607                    };
608                    let param_norm = params_l2_norm(&mut for_norm_params);
609
610                    // Print compact critic metrics occasionally
611                    if t % 100 == 0 {
612                        let q1_mean = q1.mean().value();
613                        let q2_mean = q2.mean().value();
614                        let tq_mean = target_q.mean().value();
615                        println!(
616                            "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617                            t,
618                            critic_loss.value(),
619                            q1_mean,
620                            q2_mean,
621                            tq_mean,
622                            grad_norm_before,
623                            param_norm
624                        );
625                    }
626                }
627            }
628
629            // Delayed policy update
630            if t % policy_delay == 0 {
631                // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632                // Zero actor grads before backward
633                {
634                    let mut a_params: Vec<&mut Tensor> = actor.parameters();
635                    actor_opt.zero_grad(&mut a_params);
636                }
637
638                let a_pred = actor.forward(&s);
639                let q_for_actor = critic1.forward(&s, &a_pred);
640                let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641                actor_loss.backward(None);
642
643                {
644                    let a_params: Vec<&mut Tensor> = actor.parameters();
645                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
646                    for p in a_params {
647                        if p.grad_owned().is_some() {
648                            with_grads.push(p);
649                        }
650                    }
651                    if !with_grads.is_empty() {
652                        let grad_norm_before = grad_global_norm(&mut with_grads);
653                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654                        actor_opt.step(&mut with_grads);
655                        actor_opt.zero_grad(&mut with_grads);
656
657                        // Post-step param norm
658                        let mut for_norm_params = actor.parameters();
659                        let param_norm = params_l2_norm(&mut for_norm_params);
660
661                        policy_updates += 1;
662                        if t % 200 == 0 {
663                            println!(
664                                "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665                                t,
666                                actor_loss.value(),
667                                grad_norm_before,
668                                param_norm,
669                                actor_opt.learning_rate(),
670                                critic_opt.learning_rate(),
671                                policy_updates
672                            );
673                        }
674                    }
675                }
676
677                // Target updates (Polyak averaging, no grad)
678                actor_targ.net.soft_update_from(&actor.net, tau);
679                critic1_targ.net.soft_update_from(&critic1.net, tau);
680                critic2_targ.net.soft_update_from(&critic2.net, tau);
681            }
682
683            // Clear entire graphs to avoid stale accumulation across iterations
684            clear_all_graphs_known();
685        }
686    }
687
688    println!("=== TD3 training finished ===");
689    Ok(())
690}
examples/RL_training/ppo_discrete.rs (line 378)
319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320    println!("=== PPO Discrete Example (YardEnv) ===");
321
322    let state_dim = 3usize;
323    let action_dim = 3usize;
324    let total_steps = std::env::var("PPOD_STEPS")
325        .ok()
326        .and_then(|v| v.parse::<usize>().ok())
327        .unwrap_or(3500usize);
328    let horizon = 128usize;
329    let epochs = 4usize;
330    let mini_batch_size = 64usize;
331    let gamma = 0.99f32;
332    let lam = 0.95f32;
333    let clip_eps = 0.2f32;
334    let vf_coef = 0.5f32;
335    let ent_coef = 0.0f32;
336    let max_grad_norm = 1.0f32;
337
338    let mut actor = Actor::new(state_dim, action_dim, Some(111));
339    let mut critic = Critic::new(state_dim, Some(222));
340    let mut actor_opt = Adam::with_learning_rate(3e-4);
341    for p in actor.parameters() {
342        actor_opt.add_parameter(p);
343    }
344    let mut critic_opt = Adam::with_learning_rate(3e-4);
345    for p in critic.parameters() {
346        critic_opt.add_parameter(p);
347    }
348
349    let mut env = YardEnv::new(1234);
350    let mut rng = SmallRng::new(98765);
351    let mut state = env.reset();
352    let mut episode_return = 0.0f32;
353    let mut episode = 0usize;
354    let mut ema_return: Option<f32> = None;
355    let ema_alpha = 0.05f32;
356    let mut best_return = f32::NEG_INFINITY;
357
358    let mut t = 0usize;
359    while t < total_steps {
360        let mut batch = RolloutBatch::new(horizon, state_dim);
361        for _ in 0..horizon {
362            // Actor logits and categorical sampling
363            let logits = actor.forward(&state); // [1, A]
364            let probs = logits.softmax(1); // [1, A]
365                                           // sample action from probs (CPU sampling)
366            let p = probs.data();
367            let (p0, p1, _p2) = (p[0], p[1], p[2]);
368            let u = rng.next_f32();
369            let a_idx = if u < p0 {
370                0
371            } else if u < p0 + p1 {
372                1
373            } else {
374                2
375            };
376
377            let old_logp = {
378                let _ng = NoGradTrack::new();
379                let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380                lp.data()[0]
381            };
382
383            // Step env
384            let (next_state, reward, done) = env.step(a_idx);
385            episode_return += reward;
386
387            // Critic value
388            let value_t = critic.forward(&state);
389            let value_v = value_t.data()[0];
390
391            batch.push(
392                state.data(),
393                a_idx,
394                old_logp,
395                reward,
396                if done { 1.0 } else { 0.0 },
397                value_v,
398                next_state.data(),
399            );
400
401            state = if done {
402                let st = env.reset();
403                ema_return = Some(match ema_return {
404                    None => episode_return,
405                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406                });
407                if episode_return > best_return {
408                    best_return = episode_return;
409                }
410                println!(
411                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412                    t,
413                    episode,
414                    episode_return,
415                    ema_return.unwrap_or(episode_return),
416                    best_return
417                );
418                episode_return = 0.0;
419                episode += 1;
420                st
421            } else {
422                next_state
423            };
424
425            t += 1;
426            if t >= total_steps {
427                break;
428            }
429        }
430
431        // Bootstrap values for GAE
432        let next_values: Vec<f32> = {
433            let mut out = Vec::with_capacity(batch.len());
434            for i in 0..batch.len() {
435                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437                out.push(critic.forward(&s2_t).data()[0]);
438            }
439            out
440        };
441
442        let mut returns = vec![0.0f32; batch.len()];
443        let mut adv = vec![0.0f32; batch.len()];
444        compute_gae(
445            &mut returns,
446            &mut adv,
447            &batch.rewards,
448            &batch.dones,
449            &batch.values,
450            &next_values,
451            gamma,
452            lam,
453        );
454        normalize_in_place(&mut adv, 1e-8);
455
456        // Tensors for training
457        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458        let actions_vec = batch.actions.clone();
459        let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463        // PPO epochs
464        let num_minibatches = batch.len().div_ceil(mini_batch_size);
465        for e in 0..epochs {
466            for mb in 0..num_minibatches {
467                let start = mb * mini_batch_size;
468                let end = (start + mini_batch_size).min(batch.len());
469                if start >= end {
470                    break;
471                }
472
473                // Views
474                let s_mb = states_t
475                    .slice_view(start * state_dim, 1, (end - start) * state_dim)
476                    .reshape(vec![(end - start) as i32, state_dim as i32]);
477                let oldlp_mb = old_logp_t
478                    .slice_view(start, 1, end - start)
479                    .reshape(vec![(end - start) as i32, 1]);
480                let ret_mb = returns_t
481                    .slice_view(start, 1, end - start)
482                    .reshape(vec![(end - start) as i32, 1]);
483                let adv_mb = adv_t
484                    .slice_view(start, 1, end - start)
485                    .reshape(vec![(end - start) as i32, 1]);
486                let a_slice = &actions_vec[start..end];
487
488                // Zero grads
489                {
490                    let mut ps = actor.parameters();
491                    actor_opt.zero_grad(&mut ps);
492                }
493                {
494                    let mut ps = critic.parameters();
495                    critic_opt.zero_grad(&mut ps);
496                }
497
498                // Forward
499                let logits_mb = actor.forward(&s_mb); // [B,A]
500                let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501                let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502                let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503                let pg1 = ratio.mul_tensor(&adv_mb);
504                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507                let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509                let v_pred = critic.forward(&s_mb);
510                let v_loss = v_pred
511                    .sub_tensor(&ret_mb)
512                    .pow_scalar(2.0)
513                    .mean()
514                    .mul_scalar(vf_coef);
515
516                // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517                let probs_mb = logits_mb.softmax(1);
518                let logp_all = probs_mb.add_scalar(1e-8).log();
519                let ent = probs_mb
520                    .mul_tensor(&logp_all)
521                    .sum_dims(&[1], true)
522                    .mul_scalar(-1.0)
523                    .mean()
524                    .mul_scalar(ent_coef);
525
526                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527                loss.backward(None);
528
529                // Step actor
530                {
531                    let params = actor.parameters();
532                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
533                    for p in params {
534                        if p.grad_owned().is_some() {
535                            with_grads.push(p);
536                        }
537                    }
538                    if !with_grads.is_empty() {
539                        let _ = grad_global_norm(&mut with_grads);
540                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541                        actor_opt.step(&mut with_grads);
542                        actor_opt.zero_grad(&mut with_grads);
543                    }
544                }
545
546                // Step critic
547                {
548                    let params = critic.parameters();
549                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
550                    for p in params {
551                        if p.grad_owned().is_some() {
552                            with_grads.push(p);
553                        }
554                    }
555                    if !with_grads.is_empty() {
556                        let _ = grad_global_norm(&mut with_grads);
557                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558                        critic_opt.step(&mut with_grads);
559                        critic_opt.zero_grad(&mut with_grads);
560                    }
561                }
562
563                if e == 0 && mb == 0 {
564                    println!(
565                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
566                        t,
567                        actor_loss.value(),
568                        v_loss.value()
569                    );
570                }
571
572                clear_all_graphs_known();
573            }
574        }
575    }
576
577    println!("=== PPO discrete training finished ===");
578    Ok(())
579}

Trait Implementations§

Source§

impl Default for NoGradTrack

Source§

fn default() -> Self

Returns the “default value” for a type. Read more
Source§

impl Drop for NoGradTrack

Source§

fn drop(&mut self)

Automatically restore the previous gradient tracking state

This ensures that gradient tracking is properly restored even if the guard goes out of scope due to early returns or panics.

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.