tower_sessions_core/session.rs
1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3 collections::HashMap,
4 fmt::{self, Display},
5 hash::Hash,
6 result,
7 str::{self, FromStr},
8 sync::{
9 atomic::{self, AtomicBool},
10 Arc,
11 },
12};
13
14use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{session_store, SessionStore};
21
22const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31 /// Maps `serde_json` errors.
32 #[error(transparent)]
33 SerdeJson(#[from] serde_json::Error),
34
35 /// Maps `session_store::Error` errors.
36 #[error(transparent)]
37 Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42 // This will be `None` when:
43 //
44 // 1. We have not been provided a session cookie or have failed to parse it,
45 // 2. The store has not found the session.
46 //
47 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48 session_id: parking_lot::Mutex<Option<Id>>,
49
50 // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51 // `None` value indicates we have not tried to access it yet. After access, it will always
52 // contain `Some(Record)`.
53 record: Mutex<Option<Record>>,
54
55 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56 expiry: parking_lot::Mutex<Option<Expiry>>,
57
58 is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Debug, Clone)]
64pub struct Session {
65 store: Arc<dyn SessionStore>,
66 inner: Arc<Inner>,
67}
68
69impl Session {
70 /// Creates a new session with the session ID, store, and expiry.
71 ///
72 /// This method is lazy and does not invoke the overhead of talking to the
73 /// backing store.
74 ///
75 /// # Examples
76 ///
77 /// ```rust
78 /// use std::sync::Arc;
79 ///
80 /// use tower_sessions::{MemoryStore, Session};
81 ///
82 /// let store = Arc::new(MemoryStore::default());
83 /// Session::new(None, store, None);
84 /// ```
85 pub fn new(
86 session_id: Option<Id>,
87 store: Arc<impl SessionStore>,
88 expiry: Option<Expiry>,
89 ) -> Self {
90 let inner = Inner {
91 session_id: parking_lot::Mutex::new(session_id),
92 record: Mutex::new(None), // `None` indicates we have not loaded from store.
93 expiry: parking_lot::Mutex::new(expiry),
94 is_modified: AtomicBool::new(false),
95 };
96
97 Self {
98 store,
99 inner: Arc::new(inner),
100 }
101 }
102
103 fn create_record(&self) -> Record {
104 Record::new(self.expiry_date())
105 }
106
107 #[tracing::instrument(skip(self), err)]
108 async fn get_record(&self) -> Result<MappedMutexGuard<'_, Record>> {
109 let mut record_guard = self.inner.record.lock().await;
110
111 // Lazily load the record since `None` here indicates we have no yet loaded it.
112 if record_guard.is_none() {
113 tracing::trace!("record not loaded from store; loading");
114
115 let session_id = *self.inner.session_id.lock();
116 *record_guard = Some(if let Some(session_id) = session_id {
117 match self.store.load(&session_id).await? {
118 Some(loaded_record) => {
119 tracing::trace!("record found in store");
120 loaded_record
121 }
122
123 None => {
124 // A well-behaved user agent should not send session cookies after
125 // expiration. Even so it's possible for an expired session to be removed
126 // from the store after a request was initiated. However, such a race should
127 // be relatively uncommon and as such entering this branch could indicate
128 // malicious behavior.
129 tracing::warn!("possibly suspicious activity: record not found in store");
130 *self.inner.session_id.lock() = None;
131 self.create_record()
132 }
133 }
134 } else {
135 tracing::trace!("session id not found");
136 self.create_record()
137 })
138 }
139
140 Ok(MutexGuard::map(record_guard, |opt| {
141 opt.as_mut()
142 .expect("Record should always be `Option::Some` at this point")
143 }))
144 }
145
146 /// Inserts a `impl Serialize` value into the session.
147 ///
148 /// # Examples
149 ///
150 /// ```rust
151 /// # tokio_test::block_on(async {
152 /// use std::sync::Arc;
153 ///
154 /// use tower_sessions::{MemoryStore, Session};
155 ///
156 /// let store = Arc::new(MemoryStore::default());
157 /// let session = Session::new(None, store, None);
158 ///
159 /// session.insert("foo", 42).await.unwrap();
160 ///
161 /// let value = session.get::<usize>("foo").await.unwrap();
162 /// assert_eq!(value, Some(42));
163 /// # });
164 /// ```
165 ///
166 /// # Errors
167 ///
168 /// - This method can fail when [`serde_json::to_value`] fails.
169 /// - If the session has not been hydrated and loading from the store fails,
170 /// we fail with [`Error::Store`].
171 pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
172 self.insert_value(key, serde_json::to_value(&value)?)
173 .await?;
174 Ok(())
175 }
176
177 /// Inserts a `serde_json::Value` into the session.
178 ///
179 /// If the key was not present in the underlying map, `None` is returned and
180 /// `modified` is set to `true`.
181 ///
182 /// If the underlying map did have the key and its value is the same as the
183 /// provided value, `None` is returned and `modified` is not set.
184 ///
185 /// # Examples
186 ///
187 /// ```rust
188 /// # tokio_test::block_on(async {
189 /// use std::sync::Arc;
190 ///
191 /// use tower_sessions::{MemoryStore, Session};
192 ///
193 /// let store = Arc::new(MemoryStore::default());
194 /// let session = Session::new(None, store, None);
195 ///
196 /// let value = session
197 /// .insert_value("foo", serde_json::json!(42))
198 /// .await
199 /// .unwrap();
200 /// assert!(value.is_none());
201 ///
202 /// let value = session
203 /// .insert_value("foo", serde_json::json!(42))
204 /// .await
205 /// .unwrap();
206 /// assert!(value.is_none());
207 ///
208 /// let value = session
209 /// .insert_value("foo", serde_json::json!("bar"))
210 /// .await
211 /// .unwrap();
212 /// assert_eq!(value, Some(serde_json::json!(42)));
213 /// # });
214 /// ```
215 ///
216 /// # Errors
217 ///
218 /// - If the session has not been hydrated and loading from the store fails,
219 /// we fail with [`Error::Store`].
220 pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
221 let mut record_guard = self.get_record().await?;
222 Ok(if record_guard.data.get(key) != Some(&value) {
223 let previous = record_guard.data.insert(key.to_string(), value);
224 self.inner
225 .is_modified
226 .store(true, atomic::Ordering::Release);
227 previous
228 } else {
229 None
230 })
231 }
232
233 /// Gets a value from the store.
234 ///
235 /// # Examples
236 ///
237 /// ```rust
238 /// # tokio_test::block_on(async {
239 /// use std::sync::Arc;
240 ///
241 /// use tower_sessions::{MemoryStore, Session};
242 ///
243 /// let store = Arc::new(MemoryStore::default());
244 /// let session = Session::new(None, store, None);
245 ///
246 /// session.insert("foo", 42).await.unwrap();
247 ///
248 /// let value = session.get::<usize>("foo").await.unwrap();
249 /// assert_eq!(value, Some(42));
250 /// # });
251 /// ```
252 ///
253 /// # Errors
254 ///
255 /// - This method can fail when [`serde_json::from_value`] fails.
256 /// - If the session has not been hydrated and loading from the store fails,
257 /// we fail with [`Error::Store`].
258 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
259 Ok(self
260 .get_value(key)
261 .await?
262 .map(serde_json::from_value)
263 .transpose()?)
264 }
265
266 /// Gets a `serde_json::Value` from the store.
267 ///
268 /// # Examples
269 ///
270 /// ```rust
271 /// # tokio_test::block_on(async {
272 /// use std::sync::Arc;
273 ///
274 /// use tower_sessions::{MemoryStore, Session};
275 ///
276 /// let store = Arc::new(MemoryStore::default());
277 /// let session = Session::new(None, store, None);
278 ///
279 /// session.insert("foo", 42).await.unwrap();
280 ///
281 /// let value = session.get_value("foo").await.unwrap().unwrap();
282 /// assert_eq!(value, serde_json::json!(42));
283 /// # });
284 /// ```
285 ///
286 /// # Errors
287 ///
288 /// - If the session has not been hydrated and loading from the store fails,
289 /// we fail with [`Error::Store`].
290 pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
291 let record_guard = self.get_record().await?;
292 Ok(record_guard.data.get(key).cloned())
293 }
294
295 /// Removes a value from the store, retuning the value of the key if it was
296 /// present in the underlying map.
297 ///
298 /// # Examples
299 ///
300 /// ```rust
301 /// # tokio_test::block_on(async {
302 /// use std::sync::Arc;
303 ///
304 /// use tower_sessions::{MemoryStore, Session};
305 ///
306 /// let store = Arc::new(MemoryStore::default());
307 /// let session = Session::new(None, store, None);
308 ///
309 /// session.insert("foo", 42).await.unwrap();
310 ///
311 /// let value: Option<usize> = session.remove("foo").await.unwrap();
312 /// assert_eq!(value, Some(42));
313 ///
314 /// let value: Option<usize> = session.get("foo").await.unwrap();
315 /// assert!(value.is_none());
316 /// # });
317 /// ```
318 ///
319 /// # Errors
320 ///
321 /// - This method can fail when [`serde_json::from_value`] fails.
322 /// - If the session has not been hydrated and loading from the store fails,
323 /// we fail with [`Error::Store`].
324 pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
325 Ok(self
326 .remove_value(key)
327 .await?
328 .map(serde_json::from_value)
329 .transpose()?)
330 }
331
332 /// Removes a `serde_json::Value` from the session.
333 ///
334 /// # Examples
335 ///
336 /// ```rust
337 /// # tokio_test::block_on(async {
338 /// use std::sync::Arc;
339 ///
340 /// use tower_sessions::{MemoryStore, Session};
341 ///
342 /// let store = Arc::new(MemoryStore::default());
343 /// let session = Session::new(None, store, None);
344 ///
345 /// session.insert("foo", 42).await.unwrap();
346 /// let value = session.remove_value("foo").await.unwrap().unwrap();
347 /// assert_eq!(value, serde_json::json!(42));
348 ///
349 /// let value: Option<usize> = session.get("foo").await.unwrap();
350 /// assert!(value.is_none());
351 /// # });
352 /// ```
353 ///
354 /// # Errors
355 ///
356 /// - If the session has not been hydrated and loading from the store fails,
357 /// we fail with [`Error::Store`].
358 pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
359 let mut record_guard = self.get_record().await?;
360 let previous = record_guard.data.remove(key);
361 self.inner
362 .is_modified
363 .store(true, atomic::Ordering::Release);
364 Ok(previous)
365 }
366
367 /// Clears the session of all data but does not delete it from the store.
368 ///
369 /// # Examples
370 ///
371 /// ```rust
372 /// # tokio_test::block_on(async {
373 /// use std::sync::Arc;
374 ///
375 /// use tower_sessions::{MemoryStore, Session};
376 ///
377 /// let store = Arc::new(MemoryStore::default());
378 ///
379 /// let session = Session::new(None, store.clone(), None);
380 /// session.insert("foo", 42).await.unwrap();
381 /// assert!(!session.is_empty().await);
382 ///
383 /// session.save().await.unwrap();
384 ///
385 /// session.clear().await;
386 ///
387 /// // Not empty! (We have an ID still.)
388 /// assert!(!session.is_empty().await);
389 /// // Data is cleared...
390 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
391 ///
392 /// // ...data is cleared before loading from the backend...
393 /// let session = Session::new(session.id(), store.clone(), None);
394 /// session.clear().await;
395 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
396 ///
397 /// let session = Session::new(session.id(), store, None);
398 /// // ...but data is not deleted from the store.
399 /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
400 /// # });
401 /// ```
402 pub async fn clear(&self) {
403 let mut record_guard = self.inner.record.lock().await;
404 if let Some(record) = record_guard.as_mut() {
405 record.data.clear();
406 } else if let Some(session_id) = *self.inner.session_id.lock() {
407 let mut new_record = self.create_record();
408 new_record.id = session_id;
409 *record_guard = Some(new_record);
410 }
411
412 self.inner
413 .is_modified
414 .store(true, atomic::Ordering::Release);
415 }
416
417 /// Returns `true` if there is no session ID and the session is empty.
418 ///
419 /// # Examples
420 ///
421 /// ```rust
422 /// # tokio_test::block_on(async {
423 /// use std::sync::Arc;
424 ///
425 /// use tower_sessions::{session::Id, MemoryStore, Session};
426 ///
427 /// let store = Arc::new(MemoryStore::default());
428 ///
429 /// let session = Session::new(None, store.clone(), None);
430 /// // Empty if we have no ID and record is not loaded.
431 /// assert!(session.is_empty().await);
432 ///
433 /// let session = Session::new(Some(Id::default()), store.clone(), None);
434 /// // Not empty if we have an ID but no record. (Record is not loaded here.)
435 /// assert!(!session.is_empty().await);
436 ///
437 /// let session = Session::new(Some(Id::default()), store.clone(), None);
438 /// session.insert("foo", 42).await.unwrap();
439 /// // Not empty after inserting.
440 /// assert!(!session.is_empty().await);
441 /// session.save().await.unwrap();
442 /// // Not empty after saving.
443 /// assert!(!session.is_empty().await);
444 ///
445 /// let session = Session::new(session.id(), store.clone(), None);
446 /// session.load().await.unwrap();
447 /// // Not empty after loading from store...
448 /// assert!(!session.is_empty().await);
449 /// // ...and not empty after accessing the session.
450 /// session.get::<usize>("foo").await.unwrap();
451 /// assert!(!session.is_empty().await);
452 ///
453 /// let session = Session::new(session.id(), store.clone(), None);
454 /// session.delete().await.unwrap();
455 /// // Not empty after deleting from store...
456 /// assert!(!session.is_empty().await);
457 /// session.get::<usize>("foo").await.unwrap();
458 /// // ...but empty after trying to access the deleted session.
459 /// assert!(session.is_empty().await);
460 ///
461 /// let session = Session::new(None, store, None);
462 /// session.insert("foo", 42).await.unwrap();
463 /// session.flush().await.unwrap();
464 /// // Empty after flushing.
465 /// assert!(session.is_empty().await);
466 /// # });
467 /// ```
468 pub async fn is_empty(&self) -> bool {
469 let record_guard = self.inner.record.lock().await;
470
471 // N.B.: Session IDs are `None` if:
472 //
473 // 1. The cookie was not provided or otherwise could not be parsed,
474 // 2. Or the session could not be loaded from the store.
475 let session_id = self.inner.session_id.lock();
476
477 let Some(record) = record_guard.as_ref() else {
478 return session_id.is_none();
479 };
480
481 session_id.is_none() && record.data.is_empty()
482 }
483
484 /// Get the session ID.
485 ///
486 /// # Examples
487 ///
488 /// ```rust
489 /// use std::sync::Arc;
490 ///
491 /// use tower_sessions::{session::Id, MemoryStore, Session};
492 ///
493 /// let store = Arc::new(MemoryStore::default());
494 ///
495 /// let session = Session::new(None, store.clone(), None);
496 /// assert!(session.id().is_none());
497 ///
498 /// let id = Some(Id::default());
499 /// let session = Session::new(id, store, None);
500 /// assert_eq!(id, session.id());
501 /// ```
502 pub fn id(&self) -> Option<Id> {
503 *self.inner.session_id.lock()
504 }
505
506 /// Get the session expiry.
507 ///
508 /// # Examples
509 ///
510 /// ```rust
511 /// use std::sync::Arc;
512 ///
513 /// use tower_sessions::{session::Expiry, MemoryStore, Session};
514 ///
515 /// let store = Arc::new(MemoryStore::default());
516 /// let session = Session::new(None, store, None);
517 ///
518 /// assert_eq!(session.expiry(), None);
519 /// ```
520 pub fn expiry(&self) -> Option<Expiry> {
521 *self.inner.expiry.lock()
522 }
523
524 /// Set `expiry` to the given value.
525 ///
526 /// This may be used within applications directly to alter the session's
527 /// time to live.
528 ///
529 /// # Examples
530 ///
531 /// ```rust
532 /// use std::sync::Arc;
533 ///
534 /// use time::OffsetDateTime;
535 /// use tower_sessions::{session::Expiry, MemoryStore, Session};
536 ///
537 /// let store = Arc::new(MemoryStore::default());
538 /// let session = Session::new(None, store, None);
539 ///
540 /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
541 /// session.set_expiry(Some(expiry));
542 ///
543 /// assert_eq!(session.expiry(), Some(expiry));
544 /// ```
545 pub fn set_expiry(&self, expiry: Option<Expiry>) {
546 *self.inner.expiry.lock() = expiry;
547 self.inner
548 .is_modified
549 .store(true, atomic::Ordering::Release);
550 }
551
552 /// Get session expiry as `OffsetDateTime`.
553 ///
554 /// # Examples
555 ///
556 /// ```rust
557 /// use std::sync::Arc;
558 ///
559 /// use time::{Duration, OffsetDateTime};
560 /// use tower_sessions::{MemoryStore, Session};
561 ///
562 /// let store = Arc::new(MemoryStore::default());
563 /// let session = Session::new(None, store, None);
564 ///
565 /// // Our default duration is two weeks.
566 /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
567 ///
568 /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
569 /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
570 /// ```
571 pub fn expiry_date(&self) -> OffsetDateTime {
572 let expiry = self.inner.expiry.lock();
573 match *expiry {
574 Some(Expiry::OnInactivity(duration)) => {
575 OffsetDateTime::now_utc().saturating_add(duration)
576 }
577 Some(Expiry::AtDateTime(datetime)) => datetime,
578 Some(Expiry::OnSessionEnd) | None => {
579 OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION) // TODO: The default should probably be configurable.
580 }
581 }
582 }
583
584 /// Get session expiry as `Duration`.
585 ///
586 /// # Examples
587 ///
588 /// ```rust
589 /// use std::sync::Arc;
590 ///
591 /// use time::Duration;
592 /// use tower_sessions::{MemoryStore, Session};
593 ///
594 /// let store = Arc::new(MemoryStore::default());
595 /// let session = Session::new(None, store, None);
596 ///
597 /// let expected_duration = Duration::weeks(2);
598 ///
599 /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
600 /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
601 /// ```
602 pub fn expiry_age(&self) -> Duration {
603 std::cmp::max(
604 self.expiry_date() - OffsetDateTime::now_utc(),
605 Duration::ZERO,
606 )
607 }
608
609 /// Returns `true` if the session has been modified during the request.
610 ///
611 /// # Examples
612 ///
613 /// ```rust
614 /// # tokio_test::block_on(async {
615 /// use std::sync::Arc;
616 ///
617 /// use tower_sessions::{MemoryStore, Session};
618 ///
619 /// let store = Arc::new(MemoryStore::default());
620 /// let session = Session::new(None, store, None);
621 ///
622 /// // Not modified initially.
623 /// assert!(!session.is_modified());
624 ///
625 /// // Getting doesn't count as a modification.
626 /// session.get::<usize>("foo").await.unwrap();
627 /// assert!(!session.is_modified());
628 ///
629 /// // Insertions and removals do though.
630 /// session.insert("foo", 42).await.unwrap();
631 /// assert!(session.is_modified());
632 /// # });
633 /// ```
634 pub fn is_modified(&self) -> bool {
635 self.inner.is_modified.load(atomic::Ordering::Acquire)
636 }
637
638 /// Saves the session record to the store.
639 ///
640 /// Note that this method is generally not needed and is reserved for
641 /// situations where the session store must be updated during the
642 /// request.
643 ///
644 /// # Examples
645 ///
646 /// ```rust
647 /// # tokio_test::block_on(async {
648 /// use std::sync::Arc;
649 ///
650 /// use tower_sessions::{MemoryStore, Session};
651 ///
652 /// let store = Arc::new(MemoryStore::default());
653 /// let session = Session::new(None, store.clone(), None);
654 ///
655 /// session.insert("foo", 42).await.unwrap();
656 /// session.save().await.unwrap();
657 ///
658 /// let session = Session::new(session.id(), store, None);
659 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
660 /// # });
661 /// ```
662 ///
663 /// # Errors
664 ///
665 /// - If saving to the store fails, we fail with [`Error::Store`].
666 #[tracing::instrument(skip(self), err)]
667 pub async fn save(&self) -> Result<()> {
668 let mut record_guard = self.get_record().await?;
669 record_guard.expiry_date = self.expiry_date();
670
671 // Session ID is `None` if:
672 //
673 // 1. No valid cookie was found on the request or,
674 // 2. No valid session was found in the store.
675 //
676 // In either case, we must create a new session via the store interface.
677 //
678 // Potential ID collisions must be handled by session store implementers.
679 if self.inner.session_id.lock().is_none() {
680 self.store.create(&mut record_guard).await?;
681 *self.inner.session_id.lock() = Some(record_guard.id);
682 } else {
683 self.store.save(&record_guard).await?;
684 }
685 Ok(())
686 }
687
688 /// Loads the session record from the store.
689 ///
690 /// Note that this method is generally not needed and is reserved for
691 /// situations where the session must be updated during the request.
692 ///
693 /// # Examples
694 ///
695 /// ```rust
696 /// # tokio_test::block_on(async {
697 /// use std::sync::Arc;
698 ///
699 /// use tower_sessions::{session::Id, MemoryStore, Session};
700 ///
701 /// let store = Arc::new(MemoryStore::default());
702 /// let id = Some(Id::default());
703 /// let session = Session::new(id, store.clone(), None);
704 ///
705 /// session.insert("foo", 42).await.unwrap();
706 /// session.save().await.unwrap();
707 ///
708 /// let session = Session::new(session.id(), store, None);
709 /// session.load().await.unwrap();
710 ///
711 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
712 /// # });
713 /// ```
714 ///
715 /// # Errors
716 ///
717 /// - If loading from the store fails, we fail with [`Error::Store`].
718 #[tracing::instrument(skip(self), err)]
719 pub async fn load(&self) -> Result<()> {
720 let session_id = *self.inner.session_id.lock();
721 let Some(ref id) = session_id else {
722 tracing::warn!("called load with no session id");
723 return Ok(());
724 };
725 let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
726 let mut record_guard = self.inner.record.lock().await;
727 *record_guard = loaded_record;
728 Ok(())
729 }
730
731 /// Deletes the session from the store.
732 ///
733 /// # Examples
734 ///
735 /// ```rust
736 /// # tokio_test::block_on(async {
737 /// use std::sync::Arc;
738 ///
739 /// use tower_sessions::{session::Id, MemoryStore, Session, SessionStore};
740 ///
741 /// let store = Arc::new(MemoryStore::default());
742 /// let session = Session::new(Some(Id::default()), store.clone(), None);
743 ///
744 /// // Save before deleting.
745 /// session.save().await.unwrap();
746 ///
747 /// // Delete from the store.
748 /// session.delete().await.unwrap();
749 ///
750 /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
751 /// # });
752 /// ```
753 ///
754 /// # Errors
755 ///
756 /// - If deleting from the store fails, we fail with [`Error::Store`].
757 #[tracing::instrument(skip(self), err)]
758 pub async fn delete(&self) -> Result<()> {
759 let session_id = *self.inner.session_id.lock();
760 let Some(ref session_id) = session_id else {
761 tracing::warn!("called delete with no session id");
762 return Ok(());
763 };
764 self.store.delete(session_id).await.map_err(Error::Store)?;
765 Ok(())
766 }
767
768 /// Flushes the session by removing all data contained in the session and
769 /// then deleting it from the store.
770 ///
771 /// # Examples
772 ///
773 /// ```rust
774 /// # tokio_test::block_on(async {
775 /// use std::sync::Arc;
776 ///
777 /// use tower_sessions::{MemoryStore, Session, SessionStore};
778 ///
779 /// let store = Arc::new(MemoryStore::default());
780 /// let session = Session::new(None, store.clone(), None);
781 ///
782 /// session.insert("foo", "bar").await.unwrap();
783 /// session.save().await.unwrap();
784 ///
785 /// let id = session.id().unwrap();
786 ///
787 /// session.flush().await.unwrap();
788 ///
789 /// assert!(session.id().is_none());
790 /// assert!(session.is_empty().await);
791 /// assert!(store.load(&id).await.unwrap().is_none());
792 /// # });
793 /// ```
794 ///
795 /// # Errors
796 ///
797 /// - If deleting from the store fails, we fail with [`Error::Store`].
798 pub async fn flush(&self) -> Result<()> {
799 self.clear().await;
800 self.delete().await?;
801 *self.inner.session_id.lock() = None;
802 Ok(())
803 }
804
805 /// Cycles the session ID while retaining any data that was associated with
806 /// it.
807 ///
808 /// Using this method helps prevent session fixation attacks by ensuring a
809 /// new ID is assigned to the session.
810 ///
811 /// # Examples
812 ///
813 /// ```rust
814 /// # tokio_test::block_on(async {
815 /// use std::sync::Arc;
816 ///
817 /// use tower_sessions::{session::Id, MemoryStore, Session};
818 ///
819 /// let store = Arc::new(MemoryStore::default());
820 /// let session = Session::new(None, store.clone(), None);
821 ///
822 /// session.insert("foo", 42).await.unwrap();
823 /// session.save().await.unwrap();
824 /// let id = session.id();
825 ///
826 /// let session = Session::new(session.id(), store.clone(), None);
827 /// session.cycle_id().await.unwrap();
828 ///
829 /// assert!(!session.is_empty().await);
830 /// assert!(session.is_modified());
831 ///
832 /// session.save().await.unwrap();
833 ///
834 /// let session = Session::new(session.id(), store, None);
835 ///
836 /// assert_ne!(id, session.id());
837 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
838 /// # });
839 /// ```
840 ///
841 /// # Errors
842 ///
843 /// - If deleting from the store fails or saving to the store fails, we fail
844 /// with [`Error::Store`].
845 pub async fn cycle_id(&self) -> Result<()> {
846 let mut record_guard = self.get_record().await?;
847
848 let old_session_id = record_guard.id;
849 record_guard.id = Id::default();
850 *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
851 // `create` method.
852
853 self.store
854 .delete(&old_session_id)
855 .await
856 .map_err(Error::Store)?;
857
858 self.inner
859 .is_modified
860 .store(true, atomic::Ordering::Release);
861
862 Ok(())
863 }
864}
865
866/// ID type for sessions.
867///
868/// Wraps an array of 16 bytes.
869///
870/// # Examples
871///
872/// ```rust
873/// use tower_sessions::session::Id;
874///
875/// Id::default();
876/// ```
877#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
878pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
879 // session ID, which is undesirable.
880
881impl Default for Id {
882 fn default() -> Self {
883 use rand::prelude::*;
884
885 Self(rand::rng().random())
886 }
887}
888
889impl Display for Id {
890 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
891 let mut encoded = [0; 22];
892 URL_SAFE_NO_PAD
893 .encode_slice(self.0.to_le_bytes(), &mut encoded)
894 .expect("Encoded ID must be exactly 22 bytes");
895 let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
896
897 f.write_str(encoded)
898 }
899}
900
901impl FromStr for Id {
902 type Err = base64::DecodeSliceError;
903
904 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
905 let mut decoded = [0; 16];
906 let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
907 if bytes_decoded != 16 {
908 let err = DecodeError::InvalidLength(bytes_decoded);
909 return Err(base64::DecodeSliceError::DecodeError(err));
910 }
911
912 Ok(Self(i128::from_le_bytes(decoded)))
913 }
914}
915
916/// Record type that's appropriate for encoding and decoding sessions to and
917/// from session stores.
918#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
919pub struct Record {
920 pub id: Id,
921 pub data: Data,
922 pub expiry_date: OffsetDateTime,
923}
924
925impl Record {
926 fn new(expiry_date: OffsetDateTime) -> Self {
927 Self {
928 id: Id::default(),
929 data: Data::default(),
930 expiry_date,
931 }
932 }
933}
934
935/// Session expiry configuration.
936///
937/// # Examples
938///
939/// ```rust
940/// use time::{Duration, OffsetDateTime};
941/// use tower_sessions::Expiry;
942///
943/// // Will be expired on "session end".
944/// let expiry = Expiry::OnSessionEnd;
945///
946/// // Will be expired in five minutes from last acitve.
947/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
948///
949/// // Will be expired at the given timestamp.
950/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
951/// let expiry = Expiry::AtDateTime(expired_at);
952/// ```
953#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
954pub enum Expiry {
955 /// Expire on [current session end][current-session-end], as defined by the
956 /// browser.
957 ///
958 /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
959 OnSessionEnd,
960
961 /// Expire on inactivity.
962 ///
963 /// Reading a session is not considered activity for expiration purposes.
964 /// [`Session`] expiration is computed from the last time the session was
965 /// _modified_.
966 OnInactivity(Duration),
967
968 /// Expire at a specific date and time.
969 ///
970 /// This value may be extended manually with
971 /// [`set_expiry`](Session::set_expiry).
972 AtDateTime(OffsetDateTime),
973}
974
975#[cfg(test)]
976mod tests {
977 use async_trait::async_trait;
978 use mockall::{
979 mock,
980 predicate::{self, always},
981 };
982
983 use super::*;
984
985 mock! {
986 #[derive(Debug)]
987 pub Store {}
988
989 #[async_trait]
990 impl SessionStore for Store {
991 async fn create(&self, record: &mut Record) -> session_store::Result<()>;
992 async fn save(&self, record: &Record) -> session_store::Result<()>;
993 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
994 async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
995 }
996 }
997
998 #[tokio::test]
999 async fn test_cycle_id() {
1000 let mut mock_store = MockStore::new();
1001
1002 let initial_id = Id::default();
1003 let new_id = Id::default();
1004
1005 // Set up expectations for the mock store
1006 mock_store
1007 .expect_save()
1008 .with(always())
1009 .times(1)
1010 .returning(|_| Ok(()));
1011 mock_store
1012 .expect_load()
1013 .with(predicate::eq(initial_id))
1014 .times(1)
1015 .returning(move |_| {
1016 Ok(Some(Record {
1017 id: initial_id,
1018 data: Data::default(),
1019 expiry_date: OffsetDateTime::now_utc(),
1020 }))
1021 });
1022 mock_store
1023 .expect_delete()
1024 .with(predicate::eq(initial_id))
1025 .times(1)
1026 .returning(|_| Ok(()));
1027 mock_store
1028 .expect_create()
1029 .times(1)
1030 .returning(move |record| {
1031 record.id = new_id;
1032 Ok(())
1033 });
1034
1035 let store = Arc::new(mock_store);
1036 let session = Session::new(Some(initial_id), store.clone(), None);
1037
1038 // Insert some data and save the session
1039 session.insert("foo", 42).await.unwrap();
1040 session.save().await.unwrap();
1041
1042 // Cycle the session ID
1043 session.cycle_id().await.unwrap();
1044
1045 // Verify that the session ID has changed and the data is still present
1046 assert_ne!(session.id(), Some(initial_id));
1047 assert!(session.id().is_none()); // The session ID should be None
1048 assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1049
1050 // Save the session to update the ID in the session object
1051 session.save().await.unwrap();
1052 assert_eq!(session.id(), Some(new_id));
1053 }
1054}