Skip to main content

ssh_vault/vault/crypto/
chacha20poly1305.rs

1use anyhow::{Result, anyhow};
2use chacha20poly1305::{
3    ChaCha20Poly1305,
4    aead::{Aead, AeadCore, KeyInit, OsRng, Payload},
5};
6use secrecy::{ExposeSecret, SecretSlice};
7
8pub struct ChaCha20Poly1305Crypto {
9    key: SecretSlice<u8>,
10}
11
12impl super::Crypto for ChaCha20Poly1305Crypto {
13    fn new(key: SecretSlice<u8>) -> Self {
14        Self { key }
15    }
16
17    // Encrypts data with a key and a fingerprint
18    fn encrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>, anyhow::Error> {
19        let cipher = ChaCha20Poly1305::new(self.key.expose_secret().into());
20        let nonce = ChaCha20Poly1305::generate_nonce(&mut OsRng);
21        let payload = Payload {
22            msg: data,
23            aad: fingerprint,
24        };
25
26        cipher.encrypt(&nonce, payload).map_or_else(
27            |_| Err(anyhow!("Failed to encrypt data")),
28            |ciphertext| {
29                let mut encrypted_data = nonce.to_vec();
30                encrypted_data.extend_from_slice(&ciphertext);
31                Ok(encrypted_data)
32            },
33        )
34    }
35
36    // Decrypts data with a key and a fingerprint
37    fn decrypt(&self, data: &[u8], fingerprint: &[u8]) -> Result<Vec<u8>, anyhow::Error> {
38        // Validate data length before slicing
39        if data.len() < 12 {
40            return Err(anyhow!(
41                "Invalid encrypted data: too short (expected at least 12 bytes, got {})",
42                data.len()
43            ));
44        }
45
46        let cipher = ChaCha20Poly1305::new(self.key.expose_secret().into());
47        let (nonce, ciphertext) = data.split_at(12);
48        let decrypted_data = cipher
49            .decrypt(
50                nonce.into(),
51                Payload {
52                    msg: ciphertext,
53                    aad: fingerprint,
54                },
55            )
56            .map_err(|err| anyhow!("Error decrypting password: {err}"))?;
57
58        Ok(decrypted_data)
59    }
60}
61
62#[cfg(test)]
63#[allow(clippy::unwrap_used, clippy::unwrap_in_result)]
64mod tests {
65    use super::*;
66    use crate::vault::crypto::Crypto;
67    use rand::{RngCore, rngs::OsRng};
68    use std::collections::HashSet;
69
70    const TEST_DATA: &str = "The quick brown fox jumps over the lazy dog";
71    const FINGERPRINT: &str = "SHA256:hgIL5fEHz5zuOWY1CDlUuotdaUl4MvYG7vAgE4q4TzM";
72
73    #[test]
74    fn test_chacha20poly1305() {
75        let mut password = [0_u8; 32];
76        OsRng.fill_bytes(&mut password);
77        let key = SecretSlice::new(password.into());
78
79        let crypto = ChaCha20Poly1305Crypto::new(key);
80
81        let encrypted_data = crypto
82            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
83            .unwrap();
84        let decrypted_data = crypto
85            .decrypt(&encrypted_data, FINGERPRINT.as_bytes())
86            .unwrap();
87
88        assert_eq!(TEST_DATA.as_bytes(), decrypted_data);
89    }
90
91    #[test]
92    fn test_chacha20poly1305_wrong_fingerprint() {
93        let mut password = [0_u8; 32];
94        OsRng.fill_bytes(&mut password);
95        let key = SecretSlice::new(password.into());
96
97        let crypto = ChaCha20Poly1305Crypto::new(key);
98
99        let encrypted_data = crypto
100            .encrypt(TEST_DATA.as_bytes(), FINGERPRINT.as_bytes())
101            .unwrap();
102        let decrypted_data = crypto.decrypt(&encrypted_data, b"SHA256:invalid_fingerprint");
103
104        assert!(decrypted_data.is_err());
105    }
106
107    #[test]
108    fn test_chacha20poly1305_rand() {
109        let mut unique_keys = HashSet::new();
110
111        for _ in 0..1000 {
112            let mut rng = OsRng;
113            let mut key_bytes = [0u8; 32];
114            rng.fill_bytes(&mut key_bytes);
115
116            // Insert the key into the HashSet and ensure it's unique
117            assert!(unique_keys.insert(key_bytes), "Duplicate key found");
118
119            let key = SecretSlice::new(key_bytes.into());
120            let crypto = ChaCha20Poly1305Crypto::new(key);
121
122            // Generate random data
123            let mut data = vec![0u8; 300];
124            rng.fill_bytes(&mut data);
125
126            // Generate random fingerprint
127            let mut fingerprint = vec![0u8; 100];
128            rng.fill_bytes(&mut fingerprint);
129
130            let encrypted_data = crypto.encrypt(&data, &fingerprint).unwrap();
131            let decrypted_data = crypto.decrypt(&encrypted_data, &fingerprint).unwrap();
132            assert_eq!(data, decrypted_data);
133        }
134    }
135
136    #[test]
137    fn test_chacha20poly1305_decrypt_empty_data() {
138        let mut password = [0_u8; 32];
139        OsRng.fill_bytes(&mut password);
140        let key = SecretSlice::new(password.into());
141        let crypto = ChaCha20Poly1305Crypto::new(key);
142
143        let result = crypto.decrypt(&[], FINGERPRINT.as_bytes());
144        assert!(result.is_err());
145        assert!(result.unwrap_err().to_string().contains("too short"));
146    }
147
148    #[test]
149    fn test_chacha20poly1305_decrypt_short_data() {
150        let mut password = [0_u8; 32];
151        OsRng.fill_bytes(&mut password);
152        let key = SecretSlice::new(password.into());
153        let crypto = ChaCha20Poly1305Crypto::new(key);
154
155        // Test with various short lengths
156        for len in 1..12 {
157            let short_data = vec![0u8; len];
158            let result = crypto.decrypt(&short_data, FINGERPRINT.as_bytes());
159            assert!(result.is_err(), "Should fail with {len} bytes");
160            let err_msg = result.unwrap_err().to_string();
161            assert!(
162                err_msg.contains("too short"),
163                "Error message should mention 'too short', got: {err_msg}",
164            );
165            assert!(
166                err_msg.contains(&len.to_string()),
167                "Error message should mention length {len}",
168            );
169        }
170    }
171
172    #[test]
173    fn test_chacha20poly1305_decrypt_exact_minimum() {
174        let mut password = [0_u8; 32];
175        OsRng.fill_bytes(&mut password);
176        let key = SecretSlice::new(password.into());
177        let crypto = ChaCha20Poly1305Crypto::new(key);
178
179        // 12 bytes is minimum (nonce only, no ciphertext)
180        let data = vec![0u8; 12];
181        let result = crypto.decrypt(&data, FINGERPRINT.as_bytes());
182        // Should not panic, but will fail authentication
183        assert!(result.is_err());
184    }
185}