shadow_crypt_core/v1/
key_ops.rs

1use argon2::{Algorithm, Argon2, Params, Version};
2use zeroize::Zeroize;
3
4use crate::{
5    errors::KeyDerivationError, memory::SecureKey, report::KeyDerivationReport,
6    v1::key::KeyDerivationParams,
7};
8
9pub fn derive_key(
10    password: &[u8],
11    salt: &[u8],
12    kdf_params: &KeyDerivationParams,
13) -> Result<(SecureKey, KeyDerivationReport), KeyDerivationError> {
14    let start_time = std::time::Instant::now();
15    let algorithm = Algorithm::Argon2id;
16    let version = Version::V0x13; // Version 19
17    let params = Params::new(
18        kdf_params.memory_cost,
19        kdf_params.time_cost,
20        kdf_params.parallelism,
21        Some(kdf_params.key_size as usize),
22    )
23    .map_err(|e| KeyDerivationError::InvalidParameters(format!("Invalid KDF parameters: {}", e)))?;
24    let context = Argon2::new(algorithm, version, params);
25
26    let mut buffer = [0u8; 32];
27    context
28        .hash_password_into(password, salt, &mut buffer)
29        .map_err(|e| {
30            KeyDerivationError::DerivationFailed(format!("Key derivation failed: {}", e))
31        })?;
32
33    let key = SecureKey::new(buffer);
34    buffer.zeroize(); // Clear sensitive data from memory
35
36    let duration = start_time.elapsed();
37    let report = KeyDerivationReport::new(
38        "Argon2id".to_string(),
39        format!("{}", version as u8),
40        kdf_params,
41        duration,
42    );
43
44    Ok((key, report))
45}
46
47#[cfg(test)]
48mod tests {
49    use super::*;
50    use crate::{profile::SecurityProfile, v1::key::KeyDerivationParams};
51
52    fn get_test_params() -> KeyDerivationParams {
53        let profile = SecurityProfile::Test;
54        KeyDerivationParams::from(profile)
55    }
56
57    #[test]
58    fn test_derive_key_success() {
59        let password = b"test_password";
60        let salt = b"test_salt_16_bytes";
61        let params = get_test_params();
62
63        let result = derive_key(password, salt, &params);
64        assert!(result.is_ok());
65
66        let (key, report) = result.unwrap();
67        assert_eq!(key.as_bytes().len(), 32);
68        assert_eq!(report.algorithm, "Argon2id");
69        assert_eq!(report.algorithm_version, "19");
70        assert_eq!(report.memory_cost_kib, params.memory_cost);
71        assert_eq!(report.time_cost_iterations, params.time_cost);
72        assert_eq!(report.parallelism, params.parallelism);
73        assert_eq!(report.key_size_bytes, params.key_size);
74        assert!(report.duration.as_nanos() > 0);
75    }
76
77    #[test]
78    fn test_derive_key_deterministic() {
79        let password = b"test_password";
80        let salt = b"test_salt_16_bytes";
81        let params = get_test_params();
82
83        let (key1, _) = derive_key(password, salt, &params).unwrap();
84        let (key2, _) = derive_key(password, salt, &params).unwrap();
85
86        assert_eq!(key1.as_bytes(), key2.as_bytes());
87    }
88
89    #[test]
90    fn test_derive_key_different_passwords() {
91        let salt = b"test_salt_16_bytes";
92        let params = get_test_params();
93
94        let (key1, _) = derive_key(b"password1", salt, &params).unwrap();
95        let (key2, _) = derive_key(b"password2", salt, &params).unwrap();
96
97        assert_ne!(key1.as_bytes(), key2.as_bytes());
98    }
99
100    #[test]
101    fn test_derive_key_different_salts() {
102        let password = b"test_password";
103        let params = get_test_params();
104
105        let (key1, _) = derive_key(password, b"salt1_16_bytes_!", &params).unwrap();
106        let (key2, _) = derive_key(password, b"salt2_16_bytes_!", &params).unwrap();
107
108        assert_ne!(key1.as_bytes(), key2.as_bytes());
109    }
110
111    #[test]
112    fn test_derive_key_invalid_parameters() {
113        let password = b"test_password";
114        let salt = b"test_salt_16_bytes";
115
116        // Test with memory_cost = 0 (invalid)
117        let invalid_params = KeyDerivationParams::new(0, 1, 1, 32);
118        let result = derive_key(password, salt, &invalid_params);
119        assert!(result.is_err());
120        if let Err(KeyDerivationError::InvalidParameters(_)) = result {
121            // Expected error
122        } else {
123            panic!("Expected InvalidParameters error");
124        }
125    }
126
127    #[test]
128    fn test_derive_key_invalid_key_size() {
129        let password = b"test_password";
130        let salt = b"test_salt_16_bytes";
131
132        // Test with key_size = 0 (invalid)
133        let invalid_params = KeyDerivationParams::new(1024, 1, 1, 0);
134        let result = derive_key(password, salt, &invalid_params);
135        assert!(result.is_err());
136        if let Err(KeyDerivationError::InvalidParameters(_)) = result {
137            // Expected error
138        } else {
139            panic!("Expected InvalidParameters error");
140        }
141    }
142}