1use std::sync::OnceLock;
43
44use base64::Engine as _;
45use crypto_box::{
46 SalsaBox,
47 aead::{Aead, AeadCore},
48};
49use rand_core::OsRng;
50use serde::{Deserialize, Deserializer, Serialize, Serializer};
51
52const REDACTED: &str = "••••••";
53const B64: base64::engine::general_purpose::GeneralPurpose =
54 base64::engine::general_purpose::STANDARD;
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub enum MaskError {
59 NoKeyring,
62 NoPrivateKey,
65 BadKey(String),
67 Malformed,
69 Decrypt,
71}
72
73impl std::fmt::Display for MaskError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 match self {
76 MaskError::NoKeyring => f.write_str(
77 "no mask keyring configured (set UMBRAL_MASK_PUBLIC_KEY or call set_mask_keyring)",
78 ),
79 MaskError::NoPrivateKey => f.write_str(
80 "mask keyring has no private key; cannot reveal (set UMBRAL_MASK_PRIVATE_KEY)",
81 ),
82 MaskError::BadKey(why) => write!(f, "invalid mask key: {why}"),
83 MaskError::Malformed => f.write_str("masked ciphertext is malformed"),
84 MaskError::Decrypt => f.write_str("masked ciphertext failed to decrypt"),
85 }
86 }
87}
88
89impl std::error::Error for MaskError {}
90
91#[derive(Clone)]
98pub struct MaskKeyring {
99 public: crypto_box::PublicKey,
100 secret: Option<crypto_box::SecretKey>,
101}
102
103impl MaskKeyring {
104 pub fn from_base64(public_b64: &str, secret_b64: Option<&str>) -> Result<Self, MaskError> {
107 let public = decode_key(public_b64)?;
108 let public = crypto_box::PublicKey::from(public);
109 let secret = match secret_b64 {
110 Some(s) if !s.is_empty() => Some(crypto_box::SecretKey::from(decode_key(s)?)),
111 _ => None,
112 };
113 Ok(Self { public, secret })
114 }
115
116 pub fn from_env() -> Result<Self, MaskError> {
120 let public = std::env::var("UMBRAL_MASK_PUBLIC_KEY").map_err(|_| MaskError::NoKeyring)?;
121 let secret = std::env::var("UMBRAL_MASK_PRIVATE_KEY").ok();
122 Self::from_base64(&public, secret.as_deref())
123 }
124
125 pub fn generate() -> (String, String) {
128 let secret = crypto_box::SecretKey::generate(&mut OsRng);
129 let public = secret.public_key();
130 (B64.encode(public.as_bytes()), B64.encode(secret.to_bytes()))
131 }
132
133 pub fn seal(&self, plaintext: &[u8]) -> String {
136 let eph_secret = crypto_box::SecretKey::generate(&mut OsRng);
140 let eph_public = eph_secret.public_key();
141 let salsa = SalsaBox::new(&self.public, &eph_secret);
142 let nonce = SalsaBox::generate_nonce(&mut OsRng);
143 let ciphertext = salsa
144 .encrypt(&nonce, plaintext)
145 .expect("XSalsa20-Poly1305 encryption is infallible for in-memory plaintext");
146 let mut out = Vec::with_capacity(32 + nonce.len() + ciphertext.len());
147 out.extend_from_slice(eph_public.as_bytes());
148 out.extend_from_slice(nonce.as_slice());
149 out.extend_from_slice(&ciphertext);
150 B64.encode(out)
151 }
152
153 pub fn open(&self, b64_ciphertext: &str) -> Result<String, MaskError> {
156 let secret = self.secret.as_ref().ok_or(MaskError::NoPrivateKey)?;
157 let sealed = B64
158 .decode(b64_ciphertext)
159 .map_err(|_| MaskError::Malformed)?;
160 if sealed.len() < 32 + 24 {
161 return Err(MaskError::Malformed);
162 }
163 let eph_public: [u8; 32] = sealed[..32].try_into().map_err(|_| MaskError::Malformed)?;
164 let eph_public = crypto_box::PublicKey::from(eph_public);
165 let nonce = crypto_box::Nonce::from_slice(&sealed[32..56]);
166 let ciphertext = &sealed[56..];
167 let salsa = SalsaBox::new(&eph_public, secret);
168 let plaintext = salsa
169 .decrypt(nonce, ciphertext)
170 .map_err(|_| MaskError::Decrypt)?;
171 String::from_utf8(plaintext).map_err(|_| MaskError::Decrypt)
172 }
173
174 pub fn can_reveal(&self) -> bool {
176 self.secret.is_some()
177 }
178}
179
180fn decode_key(b64: &str) -> Result<[u8; 32], MaskError> {
181 let bytes = B64
182 .decode(b64.trim())
183 .map_err(|e| MaskError::BadKey(e.to_string()))?;
184 bytes
185 .try_into()
186 .map_err(|_| MaskError::BadKey("key is not 32 bytes".to_string()))
187}
188
189static KEYRING: OnceLock<Result<Option<MaskKeyring>, MaskError>> = OnceLock::new();
196
197pub fn set_mask_keyring(keyring: MaskKeyring) -> bool {
201 KEYRING.set(Ok(Some(keyring))).is_ok()
202}
203
204fn keyring() -> Result<Option<&'static MaskKeyring>, &'static MaskError> {
214 KEYRING
215 .get_or_init(|| match MaskKeyring::from_env() {
216 Ok(k) => Ok(Some(k)),
217 Err(MaskError::NoKeyring) => Ok(None),
220 Err(e) => {
225 tracing::error!(
226 "UMBRAL_MASK_PUBLIC_KEY/UMBRAL_MASK_PRIVATE_KEY is set but could not be \
227 parsed ({e}); all Masked<T> seal/reveal calls will fail with BadKey. \
228 Fix the key or unset the variable."
229 );
230 Err(e)
231 }
232 })
233 .as_ref()
234 .map(|opt| opt.as_ref())
235 .map_err(|e| e)
236}
237
238fn ambient_seal(plaintext: &str) -> Result<String, MaskError> {
240 match keyring() {
241 Ok(Some(k)) => Ok(k.seal(plaintext.as_bytes())),
242 Ok(None) => Err(MaskError::NoKeyring),
243 Err(e) => Err(e.clone()),
244 }
245}
246
247fn ambient_open(ciphertext: &str) -> Result<String, MaskError> {
249 match keyring() {
250 Ok(Some(k)) => k.open(ciphertext),
251 Ok(None) => Err(MaskError::NoKeyring),
252 Err(e) => Err(e.clone()),
253 }
254}
255
256#[derive(Clone)]
268pub struct Masked<T = String> {
269 inner: MaskInner,
270 _marker: std::marker::PhantomData<T>,
271}
272
273#[derive(Clone)]
274enum MaskInner {
275 Plain(String),
278 Sealed(String),
281}
282
283impl<T> Masked<T> {
284 pub fn new(plaintext: impl Into<String>) -> Self {
287 Self {
288 inner: MaskInner::Plain(plaintext.into()),
289 _marker: std::marker::PhantomData,
290 }
291 }
292
293 pub fn reveal(&self) -> Result<String, MaskError> {
297 match &self.inner {
298 MaskInner::Plain(p) => Ok(p.clone()),
299 MaskInner::Sealed(c) => ambient_open(c),
300 }
301 }
302
303 pub fn is_revealable(&self) -> bool {
306 match &self.inner {
307 MaskInner::Plain(_) => true,
308 MaskInner::Sealed(_) => keyring()
309 .ok()
310 .and_then(|opt| opt)
311 .map(MaskKeyring::can_reveal)
312 .unwrap_or(false),
313 }
314 }
315
316 fn to_stored(&self) -> Result<String, MaskError> {
320 match &self.inner {
321 MaskInner::Plain(p) => ambient_seal(p),
322 MaskInner::Sealed(c) => Ok(c.clone()),
323 }
324 }
325}
326
327impl<T> Default for Masked<T> {
328 fn default() -> Self {
331 Masked::new(String::new())
332 }
333}
334
335impl<T> std::fmt::Debug for Masked<T> {
336 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
337 f.write_str("Masked(••••••)")
338 }
339}
340
341impl<T> std::fmt::Display for Masked<T> {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 f.write_str(REDACTED)
344 }
345}
346
347impl<T> From<String> for Masked<T> {
348 fn from(plaintext: String) -> Self {
349 Masked::new(plaintext)
350 }
351}
352
353impl<T> From<&str> for Masked<T> {
354 fn from(plaintext: &str) -> Self {
355 Masked::new(plaintext)
356 }
357}
358
359impl<T> Serialize for Masked<T> {
362 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
375 let stored = self.to_stored().map_err(serde::ser::Error::custom)?;
376 s.serialize_str(&stored)
377 }
378}
379
380impl<'de, T> Deserialize<'de> for Masked<T> {
381 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
386 let s = String::deserialize(d)?;
387 if s == REDACTED {
388 Ok(Masked::new(String::new()))
390 } else {
391 Ok(Masked::new(s))
392 }
393 }
394}
395
396macro_rules! impl_masked_sqlx {
399 ($db:ty, $valueref:ty, $argbuf:ty) => {
400 impl<T> sqlx::Type<$db> for Masked<T> {
401 fn type_info() -> <$db as sqlx::Database>::TypeInfo {
402 <String as sqlx::Type<$db>>::type_info()
403 }
404 fn compatible(ty: &<$db as sqlx::Database>::TypeInfo) -> bool {
405 <String as sqlx::Type<$db>>::compatible(ty)
406 }
407 }
408
409 impl<'r, T> sqlx::Decode<'r, $db> for Masked<T> {
410 fn decode(value: $valueref) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
411 let ciphertext = <String as sqlx::Decode<$db>>::decode(value)?;
412 Ok(Masked {
413 inner: MaskInner::Sealed(ciphertext),
414 _marker: std::marker::PhantomData,
415 })
416 }
417 }
418
419 impl<'q, T> sqlx::Encode<'q, $db> for Masked<T> {
420 fn encode_by_ref(
421 &self,
422 buf: &mut $argbuf,
423 ) -> Result<sqlx::encode::IsNull, Box<dyn std::error::Error + Send + Sync>> {
424 let stored = self.to_stored()?;
425 <String as sqlx::Encode<'q, $db>>::encode_by_ref(&stored, buf)
426 }
427 }
428 };
429}
430
431impl_masked_sqlx!(
432 sqlx::Sqlite,
433 sqlx::sqlite::SqliteValueRef<'r>,
434 <sqlx::Sqlite as sqlx::Database>::ArgumentBuffer<'q>
435);
436impl_masked_sqlx!(
437 sqlx::Postgres,
438 sqlx::postgres::PgValueRef<'r>,
439 <sqlx::Postgres as sqlx::Database>::ArgumentBuffer<'q>
440);
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 fn test_keyring() -> MaskKeyring {
447 let (public, secret) = MaskKeyring::generate();
448 MaskKeyring::from_base64(&public, Some(&secret)).unwrap()
449 }
450
451 #[test]
452 fn seal_open_round_trips() {
453 let kr = test_keyring();
454 let sealed = kr.seal(b"+254712345678");
455 assert_ne!(sealed, "+254712345678", "stored form is not plaintext");
456 assert_eq!(kr.open(&sealed).unwrap(), "+254712345678");
457 }
458
459 #[test]
460 fn each_seal_is_distinct_ciphertext() {
461 let kr = test_keyring();
464 let a = kr.seal(b"secret");
465 let b = kr.seal(b"secret");
466 assert_ne!(a, b, "ephemeral keypair makes ciphertext non-deterministic");
467 assert_eq!(kr.open(&a).unwrap(), "secret");
468 assert_eq!(kr.open(&b).unwrap(), "secret");
469 }
470
471 #[test]
472 fn public_key_only_cannot_open() {
473 let (public, secret) = MaskKeyring::generate();
474 let write_only = MaskKeyring::from_base64(&public, None).unwrap();
475 let sealed = write_only.seal(b"pii");
476 assert_eq!(write_only.open(&sealed), Err(MaskError::NoPrivateKey));
477 let full = MaskKeyring::from_base64(&public, Some(&secret)).unwrap();
479 assert_eq!(full.open(&sealed).unwrap(), "pii");
480 }
481
482 #[test]
483 fn wrong_key_fails_to_decrypt() {
484 let a = test_keyring();
485 let b = test_keyring();
486 let sealed = a.seal(b"private");
487 assert_eq!(b.open(&sealed), Err(MaskError::Decrypt));
488 }
489
490 #[test]
491 fn masked_redacts_in_debug_and_display() {
492 let m: Masked = Masked::new("0712-secret");
495 assert_eq!(m.to_string(), REDACTED, "Display is redacted");
496 assert!(format!("{m:?}").contains("••••••"), "Debug is redacted");
497 }
498
499 #[test]
500 fn serialize_emits_ciphertext_not_plaintext() {
501 let (public, secret) = MaskKeyring::generate();
505 set_mask_keyring(MaskKeyring::from_base64(&public, Some(&secret)).unwrap());
506 let m: Masked = Masked::new("0712-secret");
507 let json = serde_json::to_string(&m).unwrap();
508 assert!(
509 !json.contains("0712-secret"),
510 "serialized form must not be the plaintext"
511 );
512 assert_ne!(
513 json,
514 format!("\"{REDACTED}\""),
515 "serialized form is ciphertext, not the redaction marker"
516 );
517 }
518
519 #[test]
520 fn in_memory_plaintext_reveals_without_keyring() {
521 let m: Masked = Masked::new("hello");
524 assert_eq!(m.reveal().unwrap(), "hello");
525 }
526}