rustywallet_batch/
scanner.rs

1//! Incremental key scanning using EC point addition.
2//!
3//! This module provides [`KeyScanner`] for scanning key ranges efficiently
4//! using elliptic curve point addition.
5
6use crate::error::BatchError;
7use crate::stream::KeyStream;
8use rustywallet_keys::private_key::PrivateKey;
9
10/// Direction for key scanning.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
12pub enum ScanDirection {
13    /// Scan forward (increment keys).
14    #[default]
15    Forward,
16    /// Scan backward (decrement keys).
17    Backward,
18}
19
20/// Incremental key scanner using EC point addition.
21///
22/// `KeyScanner` efficiently generates sequential keys by adding or subtracting
23/// a constant value from a base key, which is faster than generating random keys.
24///
25/// # Example
26///
27/// ```rust
28/// use rustywallet_batch::prelude::*;
29/// use rustywallet_keys::prelude::PrivateKey;
30///
31/// let base = PrivateKey::from_hex(
32///     "0000000000000000000000000000000000000000000000000000000000000001"
33/// ).unwrap();
34///
35/// // Scan forward from base key
36/// let scanner = KeyScanner::new(base.clone())
37///     .direction(ScanDirection::Forward);
38///
39/// for key in scanner.scan_range(10) {
40///     println!("{}", key.unwrap().to_hex());
41/// }
42///
43/// // Scan backward
44/// let scanner = KeyScanner::new(base)
45///     .direction(ScanDirection::Backward);
46/// ```
47#[derive(Debug, Clone)]
48pub struct KeyScanner {
49    /// The base private key to start scanning from.
50    base_key: PrivateKey,
51
52    /// The scanning direction.
53    direction: ScanDirection,
54
55    /// Step size for each increment/decrement.
56    step: u64,
57}
58
59impl KeyScanner {
60    /// Create a new key scanner starting from the given base key.
61    pub fn new(base_key: PrivateKey) -> Self {
62        Self {
63            base_key,
64            direction: ScanDirection::Forward,
65            step: 1,
66        }
67    }
68
69    /// Set the scanning direction.
70    pub fn direction(mut self, direction: ScanDirection) -> Self {
71        self.direction = direction;
72        self
73    }
74
75    /// Set the step size for each increment/decrement.
76    pub fn step(mut self, step: u64) -> Self {
77        self.step = step;
78        self
79    }
80
81    /// Scan a range of keys starting from the base key.
82    ///
83    /// Returns a `KeyStream` that yields keys incrementally.
84    pub fn scan_range(self, count: usize) -> KeyStream {
85        let iter = ScanIterator::new(
86            self.base_key,
87            self.direction,
88            self.step,
89            count,
90        );
91        KeyStream::new(iter, Some(count))
92    }
93
94    /// Scan keys until a predicate returns true.
95    ///
96    /// Returns a `KeyStream` that yields keys until the predicate matches.
97    pub fn scan_until<F>(self, predicate: F) -> KeyStream
98    where
99        F: Fn(&PrivateKey) -> bool + Send + 'static,
100    {
101        let iter = ScanUntilIterator::new(
102            self.base_key,
103            self.direction,
104            self.step,
105            predicate,
106        );
107        KeyStream::new(iter, None)
108    }
109}
110
111/// Iterator for scanning a fixed range of keys.
112struct ScanIterator {
113    current_bytes: [u8; 32],
114    direction: ScanDirection,
115    step: u64,
116    remaining: usize,
117}
118
119impl ScanIterator {
120    fn new(base_key: PrivateKey, direction: ScanDirection, step: u64, count: usize) -> Self {
121        Self {
122            current_bytes: base_key.to_bytes(),
123            direction,
124            step,
125            remaining: count,
126        }
127    }
128
129    /// Add step to current key bytes (handles overflow/wraparound).
130    fn add_step(&mut self) {
131        let step_bytes = self.step.to_be_bytes();
132        let mut carry: u64 = 0;
133        
134        // Add step to the last 8 bytes
135        for i in (24..32).rev() {
136            let step_idx = 31 - i;
137            let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
138            let sum = self.current_bytes[i] as u64 + step_byte as u64 + carry;
139            self.current_bytes[i] = sum as u8;
140            carry = sum >> 8;
141        }
142
143        // Propagate carry to remaining bytes
144        for i in (0..24).rev() {
145            if carry == 0 {
146                break;
147            }
148            let sum = self.current_bytes[i] as u64 + carry;
149            self.current_bytes[i] = sum as u8;
150            carry = sum >> 8;
151        }
152
153        // Handle wraparound at curve order (simplified - just wrap to 1)
154        if carry > 0 || !PrivateKey::is_valid(&self.current_bytes) {
155            self.current_bytes = [0u8; 32];
156            self.current_bytes[31] = 1;
157        }
158    }
159
160    /// Subtract step from current key bytes (handles underflow/wraparound).
161    fn sub_step(&mut self) {
162        let step_bytes = self.step.to_be_bytes();
163        let mut borrow: i64 = 0;
164        
165        // Subtract step from the last 8 bytes
166        for i in (24..32).rev() {
167            let step_idx = 31 - i;
168            let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
169            let diff = self.current_bytes[i] as i64 - step_byte as i64 - borrow;
170            if diff < 0 {
171                self.current_bytes[i] = (diff + 256) as u8;
172                borrow = 1;
173            } else {
174                self.current_bytes[i] = diff as u8;
175                borrow = 0;
176            }
177        }
178
179        // Propagate borrow to remaining bytes
180        for i in (0..24).rev() {
181            if borrow == 0 {
182                break;
183            }
184            let diff = self.current_bytes[i] as i64 - borrow;
185            if diff < 0 {
186                self.current_bytes[i] = (diff + 256) as u8;
187                borrow = 1;
188            } else {
189                self.current_bytes[i] = diff as u8;
190                borrow = 0;
191            }
192        }
193
194        // Handle underflow (wrap to max valid key)
195        if borrow > 0 || !PrivateKey::is_valid(&self.current_bytes) {
196            // Set to curve order - 1 (max valid key)
197            self.current_bytes = [
198                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
199                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
200                0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
201                0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
202            ];
203        }
204    }
205}
206
207impl Iterator for ScanIterator {
208    type Item = Result<PrivateKey, BatchError>;
209
210    fn next(&mut self) -> Option<Self::Item> {
211        if self.remaining == 0 {
212            return None;
213        }
214
215        self.remaining -= 1;
216
217        // Create key from current bytes
218        let key = match PrivateKey::from_bytes(self.current_bytes) {
219            Ok(k) => k,
220            Err(e) => return Some(Err(BatchError::scanner_error(format!("Invalid key: {}", e)))),
221        };
222
223        // Advance to next key
224        match self.direction {
225            ScanDirection::Forward => self.add_step(),
226            ScanDirection::Backward => self.sub_step(),
227        }
228
229        Some(Ok(key))
230    }
231}
232
233/// Iterator for scanning until a predicate matches.
234struct ScanUntilIterator<F>
235where
236    F: Fn(&PrivateKey) -> bool,
237{
238    current_bytes: [u8; 32],
239    direction: ScanDirection,
240    step: u64,
241    predicate: F,
242    found: bool,
243}
244
245impl<F> ScanUntilIterator<F>
246where
247    F: Fn(&PrivateKey) -> bool,
248{
249    fn new(base_key: PrivateKey, direction: ScanDirection, step: u64, predicate: F) -> Self {
250        Self {
251            current_bytes: base_key.to_bytes(),
252            direction,
253            step,
254            predicate,
255            found: false,
256        }
257    }
258
259    fn add_step(&mut self) {
260        let step_bytes = self.step.to_be_bytes();
261        let mut carry: u64 = 0;
262        
263        for i in (24..32).rev() {
264            let step_idx = 31 - i;
265            let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
266            let sum = self.current_bytes[i] as u64 + step_byte as u64 + carry;
267            self.current_bytes[i] = sum as u8;
268            carry = sum >> 8;
269        }
270
271        for i in (0..24).rev() {
272            if carry == 0 {
273                break;
274            }
275            let sum = self.current_bytes[i] as u64 + carry;
276            self.current_bytes[i] = sum as u8;
277            carry = sum >> 8;
278        }
279
280        if carry > 0 || !PrivateKey::is_valid(&self.current_bytes) {
281            self.current_bytes = [0u8; 32];
282            self.current_bytes[31] = 1;
283        }
284    }
285
286    fn sub_step(&mut self) {
287        let step_bytes = self.step.to_be_bytes();
288        let mut borrow: i64 = 0;
289        
290        for i in (24..32).rev() {
291            let step_idx = 31 - i;
292            let step_byte = if step_idx < 8 { step_bytes[7 - step_idx] } else { 0 };
293            let diff = self.current_bytes[i] as i64 - step_byte as i64 - borrow;
294            if diff < 0 {
295                self.current_bytes[i] = (diff + 256) as u8;
296                borrow = 1;
297            } else {
298                self.current_bytes[i] = diff as u8;
299                borrow = 0;
300            }
301        }
302
303        for i in (0..24).rev() {
304            if borrow == 0 {
305                break;
306            }
307            let diff = self.current_bytes[i] as i64 - borrow;
308            if diff < 0 {
309                self.current_bytes[i] = (diff + 256) as u8;
310                borrow = 1;
311            } else {
312                self.current_bytes[i] = diff as u8;
313                borrow = 0;
314            }
315        }
316
317        if borrow > 0 || !PrivateKey::is_valid(&self.current_bytes) {
318            self.current_bytes = [
319                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
320                0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
321                0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
322                0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
323            ];
324        }
325    }
326}
327
328impl<F> Iterator for ScanUntilIterator<F>
329where
330    F: Fn(&PrivateKey) -> bool,
331{
332    type Item = Result<PrivateKey, BatchError>;
333
334    fn next(&mut self) -> Option<Self::Item> {
335        if self.found {
336            return None;
337        }
338
339        let key = match PrivateKey::from_bytes(self.current_bytes) {
340            Ok(k) => k,
341            Err(e) => return Some(Err(BatchError::scanner_error(format!("Invalid key: {}", e)))),
342        };
343
344        if (self.predicate)(&key) {
345            self.found = true;
346        }
347
348        match self.direction {
349            ScanDirection::Forward => self.add_step(),
350            ScanDirection::Backward => self.sub_step(),
351        }
352
353        Some(Ok(key))
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360
361    #[test]
362    fn test_scan_forward() {
363        let base = PrivateKey::from_hex(
364            "0000000000000000000000000000000000000000000000000000000000000001"
365        ).unwrap();
366
367        let scanner = KeyScanner::new(base);
368        let keys: Vec<_> = scanner.scan_range(5).collect();
369
370        assert_eq!(keys.len(), 5);
371        
372        // Verify sequential keys
373        let hex_keys: Vec<_> = keys.iter()
374            .map(|r| r.as_ref().unwrap().to_hex())
375            .collect();
376
377        assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000001");
378        assert_eq!(hex_keys[1], "0000000000000000000000000000000000000000000000000000000000000002");
379        assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000003");
380    }
381
382    #[test]
383    fn test_scan_backward() {
384        let base = PrivateKey::from_hex(
385            "0000000000000000000000000000000000000000000000000000000000000005"
386        ).unwrap();
387
388        let scanner = KeyScanner::new(base)
389            .direction(ScanDirection::Backward);
390        
391        let keys: Vec<_> = scanner.scan_range(5).collect();
392
393        assert_eq!(keys.len(), 5);
394        
395        let hex_keys: Vec<_> = keys.iter()
396            .map(|r| r.as_ref().unwrap().to_hex())
397            .collect();
398
399        assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000005");
400        assert_eq!(hex_keys[1], "0000000000000000000000000000000000000000000000000000000000000004");
401        assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000003");
402    }
403
404    #[test]
405    fn test_scan_with_step() {
406        let base = PrivateKey::from_hex(
407            "0000000000000000000000000000000000000000000000000000000000000001"
408        ).unwrap();
409
410        let scanner = KeyScanner::new(base)
411            .step(10);
412        
413        let keys: Vec<_> = scanner.scan_range(3).collect();
414
415        let hex_keys: Vec<_> = keys.iter()
416            .map(|r| r.as_ref().unwrap().to_hex())
417            .collect();
418
419        assert_eq!(hex_keys[0], "0000000000000000000000000000000000000000000000000000000000000001");
420        assert_eq!(hex_keys[1], "000000000000000000000000000000000000000000000000000000000000000b"); // 11
421        assert_eq!(hex_keys[2], "0000000000000000000000000000000000000000000000000000000000000015"); // 21
422    }
423
424    #[test]
425    fn test_scan_until() {
426        let base = PrivateKey::from_hex(
427            "0000000000000000000000000000000000000000000000000000000000000001"
428        ).unwrap();
429
430        let scanner = KeyScanner::new(base);
431        
432        // Scan until we find key ending with "05"
433        let keys: Vec<_> = scanner.scan_until(|k| {
434            k.to_hex().ends_with("05")
435        }).collect();
436
437        assert_eq!(keys.len(), 5); // 1, 2, 3, 4, 5
438        
439        let last_key = keys.last().unwrap().as_ref().unwrap();
440        assert!(last_key.to_hex().ends_with("05"));
441    }
442
443    #[test]
444    fn test_bidirectional_consistency() {
445        // Property 7: Bidirectional Scanning Consistency
446        // Forward N steps then backward N steps should return to original
447        let base = PrivateKey::from_hex(
448            "0000000000000000000000000000000000000000000000000000000000000064" // 100
449        ).unwrap();
450
451        // Scan forward 10 steps
452        let forward_scanner = KeyScanner::new(base.clone())
453            .direction(ScanDirection::Forward);
454        let forward_keys: Vec<_> = forward_scanner.scan_range(11).collect();
455        let last_forward = forward_keys.last().unwrap().as_ref().unwrap().clone();
456
457        // Scan backward 10 steps from the last forward key
458        let backward_scanner = KeyScanner::new(last_forward)
459            .direction(ScanDirection::Backward);
460        let backward_keys: Vec<_> = backward_scanner.scan_range(11).collect();
461        let last_backward = backward_keys.last().unwrap().as_ref().unwrap();
462
463        // Should return to original base key
464        assert_eq!(base.to_hex(), last_backward.to_hex());
465    }
466}