Skip to main content

tower_sessions_core/
session.rs

1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3    collections::HashMap,
4    fmt::{self, Display},
5    hash::Hash,
6    result,
7    str::{self, FromStr},
8    sync::{
9        atomic::{self, AtomicBool},
10        Arc,
11    },
12};
13
14use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{session_store, SessionStore};
21
22const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31    /// Maps `serde_json` errors.
32    #[error(transparent)]
33    SerdeJson(#[from] serde_json::Error),
34
35    /// Maps `session_store::Error` errors.
36    #[error(transparent)]
37    Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42    // This will be `None` when:
43    //
44    // 1. We have not been provided a session cookie or have failed to parse it,
45    // 2. The store has not found the session.
46    //
47    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48    session_id: parking_lot::Mutex<Option<Id>>,
49
50    // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51    // `None` value indicates we have not tried to access it yet. After access, it will always
52    // contain `Some(Record)`.
53    record: Mutex<Option<Record>>,
54
55    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56    expiry: parking_lot::Mutex<Option<Expiry>>,
57
58    is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Debug, Clone)]
64pub struct Session {
65    store: Arc<dyn SessionStore>,
66    inner: Arc<Inner>,
67}
68
69impl Session {
70    /// Creates a new session with the session ID, store, and expiry.
71    ///
72    /// This method is lazy and does not invoke the overhead of talking to the
73    /// backing store.
74    ///
75    /// # Examples
76    ///
77    /// ```rust
78    /// use std::sync::Arc;
79    ///
80    /// use tower_sessions::{MemoryStore, Session};
81    ///
82    /// let store = Arc::new(MemoryStore::default());
83    /// Session::new(None, store, None);
84    /// ```
85    pub fn new(
86        session_id: Option<Id>,
87        store: Arc<impl SessionStore>,
88        expiry: Option<Expiry>,
89    ) -> Self {
90        let inner = Inner {
91            session_id: parking_lot::Mutex::new(session_id),
92            record: Mutex::new(None), // `None` indicates we have not loaded from store.
93            expiry: parking_lot::Mutex::new(expiry),
94            is_modified: AtomicBool::new(false),
95        };
96
97        Self {
98            store,
99            inner: Arc::new(inner),
100        }
101    }
102
103    fn create_record(&self) -> Record {
104        Record::new(self.expiry_date())
105    }
106
107    #[tracing::instrument(skip(self), err)]
108    async fn get_record(&self) -> Result<MappedMutexGuard<'_, Record>> {
109        let mut record_guard = self.inner.record.lock().await;
110
111        // Lazily load the record since `None` here indicates we have no yet loaded it.
112        if record_guard.is_none() {
113            tracing::trace!("record not loaded from store; loading");
114
115            let session_id = *self.inner.session_id.lock();
116            *record_guard = Some(if let Some(session_id) = session_id {
117                match self.store.load(&session_id).await? {
118                    Some(loaded_record) => {
119                        tracing::trace!("record found in store");
120                        loaded_record
121                    }
122
123                    None => {
124                        // A well-behaved user agent should not send session cookies after
125                        // expiration. Even so it's possible for an expired session to be removed
126                        // from the store after a request was initiated. However, such a race should
127                        // be relatively uncommon and as such entering this branch could indicate
128                        // malicious behavior.
129                        tracing::warn!("possibly suspicious activity: record not found in store");
130                        *self.inner.session_id.lock() = None;
131                        self.create_record()
132                    }
133                }
134            } else {
135                tracing::trace!("session id not found");
136                self.create_record()
137            })
138        }
139
140        Ok(MutexGuard::map(record_guard, |opt| {
141            opt.as_mut()
142                .expect("Record should always be `Option::Some` at this point")
143        }))
144    }
145
146    /// Inserts a `impl Serialize` value into the session.
147    ///
148    /// # Examples
149    ///
150    /// ```rust
151    /// # tokio_test::block_on(async {
152    /// use std::sync::Arc;
153    ///
154    /// use tower_sessions::{MemoryStore, Session};
155    ///
156    /// let store = Arc::new(MemoryStore::default());
157    /// let session = Session::new(None, store, None);
158    ///
159    /// session.insert("foo", 42).await.unwrap();
160    ///
161    /// let value = session.get::<usize>("foo").await.unwrap();
162    /// assert_eq!(value, Some(42));
163    /// # });
164    /// ```
165    ///
166    /// # Errors
167    ///
168    /// - This method can fail when [`serde_json::to_value`] fails.
169    /// - If the session has not been hydrated and loading from the store fails,
170    ///   we fail with [`Error::Store`].
171    pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
172        self.insert_value(key, serde_json::to_value(&value)?)
173            .await?;
174        Ok(())
175    }
176
177    /// Inserts a `serde_json::Value` into the session.
178    ///
179    /// If the key was not present in the underlying map, `None` is returned and
180    /// `modified` is set to `true`.
181    ///
182    /// If the underlying map did have the key and its value is the same as the
183    /// provided value, `None` is returned and `modified` is not set.
184    ///
185    /// # Examples
186    ///
187    /// ```rust
188    /// # tokio_test::block_on(async {
189    /// use std::sync::Arc;
190    ///
191    /// use tower_sessions::{MemoryStore, Session};
192    ///
193    /// let store = Arc::new(MemoryStore::default());
194    /// let session = Session::new(None, store, None);
195    ///
196    /// let value = session
197    ///     .insert_value("foo", serde_json::json!(42))
198    ///     .await
199    ///     .unwrap();
200    /// assert!(value.is_none());
201    ///
202    /// let value = session
203    ///     .insert_value("foo", serde_json::json!(42))
204    ///     .await
205    ///     .unwrap();
206    /// assert!(value.is_none());
207    ///
208    /// let value = session
209    ///     .insert_value("foo", serde_json::json!("bar"))
210    ///     .await
211    ///     .unwrap();
212    /// assert_eq!(value, Some(serde_json::json!(42)));
213    /// # });
214    /// ```
215    ///
216    /// # Errors
217    ///
218    /// - If the session has not been hydrated and loading from the store fails,
219    ///   we fail with [`Error::Store`].
220    pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
221        let mut record_guard = self.get_record().await?;
222        Ok(if record_guard.data.get(key) != Some(&value) {
223            let previous = record_guard.data.insert(key.to_string(), value);
224            self.inner
225                .is_modified
226                .store(true, atomic::Ordering::Release);
227            previous
228        } else {
229            None
230        })
231    }
232
233    /// Gets a value from the store.
234    ///
235    /// # Examples
236    ///
237    /// ```rust
238    /// # tokio_test::block_on(async {
239    /// use std::sync::Arc;
240    ///
241    /// use tower_sessions::{MemoryStore, Session};
242    ///
243    /// let store = Arc::new(MemoryStore::default());
244    /// let session = Session::new(None, store, None);
245    ///
246    /// session.insert("foo", 42).await.unwrap();
247    ///
248    /// let value = session.get::<usize>("foo").await.unwrap();
249    /// assert_eq!(value, Some(42));
250    /// # });
251    /// ```
252    ///
253    /// # Errors
254    ///
255    /// - This method can fail when [`serde_json::from_value`] fails.
256    /// - If the session has not been hydrated and loading from the store fails,
257    ///   we fail with [`Error::Store`].
258    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
259        Ok(self
260            .get_value(key)
261            .await?
262            .map(serde_json::from_value)
263            .transpose()?)
264    }
265
266    /// Gets a `serde_json::Value` from the store.
267    ///
268    /// # Examples
269    ///
270    /// ```rust
271    /// # tokio_test::block_on(async {
272    /// use std::sync::Arc;
273    ///
274    /// use tower_sessions::{MemoryStore, Session};
275    ///
276    /// let store = Arc::new(MemoryStore::default());
277    /// let session = Session::new(None, store, None);
278    ///
279    /// session.insert("foo", 42).await.unwrap();
280    ///
281    /// let value = session.get_value("foo").await.unwrap().unwrap();
282    /// assert_eq!(value, serde_json::json!(42));
283    /// # });
284    /// ```
285    ///
286    /// # Errors
287    ///
288    /// - If the session has not been hydrated and loading from the store fails,
289    ///   we fail with [`Error::Store`].
290    pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
291        let record_guard = self.get_record().await?;
292        Ok(record_guard.data.get(key).cloned())
293    }
294
295    /// Removes a value from the store, retuning the value of the key if it was
296    /// present in the underlying map.
297    ///
298    /// # Examples
299    ///
300    /// ```rust
301    /// # tokio_test::block_on(async {
302    /// use std::sync::Arc;
303    ///
304    /// use tower_sessions::{MemoryStore, Session};
305    ///
306    /// let store = Arc::new(MemoryStore::default());
307    /// let session = Session::new(None, store, None);
308    ///
309    /// session.insert("foo", 42).await.unwrap();
310    ///
311    /// let value: Option<usize> = session.remove("foo").await.unwrap();
312    /// assert_eq!(value, Some(42));
313    ///
314    /// let value: Option<usize> = session.get("foo").await.unwrap();
315    /// assert!(value.is_none());
316    /// # });
317    /// ```
318    ///
319    /// # Errors
320    ///
321    /// - This method can fail when [`serde_json::from_value`] fails.
322    /// - If the session has not been hydrated and loading from the store fails,
323    ///   we fail with [`Error::Store`].
324    pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
325        Ok(self
326            .remove_value(key)
327            .await?
328            .map(serde_json::from_value)
329            .transpose()?)
330    }
331
332    /// Removes a `serde_json::Value` from the session.
333    ///
334    /// # Examples
335    ///
336    /// ```rust
337    /// # tokio_test::block_on(async {
338    /// use std::sync::Arc;
339    ///
340    /// use tower_sessions::{MemoryStore, Session};
341    ///
342    /// let store = Arc::new(MemoryStore::default());
343    /// let session = Session::new(None, store, None);
344    ///
345    /// session.insert("foo", 42).await.unwrap();
346    /// let value = session.remove_value("foo").await.unwrap().unwrap();
347    /// assert_eq!(value, serde_json::json!(42));
348    ///
349    /// let value: Option<usize> = session.get("foo").await.unwrap();
350    /// assert!(value.is_none());
351    /// # });
352    /// ```
353    ///
354    /// # Errors
355    ///
356    /// - If the session has not been hydrated and loading from the store fails,
357    ///   we fail with [`Error::Store`].
358    pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
359        let mut record_guard = self.get_record().await?;
360        let previous = record_guard.data.remove(key);
361        self.inner
362            .is_modified
363            .store(true, atomic::Ordering::Release);
364        Ok(previous)
365    }
366
367    /// Clears the session of all data but does not delete it from the store.
368    ///
369    /// # Examples
370    ///
371    /// ```rust
372    /// # tokio_test::block_on(async {
373    /// use std::sync::Arc;
374    ///
375    /// use tower_sessions::{MemoryStore, Session};
376    ///
377    /// let store = Arc::new(MemoryStore::default());
378    ///
379    /// let session = Session::new(None, store.clone(), None);
380    /// session.insert("foo", 42).await.unwrap();
381    /// assert!(!session.is_empty().await);
382    ///
383    /// session.save().await.unwrap();
384    ///
385    /// session.clear().await;
386    ///
387    /// // Not empty! (We have an ID still.)
388    /// assert!(!session.is_empty().await);
389    /// // Data is cleared...
390    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
391    ///
392    /// // ...data is cleared before loading from the backend...
393    /// let session = Session::new(session.id(), store.clone(), None);
394    /// session.clear().await;
395    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
396    ///
397    /// let session = Session::new(session.id(), store, None);
398    /// // ...but data is not deleted from the store.
399    /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
400    /// # });
401    /// ```
402    pub async fn clear(&self) {
403        let mut record_guard = self.inner.record.lock().await;
404        if let Some(record) = record_guard.as_mut() {
405            record.data.clear();
406        } else if let Some(session_id) = *self.inner.session_id.lock() {
407            let mut new_record = self.create_record();
408            new_record.id = session_id;
409            *record_guard = Some(new_record);
410        }
411
412        self.inner
413            .is_modified
414            .store(true, atomic::Ordering::Release);
415    }
416
417    /// Returns `true` if there is no session ID and the session is empty.
418    ///
419    /// # Examples
420    ///
421    /// ```rust
422    /// # tokio_test::block_on(async {
423    /// use std::sync::Arc;
424    ///
425    /// use tower_sessions::{session::Id, MemoryStore, Session};
426    ///
427    /// let store = Arc::new(MemoryStore::default());
428    ///
429    /// let session = Session::new(None, store.clone(), None);
430    /// // Empty if we have no ID and record is not loaded.
431    /// assert!(session.is_empty().await);
432    ///
433    /// let session = Session::new(Some(Id::default()), store.clone(), None);
434    /// // Not empty if we have an ID but no record. (Record is not loaded here.)
435    /// assert!(!session.is_empty().await);
436    ///
437    /// let session = Session::new(Some(Id::default()), store.clone(), None);
438    /// session.insert("foo", 42).await.unwrap();
439    /// // Not empty after inserting.
440    /// assert!(!session.is_empty().await);
441    /// session.save().await.unwrap();
442    /// // Not empty after saving.
443    /// assert!(!session.is_empty().await);
444    ///
445    /// let session = Session::new(session.id(), store.clone(), None);
446    /// session.load().await.unwrap();
447    /// // Not empty after loading from store...
448    /// assert!(!session.is_empty().await);
449    /// // ...and not empty after accessing the session.
450    /// session.get::<usize>("foo").await.unwrap();
451    /// assert!(!session.is_empty().await);
452    ///
453    /// let session = Session::new(session.id(), store.clone(), None);
454    /// session.delete().await.unwrap();
455    /// // Not empty after deleting from store...
456    /// assert!(!session.is_empty().await);
457    /// session.get::<usize>("foo").await.unwrap();
458    /// // ...but empty after trying to access the deleted session.
459    /// assert!(session.is_empty().await);
460    ///
461    /// let session = Session::new(None, store, None);
462    /// session.insert("foo", 42).await.unwrap();
463    /// session.flush().await.unwrap();
464    /// // Empty after flushing.
465    /// assert!(session.is_empty().await);
466    /// # });
467    /// ```
468    pub async fn is_empty(&self) -> bool {
469        let record_guard = self.inner.record.lock().await;
470
471        // N.B.: Session IDs are `None` if:
472        //
473        // 1. The cookie was not provided or otherwise could not be parsed,
474        // 2. Or the session could not be loaded from the store.
475        let session_id = self.inner.session_id.lock();
476
477        let Some(record) = record_guard.as_ref() else {
478            return session_id.is_none();
479        };
480
481        session_id.is_none() && record.data.is_empty()
482    }
483
484    /// Get the session ID.
485    ///
486    /// # Examples
487    ///
488    /// ```rust
489    /// use std::sync::Arc;
490    ///
491    /// use tower_sessions::{session::Id, MemoryStore, Session};
492    ///
493    /// let store = Arc::new(MemoryStore::default());
494    ///
495    /// let session = Session::new(None, store.clone(), None);
496    /// assert!(session.id().is_none());
497    ///
498    /// let id = Some(Id::default());
499    /// let session = Session::new(id, store, None);
500    /// assert_eq!(id, session.id());
501    /// ```
502    pub fn id(&self) -> Option<Id> {
503        *self.inner.session_id.lock()
504    }
505
506    /// Get the session expiry.
507    ///
508    /// # Examples
509    ///
510    /// ```rust
511    /// use std::sync::Arc;
512    ///
513    /// use tower_sessions::{session::Expiry, MemoryStore, Session};
514    ///
515    /// let store = Arc::new(MemoryStore::default());
516    /// let session = Session::new(None, store, None);
517    ///
518    /// assert_eq!(session.expiry(), None);
519    /// ```
520    pub fn expiry(&self) -> Option<Expiry> {
521        *self.inner.expiry.lock()
522    }
523
524    /// Set `expiry` to the given value.
525    ///
526    /// This may be used within applications directly to alter the session's
527    /// time to live.
528    ///
529    /// # Examples
530    ///
531    /// ```rust
532    /// use std::sync::Arc;
533    ///
534    /// use time::OffsetDateTime;
535    /// use tower_sessions::{session::Expiry, MemoryStore, Session};
536    ///
537    /// let store = Arc::new(MemoryStore::default());
538    /// let session = Session::new(None, store, None);
539    ///
540    /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
541    /// session.set_expiry(Some(expiry));
542    ///
543    /// assert_eq!(session.expiry(), Some(expiry));
544    /// ```
545    pub fn set_expiry(&self, expiry: Option<Expiry>) {
546        *self.inner.expiry.lock() = expiry;
547        self.inner
548            .is_modified
549            .store(true, atomic::Ordering::Release);
550    }
551
552    /// Get session expiry as `OffsetDateTime`.
553    ///
554    /// # Examples
555    ///
556    /// ```rust
557    /// use std::sync::Arc;
558    ///
559    /// use time::{Duration, OffsetDateTime};
560    /// use tower_sessions::{MemoryStore, Session};
561    ///
562    /// let store = Arc::new(MemoryStore::default());
563    /// let session = Session::new(None, store, None);
564    ///
565    /// // Our default duration is two weeks.
566    /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
567    ///
568    /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
569    /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
570    /// ```
571    pub fn expiry_date(&self) -> OffsetDateTime {
572        let expiry = self.inner.expiry.lock();
573        match *expiry {
574            Some(Expiry::OnInactivity(duration)) => {
575                OffsetDateTime::now_utc().saturating_add(duration)
576            }
577            Some(Expiry::AtDateTime(datetime)) => datetime,
578            Some(Expiry::OnSessionEnd) | None => {
579                OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION) // TODO: The default should probably be configurable.
580            }
581        }
582    }
583
584    /// Get session expiry as `Duration`.
585    ///
586    /// # Examples
587    ///
588    /// ```rust
589    /// use std::sync::Arc;
590    ///
591    /// use time::Duration;
592    /// use tower_sessions::{MemoryStore, Session};
593    ///
594    /// let store = Arc::new(MemoryStore::default());
595    /// let session = Session::new(None, store, None);
596    ///
597    /// let expected_duration = Duration::weeks(2);
598    ///
599    /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
600    /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
601    /// ```
602    pub fn expiry_age(&self) -> Duration {
603        std::cmp::max(
604            self.expiry_date() - OffsetDateTime::now_utc(),
605            Duration::ZERO,
606        )
607    }
608
609    /// Returns `true` if the session has been modified during the request.
610    ///
611    /// # Examples
612    ///
613    /// ```rust
614    /// # tokio_test::block_on(async {
615    /// use std::sync::Arc;
616    ///
617    /// use tower_sessions::{MemoryStore, Session};
618    ///
619    /// let store = Arc::new(MemoryStore::default());
620    /// let session = Session::new(None, store, None);
621    ///
622    /// // Not modified initially.
623    /// assert!(!session.is_modified());
624    ///
625    /// // Getting doesn't count as a modification.
626    /// session.get::<usize>("foo").await.unwrap();
627    /// assert!(!session.is_modified());
628    ///
629    /// // Insertions and removals do though.
630    /// session.insert("foo", 42).await.unwrap();
631    /// assert!(session.is_modified());
632    /// # });
633    /// ```
634    pub fn is_modified(&self) -> bool {
635        self.inner.is_modified.load(atomic::Ordering::Acquire)
636    }
637
638    /// Saves the session record to the store.
639    ///
640    /// Note that this method is generally not needed and is reserved for
641    /// situations where the session store must be updated during the
642    /// request.
643    ///
644    /// # Examples
645    ///
646    /// ```rust
647    /// # tokio_test::block_on(async {
648    /// use std::sync::Arc;
649    ///
650    /// use tower_sessions::{MemoryStore, Session};
651    ///
652    /// let store = Arc::new(MemoryStore::default());
653    /// let session = Session::new(None, store.clone(), None);
654    ///
655    /// session.insert("foo", 42).await.unwrap();
656    /// session.save().await.unwrap();
657    ///
658    /// let session = Session::new(session.id(), store, None);
659    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
660    /// # });
661    /// ```
662    ///
663    /// # Errors
664    ///
665    /// - If saving to the store fails, we fail with [`Error::Store`].
666    #[tracing::instrument(skip(self), err)]
667    pub async fn save(&self) -> Result<()> {
668        let mut record_guard = self.get_record().await?;
669        record_guard.expiry_date = self.expiry_date();
670
671        // Session ID is `None` if:
672        //
673        //  1. No valid cookie was found on the request or,
674        //  2. No valid session was found in the store.
675        //
676        // In either case, we must create a new session via the store interface.
677        //
678        // Potential ID collisions must be handled by session store implementers.
679        if self.inner.session_id.lock().is_none() {
680            self.store.create(&mut record_guard).await?;
681            *self.inner.session_id.lock() = Some(record_guard.id);
682        } else {
683            self.store.save(&record_guard).await?;
684        }
685        Ok(())
686    }
687
688    /// Loads the session record from the store.
689    ///
690    /// Note that this method is generally not needed and is reserved for
691    /// situations where the session must be updated during the request.
692    ///
693    /// # Examples
694    ///
695    /// ```rust
696    /// # tokio_test::block_on(async {
697    /// use std::sync::Arc;
698    ///
699    /// use tower_sessions::{session::Id, MemoryStore, Session};
700    ///
701    /// let store = Arc::new(MemoryStore::default());
702    /// let id = Some(Id::default());
703    /// let session = Session::new(id, store.clone(), None);
704    ///
705    /// session.insert("foo", 42).await.unwrap();
706    /// session.save().await.unwrap();
707    ///
708    /// let session = Session::new(session.id(), store, None);
709    /// session.load().await.unwrap();
710    ///
711    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
712    /// # });
713    /// ```
714    ///
715    /// # Errors
716    ///
717    /// - If loading from the store fails, we fail with [`Error::Store`].
718    #[tracing::instrument(skip(self), err)]
719    pub async fn load(&self) -> Result<()> {
720        let session_id = *self.inner.session_id.lock();
721        let Some(ref id) = session_id else {
722            tracing::warn!("called load with no session id");
723            return Ok(());
724        };
725        let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
726        let mut record_guard = self.inner.record.lock().await;
727        *record_guard = loaded_record;
728        Ok(())
729    }
730
731    /// Deletes the session from the store.
732    ///
733    /// # Examples
734    ///
735    /// ```rust
736    /// # tokio_test::block_on(async {
737    /// use std::sync::Arc;
738    ///
739    /// use tower_sessions::{session::Id, MemoryStore, Session, SessionStore};
740    ///
741    /// let store = Arc::new(MemoryStore::default());
742    /// let session = Session::new(Some(Id::default()), store.clone(), None);
743    ///
744    /// // Save before deleting.
745    /// session.save().await.unwrap();
746    ///
747    /// // Delete from the store.
748    /// session.delete().await.unwrap();
749    ///
750    /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
751    /// # });
752    /// ```
753    ///
754    /// # Errors
755    ///
756    /// - If deleting from the store fails, we fail with [`Error::Store`].
757    #[tracing::instrument(skip(self), err)]
758    pub async fn delete(&self) -> Result<()> {
759        let session_id = *self.inner.session_id.lock();
760        let Some(ref session_id) = session_id else {
761            tracing::warn!("called delete with no session id");
762            return Ok(());
763        };
764        self.store.delete(session_id).await.map_err(Error::Store)?;
765        Ok(())
766    }
767
768    /// Flushes the session by removing all data contained in the session and
769    /// then deleting it from the store.
770    ///
771    /// # Examples
772    ///
773    /// ```rust
774    /// # tokio_test::block_on(async {
775    /// use std::sync::Arc;
776    ///
777    /// use tower_sessions::{MemoryStore, Session, SessionStore};
778    ///
779    /// let store = Arc::new(MemoryStore::default());
780    /// let session = Session::new(None, store.clone(), None);
781    ///
782    /// session.insert("foo", "bar").await.unwrap();
783    /// session.save().await.unwrap();
784    ///
785    /// let id = session.id().unwrap();
786    ///
787    /// session.flush().await.unwrap();
788    ///
789    /// assert!(session.id().is_none());
790    /// assert!(session.is_empty().await);
791    /// assert!(store.load(&id).await.unwrap().is_none());
792    /// # });
793    /// ```
794    ///
795    /// # Errors
796    ///
797    /// - If deleting from the store fails, we fail with [`Error::Store`].
798    pub async fn flush(&self) -> Result<()> {
799        self.clear().await;
800        self.delete().await?;
801        *self.inner.session_id.lock() = None;
802        Ok(())
803    }
804
805    /// Cycles the session ID while retaining any data that was associated with
806    /// it.
807    ///
808    /// Using this method helps prevent session fixation attacks by ensuring a
809    /// new ID is assigned to the session.
810    ///
811    /// # Examples
812    ///
813    /// ```rust
814    /// # tokio_test::block_on(async {
815    /// use std::sync::Arc;
816    ///
817    /// use tower_sessions::{session::Id, MemoryStore, Session};
818    ///
819    /// let store = Arc::new(MemoryStore::default());
820    /// let session = Session::new(None, store.clone(), None);
821    ///
822    /// session.insert("foo", 42).await.unwrap();
823    /// session.save().await.unwrap();
824    /// let id = session.id();
825    ///
826    /// let session = Session::new(session.id(), store.clone(), None);
827    /// session.cycle_id().await.unwrap();
828    ///
829    /// assert!(!session.is_empty().await);
830    /// assert!(session.is_modified());
831    ///
832    /// session.save().await.unwrap();
833    ///
834    /// let session = Session::new(session.id(), store, None);
835    ///
836    /// assert_ne!(id, session.id());
837    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
838    /// # });
839    /// ```
840    ///
841    /// # Errors
842    ///
843    /// - If deleting from the store fails or saving to the store fails, we fail
844    ///   with [`Error::Store`].
845    pub async fn cycle_id(&self) -> Result<()> {
846        let mut record_guard = self.get_record().await?;
847
848        let old_session_id = record_guard.id;
849        record_guard.id = Id::default();
850        *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
851                                              // `create` method.
852
853        self.store
854            .delete(&old_session_id)
855            .await
856            .map_err(Error::Store)?;
857
858        self.inner
859            .is_modified
860            .store(true, atomic::Ordering::Release);
861
862        Ok(())
863    }
864}
865
866/// ID type for sessions.
867///
868/// Wraps an array of 16 bytes.
869///
870/// # Examples
871///
872/// ```rust
873/// use tower_sessions::session::Id;
874///
875/// Id::default();
876/// ```
877#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
878pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
879                         // session ID, which is undesirable.
880
881impl Default for Id {
882    fn default() -> Self {
883        use rand::prelude::*;
884
885        Self(rand::rng().random())
886    }
887}
888
889impl Display for Id {
890    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
891        let mut encoded = [0; 22];
892        URL_SAFE_NO_PAD
893            .encode_slice(self.0.to_le_bytes(), &mut encoded)
894            .expect("Encoded ID must be exactly 22 bytes");
895        let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
896
897        f.write_str(encoded)
898    }
899}
900
901impl FromStr for Id {
902    type Err = base64::DecodeSliceError;
903
904    fn from_str(s: &str) -> result::Result<Self, Self::Err> {
905        let mut decoded = [0; 16];
906        let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
907        if bytes_decoded != 16 {
908            let err = DecodeError::InvalidLength(bytes_decoded);
909            return Err(base64::DecodeSliceError::DecodeError(err));
910        }
911
912        Ok(Self(i128::from_le_bytes(decoded)))
913    }
914}
915
916/// Record type that's appropriate for encoding and decoding sessions to and
917/// from session stores.
918#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
919pub struct Record {
920    pub id: Id,
921    pub data: Data,
922    pub expiry_date: OffsetDateTime,
923}
924
925impl Record {
926    fn new(expiry_date: OffsetDateTime) -> Self {
927        Self {
928            id: Id::default(),
929            data: Data::default(),
930            expiry_date,
931        }
932    }
933}
934
935/// Session expiry configuration.
936///
937/// # Examples
938///
939/// ```rust
940/// use time::{Duration, OffsetDateTime};
941/// use tower_sessions::Expiry;
942///
943/// // Will be expired on "session end".
944/// let expiry = Expiry::OnSessionEnd;
945///
946/// // Will be expired in five minutes from last acitve.
947/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
948///
949/// // Will be expired at the given timestamp.
950/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
951/// let expiry = Expiry::AtDateTime(expired_at);
952/// ```
953#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
954pub enum Expiry {
955    /// Expire on [current session end][current-session-end], as defined by the
956    /// browser.
957    ///
958    /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
959    OnSessionEnd,
960
961    /// Expire on inactivity.
962    ///
963    /// Reading a session is not considered activity for expiration purposes.
964    /// [`Session`] expiration is computed from the last time the session was
965    /// _modified_.
966    OnInactivity(Duration),
967
968    /// Expire at a specific date and time.
969    ///
970    /// This value may be extended manually with
971    /// [`set_expiry`](Session::set_expiry).
972    AtDateTime(OffsetDateTime),
973}
974
975#[cfg(test)]
976mod tests {
977    use async_trait::async_trait;
978    use mockall::{
979        mock,
980        predicate::{self, always},
981    };
982
983    use super::*;
984
985    mock! {
986        #[derive(Debug)]
987        pub Store {}
988
989        #[async_trait]
990        impl SessionStore for Store {
991            async fn create(&self, record: &mut Record) -> session_store::Result<()>;
992            async fn save(&self, record: &Record) -> session_store::Result<()>;
993            async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
994            async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
995        }
996    }
997
998    #[tokio::test]
999    async fn test_cycle_id() {
1000        let mut mock_store = MockStore::new();
1001
1002        let initial_id = Id::default();
1003        let new_id = Id::default();
1004
1005        // Set up expectations for the mock store
1006        mock_store
1007            .expect_save()
1008            .with(always())
1009            .times(1)
1010            .returning(|_| Ok(()));
1011        mock_store
1012            .expect_load()
1013            .with(predicate::eq(initial_id))
1014            .times(1)
1015            .returning(move |_| {
1016                Ok(Some(Record {
1017                    id: initial_id,
1018                    data: Data::default(),
1019                    expiry_date: OffsetDateTime::now_utc(),
1020                }))
1021            });
1022        mock_store
1023            .expect_delete()
1024            .with(predicate::eq(initial_id))
1025            .times(1)
1026            .returning(|_| Ok(()));
1027        mock_store
1028            .expect_create()
1029            .times(1)
1030            .returning(move |record| {
1031                record.id = new_id;
1032                Ok(())
1033            });
1034
1035        let store = Arc::new(mock_store);
1036        let session = Session::new(Some(initial_id), store.clone(), None);
1037
1038        // Insert some data and save the session
1039        session.insert("foo", 42).await.unwrap();
1040        session.save().await.unwrap();
1041
1042        // Cycle the session ID
1043        session.cycle_id().await.unwrap();
1044
1045        // Verify that the session ID has changed and the data is still present
1046        assert_ne!(session.id(), Some(initial_id));
1047        assert!(session.id().is_none()); // The session ID should be None
1048        assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1049
1050        // Save the session to update the ID in the session object
1051        session.save().await.unwrap();
1052        assert_eq!(session.id(), Some(new_id));
1053    }
1054}