ssh_vault/vault/crypto/
aes256.rs

1use aes_gcm::{
2    Aes256Gcm,
3    aead::{Aead, AeadCore, KeyInit, OsRng, Payload},
4};
5use anyhow::{Result, anyhow};
6use secrecy::{ExposeSecret, SecretSlice};
7
8pub struct Aes256Crypto {
9    key: SecretSlice<u8>,
10}
11
12impl super::Crypto for Aes256Crypto {
13    fn new(key: SecretSlice<u8>) -> Self {
14        Self { key }
15    }
16
17    // Encrypts data with a key and a fingerprint
18    fn encrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>> {
19        let key = self.key.expose_secret().into();
20        let cipher = Aes256Gcm::new(key);
21        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
22        let payload = Payload {
23            msg: data,
24            aad: fingerprint,
25        };
26
27        cipher.encrypt(&nonce, payload).map_or_else(
28            |_| Err(anyhow!("Failed to encrypt data")),
29            |ciphertext| {
30                let mut encrypted_data = nonce.to_vec();
31                encrypted_data.extend_from_slice(&ciphertext);
32                Ok(encrypted_data)
33            },
34        )
35    }
36
37    // Decrypts data with a key and a fingerprint
38    fn decrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>> {
39        let key = self.key.expose_secret().into();
40        let cipher = Aes256Gcm::new(key);
41        let nonce = (&data[..12]).into();
42        let ciphertext = &data[12..];
43        let payload = Payload {
44            msg: ciphertext,
45            aad: fingerprint,
46        };
47
48        cipher
49            .decrypt(nonce, payload)
50            .map_or_else(|_| Err(anyhow!("Failed to decrypt data")), Ok)
51    }
52}
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57    use crate::vault::crypto::Crypto;
58    use rand::{RngCore, rngs::OsRng};
59    use std::collections::HashSet;
60
61    const TEST_DATA: &str = "The quick brown fox jumps over the lazy dog";
62    const FINGERPRINT: &str = "SHA256:hgIL5fEHz5zuOWY1CDlUuotdaUl4MvYG7vAgE4q4TzM";
63
64    #[test]
65    fn test_aes256() {
66        let mut password = [0_u8; 32];
67        OsRng.fill_bytes(&mut password);
68        let key = SecretSlice::new(password.into());
69
70        let crypto = Aes256Crypto::new(key);
71
72        let encrypted_data = crypto
73            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
74            .unwrap();
75        let decrypted_data = crypto
76            .decrypt(&encrypted_data, FINGERPRINT.as_bytes())
77            .unwrap();
78
79        assert_eq!(TEST_DATA.as_bytes(), decrypted_data)
80    }
81
82    #[test]
83    fn test_aes256_invalid_fingerprint() {
84        let mut password = [0_u8; 32];
85        OsRng.fill_bytes(&mut password);
86        let key = SecretSlice::new(password.into());
87
88        let crypto = Aes256Crypto::new(key);
89
90        let encrypted_data = crypto
91            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
92            .unwrap();
93        let decrypted_data = crypto.decrypt(&encrypted_data, b"SHA256:invalid_fingerprint");
94
95        assert!(decrypted_data.is_err());
96    }
97
98    #[test]
99    fn test_aes256_rand() {
100        let mut unique_keys = HashSet::new();
101
102        for _ in 0..1000 {
103            let mut rng = OsRng;
104            let mut key_bytes = [0u8; 32];
105            rng.fill_bytes(&mut key_bytes);
106
107            // Insert the key into the HashSet
108            let is_duplicate = !unique_keys.insert(key_bytes.clone());
109
110            // Check if it's a duplicate and assert
111            if is_duplicate {
112                assert!(false, "Duplicate key found")
113            }
114
115            let key = SecretSlice::new(key_bytes.into());
116            let crypto = Aes256Crypto::new(key);
117
118            // Generate random data
119            let mut data = vec![0u8; 300];
120            rng.fill_bytes(&mut data);
121
122            // Generate random fingerprint
123            let mut fingerprint = vec![0u8; 100];
124            rng.fill_bytes(&mut fingerprint);
125
126            let encrypted_data = crypto.encrypt(&data, &fingerprint).unwrap();
127            let decrypted_data = crypto.decrypt(&encrypted_data, &fingerprint).unwrap();
128            assert_eq!(data, decrypted_data);
129        }
130    }
131}