pub struct NoGradTrack { /* private fields */ }Expand description
RAII guard for temporarily disabling gradient tracking
NoGradTrack provides a scope-based mechanism for disabling gradient tracking, automatically restoring the previous state when the guard is dropped. It ensures proper gradient state management even in the presence of early returns or exceptions. A RAII guard that temporarily disables gradient tracking
Similar to PyTorch’s torch.no_grad(), this guard disables gradient computation
within its scope and automatically restores the previous gradient tracking state
when it goes out of scope.
§Performance Benefits
- Prevents computation graph construction during inference
- Reduces memory usage by not storing intermediate values for backpropagation
- Improves computation speed by skipping gradient-related operations
§Examples
use train_station::gradtrack::{NoGradTrack};
use train_station::Tensor;
let x = Tensor::ones(vec![3, 3]).with_requires_grad();
let y = Tensor::ones(vec![3, 3]).with_requires_grad();
// Normal computation with gradients
let z1 = x.add_tensor(&y);
assert!(z1.requires_grad());
// Computation without gradients
{
let _guard = NoGradTrack::new();
let z2 = x.add_tensor(&y);
assert!(!z2.requires_grad()); // Gradients disabled
} // Guard drops here, gradients restored
// Gradients are automatically restored
let z3 = x.add_tensor(&y);
assert!(z3.requires_grad());§Nested Contexts
use train_station::{gradtrack::NoGradTrack, gradtrack::is_grad_enabled, Tensor};
assert!(is_grad_enabled());
{
let _guard1 = NoGradTrack::new();
assert!(!is_grad_enabled());
{
let _guard2 = NoGradTrack::new();
assert!(!is_grad_enabled());
} // guard2 drops
assert!(!is_grad_enabled()); // Still disabled
} // guard1 drops
assert!(is_grad_enabled()); // RestoredImplementations§
Source§impl NoGradTrack
impl NoGradTrack
Sourcepub fn new() -> Self
pub fn new() -> Self
Create a new NoGradTrack that disables gradient tracking
This function pushes the current gradient state onto the stack and disables gradient tracking. When the guard is dropped, the previous state is automatically restored.
§Returns
A new NoGradTrack that will restore gradient state when dropped
Examples found in repository?
60 pub fn forward_no_grad(input: &Tensor) -> Tensor {
61 let _guard = NoGradTrack::new();
62 Self::forward(input)
63 }
64}
65
66/// Configuration for feed-forward network
67#[derive(Debug, Clone)]
68pub struct FeedForwardConfig {
69 pub input_size: usize,
70 pub hidden_sizes: Vec<usize>,
71 pub output_size: usize,
72 #[allow(unused)]
73 pub use_bias: bool,
74}
75
76impl Default for FeedForwardConfig {
77 fn default() -> Self {
78 Self {
79 input_size: 4,
80 hidden_sizes: vec![8, 4],
81 output_size: 2,
82 use_bias: true,
83 }
84 }
85}
86
87/// A configurable feed-forward neural network
88pub struct FeedForwardNetwork {
89 layers: Vec<LinearLayer>,
90 config: FeedForwardConfig,
91}
92
93impl FeedForwardNetwork {
94 /// Create a new feed-forward network with the given configuration
95 pub fn new(config: FeedForwardConfig, seed: Option<u64>) -> Self {
96 let mut layers = Vec::new();
97 let mut current_size = config.input_size;
98 let mut current_seed = seed;
99
100 // Create hidden layers
101 for &hidden_size in &config.hidden_sizes {
102 layers.push(LinearLayer::new(current_size, hidden_size, current_seed));
103 current_size = hidden_size;
104 current_seed = current_seed.map(|s| s + 1);
105 }
106
107 // Create output layer
108 layers.push(LinearLayer::new(
109 current_size,
110 config.output_size,
111 current_seed,
112 ));
113
114 Self { layers, config }
115 }
116
117 /// Forward pass through the entire network
118 pub fn forward(&self, input: &Tensor) -> Tensor {
119 let mut x = input.clone();
120
121 // Pass through all layers except the last one with ReLU activation
122 for layer in &self.layers[..self.layers.len() - 1] {
123 x = layer.forward(&x);
124 x = ReLU::forward(&x);
125 }
126
127 // Final layer without activation (raw logits)
128 if let Some(final_layer) = self.layers.last() {
129 x = final_layer.forward(&x);
130 }
131
132 x
133 }
134
135 /// Forward pass without gradients (for inference)
136 #[allow(unused)]
137 pub fn forward_no_grad(&self, input: &Tensor) -> Tensor {
138 let _guard = NoGradTrack::new();
139 self.forward(input)
140 }More examples
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310 let _ng = NoGradTrack::new();
311 let mut total_sq = 0.0f32;
312 for p in parameters.iter_mut() {
313 for &v in p.data() {
314 total_sq += v * v;
315 }
316 }
317 total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322 diff.pow_scalar(2.0)
323 .add_scalar(1.0)
324 .sqrt()
325 .sub_scalar(1.0)
326 .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334 println!("=== DQN Example (YardEnv discrete) ===");
335
336 // Dims
337 let state_dim = 3usize;
338 let action_dim = 3usize;
339
340 // Hparams
341 let gamma = 0.99f32;
342 let batch_size = 64usize;
343 let start_steps = 200usize;
344 let target_update_interval = 200usize; // hard update cadence
345 let max_grad_norm = 1.0f32;
346 let mut epsilon = 1.0f32;
347 let eps_min = 0.05f32;
348 let eps_decay_steps = 2_000usize; // linear decay
349 let total_steps = std::env::var("DQN_STEPS")
350 .ok()
351 .and_then(|v| v.parse::<usize>().ok())
352 .unwrap_or(3000usize);
353
354 // Models
355 let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356 let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357 q_targ.net.copy_from(&q_net.net);
358 q_targ.set_requires_grad_all(false);
359
360 // Optimizer
361 let mut q_opt = Adam::with_learning_rate(3e-4);
362 for p in q_net.parameters() {
363 q_opt.add_parameter(p);
364 }
365
366 // Replay + env
367 let mut rb = ReplayBuffer::new(100_000, state_dim);
368 let mut env = YardEnv::new(12345);
369 let mut rng = SmallRng::new(999_111);
370
371 // Metrics
372 let mut state = env.reset();
373 let mut episode_return = 0.0f32;
374 let mut episode = 0usize;
375 let mut ema_return: Option<f32> = None;
376 let ema_alpha = 0.05f32;
377 let mut best_return = f32::NEG_INFINITY;
378
379 for t in 0..total_steps {
380 // Epsilon-greedy action
381 let action_index = if t < start_steps || rng.next_f32() < epsilon {
382 rng.sample_index(action_dim)
383 } else {
384 let _ng = NoGradTrack::new();
385 let q_vals = q_net.forward(&state);
386 let row = q_vals.data();
387 let mut best_i = 0usize;
388 let mut best_v = row[0];
389 for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390 if r > best_v {
391 best_v = r;
392 best_i = i;
393 }
394 }
395 best_i
396 };
397
398 // Env step
399 let (next_state, reward, done) = env.step(action_index);
400 episode_return += reward;
401
402 // Store
403 let s_slice = state.data().to_vec();
404 let s2_slice = next_state.data().to_vec();
405 rb.push(
406 &s_slice,
407 action_index,
408 reward,
409 if done { 1.0 } else { 0.0 },
410 &s2_slice,
411 );
412
413 // Reset on done
414 state = if done {
415 let st = env.reset();
416 ema_return = Some(match ema_return {
417 None => episode_return,
418 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419 });
420 if episode_return > best_return {
421 best_return = episode_return;
422 }
423 println!(
424 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425 t,
426 episode,
427 episode_return,
428 ema_return.unwrap_or(episode_return),
429 best_return,
430 rb.size
431 );
432 episode_return = 0.0;
433 episode += 1;
434 st
435 } else {
436 next_state
437 };
438
439 // Epsilon linear decay
440 if t < eps_decay_steps {
441 epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442 }
443
444 // Train
445 if rb.can_sample(batch_size) {
446 let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448 // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449 let target_q = {
450 let _ng = NoGradTrack::new();
451 let q_online_s2 = q_net.forward(&s2);
452 // argmax per row (manual on CPU)
453 let row_stride = action_dim;
454 let qd = q_online_s2.data();
455 let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456 for i in 0..batch_size {
457 let base = i * row_stride;
458 let mut bi = 0usize;
459 let mut bv = qd[base];
460 for j in 1..action_dim {
461 let v = qd[base + j];
462 if v > bv {
463 bv = v;
464 bi = j;
465 }
466 }
467 next_actions.push(bi);
468 }
469 let q_targ_s2 = q_targ.forward(&s2);
470 let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473 };
474
475 // Q(s,a) for current actions
476 // Zero grads first
477 {
478 let mut params = q_net.parameters();
479 q_opt.zero_grad(&mut params);
480 }
481
482 let q_all = q_net.forward(&s);
483 let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484 let diff = q_sa.sub_tensor(&target_q);
485 let mut loss = pseudo_huber_mean(&diff);
486 loss.backward(None);
487
488 // Step (filter only params with grads)
489 {
490 let params = q_net.parameters();
491 let mut with_grads: Vec<&mut Tensor> = Vec::new();
492 for p in params {
493 if p.grad_owned().is_some() {
494 with_grads.push(p);
495 }
496 }
497 if !with_grads.is_empty() {
498 let gn = grad_global_norm(&mut with_grads);
499 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500 q_opt.step(&mut with_grads);
501 q_opt.zero_grad(&mut with_grads);
502 if t % 100 == 0 {
503 let mut pn = q_net.parameters();
504 let pn_l2 = params_l2_norm(&mut pn);
505 let q_mean = q_all.mean().value();
506 println!(
507 "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508 t, loss.value(), q_mean, gn, pn_l2, epsilon
509 );
510 }
511 }
512 }
513
514 // Target hard update
515 if t % target_update_interval == 0 {
516 q_targ.net.copy_from(&q_net.net);
517 }
518
519 // Clear graphs
520 clear_all_graphs_known();
521 }
522 }
523
524 println!("=== DQN training finished ===");
525 Ok(())
526}131 fn soft_update_from(&mut self, source: &Self, tau: f32) {
132 let _ng = NoGradTrack::new();
133 for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134 // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135 let new_w = t
136 .weight
137 .mul_scalar(1.0 - tau)
138 .add_tensor(&s.weight.mul_scalar(tau));
139 let new_b = t
140 .bias
141 .mul_scalar(1.0 - tau)
142 .add_tensor(&s.bias.mul_scalar(tau));
143 {
144 let src = new_w.data();
145 let dst = t.weight.data_mut();
146 dst.copy_from_slice(src);
147 }
148 {
149 let src = new_b.data();
150 let dst = t.bias.data_mut();
151 dst.copy_from_slice(src);
152 }
153 t.weight.set_requires_grad(false);
154 t.bias.set_requires_grad(false);
155 }
156 }
157}
158
159// -------------------------------
160// Actor and Critic
161// -------------------------------
162
163struct Actor {
164 net: Mlp,
165}
166
167impl Actor {
168 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169 // Smaller net for faster demo: sd -> 64 -> 64 -> ad, tanh output
170 let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171 Self { net }
172 }
173 fn forward(&self, state: &Tensor) -> Tensor {
174 self.net.forward(state, Some(tanh_bounded))
175 }
176 fn parameters(&mut self) -> Vec<&mut Tensor> {
177 self.net.parameters()
178 }
179 fn set_requires_grad_all(&mut self, enable: bool) {
180 self.net.set_requires_grad_all(enable);
181 }
182}
183
184struct Critic {
185 net: Mlp,
186}
187
188impl Critic {
189 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190 let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191 Self { net }
192 }
193 fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194 // Concatenate along feature dim (dim=1) for batched inputs
195 // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196 let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197 let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198 let sa = Tensor::cat(&[s_view, a_view], 1);
199 self.net.forward(&sa, None)
200 }
201 fn parameters(&mut self) -> Vec<&mut Tensor> {
202 self.net.parameters()
203 }
204 fn set_requires_grad_all(&mut self, enable: bool) {
205 self.net.set_requires_grad_all(enable);
206 }
207}
208
209// -------------------------------
210// Simple continuous control environment: YardEnv
211// State: normalized features [pos/3, clamp(vel/1, -1..1), bias(=0)] ; Action: scalar in [-1, 1]
212// Dynamics: vel += 0.1*act - 0.01*pos; pos += vel
213// Reward: -(pos^2) - 0.1*act^2 ; Episode ends if |pos| > 3 or step >= max_steps
214// -------------------------------
215
216struct YardEnv {
217 pos: f32,
218 vel: f32,
219 steps: usize,
220 max_steps: usize,
221 rng: SmallRng,
222}
223
224impl YardEnv {
225 fn new(seed: u64) -> Self {
226 let mut env = Self {
227 pos: 0.0,
228 vel: 0.0,
229 steps: 0,
230 max_steps: 200,
231 rng: SmallRng::new(seed),
232 };
233 env.reset();
234 env
235 }
236
237 fn reset(&mut self) -> Tensor {
238 self.pos = self.rng.uniform(-0.5, 0.5);
239 self.vel = self.rng.uniform(-0.1, 0.1);
240 self.steps = 0;
241 self.state_tensor()
242 }
243
244 fn state_tensor(&self) -> Tensor {
245 // Normalize to keep critic inputs bounded:
246 // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247 // - Velocity scaled by 1.0 and clamped to [-1,1]
248 let pos_n = self.pos / 3.0;
249 let vel_n = self.vel.clamp(-1.0, 1.0);
250 Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251 }
252
253 fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254 let a = action_value.clamp(-1.0, 1.0);
255 self.vel += 0.1 * a - 0.01 * self.pos;
256 self.pos += self.vel;
257 self.steps += 1;
258
259 let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261 (self.state_tensor(), reward, done)
262 }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270 capacity: usize,
271 size: usize,
272 pos: usize,
273 state_dim: usize,
274 action_dim: usize,
275 states: Vec<f32>,
276 actions: Vec<f32>,
277 rewards: Vec<f32>,
278 dones: Vec<f32>,
279 next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283 fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284 Self {
285 capacity,
286 size: 0,
287 pos: 0,
288 state_dim,
289 action_dim,
290 states: vec![0.0; capacity * state_dim],
291 actions: vec![0.0; capacity * action_dim],
292 rewards: vec![0.0; capacity],
293 dones: vec![0.0; capacity],
294 next_states: vec![0.0; capacity * state_dim],
295 }
296 }
297
298 fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299 let i = self.pos;
300 let so = i * self.state_dim;
301 let ao = i * self.action_dim;
302 self.states[so..so + self.state_dim].copy_from_slice(s);
303 self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304 self.rewards[i] = r;
305 self.dones[i] = d;
306 self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308 self.pos = (self.pos + 1) % self.capacity;
309 self.size = self.size.saturating_add(1).min(self.capacity);
310 }
311
312 fn can_sample(&self, batch_size: usize) -> bool {
313 self.size >= batch_size
314 }
315
316 fn sample(
317 &self,
318 batch_size: usize,
319 rng: &mut SmallRng,
320 ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321 let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322 let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323 let mut r_vec = Vec::with_capacity(batch_size);
324 let mut d_vec = Vec::with_capacity(batch_size);
325 let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327 for _ in 0..batch_size {
328 let idx = rng.sample_index(self.size);
329 let so = idx * self.state_dim;
330 let ao = idx * self.action_dim;
331 s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332 a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333 r_vec.push(self.rewards[idx]);
334 d_vec.push(self.dones[idx]);
335 s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336 }
337
338 let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339 let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340 let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341 let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342 let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343 (s, a, r, d, s2)
344 }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352 // Compute global L2 norm of all grads
353 let mut total_sq = 0.0f32;
354 for p in parameters.iter() {
355 if let Some(g) = p.grad_owned() {
356 for &v in g.data() {
357 total_sq += v * v;
358 }
359 }
360 }
361 let norm = total_sq.sqrt();
362 if norm > max_norm {
363 let scale = max_norm / (norm + eps);
364 for p in parameters.iter_mut() {
365 if let Some(g) = p.grad_owned() {
366 let scaled = g.mul_scalar(scale);
367 p.set_grad(scaled);
368 }
369 }
370 }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375 let mut total_sq = 0.0f32;
376 for p in parameters.iter_mut() {
377 if let Some(g) = p.grad_owned() {
378 for &v in g.data() {
379 total_sq += v * v;
380 }
381 }
382 }
383 total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388 let _ng = NoGradTrack::new();
389 let mut total_sq = 0.0f32;
390 for p in parameters.iter_mut() {
391 for &v in p.data() {
392 total_sq += v * v;
393 }
394 }
395 total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403 println!("=== TD3 Example (YardEnv) ===");
404
405 // Environment / problem dims
406 let state_dim = 3usize;
407 let action_dim = 1usize;
408
409 // Hyperparameters (small for demo)
410 let gamma = 0.99f32;
411 let tau = 0.005f32; // Polyak
412 let policy_noise = 0.2f32; // target smoothing noise stddev
413 let exploration_noise = 0.1f32; // behavior policy noise stddev
414 let policy_delay = 2usize;
415 let batch_size = 64usize;
416 let start_steps = 500usize; // random exploration steps
417 let total_steps = 1500usize;
418 let max_grad_norm = 1.0f32;
419
420 // Models
421 let mut actor = Actor::new(state_dim, action_dim, Some(11));
422 let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423 actor_targ.net.copy_from(&actor.net);
424 actor_targ.set_requires_grad_all(false);
425
426 let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427 let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428 let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429 let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430 critic1_targ.net.copy_from(&critic1.net);
431 critic2_targ.net.copy_from(&critic2.net);
432 critic1_targ.set_requires_grad_all(false);
433 critic2_targ.set_requires_grad_all(false);
434
435 // Optimizers
436 let mut actor_opt = Adam::with_learning_rate(1e-3);
437 for p in actor.parameters() {
438 actor_opt.add_parameter(p);
439 }
440
441 let mut critic_opt = Adam::with_learning_rate(1e-4);
442 for p in critic1.parameters() {
443 critic_opt.add_parameter(p);
444 }
445 for p in critic2.parameters() {
446 critic_opt.add_parameter(p);
447 }
448
449 // Replay buffer and env
450 let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451 let mut env = YardEnv::new(1234);
452 let mut rng = SmallRng::new(987654321);
453
454 // Reset & metric trackers
455 let mut state = env.reset(); // [1, state_dim]
456 let mut episode_return = 0.0f32;
457 let mut episode = 0usize;
458 let mut ema_return: Option<f32> = None;
459 let ema_alpha = 0.05f32; // smooth short-term
460 let mut best_return = f32::NEG_INFINITY;
461 let mut policy_updates: usize = 0;
462
463 for t in 0..total_steps {
464 // Select action
465 let action_tensor = if t < start_steps {
466 let a = rng.uniform(-1.0, 1.0);
467 Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468 } else {
469 // Behavior policy with exploration noise
470 let _ng = NoGradTrack::new();
471 let det = actor.forward(&state);
472 let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473 tanh_bounded(&det.add_tensor(&noise))
474 };
475 let action_value = action_tensor.data()[0];
476
477 // Environment step
478 let (next_state, reward, done) = env.step(action_value);
479 episode_return += reward;
480
481 // Store transition
482 let s_slice = state.data().to_vec();
483 let a_slice = action_tensor.data().to_vec();
484 let s2_slice = next_state.data().to_vec();
485 rb.push(
486 &s_slice,
487 &a_slice,
488 reward,
489 if done { 1.0 } else { 0.0 },
490 &s2_slice,
491 );
492
493 state = if done {
494 let st = env.reset();
495 // Metrics: update EMA and best
496 ema_return = Some(match ema_return {
497 None => episode_return,
498 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499 });
500 if episode_return > best_return {
501 best_return = episode_return;
502 }
503 println!(
504 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505 t,
506 episode,
507 episode_return,
508 ema_return.unwrap_or(episode_return),
509 best_return,
510 rb.size,
511 policy_updates
512 );
513 episode_return = 0.0;
514 episode += 1;
515 st
516 } else {
517 next_state
518 };
519
520 // Training
521 if rb.can_sample(batch_size) {
522 // Sample batch
523 let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525 // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526 let target_q = {
527 let _ng = NoGradTrack::new();
528 // Target actions with smoothing noise (tanh bounds)
529 let noise =
530 Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531 let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532 let q1_t = critic1_targ.forward(&s2, &a_targ);
533 let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535 // Elementwise min via data() since this path is no-grad
536 let q1d = q1_t.data();
537 let q2d = q2_t.data();
538 let mut min_vec = Vec::with_capacity(batch_size);
539 for i in 0..batch_size {
540 let v1 = q1d[i];
541 let v2 = q2d[i];
542 min_vec.push(v1.min(v2));
543 }
544 let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545 let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546 r.add_tensor(¬_done.mul_scalar(gamma).mul_tensor(&min_q))
547 };
548
549 // Critic update (both critics)
550 // Zero grads in a short scope, then drop borrows before forward
551 {
552 let mut params = {
553 let c_params = critic1.parameters();
554 let c2_params = critic2.parameters();
555 let mut tmp: Vec<&mut Tensor> = Vec::new();
556 tmp.extend(c_params);
557 tmp.extend(c2_params);
558 tmp
559 };
560 critic_opt.zero_grad(&mut params);
561 }
562
563 // Forward current Q estimates
564 let q1 = critic1.forward(&s, &a);
565 let q2 = critic2.forward(&s, &a);
566 let diff1 = q1.sub_tensor(&target_q);
567 let diff2 = q2.sub_tensor(&target_q);
568 let mut critic_loss = diff1
569 .pow_scalar(2.0)
570 .mean()
571 .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573 // Backward
574 critic_loss.backward(None);
575
576 // Optional gradient clipping + step (only for params that received grads)
577 {
578 let params = {
579 let c_params = critic1.parameters();
580 let c2_params = critic2.parameters();
581 let mut tmp: Vec<&mut Tensor> = Vec::new();
582 tmp.extend(c_params);
583 tmp.extend(c2_params);
584 tmp
585 };
586 let mut with_grads: Vec<&mut Tensor> = Vec::new();
587 for p in params {
588 if p.grad_owned().is_some() {
589 with_grads.push(p);
590 }
591 }
592 if !with_grads.is_empty() {
593 // Pre-step metrics
594 let grad_norm_before = grad_global_norm(&mut with_grads);
595 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596 critic_opt.step(&mut with_grads);
597 critic_opt.zero_grad(&mut with_grads);
598
599 // Post-step metrics (param norm)
600 let mut for_norm_params = {
601 let c_params = critic1.parameters();
602 let c2_params = critic2.parameters();
603 let mut tmp: Vec<&mut Tensor> = Vec::new();
604 tmp.extend(c_params);
605 tmp.extend(c2_params);
606 tmp
607 };
608 let param_norm = params_l2_norm(&mut for_norm_params);
609
610 // Print compact critic metrics occasionally
611 if t % 100 == 0 {
612 let q1_mean = q1.mean().value();
613 let q2_mean = q2.mean().value();
614 let tq_mean = target_q.mean().value();
615 println!(
616 "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617 t,
618 critic_loss.value(),
619 q1_mean,
620 q2_mean,
621 tq_mean,
622 grad_norm_before,
623 param_norm
624 );
625 }
626 }
627 }
628
629 // Delayed policy update
630 if t % policy_delay == 0 {
631 // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632 // Zero actor grads before backward
633 {
634 let mut a_params: Vec<&mut Tensor> = actor.parameters();
635 actor_opt.zero_grad(&mut a_params);
636 }
637
638 let a_pred = actor.forward(&s);
639 let q_for_actor = critic1.forward(&s, &a_pred);
640 let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641 actor_loss.backward(None);
642
643 {
644 let a_params: Vec<&mut Tensor> = actor.parameters();
645 let mut with_grads: Vec<&mut Tensor> = Vec::new();
646 for p in a_params {
647 if p.grad_owned().is_some() {
648 with_grads.push(p);
649 }
650 }
651 if !with_grads.is_empty() {
652 let grad_norm_before = grad_global_norm(&mut with_grads);
653 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654 actor_opt.step(&mut with_grads);
655 actor_opt.zero_grad(&mut with_grads);
656
657 // Post-step param norm
658 let mut for_norm_params = actor.parameters();
659 let param_norm = params_l2_norm(&mut for_norm_params);
660
661 policy_updates += 1;
662 if t % 200 == 0 {
663 println!(
664 "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665 t,
666 actor_loss.value(),
667 grad_norm_before,
668 param_norm,
669 actor_opt.learning_rate(),
670 critic_opt.learning_rate(),
671 policy_updates
672 );
673 }
674 }
675 }
676
677 // Target updates (Polyak averaging, no grad)
678 actor_targ.net.soft_update_from(&actor.net, tau);
679 critic1_targ.net.soft_update_from(&critic1.net, tau);
680 critic2_targ.net.soft_update_from(&critic2.net, tau);
681 }
682
683 // Clear entire graphs to avoid stale accumulation across iterations
684 clear_all_graphs_known();
685 }
686 }
687
688 println!("=== TD3 training finished ===");
689 Ok(())
690}319pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 // Actor logits and categorical sampling
363 let logits = actor.forward(&state); // [1, A]
364 let probs = logits.softmax(1); // [1, A]
365 // sample action from probs (CPU sampling)
366 let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 // Step env
384 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 // Critic value
388 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 // Bootstrap values for GAE
432 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 // Tensors for training
457 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 // PPO epochs
464 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 // Views
474 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 // Zero grads
489 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 // Forward
499 let logits_mb = actor.forward(&s_mb); // [B,A]
500 let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); // [B,1]
501 let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
506 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 // Entropy bonus from logits (categorical entropy) ≈ -sum p*logp
517 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 // Step actor
530 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 // Step critic
547 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}