1#![warn(missing_docs, clippy::all)]
64
65pub use crate::deserializer::Deserializer;
66use aes_gcm::aes::cipher::InOutBuf;
67use aes_gcm::aes::Aes256;
68use aes_gcm::{AeadInOut, Aes256Gcm, KeyInit, Nonce};
69use aes_gcm::{AesGcm, Tag};
70use base64::display::Base64Display;
71use base64::engine::general_purpose::STANDARD;
72use base64::Engine;
73use rand::{CryptoRng, RngExt};
74use serde::{Deserialize, Serialize};
75use std::error;
76use std::fmt;
77use std::fs;
78use std::io;
79use std::path::Path;
80use std::result;
81use std::str::FromStr;
82use std::string::FromUtf8Error;
83use typenum::U32;
84
85const KEY_PREFIX: &str = "AES:";
86const KEY_LEN: usize = 32;
87const LEGACY_IV_LEN: usize = 32;
88const IV_LEN: usize = 12;
89const TAG_LEN: usize = 16;
90
91type LegacyAes256Gcm = AesGcm<Aes256, U32>;
92
93mod deserializer;
94
95pub type Result<T> = result::Result<T, Error>;
97
98#[derive(Debug)]
99enum ErrorCause {
100 AesGcm(aes_gcm::Error),
101 Io(io::Error),
102 Base64(base64::DecodeError),
103 Utf8(FromUtf8Error),
104 BadPrefix,
105 InvalidLength,
106 KeyExhausted,
107}
108
109#[derive(Debug)]
111pub struct Error(Box<ErrorCause>);
112
113impl fmt::Display for Error {
114 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
115 match *self.0 {
116 ErrorCause::AesGcm(ref e) => fmt::Display::fmt(e, fmt),
117 ErrorCause::Io(ref e) => fmt::Display::fmt(e, fmt),
118 ErrorCause::Base64(ref e) => fmt::Display::fmt(e, fmt),
119 ErrorCause::Utf8(ref e) => fmt::Display::fmt(e, fmt),
120 ErrorCause::BadPrefix => fmt.write_str("invalid key prefix"),
121 ErrorCause::InvalidLength => fmt.write_str("invalid encrypted value component length"),
122 ErrorCause::KeyExhausted => fmt.write_str("key cannot encrypt more than 2^64 values"),
123 }
124 }
125}
126
127impl error::Error for Error {}
128
129#[derive(Serialize, Deserialize)]
130#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE")]
131enum EncryptedValue {
132 Aes {
133 mode: AesMode,
134 #[serde(with = "serde_base64")]
135 iv: Vec<u8>,
136 #[serde(with = "serde_base64")]
137 ciphertext: Vec<u8>,
138 #[serde(with = "serde_base64")]
139 tag: Vec<u8>,
140 },
141}
142
143mod serde_base64 {
144 use base64::engine::general_purpose::STANDARD;
145 use base64::Engine;
146 use serde::de;
147 use serde::{Deserialize, Deserializer, Serialize, Serializer};
148
149 pub fn serialize<S>(buf: &[u8], s: S) -> Result<S::Ok, S::Error>
150 where
151 S: Serializer,
152 {
153 STANDARD.encode(buf).serialize(s)
154 }
155
156 pub fn deserialize<'a, D>(d: D) -> Result<Vec<u8>, D::Error>
157 where
158 D: Deserializer<'a>,
159 {
160 let s = String::deserialize(d)?;
161 STANDARD
162 .decode(&s)
163 .map_err(|_| de::Error::invalid_value(de::Unexpected::Str(&s), &"a base64 string"))
164 }
165}
166
167fn secure_rng() -> impl CryptoRng {
169 rand::rng()
170}
171
172#[derive(Serialize, Deserialize)]
173#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
174enum AesMode {
175 Gcm,
176}
177
178pub struct ReadOnly(());
180
181pub struct ReadWrite {
183 iv: [u8; IV_LEN],
190 counter: u64,
191}
192
193pub struct Key<T> {
204 key: [u8; KEY_LEN],
205 mode: T,
206}
207
208impl Key<ReadWrite> {
209 pub fn random_aes() -> Result<Key<ReadWrite>> {
211 Ok(Key {
212 key: secure_rng().random(),
213 mode: ReadWrite {
214 iv: secure_rng().random(),
215 counter: 0,
216 },
217 })
218 }
219
220 pub fn encrypt(&mut self, value: &str) -> Result<String> {
222 let counter = self.mode.counter;
223 self.mode.counter = match self.mode.counter.checked_add(1) {
224 Some(v) => v,
225 None => return Err(Error(Box::new(ErrorCause::KeyExhausted))),
226 };
227
228 let mut iv = Nonce::from(self.mode.iv);
229 for (i, byte) in counter.to_le_bytes().iter().enumerate() {
230 iv[i] ^= *byte;
231 }
232
233 let mut ciphertext = value.as_bytes().to_vec();
234 let tag = Aes256Gcm::new(&self.key.into())
235 .encrypt_inout_detached(&iv, &[], InOutBuf::from(&mut *ciphertext))
236 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
237
238 let value = EncryptedValue::Aes {
239 mode: AesMode::Gcm,
240 iv: iv.to_vec(),
241 ciphertext,
242 tag: tag.to_vec(),
243 };
244
245 let value = serde_json::to_string(&value).unwrap();
246 Ok(STANDARD.encode(value.as_bytes()))
247 }
248}
249
250impl Key<ReadOnly> {
251 pub fn from_file<P>(path: P) -> Result<Option<Key<ReadOnly>>>
256 where
257 P: AsRef<Path>,
258 {
259 let s = match fs::read_to_string(path) {
260 Ok(s) => s,
261 Err(ref e) if e.kind() == io::ErrorKind::NotFound => return Ok(None),
262 Err(e) => return Err(Error(Box::new(ErrorCause::Io(e)))),
263 };
264 s.parse().map(Some)
265 }
266}
267
268impl<T> Key<T> {
269 pub fn decrypt(&self, value: &str) -> Result<String> {
271 let value = STANDARD
272 .decode(value)
273 .map_err(|e| Error(Box::new(ErrorCause::Base64(e))))?;
274
275 let (iv, mut ct, tag) = match serde_json::from_slice(&value) {
276 Ok(EncryptedValue::Aes {
277 mode: AesMode::Gcm,
278 iv,
279 ciphertext,
280 tag,
281 }) => {
282 if iv.len() != IV_LEN || tag.len() != TAG_LEN {
283 return Err(Error(Box::new(ErrorCause::InvalidLength)));
284 }
285
286 let mut iv_arr = [0; IV_LEN];
287 iv_arr.copy_from_slice(&iv);
288
289 let mut tag_arr = [0; TAG_LEN];
290 tag_arr.copy_from_slice(&tag);
291
292 (Iv::Standard(iv_arr), ciphertext, tag_arr)
293 }
294 Err(_) => {
295 if value.len() < LEGACY_IV_LEN + TAG_LEN {
296 return Err(Error(Box::new(ErrorCause::InvalidLength)));
297 }
298
299 let mut iv = [0; LEGACY_IV_LEN];
300 iv.copy_from_slice(&value[..LEGACY_IV_LEN]);
301
302 let ct = value[LEGACY_IV_LEN..value.len() - TAG_LEN].to_vec();
303
304 let mut tag = [0; TAG_LEN];
305 tag.copy_from_slice(&value[value.len() - TAG_LEN..]);
306
307 (Iv::Legacy(iv), ct, tag)
308 }
309 };
310
311 let tag = Tag::from(tag);
312
313 match iv {
314 Iv::Legacy(iv) => {
315 let iv = Nonce::from(iv);
316
317 LegacyAes256Gcm::new(&self.key.into())
318 .decrypt_inout_detached(&iv, &[], InOutBuf::from(&mut *ct), &tag)
319 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
320 }
321 Iv::Standard(iv) => {
322 let iv = Nonce::from(iv);
323
324 Aes256Gcm::new(&self.key.into())
325 .decrypt_inout_detached(&iv, &[], InOutBuf::from(&mut *ct), &tag)
326 .map_err(|e| Error(Box::new(ErrorCause::AesGcm(e))))?;
327 }
328 };
329
330 let pt = String::from_utf8(ct).map_err(|e| Error(Box::new(ErrorCause::Utf8(e))))?;
331
332 Ok(pt)
333 }
334}
335
336impl<T> fmt::Display for Key<T> {
337 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
338 write!(fmt, "AES:{}", Base64Display::new(&self.key, &STANDARD))
339 }
340}
341
342impl FromStr for Key<ReadOnly> {
343 type Err = Error;
344
345 fn from_str(s: &str) -> Result<Key<ReadOnly>> {
346 if !s.starts_with(KEY_PREFIX) {
347 return Err(Error(Box::new(ErrorCause::BadPrefix)));
348 }
349
350 let key = STANDARD
351 .decode(&s[KEY_PREFIX.len()..])
352 .map_err(|e| Error(Box::new(ErrorCause::Base64(e))))?;
353
354 if key.len() != KEY_LEN {
355 return Err(Error(Box::new(ErrorCause::InvalidLength)));
356 }
357
358 let mut key_arr = [0; KEY_LEN];
359 key_arr.copy_from_slice(&key);
360
361 Ok(Key {
362 key: key_arr,
363 mode: ReadOnly(()),
364 })
365 }
366}
367
368enum Iv {
369 Legacy([u8; LEGACY_IV_LEN]),
370 Standard([u8; IV_LEN]),
371}
372
373#[cfg(test)]
374mod test {
375 use serde::Deserialize;
376 use std::fs::File;
377 use std::io::Write;
378 use tempfile::tempdir;
379
380 use super::*;
381
382 const KEY: &str = "AES:NwQZdNWsFmYMCNSQlfYPDJtFBgPzY8uZlFhMCLnxNQE=";
383
384 #[test]
385 fn from_file_aes() {
386 let dir = tempdir().unwrap();
387 let path = dir.path().join("encrypted-config-value.key");
388 let mut key = File::create(&path).unwrap();
389 key.write_all(KEY.as_bytes()).unwrap();
390
391 assert!(Key::from_file(&path).unwrap().is_some());
392 }
393
394 #[test]
395 fn from_file_empty() {
396 let dir = tempdir().unwrap();
397 let path = dir.path().join("encrypted-config-value.key");
398
399 assert!(Key::from_file(path).unwrap().is_none());
400 }
401
402 #[test]
403 fn decrypt_legacy() {
404 let ct =
405 "5BBfGvf90H6bApwfxUjNdoKRW1W+GZCbhBuBpzEogVBmQZyWFFxcKyf+UPV5FOhrw/wrVZyoL3npoDfYj\
406 PQV/zg0W/P9cVOw";
407 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
408
409 let key = KEY.parse::<Key<ReadOnly>>().unwrap();
410 let actual = key.decrypt(ct).unwrap();
411 assert_eq!(actual, pt);
412 }
413
414 #[test]
415 fn decrypt() {
416 let ct =
417 "eyJ0eXBlIjoiQUVTIiwibW9kZSI6IkdDTSIsIml2IjoiUCtRQXM5aHo4VFJVOUpNLyIsImNpcGhlcnRle\
418 HQiOiJmUGpDaDVuMkR0cklPSVNXSklLcVQzSUtRNUtONVI3LyIsInRhZyI6ImlJRFIzYUtER1UyK1Brej\
419 NPSEdSL0E9PSJ9";
420 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
421
422 let key = KEY.parse::<Key<ReadOnly>>().unwrap();
423 let actual = key.decrypt(ct).unwrap();
424 assert_eq!(actual, pt);
425 }
426
427 #[test]
428 fn encrypt_decrypt() {
429 let mut key = Key::random_aes().unwrap();
430 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
431 let ct = key.encrypt(pt).unwrap();
432 let actual = key.decrypt(&ct).unwrap();
433 assert_eq!(pt, actual);
434 }
435
436 #[test]
437 fn unique_ivs() {
438 let mut key = Key::random_aes().unwrap();
439 let pt = "L/TqOWz7E4z0SoeiTYBrqbqu";
440 let ct1 = key.encrypt(pt).unwrap();
441 let ct2 = key.encrypt(pt).unwrap();
442 assert_ne!(ct1, ct2);
443 }
444
445 #[test]
446 fn deserializer() {
447 #[derive(Deserialize, PartialEq, Debug)]
448 struct Config {
449 sub: Subconfig,
450 }
451
452 #[derive(Deserialize, PartialEq, Debug)]
453 struct Subconfig {
454 encrypted: Vec<String>,
455 plaintext: String,
456 }
457
458 let config = r#"
459{
460 "sub": {
461 "encrypted": [
462 "${enc:5BBfGvf90H6bApwfxUjNdoKRW1W+GZCbhBuBpzEogVBmQZyWFFxcKyf+UPV5FOhrw/wrVZyoL3npoDfYjPQV/zg0W/P9cVOw}"
463 ],
464 "plaintext": "${foobar}"
465 }
466}
467 "#;
468
469 let key = KEY.parse().unwrap();
470 let mut deserializer = serde_json::Deserializer::from_str(config);
471 let deserializer = Deserializer::new(&mut deserializer, Some(&key));
472
473 let config = Config::deserialize(deserializer).unwrap();
474
475 let expected = Config {
476 sub: Subconfig {
477 encrypted: vec!["L/TqOWz7E4z0SoeiTYBrqbqu".to_string()],
478 plaintext: "${foobar}".to_string(),
479 },
480 };
481
482 assert_eq!(config, expected);
483 }
484}