Skip to main content

tower_sessions_ext_core/
session.rs

1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3    collections::HashMap,
4    fmt::{self, Display},
5    hash::Hash,
6    result,
7    str::{self, FromStr},
8    sync::{
9        Arc,
10        atomic::{self, AtomicBool},
11    },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31    /// Maps `serde_json` errors.
32    #[error(transparent)]
33    SerdeJson(#[from] serde_json::Error),
34
35    /// Maps `session_store::Error` errors.
36    #[error(transparent)]
37    Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42    // This will be `None` when:
43    //
44    // 1. We have not been provided a session cookie or have failed to parse it,
45    // 2. The store has not found the session.
46    //
47    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48    session_id: parking_lot::Mutex<Option<Id>>,
49
50    // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51    // `None` value indicates we have not tried to access it yet. After access, it will always
52    // contain `Some(Record)`.
53    record: Mutex<Option<Record>>,
54
55    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56    expiry: parking_lot::Mutex<Option<Expiry>>,
57
58    is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Clone)]
64pub struct Session {
65    store: Arc<dyn SessionStore>,
66    inner: Arc<Inner>,
67}
68
69impl fmt::Debug for Session {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("Session")
72            .field("store", &self.store)
73            .field("inner", &self.inner)
74            .finish()
75    }
76}
77
78impl Session {
79    /// Creates a new session with the session ID, store, and expiry.
80    ///
81    /// This method is lazy and does not invoke the overhead of talking to the
82    /// backing store.
83    ///
84    /// # Examples
85    ///
86    /// ```rust
87    /// use std::sync::Arc;
88    ///
89    /// use tower_sessions_ext::{MemoryStore, Session};
90    ///
91    /// let store = Arc::new(MemoryStore::default());
92    /// Session::new(None, store, None);
93    /// ```
94    pub fn new(
95        session_id: Option<Id>,
96        store: Arc<impl SessionStore>,
97        expiry: Option<Expiry>,
98    ) -> Self {
99        let inner = Inner {
100            session_id: parking_lot::Mutex::new(session_id),
101            record: Mutex::new(None), // `None` indicates we have not loaded from store.
102            expiry: parking_lot::Mutex::new(expiry),
103            is_modified: AtomicBool::new(false),
104        };
105
106        Self {
107            store,
108            inner: Arc::new(inner),
109        }
110    }
111
112    fn create_record(&self) -> Record {
113        Record::new(self.expiry_date())
114    }
115
116    #[tracing::instrument(skip(self), err, level = "trace")]
117    async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
118        let mut record_guard = self.inner.record.lock().await;
119
120        // Lazily load the record since `None` here indicates we have no yet loaded it.
121        if record_guard.is_none() {
122            tracing::trace!("record not loaded from store; loading");
123
124            let session_id = *self.inner.session_id.lock();
125            *record_guard = Some(if let Some(session_id) = session_id {
126                match self.store.load(&session_id).await? {
127                    Some(loaded_record) => {
128                        tracing::trace!("record found in store");
129                        loaded_record
130                    }
131
132                    None => {
133                        // A well-behaved user agent should not send session cookies after
134                        // expiration. Even so it's possible for an expired session to be removed
135                        // from the store after a request was initiated. However, such a race should
136                        // be relatively uncommon and as such entering this branch could indicate
137                        // malicious behavior.
138                        tracing::warn!("possibly suspicious activity: record not found in store");
139                        *self.inner.session_id.lock() = None;
140                        self.create_record()
141                    }
142                }
143            } else {
144                tracing::trace!("session id not found");
145                self.create_record()
146            })
147        }
148
149        Ok(MutexGuard::map(record_guard, |opt| {
150            opt.as_mut()
151                .expect("Record should always be `Option::Some` at this point")
152        }))
153    }
154
155    /// Inserts a `impl Serialize` value into the session.
156    ///
157    /// # Examples
158    ///
159    /// ```rust
160    /// # tokio_test::block_on(async {
161    /// use std::sync::Arc;
162    ///
163    /// use tower_sessions_ext::{MemoryStore, Session};
164    ///
165    /// let store = Arc::new(MemoryStore::default());
166    /// let session = Session::new(None, store, None);
167    ///
168    /// session.insert("foo", 42).await.unwrap();
169    ///
170    /// let value = session.get::<usize>("foo").await.unwrap();
171    /// assert_eq!(value, Some(42));
172    /// # });
173    /// ```
174    ///
175    /// # Errors
176    ///
177    /// - This method can fail when [`serde_json::to_value`] fails.
178    /// - If the session has not been hydrated and loading from the store fails,
179    ///   we fail with [`Error::Store`].
180    pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
181        self.insert_value(key, serde_json::to_value(&value)?)
182            .await?;
183        Ok(())
184    }
185
186    /// Inserts a `serde_json::Value` into the session.
187    ///
188    /// If the key was not present in the underlying map, `None` is returned and
189    /// `modified` is set to `true`.
190    ///
191    /// If the underlying map did have the key and its value is the same as the
192    /// provided value, `None` is returned and `modified` is not set.
193    ///
194    /// # Examples
195    ///
196    /// ```rust
197    /// # tokio_test::block_on(async {
198    /// use std::sync::Arc;
199    ///
200    /// use tower_sessions_ext::{MemoryStore, Session};
201    ///
202    /// let store = Arc::new(MemoryStore::default());
203    /// let session = Session::new(None, store, None);
204    ///
205    /// let value = session
206    ///     .insert_value("foo", serde_json::json!(42))
207    ///     .await
208    ///     .unwrap();
209    /// assert!(value.is_none());
210    ///
211    /// let value = session
212    ///     .insert_value("foo", serde_json::json!(42))
213    ///     .await
214    ///     .unwrap();
215    /// assert!(value.is_none());
216    ///
217    /// let value = session
218    ///     .insert_value("foo", serde_json::json!("bar"))
219    ///     .await
220    ///     .unwrap();
221    /// assert_eq!(value, Some(serde_json::json!(42)));
222    /// # });
223    /// ```
224    ///
225    /// # Errors
226    ///
227    /// - If the session has not been hydrated and loading from the store fails,
228    ///   we fail with [`Error::Store`].
229    pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
230        let mut record_guard = self.get_record().await?;
231        Ok(if record_guard.data.get(key) != Some(&value) {
232            self.inner
233                .is_modified
234                .store(true, atomic::Ordering::Release);
235            record_guard.data.insert(key.to_string(), value)
236        } else {
237            None
238        })
239    }
240
241    /// Gets a value from the store.
242    ///
243    /// # Examples
244    ///
245    /// ```rust
246    /// # tokio_test::block_on(async {
247    /// use std::sync::Arc;
248    ///
249    /// use tower_sessions_ext::{MemoryStore, Session};
250    ///
251    /// let store = Arc::new(MemoryStore::default());
252    /// let session = Session::new(None, store, None);
253    ///
254    /// session.insert("foo", 42).await.unwrap();
255    ///
256    /// let value = session.get::<usize>("foo").await.unwrap();
257    /// assert_eq!(value, Some(42));
258    /// # });
259    /// ```
260    ///
261    /// # Errors
262    ///
263    /// - This method can fail when [`serde_json::from_value`] fails.
264    /// - If the session has not been hydrated and loading from the store fails,
265    ///   we fail with [`Error::Store`].
266    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
267        Ok(self
268            .get_value(key)
269            .await?
270            .map(serde_json::from_value)
271            .transpose()?)
272    }
273
274    /// Gets a `serde_json::Value` from the store.
275    ///
276    /// # Examples
277    ///
278    /// ```rust
279    /// # tokio_test::block_on(async {
280    /// use std::sync::Arc;
281    ///
282    /// use tower_sessions_ext::{MemoryStore, Session};
283    ///
284    /// let store = Arc::new(MemoryStore::default());
285    /// let session = Session::new(None, store, None);
286    ///
287    /// session.insert("foo", 42).await.unwrap();
288    ///
289    /// let value = session.get_value("foo").await.unwrap().unwrap();
290    /// assert_eq!(value, serde_json::json!(42));
291    /// # });
292    /// ```
293    ///
294    /// # Errors
295    ///
296    /// - If the session has not been hydrated and loading from the store fails,
297    ///   we fail with [`Error::Store`].
298    pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
299        let record_guard = self.get_record().await?;
300        Ok(record_guard.data.get(key).cloned())
301    }
302
303    /// Removes a value from the store, retuning the value of the key if it was
304    /// present in the underlying map.
305    ///
306    /// # Examples
307    ///
308    /// ```rust
309    /// # tokio_test::block_on(async {
310    /// use std::sync::Arc;
311    ///
312    /// use tower_sessions_ext::{MemoryStore, Session};
313    ///
314    /// let store = Arc::new(MemoryStore::default());
315    /// let session = Session::new(None, store, None);
316    ///
317    /// session.insert("foo", 42).await.unwrap();
318    ///
319    /// let value: Option<usize> = session.remove("foo").await.unwrap();
320    /// assert_eq!(value, Some(42));
321    ///
322    /// let value: Option<usize> = session.get("foo").await.unwrap();
323    /// assert!(value.is_none());
324    /// # });
325    /// ```
326    ///
327    /// # Errors
328    ///
329    /// - This method can fail when [`serde_json::from_value`] fails.
330    /// - If the session has not been hydrated and loading from the store fails,
331    ///   we fail with [`Error::Store`].
332    pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
333        Ok(self
334            .remove_value(key)
335            .await?
336            .map(serde_json::from_value)
337            .transpose()?)
338    }
339
340    /// Removes a `serde_json::Value` from the session.
341    ///
342    /// # Examples
343    ///
344    /// ```rust
345    /// # tokio_test::block_on(async {
346    /// use std::sync::Arc;
347    ///
348    /// use tower_sessions_ext::{MemoryStore, Session};
349    ///
350    /// let store = Arc::new(MemoryStore::default());
351    /// let session = Session::new(None, store, None);
352    ///
353    /// session.insert("foo", 42).await.unwrap();
354    /// let value = session.remove_value("foo").await.unwrap().unwrap();
355    /// assert_eq!(value, serde_json::json!(42));
356    ///
357    /// let value: Option<usize> = session.get("foo").await.unwrap();
358    /// assert!(value.is_none());
359    /// # });
360    /// ```
361    ///
362    /// # Errors
363    ///
364    /// - If the session has not been hydrated and loading from the store fails,
365    ///   we fail with [`Error::Store`].
366    pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
367        let mut record_guard = self.get_record().await?;
368        self.inner
369            .is_modified
370            .store(true, atomic::Ordering::Release);
371        Ok(record_guard.data.remove(key))
372    }
373
374    /// Clears the session of all data but does not delete it from the store.
375    ///
376    /// # Examples
377    ///
378    /// ```rust
379    /// # tokio_test::block_on(async {
380    /// use std::sync::Arc;
381    ///
382    /// use tower_sessions_ext::{MemoryStore, Session};
383    ///
384    /// let store = Arc::new(MemoryStore::default());
385    ///
386    /// let session = Session::new(None, store.clone(), None);
387    /// session.insert("foo", 42).await.unwrap();
388    /// assert!(!session.is_empty().await);
389    ///
390    /// session.save().await.unwrap();
391    ///
392    /// session.clear().await;
393    ///
394    /// // Not empty! (We have an ID still.)
395    /// assert!(!session.is_empty().await);
396    /// // Data is cleared...
397    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
398    ///
399    /// // ...data is cleared before loading from the backend...
400    /// let session = Session::new(session.id(), store.clone(), None);
401    /// session.clear().await;
402    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
403    ///
404    /// let session = Session::new(session.id(), store, None);
405    /// // ...but data is not deleted from the store.
406    /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
407    /// # });
408    /// ```
409    pub async fn clear(&self) {
410        let mut record_guard = self.inner.record.lock().await;
411        if let Some(record) = record_guard.as_mut() {
412            record.data.clear();
413        } else if let Some(session_id) = *self.inner.session_id.lock() {
414            let mut new_record = self.create_record();
415            new_record.id = session_id;
416            *record_guard = Some(new_record);
417        }
418
419        self.inner
420            .is_modified
421            .store(true, atomic::Ordering::Release);
422    }
423
424    /// Returns `true` if there is no session ID and the session is empty.
425    ///
426    /// # Examples
427    ///
428    /// ```rust
429    /// # tokio_test::block_on(async {
430    /// use std::sync::Arc;
431    ///
432    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
433    ///
434    /// let store = Arc::new(MemoryStore::default());
435    ///
436    /// let session = Session::new(None, store.clone(), None);
437    /// // Empty if we have no ID and record is not loaded.
438    /// assert!(session.is_empty().await);
439    ///
440    /// let session = Session::new(Some(Id::default()), store.clone(), None);
441    /// // Not empty if we have an ID but no record. (Record is not loaded here.)
442    /// assert!(!session.is_empty().await);
443    ///
444    /// let session = Session::new(Some(Id::default()), store.clone(), None);
445    /// session.insert("foo", 42).await.unwrap();
446    /// // Not empty after inserting.
447    /// assert!(!session.is_empty().await);
448    /// session.save().await.unwrap();
449    /// // Not empty after saving.
450    /// assert!(!session.is_empty().await);
451    ///
452    /// let session = Session::new(session.id(), store.clone(), None);
453    /// session.load().await.unwrap();
454    /// // Not empty after loading from store...
455    /// assert!(!session.is_empty().await);
456    /// // ...and not empty after accessing the session.
457    /// session.get::<usize>("foo").await.unwrap();
458    /// assert!(!session.is_empty().await);
459    ///
460    /// let session = Session::new(session.id(), store.clone(), None);
461    /// session.delete().await.unwrap();
462    /// // Not empty after deleting from store...
463    /// assert!(!session.is_empty().await);
464    /// session.get::<usize>("foo").await.unwrap();
465    /// // ...but empty after trying to access the deleted session.
466    /// assert!(session.is_empty().await);
467    ///
468    /// let session = Session::new(None, store, None);
469    /// session.insert("foo", 42).await.unwrap();
470    /// session.flush().await.unwrap();
471    /// // Empty after flushing.
472    /// assert!(session.is_empty().await);
473    /// # });
474    /// ```
475    pub async fn is_empty(&self) -> bool {
476        let record_guard = self.inner.record.lock().await;
477
478        // N.B.: Session IDs are `None` if:
479        //
480        // 1. The cookie was not provided or otherwise could not be parsed,
481        // 2. Or the session could not be loaded from the store.
482        let session_id = self.inner.session_id.lock();
483
484        let Some(record) = record_guard.as_ref() else {
485            return session_id.is_none();
486        };
487
488        session_id.is_none() && record.data.is_empty()
489    }
490
491    /// Get the session ID.
492    ///
493    /// # Examples
494    ///
495    /// ```rust
496    /// use std::sync::Arc;
497    ///
498    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
499    ///
500    /// let store = Arc::new(MemoryStore::default());
501    ///
502    /// let session = Session::new(None, store.clone(), None);
503    /// assert!(session.id().is_none());
504    ///
505    /// let id = Some(Id::default());
506    /// let session = Session::new(id, store, None);
507    /// assert_eq!(id, session.id());
508    /// ```
509    pub fn id(&self) -> Option<Id> {
510        *self.inner.session_id.lock()
511    }
512
513    /// Get the session expiry.
514    ///
515    /// # Examples
516    ///
517    /// ```rust
518    /// use std::sync::Arc;
519    ///
520    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
521    ///
522    /// let store = Arc::new(MemoryStore::default());
523    /// let session = Session::new(None, store, None);
524    ///
525    /// assert_eq!(session.expiry(), None);
526    /// ```
527    pub fn expiry(&self) -> Option<Expiry> {
528        *self.inner.expiry.lock()
529    }
530
531    /// Set `expiry` to the given value.
532    ///
533    /// This may be used within applications directly to alter the session's
534    /// time to live.
535    ///
536    /// # Examples
537    ///
538    /// ```rust
539    /// use std::sync::Arc;
540    ///
541    /// use time::OffsetDateTime;
542    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
543    ///
544    /// let store = Arc::new(MemoryStore::default());
545    /// let session = Session::new(None, store, None);
546    ///
547    /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
548    /// session.set_expiry(Some(expiry));
549    ///
550    /// assert_eq!(session.expiry(), Some(expiry));
551    /// ```
552    pub fn set_expiry(&self, expiry: Option<Expiry>) {
553        *self.inner.expiry.lock() = expiry;
554        self.inner
555            .is_modified
556            .store(true, atomic::Ordering::Release);
557    }
558
559    /// Get session expiry as `OffsetDateTime`.
560    ///
561    /// # Examples
562    ///
563    /// ```rust
564    /// use std::sync::Arc;
565    ///
566    /// use time::{Duration, OffsetDateTime};
567    /// use tower_sessions_ext::{MemoryStore, Session};
568    ///
569    /// let store = Arc::new(MemoryStore::default());
570    /// let session = Session::new(None, store, None);
571    ///
572    /// // Our default duration is two weeks.
573    /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
574    ///
575    /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
576    /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
577    /// ```
578    pub fn expiry_date(&self) -> OffsetDateTime {
579        let expiry = self.inner.expiry.lock();
580        match *expiry {
581            Some(Expiry::OnInactivity(duration)) => {
582                OffsetDateTime::now_utc().saturating_add(duration)
583            }
584            Some(Expiry::AtDateTime(datetime)) => datetime,
585            Some(Expiry::OnSessionEnd(datetime)) => {
586                OffsetDateTime::now_utc().saturating_add(datetime)
587            }
588            None => OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION),
589        }
590    }
591
592    /// Get session expiry as `Duration`.
593    ///
594    /// # Examples
595    ///
596    /// ```rust
597    /// use std::sync::Arc;
598    ///
599    /// use time::Duration;
600    /// use tower_sessions_ext::{MemoryStore, Session};
601    ///
602    /// let store = Arc::new(MemoryStore::default());
603    /// let session = Session::new(None, store, None);
604    ///
605    /// let expected_duration = Duration::weeks(2);
606    ///
607    /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
608    /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
609    /// ```
610    pub fn expiry_age(&self) -> Duration {
611        std::cmp::max(
612            self.expiry_date() - OffsetDateTime::now_utc(),
613            Duration::ZERO,
614        )
615    }
616
617    /// Returns `true` if the session has been modified during the request.
618    ///
619    /// # Examples
620    ///
621    /// ```rust
622    /// # tokio_test::block_on(async {
623    /// use std::sync::Arc;
624    ///
625    /// use tower_sessions_ext::{MemoryStore, Session};
626    ///
627    /// let store = Arc::new(MemoryStore::default());
628    /// let session = Session::new(None, store, None);
629    ///
630    /// // Not modified initially.
631    /// assert!(!session.is_modified());
632    ///
633    /// // Getting doesn't count as a modification.
634    /// session.get::<usize>("foo").await.unwrap();
635    /// assert!(!session.is_modified());
636    ///
637    /// // Insertions and removals do though.
638    /// session.insert("foo", 42).await.unwrap();
639    /// assert!(session.is_modified());
640    /// # });
641    /// ```
642    pub fn is_modified(&self) -> bool {
643        self.inner.is_modified.load(atomic::Ordering::Acquire)
644    }
645
646    /// Saves the session record to the store.
647    ///
648    /// Note that this method is generally not needed and is reserved for
649    /// situations where the session store must be updated during the
650    /// request.
651    ///
652    /// # Examples
653    ///
654    /// ```rust
655    /// # tokio_test::block_on(async {
656    /// use std::sync::Arc;
657    ///
658    /// use tower_sessions_ext::{MemoryStore, Session};
659    ///
660    /// let store = Arc::new(MemoryStore::default());
661    /// let session = Session::new(None, store.clone(), None);
662    ///
663    /// session.insert("foo", 42).await.unwrap();
664    /// session.save().await.unwrap();
665    ///
666    /// let session = Session::new(session.id(), store, None);
667    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
668    /// # });
669    /// ```
670    ///
671    /// # Errors
672    ///
673    /// - If saving to the store fails, we fail with [`Error::Store`].
674    #[tracing::instrument(skip(self), err, level = "trace")]
675    pub async fn save(&self) -> Result<()> {
676        let mut record_guard = self.get_record().await?;
677        record_guard.expiry_date = self.expiry_date();
678
679        // Session ID is `None` if:
680        //
681        //  1. No valid cookie was found on the request or,
682        //  2. No valid session was found in the store.
683        //
684        // In either case, we must create a new session via the store interface.
685        //
686        // Potential ID collisions must be handled by session store implementers.
687        if self.inner.session_id.lock().is_none() {
688            self.store.create(&mut record_guard).await?;
689            *self.inner.session_id.lock() = Some(record_guard.id);
690        } else {
691            self.store.save(&record_guard).await?;
692        }
693        Ok(())
694    }
695
696    /// Loads the session record from the store.
697    ///
698    /// Note that this method is generally not needed and is reserved for
699    /// situations where the session must be updated during the request.
700    ///
701    /// # Examples
702    ///
703    /// ```rust
704    /// # tokio_test::block_on(async {
705    /// use std::sync::Arc;
706    ///
707    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
708    ///
709    /// let store = Arc::new(MemoryStore::default());
710    /// let id = Some(Id::default());
711    /// let session = Session::new(id, store.clone(), None);
712    ///
713    /// session.insert("foo", 42).await.unwrap();
714    /// session.save().await.unwrap();
715    ///
716    /// let session = Session::new(session.id(), store, None);
717    /// session.load().await.unwrap();
718    ///
719    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
720    /// # });
721    /// ```
722    ///
723    /// # Errors
724    ///
725    /// - If loading from the store fails, we fail with [`Error::Store`].
726    #[tracing::instrument(skip(self), err, level = "trace")]
727    pub async fn load(&self) -> Result<()> {
728        let session_id = *self.inner.session_id.lock();
729        let Some(ref id) = session_id else {
730            tracing::warn!("called load with no session id");
731            return Ok(());
732        };
733        let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
734        let mut record_guard = self.inner.record.lock().await;
735        *record_guard = loaded_record;
736        Ok(())
737    }
738
739    /// Deletes the session from the store.
740    ///
741    /// # Examples
742    ///
743    /// ```rust
744    /// # tokio_test::block_on(async {
745    /// use std::sync::Arc;
746    ///
747    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
748    ///
749    /// let store = Arc::new(MemoryStore::default());
750    /// let session = Session::new(Some(Id::default()), store.clone(), None);
751    ///
752    /// // Save before deleting.
753    /// session.save().await.unwrap();
754    ///
755    /// // Delete from the store.
756    /// session.delete().await.unwrap();
757    ///
758    /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
759    /// # });
760    /// ```
761    ///
762    /// # Errors
763    ///
764    /// - If deleting from the store fails, we fail with [`Error::Store`].
765    #[tracing::instrument(skip(self), err, level = "trace")]
766    pub async fn delete(&self) -> Result<()> {
767        let session_id = *self.inner.session_id.lock();
768        let Some(ref session_id) = session_id else {
769            tracing::warn!("called delete with no session id");
770            return Ok(());
771        };
772        self.store.delete(session_id).await.map_err(Error::Store)?;
773        Ok(())
774    }
775
776    /// Flushes the session by removing all data contained in the session and
777    /// then deleting it from the store.
778    ///
779    /// # Examples
780    ///
781    /// ```rust
782    /// # tokio_test::block_on(async {
783    /// use std::sync::Arc;
784    ///
785    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
786    ///
787    /// let store = Arc::new(MemoryStore::default());
788    /// let session = Session::new(None, store.clone(), None);
789    ///
790    /// session.insert("foo", "bar").await.unwrap();
791    /// session.save().await.unwrap();
792    ///
793    /// let id = session.id().unwrap();
794    ///
795    /// session.flush().await.unwrap();
796    ///
797    /// assert!(session.id().is_none());
798    /// assert!(session.is_empty().await);
799    /// assert!(store.load(&id).await.unwrap().is_none());
800    /// # });
801    /// ```
802    ///
803    /// # Errors
804    ///
805    /// - If deleting from the store fails, we fail with [`Error::Store`].
806    pub async fn flush(&self) -> Result<()> {
807        self.clear().await;
808        self.delete().await?;
809        *self.inner.session_id.lock() = None;
810        Ok(())
811    }
812
813    /// Cycles the session ID while retaining any data that was associated with
814    /// it.
815    ///
816    /// Using this method helps prevent session fixation attacks by ensuring a
817    /// new ID is assigned to the session.
818    ///
819    /// # Examples
820    ///
821    /// ```rust
822    /// # tokio_test::block_on(async {
823    /// use std::sync::Arc;
824    ///
825    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
826    ///
827    /// let store = Arc::new(MemoryStore::default());
828    /// let session = Session::new(None, store.clone(), None);
829    ///
830    /// session.insert("foo", 42).await.unwrap();
831    /// session.save().await.unwrap();
832    /// let id = session.id();
833    ///
834    /// let session = Session::new(session.id(), store.clone(), None);
835    /// session.cycle_id().await.unwrap();
836    ///
837    /// assert!(!session.is_empty().await);
838    /// assert!(session.is_modified());
839    ///
840    /// session.save().await.unwrap();
841    ///
842    /// let session = Session::new(session.id(), store, None);
843    ///
844    /// assert_ne!(id, session.id());
845    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
846    /// # });
847    /// ```
848    ///
849    /// # Errors
850    ///
851    /// - If deleting from the store fails or saving to the store fails, we fail
852    ///   with [`Error::Store`].
853    pub async fn cycle_id(&self) -> Result<()> {
854        let mut record_guard = self.get_record().await?;
855
856        let old_session_id = record_guard.id;
857        record_guard.id = Id::default();
858        *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
859        // `create` method.
860
861        self.store
862            .delete(&old_session_id)
863            .await
864            .map_err(Error::Store)?;
865
866        self.inner
867            .is_modified
868            .store(true, atomic::Ordering::Release);
869
870        Ok(())
871    }
872}
873
874/// ID type for sessions.
875///
876/// Wraps an array of 16 bytes.
877///
878/// # Examples
879///
880/// ```rust
881/// use tower_sessions_ext::session::Id;
882///
883/// Id::default();
884/// ```
885#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
886pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
887// session ID, which is undesirable.
888
889impl Default for Id {
890    fn default() -> Self {
891        use rand::prelude::*;
892
893        Self(rand::rng().random())
894    }
895}
896
897impl Display for Id {
898    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899        let mut encoded = [0; 22];
900        URL_SAFE_NO_PAD
901            .encode_slice(self.0.to_le_bytes(), &mut encoded)
902            .expect("Encoded ID must be exactly 22 bytes");
903        let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
904
905        f.write_str(encoded)
906    }
907}
908
909impl FromStr for Id {
910    type Err = base64::DecodeSliceError;
911
912    fn from_str(s: &str) -> result::Result<Self, Self::Err> {
913        let mut decoded = [0; 16];
914        let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
915        if bytes_decoded != 16 {
916            let err = DecodeError::InvalidLength(bytes_decoded);
917            return Err(base64::DecodeSliceError::DecodeError(err));
918        }
919
920        Ok(Self(i128::from_le_bytes(decoded)))
921    }
922}
923
924/// Record type that's appropriate for encoding and decoding sessions to and
925/// from session stores.
926#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
927pub struct Record {
928    pub id: Id,
929    pub data: Data,
930    pub expiry_date: OffsetDateTime,
931}
932
933impl Record {
934    fn new(expiry_date: OffsetDateTime) -> Self {
935        Self {
936            id: Id::default(),
937            data: Data::default(),
938            expiry_date,
939        }
940    }
941}
942
943/// Session expiry configuration.
944///
945/// # Examples
946///
947/// ```rust
948/// use time::{Duration, OffsetDateTime};
949/// use tower_sessions_ext::Expiry;
950///
951/// // Will be expired on "session end".
952/// let expiry = Expiry::OnSessionEnd;
953///
954/// // Will be expired in five minutes from last acitve.
955/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
956///
957/// // Will be expired at the given timestamp.
958/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
959/// let expiry = Expiry::AtDateTime(expired_at);
960/// ```
961#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
962pub enum Expiry {
963    /// Expire on [current session end][current-session-end], as defined by the
964    /// browser.
965    ///
966    /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
967    OnSessionEnd(Duration),
968
969    /// Expire on inactivity.
970    ///
971    /// Reading a session is not considered activity for expiration purposes.
972    /// [`Session`] expiration is computed from the last time the session was
973    /// _modified_.
974    OnInactivity(Duration),
975
976    /// Expire at a specific date and time.
977    ///
978    /// This value may be extended manually with
979    /// [`set_expiry`](Session::set_expiry).
980    AtDateTime(OffsetDateTime),
981}
982
983#[cfg(test)]
984mod tests {
985    use async_trait::async_trait;
986    use mockall::{
987        mock,
988        predicate::{self, always},
989    };
990
991    use super::*;
992
993    mock! {
994        #[derive(Debug)]
995        pub Store {}
996
997        #[async_trait]
998        impl SessionStore for Store {
999            async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1000            async fn save(&self, record: &Record) -> session_store::Result<()>;
1001            async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1002            async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1003        }
1004    }
1005
1006    #[tokio::test]
1007    async fn test_cycle_id() {
1008        let mut mock_store = MockStore::new();
1009
1010        let initial_id = Id::default();
1011        let new_id = Id::default();
1012
1013        // Set up expectations for the mock store
1014        mock_store
1015            .expect_save()
1016            .with(always())
1017            .times(1)
1018            .returning(|_| Ok(()));
1019        mock_store
1020            .expect_load()
1021            .with(predicate::eq(initial_id))
1022            .times(1)
1023            .returning(move |_| {
1024                Ok(Some(Record {
1025                    id: initial_id,
1026                    data: Data::default(),
1027                    expiry_date: OffsetDateTime::now_utc(),
1028                }))
1029            });
1030        mock_store
1031            .expect_delete()
1032            .with(predicate::eq(initial_id))
1033            .times(1)
1034            .returning(|_| Ok(()));
1035        mock_store
1036            .expect_create()
1037            .times(1)
1038            .returning(move |record| {
1039                record.id = new_id;
1040                Ok(())
1041            });
1042
1043        let store = Arc::new(mock_store);
1044        let session = Session::new(Some(initial_id), store.clone(), None);
1045
1046        // Insert some data and save the session
1047        session.insert("foo", 42).await.unwrap();
1048        session.save().await.unwrap();
1049
1050        // Cycle the session ID
1051        session.cycle_id().await.unwrap();
1052
1053        // Verify that the session ID has changed and the data is still present
1054        assert_ne!(session.id(), Some(initial_id));
1055        assert!(session.id().is_none()); // The session ID should be None
1056        assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1057
1058        // Save the session to update the ID in the session object
1059        session.save().await.unwrap();
1060        assert_eq!(session.id(), Some(new_id));
1061    }
1062}