ppo_continuous/
ppo_continuous.rs

1//! PPO (Proximal Policy Optimization) - Continuous actions example using Train Station public API
2//!
3//! - Continuous `YardEnv` (action in [-1, 1])
4//! - Actor (Gaussian policy: mean from MLP, learnable log_std) + Critic (value function)
5//! - Trajectory collection, GAE advantages, PPO clipped surrogate objective
6//! - Gradient clipping, zero_grad, clear_all_graphs between updates
7//! - Reuses `basic_linear_layer.rs` for small MLPs; no unsafe code
8//!
9//! Run:
10//!   cargo run --release --example ppo_continuous
11
12use train_station::{
13    gradtrack::clear_all_graphs_known,
14    optimizers::{Adam, Optimizer},
15    Tensor,
16};
17
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23// -------------------------------
24// Small RNG
25// -------------------------------
26
27struct SmallRng {
28    state: u64,
29}
30impl SmallRng {
31    fn new(seed: u64) -> Self {
32        Self { state: seed.max(1) }
33    }
34    fn next_u32(&mut self) -> u32 {
35        self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
36        (self.state >> 16) as u32
37    }
38    fn next_f32(&mut self) -> f32 {
39        (self.next_u32() as f32) / (u32::MAX as f32)
40    }
41    fn normal(&mut self) -> f32 {
42        // Box-Muller
43        let u1 = self.next_f32().clamp(1e-7, 1.0 - 1e-7);
44        let u2 = self.next_f32();
45        let r = (-2.0 * u1.ln()).sqrt();
46        let theta = 2.0 * std::f32::consts::PI * u2;
47        r * theta.cos()
48    }
49}
50
51// -------------------------------
52// MLP
53// -------------------------------
54
55struct Mlp {
56    layers: Vec<LinearLayer>,
57}
58impl Mlp {
59    fn new(sizes: &[usize], seed: Option<u64>) -> Self {
60        let mut layers = Vec::new();
61        let mut s = seed;
62        for w in sizes.windows(2) {
63            layers.push(LinearLayer::new(w[0], w[1], s));
64            s = s.map(|v| v + 1);
65        }
66        Self { layers }
67    }
68    fn forward(&self, input: &Tensor) -> Tensor {
69        let mut current: Option<Tensor> = None;
70        for (i, layer) in self.layers.iter().enumerate() {
71            let out = if i == 0 {
72                layer.forward(input)
73            } else {
74                layer.forward(current.as_ref().unwrap())
75            };
76            let is_last = i + 1 == self.layers.len();
77            let out = if !is_last { out.relu() } else { out };
78            current = Some(out);
79        }
80        current.expect("MLP has at least one layer")
81    }
82    fn parameters(&mut self) -> Vec<&mut Tensor> {
83        self.layers
84            .iter_mut()
85            .flat_map(|l| l.parameters())
86            .collect()
87    }
88}
89
90// -------------------------------
91// Actor: mean = MLP(state); log_std is a learnable parameter tensor
92// -------------------------------
93
94struct Actor {
95    net: Mlp,
96    log_std: Tensor, // shape [action_dim]
97}
98impl Actor {
99    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
100        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
101        let log_std = Tensor::from_slice(&vec![0.0; action_dim], vec![action_dim])
102            .unwrap()
103            .with_requires_grad();
104        Self { net, log_std }
105    }
106    fn forward(&self, state: &Tensor) -> (Tensor, Tensor) {
107        // Returns (mean [B, A], log_std [A])
108        let mean = self.net.forward(state);
109        (
110            mean,
111            self.log_std
112                .view(vec![1, self.log_std.shape().dims()[0] as i32]),
113        )
114    }
115    fn parameters(&mut self) -> Vec<&mut Tensor> {
116        let mut ps = self.net.parameters();
117        ps.push(&mut self.log_std);
118        ps
119    }
120}
121
122// -------------------------------
123// Critic: value function V(s)
124// -------------------------------
125
126struct Critic {
127    net: Mlp,
128}
129impl Critic {
130    fn new(state_dim: usize, seed: Option<u64>) -> Self {
131        Self {
132            net: Mlp::new(&[state_dim, 64, 64, 1], seed),
133        }
134    }
135    fn forward(&self, state: &Tensor) -> Tensor {
136        self.net.forward(state)
137    }
138    fn parameters(&mut self) -> Vec<&mut Tensor> {
139        self.net.parameters()
140    }
141}
142
143// -------------------------------
144// Continuous YardEnv (same dynamics as TD3 env)
145// -------------------------------
146
147struct YardEnv {
148    pos: f32,
149    vel: f32,
150    steps: usize,
151    max_steps: usize,
152    rng: SmallRng,
153}
154impl YardEnv {
155    fn new(seed: u64) -> Self {
156        let mut e = Self {
157            pos: 0.0,
158            vel: 0.0,
159            steps: 0,
160            max_steps: 200,
161            rng: SmallRng::new(seed),
162        };
163        e.reset();
164        e
165    }
166    fn reset(&mut self) -> Tensor {
167        self.pos = (self.rng.next_f32() * 1.0) - 0.5;
168        self.vel = (self.rng.next_f32() * 0.2) - 0.1;
169        self.steps = 0;
170        self.state_tensor()
171    }
172    fn state_tensor(&self) -> Tensor {
173        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
174    }
175    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
176        let a = action_value.clamp(-1.0, 1.0);
177        self.vel += 0.1 * a - 0.01 * self.pos;
178        self.pos += self.vel;
179        self.steps += 1;
180        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
181        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
182        (self.state_tensor(), reward, done)
183    }
184}
185
186// -------------------------------
187// Trajectory storage
188// -------------------------------
189
190struct RolloutBatch {
191    states: Vec<f32>,
192    actions: Vec<f32>,
193    log_probs: Vec<f32>,
194    rewards: Vec<f32>,
195    dones: Vec<f32>,
196    values: Vec<f32>,
197    next_states: Vec<f32>,
198    _state_dim: usize,
199}
200impl RolloutBatch {
201    fn new(capacity: usize, state_dim: usize) -> Self {
202        Self {
203            states: Vec::with_capacity(capacity * state_dim),
204            actions: Vec::with_capacity(capacity),
205            log_probs: Vec::with_capacity(capacity),
206            rewards: Vec::with_capacity(capacity),
207            dones: Vec::with_capacity(capacity),
208            values: Vec::with_capacity(capacity),
209            next_states: Vec::with_capacity(capacity * state_dim),
210            _state_dim: state_dim,
211        }
212    }
213
214    #[allow(clippy::too_many_arguments)]
215    fn push(&mut self, s: &[f32], a: f32, lp: f32, r: f32, d: f32, v: f32, s2: &[f32]) {
216        self.states.extend_from_slice(s);
217        self.actions.push(a);
218        self.log_probs.push(lp);
219        self.rewards.push(r);
220        self.dones.push(d);
221        self.values.push(v);
222        self.next_states.extend_from_slice(s2);
223    }
224
225    fn len(&self) -> usize {
226        self.actions.len()
227    }
228}
229
230// -------------------------------
231// Math helpers
232// -------------------------------
233
234fn gaussian_log_prob(action: &Tensor, mean: &Tensor, log_std: &Tensor) -> Tensor {
235    // All tensors shaped [B, A] (log_std is broadcastable)
236    let std = log_std.exp();
237    let var = std.pow_scalar(2.0);
238    let log_scale = log_std;
239    let diff = action.sub_tensor(mean);
240    let log_prob = diff
241        .pow_scalar(2.0)
242        .div_tensor(&var)
243        .add_scalar(std::f32::consts::LN_2 + std::f32::consts::PI)
244        .add_tensor(&log_scale.mul_scalar(2.0))
245        .mul_scalar(0.5)
246        .mul_scalar(-1.0);
247    // Sum across action dim (dim=1) -> [B,1]
248    log_prob.sum_dims(&[1], true)
249}
250
251#[allow(clippy::too_many_arguments)]
252fn compute_gae(
253    returns_out: &mut [f32],
254    adv_out: &mut [f32],
255    rewards: &[f32],
256    dones: &[f32],
257    values: &[f32],
258    next_values: &[f32],
259    gamma: f32,
260    lam: f32,
261) {
262    let n = rewards.len();
263    let mut gae = 0.0f32;
264    for t in (0..n).rev() {
265        let not_done = 1.0 - dones[t];
266        let delta = rewards[t] + gamma * next_values[t] * not_done - values[t];
267        gae = delta + gamma * lam * not_done * gae;
268        adv_out[t] = gae;
269        returns_out[t] = gae + values[t];
270    }
271}
272
273fn normalize_in_place(x: &mut [f32], eps: f32) {
274    let n = x.len() as f32;
275    if n <= 1.0 {
276        return;
277    }
278    let mean = x.iter().copied().sum::<f32>() / n;
279    let var = x
280        .iter()
281        .map(|v| {
282            let d = v - mean;
283            d * d
284        })
285        .sum::<f32>()
286        / n;
287    let std = (var + eps).sqrt();
288    for v in x.iter_mut() {
289        *v = (*v - mean) / std;
290    }
291}
292
293fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
294    let mut total_sq = 0.0f32;
295    for p in parameters.iter() {
296        if let Some(g) = p.grad_owned() {
297            for &v in g.data() {
298                total_sq += v * v;
299            }
300        }
301    }
302    let norm = total_sq.sqrt();
303    if norm > max_norm {
304        let scale = max_norm / (norm + eps);
305        for p in parameters.iter_mut() {
306            if let Some(g) = p.grad_owned() {
307                p.set_grad(g.mul_scalar(scale));
308            }
309        }
310    }
311}
312
313fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
314    let mut total_sq = 0.0f32;
315    for p in parameters.iter_mut() {
316        if let Some(g) = p.grad_owned() {
317            for &v in g.data() {
318                total_sq += v * v;
319            }
320        }
321    }
322    total_sq.sqrt()
323}
324
325// -------------------------------
326// Main
327// -------------------------------
328
329pub fn main() -> Result<(), Box<dyn std::error::Error>> {
330    println!("=== PPO Continuous Example (YardEnv) ===");
331
332    let state_dim = 3usize;
333    let action_dim = 1usize;
334
335    // Hparams
336    let total_steps = std::env::var("PPO_STEPS")
337        .ok()
338        .and_then(|v| v.parse::<usize>().ok())
339        .unwrap_or(4000usize);
340    let horizon = 128usize; // rollout length per update
341    let epochs = 4usize; // PPO epochs per update
342    let mini_batch_size = 64usize; // minibatch from horizon
343    let gamma = 0.99f32;
344    let lam = 0.95f32; // GAE lambda
345    let clip_eps = 0.2f32;
346    let vf_coef = 0.5f32;
347    let ent_coef = 0.0f32;
348    let max_grad_norm = 1.0f32;
349
350    // Models
351    let mut actor = Actor::new(state_dim, action_dim, Some(101));
352    let mut critic = Critic::new(state_dim, Some(202));
353
354    // Opts
355    let mut actor_opt = Adam::with_learning_rate(3e-4);
356    for p in actor.parameters() {
357        actor_opt.add_parameter(p);
358    }
359    let mut critic_opt = Adam::with_learning_rate(3e-4);
360    for p in critic.parameters() {
361        critic_opt.add_parameter(p);
362    }
363
364    // Env and RNG
365    let mut env = YardEnv::new(42);
366    let mut rng = SmallRng::new(999);
367    let mut state = env.reset();
368
369    // Metrics
370    let mut episode_return = 0.0f32;
371    let mut episode = 0usize;
372    let mut ema_return: Option<f32> = None;
373    let ema_alpha = 0.05f32;
374    let mut best_return = f32::NEG_INFINITY;
375
376    let mut t = 0usize;
377    while t < total_steps {
378        // Collect a rollout
379        let mut batch = RolloutBatch::new(horizon, state_dim);
380        for _ in 0..horizon {
381            // Policy forward (detached sampling to not blow graph; we use stored log_probs)
382            let (mean, log_std_row) = actor.forward(&state);
383            let mean_v = mean.data()[0];
384            let log_std_v = log_std_row.data()[0];
385            let std_v = log_std_v.exp();
386            let noise = rng.normal();
387            let action_v = (mean_v + std_v * noise).clamp(-1.0, 1.0);
388
389            // Build action tensor [1, A] for log_prob calculation with autograd
390            let action_t = Tensor::from_slice(&[action_v], vec![1, action_dim]).unwrap();
391            let log_prob_t = gaussian_log_prob(&action_t, &mean, &log_std_row);
392            let log_prob_v = log_prob_t.data()[0];
393
394            // Step env
395            let (next_state, reward, done) = env.step(action_v);
396            episode_return += reward;
397
398            // Value
399            let value_t = critic.forward(&state);
400            let value_v = value_t.data()[0];
401
402            // Push
403            batch.push(
404                state.data(),
405                action_v,
406                log_prob_v,
407                reward,
408                if done { 1.0 } else { 0.0 },
409                value_v,
410                next_state.data(),
411            );
412
413            // Reset
414            state = if done {
415                let st = env.reset();
416                ema_return = Some(match ema_return {
417                    None => episode_return,
418                    Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419                });
420                if episode_return > best_return {
421                    best_return = episode_return;
422                }
423                println!(
424                    "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3}",
425                    t,
426                    episode,
427                    episode_return,
428                    ema_return.unwrap_or(episode_return),
429                    best_return
430                );
431                episode_return = 0.0;
432                episode += 1;
433                st
434            } else {
435                next_state
436            };
437
438            t += 1;
439            if t >= total_steps {
440                break;
441            }
442        }
443
444        // Bootstrap next values for GAE
445        let next_values: Vec<f32> = {
446            let mut out = Vec::with_capacity(batch.len());
447            for i in 0..batch.len() {
448                let s2 = &batch.next_states[i * state_dim..(i + 1) * state_dim];
449                let s2_t = Tensor::from_slice(s2, vec![1, state_dim]).unwrap();
450                let v2 = critic.forward(&s2_t).data()[0];
451                out.push(v2);
452            }
453            out
454        };
455
456        // Compute returns and advantages
457        let mut returns = vec![0.0f32; batch.len()];
458        let mut adv = vec![0.0f32; batch.len()];
459        compute_gae(
460            &mut returns,
461            &mut adv,
462            &batch.rewards,
463            &batch.dones,
464            &batch.values,
465            &next_values,
466            gamma,
467            lam,
468        );
469        normalize_in_place(&mut adv, 1e-8);
470
471        // Prepare tensors for training
472        let states_t = Tensor::from_slice(&batch.states, vec![batch.len(), state_dim]).unwrap();
473        let actions_t = Tensor::from_slice(&batch.actions, vec![batch.len(), action_dim]).unwrap();
474        let old_logp_t = Tensor::from_slice(&batch.log_probs, vec![batch.len(), 1]).unwrap();
475        let returns_t = Tensor::from_slice(&returns, vec![batch.len(), 1]).unwrap();
476        let adv_t = Tensor::from_slice(&adv, vec![batch.len(), 1]).unwrap();
477
478        // PPO epochs over the rollout
479        let num_minibatches = batch.len().div_ceil(mini_batch_size);
480        for e in 0..epochs {
481            for mb in 0..num_minibatches {
482                let start = mb * mini_batch_size;
483                let end = (start + mini_batch_size).min(batch.len());
484                if start >= end {
485                    break;
486                }
487
488                // Slice views
489                let s_mb = states_t.slice_view(start * state_dim, 1, (end - start) * state_dim);
490                let s_mb = s_mb.reshape(vec![(end - start) as i32, state_dim as i32]);
491                let a_mb = actions_t
492                    .slice_view(start * action_dim, 1, (end - start) * action_dim)
493                    .reshape(vec![(end - start) as i32, action_dim as i32]);
494                let oldlp_mb = old_logp_t
495                    .slice_view(start, 1, end - start)
496                    .reshape(vec![(end - start) as i32, 1]);
497                let ret_mb = returns_t
498                    .slice_view(start, 1, end - start)
499                    .reshape(vec![(end - start) as i32, 1]);
500                let adv_mb = adv_t
501                    .slice_view(start, 1, end - start)
502                    .reshape(vec![(end - start) as i32, 1]);
503
504                // Zero grads
505                {
506                    let mut ps = actor.parameters();
507                    actor_opt.zero_grad(&mut ps);
508                }
509                {
510                    let mut ps = critic.parameters();
511                    critic_opt.zero_grad(&mut ps);
512                }
513
514                // Forward actor and critic
515                let (mean_mb, log_std_row) = actor.forward(&s_mb);
516                let logp_mb = gaussian_log_prob(&a_mb, &mean_mb, &log_std_row);
517                let ratio = logp_mb.sub_tensor(&oldlp_mb).exp(); // exp(new-old)
518                let clip_low =
519                    Tensor::from_slice(&vec![1.0 - clip_eps; end - start], vec![end - start, 1])
520                        .unwrap();
521                let clip_high =
522                    Tensor::from_slice(&vec![1.0 + clip_eps; end - start], vec![end - start, 1])
523                        .unwrap();
524                // ratio_clipped = min(max(ratio, low), high) using ReLU identities
525                let ratio_ge_low = ratio.sub_tensor(&clip_low).relu().add_tensor(&clip_low);
526                let ratio_clipped =
527                    clip_high.sub_tensor(&ratio_ge_low.sub_tensor(&clip_high).relu());
528                let pg1 = ratio.mul_tensor(&adv_mb);
529                let pg2 = ratio_clipped.mul_tensor(&adv_mb);
530                // min(pg1, pg2) = pg2 - relu(pg2 - pg1)
531                let actor_min = pg2.sub_tensor(&pg2.sub_tensor(&pg1).relu());
532                let actor_loss = actor_min.mul_scalar(-1.0).mean();
533
534                let v_pred = critic.forward(&s_mb);
535                let v_loss = v_pred
536                    .sub_tensor(&ret_mb)
537                    .pow_scalar(2.0)
538                    .mean()
539                    .mul_scalar(vf_coef);
540
541                // Entropy (approx Gaussian entropy per action)
542                let entropy = log_std_row
543                    .add_scalar(0.5 * (2.0 * std::f32::consts::PI * std::f32::consts::E).ln())
544                    .sum_dims(&[1], true)
545                    .mean()
546                    .mul_scalar(ent_coef);
547
548                let mut loss = actor_loss.add_tensor(&v_loss).sub_tensor(&entropy);
549                loss.backward(None);
550
551                // Step actor
552                {
553                    let params = actor.parameters();
554                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
555                    for p in params {
556                        if p.grad_owned().is_some() {
557                            with_grads.push(p);
558                        }
559                    }
560                    if !with_grads.is_empty() {
561                        let _ = grad_global_norm(&mut with_grads);
562                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
563                        actor_opt.step(&mut with_grads);
564                        actor_opt.zero_grad(&mut with_grads);
565                    }
566                }
567
568                // Step critic
569                {
570                    let params = critic.parameters();
571                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
572                    for p in params {
573                        if p.grad_owned().is_some() {
574                            with_grads.push(p);
575                        }
576                    }
577                    if !with_grads.is_empty() {
578                        let _ = grad_global_norm(&mut with_grads);
579                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
580                        critic_opt.step(&mut with_grads);
581                        critic_opt.zero_grad(&mut with_grads);
582                    }
583                }
584
585                // Occasionally log
586                if e == 0 && mb == 0 {
587                    println!(
588                        "update@t={} | actor_loss={:.4} v_loss={:.4}",
589                        t,
590                        actor_loss.value(),
591                        v_loss.value()
592                    );
593                }
594
595                clear_all_graphs_known();
596            }
597        }
598    }
599
600    println!("=== PPO training finished ===");
601    Ok(())
602}