sare_core/kdf/
mod.rs

1use hkdf::Hkdf;
2use rand::RngCore;
3use rand::rngs::OsRng;
4use secrecy::{ExposeSecret, SecretVec};
5use serde::{Deserialize, Serialize};
6use sha2::{Sha256, Sha512};
7pub mod error;
8
9use error::*;
10
11#[derive(Clone, Debug, Serialize, Deserialize)]
12pub enum HKDFAlgorithm {
13    SHA256,
14    SHA512,
15}
16
17impl HKDFAlgorithm {
18    pub fn get_output_size(&self) -> usize {
19        match &self {
20            HKDFAlgorithm::SHA256 => 32,
21            HKDFAlgorithm::SHA512 => 64,
22        }
23    }
24}
25
26pub trait KDF {
27    fn generate_salt() -> Vec<u8> {
28        let mut salt_buffer = [0u8; 8];
29
30        OsRng.fill_bytes(&mut salt_buffer);
31
32        salt_buffer.to_vec()
33    }
34}
35
36pub struct HKDF<'a> {
37    input_data: &'a SecretVec<u8>,
38    salt: Vec<u8>,
39    algorithm: HKDFAlgorithm,
40}
41
42impl<'a> HKDF<'a> {
43    pub fn new(input_data: &'a SecretVec<u8>, salt: Vec<u8>, algorithm: HKDFAlgorithm) -> Self {
44        HKDF {
45            input_data,
46            salt,
47            algorithm,
48        }
49    }
50
51    fn expand_sha256(
52        &self,
53        additional_context: Option<&[u8]>,
54        okm: &mut [u8],
55    ) -> Result<(), KDFError> {
56        let hkdf = Hkdf::<Sha256>::new(Some(&self.salt), self.input_data.expose_secret());
57        hkdf.expand(additional_context.unwrap_or(&[0]), okm)?;
58
59        Ok(())
60    }
61
62    fn expand_sha512(
63        &self,
64        additional_context: Option<&[u8]>,
65        okm: &mut [u8],
66    ) -> Result<(), KDFError> {
67        let hkdf = Hkdf::<Sha512>::new(Some(&self.salt), self.input_data.expose_secret());
68        hkdf.expand(additional_context.unwrap_or(&[0]), okm)?;
69
70        Ok(())
71    }
72
73    pub fn expand(&self, additional_context: Option<&[u8]>) -> Result<SecretVec<u8>, KDFError> {
74        let mut okm = vec![0u8; self.algorithm.get_output_size()];
75
76        match &self.algorithm {
77            HKDFAlgorithm::SHA256 => {
78                self.expand_sha256(additional_context, &mut okm)?;
79            }
80            HKDFAlgorithm::SHA512 => {
81                self.expand_sha512(additional_context, &mut okm)?;
82            }
83        }
84
85        Ok(SecretVec::from(okm))
86    }
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize)]
90pub enum PKDFAlgorithm {
91    Scrypt(u8, u32, u32),
92}
93
94#[derive(Clone)]
95pub struct PKDF<'a> {
96    input_data: &'a SecretVec<u8>,
97    pub salt: Vec<u8>,
98    pub algorithm: PKDFAlgorithm,
99}
100
101impl KDF for PKDF<'_> {}
102
103impl<'a> PKDF<'a> {
104    pub fn new(input_data: &'a SecretVec<u8>, salt: Vec<u8>, algorithm: PKDFAlgorithm) -> Self {
105        PKDF {
106            input_data,
107            salt,
108            algorithm,
109        }
110    }
111
112    // NOTE: To be used later
113    /*
114    pub fn calculate_scrypt_workfactor(&self) -> (usize, usize, usize) {
115        let n: usize = (self.workfactor_scale / 4).max(2);
116        let r = 8usize;
117        let p: u64 = ((2i64.pow(n as u32) / 20).max(1).ilog2()).max(1).into();
118        (n, r, p as usize)
119    }
120    */
121
122    pub fn derive_key(&self, key_length: usize) -> Result<SecretVec<u8>, KDFError> {
123        match &self.algorithm {
124            PKDFAlgorithm::Scrypt(n, r, p) => {
125                let params = scrypt::Params::new(*n, *r, *p, key_length)?;
126
127                let mut output = vec![0u8; key_length];
128                scrypt::scrypt(
129                    self.input_data.expose_secret(),
130                    &self.salt,
131                    &params,
132                    &mut output,
133                )?;
134
135                Ok(SecretVec::from(output))
136            }
137        }
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    const TEST_SALT: [u8; 9] = [1, 1, 1, 1, 1, 1, 1, 1, 1];
146    const TEST_INPUT_DATA: [u8; 6] = [6, 6, 6, 6, 6, 6];
147
148    const HKDF_SHA512_OUTPUT: [u8; 64] = [
149        46, 30, 33, 140, 17, 103, 238, 164, 144, 144, 87, 205, 161, 83, 5, 128, 209, 210, 128, 15,
150        170, 178, 211, 157, 130, 12, 111, 198, 100, 233, 90, 81, 165, 143, 136, 207, 139, 106, 43,
151        238, 207, 132, 156, 252, 170, 83, 253, 239, 179, 216, 72, 37, 218, 57, 122, 202, 198, 175,
152        44, 42, 8, 192, 18, 167,
153    ];
154
155    const SCRYPT_OUTPUT: [u8; 10] = [172, 240, 153, 61, 124, 223, 14, 128, 130, 37];
156
157    #[test]
158    fn hkdf() {
159        let binding = SecretVec::from(TEST_INPUT_DATA.to_vec());
160        let hkdf = HKDF::new(&binding, TEST_SALT.to_vec(), HKDFAlgorithm::SHA512);
161
162        let output = hkdf.expand(None).unwrap();
163
164        assert_eq!(HKDF_SHA512_OUTPUT, output.expose_secret().as_slice());
165    }
166
167    /*
168    #[test]
169    fn scrypt_workfactor_scale() {
170        let input_data = SecretVec::from(TEST_INPUT_DATA.to_vec());
171        let pkdf = PKDF::new(&input_data, &TEST_SALT, 60, PKDFAlgorithm::Scrypt);
172
173        let workfactor = pkdf.calculate_scrypt_workfactor();
174
175        assert_eq!((15, 8, 10), workfactor);
176    }
177    */
178
179    #[test]
180    fn scrypt_key_derive() {
181        let input_data = SecretVec::from(TEST_INPUT_DATA.to_vec());
182        let pkdf = PKDF::new(
183            &input_data,
184            TEST_SALT.to_vec(),
185            PKDFAlgorithm::Scrypt(5, 8, 1),
186        );
187
188        let output = pkdf.derive_key(10).unwrap();
189
190        assert_eq!(SCRYPT_OUTPUT, output.expose_secret().as_slice());
191    }
192}