Skip to main content

saorsa_pqc/api/
kdf.rs

1//! Key Derivation Function (KDF) implementations
2//!
3//! Provides quantum-resistant KDF implementations:
4//! - HKDF-SHA3-256
5//! - HKDF-SHA3-512
6//! - HKDF-BLAKE3
7
8use crate::api::errors::{PqcError, PqcResult};
9use crate::api::traits::Kdf;
10use hkdf::Hkdf as HkdfImpl;
11use sha3::{Sha3_256, Sha3_512};
12use zeroize::Zeroizing;
13
14/// HKDF with SHA3-256
15pub struct HkdfSha3_256;
16
17impl Kdf for HkdfSha3_256 {
18    fn derive(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
19        let hkdf = HkdfImpl::<Sha3_256>::new(salt, ikm);
20        hkdf.expand(info, okm)
21            .map_err(|_| PqcError::InvalidKeyLength)?;
22        Ok(())
23    }
24
25    fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
26        let (prk, _) = HkdfImpl::<Sha3_256>::extract(salt, ikm);
27        prk.to_vec()
28    }
29
30    fn expand(prk: &[u8], info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
31        let hkdf = HkdfImpl::<Sha3_256>::from_prk(prk).map_err(|_| PqcError::InvalidKeyLength)?;
32        hkdf.expand(info, okm)
33            .map_err(|_| PqcError::InvalidKeyLength)?;
34        Ok(())
35    }
36
37    fn name() -> &'static str {
38        "HKDF-SHA3-256"
39    }
40}
41
42/// HKDF with SHA3-512
43pub struct HkdfSha3_512;
44
45impl Kdf for HkdfSha3_512 {
46    fn derive(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
47        let hkdf = HkdfImpl::<Sha3_512>::new(salt, ikm);
48        hkdf.expand(info, okm)
49            .map_err(|_| PqcError::InvalidKeyLength)?;
50        Ok(())
51    }
52
53    fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
54        let (prk, _) = HkdfImpl::<Sha3_512>::extract(salt, ikm);
55        prk.to_vec()
56    }
57
58    fn expand(prk: &[u8], info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
59        let hkdf = HkdfImpl::<Sha3_512>::from_prk(prk).map_err(|_| PqcError::InvalidKeyLength)?;
60        hkdf.expand(info, okm)
61            .map_err(|_| PqcError::InvalidKeyLength)?;
62        Ok(())
63    }
64
65    fn name() -> &'static str {
66        "HKDF-SHA3-512"
67    }
68}
69
70// Note: BLAKE3 has its own key derivation (blake3::derive_key) which is more suitable
71// than HKDF-BLAKE3. We provide HKDF with SHA3 variants for standard HKDF usage.
72
73/// High-level KDF selector
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum KdfAlgorithm {
76    /// HKDF with SHA3-256
77    HkdfSha3_256,
78    /// HKDF with SHA3-512
79    HkdfSha3_512,
80}
81
82impl KdfAlgorithm {
83    /// Derive key material
84    ///
85    /// # Errors
86    ///
87    /// Returns an error if key derivation fails due to invalid parameters
88    /// or HKDF operation failures.
89    pub fn derive(
90        &self,
91        ikm: &[u8],
92        salt: Option<&[u8]>,
93        info: &[u8],
94        output_len: usize,
95    ) -> PqcResult<Vec<u8>> {
96        let mut okm = vec![0u8; output_len];
97
98        match self {
99            Self::HkdfSha3_256 => {
100                HkdfSha3_256::derive(ikm, salt, info, &mut okm)?;
101            }
102            Self::HkdfSha3_512 => {
103                HkdfSha3_512::derive(ikm, salt, info, &mut okm)?;
104            }
105        }
106
107        Ok(okm)
108    }
109
110    /// Get the algorithm name
111    #[must_use]
112    pub fn name(&self) -> &'static str {
113        match self {
114            Self::HkdfSha3_256 => HkdfSha3_256::name(),
115            Self::HkdfSha3_512 => HkdfSha3_512::name(),
116        }
117    }
118}
119
120/// Helper functions for common KDF operations
121pub mod helpers {
122    use super::{HkdfSha3_256, HkdfSha3_512, Kdf, KdfAlgorithm, PqcResult, Zeroizing};
123    use crate::api::errors::PqcError;
124
125    /// Type alias for encryption and authentication key pair
126    pub type EncAuthKeyPair = (Zeroizing<[u8; 32]>, Zeroizing<[u8; 32]>);
127
128    /// Derive encryption and authentication keys from a shared secret
129    ///
130    /// # Errors
131    ///
132    /// Returns an error if key derivation fails or the derived key material
133    /// cannot be properly split into encryption and authentication keys.
134    pub fn derive_enc_auth_keys(shared_secret: &[u8], context: &[u8]) -> PqcResult<EncAuthKeyPair> {
135        let mut okm = Zeroizing::new([0u8; 64]);
136        HkdfSha3_512::derive(shared_secret, None, context, &mut okm[..])?;
137
138        let mut enc_key = Zeroizing::new([0u8; 32]);
139        let mut auth_key = Zeroizing::new([0u8; 32]);
140
141        enc_key.copy_from_slice(okm.get(..32).ok_or(PqcError::InvalidKeyLength)?);
142        auth_key.copy_from_slice(okm.get(32..).ok_or(PqcError::InvalidKeyLength)?);
143
144        Ok((enc_key, auth_key))
145    }
146
147    /// Derive a symmetric key from a password and salt
148    ///
149    /// # Errors
150    ///
151    /// Returns an error if PBKDF2 key derivation fails due to invalid parameters.
152    pub fn derive_key_from_password(
153        password: &[u8],
154        salt: &[u8],
155        iterations: u32,
156    ) -> PqcResult<Zeroizing<[u8; 32]>> {
157        use pbkdf2::pbkdf2_hmac;
158        use sha3::Sha3_256;
159
160        let mut key = Zeroizing::new([0u8; 32]);
161        pbkdf2_hmac::<Sha3_256>(password, salt, iterations, &mut key[..]);
162        Ok(key)
163    }
164
165    /// Simple key stretching for session keys
166    ///
167    /// # Errors
168    ///
169    /// Returns an error if key stretching fails due to invalid parameters
170    /// or HKDF derivation failures.
171    pub fn stretch_key(key: &[u8], label: &[u8], output_len: usize) -> PqcResult<Vec<u8>> {
172        KdfAlgorithm::HkdfSha3_256.derive(key, None, label, output_len)
173    }
174
175    /// Derive multiple keys from a master key
176    ///
177    /// # Errors
178    ///
179    /// Returns an error if key derivation fails for any of the requested labels
180    /// or if HKDF operations fail.
181    pub fn derive_key_hierarchy(
182        master_key: &[u8],
183        labels: &[&[u8]],
184    ) -> PqcResult<Vec<Zeroizing<Vec<u8>>>> {
185        let mut keys = Vec::new();
186
187        for label in labels {
188            let mut key = Zeroizing::new(vec![0u8; 32]);
189            HkdfSha3_256::derive(master_key, None, label, &mut key)?;
190            keys.push(key);
191        }
192
193        Ok(keys)
194    }
195}
196
197#[cfg(test)]
198#[allow(clippy::indexing_slicing)]
199#[allow(clippy::unwrap_used, clippy::expect_used)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn test_hkdf_sha3_256_basic() {
205        let ikm = b"input key material";
206        let salt = b"salt";
207        let info = b"info";
208        let mut okm = [0u8; 32];
209
210        HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm).unwrap();
211
212        // Verify deterministic output
213        let mut okm2 = [0u8; 32];
214        HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm2).unwrap();
215        assert_eq!(okm, okm2);
216
217        // Different salt should give different output
218        let mut okm3 = [0u8; 32];
219        HkdfSha3_256::derive(ikm, Some(b"different salt"), info, &mut okm3).unwrap();
220        assert_ne!(okm, okm3);
221    }
222
223    #[test]
224    fn test_hkdf_sha3_512_basic() {
225        let ikm = b"input key material";
226        let salt = b"salt";
227        let info = b"info";
228        let mut okm = [0u8; 64];
229
230        HkdfSha3_512::derive(ikm, Some(salt), info, &mut okm).unwrap();
231
232        // Verify we can derive different lengths
233        let mut okm_short = [0u8; 16];
234        HkdfSha3_512::derive(ikm, Some(salt), info, &mut okm_short).unwrap();
235
236        // First 16 bytes should match
237        assert_eq!(&okm[..16], &okm_short);
238    }
239
240    #[test]
241    fn test_extract_expand_separate() {
242        let ikm = b"input key material";
243        let salt = b"salt";
244        let info = b"info";
245
246        // Extract PRK
247        let prk = HkdfSha3_256::extract(Some(salt), ikm);
248        assert_eq!(prk.len(), 32); // SHA3-256 output size
249
250        // Expand from PRK
251        let mut okm1 = [0u8; 32];
252        HkdfSha3_256::expand(&prk, info, &mut okm1).unwrap();
253
254        // Compare with one-shot derive
255        let mut okm2 = [0u8; 32];
256        HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm2).unwrap();
257
258        assert_eq!(okm1, okm2);
259    }
260
261    #[test]
262    fn test_kdf_algorithm_enum() {
263        let ikm = b"input key material";
264        let salt = b"salt";
265        let info = b"info";
266
267        let key1 = KdfAlgorithm::HkdfSha3_256
268            .derive(ikm, Some(salt), info, 32)
269            .unwrap();
270        assert_eq!(key1.len(), 32);
271        assert_eq!(KdfAlgorithm::HkdfSha3_256.name(), "HKDF-SHA3-256");
272
273        let key2 = KdfAlgorithm::HkdfSha3_512
274            .derive(ikm, Some(salt), info, 64)
275            .unwrap();
276        assert_eq!(key2.len(), 64);
277        assert_eq!(KdfAlgorithm::HkdfSha3_512.name(), "HKDF-SHA3-512");
278    }
279
280    #[test]
281    fn test_derive_enc_auth_keys() {
282        let shared_secret = b"shared secret from key exchange";
283        let context = b"application context";
284
285        let (enc_key, auth_key) = helpers::derive_enc_auth_keys(shared_secret, context).unwrap();
286
287        assert_eq!(enc_key.len(), 32);
288        assert_eq!(auth_key.len(), 32);
289        assert_ne!(&enc_key[..], &auth_key[..]);
290
291        // Should be deterministic
292        let (enc_key2, auth_key2) = helpers::derive_enc_auth_keys(shared_secret, context).unwrap();
293        assert_eq!(&enc_key[..], &enc_key2[..]);
294        assert_eq!(&auth_key[..], &auth_key2[..]);
295    }
296
297    #[test]
298    fn test_key_stretching() {
299        let key = b"short key";
300        let label = b"session key";
301
302        let stretched = helpers::stretch_key(key, label, 64).unwrap();
303        assert_eq!(stretched.len(), 64);
304
305        // Different label should give different output
306        let stretched2 = helpers::stretch_key(key, b"different label", 64).unwrap();
307        assert_ne!(stretched, stretched2);
308    }
309
310    #[test]
311    fn test_key_hierarchy() {
312        let master_key = b"master key material";
313        let labels = vec![
314            b"encryption key".as_slice(),
315            b"authentication key".as_slice(),
316            b"signing key".as_slice(),
317        ];
318
319        let derived_keys = helpers::derive_key_hierarchy(master_key, &labels).unwrap();
320
321        assert_eq!(derived_keys.len(), 3);
322        for key in &derived_keys {
323            assert_eq!(key.len(), 32);
324        }
325
326        // All keys should be different
327        assert_ne!(&derived_keys[0][..], &derived_keys[1][..]);
328        assert_ne!(&derived_keys[1][..], &derived_keys[2][..]);
329        assert_ne!(&derived_keys[0][..], &derived_keys[2][..]);
330    }
331
332    #[test]
333    fn test_password_derivation() {
334        let password = b"password123";
335        let salt = b"random salt";
336        let iterations = 100; // Low for testing, use higher in production
337
338        let key1 = helpers::derive_key_from_password(password, salt, iterations).unwrap();
339        assert_eq!(key1.len(), 32);
340
341        // Should be deterministic
342        let key2 = helpers::derive_key_from_password(password, salt, iterations).unwrap();
343        assert_eq!(&key1[..], &key2[..]);
344
345        // Different salt should give different key
346        let key3 =
347            helpers::derive_key_from_password(password, b"different salt", iterations).unwrap();
348        assert_ne!(&key1[..], &key3[..]);
349    }
350}