tower_sessions_ext_core/session.rs
1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3 collections::HashMap,
4 fmt::{self, Display},
5 hash::Hash,
6 result,
7 str::{self, FromStr},
8 sync::{
9 Arc,
10 atomic::{self, AtomicBool},
11 },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31 /// Maps `serde_json` errors.
32 #[error(transparent)]
33 SerdeJson(#[from] serde_json::Error),
34
35 /// Maps `session_store::Error` errors.
36 #[error(transparent)]
37 Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42 // This will be `None` when:
43 //
44 // 1. We have not been provided a session cookie or have failed to parse it,
45 // 2. The store has not found the session.
46 //
47 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48 session_id: parking_lot::Mutex<Option<Id>>,
49
50 // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51 // `None` value indicates we have not tried to access it yet. After access, it will always
52 // contain `Some(Record)`.
53 record: Mutex<Option<Record>>,
54
55 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56 expiry: parking_lot::Mutex<Option<Expiry>>,
57
58 is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Clone)]
64pub struct Session {
65 store: Arc<dyn SessionStore>,
66 inner: Arc<Inner>,
67}
68
69impl fmt::Debug for Session {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("Session")
72 .field("store", &self.store)
73 .field("inner", &self.inner)
74 .finish()
75 }
76}
77
78impl Session {
79 /// Creates a new session with the session ID, store, and expiry.
80 ///
81 /// This method is lazy and does not invoke the overhead of talking to the
82 /// backing store.
83 ///
84 /// # Examples
85 ///
86 /// ```rust
87 /// use std::sync::Arc;
88 ///
89 /// use tower_sessions_ext::{MemoryStore, Session};
90 ///
91 /// let store = Arc::new(MemoryStore::default());
92 /// Session::new(None, store, None);
93 /// ```
94 pub fn new(
95 session_id: Option<Id>,
96 store: Arc<impl SessionStore>,
97 expiry: Option<Expiry>,
98 ) -> Self {
99 let inner = Inner {
100 session_id: parking_lot::Mutex::new(session_id),
101 record: Mutex::new(None), // `None` indicates we have not loaded from store.
102 expiry: parking_lot::Mutex::new(expiry),
103 is_modified: AtomicBool::new(false),
104 };
105
106 Self {
107 store,
108 inner: Arc::new(inner),
109 }
110 }
111
112 fn create_record(&self) -> Record {
113 Record::new(self.expiry_date())
114 }
115
116 #[tracing::instrument(skip(self), err, level = "trace")]
117 async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
118 let mut record_guard = self.inner.record.lock().await;
119
120 // Lazily load the record since `None` here indicates we have no yet loaded it.
121 if record_guard.is_none() {
122 tracing::trace!("record not loaded from store; loading");
123
124 let session_id = *self.inner.session_id.lock();
125 *record_guard = Some(if let Some(session_id) = session_id {
126 match self.store.load(&session_id).await? {
127 Some(loaded_record) => {
128 tracing::trace!("record found in store");
129 loaded_record
130 }
131
132 None => {
133 // A well-behaved user agent should not send session cookies after
134 // expiration. Even so it's possible for an expired session to be removed
135 // from the store after a request was initiated. However, such a race should
136 // be relatively uncommon and as such entering this branch could indicate
137 // malicious behavior.
138 tracing::warn!("possibly suspicious activity: record not found in store");
139 *self.inner.session_id.lock() = None;
140 self.create_record()
141 }
142 }
143 } else {
144 tracing::trace!("session id not found");
145 self.create_record()
146 })
147 }
148
149 Ok(MutexGuard::map(record_guard, |opt| {
150 opt.as_mut()
151 .expect("Record should always be `Option::Some` at this point")
152 }))
153 }
154
155 /// Inserts a `impl Serialize` value into the session.
156 ///
157 /// # Examples
158 ///
159 /// ```rust
160 /// # tokio_test::block_on(async {
161 /// use std::sync::Arc;
162 ///
163 /// use tower_sessions_ext::{MemoryStore, Session};
164 ///
165 /// let store = Arc::new(MemoryStore::default());
166 /// let session = Session::new(None, store, None);
167 ///
168 /// session.insert("foo", 42).await.unwrap();
169 ///
170 /// let value = session.get::<usize>("foo").await.unwrap();
171 /// assert_eq!(value, Some(42));
172 /// # });
173 /// ```
174 ///
175 /// # Errors
176 ///
177 /// - This method can fail when [`serde_json::to_value`] fails.
178 /// - If the session has not been hydrated and loading from the store fails,
179 /// we fail with [`Error::Store`].
180 pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
181 self.insert_value(key, serde_json::to_value(&value)?)
182 .await?;
183 Ok(())
184 }
185
186 /// Inserts a `serde_json::Value` into the session.
187 ///
188 /// If the key was not present in the underlying map, `None` is returned and
189 /// `modified` is set to `true`.
190 ///
191 /// If the underlying map did have the key and its value is the same as the
192 /// provided value, `None` is returned and `modified` is not set.
193 ///
194 /// # Examples
195 ///
196 /// ```rust
197 /// # tokio_test::block_on(async {
198 /// use std::sync::Arc;
199 ///
200 /// use tower_sessions_ext::{MemoryStore, Session};
201 ///
202 /// let store = Arc::new(MemoryStore::default());
203 /// let session = Session::new(None, store, None);
204 ///
205 /// let value = session
206 /// .insert_value("foo", serde_json::json!(42))
207 /// .await
208 /// .unwrap();
209 /// assert!(value.is_none());
210 ///
211 /// let value = session
212 /// .insert_value("foo", serde_json::json!(42))
213 /// .await
214 /// .unwrap();
215 /// assert!(value.is_none());
216 ///
217 /// let value = session
218 /// .insert_value("foo", serde_json::json!("bar"))
219 /// .await
220 /// .unwrap();
221 /// assert_eq!(value, Some(serde_json::json!(42)));
222 /// # });
223 /// ```
224 ///
225 /// # Errors
226 ///
227 /// - If the session has not been hydrated and loading from the store fails,
228 /// we fail with [`Error::Store`].
229 pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
230 let mut record_guard = self.get_record().await?;
231 Ok(if record_guard.data.get(key) != Some(&value) {
232 self.inner
233 .is_modified
234 .store(true, atomic::Ordering::Release);
235 record_guard.data.insert(key.to_string(), value)
236 } else {
237 None
238 })
239 }
240
241 /// Gets a value from the store.
242 ///
243 /// # Examples
244 ///
245 /// ```rust
246 /// # tokio_test::block_on(async {
247 /// use std::sync::Arc;
248 ///
249 /// use tower_sessions_ext::{MemoryStore, Session};
250 ///
251 /// let store = Arc::new(MemoryStore::default());
252 /// let session = Session::new(None, store, None);
253 ///
254 /// session.insert("foo", 42).await.unwrap();
255 ///
256 /// let value = session.get::<usize>("foo").await.unwrap();
257 /// assert_eq!(value, Some(42));
258 /// # });
259 /// ```
260 ///
261 /// # Errors
262 ///
263 /// - This method can fail when [`serde_json::from_value`] fails.
264 /// - If the session has not been hydrated and loading from the store fails,
265 /// we fail with [`Error::Store`].
266 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
267 Ok(self
268 .get_value(key)
269 .await?
270 .map(serde_json::from_value)
271 .transpose()?)
272 }
273
274 /// Gets a `serde_json::Value` from the store.
275 ///
276 /// # Examples
277 ///
278 /// ```rust
279 /// # tokio_test::block_on(async {
280 /// use std::sync::Arc;
281 ///
282 /// use tower_sessions_ext::{MemoryStore, Session};
283 ///
284 /// let store = Arc::new(MemoryStore::default());
285 /// let session = Session::new(None, store, None);
286 ///
287 /// session.insert("foo", 42).await.unwrap();
288 ///
289 /// let value = session.get_value("foo").await.unwrap().unwrap();
290 /// assert_eq!(value, serde_json::json!(42));
291 /// # });
292 /// ```
293 ///
294 /// # Errors
295 ///
296 /// - If the session has not been hydrated and loading from the store fails,
297 /// we fail with [`Error::Store`].
298 pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
299 let record_guard = self.get_record().await?;
300 Ok(record_guard.data.get(key).cloned())
301 }
302
303 /// Removes a value from the store, retuning the value of the key if it was
304 /// present in the underlying map.
305 ///
306 /// # Examples
307 ///
308 /// ```rust
309 /// # tokio_test::block_on(async {
310 /// use std::sync::Arc;
311 ///
312 /// use tower_sessions_ext::{MemoryStore, Session};
313 ///
314 /// let store = Arc::new(MemoryStore::default());
315 /// let session = Session::new(None, store, None);
316 ///
317 /// session.insert("foo", 42).await.unwrap();
318 ///
319 /// let value: Option<usize> = session.remove("foo").await.unwrap();
320 /// assert_eq!(value, Some(42));
321 ///
322 /// let value: Option<usize> = session.get("foo").await.unwrap();
323 /// assert!(value.is_none());
324 /// # });
325 /// ```
326 ///
327 /// # Errors
328 ///
329 /// - This method can fail when [`serde_json::from_value`] fails.
330 /// - If the session has not been hydrated and loading from the store fails,
331 /// we fail with [`Error::Store`].
332 pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
333 Ok(self
334 .remove_value(key)
335 .await?
336 .map(serde_json::from_value)
337 .transpose()?)
338 }
339
340 /// Removes a `serde_json::Value` from the session.
341 ///
342 /// # Examples
343 ///
344 /// ```rust
345 /// # tokio_test::block_on(async {
346 /// use std::sync::Arc;
347 ///
348 /// use tower_sessions_ext::{MemoryStore, Session};
349 ///
350 /// let store = Arc::new(MemoryStore::default());
351 /// let session = Session::new(None, store, None);
352 ///
353 /// session.insert("foo", 42).await.unwrap();
354 /// let value = session.remove_value("foo").await.unwrap().unwrap();
355 /// assert_eq!(value, serde_json::json!(42));
356 ///
357 /// let value: Option<usize> = session.get("foo").await.unwrap();
358 /// assert!(value.is_none());
359 /// # });
360 /// ```
361 ///
362 /// # Errors
363 ///
364 /// - If the session has not been hydrated and loading from the store fails,
365 /// we fail with [`Error::Store`].
366 pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
367 let mut record_guard = self.get_record().await?;
368 self.inner
369 .is_modified
370 .store(true, atomic::Ordering::Release);
371 Ok(record_guard.data.remove(key))
372 }
373
374 /// Clears the session of all data but does not delete it from the store.
375 ///
376 /// # Examples
377 ///
378 /// ```rust
379 /// # tokio_test::block_on(async {
380 /// use std::sync::Arc;
381 ///
382 /// use tower_sessions_ext::{MemoryStore, Session};
383 ///
384 /// let store = Arc::new(MemoryStore::default());
385 ///
386 /// let session = Session::new(None, store.clone(), None);
387 /// session.insert("foo", 42).await.unwrap();
388 /// assert!(!session.is_empty().await);
389 ///
390 /// session.save().await.unwrap();
391 ///
392 /// session.clear().await;
393 ///
394 /// // Not empty! (We have an ID still.)
395 /// assert!(!session.is_empty().await);
396 /// // Data is cleared...
397 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
398 ///
399 /// // ...data is cleared before loading from the backend...
400 /// let session = Session::new(session.id(), store.clone(), None);
401 /// session.clear().await;
402 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
403 ///
404 /// let session = Session::new(session.id(), store, None);
405 /// // ...but data is not deleted from the store.
406 /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
407 /// # });
408 /// ```
409 pub async fn clear(&self) {
410 let mut record_guard = self.inner.record.lock().await;
411 if let Some(record) = record_guard.as_mut() {
412 record.data.clear();
413 } else if let Some(session_id) = *self.inner.session_id.lock() {
414 let mut new_record = self.create_record();
415 new_record.id = session_id;
416 *record_guard = Some(new_record);
417 }
418
419 self.inner
420 .is_modified
421 .store(true, atomic::Ordering::Release);
422 }
423
424 /// Returns `true` if there is no session ID and the session is empty.
425 ///
426 /// # Examples
427 ///
428 /// ```rust
429 /// # tokio_test::block_on(async {
430 /// use std::sync::Arc;
431 ///
432 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
433 ///
434 /// let store = Arc::new(MemoryStore::default());
435 ///
436 /// let session = Session::new(None, store.clone(), None);
437 /// // Empty if we have no ID and record is not loaded.
438 /// assert!(session.is_empty().await);
439 ///
440 /// let session = Session::new(Some(Id::default()), store.clone(), None);
441 /// // Not empty if we have an ID but no record. (Record is not loaded here.)
442 /// assert!(!session.is_empty().await);
443 ///
444 /// let session = Session::new(Some(Id::default()), store.clone(), None);
445 /// session.insert("foo", 42).await.unwrap();
446 /// // Not empty after inserting.
447 /// assert!(!session.is_empty().await);
448 /// session.save().await.unwrap();
449 /// // Not empty after saving.
450 /// assert!(!session.is_empty().await);
451 ///
452 /// let session = Session::new(session.id(), store.clone(), None);
453 /// session.load().await.unwrap();
454 /// // Not empty after loading from store...
455 /// assert!(!session.is_empty().await);
456 /// // ...and not empty after accessing the session.
457 /// session.get::<usize>("foo").await.unwrap();
458 /// assert!(!session.is_empty().await);
459 ///
460 /// let session = Session::new(session.id(), store.clone(), None);
461 /// session.delete().await.unwrap();
462 /// // Not empty after deleting from store...
463 /// assert!(!session.is_empty().await);
464 /// session.get::<usize>("foo").await.unwrap();
465 /// // ...but empty after trying to access the deleted session.
466 /// assert!(session.is_empty().await);
467 ///
468 /// let session = Session::new(None, store, None);
469 /// session.insert("foo", 42).await.unwrap();
470 /// session.flush().await.unwrap();
471 /// // Empty after flushing.
472 /// assert!(session.is_empty().await);
473 /// # });
474 /// ```
475 pub async fn is_empty(&self) -> bool {
476 let record_guard = self.inner.record.lock().await;
477
478 // N.B.: Session IDs are `None` if:
479 //
480 // 1. The cookie was not provided or otherwise could not be parsed,
481 // 2. Or the session could not be loaded from the store.
482 let session_id = self.inner.session_id.lock();
483
484 let Some(record) = record_guard.as_ref() else {
485 return session_id.is_none();
486 };
487
488 session_id.is_none() && record.data.is_empty()
489 }
490
491 /// Get the session ID.
492 ///
493 /// # Examples
494 ///
495 /// ```rust
496 /// use std::sync::Arc;
497 ///
498 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
499 ///
500 /// let store = Arc::new(MemoryStore::default());
501 ///
502 /// let session = Session::new(None, store.clone(), None);
503 /// assert!(session.id().is_none());
504 ///
505 /// let id = Some(Id::default());
506 /// let session = Session::new(id, store, None);
507 /// assert_eq!(id, session.id());
508 /// ```
509 pub fn id(&self) -> Option<Id> {
510 *self.inner.session_id.lock()
511 }
512
513 /// Get the session expiry.
514 ///
515 /// # Examples
516 ///
517 /// ```rust
518 /// use std::sync::Arc;
519 ///
520 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
521 ///
522 /// let store = Arc::new(MemoryStore::default());
523 /// let session = Session::new(None, store, None);
524 ///
525 /// assert_eq!(session.expiry(), None);
526 /// ```
527 pub fn expiry(&self) -> Option<Expiry> {
528 *self.inner.expiry.lock()
529 }
530
531 /// Set `expiry` to the given value.
532 ///
533 /// This may be used within applications directly to alter the session's
534 /// time to live.
535 ///
536 /// # Examples
537 ///
538 /// ```rust
539 /// use std::sync::Arc;
540 ///
541 /// use time::OffsetDateTime;
542 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
543 ///
544 /// let store = Arc::new(MemoryStore::default());
545 /// let session = Session::new(None, store, None);
546 ///
547 /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
548 /// session.set_expiry(Some(expiry));
549 ///
550 /// assert_eq!(session.expiry(), Some(expiry));
551 /// ```
552 pub fn set_expiry(&self, expiry: Option<Expiry>) {
553 *self.inner.expiry.lock() = expiry;
554 self.inner
555 .is_modified
556 .store(true, atomic::Ordering::Release);
557 }
558
559 /// Get session expiry as `OffsetDateTime`.
560 ///
561 /// # Examples
562 ///
563 /// ```rust
564 /// use std::sync::Arc;
565 ///
566 /// use time::{Duration, OffsetDateTime};
567 /// use tower_sessions_ext::{MemoryStore, Session};
568 ///
569 /// let store = Arc::new(MemoryStore::default());
570 /// let session = Session::new(None, store, None);
571 ///
572 /// // Our default duration is two weeks.
573 /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
574 ///
575 /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
576 /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
577 /// ```
578 pub fn expiry_date(&self) -> OffsetDateTime {
579 let expiry = self.inner.expiry.lock();
580 match *expiry {
581 Some(Expiry::OnInactivity(duration)) => {
582 OffsetDateTime::now_utc().saturating_add(duration)
583 }
584 Some(Expiry::AtDateTime(datetime)) => datetime,
585 Some(Expiry::OnSessionEnd(datetime)) => {
586 OffsetDateTime::now_utc().saturating_add(datetime)
587 },
588 None =>
589 OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION)
590 }
591 }
592
593 /// Get session expiry as `Duration`.
594 ///
595 /// # Examples
596 ///
597 /// ```rust
598 /// use std::sync::Arc;
599 ///
600 /// use time::Duration;
601 /// use tower_sessions_ext::{MemoryStore, Session};
602 ///
603 /// let store = Arc::new(MemoryStore::default());
604 /// let session = Session::new(None, store, None);
605 ///
606 /// let expected_duration = Duration::weeks(2);
607 ///
608 /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
609 /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
610 /// ```
611 pub fn expiry_age(&self) -> Duration {
612 std::cmp::max(
613 self.expiry_date() - OffsetDateTime::now_utc(),
614 Duration::ZERO,
615 )
616 }
617
618 /// Returns `true` if the session has been modified during the request.
619 ///
620 /// # Examples
621 ///
622 /// ```rust
623 /// # tokio_test::block_on(async {
624 /// use std::sync::Arc;
625 ///
626 /// use tower_sessions_ext::{MemoryStore, Session};
627 ///
628 /// let store = Arc::new(MemoryStore::default());
629 /// let session = Session::new(None, store, None);
630 ///
631 /// // Not modified initially.
632 /// assert!(!session.is_modified());
633 ///
634 /// // Getting doesn't count as a modification.
635 /// session.get::<usize>("foo").await.unwrap();
636 /// assert!(!session.is_modified());
637 ///
638 /// // Insertions and removals do though.
639 /// session.insert("foo", 42).await.unwrap();
640 /// assert!(session.is_modified());
641 /// # });
642 /// ```
643 pub fn is_modified(&self) -> bool {
644 self.inner.is_modified.load(atomic::Ordering::Acquire)
645 }
646
647 /// Saves the session record to the store.
648 ///
649 /// Note that this method is generally not needed and is reserved for
650 /// situations where the session store must be updated during the
651 /// request.
652 ///
653 /// # Examples
654 ///
655 /// ```rust
656 /// # tokio_test::block_on(async {
657 /// use std::sync::Arc;
658 ///
659 /// use tower_sessions_ext::{MemoryStore, Session};
660 ///
661 /// let store = Arc::new(MemoryStore::default());
662 /// let session = Session::new(None, store.clone(), None);
663 ///
664 /// session.insert("foo", 42).await.unwrap();
665 /// session.save().await.unwrap();
666 ///
667 /// let session = Session::new(session.id(), store, None);
668 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
669 /// # });
670 /// ```
671 ///
672 /// # Errors
673 ///
674 /// - If saving to the store fails, we fail with [`Error::Store`].
675 #[tracing::instrument(skip(self), err, level = "trace")]
676 pub async fn save(&self) -> Result<()> {
677 let mut record_guard = self.get_record().await?;
678 record_guard.expiry_date = self.expiry_date();
679
680 // Session ID is `None` if:
681 //
682 // 1. No valid cookie was found on the request or,
683 // 2. No valid session was found in the store.
684 //
685 // In either case, we must create a new session via the store interface.
686 //
687 // Potential ID collisions must be handled by session store implementers.
688 if self.inner.session_id.lock().is_none() {
689 self.store.create(&mut record_guard).await?;
690 *self.inner.session_id.lock() = Some(record_guard.id);
691 } else {
692 self.store.save(&record_guard).await?;
693 }
694 Ok(())
695 }
696
697 /// Loads the session record from the store.
698 ///
699 /// Note that this method is generally not needed and is reserved for
700 /// situations where the session must be updated during the request.
701 ///
702 /// # Examples
703 ///
704 /// ```rust
705 /// # tokio_test::block_on(async {
706 /// use std::sync::Arc;
707 ///
708 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
709 ///
710 /// let store = Arc::new(MemoryStore::default());
711 /// let id = Some(Id::default());
712 /// let session = Session::new(id, store.clone(), None);
713 ///
714 /// session.insert("foo", 42).await.unwrap();
715 /// session.save().await.unwrap();
716 ///
717 /// let session = Session::new(session.id(), store, None);
718 /// session.load().await.unwrap();
719 ///
720 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
721 /// # });
722 /// ```
723 ///
724 /// # Errors
725 ///
726 /// - If loading from the store fails, we fail with [`Error::Store`].
727 #[tracing::instrument(skip(self), err, level = "trace")]
728 pub async fn load(&self) -> Result<()> {
729 let session_id = *self.inner.session_id.lock();
730 let Some(ref id) = session_id else {
731 tracing::warn!("called load with no session id");
732 return Ok(());
733 };
734 let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
735 let mut record_guard = self.inner.record.lock().await;
736 *record_guard = loaded_record;
737 Ok(())
738 }
739
740 /// Deletes the session from the store.
741 ///
742 /// # Examples
743 ///
744 /// ```rust
745 /// # tokio_test::block_on(async {
746 /// use std::sync::Arc;
747 ///
748 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
749 ///
750 /// let store = Arc::new(MemoryStore::default());
751 /// let session = Session::new(Some(Id::default()), store.clone(), None);
752 ///
753 /// // Save before deleting.
754 /// session.save().await.unwrap();
755 ///
756 /// // Delete from the store.
757 /// session.delete().await.unwrap();
758 ///
759 /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
760 /// # });
761 /// ```
762 ///
763 /// # Errors
764 ///
765 /// - If deleting from the store fails, we fail with [`Error::Store`].
766 #[tracing::instrument(skip(self), err, level = "trace")]
767 pub async fn delete(&self) -> Result<()> {
768 let session_id = *self.inner.session_id.lock();
769 let Some(ref session_id) = session_id else {
770 tracing::warn!("called delete with no session id");
771 return Ok(());
772 };
773 self.store.delete(session_id).await.map_err(Error::Store)?;
774 Ok(())
775 }
776
777 /// Flushes the session by removing all data contained in the session and
778 /// then deleting it from the store.
779 ///
780 /// # Examples
781 ///
782 /// ```rust
783 /// # tokio_test::block_on(async {
784 /// use std::sync::Arc;
785 ///
786 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
787 ///
788 /// let store = Arc::new(MemoryStore::default());
789 /// let session = Session::new(None, store.clone(), None);
790 ///
791 /// session.insert("foo", "bar").await.unwrap();
792 /// session.save().await.unwrap();
793 ///
794 /// let id = session.id().unwrap();
795 ///
796 /// session.flush().await.unwrap();
797 ///
798 /// assert!(session.id().is_none());
799 /// assert!(session.is_empty().await);
800 /// assert!(store.load(&id).await.unwrap().is_none());
801 /// # });
802 /// ```
803 ///
804 /// # Errors
805 ///
806 /// - If deleting from the store fails, we fail with [`Error::Store`].
807 pub async fn flush(&self) -> Result<()> {
808 self.clear().await;
809 self.delete().await?;
810 *self.inner.session_id.lock() = None;
811 Ok(())
812 }
813
814 /// Cycles the session ID while retaining any data that was associated with
815 /// it.
816 ///
817 /// Using this method helps prevent session fixation attacks by ensuring a
818 /// new ID is assigned to the session.
819 ///
820 /// # Examples
821 ///
822 /// ```rust
823 /// # tokio_test::block_on(async {
824 /// use std::sync::Arc;
825 ///
826 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
827 ///
828 /// let store = Arc::new(MemoryStore::default());
829 /// let session = Session::new(None, store.clone(), None);
830 ///
831 /// session.insert("foo", 42).await.unwrap();
832 /// session.save().await.unwrap();
833 /// let id = session.id();
834 ///
835 /// let session = Session::new(session.id(), store.clone(), None);
836 /// session.cycle_id().await.unwrap();
837 ///
838 /// assert!(!session.is_empty().await);
839 /// assert!(session.is_modified());
840 ///
841 /// session.save().await.unwrap();
842 ///
843 /// let session = Session::new(session.id(), store, None);
844 ///
845 /// assert_ne!(id, session.id());
846 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
847 /// # });
848 /// ```
849 ///
850 /// # Errors
851 ///
852 /// - If deleting from the store fails or saving to the store fails, we fail
853 /// with [`Error::Store`].
854 pub async fn cycle_id(&self) -> Result<()> {
855 let mut record_guard = self.get_record().await?;
856
857 let old_session_id = record_guard.id;
858 record_guard.id = Id::default();
859 *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
860 // `create` method.
861
862 self.store
863 .delete(&old_session_id)
864 .await
865 .map_err(Error::Store)?;
866
867 self.inner
868 .is_modified
869 .store(true, atomic::Ordering::Release);
870
871 Ok(())
872 }
873}
874
875/// ID type for sessions.
876///
877/// Wraps an array of 16 bytes.
878///
879/// # Examples
880///
881/// ```rust
882/// use tower_sessions_ext::session::Id;
883///
884/// Id::default();
885/// ```
886#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
887pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
888// session ID, which is undesirable.
889
890impl Default for Id {
891 fn default() -> Self {
892 use rand::prelude::*;
893
894 Self(rand::rng().random())
895 }
896}
897
898impl Display for Id {
899 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900 let mut encoded = [0; 22];
901 URL_SAFE_NO_PAD
902 .encode_slice(self.0.to_le_bytes(), &mut encoded)
903 .expect("Encoded ID must be exactly 22 bytes");
904 let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
905
906 f.write_str(encoded)
907 }
908}
909
910impl FromStr for Id {
911 type Err = base64::DecodeSliceError;
912
913 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
914 let mut decoded = [0; 16];
915 let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
916 if bytes_decoded != 16 {
917 let err = DecodeError::InvalidLength(bytes_decoded);
918 return Err(base64::DecodeSliceError::DecodeError(err));
919 }
920
921 Ok(Self(i128::from_le_bytes(decoded)))
922 }
923}
924
925/// Record type that's appropriate for encoding and decoding sessions to and
926/// from session stores.
927#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
928pub struct Record {
929 pub id: Id,
930 pub data: Data,
931 pub expiry_date: OffsetDateTime,
932}
933
934impl Record {
935 fn new(expiry_date: OffsetDateTime) -> Self {
936 Self {
937 id: Id::default(),
938 data: Data::default(),
939 expiry_date,
940 }
941 }
942}
943
944/// Session expiry configuration.
945///
946/// # Examples
947///
948/// ```rust
949/// use time::{Duration, OffsetDateTime};
950/// use tower_sessions_ext::Expiry;
951///
952/// // Will be expired on "session end".
953/// let expiry = Expiry::OnSessionEnd;
954///
955/// // Will be expired in five minutes from last acitve.
956/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
957///
958/// // Will be expired at the given timestamp.
959/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
960/// let expiry = Expiry::AtDateTime(expired_at);
961/// ```
962#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
963pub enum Expiry {
964 /// Expire on [current session end][current-session-end], as defined by the
965 /// browser.
966 ///
967 /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
968 OnSessionEnd(Duration),
969
970 /// Expire on inactivity.
971 ///
972 /// Reading a session is not considered activity for expiration purposes.
973 /// [`Session`] expiration is computed from the last time the session was
974 /// _modified_.
975 OnInactivity(Duration),
976
977 /// Expire at a specific date and time.
978 ///
979 /// This value may be extended manually with
980 /// [`set_expiry`](Session::set_expiry).
981 AtDateTime(OffsetDateTime),
982}
983
984#[cfg(test)]
985mod tests {
986 use async_trait::async_trait;
987 use mockall::{
988 mock,
989 predicate::{self, always},
990 };
991
992 use super::*;
993
994 mock! {
995 #[derive(Debug)]
996 pub Store {}
997
998 #[async_trait]
999 impl SessionStore for Store {
1000 async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1001 async fn save(&self, record: &Record) -> session_store::Result<()>;
1002 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1003 async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1004 }
1005 }
1006
1007 #[tokio::test]
1008 async fn test_cycle_id() {
1009 let mut mock_store = MockStore::new();
1010
1011 let initial_id = Id::default();
1012 let new_id = Id::default();
1013
1014 // Set up expectations for the mock store
1015 mock_store
1016 .expect_save()
1017 .with(always())
1018 .times(1)
1019 .returning(|_| Ok(()));
1020 mock_store
1021 .expect_load()
1022 .with(predicate::eq(initial_id))
1023 .times(1)
1024 .returning(move |_| {
1025 Ok(Some(Record {
1026 id: initial_id,
1027 data: Data::default(),
1028 expiry_date: OffsetDateTime::now_utc(),
1029 }))
1030 });
1031 mock_store
1032 .expect_delete()
1033 .with(predicate::eq(initial_id))
1034 .times(1)
1035 .returning(|_| Ok(()));
1036 mock_store
1037 .expect_create()
1038 .times(1)
1039 .returning(move |record| {
1040 record.id = new_id;
1041 Ok(())
1042 });
1043
1044 let store = Arc::new(mock_store);
1045 let session = Session::new(Some(initial_id), store.clone(), None);
1046
1047 // Insert some data and save the session
1048 session.insert("foo", 42).await.unwrap();
1049 session.save().await.unwrap();
1050
1051 // Cycle the session ID
1052 session.cycle_id().await.unwrap();
1053
1054 // Verify that the session ID has changed and the data is still present
1055 assert_ne!(session.id(), Some(initial_id));
1056 assert!(session.id().is_none()); // The session ID should be None
1057 assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1058
1059 // Save the session to update the ID in the session object
1060 session.save().await.unwrap();
1061 assert_eq!(session.id(), Some(new_id));
1062 }
1063}