1use hkdf::Hkdf;
2use rand::RngCore;
3use rand::rngs::OsRng;
4use secrecy::{ExposeSecret, SecretVec};
5use serde::{Deserialize, Serialize};
6use sha2::{Sha256, Sha512};
7pub mod error;
8
9use error::*;
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub enum HKDFAlgorithm {
13 SHA256,
14 SHA512,
15}
16
17impl HKDFAlgorithm {
18 pub fn get_output_size(&self) -> usize {
19 match &self {
20 HKDFAlgorithm::SHA256 => 32,
21 HKDFAlgorithm::SHA512 => 64,
22 }
23 }
24}
25
26pub trait KDF {
27 fn generate_salt() -> Vec<u8> {
28 let mut salt_buffer = [0u8; 8];
29
30 OsRng.fill_bytes(&mut salt_buffer);
31
32 salt_buffer.to_vec()
33 }
34}
35
36pub struct HKDF<'a> {
37 input_data: &'a SecretVec<u8>,
38 salt: Vec<u8>,
39 algorithm: HKDFAlgorithm,
40}
41
42impl<'a> HKDF<'a> {
43 pub fn new(input_data: &'a SecretVec<u8>, salt: Vec<u8>, algorithm: HKDFAlgorithm) -> Self {
44 HKDF {
45 input_data,
46 salt,
47 algorithm,
48 }
49 }
50
51 fn expand_sha256(
52 &self,
53 additional_context: Option<&[u8]>,
54 okm: &mut [u8],
55 ) -> Result<(), KDFError> {
56 let hkdf = Hkdf::<Sha256>::new(Some(&self.salt), self.input_data.expose_secret());
57 hkdf.expand(additional_context.unwrap_or(&[0]), okm)?;
58
59 Ok(())
60 }
61
62 fn expand_sha512(
63 &self,
64 additional_context: Option<&[u8]>,
65 okm: &mut [u8],
66 ) -> Result<(), KDFError> {
67 let hkdf = Hkdf::<Sha512>::new(Some(&self.salt), self.input_data.expose_secret());
68 hkdf.expand(additional_context.unwrap_or(&[0]), okm)?;
69
70 Ok(())
71 }
72
73 pub fn expand(&self, additional_context: Option<&[u8]>) -> Result<SecretVec<u8>, KDFError> {
74 let mut okm = vec![0u8; self.algorithm.get_output_size()];
75
76 match &self.algorithm {
77 HKDFAlgorithm::SHA256 => {
78 self.expand_sha256(additional_context, &mut okm)?;
79 }
80 HKDFAlgorithm::SHA512 => {
81 self.expand_sha512(additional_context, &mut okm)?;
82 }
83 }
84
85 Ok(SecretVec::from(okm))
86 }
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize)]
90pub enum PKDFAlgorithm {
91 Scrypt(u8, u32, u32),
92}
93
94#[derive(Clone)]
95pub struct PKDF<'a> {
96 input_data: &'a SecretVec<u8>,
97 pub salt: Vec<u8>,
98 pub algorithm: PKDFAlgorithm,
99}
100
101impl KDF for PKDF<'_> {}
102
103impl<'a> PKDF<'a> {
104 pub fn new(input_data: &'a SecretVec<u8>, salt: Vec<u8>, algorithm: PKDFAlgorithm) -> Self {
105 PKDF {
106 input_data,
107 salt,
108 algorithm,
109 }
110 }
111
112 pub fn derive_key(&self, key_length: usize) -> Result<SecretVec<u8>, KDFError> {
123 match &self.algorithm {
124 PKDFAlgorithm::Scrypt(n, r, p) => {
125 let params = scrypt::Params::new(*n, *r, *p, key_length)?;
126
127 let mut output = vec![0u8; key_length];
128 scrypt::scrypt(
129 self.input_data.expose_secret(),
130 &self.salt,
131 ¶ms,
132 &mut output,
133 )?;
134
135 Ok(SecretVec::from(output))
136 }
137 }
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 const TEST_SALT: [u8; 9] = [1, 1, 1, 1, 1, 1, 1, 1, 1];
146 const TEST_INPUT_DATA: [u8; 6] = [6, 6, 6, 6, 6, 6];
147
148 const HKDF_SHA512_OUTPUT: [u8; 64] = [
149 46, 30, 33, 140, 17, 103, 238, 164, 144, 144, 87, 205, 161, 83, 5, 128, 209, 210, 128, 15,
150 170, 178, 211, 157, 130, 12, 111, 198, 100, 233, 90, 81, 165, 143, 136, 207, 139, 106, 43,
151 238, 207, 132, 156, 252, 170, 83, 253, 239, 179, 216, 72, 37, 218, 57, 122, 202, 198, 175,
152 44, 42, 8, 192, 18, 167,
153 ];
154
155 const SCRYPT_OUTPUT: [u8; 10] = [172, 240, 153, 61, 124, 223, 14, 128, 130, 37];
156
157 #[test]
158 fn hkdf() {
159 let binding = SecretVec::from(TEST_INPUT_DATA.to_vec());
160 let hkdf = HKDF::new(&binding, TEST_SALT.to_vec(), HKDFAlgorithm::SHA512);
161
162 let output = hkdf.expand(None).unwrap();
163
164 assert_eq!(HKDF_SHA512_OUTPUT, output.expose_secret().as_slice());
165 }
166
167 #[test]
180 fn scrypt_key_derive() {
181 let input_data = SecretVec::from(TEST_INPUT_DATA.to_vec());
182 let pkdf = PKDF::new(
183 &input_data,
184 TEST_SALT.to_vec(),
185 PKDFAlgorithm::Scrypt(5, 8, 1),
186 );
187
188 let output = pkdf.derive_key(10).unwrap();
189
190 assert_eq!(SCRYPT_OUTPUT, output.expose_secret().as_slice());
191 }
192}