1use crate::error::SecurityError;
33use crate::secure_memory::SecureMemory;
34use crate::{P2PError, Result};
35use ed25519_dalek::{SigningKey, VerifyingKey};
36use hkdf::Hkdf;
37use rand::{RngCore, thread_rng};
38use sha2::Sha256;
39use std::collections::HashMap;
40use std::sync::{Arc, RwLock};
41
42const MASTER_SEED_SIZE: usize = 32;
44
45#[allow(dead_code)]
47const DERIVED_KEY_SIZE: usize = 32;
48
49const MAX_DERIVATION_DEPTH: usize = 10;
51
52#[allow(dead_code)]
54const PATH_INDEX_SIZE: usize = 4;
55
56const HARDENED_OFFSET: u32 = 0x80000000;
58
59pub struct MasterSeed {
61 seed: SecureMemory,
63 _created_at: u64,
65 derivation_counter: u64,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71pub struct DerivationPath {
72 components: Vec<u32>,
74}
75
76pub struct DerivedKey {
78 pub secret_key: SigningKey,
80 pub public_key: VerifyingKey,
82 pub x25519_secret: [u8; 32],
84 pub x25519_public: [u8; 32],
86 pub path: DerivationPath,
88 pub created_at: u64,
90 pub usage_counter: u64,
92}
93
94impl Clone for DerivedKey {
95 fn clone(&self) -> Self {
96 let signing_key = SigningKey::from_bytes(self.secret_key.as_bytes());
98 let verifying_key =
99 VerifyingKey::from_bytes(self.public_key.as_bytes()).unwrap_or(self.public_key);
100
101 Self {
102 secret_key: signing_key,
103 public_key: verifying_key,
104 x25519_secret: self.x25519_secret,
105 x25519_public: self.x25519_public,
106 path: self.path.clone(),
107 created_at: self.created_at,
108 usage_counter: self.usage_counter,
109 }
110 }
111}
112
113pub struct KeyDerivationCache {
115 cache: RwLock<HashMap<DerivationPath, DerivedKey>>,
117 max_size: usize,
119 hits: std::sync::atomic::AtomicU64,
121 misses: std::sync::atomic::AtomicU64,
123}
124
125pub struct HierarchicalKeyDerivation {
127 master_seed: MasterSeed,
129 cache: Arc<KeyDerivationCache>,
131}
132
133pub struct BatchDerivationRequest {
135 pub paths: Vec<DerivationPath>,
137 pub use_cache: bool,
139 pub priority: DerivationPriority,
141}
142
143#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum DerivationPriority {
146 Low,
148 Normal,
150 High,
152 Critical,
154}
155
156pub struct BatchDerivationResult {
158 pub keys: HashMap<DerivationPath, DerivedKey>,
160 pub failures: HashMap<DerivationPath, String>,
162 pub cache_hit_rate: f64,
164 pub processing_time: std::time::Duration,
166}
167
168#[derive(Debug, Clone, Default)]
170pub struct DerivationStats {
171 pub total_derived: u64,
173 pub cache_hits: u64,
175 pub cache_misses: u64,
177 pub avg_derivation_time_us: u64,
179 pub batch_operations: u64,
181 pub cache_size: usize,
183}
184
185impl MasterSeed {
186 pub fn generate() -> Result<Self> {
188 let mut seed_bytes = vec![0u8; MASTER_SEED_SIZE];
189 thread_rng().fill_bytes(&mut seed_bytes);
190
191 let seed = SecureMemory::from_slice(&seed_bytes)?;
192
193 seed_bytes.zeroize();
195
196 Ok(Self {
197 seed,
198 _created_at: current_timestamp(),
199 derivation_counter: 0,
200 })
201 }
202
203 pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
205 if entropy.len() < MASTER_SEED_SIZE {
206 return Err(P2PError::Security(SecurityError::InvalidKey(
207 "Insufficient entropy for master seed".to_string().into(),
208 )));
209 }
210
211 let seed = SecureMemory::from_slice(&entropy[..MASTER_SEED_SIZE])?;
212
213 Ok(Self {
214 seed,
215 _created_at: current_timestamp(),
216 derivation_counter: 0,
217 })
218 }
219
220 pub fn seed_material(&self) -> &[u8] {
222 self.seed.as_slice()
223 }
224
225 pub fn increment_counter(&mut self) {
227 self.derivation_counter += 1;
228 }
229
230 pub fn derivation_counter(&self) -> u64 {
232 self.derivation_counter
233 }
234}
235
236impl DerivationPath {
237 pub fn new(components: Vec<u32>) -> Result<Self> {
239 if components.len() > MAX_DERIVATION_DEPTH {
240 return Err(P2PError::Security(SecurityError::InvalidKey(
241 format!(
242 "Derivation path too deep: {} > {}",
243 components.len(),
244 MAX_DERIVATION_DEPTH
245 )
246 .into(),
247 )));
248 }
249
250 Ok(Self { components })
251 }
252
253 pub fn from_string(path_str: &str) -> Result<Self> {
255 let parts: Vec<&str> = path_str.split('/').collect();
256
257 if parts.is_empty() || parts[0] != "m" {
258 return Err(P2PError::Security(SecurityError::InvalidKey(
259 "Invalid derivation path format".to_string().into(),
260 )));
261 }
262
263 let mut components = Vec::new();
264
265 for part in parts.iter().skip(1) {
266 if part.is_empty() {
267 continue;
268 }
269
270 let (index_str, hardened) = if let Some(stripped) = part.strip_suffix('\'') {
271 (stripped, true)
272 } else {
273 (*part, false)
274 };
275
276 let index: u32 = index_str.parse().map_err(|_| {
277 P2PError::Security(SecurityError::InvalidKey(
278 format!("Invalid path component: {part}").into(),
279 ))
280 })?;
281
282 let final_index = if hardened {
283 index + HARDENED_OFFSET
284 } else {
285 index
286 };
287
288 components.push(final_index);
289 }
290
291 Self::new(components)
292 }
293
294 pub fn components(&self) -> &[u32] {
297 &self.components
298 }
299
300 pub fn is_hardened(&self, index: usize) -> bool {
302 self.components
303 .get(index)
304 .map(|&c| c >= HARDENED_OFFSET)
305 .unwrap_or(false)
306 }
307
308 pub fn depth(&self) -> usize {
310 self.components.len()
311 }
312
313 pub fn child(&self, component: u32) -> Result<Self> {
315 let mut new_components = self.components.clone();
316 new_components.push(component);
317 Self::new(new_components)
318 }
319
320 pub fn hardened_child(&self, index: u32) -> Result<Self> {
322 self.child(index + HARDENED_OFFSET)
323 }
324}
325
326impl std::fmt::Display for DerivationPath {
327 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328 write!(f, "m")?;
329 for &component in &self.components {
330 write!(f, "/")?;
331 if component >= HARDENED_OFFSET {
332 write!(f, "{}'", component - HARDENED_OFFSET)?;
333 } else {
334 write!(f, "{}", component)?;
335 }
336 }
337 Ok(())
338 }
339}
340
341impl KeyDerivationCache {
342 pub fn new(max_size: usize) -> Self {
344 Self {
345 cache: RwLock::new(HashMap::new()),
346 max_size,
347 hits: std::sync::atomic::AtomicU64::new(0),
348 misses: std::sync::atomic::AtomicU64::new(0),
349 }
350 }
351
352 pub fn get(&self, path: &DerivationPath) -> Option<DerivedKey> {
354 let cache = match self.cache.read() {
355 Ok(c) => c,
356 Err(_) => return None,
357 };
358 if let Some(key) = cache.get(path) {
359 self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
360 Some(key.clone())
361 } else {
362 self.misses
363 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
364 None
365 }
366 }
367
368 pub fn insert(&self, path: DerivationPath, key: DerivedKey) {
370 let mut cache = match self.cache.write() {
371 Ok(c) => c,
372 Err(_) => return,
373 };
374
375 if cache.len() >= self.max_size {
377 let oldest_path = cache
378 .iter()
379 .min_by_key(|(_, k)| k.created_at)
380 .map(|(p, _)| p.clone());
381
382 if let Some(path_to_remove) = oldest_path {
383 cache.remove(&path_to_remove);
384 }
385 }
386
387 cache.insert(path, key);
388 }
389
390 pub fn clear(&self) {
392 let mut cache = match self.cache.write() {
393 Ok(c) => c,
394 Err(_) => return,
395 };
396 cache.clear();
397 }
398
399 pub fn stats(&self) -> (u64, u64, usize) {
401 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
402 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
403 let size = self.cache.read().map(|c| c.len()).unwrap_or(0);
404 (hits, misses, size)
405 }
406}
407
408impl HierarchicalKeyDerivation {
409 pub fn new(master_seed: MasterSeed) -> Self {
411 let cache = Arc::new(KeyDerivationCache::new(1000)); Self { master_seed, cache }
414 }
415
416 pub fn with_cache_size(master_seed: MasterSeed, cache_size: usize) -> Self {
418 let cache = Arc::new(KeyDerivationCache::new(cache_size));
419 Self { master_seed, cache }
420 }
421
422 pub fn derive_key(&mut self, path: &DerivationPath) -> Result<DerivedKey> {
424 if let Some(cached_key) = self.cache.get(path) {
426 return Ok(cached_key);
427 }
428
429 let derived_key = self.derive_key_internal(path)?;
431
432 self.cache.insert(path.clone(), derived_key.clone());
434
435 self.master_seed.increment_counter();
437
438 Ok(derived_key)
439 }
440
441 fn derive_key_internal(&self, path: &DerivationPath) -> Result<DerivedKey> {
443 let mut current_key = self.master_seed.seed_material().to_vec();
444 let mut current_chaincode = [0u8; 32];
445
446 let hkdf = Hkdf::<Sha256>::new(None, ¤t_key);
448 hkdf.expand(b"ed25519 seed", &mut current_key)
449 .map_err(|_| {
450 P2PError::Security(SecurityError::InvalidKey(
451 "HKDF expansion failed".to_string().into(),
452 ))
453 })?;
454 hkdf.expand(b"chaincode", &mut current_chaincode)
455 .map_err(|_| {
456 P2PError::Security(SecurityError::InvalidKey(
457 "HKDF expansion failed".to_string().into(),
458 ))
459 })?;
460
461 for &component in path.components() {
463 let (new_key, new_chaincode) =
464 self.derive_child_key(¤t_key, ¤t_chaincode, component)?;
465 current_key = new_key;
466 current_chaincode = new_chaincode;
467 }
468
469 let signing_key = SigningKey::from_bytes(¤t_key[..32].try_into().map_err(|_| {
471 P2PError::Security(SecurityError::InvalidKey(
472 "Invalid Ed25519 secret key length".to_string().into(),
473 ))
474 })?);
475 let verifying_key = signing_key.verifying_key();
476
477 let x25519_secret: [u8; 32] = current_key[..32].try_into().map_err(|_| {
479 P2PError::Security(SecurityError::InvalidKey(
480 "Invalid X25519 secret key".to_string().into(),
481 ))
482 })?;
483 let x25519_public = x25519_dalek::PublicKey::from(x25519_secret).to_bytes();
484
485 current_key.zeroize();
487 current_chaincode.zeroize();
488
489 Ok(DerivedKey {
490 secret_key: signing_key,
491 public_key: verifying_key,
492 x25519_secret,
493 x25519_public,
494 path: path.clone(),
495 created_at: current_timestamp(),
496 usage_counter: 0,
497 })
498 }
499
500 fn derive_child_key(
502 &self,
503 parent_key: &[u8],
504 parent_chaincode: &[u8],
505 index: u32,
506 ) -> Result<(Vec<u8>, [u8; 32])> {
507 let mut data = Vec::new();
508
509 if index >= HARDENED_OFFSET {
510 data.push(0x00);
512 data.extend_from_slice(parent_key);
513 } else {
514 let signing_key =
516 SigningKey::from_bytes(&parent_key[..32].try_into().map_err(|_| {
517 P2PError::Security(SecurityError::InvalidKey(
518 "Invalid parent key length".to_string().into(),
519 ))
520 })?);
521 let verifying_key = signing_key.verifying_key();
522 data.extend_from_slice(verifying_key.as_bytes());
523 }
524
525 data.extend_from_slice(&index.to_be_bytes());
526
527 let hkdf = Hkdf::<Sha256>::new(Some(parent_chaincode), &data);
528
529 let mut child_key = vec![0u8; 32];
530 let mut child_chaincode = [0u8; 32];
531
532 hkdf.expand(b"key", &mut child_key).map_err(|_| {
533 P2PError::Security(SecurityError::InvalidKey(
534 "Child key derivation failed".to_string().into(),
535 ))
536 })?;
537 hkdf.expand(b"chaincode", &mut child_chaincode)
538 .map_err(|_| {
539 P2PError::Security(SecurityError::InvalidKey(
540 "Child chaincode derivation failed".to_string().into(),
541 ))
542 })?;
543
544 data.zeroize();
546
547 Ok((child_key, child_chaincode))
548 }
549
550 pub fn derive_batch(
552 &mut self,
553 request: BatchDerivationRequest,
554 ) -> Result<BatchDerivationResult> {
555 let start_time = std::time::Instant::now();
556 let mut keys = HashMap::new();
557 let mut failures = HashMap::new();
558 let mut cache_hits = 0u64;
559
560 for path in request.paths {
561 match self.derive_key(&path) {
562 Ok(key) => {
563 if self.cache.get(&path).is_some() {
565 cache_hits += 1;
566 }
567 keys.insert(path, key);
568 }
569 Err(e) => {
570 failures.insert(path, e.to_string());
571 }
572 }
573 }
574
575 let processing_time = start_time.elapsed();
576 let total_requests = keys.len() + failures.len();
577 let cache_hit_rate = if total_requests > 0 {
578 cache_hits as f64 / total_requests as f64
579 } else {
580 0.0
581 };
582
583 Ok(BatchDerivationResult {
584 keys,
585 failures,
586 cache_hit_rate,
587 processing_time,
588 })
589 }
590
591 pub fn stats(&self) -> DerivationStats {
593 let (cache_hits, cache_misses, cache_size) = self.cache.stats();
594
595 DerivationStats {
596 total_derived: self.master_seed.derivation_counter(),
597 cache_hits,
598 cache_misses,
599 avg_derivation_time_us: 0, batch_operations: 0, cache_size,
602 }
603 }
604
605 pub fn clear_cache(&self) {
607 self.cache.clear();
608 }
609}
610
611impl DerivedKey {
612 pub fn increment_usage(&mut self) {
614 self.usage_counter += 1;
615 }
616
617 pub fn ed25519_keypair(&self) -> (SigningKey, VerifyingKey) {
619 (self.secret_key.clone(), self.public_key)
620 }
621
622 pub fn x25519_keypair(&self) -> ([u8; 32], [u8; 32]) {
624 (self.x25519_secret, self.x25519_public)
625 }
626}
627
628fn current_timestamp() -> u64 {
630 std::time::SystemTime::now()
631 .duration_since(std::time::UNIX_EPOCH)
632 .map(|d| d.as_secs())
633 .unwrap_or(0)
634}
635
636trait Zeroize {
638 fn zeroize(&mut self);
639}
640
641impl Zeroize for Vec<u8> {
642 fn zeroize(&mut self) {
643 for byte in self.iter_mut() {
644 *byte = 0;
645 }
646 }
647}
648
649impl Zeroize for [u8; 32] {
650 fn zeroize(&mut self) {
651 for byte in self.iter_mut() {
652 *byte = 0;
653 }
654 }
655}
656
657#[cfg(test)]
658mod tests {
659 use super::*;
660
661 #[test]
662 fn test_master_seed_generation() {
663 let seed = MasterSeed::generate().unwrap();
664 assert_eq!(seed.seed_material().len(), MASTER_SEED_SIZE);
665 assert!(seed._created_at > 0);
666 }
667
668 #[test]
669 fn test_derivation_path_parsing() {
670 let path = DerivationPath::from_string("m/0'/1/2'").unwrap();
671 assert_eq!(path.components().len(), 3);
672 assert!(path.is_hardened(0));
673 assert!(!path.is_hardened(1));
674 assert!(path.is_hardened(2));
675
676 let path_str = path.to_string();
677 assert_eq!(path_str, "m/0'/1/2'");
678 }
679
680 #[test]
681 fn test_key_derivation() {
682 let master_seed = MasterSeed::generate().unwrap();
683 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
684
685 let path = DerivationPath::from_string("m/0'/1").unwrap();
686 let derived_key = derivation.derive_key(&path).unwrap();
687
688 assert_eq!(derived_key.path, path);
689 assert_eq!(derived_key.x25519_secret.len(), 32);
690 assert_eq!(derived_key.x25519_public.len(), 32);
691 }
692
693 #[test]
694 fn test_key_derivation_cache() {
695 let master_seed = MasterSeed::generate().unwrap();
696 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
697
698 let path = DerivationPath::from_string("m/0'/1").unwrap();
699
700 let key1 = derivation.derive_key(&path).unwrap();
702
703 let key2 = derivation.derive_key(&path).unwrap();
705
706 assert_eq!(key1.secret_key.as_bytes(), key2.secret_key.as_bytes());
708 assert_eq!(key1.public_key.as_bytes(), key2.public_key.as_bytes());
709
710 let stats = derivation.stats();
712 assert!(stats.cache_hits > 0);
713 }
714
715 #[test]
716 fn test_batch_derivation() {
717 let master_seed = MasterSeed::generate().unwrap();
718 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
719
720 let paths = vec![
721 DerivationPath::from_string("m/0'/1").unwrap(),
722 DerivationPath::from_string("m/0'/2").unwrap(),
723 DerivationPath::from_string("m/1'/0").unwrap(),
724 ];
725
726 let request = BatchDerivationRequest {
727 paths: paths.clone(),
728 use_cache: true,
729 priority: DerivationPriority::Normal,
730 };
731
732 let result = derivation.derive_batch(request).unwrap();
733
734 assert_eq!(result.keys.len(), 3);
735 assert_eq!(result.failures.len(), 0);
736
737 for path in paths {
739 assert!(result.keys.contains_key(&path));
740 }
741 }
742
743 #[test]
744 fn test_derivation_path_depth_limit() {
745 let components = vec![0u32; MAX_DERIVATION_DEPTH + 1];
746 let result = DerivationPath::new(components);
747 assert!(result.is_err());
748 }
749
750 #[test]
751 fn test_hardened_derivation() {
752 let master_seed = MasterSeed::generate().unwrap();
753 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
754
755 let hardened_path = DerivationPath::from_string("m/0'").unwrap();
756 let normal_path = DerivationPath::from_string("m/0").unwrap();
757
758 let hardened_key = derivation.derive_key(&hardened_path).unwrap();
759 let normal_key = derivation.derive_key(&normal_path).unwrap();
760
761 assert_ne!(
763 hardened_key.secret_key.as_bytes(),
764 normal_key.secret_key.as_bytes()
765 );
766 assert_ne!(
767 hardened_key.public_key.as_bytes(),
768 normal_key.public_key.as_bytes()
769 );
770 }
771}