1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
use std::num::NonZeroUsize;
use std::ops::{Index, IndexMut};

use ordered_float::OrderedFloat;
use rand::Rng;

use crate::board::{Board, Coord, Player};
use crate::bot_game::Bot;
use crate::mcts::heuristic::{Heuristic, ZeroHeuristic};

pub mod heuristic;

#[derive(Debug, Copy, Clone)]
pub struct IdxRange {
    pub start: NonZeroUsize,
    pub length: u8,
}

impl IdxRange {
    pub fn iter(&self) -> std::ops::Range<usize> {
        let start = self.start.get();
        let length = self.length as usize;
        start..(start + length)
    }

    pub fn get(&self, index: usize) -> usize {
        assert!(index < (self.length as usize), "Index {} out of bounds", index);
        self.start.get() + index
    }
}

impl IntoIterator for IdxRange {
    type Item = usize;
    type IntoIter = std::ops::Range<usize>;

    fn into_iter(self) -> Self::IntoIter {
        self.iter()
    }
}

#[derive(Debug)]
pub struct Node {
    pub coord: Coord,
    //this is not just a Option<IdxRange> because of struct layout inefficiencies
    children_start: usize,
    children_length: u8,
    pub wins: u64,
    pub draws: u64,
    pub visits: u64,
}

impl Node {
    fn new(coord: Coord) -> Self {
        Node {
            coord,
            children_start: 0,
            children_length: 0,
            wins: 0,
            draws: 0,
            visits: 0,
        }
    }

    pub fn uct(&self, parent_visits: u64, heuristic: f32) -> f32 {
        let wins = self.wins as f32;
        let draws = self.draws as f32;
        let visits = self.visits as f32;

        //TODO is this really the best heuristic formula? maybe let the heuristic decide the weight as well?
        (wins + 0.5 * draws) / visits +
            (2.0 * (parent_visits as f32).ln() / visits).sqrt() +
            (heuristic / (visits + 1.0))
    }

    /// The estimated value of this node in the range -1..1
    pub fn signed_value(&self) -> f32 {
        (2.0 * (self.wins as f32) + (self.draws as f32)) / (self.visits as f32) - 1.0
    }

    pub fn children(&self) -> Option<IdxRange> {
        NonZeroUsize::new(self.children_start)
            .map(|start| IdxRange { start, length: self.children_length })
    }

    pub fn set_children(&mut self, children: IdxRange) {
        self.children_start = children.start.get();
        self.children_length = children.length;
    }
}

/// A small wrapper type for Vec<Node> that uses u64 for indexing instead.
#[derive(Debug)]
pub struct Tree {
    pub root_board: Board,
    pub nodes: Vec<Node>,
}

impl Tree {
    pub fn new(root_board: Board) -> Self {
        Tree { root_board, nodes: Default::default() }
    }

    pub fn best_move(&self) -> Coord {
        let children = self[0].children()
            .expect("Root node must have children");

        let best_child = children.iter().rev().max_by_key(|&child| {
            self[child].visits
        }).expect("Root node must have non-empty children");

        self[best_child].coord
    }

    pub fn signed_value(&self) -> f32 {
        self[0].signed_value()
    }
}

impl Index<usize> for Tree {
    type Output = Node;

    fn index(&self, index: usize) -> &Self::Output {
        &self.nodes[index as usize]
    }
}

impl IndexMut<usize> for Tree {
    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
        &mut self.nodes[index as usize]
    }
}

pub fn mcts_build_tree<H: Heuristic, R: Rng>(board: &Board, iterations: u64, heuristic: &H, rand: &mut R) -> Tree {
    assert!(iterations > 0, "MCTS must run for at least 1 iteration");
    assert!(!board.is_done(), "Cannot build MCTS tree for done board");

    let mut tree = Tree::new(board.clone());
    let mut parent_list = Vec::with_capacity(81);

    //the actual coord doesn't matter, just pick something
    tree.nodes.push(Node::new(Coord::from_o(0)));

    for _ in 0..iterations {
        parent_list.clear();

        let mut curr_node: usize = 0;
        let mut curr_board = board.clone();

        while !curr_board.is_done() {
            parent_list.push(curr_node);

            //Init children
            let children = match tree[curr_node].children() {
                Some(children) => children,
                None => {
                    static_assertions::const_assert!(Board::MAX_AVAILABLE_MOVES <= u8::MAX as u32);

                    let start = tree.nodes.len();
                    tree.nodes.extend(curr_board.available_moves().map(|c| Node::new(c)));
                    let length = (tree.nodes.len() - start) as u8;

                    let children = IdxRange {
                        start: NonZeroUsize::new(start).unwrap(),
                        length,
                    };
                    tree[curr_node].set_children(children);
                    children
                }
            };

            //Exploration
            let unexplored_children = children.iter()
                .filter(|&c| tree[c].visits == 0);
            let count = unexplored_children.clone().count();

            if count != 0 {
                let child = unexplored_children.clone().nth(rand.gen_range(0..count))
                    .expect("we specifically selected the index based on the count already");

                curr_node = child;
                curr_board.play(tree[curr_node].coord);

                break;
            }

            //Selection
            let parent_visits = tree[curr_node].visits;

            let selected = children.iter().max_by_key(|&child| {
                let heuristic = heuristic.evaluate(&curr_board);
                let uct = tree[child].uct(parent_visits, heuristic);
                OrderedFloat(uct)
            }).expect("Board is not done, this node should have a child");

            curr_node = selected;
            curr_board.play(tree[curr_node].coord);
        }

        //Simulate
        let curr_player = curr_board.next_player;

        let won_by = loop {
            if let Some(won_by) = curr_board.won_by {
                break won_by;
            }

            curr_board.play(curr_board.random_available_move(rand)
                .expect("No winner, so board is not done yet"));
        };

        parent_list.push(curr_node);

        //Update
        let mut won = if won_by != Player::Neutral {
            won_by == curr_player
        } else {
            rand.gen()
        };

        for &update_node in parent_list.iter().rev() {
            won = !won;

            let node = &mut tree[update_node];
            node.visits += 1;
            node.wins += won as u64;
        }
    }

    assert_eq!(iterations, tree[0].visits, "implementation error");
    tree
}

pub struct MCTSBot<H: Heuristic, R: Rng> {
    iterations: u64,
    heuristic: H,
    rand: R,
}

impl<R: Rng> MCTSBot<ZeroHeuristic, R> {
    pub fn new(iterations: u64, rand: R) -> Self {
        MCTSBot { iterations, heuristic: ZeroHeuristic, rand }
    }
}

impl<H: Heuristic, R: Rng> MCTSBot<H, R> {
    pub fn new_with_heuristic(iterations: u64, rand: R, heuristic: H) -> Self {
        MCTSBot { iterations, heuristic, rand }
    }
}

impl<H: Heuristic, R: Rng> Bot for MCTSBot<H, R> {
    fn play(&mut self, board: &Board) -> Option<Coord> {
        if board.is_done() {
            None
        } else {
            let tree = mcts_build_tree(board, self.iterations, &self.heuristic, &mut self.rand);
            Some(tree.best_move())
        }
    }
}