1use chacha20poly1305::{aead::Aead, KeyInit};
2use rand_core::RngCore;
3use sha2::{Digest, Sha256};
4use thiserror::Error;
5use x25519_dalek::{PublicKey, StaticSecret};
6use zeroize::Zeroizing;
7
8#[derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
9#[allow(dead_code)]
10pub struct MasterKey(Zeroizing<[u8; 32]>);
11
12#[derive(Debug, Error)]
13pub enum KdfError {
14 #[error("invalid kdf parameters")]
15 InvalidParams(argon2::Error),
16 #[error("key derivation failed")]
17 DerivationFailed(argon2::Error),
18}
19
20const MIB: u32 = 1024;
21const MEMORY_COST_KIB: u32 = 64 * MIB;
22
23pub fn derive_master_key(pass: &str, salt: &[u8]) -> Result<MasterKey, KdfError> {
25 let mut key = Zeroizing::new([0u8; 32]);
26
27 let params =
28 argon2::Params::new(MEMORY_COST_KIB, 3, 1, None).map_err(KdfError::InvalidParams)?;
29
30 let argon2 = argon2::Argon2::new(argon2::Algorithm::Argon2id, argon2::Version::V0x13, params);
31
32 argon2
33 .hash_password_into(pass.as_bytes(), salt, key.as_mut())
34 .map_err(KdfError::DerivationFailed)?;
35
36 Ok(MasterKey(key))
37}
38
39#[derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
40pub struct Dek(Zeroizing<[u8; 32]>);
41impl Dek {
42 pub fn as_bytes(&self) -> &[u8; 32] {
43 &self.0
44 }
45
46 pub fn from_bytes(bytes: &[u8; 32]) -> Result<Self, &'static str> {
47 Ok(Dek(Zeroizing::new(*bytes)))
48 }
49}
50
51pub fn generate_dek() -> Dek {
53 let mut key = Zeroizing::new([0u8; 32]);
54 rand_core::OsRng.fill_bytes(key.as_mut());
55 Dek(key)
56}
57
58pub struct Nonce(pub [u8; 24]);
59pub struct Ciphertext(pub Vec<u8>);
60
61#[derive(Debug, Error)]
62pub enum EncryptError {
63 #[error("AEAD encryption failed")]
64 AeadFailed(chacha20poly1305::aead::Error),
65}
66
67pub fn encrypt(
69 plaintext: &[u8],
70 dek: &Dek,
71 aad: &[u8],
72) -> Result<(Nonce, Ciphertext), EncryptError> {
73 let key = chacha20poly1305::Key::from(*dek.as_bytes());
74 let cipher = chacha20poly1305::XChaCha20Poly1305::new(&key);
75
76 let mut nonce_bytes = [0u8; 24];
77 rand_core::OsRng.fill_bytes(&mut nonce_bytes);
78
79 let nonce = chacha20poly1305::XNonce::from(nonce_bytes);
80 let ct = cipher
81 .encrypt(
82 &nonce,
83 chacha20poly1305::aead::Payload {
84 msg: plaintext,
85 aad,
86 },
87 )
88 .map_err(EncryptError::AeadFailed)?;
89
90 Ok((Nonce(nonce_bytes), Ciphertext(ct)))
91}
92
93#[derive(Debug, Error)]
94pub enum DecryptError {
95 #[error("AEAD decryption failed")]
96 AeadFailed(chacha20poly1305::aead::Error),
97}
98
99pub fn decrypt(
101 ciphertext: &[u8],
102 nonce: &Nonce,
103 dek: &Dek,
104 aad: &[u8],
105) -> Result<Zeroizing<Vec<u8>>, DecryptError> {
106 let key = chacha20poly1305::Key::from(*dek.as_bytes());
107 let cipher = chacha20poly1305::XChaCha20Poly1305::new(&key);
108
109 let nonce = chacha20poly1305::XNonce::from(nonce.0);
110
111 let pt = cipher
112 .decrypt(
113 &nonce,
114 chacha20poly1305::aead::Payload {
115 msg: ciphertext,
116 aad,
117 },
118 )
119 .map_err(DecryptError::AeadFailed)?;
120
121 Ok(Zeroizing::new(pt))
122}
123
124pub struct Keypair {
130 secret: StaticSecret,
131 public: PublicKey,
132}
133
134impl Keypair {
135 pub fn generate() -> Self {
137 let secret = StaticSecret::random_from_rng(rand_core::OsRng);
138 let public = PublicKey::from(&secret);
139 Self { secret, public }
140 }
141
142 pub fn from_secret_bytes(bytes: &[u8; 32]) -> Self {
144 let secret = StaticSecret::from(*bytes);
145 let public = PublicKey::from(&secret);
146 Self { secret, public }
147 }
148
149 pub fn secret_key_bytes(&self) -> [u8; 32] {
151 self.secret.to_bytes()
152 }
153
154 pub fn public_key(&self) -> &PublicKey {
156 &self.public
157 }
158
159 pub fn public_key_bytes(&self) -> [u8; 32] {
161 *self.public.as_bytes()
162 }
163
164 pub fn shared_secret(&self, their_public: &PublicKey) -> SharedSecret {
166 let secret_bytes = self.secret.diffie_hellman(their_public);
167 SharedSecret(Zeroizing::new(*secret_bytes.as_bytes()))
168 }
169}
170
171impl zeroize::ZeroizeOnDrop for Keypair {}
172
173#[derive(zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
175pub struct SharedSecret(Zeroizing<[u8; 32]>);
176
177impl SharedSecret {
178 fn as_bytes(&self) -> &[u8; 32] {
179 &self.0
180 }
181}
182
183pub fn public_key_from_bytes(bytes: &[u8]) -> Result<PublicKey, &'static str> {
185 if bytes.len() != 32 {
186 return Err("public key must be 32 bytes");
187 }
188 let mut array = [0u8; 32];
189 array.copy_from_slice(bytes);
190 Ok(PublicKey::from(array))
191}
192
193#[derive(Debug, Error)]
194pub enum WrapError {
195 #[error("AEAD encryption failed")]
196 AeadFailed(chacha20poly1305::aead::Error),
197}
198
199pub fn wrap_key(
201 key: &[u8],
202 shared_secret: &SharedSecret,
203 aad: &[u8],
204) -> Result<(Nonce, Ciphertext), WrapError> {
205 let cipher_key = chacha20poly1305::Key::from(*shared_secret.as_bytes());
206 let cipher = chacha20poly1305::XChaCha20Poly1305::new(&cipher_key);
207
208 let mut nonce_bytes = [0u8; 24];
209 rand_core::OsRng.fill_bytes(&mut nonce_bytes);
210
211 let nonce = chacha20poly1305::XNonce::from(nonce_bytes);
212 let ct = cipher
213 .encrypt(&nonce, chacha20poly1305::aead::Payload { msg: key, aad })
214 .map_err(WrapError::AeadFailed)?;
215
216 Ok((Nonce(nonce_bytes), Ciphertext(ct)))
217}
218
219#[derive(Debug, Error)]
220pub enum UnwrapError {
221 #[error("AEAD decryption failed")]
222 AeadFailed(chacha20poly1305::aead::Error),
223}
224
225pub fn unwrap_key(
227 wrapped: &[u8],
228 nonce: &Nonce,
229 shared_secret: &SharedSecret,
230 aad: &[u8],
231) -> Result<Zeroizing<Vec<u8>>, UnwrapError> {
232 let cipher_key = chacha20poly1305::Key::from(*shared_secret.as_bytes());
233 let cipher = chacha20poly1305::XChaCha20Poly1305::new(&cipher_key);
234
235 let nonce = chacha20poly1305::XNonce::from(nonce.0);
236
237 let pt = cipher
238 .decrypt(
239 &nonce,
240 chacha20poly1305::aead::Payload { msg: wrapped, aad },
241 )
242 .map_err(UnwrapError::AeadFailed)?;
243
244 Ok(Zeroizing::new(pt))
245}
246
247pub fn hash_sha256(data: &[u8]) -> [u8; 32] {
253 let mut hasher = Sha256::new();
254 hasher.update(data);
255 hasher.finalize().into()
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn crypto_round_trip_basic() {
264 let salt = b"not_random_salt_just_for_test";
265 let master = derive_master_key("password", salt).unwrap();
266 let dek = generate_dek();
267
268 let plaintext = b"super-secret";
269 let aad = b"project:foo|env:dev|key:DB_PASSWORD";
270
271 let (nonce, ct) = encrypt(plaintext, &dek, aad).unwrap();
272 let decrypted = decrypt(&ct.0, &nonce, &dek, aad).unwrap();
273
274 assert_eq!(plaintext, &decrypted[..]);
275 drop(master);
276 }
277
278 #[test]
279 fn decrypt_fails_on_tamper() {
280 let dek = generate_dek();
281 let (nonce, mut ct) = encrypt(b"hello", &dek, b"aad").unwrap();
282
283 ct.0[0] ^= 0x01;
285 assert!(decrypt(&ct.0, &nonce, &dek, b"aad").is_err());
286
287 let (nonce2, ct2) = encrypt(b"hello", &dek, b"aad").unwrap();
289 assert!(decrypt(&ct2.0, &nonce2, &dek, b"other").is_err());
290 }
291
292 #[test]
293 fn tampering_ciphertext_fails() {
294 let dek = generate_dek();
295 let (nonce, mut ct) = encrypt(b"hello", &dek, b"aad").unwrap();
296
297 ct.0[0] ^= 0x01;
299
300 assert!(decrypt(&ct.0, &nonce, &dek, b"aad").is_err());
301 }
302
303 #[test]
304 fn tampering_nonce_fails() {
305 let dek = generate_dek();
306 let (nonce, ct) = encrypt(b"hello", &dek, b"aad").unwrap();
307
308 let mut bad_nonce = nonce;
309 bad_nonce.0[0] ^= 0x01;
310
311 assert!(decrypt(&ct.0, &bad_nonce, &dek, b"aad").is_err());
312 }
313
314 #[test]
315 fn tampering_aad_fails() {
316 let dek = generate_dek();
317 let (nonce, ct) = encrypt(b"hello", &dek, b"good-aad").unwrap();
318
319 assert!(decrypt(&ct.0, &nonce, &dek, b"bad-aad").is_err());
320 }
321
322 #[test]
323 fn empty_plaintext_ok() {
324 let dek = generate_dek();
325 let (nonce, ct) = encrypt(b"", &dek, b"aad").unwrap();
326 let dec = decrypt(&ct.0, &nonce, &dek, b"aad").unwrap();
327 assert_eq!(dec.len(), 0);
328 }
329
330 #[test]
331 fn kdf_fails_on_short_salt() {
332 assert!(derive_master_key("pwd", b"short").is_err());
333 }
334
335 #[test]
336 fn sensitive_types_impl_zeroize() {
337 fn assert_zeroize<T: zeroize::Zeroize>() {}
338 assert_zeroize::<Dek>();
339 assert_zeroize::<MasterKey>();
340 assert_zeroize::<SharedSecret>();
341 }
342
343 #[test]
346 fn keypair_generation() {
347 let kp = Keypair::generate();
348 let pk_bytes = kp.public_key_bytes();
349 assert_eq!(pk_bytes.len(), 32);
350 }
351
352 #[test]
353 fn public_key_roundtrip() {
354 let kp = Keypair::generate();
355 let bytes = kp.public_key_bytes();
356 let pk = public_key_from_bytes(&bytes).unwrap();
357 assert_eq!(pk.as_bytes(), &bytes);
358 }
359
360 #[test]
361 fn public_key_from_bytes_validates_length() {
362 assert!(public_key_from_bytes(&[0u8; 31]).is_err());
363 assert!(public_key_from_bytes(&[0u8; 33]).is_err());
364 assert!(public_key_from_bytes(&[0u8; 32]).is_ok());
365 }
366
367 #[test]
368 fn ecdh_shared_secret_is_symmetric() {
369 let alice = Keypair::generate();
370 let bob = Keypair::generate();
371
372 let alice_shared = alice.shared_secret(bob.public_key());
373 let bob_shared = bob.shared_secret(alice.public_key());
374
375 assert_eq!(alice_shared.as_bytes(), bob_shared.as_bytes());
377 }
378
379 #[test]
380 fn key_wrap_unwrap_roundtrip() {
381 let alice = Keypair::generate();
382 let bob = Keypair::generate();
383
384 let kek = b"workspace-key-encryption-key-32b";
386 let shared = alice.shared_secret(bob.public_key());
387 let aad = b"workspace:uuid-here";
388 let (nonce, wrapped) = wrap_key(kek, &shared, aad).unwrap();
389
390 let bob_shared = bob.shared_secret(alice.public_key());
392 let unwrapped = unwrap_key(&wrapped.0, &nonce, &bob_shared, aad).unwrap();
393
394 assert_eq!(&unwrapped[..], kek);
395 }
396
397 #[test]
398 fn key_unwrap_fails_with_wrong_key() {
399 let alice = Keypair::generate();
400 let bob = Keypair::generate();
401 let eve = Keypair::generate();
402
403 let kek = b"secret-key";
404 let shared = alice.shared_secret(bob.public_key());
405 let (nonce, wrapped) = wrap_key(kek, &shared, b"aad").unwrap();
406
407 let eve_shared = eve.shared_secret(alice.public_key());
409 assert!(unwrap_key(&wrapped.0, &nonce, &eve_shared, b"aad").is_err());
410 }
411
412 #[test]
413 fn key_unwrap_fails_with_tampered_ciphertext() {
414 let alice = Keypair::generate();
415 let bob = Keypair::generate();
416
417 let kek = b"secret-key";
418 let shared = alice.shared_secret(bob.public_key());
419 let (nonce, mut wrapped) = wrap_key(kek, &shared, b"aad").unwrap();
420
421 wrapped.0[0] ^= 0x01;
423
424 let bob_shared = bob.shared_secret(alice.public_key());
425 assert!(unwrap_key(&wrapped.0, &nonce, &bob_shared, b"aad").is_err());
426 }
427
428 #[test]
429 fn key_unwrap_fails_with_wrong_aad() {
430 let alice = Keypair::generate();
431 let bob = Keypair::generate();
432
433 let kek = b"secret-key";
434 let shared = alice.shared_secret(bob.public_key());
435 let (nonce, wrapped) = wrap_key(kek, &shared, b"good-aad").unwrap();
436
437 let bob_shared = bob.shared_secret(alice.public_key());
438 assert!(unwrap_key(&wrapped.0, &nonce, &bob_shared, b"bad-aad").is_err());
439 }
440}