tower_sessions_ext_core/session.rs
1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3 collections::HashMap,
4 fmt::{self, Display},
5 hash::Hash,
6 result,
7 str::{self, FromStr},
8 sync::{
9 Arc,
10 atomic::{self, AtomicBool},
11 },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24/// Callback invoked when a session is discovered to have expired (store returned `None` for a
25/// previously valid session id). Useful for cleanup (e.g. revoking tokens, logging out elsewhere).
26pub type OnExpireCallback = Arc<dyn Fn(Id) + Send + Sync>;
27
28type Result<T> = result::Result<T, Error>;
29
30type Data = HashMap<String, Value>;
31
32/// Session errors.
33#[derive(thiserror::Error, Debug)]
34pub enum Error {
35 /// Maps `serde_json` errors.
36 #[error(transparent)]
37 SerdeJson(#[from] serde_json::Error),
38
39 /// Maps `session_store::Error` errors.
40 #[error(transparent)]
41 Store(#[from] session_store::Error),
42}
43
44#[derive(Debug)]
45struct Inner {
46 // This will be `None` when:
47 //
48 // 1. We have not been provided a session cookie or have failed to parse it,
49 // 2. The store has not found the session.
50 //
51 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
52 session_id: parking_lot::Mutex<Option<Id>>,
53
54 // A lazy representation of the session's value, hydrated on a just-in-time basis. A
55 // `None` value indicates we have not tried to access it yet. After access, it will always
56 // contain `Some(Record)`.
57 record: Mutex<Option<Record>>,
58
59 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
60 expiry: parking_lot::Mutex<Option<Expiry>>,
61
62 is_modified: AtomicBool,
63}
64
65/// A session which allows HTTP applications to associate key-value pairs with
66/// visitors.
67#[derive(Clone)]
68pub struct Session {
69 store: Arc<dyn SessionStore>,
70 inner: Arc<Inner>,
71 on_expire: Option<OnExpireCallback>,
72}
73
74impl fmt::Debug for Session {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_struct("Session")
77 .field("store", &self.store)
78 .field("inner", &self.inner)
79 .field("on_expire", &self.on_expire.as_ref().map(|_| "Some(_)"))
80 .finish()
81 }
82}
83
84impl Session {
85 /// Creates a new session with the session ID, store, and expiry.
86 ///
87 /// This method is lazy and does not invoke the overhead of talking to the
88 /// backing store.
89 ///
90 /// The optional `on_expire` callback is invoked when the store returns `None`
91 /// for a session id (e.g. session expired or was deleted).
92 ///
93 /// # Examples
94 ///
95 /// ```rust
96 /// use std::sync::Arc;
97 ///
98 /// use tower_sessions_ext::{MemoryStore, Session};
99 ///
100 /// let store = Arc::new(MemoryStore::default());
101 /// Session::new(None, store, None, None);
102 /// ```
103 pub fn new(
104 session_id: Option<Id>,
105 store: Arc<impl SessionStore>,
106 expiry: Option<Expiry>,
107 on_expire: Option<OnExpireCallback>,
108 ) -> Self {
109 let inner = Inner {
110 session_id: parking_lot::Mutex::new(session_id),
111 record: Mutex::new(None), // `None` indicates we have not loaded from store.
112 expiry: parking_lot::Mutex::new(expiry),
113 is_modified: AtomicBool::new(false),
114 };
115
116 Self {
117 store,
118 inner: Arc::new(inner),
119 on_expire,
120 }
121 }
122
123 fn create_record(&self) -> Record {
124 Record::new(self.expiry_date())
125 }
126
127 #[tracing::instrument(skip(self), err, level = "trace")]
128 async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
129 let mut record_guard = self.inner.record.lock().await;
130
131 // Lazily load the record since `None` here indicates we have no yet loaded it.
132 if record_guard.is_none() {
133 tracing::trace!("record not loaded from store; loading");
134
135 let session_id = *self.inner.session_id.lock();
136 *record_guard = Some(if let Some(session_id) = session_id {
137 match self.store.load(&session_id).await? {
138 Some(loaded_record) => {
139 tracing::trace!("record found in store");
140 loaded_record
141 }
142
143 None => {
144 // A well-behaved user agent should not send session cookies after
145 // expiration. Even so it's possible for an expired session to be removed
146 // from the store after a request was initiated. However, such a race should
147 // be relatively uncommon and as such entering this branch could indicate
148 // malicious behavior.
149 tracing::warn!("possibly suspicious activity: record not found in store");
150 if let Some(ref on_expire) = self.on_expire {
151 on_expire(session_id);
152 }
153 *self.inner.session_id.lock() = None;
154 self.create_record()
155 }
156 }
157 } else {
158 tracing::trace!("session id not found");
159 self.create_record()
160 })
161 }
162
163 Ok(MutexGuard::map(record_guard, |opt| {
164 opt.as_mut()
165 .expect("Record should always be `Option::Some` at this point")
166 }))
167 }
168
169 /// Inserts a `impl Serialize` value into the session.
170 ///
171 /// # Examples
172 ///
173 /// ```rust
174 /// # tokio_test::block_on(async {
175 /// use std::sync::Arc;
176 ///
177 /// use tower_sessions_ext::{MemoryStore, Session};
178 ///
179 /// let store = Arc::new(MemoryStore::default());
180 /// let session = Session::new(None, store, None, None);
181 ///
182 /// session.insert("foo", 42).await.unwrap();
183 ///
184 /// let value = session.get::<usize>("foo").await.unwrap();
185 /// assert_eq!(value, Some(42));
186 /// # });
187 /// ```
188 ///
189 /// # Errors
190 ///
191 /// - This method can fail when [`serde_json::to_value`] fails.
192 /// - If the session has not been hydrated and loading from the store fails,
193 /// we fail with [`Error::Store`].
194 pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
195 self.insert_value(key, serde_json::to_value(&value)?)
196 .await?;
197 Ok(())
198 }
199
200 /// Inserts a `serde_json::Value` into the session.
201 ///
202 /// If the key was not present in the underlying map, `None` is returned and
203 /// `modified` is set to `true`.
204 ///
205 /// If the underlying map did have the key and its value is the same as the
206 /// provided value, `None` is returned and `modified` is not set.
207 ///
208 /// # Examples
209 ///
210 /// ```rust
211 /// # tokio_test::block_on(async {
212 /// use std::sync::Arc;
213 ///
214 /// use tower_sessions_ext::{MemoryStore, Session};
215 ///
216 /// let store = Arc::new(MemoryStore::default());
217 /// let session = Session::new(None, store, None, None);
218 ///
219 /// let value = session
220 /// .insert_value("foo", serde_json::json!(42))
221 /// .await
222 /// .unwrap();
223 /// assert!(value.is_none());
224 ///
225 /// let value = session
226 /// .insert_value("foo", serde_json::json!(42))
227 /// .await
228 /// .unwrap();
229 /// assert!(value.is_none());
230 ///
231 /// let value = session
232 /// .insert_value("foo", serde_json::json!("bar"))
233 /// .await
234 /// .unwrap();
235 /// assert_eq!(value, Some(serde_json::json!(42)));
236 /// # });
237 /// ```
238 ///
239 /// # Errors
240 ///
241 /// - If the session has not been hydrated and loading from the store fails,
242 /// we fail with [`Error::Store`].
243 pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
244 let mut record_guard = self.get_record().await?;
245 Ok(if record_guard.data.get(key) != Some(&value) {
246 self.inner
247 .is_modified
248 .store(true, atomic::Ordering::Release);
249 record_guard.data.insert(key.to_string(), value)
250 } else {
251 None
252 })
253 }
254
255 /// Gets a value from the store.
256 ///
257 /// # Examples
258 ///
259 /// ```rust
260 /// # tokio_test::block_on(async {
261 /// use std::sync::Arc;
262 ///
263 /// use tower_sessions_ext::{MemoryStore, Session};
264 ///
265 /// let store = Arc::new(MemoryStore::default());
266 /// let session = Session::new(None, store, None, None);
267 ///
268 /// session.insert("foo", 42).await.unwrap();
269 ///
270 /// let value = session.get::<usize>("foo").await.unwrap();
271 /// assert_eq!(value, Some(42));
272 /// # });
273 /// ```
274 ///
275 /// # Errors
276 ///
277 /// - This method can fail when [`serde_json::from_value`] fails.
278 /// - If the session has not been hydrated and loading from the store fails,
279 /// we fail with [`Error::Store`].
280 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
281 Ok(self
282 .get_value(key)
283 .await?
284 .map(serde_json::from_value)
285 .transpose()?)
286 }
287
288 /// Gets a `serde_json::Value` from the store.
289 ///
290 /// # Examples
291 ///
292 /// ```rust
293 /// # tokio_test::block_on(async {
294 /// use std::sync::Arc;
295 ///
296 /// use tower_sessions_ext::{MemoryStore, Session};
297 ///
298 /// let store = Arc::new(MemoryStore::default());
299 /// let session = Session::new(None, store, None, None);
300 ///
301 /// session.insert("foo", 42).await.unwrap();
302 ///
303 /// let value = session.get_value("foo").await.unwrap().unwrap();
304 /// assert_eq!(value, serde_json::json!(42));
305 /// # });
306 /// ```
307 ///
308 /// # Errors
309 ///
310 /// - If the session has not been hydrated and loading from the store fails,
311 /// we fail with [`Error::Store`].
312 pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
313 let record_guard = self.get_record().await?;
314 Ok(record_guard.data.get(key).cloned())
315 }
316
317 /// Removes a value from the store, retuning the value of the key if it was
318 /// present in the underlying map.
319 ///
320 /// # Examples
321 ///
322 /// ```rust
323 /// # tokio_test::block_on(async {
324 /// use std::sync::Arc;
325 ///
326 /// use tower_sessions_ext::{MemoryStore, Session};
327 ///
328 /// let store = Arc::new(MemoryStore::default());
329 /// let session = Session::new(None, store, None, None);
330 ///
331 /// session.insert("foo", 42).await.unwrap();
332 ///
333 /// let value: Option<usize> = session.remove("foo").await.unwrap();
334 /// assert_eq!(value, Some(42));
335 ///
336 /// let value: Option<usize> = session.get("foo").await.unwrap();
337 /// assert!(value.is_none());
338 /// # });
339 /// ```
340 ///
341 /// # Errors
342 ///
343 /// - This method can fail when [`serde_json::from_value`] fails.
344 /// - If the session has not been hydrated and loading from the store fails,
345 /// we fail with [`Error::Store`].
346 pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
347 Ok(self
348 .remove_value(key)
349 .await?
350 .map(serde_json::from_value)
351 .transpose()?)
352 }
353
354 /// Removes a `serde_json::Value` from the session.
355 ///
356 /// # Examples
357 ///
358 /// ```rust
359 /// # tokio_test::block_on(async {
360 /// use std::sync::Arc;
361 ///
362 /// use tower_sessions_ext::{MemoryStore, Session};
363 ///
364 /// let store = Arc::new(MemoryStore::default());
365 /// let session = Session::new(None, store, None, None);
366 ///
367 /// session.insert("foo", 42).await.unwrap();
368 /// let value = session.remove_value("foo").await.unwrap().unwrap();
369 /// assert_eq!(value, serde_json::json!(42));
370 ///
371 /// let value: Option<usize> = session.get("foo").await.unwrap();
372 /// assert!(value.is_none());
373 /// # });
374 /// ```
375 ///
376 /// # Errors
377 ///
378 /// - If the session has not been hydrated and loading from the store fails,
379 /// we fail with [`Error::Store`].
380 pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
381 let mut record_guard = self.get_record().await?;
382 self.inner
383 .is_modified
384 .store(true, atomic::Ordering::Release);
385 Ok(record_guard.data.remove(key))
386 }
387
388 /// Clears the session of all data but does not delete it from the store.
389 ///
390 /// # Examples
391 ///
392 /// ```rust
393 /// # tokio_test::block_on(async {
394 /// use std::sync::Arc;
395 ///
396 /// use tower_sessions_ext::{MemoryStore, Session};
397 ///
398 /// let store = Arc::new(MemoryStore::default());
399 ///
400 /// let session = Session::new(None, store.clone(), None, None);
401 /// session.insert("foo", 42).await.unwrap();
402 /// assert!(!session.is_empty().await);
403 ///
404 /// session.save().await.unwrap();
405 ///
406 /// session.clear().await;
407 ///
408 /// // Not empty! (We have an ID still.)
409 /// assert!(!session.is_empty().await);
410 /// // Data is cleared...
411 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
412 ///
413 /// // ...data is cleared before loading from the backend...
414 /// let session = Session::new(session.id(), store.clone(), None, None);
415 /// session.clear().await;
416 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
417 ///
418 /// let session = Session::new(session.id(), store, None, None);
419 /// // ...but data is not deleted from the store.
420 /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
421 /// # });
422 /// ```
423 pub async fn clear(&self) {
424 let mut record_guard = self.inner.record.lock().await;
425 if let Some(record) = record_guard.as_mut() {
426 record.data.clear();
427 } else if let Some(session_id) = *self.inner.session_id.lock() {
428 let mut new_record = self.create_record();
429 new_record.id = session_id;
430 *record_guard = Some(new_record);
431 }
432
433 self.inner
434 .is_modified
435 .store(true, atomic::Ordering::Release);
436 }
437
438 /// Returns `true` if there is no session ID and the session is empty.
439 ///
440 /// # Examples
441 ///
442 /// ```rust
443 /// # tokio_test::block_on(async {
444 /// use std::sync::Arc;
445 ///
446 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
447 ///
448 /// let store = Arc::new(MemoryStore::default());
449 ///
450 /// let session = Session::new(None, store.clone(), None, None);
451 /// // Empty if we have no ID and record is not loaded.
452 /// assert!(session.is_empty().await);
453 ///
454 /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
455 /// // Not empty if we have an ID but no record. (Record is not loaded here.)
456 /// assert!(!session.is_empty().await);
457 ///
458 /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
459 /// session.insert("foo", 42).await.unwrap();
460 /// // Not empty after inserting.
461 /// assert!(!session.is_empty().await);
462 /// session.save().await.unwrap();
463 /// // Not empty after saving.
464 /// assert!(!session.is_empty().await);
465 ///
466 /// let session = Session::new(session.id(), store.clone(), None, None);
467 /// session.load().await.unwrap();
468 /// // Not empty after loading from store...
469 /// assert!(!session.is_empty().await);
470 /// // ...and not empty after accessing the session.
471 /// session.get::<usize>("foo").await.unwrap();
472 /// assert!(!session.is_empty().await);
473 ///
474 /// let session = Session::new(session.id(), store.clone(), None, None);
475 /// session.delete().await.unwrap();
476 /// // Not empty after deleting from store...
477 /// assert!(!session.is_empty().await);
478 /// session.get::<usize>("foo").await.unwrap();
479 /// // ...but empty after trying to access the deleted session.
480 /// assert!(session.is_empty().await);
481 ///
482 /// let session = Session::new(None, store, None, None);
483 /// session.insert("foo", 42).await.unwrap();
484 /// session.flush().await.unwrap();
485 /// // Empty after flushing.
486 /// assert!(session.is_empty().await);
487 /// # });
488 /// ```
489 pub async fn is_empty(&self) -> bool {
490 let record_guard = self.inner.record.lock().await;
491
492 // N.B.: Session IDs are `None` if:
493 //
494 // 1. The cookie was not provided or otherwise could not be parsed,
495 // 2. Or the session could not be loaded from the store.
496 let session_id = self.inner.session_id.lock();
497
498 let Some(record) = record_guard.as_ref() else {
499 return session_id.is_none();
500 };
501
502 session_id.is_none() && record.data.is_empty()
503 }
504
505 /// Get the session ID.
506 ///
507 /// # Examples
508 ///
509 /// ```rust
510 /// use std::sync::Arc;
511 ///
512 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
513 ///
514 /// let store = Arc::new(MemoryStore::default());
515 ///
516 /// let session = Session::new(None, store.clone(), None, None);
517 /// assert!(session.id().is_none());
518 ///
519 /// let id = Some(Id::default());
520 /// let session = Session::new(id, store, None, None);
521 /// assert_eq!(id, session.id());
522 /// ```
523 pub fn id(&self) -> Option<Id> {
524 *self.inner.session_id.lock()
525 }
526
527 /// Get the session expiry.
528 ///
529 /// # Examples
530 ///
531 /// ```rust
532 /// use std::sync::Arc;
533 ///
534 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
535 ///
536 /// let store = Arc::new(MemoryStore::default());
537 /// let session = Session::new(None, store, None, None);
538 ///
539 /// assert_eq!(session.expiry(), None);
540 /// ```
541 pub fn expiry(&self) -> Option<Expiry> {
542 *self.inner.expiry.lock()
543 }
544
545 /// Set `expiry` to the given value.
546 ///
547 /// This may be used within applications directly to alter the session's
548 /// time to live.
549 ///
550 /// # Examples
551 ///
552 /// ```rust
553 /// use std::sync::Arc;
554 ///
555 /// use time::OffsetDateTime;
556 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
557 ///
558 /// let store = Arc::new(MemoryStore::default());
559 /// let session = Session::new(None, store, None, None);
560 ///
561 /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
562 /// session.set_expiry(Some(expiry));
563 ///
564 /// assert_eq!(session.expiry(), Some(expiry));
565 /// ```
566 pub fn set_expiry(&self, expiry: Option<Expiry>) {
567 *self.inner.expiry.lock() = expiry;
568 self.inner
569 .is_modified
570 .store(true, atomic::Ordering::Release);
571 }
572
573 /// Get session expiry as `OffsetDateTime`.
574 ///
575 /// # Examples
576 ///
577 /// ```rust
578 /// use std::sync::Arc;
579 ///
580 /// use time::{Duration, OffsetDateTime};
581 /// use tower_sessions_ext::{MemoryStore, Session};
582 ///
583 /// let store = Arc::new(MemoryStore::default());
584 /// let session = Session::new(None, store, None, None);
585 ///
586 /// // Our default duration is two weeks.
587 /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
588 ///
589 /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
590 /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
591 /// ```
592 pub fn expiry_date(&self) -> OffsetDateTime {
593 let expiry = self.inner.expiry.lock();
594 match *expiry {
595 Some(Expiry::OnInactivity(duration)) => {
596 OffsetDateTime::now_utc().saturating_add(duration)
597 }
598 Some(Expiry::AtDateTime(datetime)) => datetime,
599 Some(Expiry::OnSessionEnd(datetime)) => {
600 OffsetDateTime::now_utc().saturating_add(datetime)
601 },
602 None =>
603 OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION)
604 }
605 }
606
607 /// Get session expiry as `Duration`.
608 ///
609 /// # Examples
610 ///
611 /// ```rust
612 /// use std::sync::Arc;
613 ///
614 /// use time::Duration;
615 /// use tower_sessions_ext::{MemoryStore, Session};
616 ///
617 /// let store = Arc::new(MemoryStore::default());
618 /// let session = Session::new(None, store, None, None);
619 ///
620 /// let expected_duration = Duration::weeks(2);
621 ///
622 /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
623 /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
624 /// ```
625 pub fn expiry_age(&self) -> Duration {
626 std::cmp::max(
627 self.expiry_date() - OffsetDateTime::now_utc(),
628 Duration::ZERO,
629 )
630 }
631
632 /// Returns `true` if the session has been modified during the request.
633 ///
634 /// # Examples
635 ///
636 /// ```rust
637 /// # tokio_test::block_on(async {
638 /// use std::sync::Arc;
639 ///
640 /// use tower_sessions_ext::{MemoryStore, Session};
641 ///
642 /// let store = Arc::new(MemoryStore::default());
643 /// let session = Session::new(None, store, None, None);
644 ///
645 /// // Not modified initially.
646 /// assert!(!session.is_modified());
647 ///
648 /// // Getting doesn't count as a modification.
649 /// session.get::<usize>("foo").await.unwrap();
650 /// assert!(!session.is_modified());
651 ///
652 /// // Insertions and removals do though.
653 /// session.insert("foo", 42).await.unwrap();
654 /// assert!(session.is_modified());
655 /// # });
656 /// ```
657 pub fn is_modified(&self) -> bool {
658 self.inner.is_modified.load(atomic::Ordering::Acquire)
659 }
660
661 /// Saves the session record to the store.
662 ///
663 /// Note that this method is generally not needed and is reserved for
664 /// situations where the session store must be updated during the
665 /// request.
666 ///
667 /// # Examples
668 ///
669 /// ```rust
670 /// # tokio_test::block_on(async {
671 /// use std::sync::Arc;
672 ///
673 /// use tower_sessions_ext::{MemoryStore, Session};
674 ///
675 /// let store = Arc::new(MemoryStore::default());
676 /// let session = Session::new(None, store.clone(), None, None);
677 ///
678 /// session.insert("foo", 42).await.unwrap();
679 /// session.save().await.unwrap();
680 ///
681 /// let session = Session::new(session.id(), store, None, None);
682 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
683 /// # });
684 /// ```
685 ///
686 /// # Errors
687 ///
688 /// - If saving to the store fails, we fail with [`Error::Store`].
689 #[tracing::instrument(skip(self), err, level = "trace")]
690 pub async fn save(&self) -> Result<()> {
691 let mut record_guard = self.get_record().await?;
692 record_guard.expiry_date = self.expiry_date();
693
694 // Session ID is `None` if:
695 //
696 // 1. No valid cookie was found on the request or,
697 // 2. No valid session was found in the store.
698 //
699 // In either case, we must create a new session via the store interface.
700 //
701 // Potential ID collisions must be handled by session store implementers.
702 if self.inner.session_id.lock().is_none() {
703 self.store.create(&mut record_guard).await?;
704 *self.inner.session_id.lock() = Some(record_guard.id);
705 } else {
706 self.store.save(&record_guard).await?;
707 }
708 Ok(())
709 }
710
711 /// Loads the session record from the store.
712 ///
713 /// Note that this method is generally not needed and is reserved for
714 /// situations where the session must be updated during the request.
715 ///
716 /// # Examples
717 ///
718 /// ```rust
719 /// # tokio_test::block_on(async {
720 /// use std::sync::Arc;
721 ///
722 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
723 ///
724 /// let store = Arc::new(MemoryStore::default());
725 /// let id = Some(Id::default());
726 /// let session = Session::new(id, store.clone(), None, None);
727 ///
728 /// session.insert("foo", 42).await.unwrap();
729 /// session.save().await.unwrap();
730 ///
731 /// let session = Session::new(session.id(), store, None, None);
732 /// session.load().await.unwrap();
733 ///
734 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
735 /// # });
736 /// ```
737 ///
738 /// # Errors
739 ///
740 /// - If loading from the store fails, we fail with [`Error::Store`].
741 #[tracing::instrument(skip(self), err, level = "trace")]
742 pub async fn load(&self) -> Result<()> {
743 let session_id = *self.inner.session_id.lock();
744 let Some(ref id) = session_id else {
745 tracing::warn!("called load with no session id");
746 return Ok(());
747 };
748 let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
749 let mut record_guard = self.inner.record.lock().await;
750 *record_guard = loaded_record;
751 Ok(())
752 }
753
754 /// Deletes the session from the store.
755 ///
756 /// # Examples
757 ///
758 /// ```rust
759 /// # tokio_test::block_on(async {
760 /// use std::sync::Arc;
761 ///
762 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
763 ///
764 /// let store = Arc::new(MemoryStore::default());
765 /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
766 ///
767 /// // Save before deleting.
768 /// session.save().await.unwrap();
769 ///
770 /// // Delete from the store.
771 /// session.delete().await.unwrap();
772 ///
773 /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
774 /// # });
775 /// ```
776 ///
777 /// # Errors
778 ///
779 /// - If deleting from the store fails, we fail with [`Error::Store`].
780 #[tracing::instrument(skip(self), err, level = "trace")]
781 pub async fn delete(&self) -> Result<()> {
782 let session_id = *self.inner.session_id.lock();
783 let Some(ref session_id) = session_id else {
784 tracing::warn!("called delete with no session id");
785 return Ok(());
786 };
787 self.store.delete(session_id).await.map_err(Error::Store)?;
788 Ok(())
789 }
790
791 /// Flushes the session by removing all data contained in the session and
792 /// then deleting it from the store.
793 ///
794 /// # Examples
795 ///
796 /// ```rust
797 /// # tokio_test::block_on(async {
798 /// use std::sync::Arc;
799 ///
800 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
801 ///
802 /// let store = Arc::new(MemoryStore::default());
803 /// let session = Session::new(None, store.clone(), None, None);
804 ///
805 /// session.insert("foo", "bar").await.unwrap();
806 /// session.save().await.unwrap();
807 ///
808 /// let id = session.id().unwrap();
809 ///
810 /// session.flush().await.unwrap();
811 ///
812 /// assert!(session.id().is_none());
813 /// assert!(session.is_empty().await);
814 /// assert!(store.load(&id).await.unwrap().is_none());
815 /// # });
816 /// ```
817 ///
818 /// # Errors
819 ///
820 /// - If deleting from the store fails, we fail with [`Error::Store`].
821 pub async fn flush(&self) -> Result<()> {
822 self.clear().await;
823 self.delete().await?;
824 *self.inner.session_id.lock() = None;
825 Ok(())
826 }
827
828 /// Cycles the session ID while retaining any data that was associated with
829 /// it.
830 ///
831 /// Using this method helps prevent session fixation attacks by ensuring a
832 /// new ID is assigned to the session.
833 ///
834 /// # Examples
835 ///
836 /// ```rust
837 /// # tokio_test::block_on(async {
838 /// use std::sync::Arc;
839 ///
840 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
841 ///
842 /// let store = Arc::new(MemoryStore::default());
843 /// let session = Session::new(None, store.clone(), None, None);
844 ///
845 /// session.insert("foo", 42).await.unwrap();
846 /// session.save().await.unwrap();
847 /// let id = session.id();
848 ///
849 /// let session = Session::new(session.id(), store.clone(), None, None);
850 /// session.cycle_id().await.unwrap();
851 ///
852 /// assert!(!session.is_empty().await);
853 /// assert!(session.is_modified());
854 ///
855 /// session.save().await.unwrap();
856 ///
857 /// let session = Session::new(session.id(), store, None, None);
858 ///
859 /// assert_ne!(id, session.id());
860 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
861 /// # });
862 /// ```
863 ///
864 /// # Errors
865 ///
866 /// - If deleting from the store fails or saving to the store fails, we fail
867 /// with [`Error::Store`].
868 pub async fn cycle_id(&self) -> Result<()> {
869 let mut record_guard = self.get_record().await?;
870
871 let old_session_id = record_guard.id;
872 record_guard.id = Id::default();
873 *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
874 // `create` method.
875
876 self.store
877 .delete(&old_session_id)
878 .await
879 .map_err(Error::Store)?;
880
881 self.inner
882 .is_modified
883 .store(true, atomic::Ordering::Release);
884
885 Ok(())
886 }
887}
888
889/// ID type for sessions.
890///
891/// Wraps an array of 16 bytes.
892///
893/// # Examples
894///
895/// ```rust
896/// use tower_sessions_ext::session::Id;
897///
898/// Id::default();
899/// ```
900#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
901pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
902// session ID, which is undesirable.
903
904impl Default for Id {
905 fn default() -> Self {
906 use rand::prelude::*;
907
908 Self(rand::rng().random())
909 }
910}
911
912impl Display for Id {
913 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
914 let mut encoded = [0; 22];
915 URL_SAFE_NO_PAD
916 .encode_slice(self.0.to_le_bytes(), &mut encoded)
917 .expect("Encoded ID must be exactly 22 bytes");
918 let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
919
920 f.write_str(encoded)
921 }
922}
923
924impl FromStr for Id {
925 type Err = base64::DecodeSliceError;
926
927 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
928 let mut decoded = [0; 16];
929 let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
930 if bytes_decoded != 16 {
931 let err = DecodeError::InvalidLength(bytes_decoded);
932 return Err(base64::DecodeSliceError::DecodeError(err));
933 }
934
935 Ok(Self(i128::from_le_bytes(decoded)))
936 }
937}
938
939/// Record type that's appropriate for encoding and decoding sessions to and
940/// from session stores.
941#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
942pub struct Record {
943 pub id: Id,
944 pub data: Data,
945 pub expiry_date: OffsetDateTime,
946}
947
948impl Record {
949 fn new(expiry_date: OffsetDateTime) -> Self {
950 Self {
951 id: Id::default(),
952 data: Data::default(),
953 expiry_date,
954 }
955 }
956}
957
958/// Session expiry configuration.
959///
960/// # Examples
961///
962/// ```rust
963/// use time::{Duration, OffsetDateTime};
964/// use tower_sessions_ext::Expiry;
965///
966/// // Will be expired on "session end".
967/// let expiry = Expiry::OnSessionEnd;
968///
969/// // Will be expired in five minutes from last acitve.
970/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
971///
972/// // Will be expired at the given timestamp.
973/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
974/// let expiry = Expiry::AtDateTime(expired_at);
975/// ```
976#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
977pub enum Expiry {
978 /// Expire on [current session end][current-session-end], as defined by the
979 /// browser.
980 ///
981 /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
982 OnSessionEnd(Duration),
983
984 /// Expire on inactivity.
985 ///
986 /// Reading a session is not considered activity for expiration purposes.
987 /// [`Session`] expiration is computed from the last time the session was
988 /// _modified_.
989 OnInactivity(Duration),
990
991 /// Expire at a specific date and time.
992 ///
993 /// This value may be extended manually with
994 /// [`set_expiry`](Session::set_expiry).
995 AtDateTime(OffsetDateTime),
996}
997
998#[cfg(test)]
999mod tests {
1000 use async_trait::async_trait;
1001 use mockall::{
1002 mock,
1003 predicate::{self, always},
1004 };
1005
1006 use super::*;
1007
1008 mock! {
1009 #[derive(Debug)]
1010 pub Store {}
1011
1012 #[async_trait]
1013 impl SessionStore for Store {
1014 async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1015 async fn save(&self, record: &Record) -> session_store::Result<()>;
1016 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1017 async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1018 }
1019 }
1020
1021 #[tokio::test]
1022 async fn test_cycle_id() {
1023 let mut mock_store = MockStore::new();
1024
1025 let initial_id = Id::default();
1026 let new_id = Id::default();
1027
1028 // Set up expectations for the mock store
1029 mock_store
1030 .expect_save()
1031 .with(always())
1032 .times(1)
1033 .returning(|_| Ok(()));
1034 mock_store
1035 .expect_load()
1036 .with(predicate::eq(initial_id))
1037 .times(1)
1038 .returning(move |_| {
1039 Ok(Some(Record {
1040 id: initial_id,
1041 data: Data::default(),
1042 expiry_date: OffsetDateTime::now_utc(),
1043 }))
1044 });
1045 mock_store
1046 .expect_delete()
1047 .with(predicate::eq(initial_id))
1048 .times(1)
1049 .returning(|_| Ok(()));
1050 mock_store
1051 .expect_create()
1052 .times(1)
1053 .returning(move |record| {
1054 record.id = new_id;
1055 Ok(())
1056 });
1057
1058 let store = Arc::new(mock_store);
1059 let session = Session::new(Some(initial_id), store.clone(), None, None);
1060
1061 // Insert some data and save the session
1062 session.insert("foo", 42).await.unwrap();
1063 session.save().await.unwrap();
1064
1065 // Cycle the session ID
1066 session.cycle_id().await.unwrap();
1067
1068 // Verify that the session ID has changed and the data is still present
1069 assert_ne!(session.id(), Some(initial_id));
1070 assert!(session.id().is_none()); // The session ID should be None
1071 assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1072
1073 // Save the session to update the ID in the session object
1074 session.save().await.unwrap();
1075 assert_eq!(session.id(), Some(new_id));
1076 }
1077}