1mod aes_gcm;
6mod xchacha20;
7mod key;
8
9pub use aes_gcm::{encrypt_aes_gcm, decrypt_aes_gcm};
10pub use xchacha20::{encrypt_xchacha20, decrypt_xchacha20};
11pub use key::{generate_key, derive_key_hkdf, derive_key_pbkdf2, Key};
12
13use crate::{Error, Result, MAGIC_ENCRYPTED, FORMAT_VERSION};
14use alloc::vec::Vec;
15use serde::{Deserialize, Serialize};
16use zeroize::Zeroize;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[repr(u8)]
21pub enum Algorithm {
22 Aes256Gcm = 0x01,
24 XChaCha20Poly1305 = 0x02,
26}
27
28impl Algorithm {
29 pub fn from_byte(byte: u8) -> Result<Self> {
31 match byte {
32 0x01 => Ok(Algorithm::Aes256Gcm),
33 0x02 => Ok(Algorithm::XChaCha20Poly1305),
34 _ => Err(Error::UnsupportedAlgorithm(byte)),
35 }
36 }
37
38 pub fn nonce_len(&self) -> usize {
40 match self {
41 Algorithm::Aes256Gcm => 12,
42 Algorithm::XChaCha20Poly1305 => 24,
43 }
44 }
45
46 pub fn name(&self) -> &'static str {
48 match self {
49 Algorithm::Aes256Gcm => "aes-256-gcm",
50 Algorithm::XChaCha20Poly1305 => "xchacha20-poly1305",
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct EncryptionResult {
58 pub ciphertext: Vec<u8>,
60 pub algorithm: Algorithm,
62 pub nonce: Vec<u8>,
64 pub tag: Vec<u8>,
66}
67
68impl EncryptionResult {
69 pub fn to_bytes(&self) -> Vec<u8> {
71 let nonce_len = self.nonce.len();
72 let total_len = 8 + nonce_len + self.ciphertext.len() + self.tag.len();
73 let mut buf = Vec::with_capacity(total_len);
74
75 buf.extend_from_slice(MAGIC_ENCRYPTED);
77 buf.push(FORMAT_VERSION);
79 buf.push(self.algorithm as u8);
81 buf.push(nonce_len as u8);
83 buf.push(0x00);
85 buf.extend_from_slice(&self.nonce);
87 buf.extend_from_slice(&self.ciphertext);
89 buf.extend_from_slice(&self.tag);
91
92 buf
93 }
94
95 pub fn from_bytes(data: &[u8]) -> Result<Self> {
97 if data.len() < 8 {
98 return Err(Error::TruncatedPayload {
99 expected: 8,
100 actual: data.len(),
101 });
102 }
103
104 if &data[0..4] != MAGIC_ENCRYPTED {
106 return Err(Error::InvalidFormat);
107 }
108
109 let version = data[4];
111 if version != FORMAT_VERSION {
112 return Err(Error::UnsupportedVersion(version));
113 }
114
115 let algorithm = Algorithm::from_byte(data[5])?;
117
118 let nonce_len = data[6] as usize;
120 if nonce_len != algorithm.nonce_len() {
121 return Err(Error::InvalidNonceLength {
122 expected: algorithm.nonce_len(),
123 actual: nonce_len,
124 });
125 }
126
127 let min_size = 8 + nonce_len + 16; if data.len() < min_size {
130 return Err(Error::TruncatedPayload {
131 expected: min_size,
132 actual: data.len(),
133 });
134 }
135
136 let nonce = data[8..8 + nonce_len].to_vec();
138
139 let tag = data[data.len() - 16..].to_vec();
141
142 let ciphertext = data[8 + nonce_len..data.len() - 16].to_vec();
144
145 Ok(EncryptionResult {
146 ciphertext,
147 algorithm,
148 nonce,
149 tag,
150 })
151 }
152
153 pub fn to_json(&self) -> Result<String> {
155 #[derive(Serialize)]
156 struct JsonFormat<'a> {
157 v: &'static str,
158 alg: &'a str,
159 nonce: String,
160 ct: String,
161 tag: String,
162 }
163
164 use base64::{Engine, engine::general_purpose::STANDARD};
165
166 let json = JsonFormat {
167 v: "1.0",
168 alg: self.algorithm.name(),
169 nonce: STANDARD.encode(&self.nonce),
170 ct: STANDARD.encode(&self.ciphertext),
171 tag: STANDARD.encode(&self.tag),
172 };
173
174 serde_json::to_string(&json).map_err(|e| Error::SerializationError(e.to_string()))
175 }
176
177 pub fn from_json(json: &str) -> Result<Self> {
179 #[derive(Deserialize)]
180 struct JsonFormat {
181 v: String,
182 alg: String,
183 nonce: String,
184 ct: String,
185 tag: String,
186 }
187
188 let parsed: JsonFormat = serde_json::from_str(json)?;
189
190 if parsed.v != "1.0" {
191 return Err(Error::UnsupportedVersion(0));
192 }
193
194 let algorithm = match parsed.alg.as_str() {
195 "aes-256-gcm" => Algorithm::Aes256Gcm,
196 "xchacha20-poly1305" => Algorithm::XChaCha20Poly1305,
197 _ => return Err(Error::UnsupportedAlgorithm(0)),
198 };
199
200 use base64::{Engine, engine::general_purpose::STANDARD};
201
202 Ok(EncryptionResult {
203 ciphertext: STANDARD.decode(&parsed.ct)?,
204 algorithm,
205 nonce: STANDARD.decode(&parsed.nonce)?,
206 tag: STANDARD.decode(&parsed.tag)?,
207 })
208 }
209}
210
211impl Drop for EncryptionResult {
212 fn drop(&mut self) {
213 self.ciphertext.zeroize();
214 self.nonce.zeroize();
215 self.tag.zeroize();
216 }
217}
218
219#[derive(Debug, Clone, Default)]
221pub struct EncryptOptions {
222 pub algorithm: Option<Algorithm>,
224 pub aad: Option<Vec<u8>>,
226}
227
228pub fn encrypt(plaintext: &[u8], key: &Key, options: Option<EncryptOptions>) -> Result<EncryptionResult> {
242 let opts = options.unwrap_or_default();
243 let algorithm = opts.algorithm.unwrap_or(Algorithm::Aes256Gcm);
244 let aad = opts.aad.as_deref().unwrap_or(&[]);
245
246 match algorithm {
247 Algorithm::Aes256Gcm => encrypt_aes_gcm(plaintext, key, aad),
248 Algorithm::XChaCha20Poly1305 => encrypt_xchacha20(plaintext, key, aad),
249 }
250}
251
252pub fn decrypt(encrypted: &EncryptionResult, key: &Key) -> Result<Vec<u8>> {
265 decrypt_with_aad(encrypted, key, &[])
266}
267
268pub fn decrypt_with_aad(encrypted: &EncryptionResult, key: &Key, aad: &[u8]) -> Result<Vec<u8>> {
280 match encrypted.algorithm {
281 Algorithm::Aes256Gcm => decrypt_aes_gcm(encrypted, key, aad),
282 Algorithm::XChaCha20Poly1305 => decrypt_xchacha20(encrypted, key, aad),
283 }
284}
285
286pub fn decrypt_bytes(data: &[u8], key: &Key) -> Result<Vec<u8>> {
288 let encrypted = EncryptionResult::from_bytes(data)?;
289 decrypt(&encrypted, key)
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_encrypt_decrypt_roundtrip() {
298 let key = generate_key();
299 let plaintext = b"Hello, World!";
300
301 let encrypted = encrypt(plaintext, &key, None).unwrap();
302 let decrypted = decrypt(&encrypted, &key).unwrap();
303
304 assert_eq!(plaintext, &decrypted[..]);
305 }
306
307 #[test]
308 fn test_binary_serialization() {
309 let key = generate_key();
310 let plaintext = b"Test data for serialization";
311
312 let encrypted = encrypt(plaintext, &key, None).unwrap();
313 let bytes = encrypted.to_bytes();
314 let restored = EncryptionResult::from_bytes(&bytes).unwrap();
315
316 assert_eq!(encrypted.algorithm, restored.algorithm);
317 assert_eq!(encrypted.nonce, restored.nonce);
318 assert_eq!(encrypted.ciphertext, restored.ciphertext);
319 assert_eq!(encrypted.tag, restored.tag);
320 }
321
322 #[test]
323 fn test_json_serialization() {
324 let key = generate_key();
325 let plaintext = b"Test data for JSON";
326
327 let encrypted = encrypt(plaintext, &key, None).unwrap();
328 let json = encrypted.to_json().unwrap();
329 let restored = EncryptionResult::from_json(&json).unwrap();
330
331 assert_eq!(encrypted.algorithm, restored.algorithm);
332 assert_eq!(encrypted.nonce, restored.nonce);
333 assert_eq!(encrypted.ciphertext, restored.ciphertext);
334 assert_eq!(encrypted.tag, restored.tag);
335 }
336}
337