saorsa_core/
key_derivation.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! # Hierarchical Key Derivation System
15//!
16//! This module implements BIP32-style hierarchical deterministic key derivation
17//! adapted for Ed25519/X25519 key pairs used in the P2P network.
18//!
19//! ## Security Features
20//! - Secure entropy generation for master seeds
21//! - HMAC-based key stretching (HKDF)
22//! - Deterministic key derivation from hierarchical paths
23//! - Key isolation between different derivation contexts
24//! - Side-channel resistance through constant-time operations
25//!
26//! ## Performance Features
27//! - Batch key derivation for multiple paths
28//! - Intelligent caching of derived keys
29//! - Memory-efficient storage of key material
30//! - Async key generation for non-blocking operations
31
32use crate::error::SecurityError;
33use crate::secure_memory::SecureMemory;
34use crate::{P2PError, Result};
35use ed25519_dalek::{SigningKey, VerifyingKey};
36use hkdf::Hkdf;
37use rand::{RngCore, thread_rng};
38use sha2::Sha256;
39use std::collections::HashMap;
40use std::sync::{Arc, RwLock};
41
42/// Size of master seed in bytes (256 bits for security)
43const MASTER_SEED_SIZE: usize = 32;
44
45/// Size of derived key material (32 bytes for Ed25519)
46#[allow(dead_code)]
47const DERIVED_KEY_SIZE: usize = 32;
48
49/// Maximum derivation depth to prevent stack overflow
50const MAX_DERIVATION_DEPTH: usize = 10;
51
52/// Size of derivation path index
53#[allow(dead_code)]
54const PATH_INDEX_SIZE: usize = 4;
55
56/// Hardened derivation marker (BIP32 style)
57const HARDENED_OFFSET: u32 = 0x80000000;
58
59/// Master seed for deterministic key derivation
60pub struct MasterSeed {
61    /// Secure seed material
62    seed: SecureMemory,
63    /// Creation timestamp
64    _created_at: u64,
65    /// Derivation counter for tracking usage
66    derivation_counter: u64,
67}
68
69/// Hierarchical key derivation path
70#[derive(Debug, Clone, PartialEq, Eq, Hash)]
71pub struct DerivationPath {
72    /// Path components (each can be hardened)
73    components: Vec<u32>,
74}
75
76/// Derived key material with metadata
77pub struct DerivedKey {
78    /// Ed25519 signing key
79    pub secret_key: SigningKey,
80    /// Ed25519 verifying key
81    pub public_key: VerifyingKey,
82    /// X25519 secret key (for key exchange)
83    pub x25519_secret: [u8; 32],
84    /// X25519 public key (for key exchange)
85    pub x25519_public: [u8; 32],
86    /// Derivation path used
87    pub path: DerivationPath,
88    /// Creation timestamp
89    pub created_at: u64,
90    /// Key usage counter
91    pub usage_counter: u64,
92}
93
94impl Clone for DerivedKey {
95    fn clone(&self) -> Self {
96        // Create new Ed25519 key from bytes
97        let signing_key = SigningKey::from_bytes(self.secret_key.as_bytes());
98        let verifying_key =
99            VerifyingKey::from_bytes(self.public_key.as_bytes()).unwrap_or(self.public_key);
100
101        Self {
102            secret_key: signing_key,
103            public_key: verifying_key,
104            x25519_secret: self.x25519_secret,
105            x25519_public: self.x25519_public,
106            path: self.path.clone(),
107            created_at: self.created_at,
108            usage_counter: self.usage_counter,
109        }
110    }
111}
112
113/// Key derivation cache for performance
114pub struct KeyDerivationCache {
115    /// Cached derived keys
116    cache: RwLock<HashMap<DerivationPath, DerivedKey>>,
117    /// Cache size limit
118    max_size: usize,
119    /// Cache hit statistics
120    hits: std::sync::atomic::AtomicU64,
121    /// Cache miss statistics
122    misses: std::sync::atomic::AtomicU64,
123}
124
125/// Hierarchical key derivation engine
126pub struct HierarchicalKeyDerivation {
127    /// Master seed
128    master_seed: MasterSeed,
129    /// Derivation cache
130    cache: Arc<KeyDerivationCache>,
131}
132
133/// Batch key derivation request
134pub struct BatchDerivationRequest {
135    /// Derivation paths to process
136    pub paths: Vec<DerivationPath>,
137    /// Whether to use cache
138    pub use_cache: bool,
139    /// Priority level for processing
140    pub priority: DerivationPriority,
141}
142
143/// Priority levels for key derivation
144#[derive(Debug, Clone, Copy, PartialEq, Eq)]
145pub enum DerivationPriority {
146    /// Low priority (background operations)
147    Low,
148    /// Normal priority (standard operations)
149    Normal,
150    /// High priority (time-sensitive operations)
151    High,
152    /// Critical priority (immediate operations)
153    Critical,
154}
155
156/// Results of batch key derivation
157pub struct BatchDerivationResult {
158    /// Successfully derived keys
159    pub keys: HashMap<DerivationPath, DerivedKey>,
160    /// Failed derivations with error messages
161    pub failures: HashMap<DerivationPath, String>,
162    /// Cache hit rate for this batch
163    pub cache_hit_rate: f64,
164    /// Total processing time
165    pub processing_time: std::time::Duration,
166}
167
168/// Statistics for key derivation performance
169#[derive(Debug, Clone, Default)]
170pub struct DerivationStats {
171    /// Total keys derived
172    pub total_derived: u64,
173    /// Total cache hits
174    pub cache_hits: u64,
175    /// Total cache misses
176    pub cache_misses: u64,
177    /// Average derivation time in microseconds
178    pub avg_derivation_time_us: u64,
179    /// Total batch operations
180    pub batch_operations: u64,
181    /// Current cache size
182    pub cache_size: usize,
183}
184
185impl MasterSeed {
186    /// Generate a new master seed with cryptographically secure randomness
187    pub fn generate() -> Result<Self> {
188        let mut seed_bytes = vec![0u8; MASTER_SEED_SIZE];
189        thread_rng().fill_bytes(&mut seed_bytes);
190
191        let seed = SecureMemory::from_slice(&seed_bytes)?;
192
193        // Zeroize the temporary buffer
194        seed_bytes.zeroize();
195
196        Ok(Self {
197            seed,
198            _created_at: current_timestamp(),
199            derivation_counter: 0,
200        })
201    }
202
203    /// Create master seed from existing entropy
204    pub fn from_entropy(entropy: &[u8]) -> Result<Self> {
205        if entropy.len() < MASTER_SEED_SIZE {
206            return Err(P2PError::Security(SecurityError::InvalidKey(
207                "Insufficient entropy for master seed".to_string().into(),
208            )));
209        }
210
211        let seed = SecureMemory::from_slice(&entropy[..MASTER_SEED_SIZE])?;
212
213        Ok(Self {
214            seed,
215            _created_at: current_timestamp(),
216            derivation_counter: 0,
217        })
218    }
219
220    /// Get the seed material for derivation
221    pub fn seed_material(&self) -> &[u8] {
222        self.seed.as_slice()
223    }
224
225    /// Increment derivation counter
226    pub fn increment_counter(&mut self) {
227        self.derivation_counter += 1;
228    }
229
230    /// Get derivation counter
231    pub fn derivation_counter(&self) -> u64 {
232        self.derivation_counter
233    }
234}
235
236impl DerivationPath {
237    /// Create a new derivation path
238    pub fn new(components: Vec<u32>) -> Result<Self> {
239        if components.len() > MAX_DERIVATION_DEPTH {
240            return Err(P2PError::Security(SecurityError::InvalidKey(
241                format!(
242                    "Derivation path too deep: {} > {}",
243                    components.len(),
244                    MAX_DERIVATION_DEPTH
245                )
246                .into(),
247            )));
248        }
249
250        Ok(Self { components })
251    }
252
253    /// Create path from string representation (e.g., "m/0'/1/2")
254    pub fn from_string(path_str: &str) -> Result<Self> {
255        let parts: Vec<&str> = path_str.split('/').collect();
256
257        if parts.is_empty() || parts[0] != "m" {
258            return Err(P2PError::Security(SecurityError::InvalidKey(
259                "Invalid derivation path format".to_string().into(),
260            )));
261        }
262
263        let mut components = Vec::new();
264
265        for part in parts.iter().skip(1) {
266            if part.is_empty() {
267                continue;
268            }
269
270            let (index_str, hardened) = if let Some(stripped) = part.strip_suffix('\'') {
271                (stripped, true)
272            } else {
273                (*part, false)
274            };
275
276            let index: u32 = index_str.parse().map_err(|_| {
277                P2PError::Security(SecurityError::InvalidKey(
278                    format!("Invalid path component: {part}").into(),
279                ))
280            })?;
281
282            let final_index = if hardened {
283                index + HARDENED_OFFSET
284            } else {
285                index
286            };
287
288            components.push(final_index);
289        }
290
291        Self::new(components)
292    }
293
294    // (Removed inherent to_string; rely on Display/ToString)
295    /// Get path components
296    pub fn components(&self) -> &[u32] {
297        &self.components
298    }
299
300    /// Check if path component is hardened
301    pub fn is_hardened(&self, index: usize) -> bool {
302        self.components
303            .get(index)
304            .map(|&c| c >= HARDENED_OFFSET)
305            .unwrap_or(false)
306    }
307
308    /// Get depth of derivation path
309    pub fn depth(&self) -> usize {
310        self.components.len()
311    }
312
313    /// Create child path by appending component
314    pub fn child(&self, component: u32) -> Result<Self> {
315        let mut new_components = self.components.clone();
316        new_components.push(component);
317        Self::new(new_components)
318    }
319
320    /// Create hardened child path
321    pub fn hardened_child(&self, index: u32) -> Result<Self> {
322        self.child(index + HARDENED_OFFSET)
323    }
324}
325
326impl std::fmt::Display for DerivationPath {
327    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
328        write!(f, "m")?;
329        for &component in &self.components {
330            write!(f, "/")?;
331            if component >= HARDENED_OFFSET {
332                write!(f, "{}'", component - HARDENED_OFFSET)?;
333            } else {
334                write!(f, "{}", component)?;
335            }
336        }
337        Ok(())
338    }
339}
340
341impl KeyDerivationCache {
342    /// Create new key derivation cache
343    pub fn new(max_size: usize) -> Self {
344        Self {
345            cache: RwLock::new(HashMap::new()),
346            max_size,
347            hits: std::sync::atomic::AtomicU64::new(0),
348            misses: std::sync::atomic::AtomicU64::new(0),
349        }
350    }
351
352    /// Get cached key
353    pub fn get(&self, path: &DerivationPath) -> Option<DerivedKey> {
354        let cache = match self.cache.read() {
355            Ok(c) => c,
356            Err(_) => return None,
357        };
358        if let Some(key) = cache.get(path) {
359            self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
360            Some(key.clone())
361        } else {
362            self.misses
363                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
364            None
365        }
366    }
367
368    /// Insert key into cache
369    pub fn insert(&self, path: DerivationPath, key: DerivedKey) {
370        let mut cache = match self.cache.write() {
371            Ok(c) => c,
372            Err(_) => return,
373        };
374
375        // Evict oldest entries if cache is full
376        if cache.len() >= self.max_size {
377            let oldest_path = cache
378                .iter()
379                .min_by_key(|(_, k)| k.created_at)
380                .map(|(p, _)| p.clone());
381
382            if let Some(path_to_remove) = oldest_path {
383                cache.remove(&path_to_remove);
384            }
385        }
386
387        cache.insert(path, key);
388    }
389
390    /// Clear the cache
391    pub fn clear(&self) {
392        let mut cache = match self.cache.write() {
393            Ok(c) => c,
394            Err(_) => return,
395        };
396        cache.clear();
397    }
398
399    /// Get cache statistics
400    pub fn stats(&self) -> (u64, u64, usize) {
401        let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
402        let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
403        let size = self.cache.read().map(|c| c.len()).unwrap_or(0);
404        (hits, misses, size)
405    }
406}
407
408impl HierarchicalKeyDerivation {
409    /// Create new hierarchical key derivation engine
410    pub fn new(master_seed: MasterSeed) -> Self {
411        let cache = Arc::new(KeyDerivationCache::new(1000)); // Default cache size
412        // Use thread-local RNG holder; generate randomness locally where needed
413        Self { master_seed, cache }
414    }
415
416    /// Create with custom cache size
417    pub fn with_cache_size(master_seed: MasterSeed, cache_size: usize) -> Self {
418        let cache = Arc::new(KeyDerivationCache::new(cache_size));
419        Self { master_seed, cache }
420    }
421
422    /// Derive key at specific path
423    pub fn derive_key(&mut self, path: &DerivationPath) -> Result<DerivedKey> {
424        // Check cache first
425        if let Some(cached_key) = self.cache.get(path) {
426            return Ok(cached_key);
427        }
428
429        // Perform actual derivation
430        let derived_key = self.derive_key_internal(path)?;
431
432        // Cache the result
433        self.cache.insert(path.clone(), derived_key.clone());
434
435        // Increment master seed counter
436        self.master_seed.increment_counter();
437
438        Ok(derived_key)
439    }
440
441    /// Internal key derivation implementation
442    fn derive_key_internal(&self, path: &DerivationPath) -> Result<DerivedKey> {
443        let mut current_key = self.master_seed.seed_material().to_vec();
444        let mut current_chaincode = [0u8; 32];
445
446        // Initial HKDF from master seed
447        let hkdf = Hkdf::<Sha256>::new(None, &current_key);
448        hkdf.expand(b"ed25519 seed", &mut current_key)
449            .map_err(|_| {
450                P2PError::Security(SecurityError::InvalidKey(
451                    "HKDF expansion failed".to_string().into(),
452                ))
453            })?;
454        hkdf.expand(b"chaincode", &mut current_chaincode)
455            .map_err(|_| {
456                P2PError::Security(SecurityError::InvalidKey(
457                    "HKDF expansion failed".to_string().into(),
458                ))
459            })?;
460
461        // Derive through each path component
462        for &component in path.components() {
463            let (new_key, new_chaincode) =
464                self.derive_child_key(&current_key, &current_chaincode, component)?;
465            current_key = new_key;
466            current_chaincode = new_chaincode;
467        }
468
469        // Generate Ed25519 key pair
470        let signing_key = SigningKey::from_bytes(&current_key[..32].try_into().map_err(|_| {
471            P2PError::Security(SecurityError::InvalidKey(
472                "Invalid Ed25519 secret key length".to_string().into(),
473            ))
474        })?);
475        let verifying_key = signing_key.verifying_key();
476
477        // Generate X25519 key pair for key exchange
478        let x25519_secret: [u8; 32] = current_key[..32].try_into().map_err(|_| {
479            P2PError::Security(SecurityError::InvalidKey(
480                "Invalid X25519 secret key".to_string().into(),
481            ))
482        })?;
483        let x25519_public = x25519_dalek::PublicKey::from(x25519_secret).to_bytes();
484
485        // Zeroize temporary key material
486        current_key.zeroize();
487        current_chaincode.zeroize();
488
489        Ok(DerivedKey {
490            secret_key: signing_key,
491            public_key: verifying_key,
492            x25519_secret,
493            x25519_public,
494            path: path.clone(),
495            created_at: current_timestamp(),
496            usage_counter: 0,
497        })
498    }
499
500    /// Derive child key from parent
501    fn derive_child_key(
502        &self,
503        parent_key: &[u8],
504        parent_chaincode: &[u8],
505        index: u32,
506    ) -> Result<(Vec<u8>, [u8; 32])> {
507        let mut data = Vec::new();
508
509        if index >= HARDENED_OFFSET {
510            // Hardened derivation
511            data.push(0x00);
512            data.extend_from_slice(parent_key);
513        } else {
514            // Non-hardened derivation
515            let signing_key =
516                SigningKey::from_bytes(&parent_key[..32].try_into().map_err(|_| {
517                    P2PError::Security(SecurityError::InvalidKey(
518                        "Invalid parent key length".to_string().into(),
519                    ))
520                })?);
521            let verifying_key = signing_key.verifying_key();
522            data.extend_from_slice(verifying_key.as_bytes());
523        }
524
525        data.extend_from_slice(&index.to_be_bytes());
526
527        let hkdf = Hkdf::<Sha256>::new(Some(parent_chaincode), &data);
528
529        let mut child_key = vec![0u8; 32];
530        let mut child_chaincode = [0u8; 32];
531
532        hkdf.expand(b"key", &mut child_key).map_err(|_| {
533            P2PError::Security(SecurityError::InvalidKey(
534                "Child key derivation failed".to_string().into(),
535            ))
536        })?;
537        hkdf.expand(b"chaincode", &mut child_chaincode)
538            .map_err(|_| {
539                P2PError::Security(SecurityError::InvalidKey(
540                    "Child chaincode derivation failed".to_string().into(),
541                ))
542            })?;
543
544        // Zeroize temporary data
545        data.zeroize();
546
547        Ok((child_key, child_chaincode))
548    }
549
550    /// Derive multiple keys in batch
551    pub fn derive_batch(
552        &mut self,
553        request: BatchDerivationRequest,
554    ) -> Result<BatchDerivationResult> {
555        let start_time = std::time::Instant::now();
556        let mut keys = HashMap::new();
557        let mut failures = HashMap::new();
558        let mut cache_hits = 0u64;
559
560        for path in request.paths {
561            match self.derive_key(&path) {
562                Ok(key) => {
563                    // Check if this was a cache hit
564                    if self.cache.get(&path).is_some() {
565                        cache_hits += 1;
566                    }
567                    keys.insert(path, key);
568                }
569                Err(e) => {
570                    failures.insert(path, e.to_string());
571                }
572            }
573        }
574
575        let processing_time = start_time.elapsed();
576        let total_requests = keys.len() + failures.len();
577        let cache_hit_rate = if total_requests > 0 {
578            cache_hits as f64 / total_requests as f64
579        } else {
580            0.0
581        };
582
583        Ok(BatchDerivationResult {
584            keys,
585            failures,
586            cache_hit_rate,
587            processing_time,
588        })
589    }
590
591    /// Get derivation statistics
592    pub fn stats(&self) -> DerivationStats {
593        let (cache_hits, cache_misses, cache_size) = self.cache.stats();
594
595        DerivationStats {
596            total_derived: self.master_seed.derivation_counter(),
597            cache_hits,
598            cache_misses,
599            avg_derivation_time_us: 0, // Would need to track this
600            batch_operations: 0,       // Would need to track this
601            cache_size,
602        }
603    }
604
605    /// Clear the derivation cache
606    pub fn clear_cache(&self) {
607        self.cache.clear();
608    }
609}
610
611impl DerivedKey {
612    /// Increment usage counter
613    pub fn increment_usage(&mut self) {
614        self.usage_counter += 1;
615    }
616
617    /// Get Ed25519 key pair
618    pub fn ed25519_keypair(&self) -> (SigningKey, VerifyingKey) {
619        (self.secret_key.clone(), self.public_key)
620    }
621
622    /// Get X25519 key pair
623    pub fn x25519_keypair(&self) -> ([u8; 32], [u8; 32]) {
624        (self.x25519_secret, self.x25519_public)
625    }
626}
627
628/// Get current Unix timestamp
629fn current_timestamp() -> u64 {
630    std::time::SystemTime::now()
631        .duration_since(std::time::UNIX_EPOCH)
632        .map(|d| d.as_secs())
633        .unwrap_or(0)
634}
635
636/// Zeroize trait for secure memory clearing
637trait Zeroize {
638    fn zeroize(&mut self);
639}
640
641impl Zeroize for Vec<u8> {
642    fn zeroize(&mut self) {
643        for byte in self.iter_mut() {
644            *byte = 0;
645        }
646    }
647}
648
649impl Zeroize for [u8; 32] {
650    fn zeroize(&mut self) {
651        for byte in self.iter_mut() {
652            *byte = 0;
653        }
654    }
655}
656
657#[cfg(test)]
658mod tests {
659    use super::*;
660
661    #[test]
662    fn test_master_seed_generation() {
663        let seed = MasterSeed::generate().unwrap();
664        assert_eq!(seed.seed_material().len(), MASTER_SEED_SIZE);
665        assert!(seed._created_at > 0);
666    }
667
668    #[test]
669    fn test_derivation_path_parsing() {
670        let path = DerivationPath::from_string("m/0'/1/2'").unwrap();
671        assert_eq!(path.components().len(), 3);
672        assert!(path.is_hardened(0));
673        assert!(!path.is_hardened(1));
674        assert!(path.is_hardened(2));
675
676        let path_str = path.to_string();
677        assert_eq!(path_str, "m/0'/1/2'");
678    }
679
680    #[test]
681    fn test_key_derivation() {
682        let master_seed = MasterSeed::generate().unwrap();
683        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
684
685        let path = DerivationPath::from_string("m/0'/1").unwrap();
686        let derived_key = derivation.derive_key(&path).unwrap();
687
688        assert_eq!(derived_key.path, path);
689        assert_eq!(derived_key.x25519_secret.len(), 32);
690        assert_eq!(derived_key.x25519_public.len(), 32);
691    }
692
693    #[test]
694    fn test_key_derivation_cache() {
695        let master_seed = MasterSeed::generate().unwrap();
696        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
697
698        let path = DerivationPath::from_string("m/0'/1").unwrap();
699
700        // First derivation
701        let key1 = derivation.derive_key(&path).unwrap();
702
703        // Second derivation should use cache
704        let key2 = derivation.derive_key(&path).unwrap();
705
706        // Keys should be identical
707        assert_eq!(key1.secret_key.as_bytes(), key2.secret_key.as_bytes());
708        assert_eq!(key1.public_key.as_bytes(), key2.public_key.as_bytes());
709
710        // Check cache stats
711        let stats = derivation.stats();
712        assert!(stats.cache_hits > 0);
713    }
714
715    #[test]
716    fn test_batch_derivation() {
717        let master_seed = MasterSeed::generate().unwrap();
718        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
719
720        let paths = vec![
721            DerivationPath::from_string("m/0'/1").unwrap(),
722            DerivationPath::from_string("m/0'/2").unwrap(),
723            DerivationPath::from_string("m/1'/0").unwrap(),
724        ];
725
726        let request = BatchDerivationRequest {
727            paths: paths.clone(),
728            use_cache: true,
729            priority: DerivationPriority::Normal,
730        };
731
732        let result = derivation.derive_batch(request).unwrap();
733
734        assert_eq!(result.keys.len(), 3);
735        assert_eq!(result.failures.len(), 0);
736
737        // All paths should be present
738        for path in paths {
739            assert!(result.keys.contains_key(&path));
740        }
741    }
742
743    #[test]
744    fn test_derivation_path_depth_limit() {
745        let components = vec![0u32; MAX_DERIVATION_DEPTH + 1];
746        let result = DerivationPath::new(components);
747        assert!(result.is_err());
748    }
749
750    #[test]
751    fn test_hardened_derivation() {
752        let master_seed = MasterSeed::generate().unwrap();
753        let mut derivation = HierarchicalKeyDerivation::new(master_seed);
754
755        let hardened_path = DerivationPath::from_string("m/0'").unwrap();
756        let normal_path = DerivationPath::from_string("m/0").unwrap();
757
758        let hardened_key = derivation.derive_key(&hardened_path).unwrap();
759        let normal_key = derivation.derive_key(&normal_path).unwrap();
760
761        // Keys should be different
762        assert_ne!(
763            hardened_key.secret_key.as_bytes(),
764            normal_key.secret_key.as_bytes()
765        );
766        assert_ne!(
767            hardened_key.public_key.as_bytes(),
768            normal_key.public_key.as_bytes()
769        );
770    }
771}