Skip to main content

scirs2_optimize/mdp/
tabular.rs

1//! Tabular MDP algorithms: value/policy iteration, Q-learning, SARSA.
2//!
3//! All algorithms operate on finite MDPs represented by explicit transition and reward matrices.
4
5use crate::error::OptimizeError;
6use scirs2_core::ndarray::{Array2, Array3};
7
8// ─────────────────────────────────────────────────────────────────────────────
9// MDP definition
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// A finite Markov Decision Process.
13///
14/// Transition probabilities: `T[s, a, s'] = P(s' | s, a)`.  
15/// Rewards: `R[s, a, s']` (triple-index form; use [`Mdp::with_state_action_reward`] for 2-D rewards).
16#[derive(Debug, Clone)]
17pub struct Mdp {
18    /// Number of states.
19    pub n_states: usize,
20    /// Number of actions.
21    pub n_actions: usize,
22    /// Transition tensor `(n_states × n_actions × n_states)`.
23    pub transition: Array3<f64>,
24    /// Reward tensor `(n_states × n_actions × n_states)`.
25    pub reward: Array3<f64>,
26    /// Discount factor γ ∈ [0, 1).
27    pub gamma: f64,
28    /// Optional absorbing / terminal states (no transitions away).
29    pub terminal_states: Vec<usize>,
30}
31
32impl Mdp {
33    /// Create a new MDP and validate it.
34    pub fn new(
35        n_states: usize,
36        n_actions: usize,
37        transition: Array3<f64>,
38        reward: Array3<f64>,
39        gamma: f64,
40    ) -> Result<Self, OptimizeError> {
41        if n_states == 0 {
42            return Err(OptimizeError::ValueError(
43                "n_states must be > 0".to_string(),
44            ));
45        }
46        if n_actions == 0 {
47            return Err(OptimizeError::ValueError(
48                "n_actions must be > 0".to_string(),
49            ));
50        }
51        if transition.shape() != [n_states, n_actions, n_states] {
52            return Err(OptimizeError::ValueError(format!(
53                "transition shape {:?} != [{}, {}, {}]",
54                transition.shape(),
55                n_states,
56                n_actions,
57                n_states
58            )));
59        }
60        if reward.shape() != [n_states, n_actions, n_states] {
61            return Err(OptimizeError::ValueError(format!(
62                "reward shape {:?} != [{}, {}, {}]",
63                reward.shape(),
64                n_states,
65                n_actions,
66                n_states
67            )));
68        }
69        if !(0.0..=1.0).contains(&gamma) {
70            return Err(OptimizeError::ValueError(format!(
71                "gamma {} must be in [0, 1]",
72                gamma
73            )));
74        }
75        let mdp = Self {
76            n_states,
77            n_actions,
78            transition,
79            reward,
80            gamma,
81            terminal_states: Vec::new(),
82        };
83        mdp.validate()?;
84        Ok(mdp)
85    }
86
87    /// Validate that transition probabilities sum to 1 for each (s, a).
88    pub fn validate(&self) -> Result<(), OptimizeError> {
89        for s in 0..self.n_states {
90            for a in 0..self.n_actions {
91                let sum: f64 = (0..self.n_states)
92                    .map(|sp| self.transition[[s, a, sp]])
93                    .sum();
94                if (sum - 1.0).abs() > 1e-6 {
95                    return Err(OptimizeError::ValueError(format!(
96                        "Transition probabilities for state {} action {} sum to {} (expected 1)",
97                        s, a, sum
98                    )));
99                }
100                // Ensure non-negative
101                for sp in 0..self.n_states {
102                    let p = self.transition[[s, a, sp]];
103                    if p < -1e-10 {
104                        return Err(OptimizeError::ValueError(format!(
105                            "Negative transition probability T[{},{},{}] = {}",
106                            s, a, sp, p
107                        )));
108                    }
109                }
110            }
111        }
112        Ok(())
113    }
114
115    /// Expected reward R(s,a) = Σ_{s'} T(s,a,s') · R(s,a,s').
116    pub fn expected_reward(&self) -> Array2<f64> {
117        let mut r = Array2::<f64>::zeros((self.n_states, self.n_actions));
118        for s in 0..self.n_states {
119            for a in 0..self.n_actions {
120                let val: f64 = (0..self.n_states)
121                    .map(|sp| self.transition[[s, a, sp]] * self.reward[[s, a, sp]])
122                    .sum();
123                r[[s, a]] = val;
124            }
125        }
126        r
127    }
128
129    /// Build an MDP from a 2-D reward matrix (state × action) by broadcasting to 3-D.
130    pub fn with_state_action_reward(
131        n_states: usize,
132        n_actions: usize,
133        transition: Array3<f64>,
134        reward: Array2<f64>,
135        gamma: f64,
136    ) -> Result<Self, OptimizeError> {
137        if reward.shape() != [n_states, n_actions] {
138            return Err(OptimizeError::ValueError(format!(
139                "reward shape {:?} != [{}, {}]",
140                reward.shape(),
141                n_states,
142                n_actions
143            )));
144        }
145        // Broadcast: R[s, a, s'] = reward[s, a] for all s'
146        let mut r3 = Array3::<f64>::zeros((n_states, n_actions, n_states));
147        for s in 0..n_states {
148            for a in 0..n_actions {
149                for sp in 0..n_states {
150                    r3[[s, a, sp]] = reward[[s, a]];
151                }
152            }
153        }
154        Self::new(n_states, n_actions, transition, r3, gamma)
155    }
156
157    /// Compute the Bellman backup Q(s,a) = R(s,a) + γ Σ_{s'} T(s,a,s') V(s').
158    fn q_values(&self, v: &[f64], r: &Array2<f64>) -> Array2<f64> {
159        let mut q = Array2::<f64>::zeros((self.n_states, self.n_actions));
160        for s in 0..self.n_states {
161            for a in 0..self.n_actions {
162                let future: f64 = (0..self.n_states)
163                    .map(|sp| self.transition[[s, a, sp]] * v[sp])
164                    .sum();
165                q[[s, a]] = r[[s, a]] + self.gamma * future;
166            }
167        }
168        q
169    }
170}
171
172// ─────────────────────────────────────────────────────────────────────────────
173// Solution container
174// ─────────────────────────────────────────────────────────────────────────────
175
176/// Solution returned by MDP solvers.
177#[derive(Debug, Clone)]
178pub struct MdpSolution {
179    /// Optimal value function V*(s).
180    pub value_function: Vec<f64>,
181    /// Greedy optimal policy π*(s) → action index.
182    pub policy: Vec<usize>,
183    /// Number of iterations performed.
184    pub n_iterations: usize,
185    /// Whether the algorithm converged within tolerance.
186    pub converged: bool,
187    /// Final maximum Bellman residual |V_{k+1} - V_k|_∞.
188    pub max_diff: f64,
189}
190
191// ─────────────────────────────────────────────────────────────────────────────
192// Value Iteration
193// ─────────────────────────────────────────────────────────────────────────────
194
195/// Value Iteration.
196///
197/// Performs Bellman optimality backups until convergence:
198/// `V_{k+1}(s) = max_a [ R(s,a) + γ Σ_{s'} T(s,a,s') V_k(s') ]`
199///
200/// Convergence guarantee: terminates when `‖V_{k+1} − V_k‖_∞ < tol`.
201pub fn value_iteration(mdp: &Mdp, tol: f64, max_iter: usize) -> Result<MdpSolution, OptimizeError> {
202    if tol <= 0.0 {
203        return Err(OptimizeError::ValueError(
204            "tol must be positive".to_string(),
205        ));
206    }
207    let r = mdp.expected_reward();
208    let mut v = vec![0.0_f64; mdp.n_states];
209    let mut policy = vec![0usize; mdp.n_states];
210    let mut max_diff = f64::INFINITY;
211
212    for iter in 0..max_iter {
213        let q = mdp.q_values(&v, &r);
214        max_diff = 0.0_f64;
215        for s in 0..mdp.n_states {
216            let best_a = (0..mdp.n_actions)
217                .max_by(|&a1, &a2| {
218                    q[[s, a1]]
219                        .partial_cmp(&q[[s, a2]])
220                        .unwrap_or(std::cmp::Ordering::Equal)
221                })
222                .unwrap_or(0);
223            let new_v = q[[s, best_a]];
224            let diff = (new_v - v[s]).abs();
225            if diff > max_diff {
226                max_diff = diff;
227            }
228            v[s] = new_v;
229            policy[s] = best_a;
230        }
231        // Apply terminal state overrides: V(terminal) = 0, policy unchanged
232        for &ts in &mdp.terminal_states {
233            if ts < mdp.n_states {
234                v[ts] = 0.0;
235            }
236        }
237        if max_diff < tol {
238            return Ok(MdpSolution {
239                value_function: v,
240                policy,
241                n_iterations: iter + 1,
242                converged: true,
243                max_diff,
244            });
245        }
246    }
247
248    Ok(MdpSolution {
249        value_function: v,
250        policy,
251        n_iterations: max_iter,
252        converged: false,
253        max_diff,
254    })
255}
256
257// ─────────────────────────────────────────────────────────────────────────────
258// Policy Evaluation (iterative)
259// ─────────────────────────────────────────────────────────────────────────────
260
261/// Evaluate a fixed deterministic policy iteratively.
262///
263/// Solves `V^π(s) = R(s,π(s)) + γ Σ_{s'} T(s,π(s),s') V^π(s')` by repeated substitution.
264pub fn evaluate_policy(
265    mdp: &Mdp,
266    policy: &[usize],
267    tol: f64,
268    max_iter: usize,
269) -> Result<Vec<f64>, OptimizeError> {
270    if policy.len() != mdp.n_states {
271        return Err(OptimizeError::ValueError(format!(
272            "policy length {} != n_states {}",
273            policy.len(),
274            mdp.n_states
275        )));
276    }
277    for (s, &a) in policy.iter().enumerate() {
278        if a >= mdp.n_actions {
279            return Err(OptimizeError::ValueError(format!(
280                "policy[{}] = {} >= n_actions {}",
281                s, a, mdp.n_actions
282            )));
283        }
284    }
285    let r = mdp.expected_reward();
286    let mut v = vec![0.0_f64; mdp.n_states];
287
288    for _ in 0..max_iter {
289        let mut max_diff = 0.0_f64;
290        for s in 0..mdp.n_states {
291            let a = policy[s];
292            let future: f64 = (0..mdp.n_states)
293                .map(|sp| mdp.transition[[s, a, sp]] * v[sp])
294                .sum();
295            let new_v = r[[s, a]] + mdp.gamma * future;
296            let diff = (new_v - v[s]).abs();
297            if diff > max_diff {
298                max_diff = diff;
299            }
300            v[s] = new_v;
301        }
302        // Zero out terminal states
303        for &ts in &mdp.terminal_states {
304            if ts < mdp.n_states {
305                v[ts] = 0.0;
306            }
307        }
308        if max_diff < tol {
309            return Ok(v);
310        }
311    }
312    Ok(v)
313}
314
315// ─────────────────────────────────────────────────────────────────────────────
316// Policy Iteration
317// ─────────────────────────────────────────────────────────────────────────────
318
319/// Policy Iteration.
320///
321/// Alternates between full policy evaluation and greedy policy improvement
322/// until the policy is stable.
323pub fn policy_iteration(
324    mdp: &Mdp,
325    tol: f64,
326    max_iter: usize,
327) -> Result<MdpSolution, OptimizeError> {
328    if tol <= 0.0 {
329        return Err(OptimizeError::ValueError(
330            "tol must be positive".to_string(),
331        ));
332    }
333    let r = mdp.expected_reward();
334    let mut policy: Vec<usize> = vec![0; mdp.n_states];
335    let mut v = vec![0.0_f64; mdp.n_states];
336
337    for iter in 0..max_iter {
338        // Policy evaluation
339        v = evaluate_policy(mdp, &policy, tol * 1e-3, max_iter)?;
340
341        // Policy improvement
342        let q = mdp.q_values(&v, &r);
343        let mut stable = true;
344        for s in 0..mdp.n_states {
345            let best_a = (0..mdp.n_actions)
346                .max_by(|&a1, &a2| {
347                    q[[s, a1]]
348                        .partial_cmp(&q[[s, a2]])
349                        .unwrap_or(std::cmp::Ordering::Equal)
350                })
351                .unwrap_or(0);
352            if best_a != policy[s] {
353                stable = false;
354                policy[s] = best_a;
355            }
356        }
357
358        if stable {
359            // Final max_diff: Bellman residual of converged value function
360            let q_final = mdp.q_values(&v, &r);
361            let max_diff = (0..mdp.n_states)
362                .map(|s| {
363                    let best = (0..mdp.n_actions)
364                        .map(|a| q_final[[s, a]])
365                        .fold(f64::NEG_INFINITY, f64::max);
366                    (best - v[s]).abs()
367                })
368                .fold(0.0_f64, f64::max);
369            return Ok(MdpSolution {
370                value_function: v,
371                policy,
372                n_iterations: iter + 1,
373                converged: true,
374                max_diff,
375            });
376        }
377    }
378
379    let max_diff = compute_bellman_residual(mdp, &v, &r);
380    Ok(MdpSolution {
381        value_function: v,
382        policy,
383        n_iterations: max_iter,
384        converged: false,
385        max_diff,
386    })
387}
388
389// ─────────────────────────────────────────────────────────────────────────────
390// Modified Policy Iteration
391// ─────────────────────────────────────────────────────────────────────────────
392
393/// Modified Policy Iteration (k-step partial evaluation).
394///
395/// Each iteration applies `k` Bellman updates under the current policy instead
396/// of running full evaluation to convergence.  k=1 recovers value iteration;
397/// k→∞ recovers standard policy iteration.
398pub fn modified_policy_iteration(
399    mdp: &Mdp,
400    k: usize,
401    tol: f64,
402    max_iter: usize,
403) -> Result<MdpSolution, OptimizeError> {
404    if tol <= 0.0 {
405        return Err(OptimizeError::ValueError(
406            "tol must be positive".to_string(),
407        ));
408    }
409    if k == 0 {
410        return Err(OptimizeError::ValueError("k must be >= 1".to_string()));
411    }
412    let r = mdp.expected_reward();
413    let mut v = vec![0.0_f64; mdp.n_states];
414    let mut policy = vec![0usize; mdp.n_states];
415    let mut max_diff = f64::INFINITY;
416
417    for iter in 0..max_iter {
418        // Policy improvement step
419        let q = mdp.q_values(&v, &r);
420        for s in 0..mdp.n_states {
421            policy[s] = (0..mdp.n_actions)
422                .max_by(|&a1, &a2| {
423                    q[[s, a1]]
424                        .partial_cmp(&q[[s, a2]])
425                        .unwrap_or(std::cmp::Ordering::Equal)
426                })
427                .unwrap_or(0);
428        }
429
430        // k partial evaluation steps under current policy
431        max_diff = 0.0_f64;
432        for _ in 0..k {
433            let mut iter_diff = 0.0_f64;
434            for s in 0..mdp.n_states {
435                let a = policy[s];
436                let future: f64 = (0..mdp.n_states)
437                    .map(|sp| mdp.transition[[s, a, sp]] * v[sp])
438                    .sum();
439                let new_v = r[[s, a]] + mdp.gamma * future;
440                let diff = (new_v - v[s]).abs();
441                if diff > iter_diff {
442                    iter_diff = diff;
443                }
444                v[s] = new_v;
445            }
446            for &ts in &mdp.terminal_states {
447                if ts < mdp.n_states {
448                    v[ts] = 0.0;
449                }
450            }
451            if iter_diff > max_diff {
452                max_diff = iter_diff;
453            }
454        }
455
456        if max_diff < tol {
457            return Ok(MdpSolution {
458                value_function: v,
459                policy,
460                n_iterations: iter + 1,
461                converged: true,
462                max_diff,
463            });
464        }
465    }
466
467    Ok(MdpSolution {
468        value_function: v,
469        policy,
470        n_iterations: max_iter,
471        converged: false,
472        max_diff,
473    })
474}
475
476// ─────────────────────────────────────────────────────────────────────────────
477// LP-based MDP solver
478// ─────────────────────────────────────────────────────────────────────────────
479
480/// Solve an MDP via its Linear Programming formulation.
481///
482/// The LP dual minimises `Σ_s V(s)` subject to
483/// `V(s) ≥ R(s,a) + γ Σ_{s'} T(s,a,s') V(s')  ∀ s, a`.
484///
485/// We solve this via a projected value-iteration initialised from above, which
486/// is equivalent to the LP optimum for discounted MDPs (see Puterman 1994 §6.9).
487/// For exact LP we iterate with tighter convergence to emulate LP precision.
488pub fn lp_solve_mdp(mdp: &Mdp) -> Result<MdpSolution, OptimizeError> {
489    // Use high-precision value iteration as LP equivalent for discounted MDPs.
490    // The LP and VI have the same unique fixed point for γ < 1.
491    value_iteration(mdp, 1e-12, 100_000)
492}
493
494// ─────────────────────────────────────────────────────────────────────────────
495// Q-Learning
496// ─────────────────────────────────────────────────────────────────────────────
497
498/// Tabular Q-learning agent (model-free, off-policy TD).
499#[derive(Debug, Clone)]
500pub struct QLearning {
501    /// Q-value table `(n_states × n_actions)`.
502    pub q_table: Array2<f64>,
503    /// Learning rate α ∈ (0, 1].
504    pub alpha: f64,
505    /// ε-greedy exploration probability.
506    pub epsilon: f64,
507    /// Discount factor γ.
508    pub gamma: f64,
509}
510
511impl QLearning {
512    /// Create a new Q-learning agent with zero-initialised Q-table.
513    pub fn new(n_states: usize, n_actions: usize, alpha: f64, epsilon: f64, gamma: f64) -> Self {
514        Self {
515            q_table: Array2::<f64>::zeros((n_states, n_actions)),
516            alpha,
517            epsilon,
518            gamma,
519        }
520    }
521
522    /// Apply a single Q-learning update.
523    ///
524    /// `Q(s,a) ← Q(s,a) + α [ r + γ max_{a'} Q(s',a') − Q(s,a) ]`
525    pub fn update(&mut self, state: usize, action: usize, reward: f64, next_state: usize) {
526        let n_actions = self.q_table.ncols();
527        let max_next = (0..n_actions)
528            .map(|a| self.q_table[[next_state, a]])
529            .fold(f64::NEG_INFINITY, f64::max);
530        let td_error = reward + self.gamma * max_next - self.q_table[[state, action]];
531        self.q_table[[state, action]] += self.alpha * td_error;
532    }
533
534    /// Select an action via ε-greedy policy (deterministic given `rng_seed`).
535    pub fn epsilon_greedy(&self, state: usize, rng_seed: u64) -> usize {
536        let rng_val = lcg_uniform(rng_seed);
537        if rng_val < self.epsilon {
538            // Random action
539            let n_actions = self.q_table.ncols();
540            lcg_index(rng_seed.wrapping_add(1), n_actions)
541        } else {
542            self.greedy(state)
543        }
544    }
545
546    /// Select the greedy action (no exploration).
547    pub fn greedy(&self, state: usize) -> usize {
548        let n_actions = self.q_table.ncols();
549        (0..n_actions)
550            .max_by(|&a1, &a2| {
551                self.q_table[[state, a1]]
552                    .partial_cmp(&self.q_table[[state, a2]])
553                    .unwrap_or(std::cmp::Ordering::Equal)
554            })
555            .unwrap_or(0)
556    }
557
558    /// Train Q-learning on a known MDP for `n_episodes` episodes.
559    ///
560    /// Returns episode discounted returns.
561    pub fn train(
562        &mut self,
563        mdp: &Mdp,
564        n_episodes: usize,
565        max_steps_per_episode: usize,
566        seed: u64,
567    ) -> Result<Vec<f64>, OptimizeError> {
568        let n_states = self.q_table.nrows();
569        if n_states != mdp.n_states {
570            return Err(OptimizeError::ValueError(format!(
571                "Q-table n_states {} != mdp.n_states {}",
572                n_states, mdp.n_states
573            )));
574        }
575        let r = mdp.expected_reward();
576        let mut returns = Vec::with_capacity(n_episodes);
577        let mut rng = seed;
578
579        for ep in 0..n_episodes {
580            // Start from random non-terminal state
581            let mut state = lcg_index(rng, mdp.n_states);
582            rng = lcg_next(rng);
583            // Avoid starting in terminal states
584            let terminal_set: std::collections::HashSet<usize> =
585                mdp.terminal_states.iter().copied().collect();
586            if !terminal_set.is_empty() {
587                let non_terminal: Vec<usize> = (0..mdp.n_states)
588                    .filter(|s| !terminal_set.contains(s))
589                    .collect();
590                if !non_terminal.is_empty() {
591                    state = non_terminal[lcg_index(rng, non_terminal.len())];
592                    rng = lcg_next(rng);
593                }
594            }
595
596            let mut episode_return = 0.0_f64;
597            let mut discount = 1.0_f64;
598
599            for _ in 0..max_steps_per_episode {
600                let action = self.epsilon_greedy(state, rng);
601                rng = lcg_next(rng);
602
603                // Sample next state from transition distribution
604                let next_state = sample_next_state(mdp, state, action, rng);
605                rng = lcg_next(rng);
606
607                let reward = r[[state, action]];
608                episode_return += discount * reward;
609                discount *= self.gamma;
610
611                self.update(state, action, reward, next_state);
612
613                if terminal_set.contains(&next_state) {
614                    break;
615                }
616                state = next_state;
617            }
618            let _ = ep; // suppress lint
619            returns.push(episode_return);
620        }
621        Ok(returns)
622    }
623
624    /// Extract the greedy policy from Q-table.
625    pub fn policy(&self) -> Vec<usize> {
626        let n_states = self.q_table.nrows();
627        (0..n_states).map(|s| self.greedy(s)).collect()
628    }
629
630    /// Estimate the value function: `V(s) = max_a Q(s,a)`.
631    pub fn value_function(&self) -> Vec<f64> {
632        let n_states = self.q_table.nrows();
633        let n_actions = self.q_table.ncols();
634        (0..n_states)
635            .map(|s| {
636                (0..n_actions)
637                    .map(|a| self.q_table[[s, a]])
638                    .fold(f64::NEG_INFINITY, f64::max)
639            })
640            .collect()
641    }
642}
643
644// ─────────────────────────────────────────────────────────────────────────────
645// SARSA
646// ─────────────────────────────────────────────────────────────────────────────
647
648/// Tabular SARSA agent (on-policy TD learning).
649#[derive(Debug, Clone)]
650pub struct Sarsa {
651    /// Q-value table `(n_states × n_actions)`.
652    pub q_table: Array2<f64>,
653    /// Learning rate.
654    pub alpha: f64,
655    /// ε-greedy exploration rate.
656    pub epsilon: f64,
657    /// Discount factor.
658    pub gamma: f64,
659}
660
661impl Sarsa {
662    /// Create a new SARSA agent.
663    pub fn new(n_states: usize, n_actions: usize, alpha: f64, epsilon: f64, gamma: f64) -> Self {
664        Self {
665            q_table: Array2::<f64>::zeros((n_states, n_actions)),
666            alpha,
667            epsilon,
668            gamma,
669        }
670    }
671
672    /// Apply one SARSA TD update.
673    ///
674    /// `Q(s,a) ← Q(s,a) + α [ r + γ Q(s',a') − Q(s,a) ]`
675    pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize, a_next: usize) {
676        let td_error = r + self.gamma * self.q_table[[s_next, a_next]] - self.q_table[[s, a]];
677        self.q_table[[s, a]] += self.alpha * td_error;
678    }
679
680    /// ε-greedy action selection.
681    fn epsilon_greedy_action(&self, state: usize, rng: u64) -> usize {
682        let rng_val = lcg_uniform(rng);
683        if rng_val < self.epsilon {
684            let n_actions = self.q_table.ncols();
685            lcg_index(rng.wrapping_add(1), n_actions)
686        } else {
687            let n_actions = self.q_table.ncols();
688            (0..n_actions)
689                .max_by(|&a1, &a2| {
690                    self.q_table[[state, a1]]
691                        .partial_cmp(&self.q_table[[state, a2]])
692                        .unwrap_or(std::cmp::Ordering::Equal)
693                })
694                .unwrap_or(0)
695        }
696    }
697
698    /// Train SARSA on an MDP for `n_episodes` episodes.
699    pub fn train(
700        &mut self,
701        mdp: &Mdp,
702        n_episodes: usize,
703        max_steps: usize,
704        seed: u64,
705    ) -> Result<Vec<f64>, OptimizeError> {
706        let n_states = self.q_table.nrows();
707        if n_states != mdp.n_states {
708            return Err(OptimizeError::ValueError(format!(
709                "SARSA Q-table n_states {} != mdp.n_states {}",
710                n_states, mdp.n_states
711            )));
712        }
713        let r = mdp.expected_reward();
714        let mut returns = Vec::with_capacity(n_episodes);
715        let mut rng = seed;
716        let terminal_set: std::collections::HashSet<usize> =
717            mdp.terminal_states.iter().copied().collect();
718
719        for _ in 0..n_episodes {
720            let mut state = lcg_index(rng, mdp.n_states);
721            rng = lcg_next(rng);
722
723            let mut action = self.epsilon_greedy_action(state, rng);
724            rng = lcg_next(rng);
725
726            let mut episode_return = 0.0_f64;
727            let mut discount = 1.0_f64;
728
729            for _ in 0..max_steps {
730                let next_state = sample_next_state(mdp, state, action, rng);
731                rng = lcg_next(rng);
732                let reward = r[[state, action]];
733                episode_return += discount * reward;
734                discount *= self.gamma;
735
736                let next_action = self.epsilon_greedy_action(next_state, rng);
737                rng = lcg_next(rng);
738
739                self.update(state, action, reward, next_state, next_action);
740
741                if terminal_set.contains(&next_state) {
742                    break;
743                }
744                state = next_state;
745                action = next_action;
746            }
747            returns.push(episode_return);
748        }
749        Ok(returns)
750    }
751
752    /// Extract greedy policy.
753    pub fn policy(&self) -> Vec<usize> {
754        let n_states = self.q_table.nrows();
755        let n_actions = self.q_table.ncols();
756        (0..n_states)
757            .map(|s| {
758                (0..n_actions)
759                    .max_by(|&a1, &a2| {
760                        self.q_table[[s, a1]]
761                            .partial_cmp(&self.q_table[[s, a2]])
762                            .unwrap_or(std::cmp::Ordering::Equal)
763                    })
764                    .unwrap_or(0)
765            })
766            .collect()
767    }
768}
769
770// ─────────────────────────────────────────────────────────────────────────────
771// Simulation
772// ─────────────────────────────────────────────────────────────────────────────
773
774/// Simulate an MDP with a fixed deterministic policy.
775///
776/// Returns `(states, actions, rewards)` trajectories of length `n_steps`.
777pub fn simulate(
778    mdp: &Mdp,
779    policy: &[usize],
780    initial_state: usize,
781    n_steps: usize,
782    seed: u64,
783) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
784    let r = mdp.expected_reward();
785    let mut states = Vec::with_capacity(n_steps + 1);
786    let mut actions = Vec::with_capacity(n_steps);
787    let mut rewards = Vec::with_capacity(n_steps);
788    let terminal_set: std::collections::HashSet<usize> =
789        mdp.terminal_states.iter().copied().collect();
790
791    let mut state = initial_state;
792    let mut rng = seed;
793    states.push(state);
794
795    for _ in 0..n_steps {
796        if terminal_set.contains(&state) {
797            break;
798        }
799        let action = if state < policy.len() {
800            policy[state]
801        } else {
802            0
803        };
804        let next_state = sample_next_state(mdp, state, action, rng);
805        rng = lcg_next(rng);
806        let reward = r[[state, action]];
807        actions.push(action);
808        rewards.push(reward);
809        states.push(next_state);
810        state = next_state;
811    }
812    (states, actions, rewards)
813}
814
815// ─────────────────────────────────────────────────────────────────────────────
816// Internal helpers
817// ─────────────────────────────────────────────────────────────────────────────
818
819/// Simple LCG pseudo-random number generator state advance.
820pub(crate) fn lcg_next(state: u64) -> u64 {
821    state
822        .wrapping_mul(6364136223846793005)
823        .wrapping_add(1442695040888963407)
824}
825
826/// Map LCG state to uniform f64 in [0,1).
827pub(crate) fn lcg_uniform(state: u64) -> f64 {
828    (lcg_next(state) >> 11) as f64 / (1u64 << 53) as f64
829}
830
831/// Map LCG state to index in [0, n).
832pub(crate) fn lcg_index(state: u64, n: usize) -> usize {
833    if n == 0 {
834        return 0;
835    }
836    (lcg_next(state) as usize) % n
837}
838
839/// Sample a next state by CDF inversion on the transition row.
840pub(crate) fn sample_next_state(mdp: &Mdp, state: usize, action: usize, rng: u64) -> usize {
841    let u = lcg_uniform(rng);
842    let mut cumsum = 0.0_f64;
843    for sp in 0..mdp.n_states {
844        cumsum += mdp.transition[[state, action, sp]];
845        if u < cumsum {
846            return sp;
847        }
848    }
849    // Numerical safety: return last state
850    mdp.n_states - 1
851}
852
853/// Compute the Bellman residual ‖TV − V‖_∞.
854pub(crate) fn compute_bellman_residual(mdp: &Mdp, v: &[f64], r: &Array2<f64>) -> f64 {
855    let q = mdp.q_values(v, r);
856    (0..mdp.n_states)
857        .map(|s| {
858            let best = (0..mdp.n_actions)
859                .map(|a| q[[s, a]])
860                .fold(f64::NEG_INFINITY, f64::max);
861            (best - v[s]).abs()
862        })
863        .fold(0.0_f64, f64::max)
864}
865
866// ─────────────────────────────────────────────────────────────────────────────
867// Tests
868// ─────────────────────────────────────────────────────────────────────────────
869
870#[cfg(test)]
871mod tests {
872    use super::*;
873    use scirs2_core::ndarray::{Array2, Array3};
874
875    /// Build a deterministic 2-state, 1-action MDP: state 0 → state 1 with reward 1.
876    fn two_state_deterministic() -> Mdp {
877        let n = 2;
878        let a = 1;
879        let mut t = Array3::<f64>::zeros((n, a, n));
880        t[[0, 0, 1]] = 1.0;
881        t[[1, 0, 1]] = 1.0; // absorbing
882        let mut r = Array3::<f64>::zeros((n, a, n));
883        r[[0, 0, 1]] = 1.0; // reward for transitioning s0→s1
884        let mut mdp = Mdp::new(n, a, t, r, 0.9).expect("failed to create mdp");
885        mdp.terminal_states = vec![1];
886        mdp
887    }
888
889    /// Build a simple 3-state, 2-action gridworld-style MDP.
890    fn three_state_mdp() -> Mdp {
891        let n = 3;
892        let a = 2;
893        // Action 0: deterministic move right (0→1→2→2)
894        // Action 1: stay in place
895        let mut t = Array3::<f64>::zeros((n, a, n));
896        t[[0, 0, 1]] = 1.0;
897        t[[1, 0, 2]] = 1.0;
898        t[[2, 0, 2]] = 1.0;
899        t[[0, 1, 0]] = 1.0;
900        t[[1, 1, 1]] = 1.0;
901        t[[2, 1, 2]] = 1.0;
902        let mut r = Array3::<f64>::zeros((n, a, n));
903        r[[1, 0, 2]] = 1.0; // reward for reaching state 2 via action 0
904        Mdp::new(n, a, t, r, 0.9).expect("unexpected None or Err")
905    }
906
907    /// Build a stochastic 3-state MDP.
908    fn stochastic_mdp() -> Mdp {
909        let n = 3;
910        let a = 2;
911        let mut t = Array3::<f64>::zeros((n, a, n));
912        // Action 0 from state 0: 70% → state 1, 30% → state 2
913        t[[0, 0, 1]] = 0.7;
914        t[[0, 0, 2]] = 0.3;
915        // Action 1 from state 0: 100% → state 0 (stay)
916        t[[0, 1, 0]] = 1.0;
917        // State 1: both actions go to state 2
918        t[[1, 0, 2]] = 1.0;
919        t[[1, 1, 2]] = 1.0;
920        // State 2 absorbing
921        t[[2, 0, 2]] = 1.0;
922        t[[2, 1, 2]] = 1.0;
923        let mut r = Array3::<f64>::zeros((n, a, n));
924        r[[0, 0, 1]] = 0.5;
925        r[[0, 0, 2]] = 1.0;
926        r[[1, 0, 2]] = 2.0;
927        r[[1, 1, 2]] = 2.0;
928        Mdp::new(n, a, t, r, 0.9).expect("unexpected None or Err")
929    }
930
931    // ── MDP construction ────────────────────────────────────────────────────
932
933    #[test]
934    fn test_mdp_construction_valid() {
935        let mdp = two_state_deterministic();
936        assert_eq!(mdp.n_states, 2);
937        assert_eq!(mdp.n_actions, 1);
938    }
939
940    #[test]
941    fn test_mdp_construction_bad_gamma() {
942        let n = 2;
943        let t = Array3::<f64>::zeros((n, 1, n));
944        let r = Array3::<f64>::zeros((n, 1, n));
945        assert!(Mdp::new(n, 1, t, r, 1.5).is_err());
946    }
947
948    #[test]
949    fn test_mdp_validation_rejects_bad_transitions() {
950        let n = 2;
951        let a = 1;
952        // Row does not sum to 1
953        let t = Array3::<f64>::zeros((n, a, n));
954        let r = Array3::<f64>::zeros((n, a, n));
955        assert!(Mdp::new(n, a, t, r, 0.9).is_err());
956    }
957
958    #[test]
959    fn test_expected_reward() {
960        let mdp = two_state_deterministic();
961        let er = mdp.expected_reward();
962        // R(s=0, a=0) = T(0,0,1)*R(0,0,1) = 1.0*1.0 = 1.0
963        assert!((er[[0, 0]] - 1.0).abs() < 1e-9);
964    }
965
966    #[test]
967    fn test_with_state_action_reward() {
968        let n = 2;
969        let a = 2;
970        let mut t = Array3::<f64>::zeros((n, a, n));
971        t[[0, 0, 1]] = 1.0;
972        t[[0, 1, 0]] = 1.0;
973        t[[1, 0, 1]] = 1.0;
974        t[[1, 1, 1]] = 1.0;
975        let r2 = Array2::<f64>::from_elem((n, a), 0.5);
976        let mdp = Mdp::with_state_action_reward(n, a, t, r2, 0.9);
977        assert!(mdp.is_ok());
978        let mdp = mdp.expect("failed to create mdp");
979        // All rewards in 3D should be 0.5
980        assert!((mdp.reward[[0, 0, 0]] - 0.5).abs() < 1e-9);
981        assert!((mdp.reward[[1, 1, 1]] - 0.5).abs() < 1e-9);
982    }
983
984    // ── Value Iteration ──────────────────────────────────────────────────────
985
986    #[test]
987    fn test_value_iteration_two_state() {
988        let mdp = two_state_deterministic();
989        let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
990        assert!(sol.converged);
991        // V(terminal=1) should be 0
992        assert!(sol.value_function[1].abs() < 1e-6);
993        // V(0) = 1.0 (immediate reward, then terminal)
994        assert!((sol.value_function[0] - 1.0).abs() < 1e-4);
995    }
996
997    #[test]
998    fn test_value_iteration_three_state() {
999        let mdp = three_state_mdp();
1000        let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1001        assert!(sol.converged);
1002        // Optimal policy: always move right.
1003        // State 1 gets the reward r[1,0,2]=1 when transitioning to state 2.
1004        // → V(1) = 1 + 0.9*V(2) = 1.0   (V(2) = 0, absorbing)
1005        // → V(0) = 0 + 0.9*V(1) = 0.9
1006        // State 0 is one step further from the reward than state 1.
1007        assert!(sol.value_function[0] > 0.0);
1008        assert!(sol.value_function[1] > sol.value_function[0]);
1009        assert!((sol.value_function[1] - 1.0).abs() < 1e-4);
1010        assert!((sol.value_function[0] - 0.9).abs() < 1e-4);
1011    }
1012
1013    #[test]
1014    fn test_value_iteration_policy_is_greedy() {
1015        let mdp = three_state_mdp();
1016        let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1017        assert!(sol.converged);
1018        // States 0 and 1 should prefer action 0 (move right)
1019        assert_eq!(sol.policy[0], 0);
1020        assert_eq!(sol.policy[1], 0);
1021    }
1022
1023    #[test]
1024    fn test_value_iteration_convergence_flag() {
1025        let mdp = three_state_mdp();
1026        // Very tight tolerance → still converges
1027        let sol = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create sol");
1028        assert!(sol.converged);
1029    }
1030
1031    #[test]
1032    fn test_value_iteration_stochastic() {
1033        let mdp = stochastic_mdp();
1034        let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1035        assert!(sol.converged);
1036        assert!(
1037            sol.value_function[2].abs() < 1e-6,
1038            "absorbing state value must be 0"
1039        );
1040    }
1041
1042    // ── Policy Evaluation ────────────────────────────────────────────────────
1043
1044    #[test]
1045    fn test_policy_evaluation_consistent() {
1046        let mdp = three_state_mdp();
1047        let vi = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create vi");
1048        // Evaluate the VI-optimal policy
1049        let v_eval =
1050            evaluate_policy(&mdp, &vi.policy, 1e-12, 100_000).expect("failed to create v_eval");
1051        for s in 0..mdp.n_states {
1052            assert!(
1053                (v_eval[s] - vi.value_function[s]).abs() < 1e-4,
1054                "state {}: eval {} vs vi {}",
1055                s,
1056                v_eval[s],
1057                vi.value_function[s]
1058            );
1059        }
1060    }
1061
1062    #[test]
1063    fn test_policy_evaluation_bad_policy_length() {
1064        let mdp = two_state_deterministic();
1065        let bad_policy = vec![0usize; 5];
1066        assert!(evaluate_policy(&mdp, &bad_policy, 1e-9, 100).is_err());
1067    }
1068
1069    // ── Policy Iteration ─────────────────────────────────────────────────────
1070
1071    #[test]
1072    fn test_policy_iteration_equals_vi() {
1073        let mdp = three_state_mdp();
1074        let vi = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create vi");
1075        let pi = policy_iteration(&mdp, 1e-9, 10_000).expect("failed to create pi");
1076        assert!(pi.converged);
1077        for s in 0..mdp.n_states {
1078            assert!(
1079                (pi.value_function[s] - vi.value_function[s]).abs() < 1e-3,
1080                "state {}: pi={} vi={}",
1081                s,
1082                pi.value_function[s],
1083                vi.value_function[s]
1084            );
1085        }
1086    }
1087
1088    #[test]
1089    fn test_policy_iteration_stochastic() {
1090        let mdp = stochastic_mdp();
1091        let sol = policy_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1092        assert!(sol.converged);
1093    }
1094
1095    // ── Modified Policy Iteration ────────────────────────────────────────────
1096
1097    #[test]
1098    fn test_modified_policy_iteration_k1_like_vi() {
1099        // k=1 MPI should give same result as value iteration
1100        let mdp = three_state_mdp();
1101        let vi = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create vi");
1102        let mpi = modified_policy_iteration(&mdp, 1, 1e-9, 50_000).expect("failed to create mpi");
1103        assert!(mpi.converged);
1104        for s in 0..mdp.n_states {
1105            assert!(
1106                (mpi.value_function[s] - vi.value_function[s]).abs() < 1e-3,
1107                "state {}: mpi={} vi={}",
1108                s,
1109                mpi.value_function[s],
1110                vi.value_function[s]
1111            );
1112        }
1113    }
1114
1115    #[test]
1116    fn test_modified_policy_iteration_k10() {
1117        let mdp = stochastic_mdp();
1118        let sol = modified_policy_iteration(&mdp, 10, 1e-9, 10_000).expect("failed to create sol");
1119        assert!(sol.converged);
1120    }
1121
1122    #[test]
1123    fn test_modified_policy_iteration_zero_k_error() {
1124        let mdp = two_state_deterministic();
1125        assert!(modified_policy_iteration(&mdp, 0, 1e-9, 100).is_err());
1126    }
1127
1128    // ── LP solve ─────────────────────────────────────────────────────────────
1129
1130    #[test]
1131    fn test_lp_solve_agrees_with_vi() {
1132        let mdp = three_state_mdp();
1133        let vi = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create vi");
1134        let lp = lp_solve_mdp(&mdp).expect("failed to create lp");
1135        for s in 0..mdp.n_states {
1136            assert!(
1137                (lp.value_function[s] - vi.value_function[s]).abs() < 1e-4,
1138                "state {}: lp={} vi={}",
1139                s,
1140                lp.value_function[s],
1141                vi.value_function[s]
1142            );
1143        }
1144    }
1145
1146    // ── Q-Learning ───────────────────────────────────────────────────────────
1147
1148    #[test]
1149    fn test_qlearning_update() {
1150        let mut q = QLearning::new(3, 2, 0.1, 0.0, 0.9);
1151        q.update(0, 0, 1.0, 1);
1152        // After one update from zero: Q(0,0) = 0 + 0.1*(1.0 + 0 - 0) = 0.1
1153        assert!((q.q_table[[0, 0]] - 0.1).abs() < 1e-12);
1154    }
1155
1156    #[test]
1157    fn test_qlearning_greedy() {
1158        let mut q = QLearning::new(3, 2, 0.1, 0.0, 0.9);
1159        q.q_table[[0, 1]] = 5.0;
1160        assert_eq!(q.greedy(0), 1);
1161    }
1162
1163    #[test]
1164    fn test_qlearning_train_returns_length() {
1165        let mdp = three_state_mdp();
1166        let mut q = QLearning::new(3, 2, 0.3, 0.1, 0.9);
1167        let returns = q
1168            .train(&mdp, 100, 20, 42)
1169            .expect("failed to create returns");
1170        assert_eq!(returns.len(), 100);
1171    }
1172
1173    #[test]
1174    fn test_qlearning_policy_shape() {
1175        let mut q = QLearning::new(3, 2, 0.3, 0.1, 0.9);
1176        let mdp = three_state_mdp();
1177        let _ = q.train(&mdp, 200, 30, 7).expect("failed to create _");
1178        let pol = q.policy();
1179        assert_eq!(pol.len(), 3);
1180        for &a in &pol {
1181            assert!(a < 2);
1182        }
1183    }
1184
1185    #[test]
1186    fn test_qlearning_value_function() {
1187        let q = QLearning::new(2, 2, 0.1, 0.0, 0.9);
1188        let vf = q.value_function();
1189        assert_eq!(vf.len(), 2);
1190    }
1191
1192    // ── SARSA ────────────────────────────────────────────────────────────────
1193
1194    #[test]
1195    fn test_sarsa_update() {
1196        let mut s = Sarsa::new(3, 2, 0.1, 0.0, 0.9);
1197        s.update(0, 0, 1.0, 1, 0);
1198        // Q(0,0) = 0 + 0.1*(1.0 + 0.9*Q(1,0) - 0) = 0.1
1199        assert!((s.q_table[[0, 0]] - 0.1).abs() < 1e-12);
1200    }
1201
1202    #[test]
1203    fn test_sarsa_train_returns_length() {
1204        let mdp = three_state_mdp();
1205        let mut s = Sarsa::new(3, 2, 0.3, 0.1, 0.9);
1206        let returns = s
1207            .train(&mdp, 100, 20, 13)
1208            .expect("failed to create returns");
1209        assert_eq!(returns.len(), 100);
1210    }
1211
1212    #[test]
1213    fn test_sarsa_policy_valid() {
1214        let mdp = three_state_mdp();
1215        let mut s = Sarsa::new(3, 2, 0.3, 0.1, 0.9);
1216        let _ = s.train(&mdp, 200, 30, 99).expect("failed to create _");
1217        let pol = s.policy();
1218        assert_eq!(pol.len(), 3);
1219        for &a in &pol {
1220            assert!(a < 2);
1221        }
1222    }
1223
1224    // ── Simulation ───────────────────────────────────────────────────────────
1225
1226    #[test]
1227    fn test_simulate_length() {
1228        let mdp = three_state_mdp();
1229        let policy = vec![0usize, 0, 0];
1230        let (states, actions, rewards) = simulate(&mdp, &policy, 0, 5, 42);
1231        assert!(states.len() >= 1);
1232        assert_eq!(actions.len(), rewards.len());
1233        assert!(actions.len() <= 5);
1234    }
1235
1236    #[test]
1237    fn test_simulate_terminal_stops() {
1238        let mdp = two_state_deterministic();
1239        let policy = vec![0usize; 2];
1240        let (states, _actions, _rewards) = simulate(&mdp, &policy, 0, 100, 1);
1241        // Should stop after reaching terminal state 1
1242        assert!(states.len() <= 3, "states.len() = {}", states.len());
1243    }
1244}