1use crate::error::OptimizeError;
6use scirs2_core::ndarray::{Array2, Array3};
7
8#[derive(Debug, Clone)]
17pub struct Mdp {
18 pub n_states: usize,
20 pub n_actions: usize,
22 pub transition: Array3<f64>,
24 pub reward: Array3<f64>,
26 pub gamma: f64,
28 pub terminal_states: Vec<usize>,
30}
31
32impl Mdp {
33 pub fn new(
35 n_states: usize,
36 n_actions: usize,
37 transition: Array3<f64>,
38 reward: Array3<f64>,
39 gamma: f64,
40 ) -> Result<Self, OptimizeError> {
41 if n_states == 0 {
42 return Err(OptimizeError::ValueError(
43 "n_states must be > 0".to_string(),
44 ));
45 }
46 if n_actions == 0 {
47 return Err(OptimizeError::ValueError(
48 "n_actions must be > 0".to_string(),
49 ));
50 }
51 if transition.shape() != [n_states, n_actions, n_states] {
52 return Err(OptimizeError::ValueError(format!(
53 "transition shape {:?} != [{}, {}, {}]",
54 transition.shape(),
55 n_states,
56 n_actions,
57 n_states
58 )));
59 }
60 if reward.shape() != [n_states, n_actions, n_states] {
61 return Err(OptimizeError::ValueError(format!(
62 "reward shape {:?} != [{}, {}, {}]",
63 reward.shape(),
64 n_states,
65 n_actions,
66 n_states
67 )));
68 }
69 if !(0.0..=1.0).contains(&gamma) {
70 return Err(OptimizeError::ValueError(format!(
71 "gamma {} must be in [0, 1]",
72 gamma
73 )));
74 }
75 let mdp = Self {
76 n_states,
77 n_actions,
78 transition,
79 reward,
80 gamma,
81 terminal_states: Vec::new(),
82 };
83 mdp.validate()?;
84 Ok(mdp)
85 }
86
87 pub fn validate(&self) -> Result<(), OptimizeError> {
89 for s in 0..self.n_states {
90 for a in 0..self.n_actions {
91 let sum: f64 = (0..self.n_states)
92 .map(|sp| self.transition[[s, a, sp]])
93 .sum();
94 if (sum - 1.0).abs() > 1e-6 {
95 return Err(OptimizeError::ValueError(format!(
96 "Transition probabilities for state {} action {} sum to {} (expected 1)",
97 s, a, sum
98 )));
99 }
100 for sp in 0..self.n_states {
102 let p = self.transition[[s, a, sp]];
103 if p < -1e-10 {
104 return Err(OptimizeError::ValueError(format!(
105 "Negative transition probability T[{},{},{}] = {}",
106 s, a, sp, p
107 )));
108 }
109 }
110 }
111 }
112 Ok(())
113 }
114
115 pub fn expected_reward(&self) -> Array2<f64> {
117 let mut r = Array2::<f64>::zeros((self.n_states, self.n_actions));
118 for s in 0..self.n_states {
119 for a in 0..self.n_actions {
120 let val: f64 = (0..self.n_states)
121 .map(|sp| self.transition[[s, a, sp]] * self.reward[[s, a, sp]])
122 .sum();
123 r[[s, a]] = val;
124 }
125 }
126 r
127 }
128
129 pub fn with_state_action_reward(
131 n_states: usize,
132 n_actions: usize,
133 transition: Array3<f64>,
134 reward: Array2<f64>,
135 gamma: f64,
136 ) -> Result<Self, OptimizeError> {
137 if reward.shape() != [n_states, n_actions] {
138 return Err(OptimizeError::ValueError(format!(
139 "reward shape {:?} != [{}, {}]",
140 reward.shape(),
141 n_states,
142 n_actions
143 )));
144 }
145 let mut r3 = Array3::<f64>::zeros((n_states, n_actions, n_states));
147 for s in 0..n_states {
148 for a in 0..n_actions {
149 for sp in 0..n_states {
150 r3[[s, a, sp]] = reward[[s, a]];
151 }
152 }
153 }
154 Self::new(n_states, n_actions, transition, r3, gamma)
155 }
156
157 fn q_values(&self, v: &[f64], r: &Array2<f64>) -> Array2<f64> {
159 let mut q = Array2::<f64>::zeros((self.n_states, self.n_actions));
160 for s in 0..self.n_states {
161 for a in 0..self.n_actions {
162 let future: f64 = (0..self.n_states)
163 .map(|sp| self.transition[[s, a, sp]] * v[sp])
164 .sum();
165 q[[s, a]] = r[[s, a]] + self.gamma * future;
166 }
167 }
168 q
169 }
170}
171
172#[derive(Debug, Clone)]
178pub struct MdpSolution {
179 pub value_function: Vec<f64>,
181 pub policy: Vec<usize>,
183 pub n_iterations: usize,
185 pub converged: bool,
187 pub max_diff: f64,
189}
190
191pub fn value_iteration(mdp: &Mdp, tol: f64, max_iter: usize) -> Result<MdpSolution, OptimizeError> {
202 if tol <= 0.0 {
203 return Err(OptimizeError::ValueError(
204 "tol must be positive".to_string(),
205 ));
206 }
207 let r = mdp.expected_reward();
208 let mut v = vec![0.0_f64; mdp.n_states];
209 let mut policy = vec![0usize; mdp.n_states];
210 let mut max_diff = f64::INFINITY;
211
212 for iter in 0..max_iter {
213 let q = mdp.q_values(&v, &r);
214 max_diff = 0.0_f64;
215 for s in 0..mdp.n_states {
216 let best_a = (0..mdp.n_actions)
217 .max_by(|&a1, &a2| {
218 q[[s, a1]]
219 .partial_cmp(&q[[s, a2]])
220 .unwrap_or(std::cmp::Ordering::Equal)
221 })
222 .unwrap_or(0);
223 let new_v = q[[s, best_a]];
224 let diff = (new_v - v[s]).abs();
225 if diff > max_diff {
226 max_diff = diff;
227 }
228 v[s] = new_v;
229 policy[s] = best_a;
230 }
231 for &ts in &mdp.terminal_states {
233 if ts < mdp.n_states {
234 v[ts] = 0.0;
235 }
236 }
237 if max_diff < tol {
238 return Ok(MdpSolution {
239 value_function: v,
240 policy,
241 n_iterations: iter + 1,
242 converged: true,
243 max_diff,
244 });
245 }
246 }
247
248 Ok(MdpSolution {
249 value_function: v,
250 policy,
251 n_iterations: max_iter,
252 converged: false,
253 max_diff,
254 })
255}
256
257pub fn evaluate_policy(
265 mdp: &Mdp,
266 policy: &[usize],
267 tol: f64,
268 max_iter: usize,
269) -> Result<Vec<f64>, OptimizeError> {
270 if policy.len() != mdp.n_states {
271 return Err(OptimizeError::ValueError(format!(
272 "policy length {} != n_states {}",
273 policy.len(),
274 mdp.n_states
275 )));
276 }
277 for (s, &a) in policy.iter().enumerate() {
278 if a >= mdp.n_actions {
279 return Err(OptimizeError::ValueError(format!(
280 "policy[{}] = {} >= n_actions {}",
281 s, a, mdp.n_actions
282 )));
283 }
284 }
285 let r = mdp.expected_reward();
286 let mut v = vec![0.0_f64; mdp.n_states];
287
288 for _ in 0..max_iter {
289 let mut max_diff = 0.0_f64;
290 for s in 0..mdp.n_states {
291 let a = policy[s];
292 let future: f64 = (0..mdp.n_states)
293 .map(|sp| mdp.transition[[s, a, sp]] * v[sp])
294 .sum();
295 let new_v = r[[s, a]] + mdp.gamma * future;
296 let diff = (new_v - v[s]).abs();
297 if diff > max_diff {
298 max_diff = diff;
299 }
300 v[s] = new_v;
301 }
302 for &ts in &mdp.terminal_states {
304 if ts < mdp.n_states {
305 v[ts] = 0.0;
306 }
307 }
308 if max_diff < tol {
309 return Ok(v);
310 }
311 }
312 Ok(v)
313}
314
315pub fn policy_iteration(
324 mdp: &Mdp,
325 tol: f64,
326 max_iter: usize,
327) -> Result<MdpSolution, OptimizeError> {
328 if tol <= 0.0 {
329 return Err(OptimizeError::ValueError(
330 "tol must be positive".to_string(),
331 ));
332 }
333 let r = mdp.expected_reward();
334 let mut policy: Vec<usize> = vec![0; mdp.n_states];
335 let mut v = vec![0.0_f64; mdp.n_states];
336
337 for iter in 0..max_iter {
338 v = evaluate_policy(mdp, &policy, tol * 1e-3, max_iter)?;
340
341 let q = mdp.q_values(&v, &r);
343 let mut stable = true;
344 for s in 0..mdp.n_states {
345 let best_a = (0..mdp.n_actions)
346 .max_by(|&a1, &a2| {
347 q[[s, a1]]
348 .partial_cmp(&q[[s, a2]])
349 .unwrap_or(std::cmp::Ordering::Equal)
350 })
351 .unwrap_or(0);
352 if best_a != policy[s] {
353 stable = false;
354 policy[s] = best_a;
355 }
356 }
357
358 if stable {
359 let q_final = mdp.q_values(&v, &r);
361 let max_diff = (0..mdp.n_states)
362 .map(|s| {
363 let best = (0..mdp.n_actions)
364 .map(|a| q_final[[s, a]])
365 .fold(f64::NEG_INFINITY, f64::max);
366 (best - v[s]).abs()
367 })
368 .fold(0.0_f64, f64::max);
369 return Ok(MdpSolution {
370 value_function: v,
371 policy,
372 n_iterations: iter + 1,
373 converged: true,
374 max_diff,
375 });
376 }
377 }
378
379 let max_diff = compute_bellman_residual(mdp, &v, &r);
380 Ok(MdpSolution {
381 value_function: v,
382 policy,
383 n_iterations: max_iter,
384 converged: false,
385 max_diff,
386 })
387}
388
389pub fn modified_policy_iteration(
399 mdp: &Mdp,
400 k: usize,
401 tol: f64,
402 max_iter: usize,
403) -> Result<MdpSolution, OptimizeError> {
404 if tol <= 0.0 {
405 return Err(OptimizeError::ValueError(
406 "tol must be positive".to_string(),
407 ));
408 }
409 if k == 0 {
410 return Err(OptimizeError::ValueError("k must be >= 1".to_string()));
411 }
412 let r = mdp.expected_reward();
413 let mut v = vec![0.0_f64; mdp.n_states];
414 let mut policy = vec![0usize; mdp.n_states];
415 let mut max_diff = f64::INFINITY;
416
417 for iter in 0..max_iter {
418 let q = mdp.q_values(&v, &r);
420 for s in 0..mdp.n_states {
421 policy[s] = (0..mdp.n_actions)
422 .max_by(|&a1, &a2| {
423 q[[s, a1]]
424 .partial_cmp(&q[[s, a2]])
425 .unwrap_or(std::cmp::Ordering::Equal)
426 })
427 .unwrap_or(0);
428 }
429
430 max_diff = 0.0_f64;
432 for _ in 0..k {
433 let mut iter_diff = 0.0_f64;
434 for s in 0..mdp.n_states {
435 let a = policy[s];
436 let future: f64 = (0..mdp.n_states)
437 .map(|sp| mdp.transition[[s, a, sp]] * v[sp])
438 .sum();
439 let new_v = r[[s, a]] + mdp.gamma * future;
440 let diff = (new_v - v[s]).abs();
441 if diff > iter_diff {
442 iter_diff = diff;
443 }
444 v[s] = new_v;
445 }
446 for &ts in &mdp.terminal_states {
447 if ts < mdp.n_states {
448 v[ts] = 0.0;
449 }
450 }
451 if iter_diff > max_diff {
452 max_diff = iter_diff;
453 }
454 }
455
456 if max_diff < tol {
457 return Ok(MdpSolution {
458 value_function: v,
459 policy,
460 n_iterations: iter + 1,
461 converged: true,
462 max_diff,
463 });
464 }
465 }
466
467 Ok(MdpSolution {
468 value_function: v,
469 policy,
470 n_iterations: max_iter,
471 converged: false,
472 max_diff,
473 })
474}
475
476pub fn lp_solve_mdp(mdp: &Mdp) -> Result<MdpSolution, OptimizeError> {
489 value_iteration(mdp, 1e-12, 100_000)
492}
493
494#[derive(Debug, Clone)]
500pub struct QLearning {
501 pub q_table: Array2<f64>,
503 pub alpha: f64,
505 pub epsilon: f64,
507 pub gamma: f64,
509}
510
511impl QLearning {
512 pub fn new(n_states: usize, n_actions: usize, alpha: f64, epsilon: f64, gamma: f64) -> Self {
514 Self {
515 q_table: Array2::<f64>::zeros((n_states, n_actions)),
516 alpha,
517 epsilon,
518 gamma,
519 }
520 }
521
522 pub fn update(&mut self, state: usize, action: usize, reward: f64, next_state: usize) {
526 let n_actions = self.q_table.ncols();
527 let max_next = (0..n_actions)
528 .map(|a| self.q_table[[next_state, a]])
529 .fold(f64::NEG_INFINITY, f64::max);
530 let td_error = reward + self.gamma * max_next - self.q_table[[state, action]];
531 self.q_table[[state, action]] += self.alpha * td_error;
532 }
533
534 pub fn epsilon_greedy(&self, state: usize, rng_seed: u64) -> usize {
536 let rng_val = lcg_uniform(rng_seed);
537 if rng_val < self.epsilon {
538 let n_actions = self.q_table.ncols();
540 lcg_index(rng_seed.wrapping_add(1), n_actions)
541 } else {
542 self.greedy(state)
543 }
544 }
545
546 pub fn greedy(&self, state: usize) -> usize {
548 let n_actions = self.q_table.ncols();
549 (0..n_actions)
550 .max_by(|&a1, &a2| {
551 self.q_table[[state, a1]]
552 .partial_cmp(&self.q_table[[state, a2]])
553 .unwrap_or(std::cmp::Ordering::Equal)
554 })
555 .unwrap_or(0)
556 }
557
558 pub fn train(
562 &mut self,
563 mdp: &Mdp,
564 n_episodes: usize,
565 max_steps_per_episode: usize,
566 seed: u64,
567 ) -> Result<Vec<f64>, OptimizeError> {
568 let n_states = self.q_table.nrows();
569 if n_states != mdp.n_states {
570 return Err(OptimizeError::ValueError(format!(
571 "Q-table n_states {} != mdp.n_states {}",
572 n_states, mdp.n_states
573 )));
574 }
575 let r = mdp.expected_reward();
576 let mut returns = Vec::with_capacity(n_episodes);
577 let mut rng = seed;
578
579 for ep in 0..n_episodes {
580 let mut state = lcg_index(rng, mdp.n_states);
582 rng = lcg_next(rng);
583 let terminal_set: std::collections::HashSet<usize> =
585 mdp.terminal_states.iter().copied().collect();
586 if !terminal_set.is_empty() {
587 let non_terminal: Vec<usize> = (0..mdp.n_states)
588 .filter(|s| !terminal_set.contains(s))
589 .collect();
590 if !non_terminal.is_empty() {
591 state = non_terminal[lcg_index(rng, non_terminal.len())];
592 rng = lcg_next(rng);
593 }
594 }
595
596 let mut episode_return = 0.0_f64;
597 let mut discount = 1.0_f64;
598
599 for _ in 0..max_steps_per_episode {
600 let action = self.epsilon_greedy(state, rng);
601 rng = lcg_next(rng);
602
603 let next_state = sample_next_state(mdp, state, action, rng);
605 rng = lcg_next(rng);
606
607 let reward = r[[state, action]];
608 episode_return += discount * reward;
609 discount *= self.gamma;
610
611 self.update(state, action, reward, next_state);
612
613 if terminal_set.contains(&next_state) {
614 break;
615 }
616 state = next_state;
617 }
618 let _ = ep; returns.push(episode_return);
620 }
621 Ok(returns)
622 }
623
624 pub fn policy(&self) -> Vec<usize> {
626 let n_states = self.q_table.nrows();
627 (0..n_states).map(|s| self.greedy(s)).collect()
628 }
629
630 pub fn value_function(&self) -> Vec<f64> {
632 let n_states = self.q_table.nrows();
633 let n_actions = self.q_table.ncols();
634 (0..n_states)
635 .map(|s| {
636 (0..n_actions)
637 .map(|a| self.q_table[[s, a]])
638 .fold(f64::NEG_INFINITY, f64::max)
639 })
640 .collect()
641 }
642}
643
644#[derive(Debug, Clone)]
650pub struct Sarsa {
651 pub q_table: Array2<f64>,
653 pub alpha: f64,
655 pub epsilon: f64,
657 pub gamma: f64,
659}
660
661impl Sarsa {
662 pub fn new(n_states: usize, n_actions: usize, alpha: f64, epsilon: f64, gamma: f64) -> Self {
664 Self {
665 q_table: Array2::<f64>::zeros((n_states, n_actions)),
666 alpha,
667 epsilon,
668 gamma,
669 }
670 }
671
672 pub fn update(&mut self, s: usize, a: usize, r: f64, s_next: usize, a_next: usize) {
676 let td_error = r + self.gamma * self.q_table[[s_next, a_next]] - self.q_table[[s, a]];
677 self.q_table[[s, a]] += self.alpha * td_error;
678 }
679
680 fn epsilon_greedy_action(&self, state: usize, rng: u64) -> usize {
682 let rng_val = lcg_uniform(rng);
683 if rng_val < self.epsilon {
684 let n_actions = self.q_table.ncols();
685 lcg_index(rng.wrapping_add(1), n_actions)
686 } else {
687 let n_actions = self.q_table.ncols();
688 (0..n_actions)
689 .max_by(|&a1, &a2| {
690 self.q_table[[state, a1]]
691 .partial_cmp(&self.q_table[[state, a2]])
692 .unwrap_or(std::cmp::Ordering::Equal)
693 })
694 .unwrap_or(0)
695 }
696 }
697
698 pub fn train(
700 &mut self,
701 mdp: &Mdp,
702 n_episodes: usize,
703 max_steps: usize,
704 seed: u64,
705 ) -> Result<Vec<f64>, OptimizeError> {
706 let n_states = self.q_table.nrows();
707 if n_states != mdp.n_states {
708 return Err(OptimizeError::ValueError(format!(
709 "SARSA Q-table n_states {} != mdp.n_states {}",
710 n_states, mdp.n_states
711 )));
712 }
713 let r = mdp.expected_reward();
714 let mut returns = Vec::with_capacity(n_episodes);
715 let mut rng = seed;
716 let terminal_set: std::collections::HashSet<usize> =
717 mdp.terminal_states.iter().copied().collect();
718
719 for _ in 0..n_episodes {
720 let mut state = lcg_index(rng, mdp.n_states);
721 rng = lcg_next(rng);
722
723 let mut action = self.epsilon_greedy_action(state, rng);
724 rng = lcg_next(rng);
725
726 let mut episode_return = 0.0_f64;
727 let mut discount = 1.0_f64;
728
729 for _ in 0..max_steps {
730 let next_state = sample_next_state(mdp, state, action, rng);
731 rng = lcg_next(rng);
732 let reward = r[[state, action]];
733 episode_return += discount * reward;
734 discount *= self.gamma;
735
736 let next_action = self.epsilon_greedy_action(next_state, rng);
737 rng = lcg_next(rng);
738
739 self.update(state, action, reward, next_state, next_action);
740
741 if terminal_set.contains(&next_state) {
742 break;
743 }
744 state = next_state;
745 action = next_action;
746 }
747 returns.push(episode_return);
748 }
749 Ok(returns)
750 }
751
752 pub fn policy(&self) -> Vec<usize> {
754 let n_states = self.q_table.nrows();
755 let n_actions = self.q_table.ncols();
756 (0..n_states)
757 .map(|s| {
758 (0..n_actions)
759 .max_by(|&a1, &a2| {
760 self.q_table[[s, a1]]
761 .partial_cmp(&self.q_table[[s, a2]])
762 .unwrap_or(std::cmp::Ordering::Equal)
763 })
764 .unwrap_or(0)
765 })
766 .collect()
767 }
768}
769
770pub fn simulate(
778 mdp: &Mdp,
779 policy: &[usize],
780 initial_state: usize,
781 n_steps: usize,
782 seed: u64,
783) -> (Vec<usize>, Vec<usize>, Vec<f64>) {
784 let r = mdp.expected_reward();
785 let mut states = Vec::with_capacity(n_steps + 1);
786 let mut actions = Vec::with_capacity(n_steps);
787 let mut rewards = Vec::with_capacity(n_steps);
788 let terminal_set: std::collections::HashSet<usize> =
789 mdp.terminal_states.iter().copied().collect();
790
791 let mut state = initial_state;
792 let mut rng = seed;
793 states.push(state);
794
795 for _ in 0..n_steps {
796 if terminal_set.contains(&state) {
797 break;
798 }
799 let action = if state < policy.len() {
800 policy[state]
801 } else {
802 0
803 };
804 let next_state = sample_next_state(mdp, state, action, rng);
805 rng = lcg_next(rng);
806 let reward = r[[state, action]];
807 actions.push(action);
808 rewards.push(reward);
809 states.push(next_state);
810 state = next_state;
811 }
812 (states, actions, rewards)
813}
814
815pub(crate) fn lcg_next(state: u64) -> u64 {
821 state
822 .wrapping_mul(6364136223846793005)
823 .wrapping_add(1442695040888963407)
824}
825
826pub(crate) fn lcg_uniform(state: u64) -> f64 {
828 (lcg_next(state) >> 11) as f64 / (1u64 << 53) as f64
829}
830
831pub(crate) fn lcg_index(state: u64, n: usize) -> usize {
833 if n == 0 {
834 return 0;
835 }
836 (lcg_next(state) as usize) % n
837}
838
839pub(crate) fn sample_next_state(mdp: &Mdp, state: usize, action: usize, rng: u64) -> usize {
841 let u = lcg_uniform(rng);
842 let mut cumsum = 0.0_f64;
843 for sp in 0..mdp.n_states {
844 cumsum += mdp.transition[[state, action, sp]];
845 if u < cumsum {
846 return sp;
847 }
848 }
849 mdp.n_states - 1
851}
852
853pub(crate) fn compute_bellman_residual(mdp: &Mdp, v: &[f64], r: &Array2<f64>) -> f64 {
855 let q = mdp.q_values(v, r);
856 (0..mdp.n_states)
857 .map(|s| {
858 let best = (0..mdp.n_actions)
859 .map(|a| q[[s, a]])
860 .fold(f64::NEG_INFINITY, f64::max);
861 (best - v[s]).abs()
862 })
863 .fold(0.0_f64, f64::max)
864}
865
866#[cfg(test)]
871mod tests {
872 use super::*;
873 use scirs2_core::ndarray::{Array2, Array3};
874
875 fn two_state_deterministic() -> Mdp {
877 let n = 2;
878 let a = 1;
879 let mut t = Array3::<f64>::zeros((n, a, n));
880 t[[0, 0, 1]] = 1.0;
881 t[[1, 0, 1]] = 1.0; let mut r = Array3::<f64>::zeros((n, a, n));
883 r[[0, 0, 1]] = 1.0; let mut mdp = Mdp::new(n, a, t, r, 0.9).expect("failed to create mdp");
885 mdp.terminal_states = vec![1];
886 mdp
887 }
888
889 fn three_state_mdp() -> Mdp {
891 let n = 3;
892 let a = 2;
893 let mut t = Array3::<f64>::zeros((n, a, n));
896 t[[0, 0, 1]] = 1.0;
897 t[[1, 0, 2]] = 1.0;
898 t[[2, 0, 2]] = 1.0;
899 t[[0, 1, 0]] = 1.0;
900 t[[1, 1, 1]] = 1.0;
901 t[[2, 1, 2]] = 1.0;
902 let mut r = Array3::<f64>::zeros((n, a, n));
903 r[[1, 0, 2]] = 1.0; Mdp::new(n, a, t, r, 0.9).expect("unexpected None or Err")
905 }
906
907 fn stochastic_mdp() -> Mdp {
909 let n = 3;
910 let a = 2;
911 let mut t = Array3::<f64>::zeros((n, a, n));
912 t[[0, 0, 1]] = 0.7;
914 t[[0, 0, 2]] = 0.3;
915 t[[0, 1, 0]] = 1.0;
917 t[[1, 0, 2]] = 1.0;
919 t[[1, 1, 2]] = 1.0;
920 t[[2, 0, 2]] = 1.0;
922 t[[2, 1, 2]] = 1.0;
923 let mut r = Array3::<f64>::zeros((n, a, n));
924 r[[0, 0, 1]] = 0.5;
925 r[[0, 0, 2]] = 1.0;
926 r[[1, 0, 2]] = 2.0;
927 r[[1, 1, 2]] = 2.0;
928 Mdp::new(n, a, t, r, 0.9).expect("unexpected None or Err")
929 }
930
931 #[test]
934 fn test_mdp_construction_valid() {
935 let mdp = two_state_deterministic();
936 assert_eq!(mdp.n_states, 2);
937 assert_eq!(mdp.n_actions, 1);
938 }
939
940 #[test]
941 fn test_mdp_construction_bad_gamma() {
942 let n = 2;
943 let t = Array3::<f64>::zeros((n, 1, n));
944 let r = Array3::<f64>::zeros((n, 1, n));
945 assert!(Mdp::new(n, 1, t, r, 1.5).is_err());
946 }
947
948 #[test]
949 fn test_mdp_validation_rejects_bad_transitions() {
950 let n = 2;
951 let a = 1;
952 let t = Array3::<f64>::zeros((n, a, n));
954 let r = Array3::<f64>::zeros((n, a, n));
955 assert!(Mdp::new(n, a, t, r, 0.9).is_err());
956 }
957
958 #[test]
959 fn test_expected_reward() {
960 let mdp = two_state_deterministic();
961 let er = mdp.expected_reward();
962 assert!((er[[0, 0]] - 1.0).abs() < 1e-9);
964 }
965
966 #[test]
967 fn test_with_state_action_reward() {
968 let n = 2;
969 let a = 2;
970 let mut t = Array3::<f64>::zeros((n, a, n));
971 t[[0, 0, 1]] = 1.0;
972 t[[0, 1, 0]] = 1.0;
973 t[[1, 0, 1]] = 1.0;
974 t[[1, 1, 1]] = 1.0;
975 let r2 = Array2::<f64>::from_elem((n, a), 0.5);
976 let mdp = Mdp::with_state_action_reward(n, a, t, r2, 0.9);
977 assert!(mdp.is_ok());
978 let mdp = mdp.expect("failed to create mdp");
979 assert!((mdp.reward[[0, 0, 0]] - 0.5).abs() < 1e-9);
981 assert!((mdp.reward[[1, 1, 1]] - 0.5).abs() < 1e-9);
982 }
983
984 #[test]
987 fn test_value_iteration_two_state() {
988 let mdp = two_state_deterministic();
989 let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
990 assert!(sol.converged);
991 assert!(sol.value_function[1].abs() < 1e-6);
993 assert!((sol.value_function[0] - 1.0).abs() < 1e-4);
995 }
996
997 #[test]
998 fn test_value_iteration_three_state() {
999 let mdp = three_state_mdp();
1000 let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1001 assert!(sol.converged);
1002 assert!(sol.value_function[0] > 0.0);
1008 assert!(sol.value_function[1] > sol.value_function[0]);
1009 assert!((sol.value_function[1] - 1.0).abs() < 1e-4);
1010 assert!((sol.value_function[0] - 0.9).abs() < 1e-4);
1011 }
1012
1013 #[test]
1014 fn test_value_iteration_policy_is_greedy() {
1015 let mdp = three_state_mdp();
1016 let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1017 assert!(sol.converged);
1018 assert_eq!(sol.policy[0], 0);
1020 assert_eq!(sol.policy[1], 0);
1021 }
1022
1023 #[test]
1024 fn test_value_iteration_convergence_flag() {
1025 let mdp = three_state_mdp();
1026 let sol = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create sol");
1028 assert!(sol.converged);
1029 }
1030
1031 #[test]
1032 fn test_value_iteration_stochastic() {
1033 let mdp = stochastic_mdp();
1034 let sol = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1035 assert!(sol.converged);
1036 assert!(
1037 sol.value_function[2].abs() < 1e-6,
1038 "absorbing state value must be 0"
1039 );
1040 }
1041
1042 #[test]
1045 fn test_policy_evaluation_consistent() {
1046 let mdp = three_state_mdp();
1047 let vi = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create vi");
1048 let v_eval =
1050 evaluate_policy(&mdp, &vi.policy, 1e-12, 100_000).expect("failed to create v_eval");
1051 for s in 0..mdp.n_states {
1052 assert!(
1053 (v_eval[s] - vi.value_function[s]).abs() < 1e-4,
1054 "state {}: eval {} vs vi {}",
1055 s,
1056 v_eval[s],
1057 vi.value_function[s]
1058 );
1059 }
1060 }
1061
1062 #[test]
1063 fn test_policy_evaluation_bad_policy_length() {
1064 let mdp = two_state_deterministic();
1065 let bad_policy = vec![0usize; 5];
1066 assert!(evaluate_policy(&mdp, &bad_policy, 1e-9, 100).is_err());
1067 }
1068
1069 #[test]
1072 fn test_policy_iteration_equals_vi() {
1073 let mdp = three_state_mdp();
1074 let vi = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create vi");
1075 let pi = policy_iteration(&mdp, 1e-9, 10_000).expect("failed to create pi");
1076 assert!(pi.converged);
1077 for s in 0..mdp.n_states {
1078 assert!(
1079 (pi.value_function[s] - vi.value_function[s]).abs() < 1e-3,
1080 "state {}: pi={} vi={}",
1081 s,
1082 pi.value_function[s],
1083 vi.value_function[s]
1084 );
1085 }
1086 }
1087
1088 #[test]
1089 fn test_policy_iteration_stochastic() {
1090 let mdp = stochastic_mdp();
1091 let sol = policy_iteration(&mdp, 1e-9, 10_000).expect("failed to create sol");
1092 assert!(sol.converged);
1093 }
1094
1095 #[test]
1098 fn test_modified_policy_iteration_k1_like_vi() {
1099 let mdp = three_state_mdp();
1101 let vi = value_iteration(&mdp, 1e-9, 10_000).expect("failed to create vi");
1102 let mpi = modified_policy_iteration(&mdp, 1, 1e-9, 50_000).expect("failed to create mpi");
1103 assert!(mpi.converged);
1104 for s in 0..mdp.n_states {
1105 assert!(
1106 (mpi.value_function[s] - vi.value_function[s]).abs() < 1e-3,
1107 "state {}: mpi={} vi={}",
1108 s,
1109 mpi.value_function[s],
1110 vi.value_function[s]
1111 );
1112 }
1113 }
1114
1115 #[test]
1116 fn test_modified_policy_iteration_k10() {
1117 let mdp = stochastic_mdp();
1118 let sol = modified_policy_iteration(&mdp, 10, 1e-9, 10_000).expect("failed to create sol");
1119 assert!(sol.converged);
1120 }
1121
1122 #[test]
1123 fn test_modified_policy_iteration_zero_k_error() {
1124 let mdp = two_state_deterministic();
1125 assert!(modified_policy_iteration(&mdp, 0, 1e-9, 100).is_err());
1126 }
1127
1128 #[test]
1131 fn test_lp_solve_agrees_with_vi() {
1132 let mdp = three_state_mdp();
1133 let vi = value_iteration(&mdp, 1e-12, 100_000).expect("failed to create vi");
1134 let lp = lp_solve_mdp(&mdp).expect("failed to create lp");
1135 for s in 0..mdp.n_states {
1136 assert!(
1137 (lp.value_function[s] - vi.value_function[s]).abs() < 1e-4,
1138 "state {}: lp={} vi={}",
1139 s,
1140 lp.value_function[s],
1141 vi.value_function[s]
1142 );
1143 }
1144 }
1145
1146 #[test]
1149 fn test_qlearning_update() {
1150 let mut q = QLearning::new(3, 2, 0.1, 0.0, 0.9);
1151 q.update(0, 0, 1.0, 1);
1152 assert!((q.q_table[[0, 0]] - 0.1).abs() < 1e-12);
1154 }
1155
1156 #[test]
1157 fn test_qlearning_greedy() {
1158 let mut q = QLearning::new(3, 2, 0.1, 0.0, 0.9);
1159 q.q_table[[0, 1]] = 5.0;
1160 assert_eq!(q.greedy(0), 1);
1161 }
1162
1163 #[test]
1164 fn test_qlearning_train_returns_length() {
1165 let mdp = three_state_mdp();
1166 let mut q = QLearning::new(3, 2, 0.3, 0.1, 0.9);
1167 let returns = q
1168 .train(&mdp, 100, 20, 42)
1169 .expect("failed to create returns");
1170 assert_eq!(returns.len(), 100);
1171 }
1172
1173 #[test]
1174 fn test_qlearning_policy_shape() {
1175 let mut q = QLearning::new(3, 2, 0.3, 0.1, 0.9);
1176 let mdp = three_state_mdp();
1177 let _ = q.train(&mdp, 200, 30, 7).expect("failed to create _");
1178 let pol = q.policy();
1179 assert_eq!(pol.len(), 3);
1180 for &a in &pol {
1181 assert!(a < 2);
1182 }
1183 }
1184
1185 #[test]
1186 fn test_qlearning_value_function() {
1187 let q = QLearning::new(2, 2, 0.1, 0.0, 0.9);
1188 let vf = q.value_function();
1189 assert_eq!(vf.len(), 2);
1190 }
1191
1192 #[test]
1195 fn test_sarsa_update() {
1196 let mut s = Sarsa::new(3, 2, 0.1, 0.0, 0.9);
1197 s.update(0, 0, 1.0, 1, 0);
1198 assert!((s.q_table[[0, 0]] - 0.1).abs() < 1e-12);
1200 }
1201
1202 #[test]
1203 fn test_sarsa_train_returns_length() {
1204 let mdp = three_state_mdp();
1205 let mut s = Sarsa::new(3, 2, 0.3, 0.1, 0.9);
1206 let returns = s
1207 .train(&mdp, 100, 20, 13)
1208 .expect("failed to create returns");
1209 assert_eq!(returns.len(), 100);
1210 }
1211
1212 #[test]
1213 fn test_sarsa_policy_valid() {
1214 let mdp = three_state_mdp();
1215 let mut s = Sarsa::new(3, 2, 0.3, 0.1, 0.9);
1216 let _ = s.train(&mdp, 200, 30, 99).expect("failed to create _");
1217 let pol = s.policy();
1218 assert_eq!(pol.len(), 3);
1219 for &a in &pol {
1220 assert!(a < 2);
1221 }
1222 }
1223
1224 #[test]
1227 fn test_simulate_length() {
1228 let mdp = three_state_mdp();
1229 let policy = vec![0usize, 0, 0];
1230 let (states, actions, rewards) = simulate(&mdp, &policy, 0, 5, 42);
1231 assert!(states.len() >= 1);
1232 assert_eq!(actions.len(), rewards.len());
1233 assert!(actions.len() <= 5);
1234 }
1235
1236 #[test]
1237 fn test_simulate_terminal_stops() {
1238 let mdp = two_state_deterministic();
1239 let policy = vec![0usize; 2];
1240 let (states, _actions, _rewards) = simulate(&mdp, &policy, 0, 100, 1);
1241 assert!(states.len() <= 3, "states.len() = {}", states.len());
1243 }
1244}