1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#[cfg(test)]
#[path = "../../../tests/unit/algorithms/mdp/simulator_test.rs"]
mod simulator_test;

use super::*;
use crate::utils::{parallel_into_collect, CollectGroupBy, Parallelism};

/// A type which keeps track of all state-action estimates.
pub type StateEstimates<S> = HashMap<S, ActionEstimates<S>>;

/// A simulator to train agent with multiple episodes.
pub struct Simulator<S: State> {
    q: StateEstimates<S>,
    learning_strategy: Box<dyn LearningStrategy<S> + Send + Sync>,
    policy_strategy: Box<dyn PolicyStrategy<S> + Send + Sync>,
}

impl<S: State> Simulator<S> {
    /// Creates a new instance of MDP simulator.
    pub fn new(
        learning_strategy: Box<dyn LearningStrategy<S> + Send + Sync>,
        policy_strategy: Box<dyn PolicyStrategy<S> + Send + Sync>,
    ) -> Self {
        Self { q: Default::default(), learning_strategy, policy_strategy }
    }

    /// Return a learned optimal policy for given state.
    pub fn get_optimal_policy(&self, state: &S) -> Option<(<S as State>::Action, f64)> {
        self.q.get(state).and_then(|estimates| {
            let strategy: Box<dyn PolicyStrategy<S>> = Box::new(Greedy::default());
            strategy
                .select(estimates)
                .and_then(|action| estimates.data().get(&action).map(|estimate| (action, *estimate)))
        })
    }

    /// Gets state estimates.
    pub fn get_state_estimates(&self) -> &StateEstimates<S> {
        &self.q
    }

    /// Sets action estimates for given state.
    pub fn set_action_estimates(&mut self, state: S, estimates: ActionEstimates<S>) {
        self.q.insert(state, estimates);
    }

    /// Runs single episode for each of the given agents in parallel.
    pub fn run_episodes<A>(
        &mut self,
        agents: Vec<Box<A>>,
        parallelism: Parallelism,
        reducer: impl Fn(&S, &[f64]) -> f64,
    ) -> Vec<Box<A>>
    where
        A: Agent<S> + Send + Sync,
    {
        let (agents, qs): (Vec<_>, Vec<_>) =
            parallel_into_collect(agents.into_iter().enumerate().collect(), |(idx, agent)| {
                let mut agent = agent;
                parallelism.thread_pool_execute(idx, || {
                    let qs = Self::run_episode(
                        agent.as_mut(),
                        self.learning_strategy.as_ref(),
                        self.policy_strategy.as_ref(),
                        &self.q,
                    );
                    (agent, qs)
                })
            })
            .into_iter()
            .unzip();

        merge_vec_maps(qs, |(state, values)| {
            let action_values = self.q.entry(state.clone()).or_insert_with(ActionEstimates::default);
            let vec_map = values.into_iter().map(|estimates| estimates.into()).collect();
            merge_vec_maps(vec_map, |(action, values)| {
                action_values.insert(action, reducer(&state, values.as_slice()));
            });
            action_values.recalculate_min_max();
        });

        agents
    }

    fn run_episode(
        agent: &mut dyn Agent<S>,
        learning_strategy: &(dyn LearningStrategy<S> + Send + Sync),
        policy_strategy: &(dyn PolicyStrategy<S> + Send + Sync),
        q: &StateEstimates<S>,
    ) -> StateEstimates<S> {
        let mut q_new = StateEstimates::new();

        loop {
            let old_state = agent.get_state().clone();
            Self::ensure_actions(&mut q_new, q, &old_state, agent);
            let old_estimates = q_new.get(&old_state).unwrap();

            let action = policy_strategy.select(old_estimates);
            if action.is_none() {
                break;
            }

            let action = action.unwrap();
            agent.take_action(&action);
            let old_value = *old_estimates.data().get(&action).unwrap();

            let next_state = agent.get_state();
            let reward_value = next_state.reward();

            Self::ensure_actions(&mut q_new, q, &next_state, agent);
            let new_estimates = q_new.get(&next_state).unwrap();
            let new_value = learning_strategy.value(reward_value, old_value, new_estimates);

            q_new.entry(old_state.clone()).and_modify(|estimates| {
                estimates.insert(action.clone(), new_value);
                estimates.recalculate_min_max();
            });
        }

        q_new
    }

    fn ensure_actions(q_new: &mut StateEstimates<S>, q: &StateEstimates<S>, state: &S, agent: &dyn Agent<S>) {
        match (q_new.get(state), q.get(state)) {
            (None, Some(estimates)) => q_new.insert(state.clone(), estimates.clone()),
            (None, None) => q_new.insert(state.clone(), agent.get_actions(&state)),
            (Some(_), _) => None,
        };
    }
}

fn merge_vec_maps<K: Eq + Hash, V, F: FnMut((K, Vec<V>))>(vec_map: Vec<HashMap<K, V>>, merge_func: F) {
    vec_map.into_iter().flat_map(|q| q.into_iter()).collect_group_by().into_iter().for_each(merge_func)
}