1use aes_gcm::aead::{Aead, OsRng};
2use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce};
3use anyhow::Result;
4use rand::RngCore;
5use std::fs;
6use std::path::Path;
7
8const KEY_SIZE: usize = 32;
9const NONCE_SIZE: usize = 12;
10
11pub struct RepoCipher {
12 cipher: Aes256Gcm,
13}
14
15impl RepoCipher {
16 pub fn generate() -> Self {
17 let key = Aes256Gcm::generate_key(OsRng);
18 Self {
19 cipher: Aes256Gcm::new(&key),
20 }
21 }
22
23 pub fn from_key(key_bytes: &[u8; KEY_SIZE]) -> Self {
24 let key = Key::<Aes256Gcm>::from_slice(key_bytes);
25 Self {
26 cipher: Aes256Gcm::new(key),
27 }
28 }
29
30 pub fn encrypt(&self, plaintext: &[u8]) -> Vec<u8> {
31 let mut nonce_bytes = [0u8; NONCE_SIZE];
32 OsRng.fill_bytes(&mut nonce_bytes);
33 let nonce = Nonce::from_slice(&nonce_bytes);
34 let ciphertext = self
35 .cipher
36 .encrypt(nonce, plaintext)
37 .expect("encryption should never fail with given params");
38 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
39 result.extend_from_slice(&nonce_bytes);
40 result.extend_from_slice(&ciphertext);
41 result
42 }
43
44 pub fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
45 if data.len() < NONCE_SIZE {
46 anyhow::bail!(
47 "encrypted data too short (need {} bytes for nonce)",
48 NONCE_SIZE
49 );
50 }
51 let (nonce_bytes, ciphertext) = data.split_at(NONCE_SIZE);
52 let nonce = Nonce::from_slice(nonce_bytes);
53 let plaintext = self
54 .cipher
55 .decrypt(nonce, ciphertext)
56 .map_err(|_| anyhow::anyhow!("decryption failed — wrong key or corrupted data"))?;
57 Ok(plaintext)
58 }
59
60 pub fn key_bytes(&self) -> [u8; KEY_SIZE] {
61 unimplemented!("use save_repo_key / load_repo_key instead")
64 }
65}
66
67pub fn generate_repo_key() -> [u8; KEY_SIZE] {
68 let mut key = [0u8; KEY_SIZE];
69 OsRng.fill_bytes(&mut key);
70 key
71}
72
73pub fn save_repo_key(keys_dir: &Path, key: &[u8; KEY_SIZE]) -> Result<()> {
74 fs::write(keys_dir.join("repo.key"), hex::encode(key))?;
75 Ok(())
76}
77
78pub fn load_repo_key(keys_dir: &Path) -> Result<[u8; KEY_SIZE]> {
79 let hex_key = fs::read_to_string(keys_dir.join("repo.key"))?;
80 let key_hex = hex_key.trim();
81 let bytes = hex::decode(key_hex)?;
82 if bytes.len() != KEY_SIZE {
83 anyhow::bail!(
84 "invalid repo.key: expected {} bytes, got {}",
85 KEY_SIZE,
86 bytes.len()
87 );
88 }
89 let mut key = [0u8; KEY_SIZE];
90 key.copy_from_slice(&bytes);
91 Ok(key)
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn test_encrypt_decrypt_roundtrip() {
100 let cipher = RepoCipher::generate();
101 let data = b"hello world this is test data";
102 let encrypted = cipher.encrypt(data);
103 assert_ne!(encrypted, data);
104 assert!(encrypted.len() > NONCE_SIZE);
105 let decrypted = cipher.decrypt(&encrypted).unwrap();
106 assert_eq!(decrypted, data);
107 }
108
109 #[test]
110 fn test_encrypt_different_nonces() {
111 let cipher = RepoCipher::generate();
112 let data = b"same data";
113 let e1 = cipher.encrypt(data);
114 let e2 = cipher.encrypt(data);
115 assert_ne!(
116 e1, e2,
117 "two encryptions of same data should differ (random nonce)"
118 );
119 }
120
121 #[test]
122 fn test_decrypt_wrong_key() {
123 let cipher1 = RepoCipher::generate();
124 let cipher2 = RepoCipher::generate();
125 let data = b"secret message";
126 let encrypted = cipher1.encrypt(data);
127 let result = cipher2.decrypt(&encrypted);
128 assert!(result.is_err(), "decrypt with wrong key should fail");
129 }
130
131 #[test]
132 fn test_decrypt_tampered() {
133 let cipher = RepoCipher::generate();
134 let data = b"tamper me";
135 let mut encrypted = cipher.encrypt(data);
136 if encrypted.len() > NONCE_SIZE + 1 {
138 encrypted[NONCE_SIZE] ^= 0xFF;
139 }
140 let result = cipher.decrypt(&encrypted);
141 assert!(
142 result.is_err(),
143 "decrypt of tampered ciphertext should fail"
144 );
145 }
146
147 #[test]
148 fn test_short_data() {
149 let cipher = RepoCipher::generate();
150 let result = cipher.decrypt(&[0u8; 5]);
151 assert!(
152 result.is_err(),
153 "decrypt of data shorter than nonce should fail"
154 );
155 }
156
157 #[test]
158 fn test_key_save_load_roundtrip() {
159 use tempfile::tempdir;
160 let dir = tempdir().unwrap();
161 let key = generate_repo_key();
162 save_repo_key(dir.path(), &key).unwrap();
163 let loaded = load_repo_key(dir.path()).unwrap();
164 assert_eq!(key, loaded);
165 }
166
167 #[test]
168 fn test_from_key_roundtrip() {
169 let key = generate_repo_key();
170 let cipher = RepoCipher::from_key(&key);
171 let data = b"roundtrip via from_key";
172 let encrypted = cipher.encrypt(data);
173 let decrypted = cipher.decrypt(&encrypted).unwrap();
174 assert_eq!(decrypted, data);
175
176 let cipher2 = RepoCipher::from_key(&key);
178 let decrypted2 = cipher2.decrypt(&encrypted).unwrap();
179 assert_eq!(decrypted2, data);
180 }
181}