1use crate::error::BatchError;
28use rayon::prelude::*;
29use rustywallet_address::{Network, P2PKHAddress, P2TRAddress, P2WPKHAddress};
30use rustywallet_keys::private_key::PrivateKey;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
37pub enum BatchAddressType {
38 P2PKH,
40 P2WPKH,
42 P2TR,
44}
45
46impl BatchAddressType {
47 pub fn mainnet_prefix(&self) -> &'static str {
49 match self {
50 BatchAddressType::P2PKH => "1",
51 BatchAddressType::P2WPKH => "bc1q",
52 BatchAddressType::P2TR => "bc1p",
53 }
54 }
55
56 pub fn testnet_prefix(&self) -> &'static str {
58 match self {
59 BatchAddressType::P2PKH => "m/n",
60 BatchAddressType::P2WPKH => "tb1q",
61 BatchAddressType::P2TR => "tb1p",
62 }
63 }
64}
65
66impl std::fmt::Display for BatchAddressType {
67 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68 match self {
69 BatchAddressType::P2PKH => write!(f, "P2PKH"),
70 BatchAddressType::P2WPKH => write!(f, "P2WPKH"),
71 BatchAddressType::P2TR => write!(f, "P2TR"),
72 }
73 }
74}
75
76
77#[derive(Debug, Clone)]
101pub struct BatchAddressGenerator {
102 address_type: BatchAddressType,
103 network: Network,
104 parallel: bool,
105 chunk_size: usize,
106}
107
108impl BatchAddressGenerator {
109 pub fn new(address_type: BatchAddressType, network: Network) -> Self {
125 Self {
126 address_type,
127 network,
128 parallel: true, chunk_size: 1000,
130 }
131 }
132
133 pub fn parallel(mut self, enabled: bool) -> Self {
137 self.parallel = enabled;
138 self
139 }
140
141 pub fn chunk_size(mut self, size: usize) -> Self {
146 self.chunk_size = size;
147 self
148 }
149
150 #[inline]
152 pub fn address_type(&self) -> BatchAddressType {
153 self.address_type
154 }
155
156 #[inline]
158 pub fn network(&self) -> Network {
159 self.network
160 }
161
162 pub fn generate_stream(&self, count: usize) -> AddressStream {
185 AddressStream::new(
186 self.address_type,
187 self.network,
188 count,
189 self.parallel,
190 self.chunk_size,
191 )
192 }
193
194 pub fn generate_vec(&self, count: usize) -> Result<Vec<(PrivateKey, String)>, BatchError> {
217 if !self.network.is_bitcoin() {
218 return Err(BatchError::invalid_config(format!(
219 "Network {} is not supported for Bitcoin address generation",
220 self.network
221 )));
222 }
223
224 if self.parallel {
225 self.generate_parallel_vec(count)
226 } else {
227 self.generate_sequential_vec(count)
228 }
229 }
230
231 fn generate_single(&self) -> Result<(PrivateKey, String), BatchError> {
233 let key = PrivateKey::random();
234 let pubkey = key.public_key();
235
236 let address = match self.address_type {
237 BatchAddressType::P2PKH => {
238 P2PKHAddress::from_public_key(&pubkey, self.network)
239 .map_err(|e| BatchError::generation_error(e.to_string()))?
240 .to_string()
241 }
242 BatchAddressType::P2WPKH => {
243 P2WPKHAddress::from_public_key(&pubkey, self.network)
244 .map_err(|e| BatchError::generation_error(e.to_string()))?
245 .to_string()
246 }
247 BatchAddressType::P2TR => {
248 P2TRAddress::from_public_key(&pubkey, self.network)
249 .map_err(|e| BatchError::generation_error(e.to_string()))?
250 .to_string()
251 }
252 };
253
254 Ok((key, address))
255 }
256
257 fn generate_sequential_vec(&self, count: usize) -> Result<Vec<(PrivateKey, String)>, BatchError> {
259 (0..count)
260 .map(|_| self.generate_single())
261 .collect()
262 }
263
264 fn generate_parallel_vec(&self, count: usize) -> Result<Vec<(PrivateKey, String)>, BatchError> {
266 let address_type = self.address_type;
267 let network = self.network;
268
269 let results: Vec<_> = (0..count)
270 .into_par_iter()
271 .map(|_| generate_address_pair(address_type, network))
272 .collect();
273
274 results.into_iter().collect()
276 }
277}
278
279fn generate_address_pair(
281 address_type: BatchAddressType,
282 network: Network,
283) -> Result<(PrivateKey, String), BatchError> {
284 let key = PrivateKey::random();
285 let pubkey = key.public_key();
286
287 let address = match address_type {
288 BatchAddressType::P2PKH => {
289 P2PKHAddress::from_public_key(&pubkey, network)
290 .map_err(|e| BatchError::generation_error(e.to_string()))?
291 .to_string()
292 }
293 BatchAddressType::P2WPKH => {
294 P2WPKHAddress::from_public_key(&pubkey, network)
295 .map_err(|e| BatchError::generation_error(e.to_string()))?
296 .to_string()
297 }
298 BatchAddressType::P2TR => {
299 P2TRAddress::from_public_key(&pubkey, network)
300 .map_err(|e| BatchError::generation_error(e.to_string()))?
301 .to_string()
302 }
303 };
304
305 Ok((key, address))
306}
307
308
309pub struct AddressStream {
314 address_type: BatchAddressType,
315 network: Network,
316 remaining: usize,
317 parallel: bool,
318 chunk_size: usize,
319 current_chunk: std::vec::IntoIter<(PrivateKey, String)>,
320}
321
322impl AddressStream {
323 fn new(
325 address_type: BatchAddressType,
326 network: Network,
327 count: usize,
328 parallel: bool,
329 chunk_size: usize,
330 ) -> Self {
331 Self {
332 address_type,
333 network,
334 remaining: count,
335 parallel,
336 chunk_size,
337 current_chunk: Vec::new().into_iter(),
338 }
339 }
340
341 fn generate_chunk(&mut self) -> Vec<(PrivateKey, String)> {
343 let chunk_count = self.remaining.min(self.chunk_size);
344 self.remaining -= chunk_count;
345
346 let address_type = self.address_type;
347 let network = self.network;
348
349 if self.parallel {
350 (0..chunk_count)
351 .into_par_iter()
352 .filter_map(|_| generate_address_pair(address_type, network).ok())
353 .collect()
354 } else {
355 (0..chunk_count)
356 .filter_map(|_| generate_address_pair(address_type, network).ok())
357 .collect()
358 }
359 }
360
361 #[inline]
363 pub fn remaining(&self) -> usize {
364 self.remaining + self.current_chunk.len()
365 }
366}
367
368impl Iterator for AddressStream {
369 type Item = (PrivateKey, String);
370
371 fn next(&mut self) -> Option<Self::Item> {
372 if let Some(pair) = self.current_chunk.next() {
374 return Some(pair);
375 }
376
377 if self.remaining > 0 {
379 let chunk = self.generate_chunk();
380 self.current_chunk = chunk.into_iter();
381 self.current_chunk.next()
382 } else {
383 None
384 }
385 }
386
387 fn size_hint(&self) -> (usize, Option<usize>) {
388 let remaining = self.remaining();
389 (remaining, Some(remaining))
390 }
391}
392
393impl ExactSizeIterator for AddressStream {}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_batch_address_type_display() {
401 assert_eq!(BatchAddressType::P2PKH.to_string(), "P2PKH");
402 assert_eq!(BatchAddressType::P2WPKH.to_string(), "P2WPKH");
403 assert_eq!(BatchAddressType::P2TR.to_string(), "P2TR");
404 }
405
406 #[test]
407 fn test_batch_address_type_prefixes() {
408 assert_eq!(BatchAddressType::P2PKH.mainnet_prefix(), "1");
409 assert_eq!(BatchAddressType::P2WPKH.mainnet_prefix(), "bc1q");
410 assert_eq!(BatchAddressType::P2TR.mainnet_prefix(), "bc1p");
411 }
412
413 #[test]
414 fn test_generate_p2pkh_addresses() {
415 let generator = BatchAddressGenerator::new(BatchAddressType::P2PKH, Network::BitcoinMainnet);
416 let addresses = generator.generate_vec(10).unwrap();
417
418 assert_eq!(addresses.len(), 10);
419 for (_, addr) in &addresses {
420 assert!(addr.starts_with('1'), "P2PKH address should start with '1': {}", addr);
421 }
422 }
423
424 #[test]
425 fn test_generate_p2wpkh_addresses() {
426 let generator = BatchAddressGenerator::new(BatchAddressType::P2WPKH, Network::BitcoinMainnet);
427 let addresses = generator.generate_vec(10).unwrap();
428
429 assert_eq!(addresses.len(), 10);
430 for (_, addr) in &addresses {
431 assert!(addr.starts_with("bc1q"), "P2WPKH address should start with 'bc1q': {}", addr);
432 }
433 }
434
435 #[test]
436 fn test_generate_p2tr_addresses() {
437 let generator = BatchAddressGenerator::new(BatchAddressType::P2TR, Network::BitcoinMainnet);
438 let addresses = generator.generate_vec(10).unwrap();
439
440 assert_eq!(addresses.len(), 10);
441 for (_, addr) in &addresses {
442 assert!(addr.starts_with("bc1p"), "P2TR address should start with 'bc1p': {}", addr);
443 }
444 }
445
446 #[test]
447 fn test_generate_testnet_addresses() {
448 let generator = BatchAddressGenerator::new(BatchAddressType::P2WPKH, Network::BitcoinTestnet);
449 let addresses = generator.generate_vec(10).unwrap();
450
451 assert_eq!(addresses.len(), 10);
452 for (_, addr) in &addresses {
453 assert!(addr.starts_with("tb1q"), "Testnet P2WPKH should start with 'tb1q': {}", addr);
454 }
455 }
456
457 #[test]
458 fn test_generate_stream() {
459 let generator = BatchAddressGenerator::new(BatchAddressType::P2WPKH, Network::BitcoinMainnet);
460 let stream = generator.generate_stream(100);
461
462 let addresses: Vec<_> = stream.collect();
463 assert_eq!(addresses.len(), 100);
464 }
465
466 #[test]
467 fn test_generate_stream_parallel() {
468 let generator = BatchAddressGenerator::new(BatchAddressType::P2TR, Network::BitcoinMainnet)
469 .parallel(true)
470 .chunk_size(50);
471
472 let stream = generator.generate_stream(200);
473 let addresses: Vec<_> = stream.collect();
474
475 assert_eq!(addresses.len(), 200);
476 for (_, addr) in &addresses {
477 assert!(addr.starts_with("bc1p"));
478 }
479 }
480
481 #[test]
482 fn test_generate_sequential() {
483 let generator = BatchAddressGenerator::new(BatchAddressType::P2PKH, Network::BitcoinMainnet)
484 .parallel(false);
485
486 let addresses = generator.generate_vec(50).unwrap();
487 assert_eq!(addresses.len(), 50);
488 }
489
490 #[test]
491 fn test_addresses_are_unique() {
492 let generator = BatchAddressGenerator::new(BatchAddressType::P2WPKH, Network::BitcoinMainnet);
493 let addresses = generator.generate_vec(100).unwrap();
494
495 let unique_addrs: std::collections::HashSet<_> = addresses.iter().map(|(_, a)| a.clone()).collect();
496 assert_eq!(unique_addrs.len(), addresses.len(), "All addresses should be unique");
497 }
498
499 #[test]
500 fn test_key_derives_to_address() {
501 let generator = BatchAddressGenerator::new(BatchAddressType::P2WPKH, Network::BitcoinMainnet);
502 let addresses = generator.generate_vec(10).unwrap();
503
504 for (key, addr) in addresses {
505 let pubkey = key.public_key();
507 let derived_addr = P2WPKHAddress::from_public_key(&pubkey, Network::BitcoinMainnet)
508 .unwrap()
509 .to_string();
510 assert_eq!(addr, derived_addr, "Address should match derived address");
511 }
512 }
513}