1use crate::error::SessionError;
4use aead::{generic_array::GenericArray, Aead, NewAead};
5use aes_gcm::Aes256Gcm;
6use chacha20poly1305::ChaCha20Poly1305;
7use rand::rngs::OsRng;
8use rand::RngCore;
9use serde::de::DeserializeOwned;
10use serde::{Deserialize, Serialize};
11use std::marker::PhantomData;
12use time::OffsetDateTime;
13
14#[derive(Clone, Serialize, Deserialize, Eq, PartialEq)]
16pub struct Session<V> {
17 pub expires: Option<OffsetDateTime>,
19 pub value: Option<V>,
21}
22
23pub trait SessionManager<V: Serialize + DeserializeOwned>: Send + Sync {
25 fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError>;
31
32 fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError>;
38
39 fn is_encrypted(&self) -> bool;
41}
42
43pub struct ChaCha20Poly1305SessionManager<V: Serialize + DeserializeOwned> {
45 aead_key: [u8; 32],
46 _value: PhantomData<V>,
47}
48
49impl<V: Serialize + DeserializeOwned> ChaCha20Poly1305SessionManager<V> {
50 pub fn from_key(aead_key: [u8; 32]) -> Self {
52 ChaCha20Poly1305SessionManager {
53 aead_key: aead_key,
54 _value: PhantomData,
55 }
56 }
57
58 fn random_bytes(&self, buf: &mut [u8]) {
59 OsRng.fill_bytes(buf);
60 }
61
62 fn aead(&self) -> ChaCha20Poly1305 {
63 ChaCha20Poly1305::new(&GenericArray::clone_from_slice(&self.aead_key))
64 }
65}
66
67impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V>
68 for ChaCha20Poly1305SessionManager<V>
69{
70 fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
71 if bytes.len() <= 60 {
72 return Err(SessionError::ValidationError);
73 }
74
75 let nonce = GenericArray::from_slice(&bytes[0..12]);
76 let plaintext = self
77 .aead()
78 .decrypt(&nonce, bytes[12..].as_ref())
79 .map_err(|_| SessionError::InternalError)?;
80
81 serde_cbor::from_slice(&plaintext[32..plaintext.len()]).map_err(|err| {
83 warn!("Failed to deserialize session: {}", err);
84 SessionError::InternalError
85 })
86 }
87
88 fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
89 let session_bytes = serde_cbor::to_vec(&session).map_err(|err| {
90 warn!("Failed to serialize session: {}", err);
91 SessionError::InternalError
92 })?;
93
94 let mut padding = [0; 32];
95 self.random_bytes(&mut padding);
96
97 let mut plaintext = vec![0; session_bytes.len() + 32];
98 plaintext[0..32].copy_from_slice(&padding);
99 plaintext[32..].copy_from_slice(&session_bytes);
100
101 let mut nonce = [0; 12];
102 self.random_bytes(&mut nonce);
103 let nonce = GenericArray::from_slice(&nonce);
104
105 let ciphertext = self
106 .aead()
107 .encrypt(&nonce, plaintext.as_ref())
108 .map_err(|_| SessionError::InternalError)?;
109
110 let mut transport = vec![0; ciphertext.len() + 12];
111 transport[0..12].copy_from_slice(&nonce);
112 transport[12..].copy_from_slice(&ciphertext);
113
114 Ok(transport)
115 }
116
117 fn is_encrypted(&self) -> bool {
119 true
120 }
121}
122
123pub struct AesGcmSessionManager<V: Serialize + DeserializeOwned> {
125 aead_key: [u8; 32],
126 _value: PhantomData<V>,
127}
128
129impl<V: Serialize + DeserializeOwned> AesGcmSessionManager<V> {
130 pub fn from_key(aead_key: [u8; 32]) -> Self {
132 AesGcmSessionManager {
133 aead_key: aead_key,
134 _value: PhantomData,
135 }
136 }
137
138 fn random_bytes(&self, buf: &mut [u8]) {
139 OsRng.fill_bytes(buf);
140 }
141
142 fn aead(&self) -> Aes256Gcm {
143 Aes256Gcm::new(&GenericArray::clone_from_slice(&self.aead_key))
144 }
145}
146
147impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V> for AesGcmSessionManager<V> {
148 fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
149 if bytes.len() <= 60 {
150 return Err(SessionError::ValidationError);
151 }
152
153 let nonce = GenericArray::from_slice(&bytes[0..12]);
154 let plaintext = self
155 .aead()
156 .decrypt(&nonce, bytes[12..].as_ref())
157 .map_err(|_| SessionError::InternalError)?;
158
159 serde_cbor::from_slice(&plaintext[32..plaintext.len()]).map_err(|err| {
161 warn!("Failed to deserialize session: {}", err);
162 SessionError::InternalError
163 })
164 }
165
166 fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
167 let session_bytes = serde_cbor::to_vec(&session).map_err(|err| {
168 warn!("Failed to serialize session: {}", err);
169 SessionError::InternalError
170 })?;
171
172 let mut padding = [0; 32];
173 self.random_bytes(&mut padding);
174
175 let mut plaintext = vec![0; session_bytes.len() + 32];
176 plaintext[0..32].copy_from_slice(&padding);
177 plaintext[32..].copy_from_slice(&session_bytes);
178
179 let mut nonce = [0; 12];
180 self.random_bytes(&mut nonce);
181 let nonce = GenericArray::from_slice(&nonce);
182
183 let ciphertext = self
184 .aead()
185 .encrypt(&nonce, plaintext.as_ref())
186 .map_err(|_| SessionError::InternalError)?;
187
188 let mut transport = vec![0; ciphertext.len() + 12];
189 transport[0..12].copy_from_slice(&nonce);
190 transport[12..].copy_from_slice(&ciphertext);
191
192 Ok(transport)
193 }
194
195 fn is_encrypted(&self) -> bool {
197 true
198 }
199}
200
201pub struct MultiSessionManager<V: Serialize + DeserializeOwned + Send + Sync> {
206 current: Box<dyn SessionManager<V>>,
207 previous: Vec<Box<dyn SessionManager<V>>>,
208}
209
210impl<V: Serialize + DeserializeOwned + Send + Sync> MultiSessionManager<V> {
211 pub fn new(
214 current: Box<dyn SessionManager<V>>,
215 previous: Vec<Box<dyn SessionManager<V>>>,
216 ) -> Self {
217 Self { current, previous }
218 }
219}
220
221impl<V: Serialize + DeserializeOwned + Send + Sync> SessionManager<V> for MultiSessionManager<V> {
222 fn deserialize(&self, bytes: &[u8]) -> Result<Session<V>, SessionError> {
223 match self.current.deserialize(bytes) {
224 ok @ Ok(_) => return ok,
225 Err(_) => {
226 for manager in self.previous.iter() {
227 match manager.deserialize(bytes) {
228 ok @ Ok(_) => return ok,
229 Err(_) => (),
230 }
231 }
232 }
233 }
234 Err(SessionError::ValidationError)
235 }
236
237 fn serialize(&self, session: &Session<V>) -> Result<Vec<u8>, SessionError> {
238 self.current.serialize(session)
239 }
240
241 fn is_encrypted(&self) -> bool {
242 self.current.is_encrypted()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 const KEY_1: [u8; 32] = *b"01234567012345670123456701234567";
249 const KEY_2: [u8; 32] = *b"76543210765432107654321076543210";
250
251 macro_rules! test_cases {
252 ($strct: ident, $md: ident) => {
253 mod $md {
254 use super::KEY_1;
255 use serde::{Deserialize, Serialize};
256 use $crate::error::SessionError;
257 use $crate::session::{$strct, Session, SessionManager};
258
259 #[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
260 struct Data {
261 string: String,
262 }
263
264 #[test]
265 fn serde_happy_path() {
266 let manager = $strct::from_key(KEY_1);
267 let data = Data {
268 string: "boots and cats".to_string(),
269 };
270 let session = Session {
271 expires: None,
272 value: Some(data.clone()),
273 };
274 let bytes = manager.serialize(&session).expect("couldn't serialize");
275 let parsed_session = manager.deserialize(&bytes).expect("couldn't deserialize");
276 assert_eq!(parsed_session.value, Some(data));
277 }
278
279 #[test]
280 fn serde_bad_data_end() {
281 let manager = $strct::from_key(KEY_1);
282 let data = Data {
283 string: "boots and cats".to_string(),
284 };
285 let session = Session {
286 expires: None,
287 value: Some(data.clone()),
288 };
289 let mut bytes = manager.serialize(&session).expect("couldn't serialize");
290 let len = bytes.len();
291 bytes[len - 1] ^= 0x01;
292
293 let deserialized: Result<Session<Data>, SessionError> =
294 manager.deserialize(&bytes);
295 assert!(deserialized.is_err());
296 }
297
298 #[test]
299 fn serde_bad_data_start() {
300 let manager = $strct::from_key(KEY_1);
301 let data = Data {
302 string: "boots and cats".to_string(),
303 };
304 let session = Session {
305 expires: None,
306 value: Some(data.clone()),
307 };
308
309 let mut bytes = manager.serialize(&session).expect("couldn't serialize");
310 bytes[0] ^= 0x01;
311
312 let deserialized: Result<Session<Data>, SessionError> =
313 manager.deserialize(&bytes);
314 assert!(deserialized.is_err());
315 }
316 }
317 };
318 }
319
320 test_cases!(AesGcmSessionManager, aesgcm);
321 test_cases!(ChaCha20Poly1305SessionManager, chacha20poly1305);
322
323 mod multi {
324 macro_rules! test_cases {
325 ($strct1: ident, $strct2: ident, $name: ident) => {
326 mod $name {
327 use super::super::{KEY_1, KEY_2};
328 use $crate::session::*;
329
330 #[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
331 struct Data {
332 string: String,
333 }
334
335 #[test]
336 fn no_previous() {
337 let manager = $strct1::from_key(KEY_1);
338 let mut sessions = vec![];
339
340 let data = Data {
341 string: "boots and cats".to_string(),
342 };
343 let session = Session {
344 expires: None,
345 value: Some(data.clone()),
346 };
347 let bytes = manager.serialize(&session).expect("couldn't serialize");
348 sessions.push(bytes);
349
350 let multi = MultiSessionManager::new(Box::new(manager), vec![]);
351 let bytes = multi.serialize(&session).expect("couldn't serialize");
352 sessions.push(bytes);
353
354 for session in sessions.iter() {
355 let parsed_session =
356 multi.deserialize(session).expect("couldn't deserialize");
357 assert_eq!(parsed_session.value, Some(data.clone()));
358 }
359 }
360
361 #[test]
362 fn $name() {
363 let manager_1 = $strct1::from_key(KEY_1);
364 let manager_2 = $strct2::from_key(KEY_2);
365 let mut sessions = vec![];
366
367 let data = Data {
368 string: "boots and cats".to_string(),
369 };
370 let session = Session {
371 expires: None,
372 value: Some(data.clone()),
373 };
374 let bytes = manager_1.serialize(&session).expect("couldn't serialize");
375 sessions.push(bytes);
376
377 let bytes = manager_2.serialize(&session).expect("couldn't serialize");
378 sessions.push(bytes);
379
380 let multi = MultiSessionManager::new(
381 Box::new(manager_1),
382 vec![Box::new(manager_2)],
383 );
384 let bytes = multi.serialize(&session).expect("couldn't serialize");
385 sessions.push(bytes);
386
387 for session in sessions.iter() {
388 let parsed_session =
389 multi.deserialize(session).expect("couldn't deserialize");
390 assert_eq!(parsed_session.value, Some(data.clone()));
391 }
392 }
393 }
394 };
395 }
396
397 test_cases!(
398 AesGcmSessionManager,
399 AesGcmSessionManager,
400 aesgcm_then_aesgcm
401 );
402
403 test_cases!(
404 ChaCha20Poly1305SessionManager,
405 ChaCha20Poly1305SessionManager,
406 chacha20poly1305_then_chacha20poly1305
407 );
408
409 test_cases!(
410 ChaCha20Poly1305SessionManager,
411 AesGcmSessionManager,
412 chacha20poly1305_then_aesgcm
413 );
414
415 test_cases!(
416 AesGcmSessionManager,
417 ChaCha20Poly1305SessionManager,
418 aesgcm_then_chacha20poly1305
419 );
420 }
421}