Skip to main content

rustoku_lib/core/
masks.rs

1/// Masks for Rustoku puzzle, representing the state of rows, columns, and boxes.
2///
3/// This struct holds bitmasks for each row, column, and 3x3 box in the Rustoku board.
4/// Each bit in the masks corresponds to a number from 1 to 9, where a bit set to 1 indicates
5/// that the corresponding number is present in that row, column, or box.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub struct Masks {
8    row_masks: [u16; 9],
9    col_masks: [u16; 9],
10    box_masks: [u16; 9],
11}
12
13impl Masks {
14    pub(super) fn new() -> Self {
15        Masks {
16            row_masks: [0; 9],
17            col_masks: [0; 9],
18            box_masks: [0; 9],
19        }
20    }
21
22    /// Computes the index of the 3x3 box based on the row and column indices.
23    #[inline]
24    pub(super) fn get_box_idx(r: usize, c: usize) -> usize {
25        (r / 3) * 3 + (c / 3)
26    }
27
28    /// Adds a number to the masks for the specified row, column, and box.
29    #[inline]
30    pub(super) fn add_number(&mut self, r: usize, c: usize, num: u8) {
31        let bit_to_set = 1 << (num - 1);
32        let box_idx = Self::get_box_idx(r, c);
33        self.row_masks[r] |= bit_to_set;
34        self.col_masks[c] |= bit_to_set;
35        self.box_masks[box_idx] |= bit_to_set;
36    }
37
38    /// Removes a number from the masks for the specified row, column, and box.
39    #[inline]
40    pub(super) fn remove_number(&mut self, r: usize, c: usize, num: u8) {
41        let bit_to_unset = 1 << (num - 1);
42        let box_idx = Self::get_box_idx(r, c);
43        self.row_masks[r] &= !bit_to_unset;
44        self.col_masks[c] &= !bit_to_unset;
45        self.box_masks[box_idx] &= !bit_to_unset;
46    }
47
48    /// Checks if a number can be safely placed in the specified cell.
49    #[inline]
50    pub fn is_safe(&self, r: usize, c: usize, num: u8) -> bool {
51        let bit_to_check = 1 << (num - 1);
52        let box_idx = Self::get_box_idx(r, c);
53
54        (self.row_masks[r] & bit_to_check == 0)
55            && (self.col_masks[c] & bit_to_check == 0)
56            && (self.box_masks[box_idx] & bit_to_check == 0)
57    }
58
59    /// Computes the candidates mask for a specific cell based on the current masks.
60    #[inline]
61    pub(super) fn compute_candidates_mask_for_cell(&self, r: usize, c: usize) -> u16 {
62        let row_mask = self.row_masks[r];
63        let col_mask = self.col_masks[c];
64        let box_mask = self.box_masks[Self::get_box_idx(r, c)];
65        let used = row_mask | col_mask | box_mask;
66        !used & 0x1FF
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    // Helper for creating bitmasks for expected values
75    fn bit(n: u8) -> u16 {
76        1u16 << (n - 1)
77    }
78
79    // Helper for combining bits for expected values
80    fn bits(nums: &[u8]) -> u16 {
81        nums.iter().map(|&n| bit(n)).fold(0, |acc, b| acc | b)
82    }
83
84    #[test]
85    fn test_new_initializes_empty_masks() {
86        let masks = Masks::new();
87        assert_eq!(masks.row_masks, [0; 9]);
88        assert_eq!(masks.col_masks, [0; 9]);
89        assert_eq!(masks.box_masks, [0; 9]);
90    }
91
92    #[test]
93    fn test_get_box_idx_top_left_box() {
94        assert_eq!(Masks::get_box_idx(0, 0), 0);
95        assert_eq!(Masks::get_box_idx(1, 1), 0);
96        assert_eq!(Masks::get_box_idx(2, 2), 0);
97    }
98
99    #[test]
100    fn test_get_box_idx_middle_box() {
101        assert_eq!(Masks::get_box_idx(3, 3), 4);
102        assert_eq!(Masks::get_box_idx(4, 4), 4);
103        assert_eq!(Masks::get_box_idx(5, 5), 4);
104    }
105
106    #[test]
107    fn test_get_box_idx_bottom_right_box() {
108        assert_eq!(Masks::get_box_idx(6, 6), 8);
109        assert_eq!(Masks::get_box_idx(7, 7), 8);
110        assert_eq!(Masks::get_box_idx(8, 8), 8);
111    }
112
113    #[test]
114    fn test_get_box_idx_various_boxes() {
115        assert_eq!(Masks::get_box_idx(0, 3), 1); // Top-middle
116        assert_eq!(Masks::get_box_idx(3, 0), 3); // Middle-left
117        assert_eq!(Masks::get_box_idx(0, 8), 2); // Top-right
118        assert_eq!(Masks::get_box_idx(8, 0), 6); // Bottom-left
119    }
120
121    #[test]
122    fn test_add_number_single_cell() {
123        let mut masks = Masks::new();
124        masks.add_number(0, 0, 1);
125        assert_eq!(masks.row_masks[0], bit(1));
126        assert_eq!(masks.col_masks[0], bit(1));
127        assert_eq!(masks.box_masks[0], bit(1));
128    }
129
130    #[test]
131    fn test_add_number_multiple_in_same_row() {
132        let mut masks = Masks::new();
133        masks.add_number(0, 0, 1);
134        masks.add_number(0, 1, 5);
135        assert_eq!(masks.row_masks[0], bits(&[1, 5]));
136        assert_eq!(masks.col_masks[0], bit(1));
137        assert_eq!(masks.col_masks[1], bit(5));
138        assert_eq!(masks.box_masks[0], bits(&[1, 5]));
139    }
140
141    #[test]
142    fn test_add_number_to_different_box() {
143        let mut masks = Masks::new();
144        masks.add_number(8, 8, 9);
145        assert_eq!(masks.row_masks[8], bit(9));
146        assert_eq!(masks.col_masks[8], bit(9));
147        assert_eq!(masks.box_masks[8], bit(9));
148    }
149
150    #[test]
151    fn test_add_number_already_present_no_change() {
152        let mut masks = Masks::new();
153        masks.add_number(0, 0, 1);
154        let initial_row_0 = masks.row_masks[0];
155        masks.add_number(0, 0, 1); // Adding again
156        assert_eq!(masks.row_masks[0], initial_row_0);
157    }
158
159    #[test]
160    fn test_remove_number_single_value_from_cell() {
161        let mut masks = Masks::new();
162        masks.add_number(0, 0, 1); // Add 1
163        masks.remove_number(0, 0, 1); // Remove 1
164        assert_eq!(masks.row_masks[0], 0);
165        assert_eq!(masks.col_masks[0], 0);
166        assert_eq!(masks.box_masks[0], 0);
167    }
168
169    #[test]
170    fn test_remove_number_from_shared_row() {
171        let mut masks = Masks::new();
172        masks.add_number(0, 0, 1);
173        masks.add_number(0, 1, 5);
174        masks.remove_number(0, 0, 1);
175        assert_eq!(masks.row_masks[0], bit(5)); // Only 5 should remain in row 0
176        assert_eq!(masks.col_masks[0], 0); // Col 0 should be clear
177        assert_eq!(masks.col_masks[1], bit(5)); // Col 1 should still have 5
178        assert_eq!(masks.box_masks[0], bit(5)); // Box 0 should only have 5
179    }
180
181    #[test]
182    fn test_remove_number_not_present_no_change() {
183        let mut masks = Masks::new();
184        masks.add_number(0, 0, 1);
185        let initial_row_0 = masks.row_masks[0];
186        masks.remove_number(0, 0, 3); // Remove 3 (not present)
187        assert_eq!(masks.row_masks[0], initial_row_0);
188    }
189
190    #[test]
191    fn test_is_safe_on_empty_board_always_true() {
192        let masks = Masks::new();
193        assert!(masks.is_safe(0, 0, 1)); // 1 should be safe anywhere
194        assert!(masks.is_safe(8, 8, 9)); // 9 should be safe anywhere
195    }
196
197    #[test]
198    fn test_is_safe_conflict_in_row() {
199        let mut masks = Masks::new();
200        masks.add_number(0, 0, 1); // Place 1 at (0,0)
201        assert!(!masks.is_safe(0, 1, 1)); // 1 should not be safe in same row
202        assert!(masks.is_safe(0, 1, 2)); // 2 should be safe in same row
203    }
204
205    #[test]
206    fn test_is_safe_conflict_in_column() {
207        let mut masks = Masks::new();
208        masks.add_number(0, 0, 1); // Place 1 at (0,0)
209        assert!(!masks.is_safe(1, 0, 1)); // 1 should not be safe in same col
210        assert!(masks.is_safe(1, 0, 2)); // 2 should be safe in same col
211    }
212
213    #[test]
214    fn test_is_safe_conflict_in_box() {
215        let mut masks = Masks::new();
216        masks.add_number(1, 1, 1); // Place 1 at (1,1) in box 0
217        assert!(!masks.is_safe(0, 0, 1)); // 1 should not be safe in same box
218        assert!(masks.is_safe(0, 0, 2)); // 2 should be safe in same box
219    }
220
221    #[test]
222    fn test_is_safe_conflict_with_current_cell_value() {
223        let mut masks = Masks::new();
224        masks.add_number(0, 0, 1);
225        assert!(!masks.is_safe(0, 0, 1)); // Should not be safe to place 1 where 1 already is
226        assert!(masks.is_safe(0, 0, 2)); // Should be safe for other numbers
227    }
228
229    #[test]
230    fn test_compute_candidates_empty_cell_all_available() {
231        let masks = Masks::new();
232        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
233        assert_eq!(candidates, 0x1FF); // All 9 bits should be set
234    }
235
236    #[test]
237    fn test_compute_candidates_row_has_1_to_8_only_9_available() {
238        let mut masks = Masks::new();
239        masks.row_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask
240        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
241        assert_eq!(candidates, bit(9));
242    }
243
244    #[test]
245    fn test_compute_candidates_col_has_1_to_8_only_9_available() {
246        let mut masks = Masks::new();
247        masks.col_masks[0] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask
248        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
249        assert_eq!(candidates, bit(9));
250    }
251
252    #[test]
253    fn test_compute_candidates_box_has_1_to_8_only_9_available() {
254        let mut masks = Masks::new();
255        masks.box_masks[Masks::get_box_idx(1, 1)] = bits(&[1, 2, 3, 4, 5, 6, 7, 8]); // Directly set mask for box 0
256        let candidates = masks.compute_candidates_mask_for_cell(1, 1);
257        assert_eq!(candidates, bit(9));
258    }
259
260    #[test]
261    fn test_compute_candidates_mixed_restrictions() {
262        let mut masks = Masks::new();
263        // Row 0 has 1, 2
264        masks.row_masks[0] = bits(&[1, 2]);
265        // Col 0 has 3, 4
266        masks.col_masks[0] = bits(&[3, 4]);
267        // Box 0 (for 0,0) has 5, 6
268        masks.box_masks[0] = bits(&[5, 6]);
269
270        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
271        assert_eq!(candidates, bits(&[7, 8, 9])); // Should be 7, 8, 9
272    }
273
274    #[test]
275    fn test_compute_candidates_no_candidates_left() {
276        let mut masks = Masks::new();
277        masks.row_masks[0] = 0x1FF; // All 1-9 used in row
278        let candidates = masks.compute_candidates_mask_for_cell(0, 0);
279        assert_eq!(candidates, 0); // No candidates
280    }
281}