1use crate::error::SecurityError;
33use crate::quantum_crypto::ant_quic_integration::{MlDsaPublicKey, MlDsaSecretKey};
34use crate::secure_memory::SecureMemory;
35use crate::{P2PError, Result};
36use rand::{RngCore, thread_rng};
37use saorsa_pqc::{HkdfSha3_256, api::traits::Kdf};
38use std::collections::HashMap;
39use std::sync::{Arc, RwLock};
40
41const MASTER_SEED_SIZE: usize = 32;
43
44const ML_DSA_PUB_LEN: usize = 1952;
46const ML_DSA_SEC_LEN: usize = 4032;
47
48const MAX_DERIVATION_DEPTH: usize = 10;
50
51#[allow(dead_code)]
53const PATH_INDEX_SIZE: usize = 4;
54
55const HARDENED_OFFSET: u32 = 0x8000_0000;
57
58pub struct MasterSeed {
60 seed: SecureMemory,
62 _created_at: u64,
64 derivation_counter: u64,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct DerivationPath {
71 components: Vec<u32>,
73}
74
75pub struct DerivedKey {
77 pub secret_key: std::sync::Arc<MlDsaSecretKey>,
79 pub public_key: MlDsaPublicKey,
81 pub path: DerivationPath,
83 pub created_at: u64,
85 pub usage_counter: u64,
87}
88
89impl Clone for DerivedKey {
90 fn clone(&self) -> Self {
91 Self {
92 secret_key: std::sync::Arc::clone(&self.secret_key),
93 public_key: self.public_key.clone(),
94 path: self.path.clone(),
95 created_at: self.created_at,
96 usage_counter: self.usage_counter,
97 }
98 }
99}
100
101pub struct KeyDerivationCache {
103 cache: RwLock<HashMap<DerivationPath, DerivedKey>>,
105 max_size: usize,
107 hits: std::sync::atomic::AtomicU64,
109 misses: std::sync::atomic::AtomicU64,
111}
112
113pub struct HierarchicalKeyDerivation {
115 master_seed: MasterSeed,
117 cache: Arc<KeyDerivationCache>,
119}
120
121pub struct BatchDerivationRequest {
123 pub paths: Vec<DerivationPath>,
125 pub use_cache: bool,
127 pub priority: DerivationPriority,
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum DerivationPriority {
134 Low,
136 Normal,
138 High,
140 Critical,
142}
143
144pub struct BatchDerivationResult {
146 pub keys: HashMap<DerivationPath, DerivedKey>,
148 pub failures: HashMap<DerivationPath, String>,
150 pub cache_hit_rate: f64,
152 pub processing_time: std::time::Duration,
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct DerivationStats {
159 pub total_derived: u64,
161 pub cache_hits: u64,
163 pub cache_misses: u64,
165 pub avg_derivation_time_us: u64,
167 pub batch_operations: u64,
169 pub cache_size: usize,
171}
172
173impl MasterSeed {
174 pub fn generate() -> Result<Self> {
176 let mut seed_bytes = vec![0u8; MASTER_SEED_SIZE];
177 thread_rng().fill_bytes(&mut seed_bytes);
178
179 let seed = SecureMemory::from_slice(&seed_bytes)?;
180
181 seed_bytes.zeroize();
183
184 Ok(Self {
185 seed,
186 _created_at: current_timestamp(),
187 derivation_counter: 0,
188 })
189 }
190
191 pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
193 if entropy.len() < MASTER_SEED_SIZE {
194 return Err(P2PError::Security(SecurityError::InvalidKey(
195 "Insufficient entropy for master seed".to_string().into(),
196 )));
197 }
198
199 let seed = SecureMemory::from_slice(&entropy[..MASTER_SEED_SIZE])?;
200
201 Ok(Self {
202 seed,
203 _created_at: current_timestamp(),
204 derivation_counter: 0,
205 })
206 }
207
208 pub fn seed_material(&self) -> &[u8] {
210 self.seed.as_slice()
211 }
212
213 pub fn increment_counter(&mut self) {
215 self.derivation_counter += 1;
216 }
217
218 pub fn derivation_counter(&self) -> u64 {
220 self.derivation_counter
221 }
222}
223
224impl DerivationPath {
225 pub fn new(components: Vec<u32>) -> Result<Self> {
227 if components.len() > MAX_DERIVATION_DEPTH {
228 return Err(P2PError::Security(SecurityError::InvalidKey(
229 format!(
230 "Derivation path too deep: {} > {}",
231 components.len(),
232 MAX_DERIVATION_DEPTH
233 )
234 .into(),
235 )));
236 }
237
238 Ok(Self { components })
239 }
240
241 pub fn from_string(path_str: &str) -> Result<Self> {
243 let parts: Vec<&str> = path_str.split('/').collect();
244
245 if parts.is_empty() || parts[0] != "m" {
246 return Err(P2PError::Security(SecurityError::InvalidKey(
247 "Invalid derivation path format".to_string().into(),
248 )));
249 }
250
251 let mut components = Vec::new();
252
253 for part in parts.iter().skip(1) {
254 if part.is_empty() {
255 continue;
256 }
257
258 let (index_str, hardened) = if let Some(stripped) = part.strip_suffix('\'') {
259 (stripped, true)
260 } else {
261 (*part, false)
262 };
263
264 let index: u32 = index_str.parse().map_err(|_| {
265 P2PError::Security(SecurityError::InvalidKey(
266 format!("Invalid path component: {part}").into(),
267 ))
268 })?;
269
270 let final_index = if hardened {
271 index + HARDENED_OFFSET
272 } else {
273 index
274 };
275
276 components.push(final_index);
277 }
278
279 Self::new(components)
280 }
281
282 pub fn components(&self) -> &[u32] {
285 &self.components
286 }
287
288 pub fn is_hardened(&self, index: usize) -> bool {
290 self.components
291 .get(index)
292 .map(|&c| c >= HARDENED_OFFSET)
293 .unwrap_or(false)
294 }
295
296 pub fn depth(&self) -> usize {
298 self.components.len()
299 }
300
301 pub fn child(&self, component: u32) -> Result<Self> {
303 let mut new_components = self.components.clone();
304 new_components.push(component);
305 Self::new(new_components)
306 }
307
308 pub fn hardened_child(&self, index: u32) -> Result<Self> {
310 self.child(index + HARDENED_OFFSET)
311 }
312}
313
314impl std::fmt::Display for DerivationPath {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(f, "m")?;
317 for &component in &self.components {
318 write!(f, "/")?;
319 if component >= HARDENED_OFFSET {
320 write!(f, "{}'", component - HARDENED_OFFSET)?;
321 } else {
322 write!(f, "{}", component)?;
323 }
324 }
325 Ok(())
326 }
327}
328
329impl KeyDerivationCache {
330 pub fn new(max_size: usize) -> Self {
332 Self {
333 cache: RwLock::new(HashMap::new()),
334 max_size,
335 hits: std::sync::atomic::AtomicU64::new(0),
336 misses: std::sync::atomic::AtomicU64::new(0),
337 }
338 }
339
340 pub fn get(&self, path: &DerivationPath) -> Option<DerivedKey> {
342 let cache = match self.cache.read() {
343 Ok(c) => c,
344 Err(_) => return None,
345 };
346 if let Some(key) = cache.get(path) {
347 self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
348 Some(key.clone())
349 } else {
350 self.misses
351 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
352 None
353 }
354 }
355
356 pub fn insert(&self, path: DerivationPath, key: DerivedKey) {
358 let mut cache = match self.cache.write() {
359 Ok(c) => c,
360 Err(_) => return,
361 };
362
363 if cache.len() >= self.max_size {
365 let oldest_path = cache
366 .iter()
367 .min_by_key(|(_, k)| k.created_at)
368 .map(|(p, _)| p.clone());
369
370 if let Some(path_to_remove) = oldest_path {
371 cache.remove(&path_to_remove);
372 }
373 }
374
375 cache.insert(path, key);
376 }
377
378 pub fn clear(&self) {
380 let mut cache = match self.cache.write() {
381 Ok(c) => c,
382 Err(_) => return,
383 };
384 cache.clear();
385 }
386
387 pub fn stats(&self) -> (u64, u64, usize) {
389 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
390 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
391 let size = self.cache.read().map(|c| c.len()).unwrap_or(0);
392 (hits, misses, size)
393 }
394}
395
396impl HierarchicalKeyDerivation {
397 pub fn new(master_seed: MasterSeed) -> Self {
399 let cache = Arc::new(KeyDerivationCache::new(1000)); Self { master_seed, cache }
402 }
403
404 pub fn with_cache_size(master_seed: MasterSeed, cache_size: usize) -> Self {
406 let cache = Arc::new(KeyDerivationCache::new(cache_size));
407 Self { master_seed, cache }
408 }
409
410 pub fn derive_key(&mut self, path: &DerivationPath) -> Result<DerivedKey> {
412 if let Some(cached_key) = self.cache.get(path) {
414 return Ok(cached_key);
415 }
416
417 let derived_key = self.derive_key_internal(path)?;
419
420 self.cache.insert(path.clone(), derived_key.clone());
422
423 self.master_seed.increment_counter();
425
426 Ok(derived_key)
427 }
428
429 fn derive_key_internal(&self, path: &DerivationPath) -> Result<DerivedKey> {
431 let mut current_key = self.master_seed.seed_material().to_vec();
432 let mut current_chaincode = [0u8; 32];
433
434 let mut temp_key = [0u8; 32];
436 HkdfSha3_256::derive(¤t_key, None, b"ml-dsa seed", &mut temp_key).map_err(|_| {
437 P2PError::Security(SecurityError::InvalidKey(
438 "HKDF derivation failed".to_string().into(),
439 ))
440 })?;
441 current_key.copy_from_slice(&temp_key);
442
443 HkdfSha3_256::derive(¤t_key, None, b"chaincode", &mut current_chaincode).map_err(
444 |_| {
445 P2PError::Security(SecurityError::InvalidKey(
446 "HKDF derivation failed".to_string().into(),
447 ))
448 },
449 )?;
450
451 for &component in path.components() {
453 let (new_key, new_chaincode) =
454 self.derive_child_key(¤t_key, ¤t_chaincode, component)?;
455 current_key = new_key;
456 current_chaincode = new_chaincode;
457 }
458
459 let mut derived = vec![0u8; ML_DSA_PUB_LEN + ML_DSA_SEC_LEN];
461 HkdfSha3_256::derive(
462 ¤t_key,
463 Some(¤t_chaincode),
464 b"ml-dsa keypair",
465 &mut derived,
466 )
467 .map_err(|_| {
468 P2PError::Security(SecurityError::InvalidKey(
469 "HKDF derivation failed".to_string().into(),
470 ))
471 })?;
472 let pub_bytes = &derived[..ML_DSA_PUB_LEN];
473 let sec_bytes = &derived[ML_DSA_PUB_LEN..];
474 let public_key = MlDsaPublicKey::from_bytes(pub_bytes).map_err(|e| {
475 P2PError::Security(SecurityError::InvalidKey(
476 format!("Invalid ML-DSA public key: {e}").into(),
477 ))
478 })?;
479 let secret_key = MlDsaSecretKey::from_bytes(sec_bytes).map_err(|e| {
480 P2PError::Security(SecurityError::InvalidKey(
481 format!("Invalid ML-DSA secret key: {e}").into(),
482 ))
483 })?;
484
485 crate::quantum_crypto::ant_quic_integration::register_debug_ml_dsa_keypair(
486 &secret_key,
487 &public_key,
488 );
489
490 current_key.zeroize();
492 current_chaincode.zeroize();
493 derived.zeroize();
494
495 Ok(DerivedKey {
496 secret_key: std::sync::Arc::new(secret_key),
497 public_key,
498 path: path.clone(),
499 created_at: current_timestamp(),
500 usage_counter: 0,
501 })
502 }
503
504 fn derive_child_key(
506 &self,
507 parent_key: &[u8],
508 parent_chaincode: &[u8],
509 index: u32,
510 ) -> Result<(Vec<u8>, [u8; 32])> {
511 let mut data = Vec::with_capacity(parent_key.len() + parent_chaincode.len() + 4);
513 data.extend_from_slice(parent_key);
514 data.extend_from_slice(parent_chaincode);
515 data.extend_from_slice(&index.to_be_bytes());
516
517 let mut child_key = vec![0u8; parent_key.len().max(32)];
518 let mut child_chaincode = [0u8; 32];
519
520 HkdfSha3_256::derive(&data, Some(parent_chaincode), b"key", &mut child_key).map_err(
521 |_| {
522 P2PError::Security(SecurityError::InvalidKey(
523 "Child key derivation failed".to_string().into(),
524 ))
525 },
526 )?;
527 HkdfSha3_256::derive(
528 &data,
529 Some(parent_chaincode),
530 b"chaincode",
531 &mut child_chaincode,
532 )
533 .map_err(|_| {
534 P2PError::Security(SecurityError::InvalidKey(
535 "Child chaincode derivation failed".to_string().into(),
536 ))
537 })?;
538
539 data.zeroize();
541
542 Ok((child_key, child_chaincode))
543 }
544
545 pub fn derive_batch(
547 &mut self,
548 request: BatchDerivationRequest,
549 ) -> Result<BatchDerivationResult> {
550 let start_time = std::time::Instant::now();
551 let mut keys = HashMap::new();
552 let mut failures = HashMap::new();
553 let mut cache_hits = 0u64;
554
555 for path in request.paths {
556 match self.derive_key(&path) {
557 Ok(key) => {
558 if self.cache.get(&path).is_some() {
560 cache_hits += 1;
561 }
562 keys.insert(path, key);
563 }
564 Err(e) => {
565 failures.insert(path, e.to_string());
566 }
567 }
568 }
569
570 let processing_time = start_time.elapsed();
571 let total_requests = keys.len() + failures.len();
572 let cache_hit_rate = if total_requests > 0 {
573 cache_hits as f64 / total_requests as f64
574 } else {
575 0.0
576 };
577
578 Ok(BatchDerivationResult {
579 keys,
580 failures,
581 cache_hit_rate,
582 processing_time,
583 })
584 }
585
586 pub fn stats(&self) -> DerivationStats {
588 let (cache_hits, cache_misses, cache_size) = self.cache.stats();
589
590 DerivationStats {
591 total_derived: self.master_seed.derivation_counter(),
592 cache_hits,
593 cache_misses,
594 avg_derivation_time_us: 0, batch_operations: 0, cache_size,
597 }
598 }
599
600 pub fn clear_cache(&self) {
602 self.cache.clear();
603 }
604}
605
606impl DerivedKey {
607 pub fn increment_usage(&mut self) {
609 self.usage_counter += 1;
610 }
611
612 pub fn ml_dsa_keypair(&self) -> (MlDsaPublicKey, std::sync::Arc<MlDsaSecretKey>) {
614 (
615 self.public_key.clone(),
616 std::sync::Arc::clone(&self.secret_key),
617 )
618 }
619}
620
621fn current_timestamp() -> u64 {
623 std::time::SystemTime::now()
624 .duration_since(std::time::UNIX_EPOCH)
625 .map(|d| d.as_secs())
626 .unwrap_or(0)
627}
628
629trait Zeroize {
631 fn zeroize(&mut self);
632}
633
634impl Zeroize for Vec<u8> {
635 fn zeroize(&mut self) {
636 for byte in self.iter_mut() {
637 *byte = 0;
638 }
639 }
640}
641
642impl Zeroize for [u8; 32] {
643 fn zeroize(&mut self) {
644 for byte in self.iter_mut() {
645 *byte = 0;
646 }
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 #[test]
655 fn test_master_seed_generation() {
656 let seed = MasterSeed::generate().unwrap();
657 assert_eq!(seed.seed_material().len(), MASTER_SEED_SIZE);
658 assert!(seed._created_at > 0);
659 }
660
661 #[test]
662 fn test_derivation_path_parsing() {
663 let path = DerivationPath::from_string("m/0'/1/2'").unwrap();
664 assert_eq!(path.components().len(), 3);
665 assert!(path.is_hardened(0));
666 assert!(!path.is_hardened(1));
667 assert!(path.is_hardened(2));
668
669 let path_str = path.to_string();
670 assert_eq!(path_str, "m/0'/1/2'");
671 }
672
673 #[test]
674 fn test_key_derivation() {
675 let master_seed = MasterSeed::generate().unwrap();
676 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
677
678 let path = DerivationPath::from_string("m/0'/1").unwrap();
679 let derived_key = derivation.derive_key(&path).unwrap();
680
681 assert_eq!(derived_key.path, path);
682 assert_eq!(derived_key.public_key.as_bytes().len(), ML_DSA_PUB_LEN);
684 assert_eq!(derived_key.secret_key.as_bytes().len(), ML_DSA_SEC_LEN);
685 }
686
687 #[test]
688 fn test_key_derivation_cache() {
689 let master_seed = MasterSeed::generate().unwrap();
690 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
691
692 let path = DerivationPath::from_string("m/0'/1").unwrap();
693
694 let key1 = derivation.derive_key(&path).unwrap();
696
697 let key2 = derivation.derive_key(&path).unwrap();
699
700 assert_eq!(key1.secret_key.as_bytes(), key2.secret_key.as_bytes());
702 assert_eq!(key1.public_key.as_bytes(), key2.public_key.as_bytes());
703
704 let stats = derivation.stats();
706 assert!(stats.cache_hits > 0);
707 }
708
709 #[test]
710 fn test_batch_derivation() {
711 let master_seed = MasterSeed::generate().unwrap();
712 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
713
714 let paths = vec![
715 DerivationPath::from_string("m/0'/1").unwrap(),
716 DerivationPath::from_string("m/0'/2").unwrap(),
717 DerivationPath::from_string("m/1'/0").unwrap(),
718 ];
719
720 let request = BatchDerivationRequest {
721 paths: paths.clone(),
722 use_cache: true,
723 priority: DerivationPriority::Normal,
724 };
725
726 let result = derivation.derive_batch(request).unwrap();
727
728 assert_eq!(result.keys.len(), 3);
729 assert_eq!(result.failures.len(), 0);
730
731 for path in paths {
733 assert!(result.keys.contains_key(&path));
734 }
735 }
736
737 #[test]
738 fn test_derivation_path_depth_limit() {
739 let components = vec![0u32; MAX_DERIVATION_DEPTH + 1];
740 let result = DerivationPath::new(components);
741 assert!(result.is_err());
742 }
743
744 #[test]
745 fn test_hardened_derivation() {
746 let master_seed = MasterSeed::generate().unwrap();
747 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
748
749 let hardened_path = DerivationPath::from_string("m/0'").unwrap();
750 let normal_path = DerivationPath::from_string("m/0").unwrap();
751
752 let hardened_key = derivation.derive_key(&hardened_path).unwrap();
753 let normal_key = derivation.derive_key(&normal_path).unwrap();
754 assert_ne!(
756 hardened_key.secret_key.as_bytes(),
757 normal_key.secret_key.as_bytes()
758 );
759 assert_ne!(
760 hardened_key.public_key.as_bytes(),
761 normal_key.public_key.as_bytes()
762 );
763 }
764}