Skip to main content

shard_core/
encryption.rs

1use aes_gcm::aead::{Aead, OsRng};
2use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce};
3use anyhow::Result;
4use rand::RngCore;
5use std::fs;
6use std::path::Path;
7
8const KEY_SIZE: usize = 32;
9const NONCE_SIZE: usize = 12;
10
11pub struct RepoCipher {
12    cipher: Aes256Gcm,
13}
14
15impl RepoCipher {
16    pub fn generate() -> Self {
17        let key = Aes256Gcm::generate_key(OsRng);
18        Self {
19            cipher: Aes256Gcm::new(&key),
20        }
21    }
22
23    pub fn from_key(key_bytes: &[u8; KEY_SIZE]) -> Self {
24        let key = Key::<Aes256Gcm>::from_slice(key_bytes);
25        Self {
26            cipher: Aes256Gcm::new(key),
27        }
28    }
29
30    pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
31        let mut nonce_bytes = [0u8; NONCE_SIZE];
32        OsRng.fill_bytes(&mut nonce_bytes);
33        let nonce = Nonce::from_slice(&nonce_bytes);
34        let ciphertext = self
35            .cipher
36            .encrypt(nonce, plaintext)
37            .expect("encryption should never fail with given params");
38        let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
39        result.extend_from_slice(&nonce_bytes);
40        result.extend_from_slice(&ciphertext);
41        result
42    }
43
44    pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
45        if data.len() < NONCE_SIZE {
46            anyhow::bail!(
47                "encrypted data too short (need {} bytes for nonce)",
48                NONCE_SIZE
49            );
50        }
51        let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
52        let nonce = Nonce::from_slice(nonce_bytes);
53        let plaintext = self
54            .cipher
55            .decrypt(nonce, ciphertext)
56            .map_err(|_| anyhow::anyhow!("decryption failed — wrong key or corrupted data"))?;
57        Ok(plaintext)
58    }
59
60    pub fn key_bytes(&self) -> [u8; KEY_SIZE] {
61        // Aes256Gcm doesn't expose the key directly. We re-generate from a stored copy.
62        // This is a placeholder — keys are managed externally via save/load functions.
63        unimplemented!("use save_repo_key / load_repo_key instead")
64    }
65}
66
67pub fn generate_repo_key() -> [u8; KEY_SIZE] {
68    let mut key = [0u8; KEY_SIZE];
69    OsRng.fill_bytes(&mut key);
70    key
71}
72
73pub fn save_repo_key(keys_dir: &Path, key: &[u8; KEY_SIZE]) -> Result<()> {
74    fs::write(keys_dir.join("repo.key"), hex::encode(key))?;
75    Ok(())
76}
77
78pub fn load_repo_key(keys_dir: &Path) -> Result<[u8; KEY_SIZE]> {
79    let hex_key = fs::read_to_string(keys_dir.join("repo.key"))?;
80    let key_hex = hex_key.trim();
81    let bytes = hex::decode(key_hex)?;
82    if bytes.len() != KEY_SIZE {
83        anyhow::bail!(
84            "invalid repo.key: expected {} bytes, got {}",
85            KEY_SIZE,
86            bytes.len()
87        );
88    }
89    let mut key = [0u8; KEY_SIZE];
90    key.copy_from_slice(&bytes);
91    Ok(key)
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn test_encrypt_decrypt_roundtrip() {
100        let cipher = RepoCipher::generate();
101        let data = b"hello world this is test data";
102        let encrypted = cipher.encrypt(data);
103        assert_ne!(encrypted, data);
104        assert!(encrypted.len() > NONCE_SIZE);
105        let decrypted = cipher.decrypt(&encrypted).unwrap();
106        assert_eq!(decrypted, data);
107    }
108
109    #[test]
110    fn test_encrypt_different_nonces() {
111        let cipher = RepoCipher::generate();
112        let data = b"same data";
113        let e1 = cipher.encrypt(data);
114        let e2 = cipher.encrypt(data);
115        assert_ne!(
116            e1, e2,
117            "two encryptions of same data should differ (random nonce)"
118        );
119    }
120
121    #[test]
122    fn test_decrypt_wrong_key() {
123        let cipher1 = RepoCipher::generate();
124        let cipher2 = RepoCipher::generate();
125        let data = b"secret message";
126        let encrypted = cipher1.encrypt(data);
127        let result = cipher2.decrypt(&encrypted);
128        assert!(result.is_err(), "decrypt with wrong key should fail");
129    }
130
131    #[test]
132    fn test_decrypt_tampered() {
133        let cipher = RepoCipher::generate();
134        let data = b"tamper me";
135        let mut encrypted = cipher.encrypt(data);
136        // Flip a byte in the ciphertext portion (after nonce)
137        if encrypted.len() > NONCE_SIZE + 1 {
138            encrypted[NONCE_SIZE] ^= 0xFF;
139        }
140        let result = cipher.decrypt(&encrypted);
141        assert!(
142            result.is_err(),
143            "decrypt of tampered ciphertext should fail"
144        );
145    }
146
147    #[test]
148    fn test_short_data() {
149        let cipher = RepoCipher::generate();
150        let result = cipher.decrypt(&[0u8; 5]);
151        assert!(
152            result.is_err(),
153            "decrypt of data shorter than nonce should fail"
154        );
155    }
156
157    #[test]
158    fn test_key_save_load_roundtrip() {
159        use tempfile::tempdir;
160        let dir = tempdir().unwrap();
161        let key = generate_repo_key();
162        save_repo_key(dir.path(), &key).unwrap();
163        let loaded = load_repo_key(dir.path()).unwrap();
164        assert_eq!(key, loaded);
165    }
166
167    #[test]
168    fn test_from_key_roundtrip() {
169        let key = generate_repo_key();
170        let cipher = RepoCipher::from_key(&key);
171        let data = b"roundtrip via from_key";
172        let encrypted = cipher.encrypt(data);
173        let decrypted = cipher.decrypt(&encrypted).unwrap();
174        assert_eq!(decrypted, data);
175
176        // Same key should produce decryptable ciphertext
177        let cipher2 = RepoCipher::from_key(&key);
178        let decrypted2 = cipher2.decrypt(&encrypted).unwrap();
179        assert_eq!(decrypted2, data);
180    }
181}