rustoku_lib/core/
masks.rs

1/// Masks for Rustoku puzzle, representing the state of rows, columns, and boxes.
2///
3/// This struct holds bitmasks for each row, column, and 3x3 box in the Rustoku board.
4/// Each bit in the masks corresponds to a number from 1 to 9, where a bit set to 1 indicates
5/// that the corresponding number is present in that row, column, or box.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Masks {
8    row_masks: [u16; 9],
9    col_masks: [u16; 9],
10    box_masks: [u16; 9],
11}
12
13impl Masks {
14    pub(super) fn new() -> Self {
15        Masks {
16            row_masks: [0; 9],
17            col_masks: [0; 9],
18            box_masks: [0; 9],
19        }
20    }
21
22    /// Computes the index of the 3x3 box based on the row and column indices.
23    pub(super) fn get_box_idx(r: usize, c: usize) -> usize {
24        (r / 3) * 3 + (c / 3)
25    }
26
27    /// Adds a number to the masks for the specified row, column, and box.
28    pub(super) fn add_number(&mut self, r: usize, c: usize, num: u8) {
29        let bit_to_set = 1 << (num - 1);
30        let box_idx = Self::get_box_idx(r, c);
31        self.row_masks[r] |= bit_to_set;
32        self.col_masks[c] |= bit_to_set;
33        self.box_masks[box_idx] |= bit_to_set;
34    }
35
36    /// Removes a number from the masks for the specified row, column, and box.
37    pub(super) fn remove_number(&mut self, r: usize, c: usize, num: u8) {
38        let bit_to_unset = 1 << (num - 1);
39        let box_idx = Self::get_box_idx(r, c);
40        self.row_masks[r] &= !bit_to_unset;
41        self.col_masks[c] &= !bit_to_unset;
42        self.box_masks[box_idx] &= !bit_to_unset;
43    }
44
45    /// Checks if a number can be safely placed in the specified cell.
46    pub fn is_safe(&self, r: usize, c: usize, num: u8) -> bool {
47        let bit_to_check = 1 << (num - 1);
48        let box_idx = Self::get_box_idx(r, c);
49
50        (self.row_masks[r] & bit_to_check == 0)
51            && (self.col_masks[c] & bit_to_check == 0)
52            && (self.box_masks[box_idx] & bit_to_check == 0)
53    }
54
55    /// Computes the candidates mask for a specific cell based on the current masks.
56    pub(super) fn compute_candidates_mask_for_cell(&self, r: usize, c: usize) -> u16 {
57        let row_mask = self.row_masks[r];
58        let col_mask = self.col_masks[c];
59        let box_mask = self.box_masks[Self::get_box_idx(r, c)];
60        let used = row_mask | col_mask | box_mask;
61        !used & 0x1FF
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68
69    // Helper for creating bitmasks for expected values
70    fn bit(n: u8) -> u16 {
71        1u16 << (n - 1)
72    }
73
74    // Helper for combining bits for expected values
75    fn bits(nums: &[u8]) -> u16 {
76        nums.iter().map(|&n| bit(n)).fold(0, |acc, b| acc | b)
77    }
78
79    #[test]
80    fn test_new_initializes_empty_masks() {
81        let masks = Masks::new();
82        assert_eq!(masks.row_masks, [0; 9]);
83        assert_eq!(masks.col_masks, [0; 9]);
84        assert_eq!(masks.box_masks, [0; 9]);
85    }
86
87    #[test]
88    fn test_get_box_idx_top_left_box() {
89        assert_eq!(Masks::get_box_idx(0, 0), 0);
90        assert_eq!(Masks::get_box_idx(1, 1), 0);
91        assert_eq!(Masks::get_box_idx(2, 2), 0);
92    }
93
94    #[test]
95    fn test_get_box_idx_middle_box() {
96        assert_eq!(Masks::get_box_idx(3, 3), 4);
97        assert_eq!(Masks::get_box_idx(4, 4), 4);
98        assert_eq!(Masks::get_box_idx(5, 5), 4);
99    }
100
101    #[test]
102    fn test_get_box_idx_bottom_right_box() {
103        assert_eq!(Masks::get_box_idx(6, 6), 8);
104        assert_eq!(Masks::get_box_idx(7, 7), 8);
105        assert_eq!(Masks::get_box_idx(8, 8), 8);
106    }
107
108    #[test]
109    fn test_get_box_idx_various_boxes() {
110        assert_eq!(Masks::get_box_idx(0, 3), 1); // Top-middle
111        assert_eq!(Masks::get_box_idx(3, 0), 3); // Middle-left
112        assert_eq!(Masks::get_box_idx(0, 8), 2); // Top-right
113        assert_eq!(Masks::get_box_idx(8, 0), 6); // Bottom-left
114    }
115
116    #[test]
117    fn test_add_number_single_cell() {
118        let mut masks = Masks::new();
119        masks.add_number(0, 0, 1);
120        assert_eq!(masks.row_masks[0], bit(1));
121        assert_eq!(masks.col_masks[0], bit(1));
122        assert_eq!(masks.box_masks[0], bit(1));
123    }
124
125    #[test]
126    fn test_add_number_multiple_in_same_row() {
127        let mut masks = Masks::new();
128        masks.add_number(0, 0, 1);
129        masks.add_number(0, 1, 5);
130        assert_eq!(masks.row_masks[0], bits(&[1, 5]));
131        assert_eq!(masks.col_masks[0], bit(1));
132        assert_eq!(masks.col_masks[1], bit(5));
133        assert_eq!(masks.box_masks[0], bits(&[1, 5]));
134    }
135
136    #[test]
137    fn test_add_number_to_different_box() {
138        let mut masks = Masks::new();
139        masks.add_number(8, 8, 9);
140        assert_eq!(masks.row_masks[8], bit(9));
141        assert_eq!(masks.col_masks[8], bit(9));
142        assert_eq!(masks.box_masks[8], bit(9));
143    }
144
145    #[test]
146    fn test_add_number_already_present_no_change() {
147        let mut masks = Masks::new();
148        masks.add_number(0, 0, 1);
149        let initial_row_0 = masks.row_masks[0];
150        masks.add_number(0, 0, 1); // Adding again
151        assert_eq!(masks.row_masks[0], initial_row_0);
152    }
153
154    #[test]
155    fn test_remove_number_single_value_from_cell() {
156        let mut masks = Masks::new();
157        masks.add_number(0, 0, 1); // Add 1
158        masks.remove_number(0, 0, 1); // Remove 1
159        assert_eq!(masks.row_masks[0], 0);
160        assert_eq!(masks.col_masks[0], 0);
161        assert_eq!(masks.box_masks[0], 0);
162    }
163
164    #[test]
165    fn test_remove_number_from_shared_row() {
166        let mut masks = Masks::new();
167        masks.add_number(0, 0, 1);
168        masks.add_number(0, 1, 5);
169        masks.remove_number(0, 0, 1);
170        assert_eq!(masks.row_masks[0], bit(5)); // Only 5 should remain in row 0
171        assert_eq!(masks.col_masks[0], 0); // Col 0 should be clear
172        assert_eq!(masks.col_masks[1], bit(5)); // Col 1 should still have 5
173        assert_eq!(masks.box_masks[0], bit(5)); // Box 0 should only have 5
174    }
175
176    #[test]
177    fn test_remove_number_not_present_no_change() {
178        let mut masks = Masks::new();
179        masks.add_number(0, 0, 1);
180        let initial_row_0 = masks.row_masks[0];
181        masks.remove_number(0, 0, 3); // Remove 3 (not present)
182        assert_eq!(masks.row_masks[0], initial_row_0);
183    }
184
185    #[test]
186    fn test_is_safe_on_empty_board_always_true() {
187        let masks = Masks::new();
188        assert!(masks.is_safe(0, 0, 1)); // 1 should be safe anywhere
189        assert!(masks.is_safe(8, 8, 9)); // 9 should be safe anywhere
190    }
191
192    #[test]
193    fn test_is_safe_conflict_in_row() {
194        let mut masks = Masks::new();
195        masks.add_number(0, 0, 1); // Place 1 at (0,0)
196        assert!(!masks.is_safe(0, 1, 1)); // 1 should not be safe in same row
197        assert!(masks.is_safe(0, 1, 2)); // 2 should be safe in same row
198    }
199
200    #[test]
201    fn test_is_safe_conflict_in_column() {
202        let mut masks = Masks::new();
203        masks.add_number(0, 0, 1); // Place 1 at (0,0)
204        assert!(!masks.is_safe(1, 0, 1)); // 1 should not be safe in same col
205        assert!(masks.is_safe(1, 0, 2)); // 2 should be safe in same col
206    }
207
208    #[test]
209    fn test_is_safe_conflict_in_box() {
210        let mut masks = Masks::new();
211        masks.add_number(1, 1, 1); // Place 1 at (1,1) in box 0
212        assert!(!masks.is_safe(0, 0, 1)); // 1 should not be safe in same box
213        assert!(masks.is_safe(0, 0, 2)); // 2 should be safe in same box
214    }
215
216    #[test]
217    fn test_is_safe_conflict_with_current_cell_value() {
218        let mut masks = Masks::new();
219        masks.add_number(0, 0, 1);
220        assert!(!masks.is_safe(0, 0, 1)); // Should not be safe to place 1 where 1 already is
221        assert!(masks.is_safe(0, 0, 2)); // Should be safe for other numbers
222    }
223
224    #[test]
225    fn test_compute_candidates_empty_cell_all_available() {
226        let masks = Masks::new();
227        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
228        assert_eq!(candidates, 0x1FF); // All 9 bits should be set
229    }
230
231    #[test]
232    fn test_compute_candidates_row_has_1_to_8_only_9_available() {
233        let mut masks = Masks::new();
234        masks.row_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask
235        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
236        assert_eq!(candidates, bit(9));
237    }
238
239    #[test]
240    fn test_compute_candidates_col_has_1_to_8_only_9_available() {
241        let mut masks = Masks::new();
242        masks.col_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask
243        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
244        assert_eq!(candidates, bit(9));
245    }
246
247    #[test]
248    fn test_compute_candidates_box_has_1_to_8_only_9_available() {
249        let mut masks = Masks::new();
250        masks.box_masks[Masks::get_box_idx(1, 1)] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask for box 0
251        let candidates = masks.compute_candidates_mask_for_cell(1, 1);
252        assert_eq!(candidates, bit(9));
253    }
254
255    #[test]
256    fn test_compute_candidates_mixed_restrictions() {
257        let mut masks = Masks::new();
258        // Row 0 has 1, 2
259        masks.row_masks[0] = bits(&[1, 2]);
260        // Col 0 has 3, 4
261        masks.col_masks[0] = bits(&[3, 4]);
262        // Box 0 (for 0,0) has 5, 6
263        masks.box_masks[0] = bits(&[5, 6]);
264
265        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
266        assert_eq!(candidates, bits(&[7, 8, 9])); // Should be 7, 8, 9
267    }
268
269    #[test]
270    fn test_compute_candidates_no_candidates_left() {
271        let mut masks = Masks::new();
272        masks.row_masks[0] = 0x1FF; // All 1-9 used in row
273        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
274        assert_eq!(candidates, 0); // No candidates
275    }
276}