Skip to main content

saorsa_core/
key_derivation.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: david@saorsalabs.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Hierarchical Key Derivation System
15//!
16//! This module implements BIP32-style hierarchical deterministic key derivation
17//! adapted for Ed25519/X25519 key pairs used in the P2P network.
18//!
19//! ## Security Features
20//! - Secure entropy generation for master seeds
21//! - HMAC-based key stretching (HKDF)
22//! - Deterministic key derivation from hierarchical paths
23//! - Key isolation between different derivation contexts
24//! - Side-channel resistance through constant-time operations
25//!
26//! ## Performance Features
27//! - Batch key derivation for multiple paths
28//! - Intelligent caching of derived keys
29//! - Memory-efficient storage of key material
30//! - Async key generation for non-blocking operations
31
32use crate::error::SecurityError;
33use crate::quantum_crypto::ant_quic_integration::{MlDsaPublicKey, MlDsaSecretKey};
34use crate::secure_memory::SecureMemory;
35use crate::{P2PError, Result};
36use rand::{RngCore, thread_rng};
37use saorsa_pqc::{HkdfSha3_256, api::traits::Kdf};
38use std::collections::HashMap;
39use std::sync::{Arc, RwLock};
40
41/// Size of master seed in bytes (256 bits for security)
42const MASTER_SEED_SIZE: usize = 32;
43
44/// Sizes for ML-DSA-65 keys
45const ML_DSA_PUB_LEN: usize = 1952;
46const ML_DSA_SEC_LEN: usize = 4032;
47
48/// Maximum derivation depth to prevent stack overflow
49const MAX_DERIVATION_DEPTH: usize = 10;
50
51/// Size of derivation path index
52#[allow(dead_code)]
53const PATH_INDEX_SIZE: usize = 4;
54
55/// Hardened derivation marker (BIP32 style)
56const HARDENED_OFFSET: u32 = 0x8000_0000;
57
58/// Master seed for deterministic key derivation
59pub struct MasterSeed {
60    /// Secure seed material
61    seed: SecureMemory,
62    /// Creation timestamp
63    _created_at: u64,
64    /// Derivation counter for tracking usage
65    derivation_counter: u64,
66}
67
68/// Hierarchical key derivation path
69#[derive(Debug, Clone, PartialEq, Eq, Hash)]
70pub struct DerivationPath {
71    /// Path components (each can be hardened)
72    components: Vec<u32>,
73}
74
75/// Derived key material with metadata
76pub struct DerivedKey {
77    /// ML-DSA secret key (wrapped in Arc to avoid clone issues)
78    pub secret_key: std::sync::Arc<MlDsaSecretKey>,
79    /// ML-DSA public key
80    pub public_key: MlDsaPublicKey,
81    /// Derivation path used
82    pub path: DerivationPath,
83    /// Creation timestamp
84    pub created_at: u64,
85    /// Key usage counter
86    pub usage_counter: u64,
87}
88
89impl Clone for DerivedKey {
90    fn clone(&self) -> Self {
91        Self {
92            secret_key: std::sync::Arc::clone(&self.secret_key),
93            public_key: self.public_key.clone(),
94            path: self.path.clone(),
95            created_at: self.created_at,
96            usage_counter: self.usage_counter,
97        }
98    }
99}
100
101/// Key derivation cache for performance
102pub struct KeyDerivationCache {
103    /// Cached derived keys
104    cache: RwLock<HashMap<DerivationPath, DerivedKey>>,
105    /// Cache size limit
106    max_size: usize,
107    /// Cache hit statistics
108    hits: std::sync::atomic::AtomicU64,
109    /// Cache miss statistics
110    misses: std::sync::atomic::AtomicU64,
111}
112
113/// Hierarchical key derivation engine
114pub struct HierarchicalKeyDerivation {
115    /// Master seed
116    master_seed: MasterSeed,
117    /// Derivation cache
118    cache: Arc<KeyDerivationCache>,
119}
120
121/// Batch key derivation request
122pub struct BatchDerivationRequest {
123    /// Derivation paths to process
124    pub paths: Vec<DerivationPath>,
125    /// Whether to use cache
126    pub use_cache: bool,
127    /// Priority level for processing
128    pub priority: DerivationPriority,
129}
130
131/// Priority levels for key derivation
132#[derive(Debug, Clone, Copy, PartialEq, Eq)]
133pub enum DerivationPriority {
134    /// Low priority (background operations)
135    Low,
136    /// Normal priority (standard operations)
137    Normal,
138    /// High priority (time-sensitive operations)
139    High,
140    /// Critical priority (immediate operations)
141    Critical,
142}
143
144/// Results of batch key derivation
145pub struct BatchDerivationResult {
146    /// Successfully derived keys
147    pub keys: HashMap<DerivationPath, DerivedKey>,
148    /// Failed derivations with error messages
149    pub failures: HashMap<DerivationPath, String>,
150    /// Cache hit rate for this batch
151    pub cache_hit_rate: f64,
152    /// Total processing time
153    pub processing_time: std::time::Duration,
154}
155
156/// Statistics for key derivation performance
157#[derive(Debug, Clone, Default)]
158pub struct DerivationStats {
159    /// Total keys derived
160    pub total_derived: u64,
161    /// Total cache hits
162    pub cache_hits: u64,
163    /// Total cache misses
164    pub cache_misses: u64,
165    /// Average derivation time in microseconds
166    pub avg_derivation_time_us: u64,
167    /// Total batch operations
168    pub batch_operations: u64,
169    /// Current cache size
170    pub cache_size: usize,
171}
172
173impl MasterSeed {
174    /// Generate a new master seed with cryptographically secure randomness
175    pub fn generate() -> Result<Self> {
176        let mut seed_bytes = vec![0u8; MASTER_SEED_SIZE];
177        thread_rng().fill_bytes(&mut seed_bytes);
178
179        let seed = SecureMemory::from_slice(&seed_bytes)?;
180
181        // Zeroize the temporary buffer
182        seed_bytes.zeroize();
183
184        Ok(Self {
185            seed,
186            _created_at: current_timestamp(),
187            derivation_counter: 0,
188        })
189    }
190
191    /// Create master seed from existing entropy
192    pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
193        if entropy.len() < MASTER_SEED_SIZE {
194            return Err(P2PError::Security(SecurityError::InvalidKey(
195                "Insufficient entropy for master seed".to_string().into(),
196            )));
197        }
198
199        let seed = SecureMemory::from_slice(&entropy[..MASTER_SEED_SIZE])?;
200
201        Ok(Self {
202            seed,
203            _created_at: current_timestamp(),
204            derivation_counter: 0,
205        })
206    }
207
208    /// Get the seed material for derivation
209    pub fn seed_material(&self) -> &[u8] {
210        self.seed.as_slice()
211    }
212
213    /// Increment derivation counter
214    pub fn increment_counter(&mut self) {
215        self.derivation_counter += 1;
216    }
217
218    /// Get derivation counter
219    pub fn derivation_counter(&self) -> u64 {
220        self.derivation_counter
221    }
222}
223
224impl DerivationPath {
225    /// Create a new derivation path
226    pub fn new(components: Vec<u32>) -> Result<Self> {
227        if components.len() > MAX_DERIVATION_DEPTH {
228            return Err(P2PError::Security(SecurityError::InvalidKey(
229                format!(
230                    "Derivation path too deep: {} > {}",
231                    components.len(),
232                    MAX_DERIVATION_DEPTH
233                )
234                .into(),
235            )));
236        }
237
238        Ok(Self { components })
239    }
240
241    /// Create path from string representation (e.g., "m/0'/1/2")
242    pub fn from_string(path_str: &str) -> Result<Self> {
243        let parts: Vec<&str> = path_str.split('/').collect();
244
245        if parts.is_empty() || parts[0] != "m" {
246            return Err(P2PError::Security(SecurityError::InvalidKey(
247                "Invalid derivation path format".to_string().into(),
248            )));
249        }
250
251        let mut components = Vec::new();
252
253        for part in parts.iter().skip(1) {
254            if part.is_empty() {
255                continue;
256            }
257
258            let (index_str, hardened) = if let Some(stripped) = part.strip_suffix('\'') {
259                (stripped, true)
260            } else {
261                (*part, false)
262            };
263
264            let index: u32 = index_str.parse().map_err(|_| {
265                P2PError::Security(SecurityError::InvalidKey(
266                    format!("Invalid path component: {part}").into(),
267                ))
268            })?;
269
270            let final_index = if hardened {
271                index + HARDENED_OFFSET
272            } else {
273                index
274            };
275
276            components.push(final_index);
277        }
278
279        Self::new(components)
280    }
281
282    // (Removed inherent to_string; rely on Display/ToString)
283    /// Get path components
284    pub fn components(&self) -> &[u32] {
285        &self.components
286    }
287
288    /// Check if path component is hardened
289    pub fn is_hardened(&self, index: usize) -> bool {
290        self.components
291            .get(index)
292            .map(|&c| c >= HARDENED_OFFSET)
293            .unwrap_or(false)
294    }
295
296    /// Get depth of derivation path
297    pub fn depth(&self) -> usize {
298        self.components.len()
299    }
300
301    /// Create child path by appending component
302    pub fn child(&self, component: u32) -> Result<Self> {
303        let mut new_components = self.components.clone();
304        new_components.push(component);
305        Self::new(new_components)
306    }
307
308    /// Create hardened child path
309    pub fn hardened_child(&self, index: u32) -> Result<Self> {
310        self.child(index + HARDENED_OFFSET)
311    }
312}
313
314impl std::fmt::Display for DerivationPath {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        write!(f, "m")?;
317        for &component in &self.components {
318            write!(f, "/")?;
319            if component >= HARDENED_OFFSET {
320                write!(f, "{}'", component - HARDENED_OFFSET)?;
321            } else {
322                write!(f, "{}", component)?;
323            }
324        }
325        Ok(())
326    }
327}
328
329impl KeyDerivationCache {
330    /// Create new key derivation cache
331    pub fn new(max_size: usize) -> Self {
332        Self {
333            cache: RwLock::new(HashMap::new()),
334            max_size,
335            hits: std::sync::atomic::AtomicU64::new(0),
336            misses: std::sync::atomic::AtomicU64::new(0),
337        }
338    }
339
340    /// Get cached key
341    pub fn get(&self, path: &DerivationPath) -> Option<DerivedKey> {
342        let cache = match self.cache.read() {
343            Ok(c) => c,
344            Err(_) => return None,
345        };
346        if let Some(key) = cache.get(path) {
347            self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
348            Some(key.clone())
349        } else {
350            self.misses
351                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
352            None
353        }
354    }
355
356    /// Insert key into cache
357    pub fn insert(&self, path: DerivationPath, key: DerivedKey) {
358        let mut cache = match self.cache.write() {
359            Ok(c) => c,
360            Err(_) => return,
361        };
362
363        // Evict oldest entries if cache is full
364        if cache.len() >= self.max_size {
365            let oldest_path = cache
366                .iter()
367                .min_by_key(|(_, k)| k.created_at)
368                .map(|(p, _)| p.clone());
369
370            if let Some(path_to_remove) = oldest_path {
371                cache.remove(&path_to_remove);
372            }
373        }
374
375        cache.insert(path, key);
376    }
377
378    /// Clear the cache
379    pub fn clear(&self) {
380        let mut cache = match self.cache.write() {
381            Ok(c) => c,
382            Err(_) => return,
383        };
384        cache.clear();
385    }
386
387    /// Get cache statistics
388    pub fn stats(&self) -> (u64, u64, usize) {
389        let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
390        let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
391        let size = self.cache.read().map(|c| c.len()).unwrap_or(0);
392        (hits, misses, size)
393    }
394}
395
396impl HierarchicalKeyDerivation {
397    /// Create new hierarchical key derivation engine
398    pub fn new(master_seed: MasterSeed) -> Self {
399        let cache = Arc::new(KeyDerivationCache::new(1000)); // Default cache size
400        // Use thread-local RNG holder; generate randomness locally where needed
401        Self { master_seed, cache }
402    }
403
404    /// Create with custom cache size
405    pub fn with_cache_size(master_seed: MasterSeed, cache_size: usize) -> Self {
406        let cache = Arc::new(KeyDerivationCache::new(cache_size));
407        Self { master_seed, cache }
408    }
409
410    /// Derive key at specific path
411    pub fn derive_key(&mut self, path: &DerivationPath) -> Result<DerivedKey> {
412        // Check cache first
413        if let Some(cached_key) = self.cache.get(path) {
414            return Ok(cached_key);
415        }
416
417        // Perform actual derivation
418        let derived_key = self.derive_key_internal(path)?;
419
420        // Cache the result
421        self.cache.insert(path.clone(), derived_key.clone());
422
423        // Increment master seed counter
424        self.master_seed.increment_counter();
425
426        Ok(derived_key)
427    }
428
429    /// Internal key derivation implementation (PQC: ML-DSA keys)
430    fn derive_key_internal(&self, path: &DerivationPath) -> Result<DerivedKey> {
431        let mut current_key = self.master_seed.seed_material().to_vec();
432        let mut current_chaincode = [0u8; 32];
433
434        // Initial HKDF from master seed
435        let mut temp_key = [0u8; 32];
436        HkdfSha3_256::derive(&current_key, None, b"ml-dsa seed", &mut temp_key).map_err(|_| {
437            P2PError::Security(SecurityError::InvalidKey(
438                "HKDF derivation failed".to_string().into(),
439            ))
440        })?;
441        current_key.copy_from_slice(&temp_key);
442
443        HkdfSha3_256::derive(&current_key, None, b"chaincode", &mut current_chaincode).map_err(
444            |_| {
445                P2PError::Security(SecurityError::InvalidKey(
446                    "HKDF derivation failed".to_string().into(),
447                ))
448            },
449        )?;
450
451        // Derive through each path component
452        for &component in path.components() {
453            let (new_key, new_chaincode) =
454                self.derive_child_key(&current_key, &current_chaincode, component)?;
455            current_key = new_key;
456            current_chaincode = new_chaincode;
457        }
458
459        // Derive ML-DSA key material deterministically
460        let mut derived = vec![0u8; ML_DSA_PUB_LEN + ML_DSA_SEC_LEN];
461        HkdfSha3_256::derive(
462            &current_key,
463            Some(&current_chaincode),
464            b"ml-dsa keypair",
465            &mut derived,
466        )
467        .map_err(|_| {
468            P2PError::Security(SecurityError::InvalidKey(
469                "HKDF derivation failed".to_string().into(),
470            ))
471        })?;
472        let pub_bytes = &derived[..ML_DSA_PUB_LEN];
473        let sec_bytes = &derived[ML_DSA_PUB_LEN..];
474        let public_key = MlDsaPublicKey::from_bytes(pub_bytes).map_err(|e| {
475            P2PError::Security(SecurityError::InvalidKey(
476                format!("Invalid ML-DSA public key: {e}").into(),
477            ))
478        })?;
479        let secret_key = MlDsaSecretKey::from_bytes(sec_bytes).map_err(|e| {
480            P2PError::Security(SecurityError::InvalidKey(
481                format!("Invalid ML-DSA secret key: {e}").into(),
482            ))
483        })?;
484
485        crate::quantum_crypto::ant_quic_integration::register_debug_ml_dsa_keypair(
486            &secret_key,
487            &public_key,
488        );
489
490        // Zeroize temporary key material
491        current_key.zeroize();
492        current_chaincode.zeroize();
493        derived.zeroize();
494
495        Ok(DerivedKey {
496            secret_key: std::sync::Arc::new(secret_key),
497            public_key,
498            path: path.clone(),
499            created_at: current_timestamp(),
500            usage_counter: 0,
501        })
502    }
503
504    /// Derive child key from parent (PQC-friendly HKDF)
505    fn derive_child_key(
506        &self,
507        parent_key: &[u8],
508        parent_chaincode: &[u8],
509        index: u32,
510    ) -> Result<(Vec<u8>, [u8; 32])> {
511        // Combine parent key, parent chaincode and index deterministically
512        let mut data = Vec::with_capacity(parent_key.len() + parent_chaincode.len() + 4);
513        data.extend_from_slice(parent_key);
514        data.extend_from_slice(parent_chaincode);
515        data.extend_from_slice(&index.to_be_bytes());
516
517        let mut child_key = vec![0u8; parent_key.len().max(32)];
518        let mut child_chaincode = [0u8; 32];
519
520        HkdfSha3_256::derive(&data, Some(parent_chaincode), b"key", &mut child_key).map_err(
521            |_| {
522                P2PError::Security(SecurityError::InvalidKey(
523                    "Child key derivation failed".to_string().into(),
524                ))
525            },
526        )?;
527        HkdfSha3_256::derive(
528            &data,
529            Some(parent_chaincode),
530            b"chaincode",
531            &mut child_chaincode,
532        )
533        .map_err(|_| {
534            P2PError::Security(SecurityError::InvalidKey(
535                "Child chaincode derivation failed".to_string().into(),
536            ))
537        })?;
538
539        // Zeroize temporary data
540        data.zeroize();
541
542        Ok((child_key, child_chaincode))
543    }
544
545    /// Derive multiple keys in batch
546    pub fn derive_batch(
547        &mut self,
548        request: BatchDerivationRequest,
549    ) -> Result<BatchDerivationResult> {
550        let start_time = std::time::Instant::now();
551        let mut keys = HashMap::new();
552        let mut failures = HashMap::new();
553        let mut cache_hits = 0u64;
554
555        for path in request.paths {
556            match self.derive_key(&path) {
557                Ok(key) => {
558                    // Check if this was a cache hit
559                    if self.cache.get(&path).is_some() {
560                        cache_hits += 1;
561                    }
562                    keys.insert(path, key);
563                }
564                Err(e) => {
565                    failures.insert(path, e.to_string());
566                }
567            }
568        }
569
570        let processing_time = start_time.elapsed();
571        let total_requests = keys.len() + failures.len();
572        let cache_hit_rate = if total_requests > 0 {
573            cache_hits as f64 / total_requests as f64
574        } else {
575            0.0
576        };
577
578        Ok(BatchDerivationResult {
579            keys,
580            failures,
581            cache_hit_rate,
582            processing_time,
583        })
584    }
585
586    /// Get derivation statistics
587    pub fn stats(&self) -> DerivationStats {
588        let (cache_hits, cache_misses, cache_size) = self.cache.stats();
589
590        DerivationStats {
591            total_derived: self.master_seed.derivation_counter(),
592            cache_hits,
593            cache_misses,
594            avg_derivation_time_us: 0, // Would need to track this
595            batch_operations: 0,       // Would need to track this
596            cache_size,
597        }
598    }
599
600    /// Clear the derivation cache
601    pub fn clear_cache(&self) {
602        self.cache.clear();
603    }
604}
605
606impl DerivedKey {
607    /// Increment usage counter
608    pub fn increment_usage(&mut self) {
609        self.usage_counter += 1;
610    }
611
612    /// Get ML-DSA key pair
613    pub fn ml_dsa_keypair(&self) -> (MlDsaPublicKey, std::sync::Arc<MlDsaSecretKey>) {
614        (
615            self.public_key.clone(),
616            std::sync::Arc::clone(&self.secret_key),
617        )
618    }
619}
620
621/// Get current Unix timestamp
622fn current_timestamp() -> u64 {
623    std::time::SystemTime::now()
624        .duration_since(std::time::UNIX_EPOCH)
625        .map(|d| d.as_secs())
626        .unwrap_or(0)
627}
628
629/// Zeroize trait for secure memory clearing
630trait Zeroize {
631    fn zeroize(&mut self);
632}
633
634impl Zeroize for Vec<u8> {
635    fn zeroize(&mut self) {
636        for byte in self.iter_mut() {
637            *byte = 0;
638        }
639    }
640}
641
642impl Zeroize for [u8; 32] {
643    fn zeroize(&mut self) {
644        for byte in self.iter_mut() {
645            *byte = 0;
646        }
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    #[test]
655    fn test_master_seed_generation() {
656        let seed = MasterSeed::generate().unwrap();
657        assert_eq!(seed.seed_material().len(), MASTER_SEED_SIZE);
658        assert!(seed._created_at > 0);
659    }
660
661    #[test]
662    fn test_derivation_path_parsing() {
663        let path = DerivationPath::from_string("m/0'/1/2'").unwrap();
664        assert_eq!(path.components().len(), 3);
665        assert!(path.is_hardened(0));
666        assert!(!path.is_hardened(1));
667        assert!(path.is_hardened(2));
668
669        let path_str = path.to_string();
670        assert_eq!(path_str, "m/0'/1/2'");
671    }
672
673    #[test]
674    fn test_key_derivation() {
675        let master_seed = MasterSeed::generate().unwrap();
676        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
677
678        let path = DerivationPath::from_string("m/0'/1").unwrap();
679        let derived_key = derivation.derive_key(&path).unwrap();
680
681        assert_eq!(derived_key.path, path);
682        // Validate ML-DSA key sizes
683        assert_eq!(derived_key.public_key.as_bytes().len(), ML_DSA_PUB_LEN);
684        assert_eq!(derived_key.secret_key.as_bytes().len(), ML_DSA_SEC_LEN);
685    }
686
687    #[test]
688    fn test_key_derivation_cache() {
689        let master_seed = MasterSeed::generate().unwrap();
690        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
691
692        let path = DerivationPath::from_string("m/0'/1").unwrap();
693
694        // First derivation
695        let key1 = derivation.derive_key(&path).unwrap();
696
697        // Second derivation should use cache
698        let key2 = derivation.derive_key(&path).unwrap();
699
700        // Keys should be identical
701        assert_eq!(key1.secret_key.as_bytes(), key2.secret_key.as_bytes());
702        assert_eq!(key1.public_key.as_bytes(), key2.public_key.as_bytes());
703
704        // Check cache stats
705        let stats = derivation.stats();
706        assert!(stats.cache_hits > 0);
707    }
708
709    #[test]
710    fn test_batch_derivation() {
711        let master_seed = MasterSeed::generate().unwrap();
712        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
713
714        let paths = vec![
715            DerivationPath::from_string("m/0'/1").unwrap(),
716            DerivationPath::from_string("m/0'/2").unwrap(),
717            DerivationPath::from_string("m/1'/0").unwrap(),
718        ];
719
720        let request = BatchDerivationRequest {
721            paths: paths.clone(),
722            use_cache: true,
723            priority: DerivationPriority::Normal,
724        };
725
726        let result = derivation.derive_batch(request).unwrap();
727
728        assert_eq!(result.keys.len(), 3);
729        assert_eq!(result.failures.len(), 0);
730
731        // All paths should be present
732        for path in paths {
733            assert!(result.keys.contains_key(&path));
734        }
735    }
736
737    #[test]
738    fn test_derivation_path_depth_limit() {
739        let components = vec![0u32; MAX_DERIVATION_DEPTH + 1];
740        let result = DerivationPath::new(components);
741        assert!(result.is_err());
742    }
743
744    #[test]
745    fn test_hardened_derivation() {
746        let master_seed = MasterSeed::generate().unwrap();
747        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
748
749        let hardened_path = DerivationPath::from_string("m/0'").unwrap();
750        let normal_path = DerivationPath::from_string("m/0").unwrap();
751
752        let hardened_key = derivation.derive_key(&hardened_path).unwrap();
753        let normal_key = derivation.derive_key(&normal_path).unwrap();
754        // Keys should be different
755        assert_ne!(
756            hardened_key.secret_key.as_bytes(),
757            normal_key.secret_key.as_bytes()
758        );
759        assert_ne!(
760            hardened_key.public_key.as_bytes(),
761            normal_key.public_key.as_bytes()
762        );
763    }
764}