1use crate::error::SecurityError;
33use crate::quantum_crypto::ant_quic_integration::{MlDsaPublicKey, MlDsaSecretKey};
34use crate::secure_memory::SecureMemory;
35use crate::{P2PError, Result};
36use rand::{RngCore, thread_rng};
37use saorsa_pqc::{HkdfSha3_256, api::traits::Kdf};
38use std::collections::HashMap;
39use std::sync::{Arc, RwLock};
40
41const MASTER_SEED_SIZE: usize = 32;
43
44const ML_DSA_PUB_LEN: usize = 1952;
46const ML_DSA_SEC_LEN: usize = 4032;
47
48const MAX_DERIVATION_DEPTH: usize = 10;
50
51#[allow(dead_code)]
53const PATH_INDEX_SIZE: usize = 4;
54
55const HARDENED_OFFSET: u32 = 0x8000_0000;
57
58pub struct MasterSeed {
60 seed: SecureMemory,
62 _created_at: u64,
64 derivation_counter: u64,
66}
67
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct DerivationPath {
71 components: Vec<u32>,
73}
74
75pub struct DerivedKey {
77 pub secret_key: std::sync::Arc<MlDsaSecretKey>,
79 pub public_key: MlDsaPublicKey,
81 pub path: DerivationPath,
83 pub created_at: u64,
85 pub usage_counter: u64,
87}
88
89impl Clone for DerivedKey {
90 fn clone(&self) -> Self {
91 Self {
92 secret_key: std::sync::Arc::clone(&self.secret_key),
93 public_key: self.public_key.clone(),
94 path: self.path.clone(),
95 created_at: self.created_at,
96 usage_counter: self.usage_counter,
97 }
98 }
99}
100
101pub struct KeyDerivationCache {
103 cache: RwLock<HashMap<DerivationPath, DerivedKey>>,
105 max_size: usize,
107 hits: std::sync::atomic::AtomicU64,
109 misses: std::sync::atomic::AtomicU64,
111}
112
113pub struct HierarchicalKeyDerivation {
115 master_seed: MasterSeed,
117 cache: Arc<KeyDerivationCache>,
119}
120
121pub struct BatchDerivationRequest {
123 pub paths: Vec<DerivationPath>,
125 pub use_cache: bool,
127 pub priority: DerivationPriority,
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum DerivationPriority {
134 Low,
136 Normal,
138 High,
140 Critical,
142}
143
144pub struct BatchDerivationResult {
146 pub keys: HashMap<DerivationPath, DerivedKey>,
148 pub failures: HashMap<DerivationPath, String>,
150 pub cache_hit_rate: f64,
152 pub processing_time: std::time::Duration,
154}
155
156#[derive(Debug, Clone, Default)]
158pub struct DerivationStats {
159 pub total_derived: u64,
161 pub cache_hits: u64,
163 pub cache_misses: u64,
165 pub avg_derivation_time_us: u64,
167 pub batch_operations: u64,
169 pub cache_size: usize,
171}
172
173impl MasterSeed {
174 pub fn generate() -> Result<Self> {
176 let mut seed_bytes = vec![0u8; MASTER_SEED_SIZE];
177 thread_rng().fill_bytes(&mut seed_bytes);
178
179 let seed = SecureMemory::from_slice(&seed_bytes)?;
180
181 seed_bytes.zeroize();
183
184 Ok(Self {
185 seed,
186 _created_at: current_timestamp(),
187 derivation_counter: 0,
188 })
189 }
190
191 pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
193 if entropy.len() < MASTER_SEED_SIZE {
194 return Err(P2PError::Security(SecurityError::InvalidKey(
195 "Insufficient entropy for master seed".to_string().into(),
196 )));
197 }
198
199 let seed = SecureMemory::from_slice(&entropy[..MASTER_SEED_SIZE])?;
200
201 Ok(Self {
202 seed,
203 _created_at: current_timestamp(),
204 derivation_counter: 0,
205 })
206 }
207
208 pub fn seed_material(&self) -> &[u8] {
210 self.seed.as_slice()
211 }
212
213 pub fn increment_counter(&mut self) {
215 self.derivation_counter += 1;
216 }
217
218 pub fn derivation_counter(&self) -> u64 {
220 self.derivation_counter
221 }
222}
223
224impl DerivationPath {
225 pub fn new(components: Vec<u32>) -> Result<Self> {
227 if components.len() > MAX_DERIVATION_DEPTH {
228 return Err(P2PError::Security(SecurityError::InvalidKey(
229 format!(
230 "Derivation path too deep: {} > {}",
231 components.len(),
232 MAX_DERIVATION_DEPTH
233 )
234 .into(),
235 )));
236 }
237
238 Ok(Self { components })
239 }
240
241 pub fn from_string(path_str: &str) -> Result<Self> {
243 let parts: Vec<&str> = path_str.split('/').collect();
244
245 if parts.is_empty() || parts[0] != "m" {
246 return Err(P2PError::Security(SecurityError::InvalidKey(
247 "Invalid derivation path format".to_string().into(),
248 )));
249 }
250
251 let mut components = Vec::new();
252
253 for part in parts.iter().skip(1) {
254 if part.is_empty() {
255 continue;
256 }
257
258 let (index_str, hardened) = if let Some(stripped) = part.strip_suffix('\'') {
259 (stripped, true)
260 } else {
261 (*part, false)
262 };
263
264 let index: u32 = index_str.parse().map_err(|_| {
265 P2PError::Security(SecurityError::InvalidKey(
266 format!("Invalid path component: {part}").into(),
267 ))
268 })?;
269
270 let final_index = if hardened {
271 index + HARDENED_OFFSET
272 } else {
273 index
274 };
275
276 components.push(final_index);
277 }
278
279 Self::new(components)
280 }
281
282 pub fn components(&self) -> &[u32] {
285 &self.components
286 }
287
288 pub fn is_hardened(&self, index: usize) -> bool {
290 self.components
291 .get(index)
292 .map(|&c| c >= HARDENED_OFFSET)
293 .unwrap_or(false)
294 }
295
296 pub fn depth(&self) -> usize {
298 self.components.len()
299 }
300
301 pub fn child(&self, component: u32) -> Result<Self> {
303 let mut new_components = self.components.clone();
304 new_components.push(component);
305 Self::new(new_components)
306 }
307
308 pub fn hardened_child(&self, index: u32) -> Result<Self> {
310 self.child(index + HARDENED_OFFSET)
311 }
312}
313
314impl std::fmt::Display for DerivationPath {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 write!(f, "m")?;
317 for &component in &self.components {
318 write!(f, "/")?;
319 if component >= HARDENED_OFFSET {
320 write!(f, "{}'", component - HARDENED_OFFSET)?;
321 } else {
322 write!(f, "{}", component)?;
323 }
324 }
325 Ok(())
326 }
327}
328
329impl KeyDerivationCache {
330 pub fn new(max_size: usize) -> Self {
332 Self {
333 cache: RwLock::new(HashMap::new()),
334 max_size,
335 hits: std::sync::atomic::AtomicU64::new(0),
336 misses: std::sync::atomic::AtomicU64::new(0),
337 }
338 }
339
340 pub fn get(&self, path: &DerivationPath) -> Option<DerivedKey> {
342 let cache = match self.cache.read() {
343 Ok(c) => c,
344 Err(_) => return None,
345 };
346 if let Some(key) = cache.get(path) {
347 self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
348 Some(key.clone())
349 } else {
350 self.misses
351 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
352 None
353 }
354 }
355
356 pub fn insert(&self, path: DerivationPath, key: DerivedKey) {
358 let mut cache = match self.cache.write() {
359 Ok(c) => c,
360 Err(_) => return,
361 };
362
363 if cache.len() >= self.max_size {
365 let oldest_path = cache
366 .iter()
367 .min_by_key(|(_, k)| k.created_at)
368 .map(|(p, _)| p.clone());
369
370 if let Some(path_to_remove) = oldest_path {
371 cache.remove(&path_to_remove);
372 }
373 }
374
375 cache.insert(path, key);
376 }
377
378 pub fn clear(&self) {
380 let mut cache = match self.cache.write() {
381 Ok(c) => c,
382 Err(_) => return,
383 };
384 cache.clear();
385 }
386
387 pub fn stats(&self) -> (u64, u64, usize) {
389 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
390 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
391 let size = self.cache.read().map(|c| c.len()).unwrap_or(0);
392 (hits, misses, size)
393 }
394}
395
396impl HierarchicalKeyDerivation {
397 pub fn new(master_seed: MasterSeed) -> Self {
399 let cache = Arc::new(KeyDerivationCache::new(1000)); Self { master_seed, cache }
402 }
403
404 pub fn with_cache_size(master_seed: MasterSeed, cache_size: usize) -> Self {
406 let cache = Arc::new(KeyDerivationCache::new(cache_size));
407 Self { master_seed, cache }
408 }
409
410 pub fn derive_key(&mut self, path: &DerivationPath) -> Result<DerivedKey> {
412 if let Some(cached_key) = self.cache.get(path) {
414 return Ok(cached_key);
415 }
416
417 let derived_key = self.derive_key_internal(path)?;
419
420 self.cache.insert(path.clone(), derived_key.clone());
422
423 self.master_seed.increment_counter();
425
426 Ok(derived_key)
427 }
428
429 fn derive_key_internal(&self, path: &DerivationPath) -> Result<DerivedKey> {
431 let mut current_key = self.master_seed.seed_material().to_vec();
432 let mut current_chaincode = [0u8; 32];
433
434 let mut temp_key = [0u8; 32];
436 HkdfSha3_256::derive(¤t_key, None, b"ml-dsa seed", &mut temp_key).map_err(|_| {
437 P2PError::Security(SecurityError::InvalidKey(
438 "HKDF derivation failed".to_string().into(),
439 ))
440 })?;
441 current_key.copy_from_slice(&temp_key);
442
443 HkdfSha3_256::derive(¤t_key, None, b"chaincode", &mut current_chaincode).map_err(
444 |_| {
445 P2PError::Security(SecurityError::InvalidKey(
446 "HKDF derivation failed".to_string().into(),
447 ))
448 },
449 )?;
450
451 for &component in path.components() {
453 let (new_key, new_chaincode) =
454 self.derive_child_key(¤t_key, ¤t_chaincode, component)?;
455 current_key = new_key;
456 current_chaincode = new_chaincode;
457 }
458
459 let mut derived = vec![0u8; ML_DSA_PUB_LEN + ML_DSA_SEC_LEN];
461 HkdfSha3_256::derive(
462 ¤t_key,
463 Some(¤t_chaincode),
464 b"ml-dsa keypair",
465 &mut derived,
466 )
467 .map_err(|_| {
468 P2PError::Security(SecurityError::InvalidKey(
469 "HKDF derivation failed".to_string().into(),
470 ))
471 })?;
472 let pub_bytes = &derived[..ML_DSA_PUB_LEN];
473 let sec_bytes = &derived[ML_DSA_PUB_LEN..];
474 let public_key = MlDsaPublicKey::from_bytes(pub_bytes).map_err(|e| {
475 P2PError::Security(SecurityError::InvalidKey(
476 format!("Invalid ML-DSA public key: {e}").into(),
477 ))
478 })?;
479 let secret_key = MlDsaSecretKey::from_bytes(sec_bytes).map_err(|e| {
480 P2PError::Security(SecurityError::InvalidKey(
481 format!("Invalid ML-DSA secret key: {e}").into(),
482 ))
483 })?;
484
485 current_key.zeroize();
487 current_chaincode.zeroize();
488 derived.zeroize();
489
490 Ok(DerivedKey {
491 secret_key: std::sync::Arc::new(secret_key),
492 public_key,
493 path: path.clone(),
494 created_at: current_timestamp(),
495 usage_counter: 0,
496 })
497 }
498
499 fn derive_child_key(
501 &self,
502 parent_key: &[u8],
503 parent_chaincode: &[u8],
504 index: u32,
505 ) -> Result<(Vec<u8>, [u8; 32])> {
506 let mut data = Vec::with_capacity(parent_key.len() + parent_chaincode.len() + 4);
508 data.extend_from_slice(parent_key);
509 data.extend_from_slice(parent_chaincode);
510 data.extend_from_slice(&index.to_be_bytes());
511
512 let mut child_key = vec![0u8; parent_key.len().max(32)];
513 let mut child_chaincode = [0u8; 32];
514
515 HkdfSha3_256::derive(&data, Some(parent_chaincode), b"key", &mut child_key).map_err(
516 |_| {
517 P2PError::Security(SecurityError::InvalidKey(
518 "Child key derivation failed".to_string().into(),
519 ))
520 },
521 )?;
522 HkdfSha3_256::derive(
523 &data,
524 Some(parent_chaincode),
525 b"chaincode",
526 &mut child_chaincode,
527 )
528 .map_err(|_| {
529 P2PError::Security(SecurityError::InvalidKey(
530 "Child chaincode derivation failed".to_string().into(),
531 ))
532 })?;
533
534 data.zeroize();
536
537 Ok((child_key, child_chaincode))
538 }
539
540 pub fn derive_batch(
542 &mut self,
543 request: BatchDerivationRequest,
544 ) -> Result<BatchDerivationResult> {
545 let start_time = std::time::Instant::now();
546 let mut keys = HashMap::new();
547 let mut failures = HashMap::new();
548 let mut cache_hits = 0u64;
549
550 for path in request.paths {
551 match self.derive_key(&path) {
552 Ok(key) => {
553 if self.cache.get(&path).is_some() {
555 cache_hits += 1;
556 }
557 keys.insert(path, key);
558 }
559 Err(e) => {
560 failures.insert(path, e.to_string());
561 }
562 }
563 }
564
565 let processing_time = start_time.elapsed();
566 let total_requests = keys.len() + failures.len();
567 let cache_hit_rate = if total_requests > 0 {
568 cache_hits as f64 / total_requests as f64
569 } else {
570 0.0
571 };
572
573 Ok(BatchDerivationResult {
574 keys,
575 failures,
576 cache_hit_rate,
577 processing_time,
578 })
579 }
580
581 pub fn stats(&self) -> DerivationStats {
583 let (cache_hits, cache_misses, cache_size) = self.cache.stats();
584
585 DerivationStats {
586 total_derived: self.master_seed.derivation_counter(),
587 cache_hits,
588 cache_misses,
589 avg_derivation_time_us: 0, batch_operations: 0, cache_size,
592 }
593 }
594
595 pub fn clear_cache(&self) {
597 self.cache.clear();
598 }
599}
600
601impl DerivedKey {
602 pub fn increment_usage(&mut self) {
604 self.usage_counter += 1;
605 }
606
607 pub fn ml_dsa_keypair(&self) -> (MlDsaPublicKey, std::sync::Arc<MlDsaSecretKey>) {
609 (
610 self.public_key.clone(),
611 std::sync::Arc::clone(&self.secret_key),
612 )
613 }
614}
615
616fn current_timestamp() -> u64 {
618 std::time::SystemTime::now()
619 .duration_since(std::time::UNIX_EPOCH)
620 .map(|d| d.as_secs())
621 .unwrap_or(0)
622}
623
624trait Zeroize {
626 fn zeroize(&mut self);
627}
628
629impl Zeroize for Vec<u8> {
630 fn zeroize(&mut self) {
631 for byte in self.iter_mut() {
632 *byte = 0;
633 }
634 }
635}
636
637impl Zeroize for [u8; 32] {
638 fn zeroize(&mut self) {
639 for byte in self.iter_mut() {
640 *byte = 0;
641 }
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648
649 #[test]
650 fn test_master_seed_generation() {
651 let seed = MasterSeed::generate().unwrap();
652 assert_eq!(seed.seed_material().len(), MASTER_SEED_SIZE);
653 assert!(seed._created_at > 0);
654 }
655
656 #[test]
657 fn test_derivation_path_parsing() {
658 let path = DerivationPath::from_string("m/0'/1/2'").unwrap();
659 assert_eq!(path.components().len(), 3);
660 assert!(path.is_hardened(0));
661 assert!(!path.is_hardened(1));
662 assert!(path.is_hardened(2));
663
664 let path_str = path.to_string();
665 assert_eq!(path_str, "m/0'/1/2'");
666 }
667
668 #[test]
669 fn test_key_derivation() {
670 let master_seed = MasterSeed::generate().unwrap();
671 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
672
673 let path = DerivationPath::from_string("m/0'/1").unwrap();
674 let derived_key = derivation.derive_key(&path).unwrap();
675
676 assert_eq!(derived_key.path, path);
677 assert_eq!(derived_key.public_key.as_bytes().len(), ML_DSA_PUB_LEN);
679 assert_eq!(derived_key.secret_key.as_bytes().len(), ML_DSA_SEC_LEN);
680 }
681
682 #[test]
683 fn test_key_derivation_cache() {
684 let master_seed = MasterSeed::generate().unwrap();
685 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
686
687 let path = DerivationPath::from_string("m/0'/1").unwrap();
688
689 let key1 = derivation.derive_key(&path).unwrap();
691
692 let key2 = derivation.derive_key(&path).unwrap();
694
695 assert_eq!(key1.secret_key.as_bytes(), key2.secret_key.as_bytes());
697 assert_eq!(key1.public_key.as_bytes(), key2.public_key.as_bytes());
698
699 let stats = derivation.stats();
701 assert!(stats.cache_hits > 0);
702 }
703
704 #[test]
705 fn test_batch_derivation() {
706 let master_seed = MasterSeed::generate().unwrap();
707 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
708
709 let paths = vec![
710 DerivationPath::from_string("m/0'/1").unwrap(),
711 DerivationPath::from_string("m/0'/2").unwrap(),
712 DerivationPath::from_string("m/1'/0").unwrap(),
713 ];
714
715 let request = BatchDerivationRequest {
716 paths: paths.clone(),
717 use_cache: true,
718 priority: DerivationPriority::Normal,
719 };
720
721 let result = derivation.derive_batch(request).unwrap();
722
723 assert_eq!(result.keys.len(), 3);
724 assert_eq!(result.failures.len(), 0);
725
726 for path in paths {
728 assert!(result.keys.contains_key(&path));
729 }
730 }
731
732 #[test]
733 fn test_derivation_path_depth_limit() {
734 let components = vec![0u32; MAX_DERIVATION_DEPTH + 1];
735 let result = DerivationPath::new(components);
736 assert!(result.is_err());
737 }
738
739 #[test]
740 fn test_hardened_derivation() {
741 let master_seed = MasterSeed::generate().unwrap();
742 let mut derivation = HierarchicalKeyDerivation::new(master_seed);
743
744 let hardened_path = DerivationPath::from_string("m/0'").unwrap();
745 let normal_path = DerivationPath::from_string("m/0").unwrap();
746
747 let hardened_key = derivation.derive_key(&hardened_path).unwrap();
748 let normal_key = derivation.derive_key(&normal_path).unwrap();
749 assert_ne!(
751 hardened_key.secret_key.as_bytes(),
752 normal_key.secret_key.as_bytes()
753 );
754 assert_ne!(
755 hardened_key.public_key.as_bytes(),
756 normal_key.public_key.as_bytes()
757 );
758 }
759}