1use crate::api::errors::{PqcError, PqcResult};
9use crate::api::traits::Kdf;
10use hkdf::Hkdf as HkdfImpl;
11use sha3::{Sha3_256, Sha3_512};
12use zeroize::Zeroizing;
13
14pub struct HkdfSha3_256;
16
17impl Kdf for HkdfSha3_256 {
18 fn derive(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
19 let hkdf = HkdfImpl::<Sha3_256>::new(salt, ikm);
20 hkdf.expand(info, okm)
21 .map_err(|_| PqcError::InvalidKeyLength)?;
22 Ok(())
23 }
24
25 fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
26 let (prk, _) = HkdfImpl::<Sha3_256>::extract(salt, ikm);
27 prk.to_vec()
28 }
29
30 fn expand(prk: &[u8], info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
31 let hkdf = HkdfImpl::<Sha3_256>::from_prk(prk).map_err(|_| PqcError::InvalidKeyLength)?;
32 hkdf.expand(info, okm)
33 .map_err(|_| PqcError::InvalidKeyLength)?;
34 Ok(())
35 }
36
37 fn name() -> &'static str {
38 "HKDF-SHA3-256"
39 }
40}
41
42pub struct HkdfSha3_512;
44
45impl Kdf for HkdfSha3_512 {
46 fn derive(ikm: &[u8], salt: Option<&[u8]>, info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
47 let hkdf = HkdfImpl::<Sha3_512>::new(salt, ikm);
48 hkdf.expand(info, okm)
49 .map_err(|_| PqcError::InvalidKeyLength)?;
50 Ok(())
51 }
52
53 fn extract(salt: Option<&[u8]>, ikm: &[u8]) -> Vec<u8> {
54 let (prk, _) = HkdfImpl::<Sha3_512>::extract(salt, ikm);
55 prk.to_vec()
56 }
57
58 fn expand(prk: &[u8], info: &[u8], okm: &mut [u8]) -> PqcResult<()> {
59 let hkdf = HkdfImpl::<Sha3_512>::from_prk(prk).map_err(|_| PqcError::InvalidKeyLength)?;
60 hkdf.expand(info, okm)
61 .map_err(|_| PqcError::InvalidKeyLength)?;
62 Ok(())
63 }
64
65 fn name() -> &'static str {
66 "HKDF-SHA3-512"
67 }
68}
69
70#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum KdfAlgorithm {
76 HkdfSha3_256,
78 HkdfSha3_512,
80}
81
82impl KdfAlgorithm {
83 pub fn derive(
90 &self,
91 ikm: &[u8],
92 salt: Option<&[u8]>,
93 info: &[u8],
94 output_len: usize,
95 ) -> PqcResult<Vec<u8>> {
96 let mut okm = vec![0u8; output_len];
97
98 match self {
99 Self::HkdfSha3_256 => {
100 HkdfSha3_256::derive(ikm, salt, info, &mut okm)?;
101 }
102 Self::HkdfSha3_512 => {
103 HkdfSha3_512::derive(ikm, salt, info, &mut okm)?;
104 }
105 }
106
107 Ok(okm)
108 }
109
110 #[must_use]
112 pub fn name(&self) -> &'static str {
113 match self {
114 Self::HkdfSha3_256 => HkdfSha3_256::name(),
115 Self::HkdfSha3_512 => HkdfSha3_512::name(),
116 }
117 }
118}
119
120pub mod helpers {
122 use super::{HkdfSha3_256, HkdfSha3_512, Kdf, KdfAlgorithm, PqcResult, Zeroizing};
123 use crate::api::errors::PqcError;
124
125 pub type EncAuthKeyPair = (Zeroizing<[u8; 32]>, Zeroizing<[u8; 32]>);
127
128 pub fn derive_enc_auth_keys(shared_secret: &[u8], context: &[u8]) -> PqcResult<EncAuthKeyPair> {
135 let mut okm = Zeroizing::new([0u8; 64]);
136 HkdfSha3_512::derive(shared_secret, None, context, &mut okm[..])?;
137
138 let mut enc_key = Zeroizing::new([0u8; 32]);
139 let mut auth_key = Zeroizing::new([0u8; 32]);
140
141 enc_key.copy_from_slice(okm.get(..32).ok_or(PqcError::InvalidKeyLength)?);
142 auth_key.copy_from_slice(okm.get(32..).ok_or(PqcError::InvalidKeyLength)?);
143
144 Ok((enc_key, auth_key))
145 }
146
147 pub fn derive_key_from_password(
153 password: &[u8],
154 salt: &[u8],
155 iterations: u32,
156 ) -> PqcResult<Zeroizing<[u8; 32]>> {
157 use pbkdf2::pbkdf2_hmac;
158 use sha3::Sha3_256;
159
160 let mut key = Zeroizing::new([0u8; 32]);
161 pbkdf2_hmac::<Sha3_256>(password, salt, iterations, &mut key[..]);
162 Ok(key)
163 }
164
165 pub fn stretch_key(key: &[u8], label: &[u8], output_len: usize) -> PqcResult<Vec<u8>> {
172 KdfAlgorithm::HkdfSha3_256.derive(key, None, label, output_len)
173 }
174
175 pub fn derive_key_hierarchy(
182 master_key: &[u8],
183 labels: &[&[u8]],
184 ) -> PqcResult<Vec<Zeroizing<Vec<u8>>>> {
185 let mut keys = Vec::new();
186
187 for label in labels {
188 let mut key = Zeroizing::new(vec![0u8; 32]);
189 HkdfSha3_256::derive(master_key, None, label, &mut key)?;
190 keys.push(key);
191 }
192
193 Ok(keys)
194 }
195}
196
197#[cfg(test)]
198#[allow(clippy::indexing_slicing)]
199#[allow(clippy::unwrap_used, clippy::expect_used)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn test_hkdf_sha3_256_basic() {
205 let ikm = b"input key material";
206 let salt = b"salt";
207 let info = b"info";
208 let mut okm = [0u8; 32];
209
210 HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm).unwrap();
211
212 let mut okm2 = [0u8; 32];
214 HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm2).unwrap();
215 assert_eq!(okm, okm2);
216
217 let mut okm3 = [0u8; 32];
219 HkdfSha3_256::derive(ikm, Some(b"different salt"), info, &mut okm3).unwrap();
220 assert_ne!(okm, okm3);
221 }
222
223 #[test]
224 fn test_hkdf_sha3_512_basic() {
225 let ikm = b"input key material";
226 let salt = b"salt";
227 let info = b"info";
228 let mut okm = [0u8; 64];
229
230 HkdfSha3_512::derive(ikm, Some(salt), info, &mut okm).unwrap();
231
232 let mut okm_short = [0u8; 16];
234 HkdfSha3_512::derive(ikm, Some(salt), info, &mut okm_short).unwrap();
235
236 assert_eq!(&okm[..16], &okm_short);
238 }
239
240 #[test]
241 fn test_extract_expand_separate() {
242 let ikm = b"input key material";
243 let salt = b"salt";
244 let info = b"info";
245
246 let prk = HkdfSha3_256::extract(Some(salt), ikm);
248 assert_eq!(prk.len(), 32); let mut okm1 = [0u8; 32];
252 HkdfSha3_256::expand(&prk, info, &mut okm1).unwrap();
253
254 let mut okm2 = [0u8; 32];
256 HkdfSha3_256::derive(ikm, Some(salt), info, &mut okm2).unwrap();
257
258 assert_eq!(okm1, okm2);
259 }
260
261 #[test]
262 fn test_kdf_algorithm_enum() {
263 let ikm = b"input key material";
264 let salt = b"salt";
265 let info = b"info";
266
267 let key1 = KdfAlgorithm::HkdfSha3_256
268 .derive(ikm, Some(salt), info, 32)
269 .unwrap();
270 assert_eq!(key1.len(), 32);
271 assert_eq!(KdfAlgorithm::HkdfSha3_256.name(), "HKDF-SHA3-256");
272
273 let key2 = KdfAlgorithm::HkdfSha3_512
274 .derive(ikm, Some(salt), info, 64)
275 .unwrap();
276 assert_eq!(key2.len(), 64);
277 assert_eq!(KdfAlgorithm::HkdfSha3_512.name(), "HKDF-SHA3-512");
278 }
279
280 #[test]
281 fn test_derive_enc_auth_keys() {
282 let shared_secret = b"shared secret from key exchange";
283 let context = b"application context";
284
285 let (enc_key, auth_key) = helpers::derive_enc_auth_keys(shared_secret, context).unwrap();
286
287 assert_eq!(enc_key.len(), 32);
288 assert_eq!(auth_key.len(), 32);
289 assert_ne!(&enc_key[..], &auth_key[..]);
290
291 let (enc_key2, auth_key2) = helpers::derive_enc_auth_keys(shared_secret, context).unwrap();
293 assert_eq!(&enc_key[..], &enc_key2[..]);
294 assert_eq!(&auth_key[..], &auth_key2[..]);
295 }
296
297 #[test]
298 fn test_key_stretching() {
299 let key = b"short key";
300 let label = b"session key";
301
302 let stretched = helpers::stretch_key(key, label, 64).unwrap();
303 assert_eq!(stretched.len(), 64);
304
305 let stretched2 = helpers::stretch_key(key, b"different label", 64).unwrap();
307 assert_ne!(stretched, stretched2);
308 }
309
310 #[test]
311 fn test_key_hierarchy() {
312 let master_key = b"master key material";
313 let labels = vec![
314 b"encryption key".as_slice(),
315 b"authentication key".as_slice(),
316 b"signing key".as_slice(),
317 ];
318
319 let derived_keys = helpers::derive_key_hierarchy(master_key, &labels).unwrap();
320
321 assert_eq!(derived_keys.len(), 3);
322 for key in &derived_keys {
323 assert_eq!(key.len(), 32);
324 }
325
326 assert_ne!(&derived_keys[0][..], &derived_keys[1][..]);
328 assert_ne!(&derived_keys[1][..], &derived_keys[2][..]);
329 assert_ne!(&derived_keys[0][..], &derived_keys[2][..]);
330 }
331
332 #[test]
333 fn test_password_derivation() {
334 let password = b"password123";
335 let salt = b"random salt";
336 let iterations = 100; let key1 = helpers::derive_key_from_password(password, salt, iterations).unwrap();
339 assert_eq!(key1.len(), 32);
340
341 let key2 = helpers::derive_key_from_password(password, salt, iterations).unwrap();
343 assert_eq!(&key1[..], &key2[..]);
344
345 let key3 =
347 helpers::derive_key_from_password(password, b"different salt", iterations).unwrap();
348 assert_ne!(&key1[..], &key3[..]);
349 }
350}