1use train_station::{
13 gradtrack::{clear_all_graphs_known, NoGradTrack},
14 optimizers::{Adam, Optimizer},
15 Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23struct SmallRng {
28 state: u64,
29}
30impl SmallRng {
31 fn new(seed: u64) -> Self {
32 Self { state: seed.max(1) }
33 }
34 fn next_u32(&mut self) -> u32 {
35 self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
36 (self.state >> 16) as u32
37 }
38 fn next_f32(&mut self) -> f32 {
39 (self.next_u32() as f32) / (u32::MAX as f32)
40 }
41}
42
43struct Mlp {
48 layers: Vec<LinearLayer>,
49}
50impl Mlp {
51 fn new(sizes: &[usize], seed: Option<u64>) -> Self {
52 let mut layers = Vec::new();
53 let mut s = seed;
54 for w in sizes.windows(2) {
55 layers.push(LinearLayer::new(w[0], w[1], s));
56 s = s.map(|v| v + 1);
57 }
58 Self { layers }
59 }
60 fn forward(&self, input: &Tensor) -> Tensor {
61 let mut current: Option<Tensor> = None;
62 for (i, layer) in self.layers.iter().enumerate() {
63 let out = if i == 0 {
64 layer.forward(input)
65 } else {
66 layer.forward(current.as_ref().unwrap())
67 };
68 let is_last = i + 1 == self.layers.len();
69 let out = if !is_last { out.relu() } else { out };
70 current = Some(out);
71 }
72 current.expect("MLP has at least one layer")
73 }
74 fn parameters(&mut self) -> Vec<&mut Tensor> {
75 self.layers
76 .iter_mut()
77 .flat_map(|l| l.parameters())
78 .collect()
79 }
80}
81
82struct Actor {
87 net: Mlp,
88}
89impl Actor {
90 fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
91 Self {
92 net: Mlp::new(&[state_dim, 64, 64, action_dim], seed),
93 }
94 }
95 fn forward(&self, state: &Tensor) -> Tensor {
96 self.net.forward(state)
97 } fn parameters(&mut self) -> Vec<&mut Tensor> {
99 self.net.parameters()
100 }
101}
102
103struct Critic {
104 net: Mlp,
105}
106impl Critic {
107 fn new(state_dim: usize, seed: Option<u64>) -> Self {
108 Self {
109 net: Mlp::new(&[state_dim, 64, 64, 1], seed),
110 }
111 }
112 fn forward(&self, state: &Tensor) -> Tensor {
113 self.net.forward(state)
114 }
115 fn parameters(&mut self) -> Vec<&mut Tensor> {
116 self.net.parameters()
117 }
118}
119
120struct YardEnv {
125 pos: f32,
126 vel: f32,
127 steps: usize,
128 max_steps: usize,
129 rng: SmallRng,
130}
131impl YardEnv {
132 const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
133 fn new(seed: u64) -> Self {
134 let mut e = Self {
135 pos: 0.0,
136 vel: 0.0,
137 steps: 0,
138 max_steps: 200,
139 rng: SmallRng::new(seed),
140 };
141 e.reset();
142 e
143 }
144 fn reset(&mut self) -> Tensor {
145 self.pos = (self.rng.next_f32() * 1.0) - 0.5;
146 self.vel = (self.rng.next_f32() * 0.2) - 0.1;
147 self.steps = 0;
148 self.state_tensor()
149 }
150 fn state_tensor(&self) -> Tensor {
151 Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
152 }
153 fn step(&mut self, action_idx: usize) -> (Tensor, f32, bool) {
154 let a = Self::ACTIONS[action_idx.min(2)];
155 self.vel += 0.1 * a - 0.01 * self.pos;
156 self.pos += self.vel;
157 self.steps += 1;
158 let reward = -(self.pos * self.pos) - 0.05 * (a * a);
159 let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
160 (self.state_tensor(), reward, done)
161 }
162}
163
164struct RolloutBatch {
169 states: Vec<f32>,
170 actions: Vec<usize>,
171 old_logps: Vec<f32>,
172 rewards: Vec<f32>,
173 dones: Vec<f32>,
174 values: Vec<f32>,
175 next_states: Vec<f32>,
176 _state_dim: usize,
177}
178impl RolloutBatch {
179 fn new(cap: usize, sd: usize) -> Self {
180 Self {
181 states: Vec::with_capacity(cap * sd),
182 actions: Vec::with_capacity(cap),
183 old_logps: Vec::with_capacity(cap),
184 rewards: Vec::with_capacity(cap),
185 dones: Vec::with_capacity(cap),
186 values: Vec::with_capacity(cap),
187 next_states: Vec::with_capacity(cap * sd),
188 _state_dim: sd,
189 }
190 }
191 #[allow(clippy::too_many_arguments)]
192 fn push(&mut self, s: &[f32], a: usize, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
193 self.states.extend_from_slice(s);
194 self.actions.push(a);
195 self.old_logps.push(lp);
196 self.rewards.push(r);
197 self.dones.push(d);
198 self.values.push(v);
199 self.next_states.extend_from_slice(s2);
200 }
201 fn len(&self) -> usize {
202 self.actions.len()
203 }
204}
205
206#[allow(clippy::too_many_arguments)]
211fn compute_gae(
212 returns_out: &mut [f32],
213 adv_out: &mut [f32],
214 rewards: &[f32],
215 dones: &[f32],
216 values: &[f32],
217 next_values: &[f32],
218 gamma: f32,
219 lam: f32,
220) {
221 let n = rewards.len();
222 let mut gae = 0.0f32;
223 for t in (0..n).rev() {
224 let not_done = 1.0 - dones[t];
225 let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
226 gae = delta + gamma * lam * not_done * gae;
227 adv_out[t] = gae;
228 returns_out[t] = gae + values[t];
229 }
230}
231
232fn normalize_in_place(x: &mut [f32], eps: f32) {
233 let n = x.len() as f32;
234 if n <= 1.0 {
235 return;
236 }
237 let mean = x.iter().copied().sum::<f32>() / n;
238 let var = x
239 .iter()
240 .map(|v| {
241 let d = v - mean;
242 d * d
243 })
244 .sum::<f32>()
245 / n;
246 let std = (var + eps).sqrt();
247 for v in x.iter_mut() {
248 *v = (*v - mean) / std;
249 }
250}
251
252fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
253 let mut total_sq = 0.0f32;
254 for p in parameters.iter() {
255 if let Some(g) = p.grad_owned() {
256 for &v in g.data() {
257 total_sq += v * v;
258 }
259 }
260 }
261 let norm = total_sq.sqrt();
262 if norm > max_norm {
263 let scale = max_norm / (norm + eps);
264 for p in parameters.iter_mut() {
265 if let Some(g) = p.grad_owned() {
266 p.set_grad(g.mul_scalar(scale));
267 }
268 }
269 }
270}
271
272fn log_prob_actions(
274 logits: &Tensor,
275 actions: &[usize],
276 batch: usize,
277 _action_dim: usize,
278) -> Tensor {
279 let max_logits = logits.max_dims(&[1], true); let shifted = logits.sub_tensor(&max_logits);
281 let exp = shifted.exp();
282 let sum_exp = exp.sum_dims(&[1], true); let log_sum_exp = sum_exp.log(); let log_softmax = shifted.sub_tensor(&log_sum_exp); log_softmax.gather(1, actions, &[batch, 1])
287}
288
289fn ratio_from_logps(new_logp: &Tensor, old_logp: &Tensor) -> Tensor {
291 new_logp.sub_tensor(old_logp).exp()
292}
293
294fn clamp_ratio(ratio: &Tensor, clip_eps: f32) -> Tensor {
296 let b = ratio.shape().dims()[0];
297 let low = Tensor::from_slice(&vec![1.0 - clip_eps; b], vec![b, 1]).unwrap();
298 let high = Tensor::from_slice(&vec![1.0 + clip_eps; b], vec![b, 1]).unwrap();
299 let ge_low = ratio.sub_tensor(&low).relu().add_tensor(&low);
300 high.sub_tensor(&ge_low.sub_tensor(&high).relu())
301}
302
303fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
304 let mut total_sq = 0.0f32;
305 for p in parameters.iter_mut() {
306 if let Some(g) = p.grad_owned() {
307 for &v in g.data() {
308 total_sq += v * v;
309 }
310 }
311 }
312 total_sq.sqrt()
313}
314
315pub fn main() -> Result<(), Box<dyn std::error::Error>> {
320 println!("=== PPO Discrete Example (YardEnv) ===");
321
322 let state_dim = 3usize;
323 let action_dim = 3usize;
324 let total_steps = std::env::var("PPOD_STEPS")
325 .ok()
326 .and_then(|v| v.parse::<usize>().ok())
327 .unwrap_or(3500usize);
328 let horizon = 128usize;
329 let epochs = 4usize;
330 let mini_batch_size = 64usize;
331 let gamma = 0.99f32;
332 let lam = 0.95f32;
333 let clip_eps = 0.2f32;
334 let vf_coef = 0.5f32;
335 let ent_coef = 0.0f32;
336 let max_grad_norm = 1.0f32;
337
338 let mut actor = Actor::new(state_dim, action_dim, Some(111));
339 let mut critic = Critic::new(state_dim, Some(222));
340 let mut actor_opt = Adam::with_learning_rate(3e-4);
341 for p in actor.parameters() {
342 actor_opt.add_parameter(p);
343 }
344 let mut critic_opt = Adam::with_learning_rate(3e-4);
345 for p in critic.parameters() {
346 critic_opt.add_parameter(p);
347 }
348
349 let mut env = YardEnv::new(1234);
350 let mut rng = SmallRng::new(98765);
351 let mut state = env.reset();
352 let mut episode_return = 0.0f32;
353 let mut episode = 0usize;
354 let mut ema_return: Option<f32> = None;
355 let ema_alpha = 0.05f32;
356 let mut best_return = f32::NEG_INFINITY;
357
358 let mut t = 0usize;
359 while t < total_steps {
360 let mut batch = RolloutBatch::new(horizon, state_dim);
361 for _ in 0..horizon {
362 let logits = actor.forward(&state); let probs = logits.softmax(1); let p = probs.data();
367 let (p0, p1, _p2) = (p[0], p[1], p[2]);
368 let u = rng.next_f32();
369 let a_idx = if u < p0 {
370 0
371 } else if u < p0 + p1 {
372 1
373 } else {
374 2
375 };
376
377 let old_logp = {
378 let _ng = NoGradTrack::new();
379 let lp = log_prob_actions(&logits, &[a_idx], 1, action_dim);
380 lp.data()[0]
381 };
382
383 let (next_state, reward, done) = env.step(a_idx);
385 episode_return += reward;
386
387 let value_t = critic.forward(&state);
389 let value_v = value_t.data()[0];
390
391 batch.push(
392 state.data(),
393 a_idx,
394 old_logp,
395 reward,
396 if done { 1.0 } else { 0.0 },
397 value_v,
398 next_state.data(),
399 );
400
401 state = if done {
402 let st = env.reset();
403 ema_return = Some(match ema_return {
404 None => episode_return,
405 Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
406 });
407 if episode_return > best_return {
408 best_return = episode_return;
409 }
410 println!(
411 "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
412 t,
413 episode,
414 episode_return,
415 ema_return.unwrap_or(episode_return),
416 best_return
417 );
418 episode_return = 0.0;
419 episode += 1;
420 st
421 } else {
422 next_state
423 };
424
425 t += 1;
426 if t >= total_steps {
427 break;
428 }
429 }
430
431 let next_values: Vec<f32> = {
433 let mut out = Vec::with_capacity(batch.len());
434 for i in 0..batch.len() {
435 let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
436 let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
437 out.push(critic.forward(&s2_t).data()[0]);
438 }
439 out
440 };
441
442 let mut returns = vec![0.0f32; batch.len()];
443 let mut adv = vec![0.0f32; batch.len()];
444 compute_gae(
445 &mut returns,
446 &mut adv,
447 &batch.rewards,
448 &batch.dones,
449 &batch.values,
450 &next_values,
451 gamma,
452 lam,
453 );
454 normalize_in_place(&mut adv, 1e-8);
455
456 let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
458 let actions_vec = batch.actions.clone();
459 let old_logp_t = Tensor::from_slice(&batch.old_logps, vec![batch.len(), 1]).unwrap();
460 let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
461 let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
462
463 let num_minibatches = batch.len().div_ceil(mini_batch_size);
465 for e in 0..epochs {
466 for mb in 0..num_minibatches {
467 let start = mb * mini_batch_size;
468 let end = (start + mini_batch_size).min(batch.len());
469 if start >= end {
470 break;
471 }
472
473 let s_mb = states_t
475 .slice_view(start * state_dim, 1, (end - start) * state_dim)
476 .reshape(vec![(end - start) as i32, state_dim as i32]);
477 let oldlp_mb = old_logp_t
478 .slice_view(start, 1, end - start)
479 .reshape(vec![(end - start) as i32, 1]);
480 let ret_mb = returns_t
481 .slice_view(start, 1, end - start)
482 .reshape(vec![(end - start) as i32, 1]);
483 let adv_mb = adv_t
484 .slice_view(start, 1, end - start)
485 .reshape(vec![(end - start) as i32, 1]);
486 let a_slice = &actions_vec[start..end];
487
488 {
490 let mut ps = actor.parameters();
491 actor_opt.zero_grad(&mut ps);
492 }
493 {
494 let mut ps = critic.parameters();
495 critic_opt.zero_grad(&mut ps);
496 }
497
498 let logits_mb = actor.forward(&s_mb); let new_logp_mb = log_prob_actions(&logits_mb, a_slice, end - start, action_dim); let ratio = ratio_from_logps(&new_logp_mb, &oldlp_mb);
502 let ratio_clipped = clamp_ratio(&ratio, clip_eps);
503 let pg1 = ratio.mul_tensor(&adv_mb);
504 let pg2 = ratio_clipped.mul_tensor(&adv_mb);
505 let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
507 let actor_loss = actor_min.mul_scalar(-1.0).mean();
508
509 let v_pred = critic.forward(&s_mb);
510 let v_loss = v_pred
511 .sub_tensor(&ret_mb)
512 .pow_scalar(2.0)
513 .mean()
514 .mul_scalar(vf_coef);
515
516 let probs_mb = logits_mb.softmax(1);
518 let logp_all = probs_mb.add_scalar(1e-8).log();
519 let ent = probs_mb
520 .mul_tensor(&logp_all)
521 .sum_dims(&[1], true)
522 .mul_scalar(-1.0)
523 .mean()
524 .mul_scalar(ent_coef);
525
526 let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&ent);
527 loss.backward(None);
528
529 {
531 let params = actor.parameters();
532 let mut with_grads: Vec<&mut Tensor> = Vec::new();
533 for p in params {
534 if p.grad_owned().is_some() {
535 with_grads.push(p);
536 }
537 }
538 if !with_grads.is_empty() {
539 let _ = grad_global_norm(&mut with_grads);
540 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
541 actor_opt.step(&mut with_grads);
542 actor_opt.zero_grad(&mut with_grads);
543 }
544 }
545
546 {
548 let params = critic.parameters();
549 let mut with_grads: Vec<&mut Tensor> = Vec::new();
550 for p in params {
551 if p.grad_owned().is_some() {
552 with_grads.push(p);
553 }
554 }
555 if !with_grads.is_empty() {
556 let _ = grad_global_norm(&mut with_grads);
557 clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
558 critic_opt.step(&mut with_grads);
559 critic_opt.zero_grad(&mut with_grads);
560 }
561 }
562
563 if e == 0 && mb == 0 {
564 println!(
565 "update@t={} | actor_loss={:.4} v_loss={:.4}",
566 t,
567 actor_loss.value(),
568 v_loss.value()
569 );
570 }
571
572 clear_all_graphs_known();
573 }
574 }
575 }
576
577 println!("=== PPO discrete training finished ===");
578 Ok(())
579}