td3/
td3.rs

1//! TD3 (Twin Delayed DDPG) - Minimal, self-contained example using Train Station public API
2//!
3//! Goals:
4//! - Keep it small and easy to follow
5//! - Reuse `basic_linear_layer.rs` building block (no duplication)
6//! - Link optimizer parameters correctly (no cloning of params)
7//! - Zero gradients and clear all graphs between iterations
8//! - Use only public Train Station APIs + standard Rust
9//!
10//! Run:
11//!   cargo run --release --example td3
12
13use train_station::{
14    gradtrack::{clear_all_graphs_known, NoGradTrack},
15    optimizers::{Adam, Optimizer},
16    Tensor,
17};
18
19// Reuse simple LinearLayer to build tiny MLPs (actor/critic)
20#[allow(clippy::duplicate_mod)]
21#[path = "../neural_networks/basic_linear_layer.rs"]
22mod basic_linear_layer;
23use basic_linear_layer::LinearLayer;
24
25// -------------------------------
26// Utilities
27// -------------------------------
28
29// Simple LCG RNG (no external deps)
30struct SmallRng {
31    state: u64,
32}
33
34impl SmallRng {
35    fn new(seed: u64) -> Self {
36        Self { state: seed.max(1) }
37    }
38    fn next_u32(&mut self) -> u32 {
39        // Numerical Recipes LCG
40        self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
41        (self.state >> 16) as u32
42    }
43    fn next_f32(&mut self) -> f32 {
44        (self.next_u32() as f32) / (u32::MAX as f32)
45    }
46    fn uniform(&mut self, low: f32, high: f32) -> f32 {
47        low + (high - low) * self.next_f32()
48    }
49    fn sample_index(&mut self, upper_exclusive: usize) -> usize {
50        (self.next_u32() as usize) % upper_exclusive.max(1)
51    }
52}
53
54fn tanh_bounded(x: &Tensor) -> Tensor {
55    x.tanh()
56}
57
58// -------------------------------
59// Tiny MLP builder on LinearLayer
60// -------------------------------
61
62struct Mlp {
63    layers: Vec<LinearLayer>,
64}
65
66impl Mlp {
67    fn new(sizes: &[usize], seed: Option<u64>) -> Self {
68        assert!(sizes.len() >= 2);
69        let mut layers = Vec::new();
70        let mut s = seed;
71        for w in sizes.windows(2) {
72            layers.push(LinearLayer::new(w[0], w[1], s));
73            s = s.map(|v| v + 1);
74        }
75        Self { layers }
76    }
77
78    fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
79        let mut current: Option<Tensor> = None;
80        for (i, layer) in self.layers.iter().enumerate() {
81            let out = if i == 0 {
82                layer.forward(input)
83            } else {
84                layer.forward(current.as_ref().unwrap())
85            };
86            let is_last = i + 1 == self.layers.len();
87            let out = if !is_last {
88                out.relu()
89            } else if let Some(act) = final_activation {
90                act(&out)
91            } else {
92                out
93            };
94            current = Some(out);
95        }
96        current.expect("MLP has at least one layer")
97    }
98
99    fn parameters(&mut self) -> Vec<&mut Tensor> {
100        let mut params = Vec::new();
101        for l in &mut self.layers {
102            params.extend(l.parameters());
103        }
104        params
105    }
106
107    fn set_requires_grad_all(&mut self, enable: bool) {
108        for l in &mut self.layers {
109            l.weight.set_requires_grad(enable);
110            l.bias.set_requires_grad(enable);
111        }
112    }
113
114    fn copy_from(&mut self, other: &Self) {
115        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
116            {
117                let src = s.weight.data();
118                let dst = t.weight.data_mut();
119                dst.copy_from_slice(src);
120            }
121            {
122                let src = s.bias.data();
123                let dst = t.bias.data_mut();
124                dst.copy_from_slice(src);
125            }
126            t.weight.set_requires_grad(false);
127            t.bias.set_requires_grad(false);
128        }
129    }
130
131    fn soft_update_from(&mut self, source: &Self, tau: f32) {
132        let _ng = NoGradTrack::new();
133        for (t, s) in self.layers.iter_mut().zip(source.layers.iter()) {
134            // In-place Polyak update to preserve tensor IDs (no optimizer relink needed)
135            let new_w = t
136                .weight
137                .mul_scalar(1.0 - tau)
138                .add_tensor(&s.weight.mul_scalar(tau));
139            let new_b = t
140                .bias
141                .mul_scalar(1.0 - tau)
142                .add_tensor(&s.bias.mul_scalar(tau));
143            {
144                let src = new_w.data();
145                let dst = t.weight.data_mut();
146                dst.copy_from_slice(src);
147            }
148            {
149                let src = new_b.data();
150                let dst = t.bias.data_mut();
151                dst.copy_from_slice(src);
152            }
153            t.weight.set_requires_grad(false);
154            t.bias.set_requires_grad(false);
155        }
156    }
157}
158
159// -------------------------------
160// Actor and Critic
161// -------------------------------
162
163struct Actor {
164    net: Mlp,
165}
166
167impl Actor {
168    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
169        // Smaller net for faster demo: sd -> 64 -> 64 -> ad, tanh output
170        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
171        Self { net }
172    }
173    fn forward(&self, state: &Tensor) -> Tensor {
174        self.net.forward(state, Some(tanh_bounded))
175    }
176    fn parameters(&mut self) -> Vec<&mut Tensor> {
177        self.net.parameters()
178    }
179    fn set_requires_grad_all(&mut self, enable: bool) {
180        self.net.set_requires_grad_all(enable);
181    }
182}
183
184struct Critic {
185    net: Mlp,
186}
187
188impl Critic {
189    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
190        let net = Mlp::new(&[state_dim + action_dim, 64, 64, 1], seed);
191        Self { net }
192    }
193    fn forward(&self, state: &Tensor, action: &Tensor) -> Tensor {
194        // Concatenate along feature dim (dim=1) for batched inputs
195        // IMPORTANT: use views to preserve gradient graph; cloning would detach autograd
196        let s_view = state.view(state.shape().dims().iter().map(|&d| d as i32).collect());
197        let a_view = action.view(action.shape().dims().iter().map(|&d| d as i32).collect());
198        let sa = Tensor::cat(&[s_view, a_view], 1);
199        self.net.forward(&sa, None)
200    }
201    fn parameters(&mut self) -> Vec<&mut Tensor> {
202        self.net.parameters()
203    }
204    fn set_requires_grad_all(&mut self, enable: bool) {
205        self.net.set_requires_grad_all(enable);
206    }
207}
208
209// -------------------------------
210// Simple continuous control environment: YardEnv
211// State: normalized features [pos/3, clamp(vel/1, -1..1), bias(=0)] ; Action: scalar in [-1, 1]
212// Dynamics: vel += 0.1*act - 0.01*pos; pos += vel
213// Reward: -(pos^2) - 0.1*act^2 ; Episode ends if |pos| > 3 or step >= max_steps
214// -------------------------------
215
216struct YardEnv {
217    pos: f32,
218    vel: f32,
219    steps: usize,
220    max_steps: usize,
221    rng: SmallRng,
222}
223
224impl YardEnv {
225    fn new(seed: u64) -> Self {
226        let mut env = Self {
227            pos: 0.0,
228            vel: 0.0,
229            steps: 0,
230            max_steps: 200,
231            rng: SmallRng::new(seed),
232        };
233        env.reset();
234        env
235    }
236
237    fn reset(&mut self) -> Tensor {
238        self.pos = self.rng.uniform(-0.5, 0.5);
239        self.vel = self.rng.uniform(-0.1, 0.1);
240        self.steps = 0;
241        self.state_tensor()
242    }
243
244    fn state_tensor(&self) -> Tensor {
245        // Normalize to keep critic inputs bounded:
246        // - Position is bounded by termination at |pos|>3 → scale by 3 to [-1,1]
247        // - Velocity scaled by 1.0 and clamped to [-1,1]
248        let pos_n = self.pos / 3.0;
249        let vel_n = self.vel.clamp(-1.0, 1.0);
250        Tensor::from_slice(&[pos_n, vel_n, 0.0], vec![1, 3]).unwrap()
251    }
252
253    fn step(&mut self, action_value: f32) -> (Tensor, f32, bool) {
254        let a = action_value.clamp(-1.0, 1.0);
255        self.vel += 0.1 * a - 0.01 * self.pos;
256        self.pos += self.vel;
257        self.steps += 1;
258
259        let reward = -(self.pos * self.pos) - 0.1 * (a * a);
260        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
261        (self.state_tensor(), reward, done)
262    }
263}
264
265// -------------------------------
266// Replay Buffer
267// -------------------------------
268
269struct ReplayBuffer {
270    capacity: usize,
271    size: usize,
272    pos: usize,
273    state_dim: usize,
274    action_dim: usize,
275    states: Vec<f32>,
276    actions: Vec<f32>,
277    rewards: Vec<f32>,
278    dones: Vec<f32>,
279    next_states: Vec<f32>,
280}
281
282impl ReplayBuffer {
283    fn new(capacity: usize, state_dim: usize, action_dim: usize) -> Self {
284        Self {
285            capacity,
286            size: 0,
287            pos: 0,
288            state_dim,
289            action_dim,
290            states: vec![0.0; capacity * state_dim],
291            actions: vec![0.0; capacity * action_dim],
292            rewards: vec![0.0; capacity],
293            dones: vec![0.0; capacity],
294            next_states: vec![0.0; capacity * state_dim],
295        }
296    }
297
298    fn push(&mut self, s: &[f32], a: &[f32], r: f32, d: f32, s2: &[f32]) {
299        let i = self.pos;
300        let so = i * self.state_dim;
301        let ao = i * self.action_dim;
302        self.states[so..so + self.state_dim].copy_from_slice(s);
303        self.actions[ao..ao + self.action_dim].copy_from_slice(a);
304        self.rewards[i] = r;
305        self.dones[i] = d;
306        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
307
308        self.pos = (self.pos + 1) % self.capacity;
309        self.size = self.size.saturating_add(1).min(self.capacity);
310    }
311
312    fn can_sample(&self, batch_size: usize) -> bool {
313        self.size >= batch_size
314    }
315
316    fn sample(
317        &self,
318        batch_size: usize,
319        rng: &mut SmallRng,
320    ) -> (Tensor, Tensor, Tensor, Tensor, Tensor) {
321        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
322        let mut a_vec = Vec::with_capacity(batch_size * self.action_dim);
323        let mut r_vec = Vec::with_capacity(batch_size);
324        let mut d_vec = Vec::with_capacity(batch_size);
325        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
326
327        for _ in 0..batch_size {
328            let idx = rng.sample_index(self.size);
329            let so = idx * self.state_dim;
330            let ao = idx * self.action_dim;
331            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
332            a_vec.extend_from_slice(&self.actions[ao..ao + self.action_dim]);
333            r_vec.push(self.rewards[idx]);
334            d_vec.push(self.dones[idx]);
335            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
336        }
337
338        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
339        let a = Tensor::from_slice(&a_vec, vec![batch_size, self.action_dim]).unwrap();
340        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
341        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
342        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
343        (s, a, r, d, s2)
344    }
345}
346
347// -------------------------------
348// Helper: gradient clipping by global norm
349// -------------------------------
350
351fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
352    // Compute global L2 norm of all grads
353    let mut total_sq = 0.0f32;
354    for p in parameters.iter() {
355        if let Some(g) = p.grad_owned() {
356            for &v in g.data() {
357                total_sq += v * v;
358            }
359        }
360    }
361    let norm = total_sq.sqrt();
362    if norm > max_norm {
363        let scale = max_norm / (norm + eps);
364        for p in parameters.iter_mut() {
365            if let Some(g) = p.grad_owned() {
366                let scaled = g.mul_scalar(scale);
367                p.set_grad(scaled);
368            }
369        }
370    }
371}
372
373// Compute global L2 norm of gradients across a parameter list (read-only)
374fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
375    let mut total_sq = 0.0f32;
376    for p in parameters.iter_mut() {
377        if let Some(g) = p.grad_owned() {
378            for &v in g.data() {
379                total_sq += v * v;
380            }
381        }
382    }
383    total_sq.sqrt()
384}
385
386// Compute L2 norm of parameters (weights/biases) across a parameter list
387fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
388    let _ng = NoGradTrack::new();
389    let mut total_sq = 0.0f32;
390    for p in parameters.iter_mut() {
391        for &v in p.data() {
392            total_sq += v * v;
393        }
394    }
395    total_sq.sqrt()
396}
397
398// -------------------------------
399// Main: TD3 training on YardEnv
400// -------------------------------
401
402pub fn main() -> Result<(), Box<dyn std::error::Error>> {
403    println!("=== TD3 Example (YardEnv) ===");
404
405    // Environment / problem dims
406    let state_dim = 3usize;
407    let action_dim = 1usize;
408
409    // Hyperparameters (small for demo)
410    let gamma = 0.99f32;
411    let tau = 0.005f32; // Polyak
412    let policy_noise = 0.2f32; // target smoothing noise stddev
413    let exploration_noise = 0.1f32; // behavior policy noise stddev
414    let policy_delay = 2usize;
415    let batch_size = 64usize;
416    let start_steps = 500usize; // random exploration steps
417    let total_steps = 1500usize;
418    let max_grad_norm = 1.0f32;
419
420    // Models
421    let mut actor = Actor::new(state_dim, action_dim, Some(11));
422    let mut actor_targ = Actor::new(state_dim, action_dim, Some(12));
423    actor_targ.net.copy_from(&actor.net);
424    actor_targ.set_requires_grad_all(false);
425
426    let mut critic1 = Critic::new(state_dim, action_dim, Some(21));
427    let mut critic2 = Critic::new(state_dim, action_dim, Some(22));
428    let mut critic1_targ = Critic::new(state_dim, action_dim, Some(23));
429    let mut critic2_targ = Critic::new(state_dim, action_dim, Some(24));
430    critic1_targ.net.copy_from(&critic1.net);
431    critic2_targ.net.copy_from(&critic2.net);
432    critic1_targ.set_requires_grad_all(false);
433    critic2_targ.set_requires_grad_all(false);
434
435    // Optimizers
436    let mut actor_opt = Adam::with_learning_rate(1e-3);
437    for p in actor.parameters() {
438        actor_opt.add_parameter(p);
439    }
440
441    let mut critic_opt = Adam::with_learning_rate(1e-4);
442    for p in critic1.parameters() {
443        critic_opt.add_parameter(p);
444    }
445    for p in critic2.parameters() {
446        critic_opt.add_parameter(p);
447    }
448
449    // Replay buffer and env
450    let mut rb = ReplayBuffer::new(100_000, state_dim, action_dim);
451    let mut env = YardEnv::new(1234);
452    let mut rng = SmallRng::new(987654321);
453
454    // Reset & metric trackers
455    let mut state = env.reset(); // [1, state_dim]
456    let mut episode_return = 0.0f32;
457    let mut episode = 0usize;
458    let mut ema_return: Option<f32> = None;
459    let ema_alpha = 0.05f32; // smooth short-term
460    let mut best_return = f32::NEG_INFINITY;
461    let mut policy_updates: usize = 0;
462
463    for t in 0..total_steps {
464        // Select action
465        let action_tensor = if t < start_steps {
466            let a = rng.uniform(-1.0, 1.0);
467            Tensor::from_slice(&[a], vec![1, action_dim]).unwrap()
468        } else {
469            // Behavior policy with exploration noise
470            let _ng = NoGradTrack::new();
471            let det = actor.forward(&state);
472            let noise = Tensor::randn(vec![1, action_dim], None).mul_scalar(exploration_noise);
473            tanh_bounded(&det.add_tensor(&noise))
474        };
475        let action_value = action_tensor.data()[0];
476
477        // Environment step
478        let (next_state, reward, done) = env.step(action_value);
479        episode_return += reward;
480
481        // Store transition
482        let s_slice = state.data().to_vec();
483        let a_slice = action_tensor.data().to_vec();
484        let s2_slice = next_state.data().to_vec();
485        rb.push(
486            &s_slice,
487            &a_slice,
488            reward,
489            if done { 1.0 } else { 0.0 },
490            &s2_slice,
491        );
492
493        state = if done {
494            let st = env.reset();
495            // Metrics: update EMA and best
496            ema_return = Some(match ema_return {
497                None => episode_return,
498                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
499            });
500            if episode_return > best_return {
501                best_return = episode_return;
502            }
503            println!(
504                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={} | policy_updates={}",
505                t,
506                episode,
507                episode_return,
508                ema_return.unwrap_or(episode_return),
509                best_return,
510                rb.size,
511                policy_updates
512            );
513            episode_return = 0.0;
514            episode += 1;
515            st
516        } else {
517            next_state
518        };
519
520        // Training
521        if rb.can_sample(batch_size) {
522            // Sample batch
523            let (s, a, r, d, s2) = rb.sample(batch_size, &mut rng);
524
525            // Compute target values y = r + (1-d)*gamma*min(Q1', Q2') using target networks (no grad)
526            let target_q = {
527                let _ng = NoGradTrack::new();
528                // Target actions with smoothing noise (tanh bounds)
529                let noise =
530                    Tensor::randn(vec![batch_size, action_dim], None).mul_scalar(policy_noise);
531                let a_targ = tanh_bounded(&actor_targ.forward(&s2).add_tensor(&noise));
532                let q1_t = critic1_targ.forward(&s2, &a_targ);
533                let q2_t = critic2_targ.forward(&s2, &a_targ);
534
535                // Elementwise min via data() since this path is no-grad
536                let q1d = q1_t.data();
537                let q2d = q2_t.data();
538                let mut min_vec = Vec::with_capacity(batch_size);
539                for i in 0..batch_size {
540                    let v1 = q1d[i];
541                    let v2 = q2d[i];
542                    min_vec.push(v1.min(v2));
543                }
544                let min_q = Tensor::from_slice(&min_vec, vec![batch_size, 1]).unwrap();
545                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
546                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&min_q))
547            };
548
549            // Critic update (both critics)
550            // Zero grads in a short scope, then drop borrows before forward
551            {
552                let mut params = {
553                    let c_params = critic1.parameters();
554                    let c2_params = critic2.parameters();
555                    let mut tmp: Vec<&mut Tensor> = Vec::new();
556                    tmp.extend(c_params);
557                    tmp.extend(c2_params);
558                    tmp
559                };
560                critic_opt.zero_grad(&mut params);
561            }
562
563            // Forward current Q estimates
564            let q1 = critic1.forward(&s, &a);
565            let q2 = critic2.forward(&s, &a);
566            let diff1 = q1.sub_tensor(&target_q);
567            let diff2 = q2.sub_tensor(&target_q);
568            let mut critic_loss = diff1
569                .pow_scalar(2.0)
570                .mean()
571                .add_tensor(&diff2.pow_scalar(2.0).mean());
572
573            // Backward
574            critic_loss.backward(None);
575
576            // Optional gradient clipping + step (only for params that received grads)
577            {
578                let params = {
579                    let c_params = critic1.parameters();
580                    let c2_params = critic2.parameters();
581                    let mut tmp: Vec<&mut Tensor> = Vec::new();
582                    tmp.extend(c_params);
583                    tmp.extend(c2_params);
584                    tmp
585                };
586                let mut with_grads: Vec<&mut Tensor> = Vec::new();
587                for p in params {
588                    if p.grad_owned().is_some() {
589                        with_grads.push(p);
590                    }
591                }
592                if !with_grads.is_empty() {
593                    // Pre-step metrics
594                    let grad_norm_before = grad_global_norm(&mut with_grads);
595                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
596                    critic_opt.step(&mut with_grads);
597                    critic_opt.zero_grad(&mut with_grads);
598
599                    // Post-step metrics (param norm)
600                    let mut for_norm_params = {
601                        let c_params = critic1.parameters();
602                        let c2_params = critic2.parameters();
603                        let mut tmp: Vec<&mut Tensor> = Vec::new();
604                        tmp.extend(c_params);
605                        tmp.extend(c2_params);
606                        tmp
607                    };
608                    let param_norm = params_l2_norm(&mut for_norm_params);
609
610                    // Print compact critic metrics occasionally
611                    if t % 100 == 0 {
612                        let q1_mean = q1.mean().value();
613                        let q2_mean = q2.mean().value();
614                        let tq_mean = target_q.mean().value();
615                        println!(
616                            "t={:5} | critic_loss={:.4} | q1_mean={:.3} q2_mean={:.3} tq_mean={:.3} | grad_norm={:.3} | crit_param_norm={:.3}",
617                            t,
618                            critic_loss.value(),
619                            q1_mean,
620                            q2_mean,
621                            tq_mean,
622                            grad_norm_before,
623                            param_norm
624                        );
625                    }
626                }
627            }
628
629            // Delayed policy update
630            if t % policy_delay == 0 {
631                // Actor update: maximize Q1(s, actor(s)) -> minimize -Q1
632                // Zero actor grads before backward
633                {
634                    let mut a_params: Vec<&mut Tensor> = actor.parameters();
635                    actor_opt.zero_grad(&mut a_params);
636                }
637
638                let a_pred = actor.forward(&s);
639                let q_for_actor = critic1.forward(&s, &a_pred);
640                let mut actor_loss = q_for_actor.mul_scalar(-1.0).mean();
641                actor_loss.backward(None);
642
643                {
644                    let a_params: Vec<&mut Tensor> = actor.parameters();
645                    let mut with_grads: Vec<&mut Tensor> = Vec::new();
646                    for p in a_params {
647                        if p.grad_owned().is_some() {
648                            with_grads.push(p);
649                        }
650                    }
651                    if !with_grads.is_empty() {
652                        let grad_norm_before = grad_global_norm(&mut with_grads);
653                        clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
654                        actor_opt.step(&mut with_grads);
655                        actor_opt.zero_grad(&mut with_grads);
656
657                        // Post-step param norm
658                        let mut for_norm_params = actor.parameters();
659                        let param_norm = params_l2_norm(&mut for_norm_params);
660
661                        policy_updates += 1;
662                        if t % 200 == 0 {
663                            println!(
664                                "t={:5} | actor_loss={:.4} | act_grad_norm={:.3} | act_param_norm={:.3} | lr_a={:.4e} lr_c={:.4e} | policy_updates={}",
665                                t,
666                                actor_loss.value(),
667                                grad_norm_before,
668                                param_norm,
669                                actor_opt.learning_rate(),
670                                critic_opt.learning_rate(),
671                                policy_updates
672                            );
673                        }
674                    }
675                }
676
677                // Target updates (Polyak averaging, no grad)
678                actor_targ.net.soft_update_from(&actor.net, tau);
679                critic1_targ.net.soft_update_from(&critic1.net, tau);
680                critic2_targ.net.soft_update_from(&critic2.net, tau);
681            }
682
683            // Clear entire graphs to avoid stale accumulation across iterations
684            clear_all_graphs_known();
685        }
686    }
687
688    println!("=== TD3 training finished ===");
689    Ok(())
690}