rps/
rps.rs

1use rand::Rng;
2use std::io::{self, Write};
3use std::thread;
4use std::time::Duration;
5use vexus::{Activation, NeuralNetwork, Sigmoid};
6
7const MODEL_FILE_1: &str = "rps_model_1.json";
8const MODEL_FILE_2: &str = "rps_model_2.json";
9const HISTORY_LENGTH: usize = 32; // Number of previous moves to consider
10const TRAINING_ITERATIONS: usize = 10; // Training iterations per move
11const AI_VS_AI_GAMES: usize = 2000; // Number of games to play in AI vs AI mode
12const AI_VS_AI_DELAY_MS: u64 = 0; // Delay between moves in AI vs AI mode
13
14#[derive(Debug, PartialEq, Clone, Copy)]
15enum Move {
16    Rock = 0,
17    Paper = 1,
18    Scissors = 2,
19}
20
21impl Move {
22    fn from_str(s: &str) -> Option<Move> {
23        match s.to_lowercase().as_str() {
24            "r" | "rock" => Some(Move::Rock),
25            "p" | "paper" => Some(Move::Paper),
26            "s" | "scissors" => Some(Move::Scissors),
27            _ => None,
28        }
29    }
30
31    fn from_index(idx: usize) -> Move {
32        match idx % 3 {
33            0 => Move::Rock,
34            1 => Move::Paper,
35            2 => Move::Scissors,
36            _ => unreachable!(),
37        }
38    }
39
40    fn beats(&self, other: &Move) -> bool {
41        match (self, other) {
42            (Move::Rock, Move::Scissors) => true,
43            (Move::Paper, Move::Rock) => true,
44            (Move::Scissors, Move::Paper) => true,
45            _ => false,
46        }
47    }
48
49    fn random() -> Move {
50        let mut rng = rand::rng();
51        match rng.random_range(0..3) {
52            0 => Move::Rock,
53            1 => Move::Paper,
54            _ => Move::Scissors,
55        }
56    }
57
58    fn counter(&self) -> Move {
59        match self {
60            Move::Rock => Move::Paper,
61            Move::Paper => Move::Scissors,
62            Move::Scissors => Move::Rock,
63        }
64    }
65
66    fn to_string(&self) -> &'static str {
67        match self {
68            Move::Rock => "Rock",
69            Move::Paper => "Paper",
70            Move::Scissors => "Scissors",
71        }
72    }
73
74    fn to_input_vec(&self) -> Vec<f32> {
75        let mut result = vec![0.0, 0.0, 0.0];
76        result[*self as usize] = 1.0;
77        result
78    }
79}
80
81struct MovePredictor {
82    nn: NeuralNetwork,
83    player_history: Vec<Move>,
84    initialized: bool,
85    model_file: String,
86}
87
88impl MovePredictor {
89    fn new(model_file: &str) -> Self {
90        // Try to load an existing model or create a new one
91        let nn = NeuralNetwork::new(
92            vec![HISTORY_LENGTH * 3, 32, 32, 16, 3],
93            0.1,
94            Box::new(Sigmoid),
95        );
96
97        MovePredictor {
98            nn,
99            player_history: Vec::new(),
100            initialized: false,
101            model_file: model_file.to_string(),
102        }
103    }
104
105    fn record_move(&mut self, player_move: Move) {
106        self.player_history.push(player_move);
107
108        // Keep only the most recent moves
109        if self.player_history.len() > HISTORY_LENGTH * 2 {
110            self.player_history =
111                self.player_history[self.player_history.len() - HISTORY_LENGTH * 2..].to_vec();
112        }
113
114        // Mark as initialized once we have enough history
115        if self.player_history.len() >= HISTORY_LENGTH {
116            self.initialized = true;
117        }
118    }
119
120    fn train(&mut self) {
121        if !self.initialized || self.player_history.len() < HISTORY_LENGTH + 1 {
122            return;
123        }
124
125        // Train on sequences in the history
126        for i in 0..self.player_history.len() - HISTORY_LENGTH {
127            let inputs = self.history_to_input(&self.player_history[i..i + HISTORY_LENGTH]);
128            let target = self.player_history[i + HISTORY_LENGTH].to_input_vec();
129
130            // Train multiple times on each sequence to reinforce learning
131            for _ in 0..TRAINING_ITERATIONS {
132                let outputs = self.nn.forward(&inputs);
133
134                // Calculate errors (expected - actual)
135
136                self.nn.backpropagate(&target);
137            }
138        }
139    }
140
141    fn predict_next_move(&mut self) -> Move {
142        if !self.initialized || self.player_history.len() < HISTORY_LENGTH {
143            return Move::random();
144        }
145
146        // Get the last HISTORY_LENGTH moves
147        let recent_history = &self.player_history[self.player_history.len() - HISTORY_LENGTH..];
148        let inputs = self.history_to_input(recent_history);
149
150        // Forward pass through the neural network
151        let outputs = self.nn.forward(&inputs);
152
153        // Find the move with highest probability
154        let mut max_idx = 0;
155        let mut max_val = outputs[0];
156
157        for (i, &val) in outputs.iter().enumerate().skip(1) {
158            if val > max_val {
159                max_val = val;
160                max_idx = i;
161            }
162        }
163
164        // Return the move corresponding to the highest output
165        Move::from_index(max_idx)
166    }
167
168    fn make_move(&mut self) -> Move {
169        // For AI gameplay: predict the next move and choose a strategic response
170        if !self.initialized || self.player_history.len() < HISTORY_LENGTH {
171            // When not enough history, pick a random move
172            return Move::random();
173        }
174
175        // More complex strategy than just counter-picking:
176        // Occasionally be random to prevent being too predictable
177        let mut rng = rand::rng();
178        if rng.random_bool(0.2) {
179            // 20% chance of random move
180            return Move::random();
181        }
182
183        // Otherwise, predict and counter
184        let predicted_next = self.predict_next_move();
185        predicted_next.counter()
186    }
187
188    fn history_to_input(&self, history: &[Move]) -> Vec<f32> {
189        let mut inputs = Vec::with_capacity(history.len() * 3);
190        for &m in history {
191            inputs.extend_from_slice(&m.to_input_vec());
192        }
193        inputs
194    }
195}
196
197enum GameMode {
198    PlayerVsAI,
199    AIVsAI,
200}
201
202fn get_game_mode() -> GameMode {
203    loop {
204        println!("=== ROCK PAPER SCISSORS with AI ===");
205        println!("Select game mode:");
206        println!("1. Player vs AI");
207        println!("2. AI vs AI");
208        print!("Enter your choice (1-2): ");
209        io::stdout().flush().unwrap();
210
211        let mut input = String::new();
212        io::stdin()
213            .read_line(&mut input)
214            .expect("Failed to read line");
215
216        match input.trim() {
217            "1" => return GameMode::PlayerVsAI,
218            "2" => return GameMode::AIVsAI,
219            _ => {
220                println!("Invalid choice! Please enter 1 or 2.");
221                println!();
222            }
223        }
224    }
225}
226
227fn player_vs_ai_mode() {
228    println!("\n=== PLAYER VS AI MODE ===");
229    println!("Enter 'q' to quit at any time");
230
231    let mut player_score = 0;
232    let mut computer_score = 0;
233    let mut predictor = MovePredictor::new(MODEL_FILE_1);
234
235    loop {
236        // Get player's move
237        print!("Enter your move (r)ock, (p)aper, (s)cissors: ");
238        io::stdout().flush().unwrap();
239
240        let mut input = String::new();
241        io::stdin()
242            .read_line(&mut input)
243            .expect("Failed to read line");
244
245        let input = input.trim();
246        if input == "q" || input == "quit" {
247            break;
248        }
249
250        let player_move = match Move::from_str(input) {
251            Some(m) => m,
252            None => {
253                println!("Invalid move! Please enter 'r', 'p', or 's'.");
254                continue;
255            }
256        };
257
258        // Computer predicts and makes its move
259        let predicted_move = predictor.predict_next_move();
260        let computer_move = predicted_move.counter();
261
262        // Record the player's move for future predictions
263        predictor.record_move(player_move);
264
265        // Train the neural network with the updated history
266        predictor.train();
267
268        println!("You chose: {}", player_move.to_string());
269        println!("Computer chose: {}", computer_move.to_string());
270
271        // Determine the winner
272        if player_move == computer_move {
273            println!("It's a tie!");
274        } else if player_move.beats(&computer_move) {
275            println!("You win this round!");
276            player_score += 1;
277        } else {
278            println!("Computer wins this round!");
279            computer_score += 1;
280        }
281
282        println!(
283            "Score - You: {}, Computer: {}",
284            player_score, computer_score
285        );
286        println!();
287    }
288
289    // Save the final model
290
291    println!("\nFinal Score:");
292    println!("You: {}", player_score);
293    println!("Computer: {}", computer_score);
294
295    if player_score > computer_score {
296        println!("Congratulations! You won the game!");
297    } else if player_score < computer_score {
298        println!("Better luck next time! Computer won the game.");
299    } else {
300        println!("It's a tie game!");
301    }
302}
303
304fn ai_vs_ai_mode() {
305    println!("\n=== AI VS AI MODE ===");
306    println!(
307        "The AIs will play {} games against each other",
308        AI_VS_AI_GAMES
309    );
310    println!("Press Ctrl+C to stop at any time");
311    println!();
312
313    let mut ai1 = MovePredictor::new(MODEL_FILE_1);
314    let mut ai2 = MovePredictor::new(MODEL_FILE_2);
315
316    let mut ai1_score = 0;
317    let mut ai2_score = 0;
318    let mut ties = 0;
319
320    println!("Game starting...");
321
322    // Initialize with some random moves to get history
323    for _ in 0..HISTORY_LENGTH {
324        let random_move = Move::random();
325        ai1.record_move(random_move);
326        ai2.record_move(random_move);
327    }
328
329    for game in 1..=AI_VS_AI_GAMES {
330        // AI 1 makes a move
331        let ai1_move = ai1.make_move();
332
333        // AI 2 makes a move
334        let ai2_move = ai2.make_move();
335
336        // Record each other's moves
337        ai1.record_move(ai2_move);
338        ai2.record_move(ai1_move);
339
340        // Train both AIs
341        ai1.train();
342        ai2.train();
343
344        if game % 100 == 0 {
345            // Determine winner
346            println!(
347                "Game {}: AI1 chose {}, AI2 chose {}",
348                game,
349                ai1_move.to_string(),
350                ai2_move.to_string()
351            );
352        }
353        if ai1_move == ai2_move {
354            // println!("Game {}: It's a tie!", game);
355            ties += 1;
356        } else if ai1_move.beats(&ai2_move) {
357            // println!("Game {}: AI1 wins!", game);
358            ai1_score += 1;
359        } else {
360            // println!("Game {}: AI2 wins!", game);
361            ai2_score += 1;
362        }
363
364        // println!(
365        //      "Current score - AI1: {}, AI2: {}, Ties: {}",
366        //      ai1_score, ai2_score, ties
367        //  );
368        //  println!();
369
370        // Save models occasionally
371
372        // Delay to make it easier to follow
373        thread::sleep(Duration::from_millis(AI_VS_AI_DELAY_MS));
374    }
375
376    // Save final models
377
378    println!("\nFinal Score after {} games:", AI_VS_AI_GAMES);
379    println!("AI1: {}", ai1_score);
380    println!("AI2: {}", ai2_score);
381    println!("Ties: {}", ties);
382
383    if ai1_score > ai2_score {
384        println!("AI1 won the tournament!");
385    } else if ai1_score < ai2_score {
386        println!("AI2 won the tournament!");
387    } else {
388        println!("The tournament ended in a tie!");
389    }
390}
391
392fn main() {
393    let game_mode = get_game_mode();
394
395    match game_mode {
396        GameMode::PlayerVsAI => player_vs_ai_mode(),
397        GameMode::AIVsAI => ai_vs_ai_mode(),
398    }
399}