Skip to main content

ssh_vault/vault/crypto/
aes256.rs

1use aes_gcm::{
2    Aes256Gcm,
3    aead::{Aead, AeadCore, KeyInit, OsRng, Payload},
4};
5use anyhow::{Result, anyhow};
6use secrecy::{ExposeSecret, SecretSlice};
7
8pub struct Aes256Crypto {
9    key: SecretSlice<u8>,
10}
11
12impl super::Crypto for Aes256Crypto {
13    fn new(key: SecretSlice<u8>) -> Self {
14        Self { key }
15    }
16
17    // Encrypts data with a key and a fingerprint
18    fn encrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>> {
19        let key = self.key.expose_secret().into();
20        let cipher = Aes256Gcm::new(key);
21        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
22        let payload = Payload {
23            msg: data,
24            aad: fingerprint,
25        };
26
27        cipher.encrypt(&nonce, payload).map_or_else(
28            |_| Err(anyhow!("Failed to encrypt data")),
29            |ciphertext| {
30                let mut encrypted_data = nonce.to_vec();
31                encrypted_data.extend_from_slice(&ciphertext);
32                Ok(encrypted_data)
33            },
34        )
35    }
36
37    // Decrypts data with a key and a fingerprint
38    fn decrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>> {
39        // Validate data length before slicing
40        if data.len() < 12 {
41            return Err(anyhow!(
42                "Invalid encrypted data: too short (expected at least 12 bytes, got {})",
43                data.len()
44            ));
45        }
46
47        let key = self.key.expose_secret().into();
48        let cipher = Aes256Gcm::new(key);
49        let (nonce, ciphertext) = data.split_at(12);
50        let payload = Payload {
51            msg: ciphertext,
52            aad: fingerprint,
53        };
54
55        cipher
56            .decrypt(nonce.into(), payload)
57            .map_or_else(|_| Err(anyhow!("Failed to decrypt data")), Ok)
58    }
59}
60
61#[cfg(test)]
62#[allow(clippy::unwrap_used, clippy::unwrap_in_result)]
63mod tests {
64    use super::*;
65    use crate::vault::crypto::Crypto;
66    use rand::{RngCore, rngs::OsRng};
67    use std::collections::HashSet;
68
69    const TEST_DATA: &str = "The quick brown fox jumps over the lazy dog";
70    const FINGERPRINT: &str = "SHA256:hgIL5fEHz5zuOWY1CDlUuotdaUl4MvYG7vAgE4q4TzM";
71
72    #[test]
73    fn test_aes256() {
74        let mut password = [0_u8; 32];
75        OsRng.fill_bytes(&mut password);
76        let key = SecretSlice::new(password.into());
77
78        let crypto = Aes256Crypto::new(key);
79
80        let encrypted_data = crypto
81            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
82            .unwrap();
83        let decrypted_data = crypto
84            .decrypt(&encrypted_data, FINGERPRINT.as_bytes())
85            .unwrap();
86
87        assert_eq!(TEST_DATA.as_bytes(), decrypted_data);
88    }
89
90    #[test]
91    fn test_aes256_invalid_fingerprint() {
92        let mut password = [0_u8; 32];
93        OsRng.fill_bytes(&mut password);
94        let key = SecretSlice::new(password.into());
95
96        let crypto = Aes256Crypto::new(key);
97
98        let encrypted_data = crypto
99            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
100            .unwrap();
101        let decrypted_data = crypto.decrypt(&encrypted_data, b"SHA256:invalid_fingerprint");
102
103        assert!(decrypted_data.is_err());
104    }
105
106    #[test]
107    fn test_aes256_rand() {
108        let mut unique_keys = HashSet::new();
109
110        for _ in 0..1000 {
111            let mut rng = OsRng;
112            let mut key_bytes = [0u8; 32];
113            rng.fill_bytes(&mut key_bytes);
114
115            // Insert the key into the HashSet and ensure it's unique
116            assert!(unique_keys.insert(key_bytes), "Duplicate key found");
117
118            let key = SecretSlice::new(key_bytes.into());
119            let crypto = Aes256Crypto::new(key);
120
121            // Generate random data
122            let mut data = vec![0u8; 300];
123            rng.fill_bytes(&mut data);
124
125            // Generate random fingerprint
126            let mut fingerprint = vec![0u8; 100];
127            rng.fill_bytes(&mut fingerprint);
128
129            let encrypted_data = crypto.encrypt(&data, &fingerprint).unwrap();
130            let decrypted_data = crypto.decrypt(&encrypted_data, &fingerprint).unwrap();
131            assert_eq!(data, decrypted_data);
132        }
133    }
134
135    #[test]
136    fn test_aes256_decrypt_empty_data() {
137        let mut password = [0_u8; 32];
138        OsRng.fill_bytes(&mut password);
139        let key = SecretSlice::new(password.into());
140        let crypto = Aes256Crypto::new(key);
141
142        let result = crypto.decrypt(&[], FINGERPRINT.as_bytes());
143        assert!(result.is_err());
144        assert!(result.unwrap_err().to_string().contains("too short"));
145    }
146
147    #[test]
148    fn test_aes256_decrypt_short_data() {
149        let mut password = [0_u8; 32];
150        OsRng.fill_bytes(&mut password);
151        let key = SecretSlice::new(password.into());
152        let crypto = Aes256Crypto::new(key);
153
154        // Test with various short lengths
155        for len in 1..12 {
156            let short_data = vec![0u8; len];
157            let result = crypto.decrypt(&short_data, FINGERPRINT.as_bytes());
158            assert!(result.is_err(), "Should fail with {len} bytes");
159            let err_msg = result.unwrap_err().to_string();
160            assert!(
161                err_msg.contains("too short"),
162                "Error message should mention 'too short', got: {err_msg}",
163            );
164            assert!(
165                err_msg.contains(&len.to_string()),
166                "Error message should mention length {len}",
167            );
168        }
169    }
170
171    #[test]
172    fn test_aes256_decrypt_exact_minimum() {
173        let mut password = [0_u8; 32];
174        OsRng.fill_bytes(&mut password);
175        let key = SecretSlice::new(password.into());
176        let crypto = Aes256Crypto::new(key);
177
178        // 12 bytes is minimum (nonce only, no ciphertext)
179        let data = vec![0u8; 12];
180        let result = crypto.decrypt(&data, FINGERPRINT.as_bytes());
181        // Should not panic, but will fail authentication
182        assert!(result.is_err());
183    }
184}