1use crate::error::BatchError;
7use crate::stream::KeyStream;
8use rustywallet_keys::private_key::PrivateKey;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ScanDirection {
13 #[default]
15 Forward,
16 Backward,
18}
19
20#[derive(Debug, Clone)]
48pub struct KeyScanner {
49 base_key: PrivateKey,
51
52 direction: ScanDirection,
54
55 step: u64,
57}
58
59impl KeyScanner {
60 pub fn new(base_key: PrivateKey) -> Self {
62 Self {
63 base_key,
64 direction: ScanDirection::Forward,
65 step: 1,
66 }
67 }
68
69 pub fn direction(mut self, direction: ScanDirection) -> Self {
71 self.direction = direction;
72 self
73 }
74
75 pub fn step(mut self, step: u64) -> Self {
77 self.step = step;
78 self
79 }
80
81 pub fn scan_range(self, count: usize) -> KeyStream {
85 let iter = ScanIterator::new(
86 self.base_key,
87 self.direction,
88 self.step,
89 count,
90 );
91 KeyStream::new(iter, Some(count))
92 }
93
94 pub fn scan_until<F>(self, predicate: F) -> KeyStream
98 where
99 F: Fn(&PrivateKey) -> bool + Send + 'static,
100 {
101 let iter = ScanUntilIterator::new(
102 self.base_key,
103 self.direction,
104 self.step,
105 predicate,
106 );
107 KeyStream::new(iter, None)
108 }
109}
110
111struct ScanIterator {
113 current_bytes: [u8; 32],
114 direction: ScanDirection,
115 step: u64,
116 remaining: usize,
117}
118
119impl ScanIterator {
120 fn new(base_key: PrivateKey, direction: ScanDirection, step: u64, count: usize) -> Self {
121 Self {
122 current_bytes: base_key.to_bytes(),
123 direction,
124 step,
125 remaining: count,
126 }
127 }
128
129 fn add_step(&mut self) {
131 let step_bytes = self.step.to_be_bytes();
132 let mut carry: u64 = 0;
133
134 for i in (24..32).rev() {
136 let step_idx = 31 - i;
137 let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
138 let sum = self.current_bytes[i] as u64 + step_byte as u64 + carry;
139 self.current_bytes[i] = sum as u8;
140 carry = sum >> 8;
141 }
142
143 for i in (0..24).rev() {
145 if carry == 0 {
146 break;
147 }
148 let sum = self.current_bytes[i] as u64 + carry;
149 self.current_bytes[i] = sum as u8;
150 carry = sum >> 8;
151 }
152
153 if carry > 0 || !PrivateKey::is_valid(&self.current_bytes) {
155 self.current_bytes = [0u8; 32];
156 self.current_bytes[31] = 1;
157 }
158 }
159
160 fn sub_step(&mut self) {
162 let step_bytes = self.step.to_be_bytes();
163 let mut borrow: i64 = 0;
164
165 for i in (24..32).rev() {
167 let step_idx = 31 - i;
168 let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
169 let diff = self.current_bytes[i] as i64 - step_byte as i64 - borrow;
170 if diff < 0 {
171 self.current_bytes[i] = (diff + 256) as u8;
172 borrow = 1;
173 } else {
174 self.current_bytes[i] = diff as u8;
175 borrow = 0;
176 }
177 }
178
179 for i in (0..24).rev() {
181 if borrow == 0 {
182 break;
183 }
184 let diff = self.current_bytes[i] as i64 - borrow;
185 if diff < 0 {
186 self.current_bytes[i] = (diff + 256) as u8;
187 borrow = 1;
188 } else {
189 self.current_bytes[i] = diff as u8;
190 borrow = 0;
191 }
192 }
193
194 if borrow > 0 || !PrivateKey::is_valid(&self.current_bytes) {
196 self.current_bytes = [
198 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
199 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
200 0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
201 0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
202 ];
203 }
204 }
205}
206
207impl Iterator for ScanIterator {
208 type Item = Result<PrivateKey, BatchError>;
209
210 fn next(&mut self) -> Option<Self::Item> {
211 if self.remaining == 0 {
212 return None;
213 }
214
215 self.remaining -= 1;
216
217 let key = match PrivateKey::from_bytes(self.current_bytes) {
219 Ok(k) => k,
220 Err(e) => return Some(Err(BatchError::scanner_error(format!("Invalid key: {}", e)))),
221 };
222
223 match self.direction {
225 ScanDirection::Forward => self.add_step(),
226 ScanDirection::Backward => self.sub_step(),
227 }
228
229 Some(Ok(key))
230 }
231}
232
233struct ScanUntilIterator<F>
235where
236 F: Fn(&PrivateKey) -> bool,
237{
238 current_bytes: [u8; 32],
239 direction: ScanDirection,
240 step: u64,
241 predicate: F,
242 found: bool,
243}
244
245impl<F> ScanUntilIterator<F>
246where
247 F: Fn(&PrivateKey) -> bool,
248{
249 fn new(base_key: PrivateKey, direction: ScanDirection, step: u64, predicate: F) -> Self {
250 Self {
251 current_bytes: base_key.to_bytes(),
252 direction,
253 step,
254 predicate,
255 found: false,
256 }
257 }
258
259 fn add_step(&mut self) {
260 let step_bytes = self.step.to_be_bytes();
261 let mut carry: u64 = 0;
262
263 for i in (24..32).rev() {
264 let step_idx = 31 - i;
265 let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
266 let sum = self.current_bytes[i] as u64 + step_byte as u64 + carry;
267 self.current_bytes[i] = sum as u8;
268 carry = sum >> 8;
269 }
270
271 for i in (0..24).rev() {
272 if carry == 0 {
273 break;
274 }
275 let sum = self.current_bytes[i] as u64 + carry;
276 self.current_bytes[i] = sum as u8;
277 carry = sum >> 8;
278 }
279
280 if carry > 0 || !PrivateKey::is_valid(&self.current_bytes) {
281 self.current_bytes = [0u8; 32];
282 self.current_bytes[31] = 1;
283 }
284 }
285
286 fn sub_step(&mut self) {
287 let step_bytes = self.step.to_be_bytes();
288 let mut borrow: i64 = 0;
289
290 for i in (24..32).rev() {
291 let step_idx = 31 - i;
292 let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
293 let diff = self.current_bytes[i] as i64 - step_byte as i64 - borrow;
294 if diff < 0 {
295 self.current_bytes[i] = (diff + 256) as u8;
296 borrow = 1;
297 } else {
298 self.current_bytes[i] = diff as u8;
299 borrow = 0;
300 }
301 }
302
303 for i in (0..24).rev() {
304 if borrow == 0 {
305 break;
306 }
307 let diff = self.current_bytes[i] as i64 - borrow;
308 if diff < 0 {
309 self.current_bytes[i] = (diff + 256) as u8;
310 borrow = 1;
311 } else {
312 self.current_bytes[i] = diff as u8;
313 borrow = 0;
314 }
315 }
316
317 if borrow > 0 || !PrivateKey::is_valid(&self.current_bytes) {
318 self.current_bytes = [
319 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
320 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
321 0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
322 0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
323 ];
324 }
325 }
326}
327
328impl<F> Iterator for ScanUntilIterator<F>
329where
330 F: Fn(&PrivateKey) -> bool,
331{
332 type Item = Result<PrivateKey, BatchError>;
333
334 fn next(&mut self) -> Option<Self::Item> {
335 if self.found {
336 return None;
337 }
338
339 let key = match PrivateKey::from_bytes(self.current_bytes) {
340 Ok(k) => k,
341 Err(e) => return Some(Err(BatchError::scanner_error(format!("Invalid key: {}", e)))),
342 };
343
344 if (self.predicate)(&key) {
345 self.found = true;
346 }
347
348 match self.direction {
349 ScanDirection::Forward => self.add_step(),
350 ScanDirection::Backward => self.sub_step(),
351 }
352
353 Some(Ok(key))
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_scan_forward() {
363 let base = PrivateKey::from_hex(
364 "0000000000000000000000000000000000000000000000000000000000000001"
365 ).unwrap();
366
367 let scanner = KeyScanner::new(base);
368 let keys: Vec<_> = scanner.scan_range(5).collect();
369
370 assert_eq!(keys.len(), 5);
371
372 let hex_keys: Vec<_> = keys.iter()
374 .map(|r| r.as_ref().unwrap().to_hex())
375 .collect();
376
377 assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000001");
378 assert_eq!(hex_keys[1], "0000000000000000000000000000000000000000000000000000000000000002");
379 assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000003");
380 }
381
382 #[test]
383 fn test_scan_backward() {
384 let base = PrivateKey::from_hex(
385 "0000000000000000000000000000000000000000000000000000000000000005"
386 ).unwrap();
387
388 let scanner = KeyScanner::new(base)
389 .direction(ScanDirection::Backward);
390
391 let keys: Vec<_> = scanner.scan_range(5).collect();
392
393 assert_eq!(keys.len(), 5);
394
395 let hex_keys: Vec<_> = keys.iter()
396 .map(|r| r.as_ref().unwrap().to_hex())
397 .collect();
398
399 assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000005");
400 assert_eq!(hex_keys[1], "0000000000000000000000000000000000000000000000000000000000000004");
401 assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000003");
402 }
403
404 #[test]
405 fn test_scan_with_step() {
406 let base = PrivateKey::from_hex(
407 "0000000000000000000000000000000000000000000000000000000000000001"
408 ).unwrap();
409
410 let scanner = KeyScanner::new(base)
411 .step(10);
412
413 let keys: Vec<_> = scanner.scan_range(3).collect();
414
415 let hex_keys: Vec<_> = keys.iter()
416 .map(|r| r.as_ref().unwrap().to_hex())
417 .collect();
418
419 assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000001");
420 assert_eq!(hex_keys[1], "000000000000000000000000000000000000000000000000000000000000000b"); assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000015"); }
423
424 #[test]
425 fn test_scan_until() {
426 let base = PrivateKey::from_hex(
427 "0000000000000000000000000000000000000000000000000000000000000001"
428 ).unwrap();
429
430 let scanner = KeyScanner::new(base);
431
432 let keys: Vec<_> = scanner.scan_until(|k| {
434 k.to_hex().ends_with("05")
435 }).collect();
436
437 assert_eq!(keys.len(), 5); let last_key = keys.last().unwrap().as_ref().unwrap();
440 assert!(last_key.to_hex().ends_with("05"));
441 }
442
443 #[test]
444 fn test_bidirectional_consistency() {
445 let base = PrivateKey::from_hex(
448 "0000000000000000000000000000000000000000000000000000000000000064" ).unwrap();
450
451 let forward_scanner = KeyScanner::new(base.clone())
453 .direction(ScanDirection::Forward);
454 let forward_keys: Vec<_> = forward_scanner.scan_range(11).collect();
455 let last_forward = forward_keys.last().unwrap().as_ref().unwrap().clone();
456
457 let backward_scanner = KeyScanner::new(last_forward)
459 .direction(ScanDirection::Backward);
460 let backward_keys: Vec<_> = backward_scanner.scan_range(11).collect();
461 let last_backward = backward_keys.last().unwrap().as_ref().unwrap();
462
463 assert_eq!(base.to_hex(), last_backward.to_hex());
465 }
466}