weightedcoin/
weightedcoin.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use rurel::mdp::{Agent, State};
6use rurel::strategy::explore::RandomExploration;
7use rurel::strategy::learn::QLearning;
8use rurel::strategy::terminate::SinkStates;
9use rurel::AgentTrainer;
10
11const TARGET: i32 = 100;
12const WEIGHT: u8 = 100; //portion of 255
13
14#[derive(PartialEq, Eq, Hash, Clone)]
15struct CoinState {
16    balance: i32,
17}
18
19#[derive(PartialEq, Eq, Hash, Clone)]
20struct CoinAction {
21    bet: i32,
22}
23
24impl State for CoinState {
25    type A = CoinAction;
26
27    fn reward(&self) -> f64 {
28        if self.balance >= TARGET {
29            1.0
30        } else {
31            0.0
32        }
33    }
34
35    fn actions(&self) -> Vec<CoinAction> {
36        let bet_range = {
37            if self.balance < TARGET / 2 {
38                1..self.balance + 1
39            } else {
40                1..(TARGET - self.balance) + 1
41            }
42        };
43        bet_range.map(|bet| CoinAction { bet }).collect()
44    }
45}
46
47struct CoinAgent {
48    state: CoinState,
49}
50
51impl Agent<CoinState> for CoinAgent {
52    fn current_state(&self) -> &CoinState {
53        &self.state
54    }
55    fn take_action(&mut self, action: &CoinAction) {
56        //Update the state to:
57        self.state = CoinState {
58            balance: if rand::random::<u8>() <= WEIGHT {
59                self.state.balance + action.bet
60            }
61            //If the coin is heads, balance + bet
62            else {
63                self.state.balance - action.bet
64            }, //If the coin is tails, balance - bet
65        }
66    }
67}
68
69fn main() {
70    const TRIALS: i32 = 100000;
71    let mut trainer = AgentTrainer::new();
72    for trial in 0..TRIALS {
73        let mut agent = CoinAgent {
74            state: CoinState {
75                balance: 1 + trial % 98,
76            },
77        };
78        trainer.train(
79            &mut agent,
80            &QLearning::new(0.2, 1.0, 0.0),
81            &mut SinkStates {},
82            &RandomExploration::new(),
83        );
84    }
85
86    println!("Balance\tBet\tQ-value");
87    for balance in 1..TARGET {
88        let state = CoinState { balance };
89        let action = trainer.best_action(&state).unwrap();
90        println!(
91            "{}\t{}\t{}",
92            balance,
93            action.bet,
94            trainer.expected_value(&state, &action).unwrap(),
95        );
96    }
97}