1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
/* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ //! Rurel is a flexible, reusable reinforcement learning (Q learning) implementation in Rust. //! //! Implement the [Agent](mdp/trait.Agent.html) and [State](mdp/trait.State.html) traits for your //! process, then create an [AgentTrainer](struct.AgentTrainer.html) and train it for your process. //! //! # Basic Example //! //! The following example defines the `State` as a position on a 21x21 2D matrix. The `Action`s //! that can be taken are: go up, go down, go left and go right. Positions closer to (10, 10) are //! assigned a higher reward. //! //! After training, the AgentTrainer will have assigned higher values to actions which move closer //! to (10, 10). //! //! ``` //! use rurel::mdp::{State, Agent}; //! //! #[derive(PartialEq, Eq, Hash, Clone)] //! struct MyState { x: i32, y: i32 } //! #[derive(PartialEq, Eq, Hash, Clone)] //! struct MyAction { dx: i32, dy: i32 } //! //! impl State for MyState { //! type A = MyAction; //! fn reward(&self) -> f64 { //! // Negative Euclidean distance //! -((((10 - self.x).pow(2) + (10 - self.y).pow(2)) as f64).sqrt()) //! } //! fn actions(&self) -> Vec<MyAction> { //! vec![MyAction { dx: 0, dy: -1 }, // up //! MyAction { dx: 0, dy: 1 }, // down //! MyAction { dx: -1, dy: 0 }, // left //! MyAction { dx: 1, dy: 0 }, // right //! ] //! } //! } //! //! struct MyAgent { state: MyState } //! impl Agent<MyState> for MyAgent { //! fn current_state(&self) -> &MyState { //! &self.state //! } //! fn take_action(&mut self, action: &MyAction) -> () { //! match action { //! &MyAction { dx, dy } => { //! self.state = MyState { //! x: (((self.state.x + dx) % 21) + 21) % 21, // (x+dx) mod 21 //! y: (((self.state.y + dy) % 21) + 21) % 21, // (y+dy) mod 21 //! } //! } //! } //! } //! } //! //! use rurel::AgentTrainer; //! use rurel::strategy::learn::QLearning; //! use rurel::strategy::explore::RandomExploration; //! use rurel::strategy::terminate::FixedIterations; //! //! let mut trainer = AgentTrainer::new(); //! let mut agent = MyAgent { state: MyState { x: 0, y: 0 }}; //! trainer.train(&mut agent, //! &QLearning::new(0.2, 0.01, 2.), //! &mut FixedIterations::new(100000), //! &RandomExploration::new()); //! //! // Test to see if it worked //! let test_state = MyState { x: 10, y: 9 }; //! let go_up = MyAction { dx: 0, dy: -1 }; //! let go_down = MyAction { dx: 0, dy: 1}; //! // Going down is better than going up //! assert!(trainer.expected_value(&test_state, &go_down) //! > trainer.expected_value(&test_state, &go_up)); //! ``` pub mod mdp; pub mod strategy; use std::collections::HashMap; use mdp::{Agent, State}; use strategy::explore::ExplorationStrategy; use strategy::learn::LearningStrategy; use strategy::terminate::TerminationStrategy; /// An `AgentTrainer` can be trained for using a certain [Agent](mdp/trait.Agent.html). After /// training, the `AgentTrainer` contains learned knowledge about the process, and can be queried /// for this. For example, you can ask the `AgentTrainer` the expected values of all possible /// actions in a given state. pub struct AgentTrainer<S> where S: State { q: HashMap<S, HashMap<S::A, f64>>, } impl<S> AgentTrainer<S> where S: State { pub fn new() -> AgentTrainer<S> { AgentTrainer { q: HashMap::new() } } /// Fetches the learned values for the given state, by `Action`, or `None` if no value was /// learned. pub fn expected_values(&self, state: &S) -> Option<&HashMap<S::A, f64>> { // XXX: make associated const with empty map and remove Option? self.q.get(state) } /// Fetches the learned value for the given `Action` in the given `State`, or `None` if no /// value was learned. pub fn expected_value(&self, state: &S, action: &S::A) -> Option<f64> { self.q .get(state) .and_then(|m| { m.get(action) .and_then(|&v| Some(v)) }) } /// Trains this `AgentTrainer` using the given `ExplorationStrategy`, `LearningStrategy` and /// `Agent` for `iters` iterations. pub fn train(&mut self, agent: &mut Agent<S>, learning_strategy: &LearningStrategy<S>, termination_strategy: &mut TerminationStrategy<S>, exploration_strategy: &ExplorationStrategy<S>) -> () { loop { let s_t = agent.current_state().clone(); let action = exploration_strategy.pick_action(agent); // current action value let s_t_next = agent.current_state(); let r_t_next = s_t_next.reward(); let v = { let old_value = self.q.get(&s_t).and_then(|m| m.get(&action)); learning_strategy.value(&self.q.get(s_t_next), &old_value, r_t_next) }; self.q.entry(s_t).or_insert_with(|| HashMap::new()).insert(action, v); if termination_strategy.should_stop(&s_t_next) { break; } } } }