dqn/
dqn.rs

1//! DQN (Deep Q-Network) - Minimal example using Train Station public API
2//!
3//! - Discrete `YardEnv` (3 actions: -1, 0, +1)
4//! - Experience Replay + Double DQN targets + target network hard updates
5//! - Gradient clipping, zero_grad, clear_all_graphs between steps
6//! - Reuses `basic_linear_layer.rs` for a small MLP
7//!
8//! Run:
9//!   cargo run --release --example dqn
10
11use train_station::{
12    gradtrack::{clear_all_graphs_known, NoGradTrack},
13    optimizers::{Adam, Optimizer},
14    Tensor,
15};
16
17// Reuse simple LinearLayer to build tiny MLP
18#[allow(clippy::duplicate_mod)]
19#[path = "../neural_networks/basic_linear_layer.rs"]
20mod basic_linear_layer;
21use basic_linear_layer::LinearLayer;
22
23// -------------------------------
24// Utilities
25// -------------------------------
26
27// Simple LCG RNG (no external deps)
28struct SmallRng {
29    state: u64,
30}
31
32impl SmallRng {
33    fn new(seed: u64) -> Self {
34        Self { state: seed.max(1) }
35    }
36    fn next_u32(&mut self) -> u32 {
37        self.state = self.state.wrapping_mul(1664525).wrapping_add(1013904223);
38        (self.state >> 16) as u32
39    }
40    fn next_f32(&mut self) -> f32 {
41        (self.next_u32() as f32) / (u32::MAX as f32)
42    }
43    fn uniform(&mut self, low: f32, high: f32) -> f32 {
44        low + (high - low) * self.next_f32()
45    }
46    fn sample_index(&mut self, upper_exclusive: usize) -> usize {
47        (self.next_u32() as usize) % upper_exclusive.max(1)
48    }
49}
50
51// -------------------------------
52// Tiny MLP builder on LinearLayer
53// -------------------------------
54
55struct Mlp {
56    layers: Vec<LinearLayer>,
57}
58
59impl Mlp {
60    fn new(sizes: &[usize], seed: Option<u64>) -> Self {
61        assert!(sizes.len() >= 2);
62        let mut layers = Vec::new();
63        let mut s = seed;
64        for w in sizes.windows(2) {
65            layers.push(LinearLayer::new(w[0], w[1], s));
66            s = s.map(|v| v + 1);
67        }
68        Self { layers }
69    }
70
71    fn forward(&self, input: &Tensor, final_activation: Option<fn(&Tensor) -> Tensor>) -> Tensor {
72        let mut current: Option<Tensor> = None;
73        for (i, layer) in self.layers.iter().enumerate() {
74            let out = if i == 0 {
75                layer.forward(input)
76            } else {
77                layer.forward(current.as_ref().unwrap())
78            };
79            let is_last = i + 1 == self.layers.len();
80            let out = if !is_last {
81                out.relu()
82            } else if let Some(act) = final_activation {
83                act(&out)
84            } else {
85                out
86            };
87            current = Some(out);
88        }
89        current.expect("MLP has at least one layer")
90    }
91
92    fn parameters(&mut self) -> Vec<&mut Tensor> {
93        let mut params = Vec::new();
94        for l in &mut self.layers {
95            params.extend(l.parameters());
96        }
97        params
98    }
99
100    fn set_requires_grad_all(&mut self, enable: bool) {
101        for l in &mut self.layers {
102            l.weight.set_requires_grad(enable);
103            l.bias.set_requires_grad(enable);
104        }
105    }
106
107    // In-place copy (preserve tensor IDs and optimizer links on targets)
108    fn copy_from(&mut self, other: &Self) {
109        for (t, s) in self.layers.iter_mut().zip(other.layers.iter()) {
110            {
111                let src = s.weight.data();
112                let dst = t.weight.data_mut();
113                dst.copy_from_slice(src);
114            }
115            {
116                let src = s.bias.data();
117                let dst = t.bias.data_mut();
118                dst.copy_from_slice(src);
119            }
120            t.weight.set_requires_grad(false);
121            t.bias.set_requires_grad(false);
122        }
123    }
124}
125
126// -------------------------------
127// Q-Network (state -> Q-values over actions)
128// -------------------------------
129
130struct QNet {
131    net: Mlp,
132}
133
134impl QNet {
135    fn new(state_dim: usize, action_dim: usize, seed: Option<u64>) -> Self {
136        let net = Mlp::new(&[state_dim, 64, 64, action_dim], seed);
137        Self { net }
138    }
139    fn forward(&self, state: &Tensor) -> Tensor {
140        self.net.forward(state, None)
141    }
142    fn parameters(&mut self) -> Vec<&mut Tensor> {
143        self.net.parameters()
144    }
145    fn set_requires_grad_all(&mut self, enable: bool) {
146        self.net.set_requires_grad_all(enable);
147    }
148}
149
150// -------------------------------
151// Discrete YardEnv (3 actions: -1, 0, +1)
152// -------------------------------
153
154struct YardEnv {
155    pos: f32,
156    vel: f32,
157    steps: usize,
158    max_steps: usize,
159    rng: SmallRng,
160}
161
162impl YardEnv {
163    const ACTIONS: [f32; 3] = [-1.0, 0.0, 1.0];
164
165    fn new(seed: u64) -> Self {
166        let mut env = Self {
167            pos: 0.0,
168            vel: 0.0,
169            steps: 0,
170            max_steps: 200,
171            rng: SmallRng::new(seed),
172        };
173        env.reset();
174        env
175    }
176
177    fn reset(&mut self) -> Tensor {
178        self.pos = self.rng.uniform(-0.5, 0.5);
179        self.vel = self.rng.uniform(-0.1, 0.1);
180        self.steps = 0;
181        self.state_tensor()
182    }
183
184    fn state_tensor(&self) -> Tensor {
185        Tensor::from_slice(&[self.pos, self.vel, 0.0], vec![1, 3]).unwrap()
186    }
187
188    fn step(&mut self, action_index: usize) -> (Tensor, f32, bool) {
189        let a = Self::ACTIONS[action_index.min(2)];
190        self.vel += 0.1 * a - 0.01 * self.pos;
191        self.pos += self.vel;
192        self.steps += 1;
193        let reward = -(self.pos * self.pos) - 0.05 * (a * a);
194        let done = self.pos.abs() > 3.0 || self.steps >= self.max_steps;
195        (self.state_tensor(), reward, done)
196    }
197}
198
199// -------------------------------
200// Replay Buffer
201// -------------------------------
202
203struct ReplayBuffer {
204    capacity: usize,
205    size: usize,
206    pos: usize,
207    state_dim: usize,
208    states: Vec<f32>,
209    actions: Vec<usize>,
210    rewards: Vec<f32>,
211    dones: Vec<f32>,
212    next_states: Vec<f32>,
213}
214
215impl ReplayBuffer {
216    fn new(capacity: usize, state_dim: usize) -> Self {
217        Self {
218            capacity,
219            size: 0,
220            pos: 0,
221            state_dim,
222            states: vec![0.0; capacity * state_dim],
223            actions: vec![0usize; capacity],
224            rewards: vec![0.0; capacity],
225            dones: vec![0.0; capacity],
226            next_states: vec![0.0; capacity * state_dim],
227        }
228    }
229
230    fn push(&mut self, s: &[f32], a_idx: usize, r: f32, d: f32, s2: &[f32]) {
231        let i = self.pos;
232        let so = i * self.state_dim;
233        self.states[so..so + self.state_dim].copy_from_slice(s);
234        self.actions[i] = a_idx;
235        self.rewards[i] = r;
236        self.dones[i] = d;
237        self.next_states[so..so + self.state_dim].copy_from_slice(s2);
238        self.pos = (self.pos + 1) % self.capacity;
239        self.size = self.size.saturating_add(1).min(self.capacity);
240    }
241
242    fn can_sample(&self, batch_size: usize) -> bool {
243        self.size >= batch_size
244    }
245
246    fn sample(
247        &self,
248        batch_size: usize,
249        rng: &mut SmallRng,
250    ) -> (Tensor, Vec<usize>, Tensor, Tensor, Tensor) {
251        let mut s_vec = Vec::with_capacity(batch_size * self.state_dim);
252        let mut a_idx = Vec::with_capacity(batch_size);
253        let mut r_vec = Vec::with_capacity(batch_size);
254        let mut d_vec = Vec::with_capacity(batch_size);
255        let mut s2_vec = Vec::with_capacity(batch_size * self.state_dim);
256        for _ in 0..batch_size {
257            let idx = rng.sample_index(self.size);
258            let so = idx * self.state_dim;
259            s_vec.extend_from_slice(&self.states[so..so + self.state_dim]);
260            a_idx.push(self.actions[idx]);
261            r_vec.push(self.rewards[idx]);
262            d_vec.push(self.dones[idx]);
263            s2_vec.extend_from_slice(&self.next_states[so..so + self.state_dim]);
264        }
265        let s = Tensor::from_slice(&s_vec, vec![batch_size, self.state_dim]).unwrap();
266        let r = Tensor::from_slice(&r_vec, vec![batch_size, 1]).unwrap();
267        let d = Tensor::from_slice(&d_vec, vec![batch_size, 1]).unwrap();
268        let s2 = Tensor::from_slice(&s2_vec, vec![batch_size, self.state_dim]).unwrap();
269        (s, a_idx, r, d, s2)
270    }
271}
272
273// -------------------------------
274// Helpers
275// -------------------------------
276
277fn clip_gradients(parameters: &mut [&mut Tensor], max_norm: f32, eps: f32) {
278    let mut total_sq = 0.0f32;
279    for p in parameters.iter() {
280        if let Some(g) = p.grad_owned() {
281            for &v in g.data() {
282                total_sq += v * v;
283            }
284        }
285    }
286    let norm = total_sq.sqrt();
287    if norm > max_norm {
288        let scale = max_norm / (norm + eps);
289        for p in parameters.iter_mut() {
290            if let Some(g) = p.grad_owned() {
291                p.set_grad(g.mul_scalar(scale));
292            }
293        }
294    }
295}
296
297fn grad_global_norm(parameters: &mut [&mut Tensor]) -> f32 {
298    let mut total_sq = 0.0f32;
299    for p in parameters.iter_mut() {
300        if let Some(g) = p.grad_owned() {
301            for &v in g.data() {
302                total_sq += v * v;
303            }
304        }
305    }
306    total_sq.sqrt()
307}
308
309fn params_l2_norm(parameters: &mut [&mut Tensor]) -> f32 {
310    let _ng = NoGradTrack::new();
311    let mut total_sq = 0.0f32;
312    for p in parameters.iter_mut() {
313        for &v in p.data() {
314            total_sq += v * v;
315        }
316    }
317    total_sq.sqrt()
318}
319
320// Pseudo-Huber loss: sqrt(1 + diff^2) - 1 (smooth, robust)
321fn pseudo_huber_mean(diff: &Tensor) -> Tensor {
322    diff.pow_scalar(2.0)
323        .add_scalar(1.0)
324        .sqrt()
325        .sub_scalar(1.0)
326        .mean()
327}
328
329// -------------------------------
330// Main
331// -------------------------------
332
333pub fn main() -> Result<(), Box<dyn std::error::Error>> {
334    println!("=== DQN Example (YardEnv discrete) ===");
335
336    // Dims
337    let state_dim = 3usize;
338    let action_dim = 3usize;
339
340    // Hparams
341    let gamma = 0.99f32;
342    let batch_size = 64usize;
343    let start_steps = 200usize;
344    let target_update_interval = 200usize; // hard update cadence
345    let max_grad_norm = 1.0f32;
346    let mut epsilon = 1.0f32;
347    let eps_min = 0.05f32;
348    let eps_decay_steps = 2_000usize; // linear decay
349    let total_steps = std::env::var("DQN_STEPS")
350        .ok()
351        .and_then(|v| v.parse::<usize>().ok())
352        .unwrap_or(3000usize);
353
354    // Models
355    let mut q_net = QNet::new(state_dim, action_dim, Some(7));
356    let mut q_targ = QNet::new(state_dim, action_dim, Some(8));
357    q_targ.net.copy_from(&q_net.net);
358    q_targ.set_requires_grad_all(false);
359
360    // Optimizer
361    let mut q_opt = Adam::with_learning_rate(3e-4);
362    for p in q_net.parameters() {
363        q_opt.add_parameter(p);
364    }
365
366    // Replay + env
367    let mut rb = ReplayBuffer::new(100_000, state_dim);
368    let mut env = YardEnv::new(12345);
369    let mut rng = SmallRng::new(999_111);
370
371    // Metrics
372    let mut state = env.reset();
373    let mut episode_return = 0.0f32;
374    let mut episode = 0usize;
375    let mut ema_return: Option<f32> = None;
376    let ema_alpha = 0.05f32;
377    let mut best_return = f32::NEG_INFINITY;
378
379    for t in 0..total_steps {
380        // Epsilon-greedy action
381        let action_index = if t < start_steps || rng.next_f32() < epsilon {
382            rng.sample_index(action_dim)
383        } else {
384            let _ng = NoGradTrack::new();
385            let q_vals = q_net.forward(&state);
386            let row = q_vals.data();
387            let mut best_i = 0usize;
388            let mut best_v = row[0];
389            for (i, &r) in row.iter().enumerate().take(action_dim).skip(1) {
390                if r > best_v {
391                    best_v = r;
392                    best_i = i;
393                }
394            }
395            best_i
396        };
397
398        // Env step
399        let (next_state, reward, done) = env.step(action_index);
400        episode_return += reward;
401
402        // Store
403        let s_slice = state.data().to_vec();
404        let s2_slice = next_state.data().to_vec();
405        rb.push(
406            &s_slice,
407            action_index,
408            reward,
409            if done { 1.0 } else { 0.0 },
410            &s2_slice,
411        );
412
413        // Reset on done
414        state = if done {
415            let st = env.reset();
416            ema_return = Some(match ema_return {
417                None => episode_return,
418                Some(prev) => prev * (1.0 - ema_alpha) + ema_alpha * episode_return,
419            });
420            if episode_return > best_return {
421                best_return = episode_return;
422            }
423            println!(
424                "step {:5} | episode {:4} return={:.3} ema={:.3} best={:.3} | rb_size={}",
425                t,
426                episode,
427                episode_return,
428                ema_return.unwrap_or(episode_return),
429                best_return,
430                rb.size
431            );
432            episode_return = 0.0;
433            episode += 1;
434            st
435        } else {
436            next_state
437        };
438
439        // Epsilon linear decay
440        if t < eps_decay_steps {
441            epsilon = (1.0 - (t as f32) / (eps_decay_steps as f32)) * (1.0 - eps_min) + eps_min;
442        }
443
444        // Train
445        if rb.can_sample(batch_size) {
446            let (s, a_idx, r, d, s2) = rb.sample(batch_size, &mut rng);
447
448            // Double DQN target: a* = argmax_a Q_online(s2,a); y = r + (1-d)*gamma*Q_target(s2, a*)
449            let target_q = {
450                let _ng = NoGradTrack::new();
451                let q_online_s2 = q_net.forward(&s2);
452                // argmax per row (manual on CPU)
453                let row_stride = action_dim;
454                let qd = q_online_s2.data();
455                let mut next_actions: Vec<usize> = Vec::with_capacity(batch_size);
456                for i in 0..batch_size {
457                    let base = i * row_stride;
458                    let mut bi = 0usize;
459                    let mut bv = qd[base];
460                    for j in 1..action_dim {
461                        let v = qd[base + j];
462                        if v > bv {
463                            bv = v;
464                            bi = j;
465                        }
466                    }
467                    next_actions.push(bi);
468                }
469                let q_targ_s2 = q_targ.forward(&s2);
470                let q_targ_g = q_targ_s2.gather(1, &next_actions, &[batch_size, 1]);
471                let not_done = Tensor::ones(vec![batch_size, 1]).sub_tensor(&d);
472                r.add_tensor(&not_done.mul_scalar(gamma).mul_tensor(&q_targ_g))
473            };
474
475            // Q(s,a) for current actions
476            // Zero grads first
477            {
478                let mut params = q_net.parameters();
479                q_opt.zero_grad(&mut params);
480            }
481
482            let q_all = q_net.forward(&s);
483            let q_sa = q_all.gather(1, &a_idx, &[batch_size, 1]);
484            let diff = q_sa.sub_tensor(&target_q);
485            let mut loss = pseudo_huber_mean(&diff);
486            loss.backward(None);
487
488            // Step (filter only params with grads)
489            {
490                let params = q_net.parameters();
491                let mut with_grads: Vec<&mut Tensor> = Vec::new();
492                for p in params {
493                    if p.grad_owned().is_some() {
494                        with_grads.push(p);
495                    }
496                }
497                if !with_grads.is_empty() {
498                    let gn = grad_global_norm(&mut with_grads);
499                    clip_gradients(&mut with_grads, max_grad_norm, 1e-6);
500                    q_opt.step(&mut with_grads);
501                    q_opt.zero_grad(&mut with_grads);
502                    if t % 100 == 0 {
503                        let mut pn = q_net.parameters();
504                        let pn_l2 = params_l2_norm(&mut pn);
505                        let q_mean = q_all.mean().value();
506                        println!(
507                            "t={:5} | loss={:.4} | q_mean={:.3} | grad_norm={:.3} | param_norm={:.3} | eps={:.3}",
508                            t, loss.value(), q_mean, gn, pn_l2, epsilon
509                        );
510                    }
511                }
512            }
513
514            // Target hard update
515            if t % target_update_interval == 0 {
516                q_targ.net.copy_from(&q_net.net);
517            }
518
519            // Clear graphs
520            clear_all_graphs_known();
521        }
522    }
523
524    println!("=== DQN training finished ===");
525    Ok(())
526}