tower_sessions_ext_core/session.rs
1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3 collections::HashMap,
4 fmt::{self, Display},
5 hash::Hash,
6 result,
7 str::{self, FromStr},
8 sync::{
9 Arc,
10 atomic::{self, AtomicBool},
11 },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31 /// Maps `serde_json` errors.
32 #[error(transparent)]
33 SerdeJson(#[from] serde_json::Error),
34
35 /// Maps `session_store::Error` errors.
36 #[error(transparent)]
37 Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42 // This will be `None` when:
43 //
44 // 1. We have not been provided a session cookie or have failed to parse it,
45 // 2. The store has not found the session.
46 //
47 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48 session_id: parking_lot::Mutex<Option<Id>>,
49
50 // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51 // `None` value indicates we have not tried to access it yet. After access, it will always
52 // contain `Some(Record)`.
53 record: Mutex<Option<Record>>,
54
55 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56 expiry: parking_lot::Mutex<Option<Expiry>>,
57
58 is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Clone)]
64pub struct Session {
65 store: Arc<dyn SessionStore>,
66 inner: Arc<Inner>,
67}
68
69impl fmt::Debug for Session {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("Session")
72 .field("store", &self.store)
73 .field("inner", &self.inner)
74 .finish()
75 }
76}
77
78impl Session {
79 /// Creates a new session with the session ID, store, and expiry.
80 ///
81 /// This method is lazy and does not invoke the overhead of talking to the
82 /// backing store.
83 ///
84 /// # Examples
85 ///
86 /// ```rust
87 /// use std::sync::Arc;
88 ///
89 /// use tower_sessions_ext::{MemoryStore, Session};
90 ///
91 /// let store = Arc::new(MemoryStore::default());
92 /// Session::new(None, store, None);
93 /// ```
94 pub fn new(
95 session_id: Option<Id>,
96 store: Arc<impl SessionStore>,
97 expiry: Option<Expiry>,
98 ) -> Self {
99 let inner = Inner {
100 session_id: parking_lot::Mutex::new(session_id),
101 record: Mutex::new(None), // `None` indicates we have not loaded from store.
102 expiry: parking_lot::Mutex::new(expiry),
103 is_modified: AtomicBool::new(false),
104 };
105
106 Self {
107 store,
108 inner: Arc::new(inner),
109 }
110 }
111
112 fn create_record(&self) -> Record {
113 Record::new(self.expiry_date())
114 }
115
116 #[tracing::instrument(skip(self), err, level = "trace")]
117 async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
118 let mut record_guard = self.inner.record.lock().await;
119
120 // Lazily load the record since `None` here indicates we have no yet loaded it.
121 if record_guard.is_none() {
122 tracing::trace!("record not loaded from store; loading");
123
124 let session_id = *self.inner.session_id.lock();
125 *record_guard = Some(if let Some(session_id) = session_id {
126 match self.store.load(&session_id).await? {
127 Some(loaded_record) => {
128 tracing::trace!("record found in store");
129 loaded_record
130 }
131
132 None => {
133 // A well-behaved user agent should not send session cookies after
134 // expiration. Even so it's possible for an expired session to be removed
135 // from the store after a request was initiated. However, such a race should
136 // be relatively uncommon and as such entering this branch could indicate
137 // malicious behavior.
138 tracing::warn!("possibly suspicious activity: record not found in store");
139 *self.inner.session_id.lock() = None;
140 self.create_record()
141 }
142 }
143 } else {
144 tracing::trace!("session id not found");
145 self.create_record()
146 })
147 }
148
149 Ok(MutexGuard::map(record_guard, |opt| {
150 opt.as_mut()
151 .expect("Record should always be `Option::Some` at this point")
152 }))
153 }
154
155 /// Inserts a `impl Serialize` value into the session.
156 ///
157 /// # Examples
158 ///
159 /// ```rust
160 /// # tokio_test::block_on(async {
161 /// use std::sync::Arc;
162 ///
163 /// use tower_sessions_ext::{MemoryStore, Session};
164 ///
165 /// let store = Arc::new(MemoryStore::default());
166 /// let session = Session::new(None, store, None);
167 ///
168 /// session.insert("foo", 42).await.unwrap();
169 ///
170 /// let value = session.get::<usize>("foo").await.unwrap();
171 /// assert_eq!(value, Some(42));
172 /// # });
173 /// ```
174 ///
175 /// # Errors
176 ///
177 /// - This method can fail when [`serde_json::to_value`] fails.
178 /// - If the session has not been hydrated and loading from the store fails,
179 /// we fail with [`Error::Store`].
180 pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
181 self.insert_value(key, serde_json::to_value(&value)?)
182 .await?;
183 Ok(())
184 }
185
186 /// Inserts a `serde_json::Value` into the session.
187 ///
188 /// If the key was not present in the underlying map, `None` is returned and
189 /// `modified` is set to `true`.
190 ///
191 /// If the underlying map did have the key and its value is the same as the
192 /// provided value, `None` is returned and `modified` is not set.
193 ///
194 /// # Examples
195 ///
196 /// ```rust
197 /// # tokio_test::block_on(async {
198 /// use std::sync::Arc;
199 ///
200 /// use tower_sessions_ext::{MemoryStore, Session};
201 ///
202 /// let store = Arc::new(MemoryStore::default());
203 /// let session = Session::new(None, store, None);
204 ///
205 /// let value = session
206 /// .insert_value("foo", serde_json::json!(42))
207 /// .await
208 /// .unwrap();
209 /// assert!(value.is_none());
210 ///
211 /// let value = session
212 /// .insert_value("foo", serde_json::json!(42))
213 /// .await
214 /// .unwrap();
215 /// assert!(value.is_none());
216 ///
217 /// let value = session
218 /// .insert_value("foo", serde_json::json!("bar"))
219 /// .await
220 /// .unwrap();
221 /// assert_eq!(value, Some(serde_json::json!(42)));
222 /// # });
223 /// ```
224 ///
225 /// # Errors
226 ///
227 /// - If the session has not been hydrated and loading from the store fails,
228 /// we fail with [`Error::Store`].
229 pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
230 let mut record_guard = self.get_record().await?;
231 Ok(if record_guard.data.get(key) != Some(&value) {
232 self.inner
233 .is_modified
234 .store(true, atomic::Ordering::Release);
235 record_guard.data.insert(key.to_string(), value)
236 } else {
237 None
238 })
239 }
240
241 /// Gets a value from the store.
242 ///
243 /// # Examples
244 ///
245 /// ```rust
246 /// # tokio_test::block_on(async {
247 /// use std::sync::Arc;
248 ///
249 /// use tower_sessions_ext::{MemoryStore, Session};
250 ///
251 /// let store = Arc::new(MemoryStore::default());
252 /// let session = Session::new(None, store, None);
253 ///
254 /// session.insert("foo", 42).await.unwrap();
255 ///
256 /// let value = session.get::<usize>("foo").await.unwrap();
257 /// assert_eq!(value, Some(42));
258 /// # });
259 /// ```
260 ///
261 /// # Errors
262 ///
263 /// - This method can fail when [`serde_json::from_value`] fails.
264 /// - If the session has not been hydrated and loading from the store fails,
265 /// we fail with [`Error::Store`].
266 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
267 Ok(self
268 .get_value(key)
269 .await?
270 .map(serde_json::from_value)
271 .transpose()?)
272 }
273
274 /// Gets a `serde_json::Value` from the store.
275 ///
276 /// # Examples
277 ///
278 /// ```rust
279 /// # tokio_test::block_on(async {
280 /// use std::sync::Arc;
281 ///
282 /// use tower_sessions_ext::{MemoryStore, Session};
283 ///
284 /// let store = Arc::new(MemoryStore::default());
285 /// let session = Session::new(None, store, None);
286 ///
287 /// session.insert("foo", 42).await.unwrap();
288 ///
289 /// let value = session.get_value("foo").await.unwrap().unwrap();
290 /// assert_eq!(value, serde_json::json!(42));
291 /// # });
292 /// ```
293 ///
294 /// # Errors
295 ///
296 /// - If the session has not been hydrated and loading from the store fails,
297 /// we fail with [`Error::Store`].
298 pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
299 let record_guard = self.get_record().await?;
300 Ok(record_guard.data.get(key).cloned())
301 }
302
303 /// Removes a value from the store, retuning the value of the key if it was
304 /// present in the underlying map.
305 ///
306 /// # Examples
307 ///
308 /// ```rust
309 /// # tokio_test::block_on(async {
310 /// use std::sync::Arc;
311 ///
312 /// use tower_sessions_ext::{MemoryStore, Session};
313 ///
314 /// let store = Arc::new(MemoryStore::default());
315 /// let session = Session::new(None, store, None);
316 ///
317 /// session.insert("foo", 42).await.unwrap();
318 ///
319 /// let value: Option<usize> = session.remove("foo").await.unwrap();
320 /// assert_eq!(value, Some(42));
321 ///
322 /// let value: Option<usize> = session.get("foo").await.unwrap();
323 /// assert!(value.is_none());
324 /// # });
325 /// ```
326 ///
327 /// # Errors
328 ///
329 /// - This method can fail when [`serde_json::from_value`] fails.
330 /// - If the session has not been hydrated and loading from the store fails,
331 /// we fail with [`Error::Store`].
332 pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
333 Ok(self
334 .remove_value(key)
335 .await?
336 .map(serde_json::from_value)
337 .transpose()?)
338 }
339
340 /// Removes a `serde_json::Value` from the session.
341 ///
342 /// # Examples
343 ///
344 /// ```rust
345 /// # tokio_test::block_on(async {
346 /// use std::sync::Arc;
347 ///
348 /// use tower_sessions_ext::{MemoryStore, Session};
349 ///
350 /// let store = Arc::new(MemoryStore::default());
351 /// let session = Session::new(None, store, None);
352 ///
353 /// session.insert("foo", 42).await.unwrap();
354 /// let value = session.remove_value("foo").await.unwrap().unwrap();
355 /// assert_eq!(value, serde_json::json!(42));
356 ///
357 /// let value: Option<usize> = session.get("foo").await.unwrap();
358 /// assert!(value.is_none());
359 /// # });
360 /// ```
361 ///
362 /// # Errors
363 ///
364 /// - If the session has not been hydrated and loading from the store fails,
365 /// we fail with [`Error::Store`].
366 pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
367 let mut record_guard = self.get_record().await?;
368 self.inner
369 .is_modified
370 .store(true, atomic::Ordering::Release);
371 Ok(record_guard.data.remove(key))
372 }
373
374 /// Clears the session of all data but does not delete it from the store.
375 ///
376 /// # Examples
377 ///
378 /// ```rust
379 /// # tokio_test::block_on(async {
380 /// use std::sync::Arc;
381 ///
382 /// use tower_sessions_ext::{MemoryStore, Session};
383 ///
384 /// let store = Arc::new(MemoryStore::default());
385 ///
386 /// let session = Session::new(None, store.clone(), None);
387 /// session.insert("foo", 42).await.unwrap();
388 /// assert!(!session.is_empty().await);
389 ///
390 /// session.save().await.unwrap();
391 ///
392 /// session.clear().await;
393 ///
394 /// // Not empty! (We have an ID still.)
395 /// assert!(!session.is_empty().await);
396 /// // Data is cleared...
397 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
398 ///
399 /// // ...data is cleared before loading from the backend...
400 /// let session = Session::new(session.id(), store.clone(), None);
401 /// session.clear().await;
402 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
403 ///
404 /// let session = Session::new(session.id(), store, None);
405 /// // ...but data is not deleted from the store.
406 /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
407 /// # });
408 /// ```
409 pub async fn clear(&self) {
410 let mut record_guard = self.inner.record.lock().await;
411 if let Some(record) = record_guard.as_mut() {
412 record.data.clear();
413 } else if let Some(session_id) = *self.inner.session_id.lock() {
414 let mut new_record = self.create_record();
415 new_record.id = session_id;
416 *record_guard = Some(new_record);
417 }
418
419 self.inner
420 .is_modified
421 .store(true, atomic::Ordering::Release);
422 }
423
424 /// Returns `true` if there is no session ID and the session is empty.
425 ///
426 /// # Examples
427 ///
428 /// ```rust
429 /// # tokio_test::block_on(async {
430 /// use std::sync::Arc;
431 ///
432 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
433 ///
434 /// let store = Arc::new(MemoryStore::default());
435 ///
436 /// let session = Session::new(None, store.clone(), None);
437 /// // Empty if we have no ID and record is not loaded.
438 /// assert!(session.is_empty().await);
439 ///
440 /// let session = Session::new(Some(Id::default()), store.clone(), None);
441 /// // Not empty if we have an ID but no record. (Record is not loaded here.)
442 /// assert!(!session.is_empty().await);
443 ///
444 /// let session = Session::new(Some(Id::default()), store.clone(), None);
445 /// session.insert("foo", 42).await.unwrap();
446 /// // Not empty after inserting.
447 /// assert!(!session.is_empty().await);
448 /// session.save().await.unwrap();
449 /// // Not empty after saving.
450 /// assert!(!session.is_empty().await);
451 ///
452 /// let session = Session::new(session.id(), store.clone(), None);
453 /// session.load().await.unwrap();
454 /// // Not empty after loading from store...
455 /// assert!(!session.is_empty().await);
456 /// // ...and not empty after accessing the session.
457 /// session.get::<usize>("foo").await.unwrap();
458 /// assert!(!session.is_empty().await);
459 ///
460 /// let session = Session::new(session.id(), store.clone(), None);
461 /// session.delete().await.unwrap();
462 /// // Not empty after deleting from store...
463 /// assert!(!session.is_empty().await);
464 /// session.get::<usize>("foo").await.unwrap();
465 /// // ...but empty after trying to access the deleted session.
466 /// assert!(session.is_empty().await);
467 ///
468 /// let session = Session::new(None, store, None);
469 /// session.insert("foo", 42).await.unwrap();
470 /// session.flush().await.unwrap();
471 /// // Empty after flushing.
472 /// assert!(session.is_empty().await);
473 /// # });
474 /// ```
475 pub async fn is_empty(&self) -> bool {
476 let record_guard = self.inner.record.lock().await;
477
478 // N.B.: Session IDs are `None` if:
479 //
480 // 1. The cookie was not provided or otherwise could not be parsed,
481 // 2. Or the session could not be loaded from the store.
482 let session_id = self.inner.session_id.lock();
483
484 let Some(record) = record_guard.as_ref() else {
485 return session_id.is_none();
486 };
487
488 session_id.is_none() && record.data.is_empty()
489 }
490
491 /// Get the session ID.
492 ///
493 /// # Examples
494 ///
495 /// ```rust
496 /// use std::sync::Arc;
497 ///
498 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
499 ///
500 /// let store = Arc::new(MemoryStore::default());
501 ///
502 /// let session = Session::new(None, store.clone(), None);
503 /// assert!(session.id().is_none());
504 ///
505 /// let id = Some(Id::default());
506 /// let session = Session::new(id, store, None);
507 /// assert_eq!(id, session.id());
508 /// ```
509 pub fn id(&self) -> Option<Id> {
510 *self.inner.session_id.lock()
511 }
512
513 /// Get the session expiry.
514 ///
515 /// # Examples
516 ///
517 /// ```rust
518 /// use std::sync::Arc;
519 ///
520 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
521 ///
522 /// let store = Arc::new(MemoryStore::default());
523 /// let session = Session::new(None, store, None);
524 ///
525 /// assert_eq!(session.expiry(), None);
526 /// ```
527 pub fn expiry(&self) -> Option<Expiry> {
528 *self.inner.expiry.lock()
529 }
530
531 /// Set `expiry` to the given value.
532 ///
533 /// This may be used within applications directly to alter the session's
534 /// time to live.
535 ///
536 /// # Examples
537 ///
538 /// ```rust
539 /// use std::sync::Arc;
540 ///
541 /// use time::OffsetDateTime;
542 /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
543 ///
544 /// let store = Arc::new(MemoryStore::default());
545 /// let session = Session::new(None, store, None);
546 ///
547 /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
548 /// session.set_expiry(Some(expiry));
549 ///
550 /// assert_eq!(session.expiry(), Some(expiry));
551 /// ```
552 pub fn set_expiry(&self, expiry: Option<Expiry>) {
553 *self.inner.expiry.lock() = expiry;
554 self.inner
555 .is_modified
556 .store(true, atomic::Ordering::Release);
557 }
558
559 /// Get session expiry as `OffsetDateTime`.
560 ///
561 /// # Examples
562 ///
563 /// ```rust
564 /// use std::sync::Arc;
565 ///
566 /// use time::{Duration, OffsetDateTime};
567 /// use tower_sessions_ext::{MemoryStore, Session};
568 ///
569 /// let store = Arc::new(MemoryStore::default());
570 /// let session = Session::new(None, store, None);
571 ///
572 /// // Our default duration is two weeks.
573 /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
574 ///
575 /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
576 /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
577 /// ```
578 pub fn expiry_date(&self) -> OffsetDateTime {
579 let expiry = self.inner.expiry.lock();
580 match *expiry {
581 Some(Expiry::OnInactivity(duration)) => {
582 OffsetDateTime::now_utc().saturating_add(duration)
583 }
584 Some(Expiry::AtDateTime(datetime)) => datetime,
585 Some(Expiry::OnSessionEnd(datetime)) => {
586 OffsetDateTime::now_utc().saturating_add(datetime)
587 }
588 None => OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION),
589 }
590 }
591
592 /// Get session expiry as `Duration`.
593 ///
594 /// # Examples
595 ///
596 /// ```rust
597 /// use std::sync::Arc;
598 ///
599 /// use time::Duration;
600 /// use tower_sessions_ext::{MemoryStore, Session};
601 ///
602 /// let store = Arc::new(MemoryStore::default());
603 /// let session = Session::new(None, store, None);
604 ///
605 /// let expected_duration = Duration::weeks(2);
606 ///
607 /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
608 /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
609 /// ```
610 pub fn expiry_age(&self) -> Duration {
611 std::cmp::max(
612 self.expiry_date() - OffsetDateTime::now_utc(),
613 Duration::ZERO,
614 )
615 }
616
617 /// Returns `true` if the session has been modified during the request.
618 ///
619 /// # Examples
620 ///
621 /// ```rust
622 /// # tokio_test::block_on(async {
623 /// use std::sync::Arc;
624 ///
625 /// use tower_sessions_ext::{MemoryStore, Session};
626 ///
627 /// let store = Arc::new(MemoryStore::default());
628 /// let session = Session::new(None, store, None);
629 ///
630 /// // Not modified initially.
631 /// assert!(!session.is_modified());
632 ///
633 /// // Getting doesn't count as a modification.
634 /// session.get::<usize>("foo").await.unwrap();
635 /// assert!(!session.is_modified());
636 ///
637 /// // Insertions and removals do though.
638 /// session.insert("foo", 42).await.unwrap();
639 /// assert!(session.is_modified());
640 /// # });
641 /// ```
642 pub fn is_modified(&self) -> bool {
643 self.inner.is_modified.load(atomic::Ordering::Acquire)
644 }
645
646 /// Saves the session record to the store.
647 ///
648 /// Note that this method is generally not needed and is reserved for
649 /// situations where the session store must be updated during the
650 /// request.
651 ///
652 /// # Examples
653 ///
654 /// ```rust
655 /// # tokio_test::block_on(async {
656 /// use std::sync::Arc;
657 ///
658 /// use tower_sessions_ext::{MemoryStore, Session};
659 ///
660 /// let store = Arc::new(MemoryStore::default());
661 /// let session = Session::new(None, store.clone(), None);
662 ///
663 /// session.insert("foo", 42).await.unwrap();
664 /// session.save().await.unwrap();
665 ///
666 /// let session = Session::new(session.id(), store, None);
667 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
668 /// # });
669 /// ```
670 ///
671 /// # Errors
672 ///
673 /// - If saving to the store fails, we fail with [`Error::Store`].
674 #[tracing::instrument(skip(self), err, level = "trace")]
675 pub async fn save(&self) -> Result<()> {
676 let mut record_guard = self.get_record().await?;
677 record_guard.expiry_date = self.expiry_date();
678
679 // Session ID is `None` if:
680 //
681 // 1. No valid cookie was found on the request or,
682 // 2. No valid session was found in the store.
683 //
684 // In either case, we must create a new session via the store interface.
685 //
686 // Potential ID collisions must be handled by session store implementers.
687 if self.inner.session_id.lock().is_none() {
688 self.store.create(&mut record_guard).await?;
689 *self.inner.session_id.lock() = Some(record_guard.id);
690 } else {
691 self.store.save(&record_guard).await?;
692 }
693 Ok(())
694 }
695
696 /// Loads the session record from the store.
697 ///
698 /// Note that this method is generally not needed and is reserved for
699 /// situations where the session must be updated during the request.
700 ///
701 /// # Examples
702 ///
703 /// ```rust
704 /// # tokio_test::block_on(async {
705 /// use std::sync::Arc;
706 ///
707 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
708 ///
709 /// let store = Arc::new(MemoryStore::default());
710 /// let id = Some(Id::default());
711 /// let session = Session::new(id, store.clone(), None);
712 ///
713 /// session.insert("foo", 42).await.unwrap();
714 /// session.save().await.unwrap();
715 ///
716 /// let session = Session::new(session.id(), store, None);
717 /// session.load().await.unwrap();
718 ///
719 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
720 /// # });
721 /// ```
722 ///
723 /// # Errors
724 ///
725 /// - If loading from the store fails, we fail with [`Error::Store`].
726 #[tracing::instrument(skip(self), err, level = "trace")]
727 pub async fn load(&self) -> Result<()> {
728 let session_id = *self.inner.session_id.lock();
729 let Some(ref id) = session_id else {
730 tracing::warn!("called load with no session id");
731 return Ok(());
732 };
733 let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
734 let mut record_guard = self.inner.record.lock().await;
735 *record_guard = loaded_record;
736 Ok(())
737 }
738
739 /// Deletes the session from the store.
740 ///
741 /// # Examples
742 ///
743 /// ```rust
744 /// # tokio_test::block_on(async {
745 /// use std::sync::Arc;
746 ///
747 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
748 ///
749 /// let store = Arc::new(MemoryStore::default());
750 /// let session = Session::new(Some(Id::default()), store.clone(), None);
751 ///
752 /// // Save before deleting.
753 /// session.save().await.unwrap();
754 ///
755 /// // Delete from the store.
756 /// session.delete().await.unwrap();
757 ///
758 /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
759 /// # });
760 /// ```
761 ///
762 /// # Errors
763 ///
764 /// - If deleting from the store fails, we fail with [`Error::Store`].
765 #[tracing::instrument(skip(self), err, level = "trace")]
766 pub async fn delete(&self) -> Result<()> {
767 let session_id = *self.inner.session_id.lock();
768 let Some(ref session_id) = session_id else {
769 tracing::warn!("called delete with no session id");
770 return Ok(());
771 };
772 self.store.delete(session_id).await.map_err(Error::Store)?;
773 Ok(())
774 }
775
776 /// Flushes the session by removing all data contained in the session and
777 /// then deleting it from the store.
778 ///
779 /// # Examples
780 ///
781 /// ```rust
782 /// # tokio_test::block_on(async {
783 /// use std::sync::Arc;
784 ///
785 /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
786 ///
787 /// let store = Arc::new(MemoryStore::default());
788 /// let session = Session::new(None, store.clone(), None);
789 ///
790 /// session.insert("foo", "bar").await.unwrap();
791 /// session.save().await.unwrap();
792 ///
793 /// let id = session.id().unwrap();
794 ///
795 /// session.flush().await.unwrap();
796 ///
797 /// assert!(session.id().is_none());
798 /// assert!(session.is_empty().await);
799 /// assert!(store.load(&id).await.unwrap().is_none());
800 /// # });
801 /// ```
802 ///
803 /// # Errors
804 ///
805 /// - If deleting from the store fails, we fail with [`Error::Store`].
806 pub async fn flush(&self) -> Result<()> {
807 self.clear().await;
808 self.delete().await?;
809 *self.inner.session_id.lock() = None;
810 Ok(())
811 }
812
813 /// Cycles the session ID while retaining any data that was associated with
814 /// it.
815 ///
816 /// Using this method helps prevent session fixation attacks by ensuring a
817 /// new ID is assigned to the session.
818 ///
819 /// # Examples
820 ///
821 /// ```rust
822 /// # tokio_test::block_on(async {
823 /// use std::sync::Arc;
824 ///
825 /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
826 ///
827 /// let store = Arc::new(MemoryStore::default());
828 /// let session = Session::new(None, store.clone(), None);
829 ///
830 /// session.insert("foo", 42).await.unwrap();
831 /// session.save().await.unwrap();
832 /// let id = session.id();
833 ///
834 /// let session = Session::new(session.id(), store.clone(), None);
835 /// session.cycle_id().await.unwrap();
836 ///
837 /// assert!(!session.is_empty().await);
838 /// assert!(session.is_modified());
839 ///
840 /// session.save().await.unwrap();
841 ///
842 /// let session = Session::new(session.id(), store, None);
843 ///
844 /// assert_ne!(id, session.id());
845 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
846 /// # });
847 /// ```
848 ///
849 /// # Errors
850 ///
851 /// - If deleting from the store fails or saving to the store fails, we fail
852 /// with [`Error::Store`].
853 pub async fn cycle_id(&self) -> Result<()> {
854 let mut record_guard = self.get_record().await?;
855
856 let old_session_id = record_guard.id;
857 record_guard.id = Id::default();
858 *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
859 // `create` method.
860
861 self.store
862 .delete(&old_session_id)
863 .await
864 .map_err(Error::Store)?;
865
866 self.inner
867 .is_modified
868 .store(true, atomic::Ordering::Release);
869
870 Ok(())
871 }
872}
873
874/// ID type for sessions.
875///
876/// Wraps an array of 16 bytes.
877///
878/// # Examples
879///
880/// ```rust
881/// use tower_sessions_ext::session::Id;
882///
883/// Id::default();
884/// ```
885#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
886pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
887// session ID, which is undesirable.
888
889impl Default for Id {
890 fn default() -> Self {
891 use rand::prelude::*;
892
893 Self(rand::rng().random())
894 }
895}
896
897impl Display for Id {
898 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899 let mut encoded = [0; 22];
900 URL_SAFE_NO_PAD
901 .encode_slice(self.0.to_le_bytes(), &mut encoded)
902 .expect("Encoded ID must be exactly 22 bytes");
903 let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
904
905 f.write_str(encoded)
906 }
907}
908
909impl FromStr for Id {
910 type Err = base64::DecodeSliceError;
911
912 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
913 let mut decoded = [0; 16];
914 let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
915 if bytes_decoded != 16 {
916 let err = DecodeError::InvalidLength(bytes_decoded);
917 return Err(base64::DecodeSliceError::DecodeError(err));
918 }
919
920 Ok(Self(i128::from_le_bytes(decoded)))
921 }
922}
923
924/// Record type that's appropriate for encoding and decoding sessions to and
925/// from session stores.
926#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
927pub struct Record {
928 pub id: Id,
929 pub data: Data,
930 pub expiry_date: OffsetDateTime,
931}
932
933impl Record {
934 fn new(expiry_date: OffsetDateTime) -> Self {
935 Self {
936 id: Id::default(),
937 data: Data::default(),
938 expiry_date,
939 }
940 }
941}
942
943/// Session expiry configuration.
944///
945/// # Examples
946///
947/// ```rust
948/// use time::{Duration, OffsetDateTime};
949/// use tower_sessions_ext::Expiry;
950///
951/// // Will be expired on "session end".
952/// let expiry = Expiry::OnSessionEnd;
953///
954/// // Will be expired in five minutes from last acitve.
955/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
956///
957/// // Will be expired at the given timestamp.
958/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
959/// let expiry = Expiry::AtDateTime(expired_at);
960/// ```
961#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
962pub enum Expiry {
963 /// Expire on [current session end][current-session-end], as defined by the
964 /// browser.
965 ///
966 /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
967 OnSessionEnd(Duration),
968
969 /// Expire on inactivity.
970 ///
971 /// Reading a session is not considered activity for expiration purposes.
972 /// [`Session`] expiration is computed from the last time the session was
973 /// _modified_.
974 OnInactivity(Duration),
975
976 /// Expire at a specific date and time.
977 ///
978 /// This value may be extended manually with
979 /// [`set_expiry`](Session::set_expiry).
980 AtDateTime(OffsetDateTime),
981}
982
983#[cfg(test)]
984mod tests {
985 use async_trait::async_trait;
986 use mockall::{
987 mock,
988 predicate::{self, always},
989 };
990
991 use super::*;
992
993 mock! {
994 #[derive(Debug)]
995 pub Store {}
996
997 #[async_trait]
998 impl SessionStore for Store {
999 async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1000 async fn save(&self, record: &Record) -> session_store::Result<()>;
1001 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1002 async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1003 }
1004 }
1005
1006 #[tokio::test]
1007 async fn test_cycle_id() {
1008 let mut mock_store = MockStore::new();
1009
1010 let initial_id = Id::default();
1011 let new_id = Id::default();
1012
1013 // Set up expectations for the mock store
1014 mock_store
1015 .expect_save()
1016 .with(always())
1017 .times(1)
1018 .returning(|_| Ok(()));
1019 mock_store
1020 .expect_load()
1021 .with(predicate::eq(initial_id))
1022 .times(1)
1023 .returning(move |_| {
1024 Ok(Some(Record {
1025 id: initial_id,
1026 data: Data::default(),
1027 expiry_date: OffsetDateTime::now_utc(),
1028 }))
1029 });
1030 mock_store
1031 .expect_delete()
1032 .with(predicate::eq(initial_id))
1033 .times(1)
1034 .returning(|_| Ok(()));
1035 mock_store
1036 .expect_create()
1037 .times(1)
1038 .returning(move |record| {
1039 record.id = new_id;
1040 Ok(())
1041 });
1042
1043 let store = Arc::new(mock_store);
1044 let session = Session::new(Some(initial_id), store.clone(), None);
1045
1046 // Insert some data and save the session
1047 session.insert("foo", 42).await.unwrap();
1048 session.save().await.unwrap();
1049
1050 // Cycle the session ID
1051 session.cycle_id().await.unwrap();
1052
1053 // Verify that the session ID has changed and the data is still present
1054 assert_ne!(session.id(), Some(initial_id));
1055 assert!(session.id().is_none()); // The session ID should be None
1056 assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1057
1058 // Save the session to update the ID in the session object
1059 session.save().await.unwrap();
1060 assert_eq!(session.id(), Some(new_id));
1061 }
1062}