1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

//! Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust.
//!
//! Implement the [Agent](mdp/trait.Agent.html) and [State](mdp/trait.State.html) traits for your
//! process, then create an [AgentTrainer](struct.AgentTrainer.html) and train it for your process.
//!
//! # Basic Example
//!
//! The following example defines the `State` as a position on a 21x21 2D matrix. The `Action`s
//! that can be taken are: go up, go down, go left and go right. Positions closer to (10, 10) are
//! assigned a higher reward.
//!
//! After training, the AgentTrainer will have assigned higher values to actions which move closer
//! to (10, 10).
//!
//! ```
//! use rurel::mdp::{State, Agent};
//!
//! #[derive(PartialEq, Eq, Hash, Clone)]
//! struct MyState { x: i32, y: i32 }
//! #[derive(PartialEq, Eq, Hash, Clone)]
//! struct MyAction { dx: i32, dy: i32 }
//!
//! impl State for MyState {
//!     type A = MyAction;
//!     fn reward(&self) -> f64 {
//!         // Negative Euclidean distance
//!         -((((10 - self.x).pow(2) + (10 - self.y).pow(2)) as f64).sqrt())
//!     }
//!     fn actions(&self) -> Vec<MyAction> {
//!         vec![MyAction { dx: 0, dy: -1 },    // up
//!              MyAction { dx: 0, dy: 1 }, // down
//!              MyAction { dx: -1, dy: 0 },    // left
//!              MyAction { dx: 1, dy: 0 }, // right
//!         ]
//!     }
//! }
//!
//! struct MyAgent { state: MyState }
//! impl Agent<MyState> for MyAgent {
//!     fn current_state(&self) -> &MyState {
//!         &self.state
//!     }
//!     fn take_action(&mut self, action: &MyAction) -> () {
//!         match action {
//!             &MyAction { dx, dy } => {
//!                 self.state = MyState {
//!                     x: (((self.state.x + dx) % 21) + 21) % 21, // (x+dx) mod 21
//!                     y: (((self.state.y + dy) % 21) + 21) % 21, // (y+dy) mod 21
//!                 }
//!             }
//!         }
//!     }
//! }
//!
//! use rurel::AgentTrainer;
//! use rurel::strategy::learn::QLearning;
//! use rurel::strategy::explore::RandomExploration;
//! use rurel::strategy::terminate::FixedIterations;
//!
//! let mut trainer = AgentTrainer::new();
//! let mut agent = MyAgent { state: MyState { x: 0, y: 0 }};
//! trainer.train(&mut agent,
//!               &QLearning::new(0.2, 0.01, 2.),
//!               &mut FixedIterations::new(100000),
//!               &RandomExploration::new());
//!
//! // Test to see if it worked
//! let test_state = MyState { x: 10, y: 9 };
//! let go_up = MyAction { dx: 0, dy: -1 };
//! let go_down = MyAction { dx: 0, dy: 1};
//! // Going down is better than going up
//! assert!(trainer.expected_value(&test_state, &go_down)
//!     > trainer.expected_value(&test_state, &go_up));
//! ```

pub mod mdp;
pub mod strategy;

use std::collections::HashMap;
use mdp::{Agent, State};
use strategy::explore::ExplorationStrategy;
use strategy::learn::LearningStrategy;
use strategy::terminate::TerminationStrategy;

/// An `AgentTrainer` can be trained for using a certain [Agent](mdp/trait.Agent.html). After
/// training, the `AgentTrainer` contains learned knowledge about the process, and can be queried
/// for this. For example, you can ask the `AgentTrainer` the expected values of all possible
/// actions in a given state.
pub struct AgentTrainer<S>
    where S: State
{
    q: HashMap<S, HashMap<S::A, f64>>,
}

impl<S> AgentTrainer<S>
    where S: State
{
    pub fn new() -> AgentTrainer<S> {
        AgentTrainer { q: HashMap::new() }
    }
    /// Fetches the learned values for the given state, by `Action`, or `None` if no value was
    /// learned.
    pub fn expected_values(&self, state: &S) -> Option<&HashMap<S::A, f64>> {
        // XXX: make associated const with empty map and remove Option?
        self.q.get(state)
    }
    /// Fetches the learned value for the given `Action` in the given `State`, or `None` if no
    /// value was learned.
    pub fn expected_value(&self, state: &S, action: &S::A) -> Option<f64> {
        self.q
            .get(state)
            .and_then(|m| {
                m.get(action)
                    .and_then(|&v| Some(v))
            })
    }
    /// Trains this `AgentTrainer` using the given `ExplorationStrategy`, `LearningStrategy` and
    /// `Agent` for `iters` iterations.
    pub fn train(&mut self,
                 agent: &mut Agent<S>,
                 learning_strategy: &LearningStrategy<S>,
                 termination_strategy: &mut TerminationStrategy<S>,
                 exploration_strategy: &ExplorationStrategy<S>)
                 -> () {
        loop {
            let s_t = agent.current_state().clone();
            let action = exploration_strategy.pick_action(agent);

            // current action value
            let s_t_next = agent.current_state();
            let r_t_next = s_t_next.reward();

            let v = {
                let old_value = self.q.get(&s_t).and_then(|m| m.get(&action));
                learning_strategy.value(&self.q.get(s_t_next), &old_value, r_t_next)
            };

            self.q.entry(s_t).or_insert_with(|| HashMap::new()).insert(action, v);

            if termination_strategy.should_stop(&s_t_next) {
                break;
            }
        }
    }
}