1use super::board::Board;
2use super::masks::Masks;
3
4#[derive(Debug, Copy, Clone, PartialEq, Eq)]
12pub struct Candidates {
13 cache: [[u16; 9]; 9],
14}
15
16impl Candidates {
17 pub(super) fn new() -> Self {
18 Candidates { cache: [[0; 9]; 9] }
19 }
20
21 pub(super) fn get(&self, r: usize, c: usize) -> u16 {
23 self.cache[r][c]
24 }
25
26 pub fn get_candidates(&self, r: usize, c: usize) -> Vec<u8> {
28 let mask = self.get(r, c);
29 let mut candidates = Vec::new();
30 for i in 0..9 {
31 if (mask >> i) & 1 == 1 {
33 candidates.push((i + 1) as u8);
34 }
35 }
36 candidates
37 }
38
39 pub(super) fn set(&mut self, r: usize, c: usize, mask: u16) {
41 self.cache[r][c] = mask;
42 }
43
44 pub(super) fn update_affected_cells(
47 &mut self,
48 r: usize,
49 c: usize,
50 masks: &Masks,
51 board: &Board,
52 ) {
53 self.update_affected_cells_for(r, c, masks, board, None);
54 }
55
56 pub(super) fn update_affected_cells_for(
61 &mut self,
62 r: usize,
63 c: usize,
64 masks: &Masks,
65 board: &Board,
66 placed_num: Option<u8>,
67 ) {
68 if board.is_empty(r, c) {
70 self.cache[r][c] = masks.compute_candidates_mask_for_cell(r, c);
72 } else {
73 self.cache[r][c] = 0;
75 }
76
77 let num_bit = placed_num.map(|n| 1u16 << (n - 1));
78
79 for i in 0..9 {
81 if board.is_empty(r, i) && i != c {
82 if let Some(bit) = num_bit {
84 if self.cache[r][i] & bit == 0 {
85 continue;
86 }
87 }
88 self.cache[r][i] = masks.compute_candidates_mask_for_cell(r, i);
89 }
90 if board.is_empty(i, c) && i != r {
91 if let Some(bit) = num_bit {
92 if self.cache[i][c] & bit == 0 {
93 continue;
94 }
95 }
96 self.cache[i][c] = masks.compute_candidates_mask_for_cell(i, c);
97 }
98 }
99
100 let box_idx = Masks::get_box_idx(r, c);
102 let start_row = (box_idx / 3) * 3;
103 let start_col = (box_idx % 3) * 3;
104 for r_offset in 0..3 {
105 for c_offset in 0..3 {
106 let cur_r = start_row + r_offset;
107 let cur_c = start_col + c_offset;
108 if (cur_r == r) || (cur_c == c) {
110 continue;
111 }
112 if board.is_empty(cur_r, cur_c) {
113 if let Some(bit) = num_bit {
114 if self.cache[cur_r][cur_c] & bit == 0 {
115 continue;
116 }
117 }
118 self.cache[cur_r][cur_c] = masks.compute_candidates_mask_for_cell(cur_r, cur_c);
119 }
120 }
121 }
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_get_candidates_empty_mask() {
131 let candidates = Candidates::new(); let r = 0;
133 let c = 0;
134 let cands = candidates.get_candidates(r, c);
135 assert_eq!(cands, vec![]);
136 }
137
138 #[test]
139 fn test_get_candidates_full_mask() {
140 let mut candidates = Candidates::new();
141 let r = 0;
142 let c = 0;
143 let full_mask = (1 << 9) - 1; candidates.set(r, c, full_mask);
147 let cands = candidates.get_candidates(r, c);
148 assert_eq!(cands, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
149 }
150
151 #[test]
152 fn test_get_candidates_single_candidate() {
153 let mut candidates = Candidates::new();
154 let r = 1;
155 let c = 2;
156
157 candidates.set(r, c, 1 << 0); let cands_1 = candidates.get_candidates(r, c);
160 assert_eq!(cands_1, vec![1]);
161
162 candidates.set(r, c, 1 << 4); let cands_5 = candidates.get_candidates(r, c);
165 assert_eq!(cands_5, vec![5]);
166
167 candidates.set(r, c, 1 << 8); let cands_9 = candidates.get_candidates(r, c);
170 assert_eq!(cands_9, vec![9]);
171 }
172
173 #[test]
174 fn test_get_candidates_multiple_candidates() {
175 let mut candidates = Candidates::new();
176 let r = 3;
177 let c = 4;
178
179 let mask = (1 << 1) | (1 << 3) | (1 << 6); candidates.set(r, c, mask);
185 let cands = candidates.get_candidates(r, c);
186 assert_eq!(cands, vec![2, 4, 7]);
187
188 let mask_1_9 = (1 << 0) | (1 << 8); candidates.set(r, c, mask_1_9);
192 let cands_1_9 = candidates.get_candidates(r, c);
193 assert_eq!(cands_1_9, vec![1, 9]);
194 }
195
196 #[test]
197 fn test_get_candidates_different_cells() {
198 let mut candidates = Candidates::new();
199
200 candidates.set(0, 0, (1 << 0) | (1 << 2)); candidates.set(8, 8, (1 << 5) | (1 << 7)); assert_eq!(candidates.get_candidates(0, 0), vec![1, 3]);
206 assert_eq!(candidates.get_candidates(8, 8), vec![6, 8]);
207 assert_eq!(candidates.get_candidates(0, 1), vec![]); }
209}