toolkit_zero/serialization/
aead.rs1use chacha20poly1305::{
15 aead::{Aead, KeyInit},
16 ChaCha20Poly1305, Key, Nonce,
17};
18use bincode::{
19 config::standard,
20 encode_to_vec, decode_from_slice,
21 Encode, Decode,
22 error::{EncodeError, DecodeError},
23};
24use rand::RngCore as _;
25use zeroize::Zeroizing;
26
27#[derive(Debug)]
31pub enum SerializationError {
32 Encode(EncodeError),
34 Decode(DecodeError),
36 Cipher,
38}
39
40impl std::fmt::Display for SerializationError {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 Self::Encode(e) => write!(f, "seal encode error: {e}"),
44 Self::Decode(e) => write!(f, "open decode error: {e}"),
45 Self::Cipher => write!(f, "AEAD cipher error: wrong key or tampered ciphertext"),
46 }
47 }
48}
49
50impl std::error::Error for SerializationError {
51 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
52 match self {
53 Self::Encode(e) => Some(e),
54 Self::Decode(e) => Some(e),
55 Self::Cipher => None,
56 }
57 }
58}
59
60impl From<EncodeError> for SerializationError {
61 fn from(e: EncodeError) -> Self { Self::Encode(e) }
62}
63
64impl From<DecodeError> for SerializationError {
65 fn from(e: DecodeError) -> Self { Self::Decode(e) }
66}
67
68const DEFAULT_KEY: &str = "serialization/deserialization";
71const NONCE_LEN: usize = 12;
72
73pub fn seal<T, K>(value: &T, key: Option<K>) -> Result<Vec<u8>, SerializationError>
90where
91 T: Encode,
92 K: AsRef<str>,
93{
94 let key_str = key.as_ref().map(|k| k.as_ref()).unwrap_or(DEFAULT_KEY);
95 let cipher_key = derive_key(key_str.as_bytes());
96
97 let plain: Zeroizing<Vec<u8>> = Zeroizing::new(encode_to_vec(value, standard())?);
99
100 let mut nonce_bytes = [0u8; NONCE_LEN];
102 rand::rng().fill_bytes(&mut nonce_bytes);
103 let nonce = Nonce::from_slice(&nonce_bytes);
104
105 let cipher = ChaCha20Poly1305::new(Key::from_slice(&*cipher_key));
106 let ciphertext = cipher
107 .encrypt(nonce, plain.as_slice())
108 .map_err(|_| SerializationError::Cipher)?;
109
110 let mut blob = Vec::with_capacity(NONCE_LEN + ciphertext.len());
112 blob.extend_from_slice(&nonce_bytes);
113 blob.extend_from_slice(&ciphertext);
114 Ok(blob)
115}
116
117pub fn open<T, K>(blob: &[u8], key: Option<K>) -> Result<T, SerializationError>
129where
130 T: Decode<()>,
131 K: AsRef<str>,
132{
133 if blob.len() < NONCE_LEN {
134 return Err(SerializationError::Cipher);
135 }
136
137 let key_str = key.as_ref().map(|k| k.as_ref()).unwrap_or(DEFAULT_KEY);
138 let cipher_key = derive_key(key_str.as_bytes());
139
140 let nonce = Nonce::from_slice(&blob[..NONCE_LEN]);
141 let cipher = ChaCha20Poly1305::new(Key::from_slice(&*cipher_key));
142
143 let plain: Zeroizing<Vec<u8>> = Zeroizing::new(
144 cipher
145 .decrypt(nonce, &blob[NONCE_LEN..])
146 .map_err(|_| SerializationError::Cipher)?,
147 );
148
149 let (value, _): (T, _) = decode_from_slice(&*plain, standard())?;
150 Ok(value)
151}
152
153#[inline]
163fn derive_key(key_bytes: &[u8]) -> Zeroizing<[u8; 32]> {
164 use sha2::Digest as _;
165 let digest = sha2::Sha256::digest(key_bytes);
166 let mut out = Zeroizing::new([0u8; 32]);
167 out.copy_from_slice(digest.as_slice());
168 out
169}
170
171#[cfg(test)]
174mod tests {
175 use super::*;
176 use bincode::{Encode, Decode};
177
178 #[derive(Encode, Decode, Debug, PartialEq)]
179 struct Point { x: f64, y: f64, label: String }
180
181 #[derive(Encode, Decode, Debug, PartialEq)]
182 struct Nested { id: u64, inner: Point, tags: Vec<String> }
183
184 #[test]
185 fn round_trip_default_key() {
186 let p = Point { x: 1.5, y: -3.0, label: "origin".into() };
187 let blob = seal(&p, None::<&str>).unwrap();
188 let back: Point = open(&blob, None::<&str>).unwrap();
189 assert_eq!(p, back);
190 }
191
192 #[test]
193 fn round_trip_custom_key_str_literal() {
194 let p = Point { x: 42.0, y: 0.001, label: "custom".into() };
195 let blob = seal(&p, Some("hunter2")).unwrap();
196 let back: Point = open(&blob, Some("hunter2")).unwrap();
197 assert_eq!(p, back);
198 }
199
200 #[test]
201 fn round_trip_custom_key_owned_string() {
202 let p = Point { x: 42.0, y: 0.001, label: "owned".into() };
203 let key = String::from("hunter2");
204 let blob = seal(&p, Some(key.as_str())).unwrap();
205 let back: Point = open(&blob, Some(key.as_str())).unwrap();
206 assert_eq!(p, back);
207 }
208
209 #[test]
210 fn round_trip_nested() {
211 let n = Nested {
212 id: 9999,
213 inner: Point { x: -1.0, y: 2.5, label: "nested".into() },
214 tags: vec!["a".into(), "bb".into(), "ccc".into()],
215 };
216 let blob = seal(&n, Some("nested-key")).unwrap();
217 let back: Nested = open(&blob, Some("nested-key")).unwrap();
218 assert_eq!(n, back);
219 }
220
221 #[test]
222 fn wrong_key_fails() {
223 let p = Point { x: 1.0, y: 2.0, label: "x".into() };
224 let blob = seal(&p, Some("correct")).unwrap();
225 let result: Result<Point, _> = open(&blob, Some("wrong"));
226 assert!(result.is_err());
227 }
228
229 #[test]
230 fn ciphertext_differs_from_plaintext() {
231 let p = Point { x: 0.0, y: 0.0, label: "zero".into() };
232 let plain = bincode::encode_to_vec(&p, bincode::config::standard()).unwrap();
233 let blob = seal(&p, None::<&str>).unwrap();
234 assert_ne!(blob, plain);
235 }
236
237 #[test]
238 fn same_plaintext_same_key_produces_different_ciphertext_each_time() {
239 let p = Point { x: 1.0, y: 2.0, label: "det".into() };
241 let b1 = seal(&p, Some("k")).unwrap();
242 let b2 = seal(&p, Some("k")).unwrap();
243 assert_ne!(b1, b2, "random nonces must produce distinct ciphertexts");
244 }
245
246 #[test]
247 fn tampered_blob_rejected() {
248 let p = Point { x: 3.0, y: 4.0, label: "t".into() };
249 let mut blob = seal(&p, Some("key")).unwrap();
250 let len = blob.len();
252 blob[len - 1] ^= 0xff;
253 let result: Result<Point, _> = open(&blob, Some("key"));
254 assert!(result.is_err(), "tampered ciphertext must be rejected");
255 }
256}