Skip to main content

tower_sessions_ext_core/
session.rs

1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3    collections::HashMap,
4    fmt::{self, Display},
5    hash::Hash,
6    result,
7    str::{self, FromStr},
8    sync::{
9        Arc,
10        atomic::{self, AtomicBool},
11    },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31    /// Maps `serde_json` errors.
32    #[error(transparent)]
33    SerdeJson(#[from] serde_json::Error),
34
35    /// Maps `session_store::Error` errors.
36    #[error(transparent)]
37    Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42    // This will be `None` when:
43    //
44    // 1. We have not been provided a session cookie or have failed to parse it,
45    // 2. The store has not found the session.
46    //
47    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48    session_id: parking_lot::Mutex<Option<Id>>,
49
50    // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51    // `None` value indicates we have not tried to access it yet. After access, it will always
52    // contain `Some(Record)`.
53    record: Mutex<Option<Record>>,
54
55    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56    expiry: parking_lot::Mutex<Option<Expiry>>,
57
58    is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Clone)]
64pub struct Session {
65    store: Arc<dyn SessionStore>,
66    inner: Arc<Inner>,
67}
68
69impl fmt::Debug for Session {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("Session")
72            .field("store", &self.store)
73            .field("inner", &self.inner)
74            .finish()
75    }
76}
77
78impl Session {
79    /// Creates a new session with the session ID, store, and expiry.
80    ///
81    /// This method is lazy and does not invoke the overhead of talking to the
82    /// backing store.
83    ///
84    /// # Examples
85    ///
86    /// ```rust
87    /// use std::sync::Arc;
88    ///
89    /// use tower_sessions_ext::{MemoryStore, Session};
90    ///
91    /// let store = Arc::new(MemoryStore::default());
92    /// Session::new(None, store, None);
93    /// ```
94    pub fn new(
95        session_id: Option<Id>,
96        store: Arc<impl SessionStore>,
97        expiry: Option<Expiry>,
98    ) -> Self {
99        let inner = Inner {
100            session_id: parking_lot::Mutex::new(session_id),
101            record: Mutex::new(None), // `None` indicates we have not loaded from store.
102            expiry: parking_lot::Mutex::new(expiry),
103            is_modified: AtomicBool::new(false),
104        };
105
106        Self {
107            store,
108            inner: Arc::new(inner),
109        }
110    }
111
112    fn create_record(&self) -> Record {
113        Record::new(self.expiry_date())
114    }
115
116    #[tracing::instrument(skip(self), err, level = "trace")]
117    async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
118        let mut record_guard = self.inner.record.lock().await;
119
120        // Lazily load the record since `None` here indicates we have no yet loaded it.
121        if record_guard.is_none() {
122            tracing::trace!("record not loaded from store; loading");
123
124            let session_id = *self.inner.session_id.lock();
125            *record_guard = Some(if let Some(session_id) = session_id {
126                match self.store.load(&session_id).await? {
127                    Some(loaded_record) => {
128                        tracing::trace!("record found in store");
129                        loaded_record
130                    }
131
132                    None => {
133                        // A well-behaved user agent should not send session cookies after
134                        // expiration. Even so it's possible for an expired session to be removed
135                        // from the store after a request was initiated. However, such a race should
136                        // be relatively uncommon and as such entering this branch could indicate
137                        // malicious behavior.
138                        tracing::warn!("possibly suspicious activity: record not found in store");
139                        *self.inner.session_id.lock() = None;
140                        self.create_record()
141                    }
142                }
143            } else {
144                tracing::trace!("session id not found");
145                self.create_record()
146            })
147        }
148
149        Ok(MutexGuard::map(record_guard, |opt| {
150            opt.as_mut()
151                .expect("Record should always be `Option::Some` at this point")
152        }))
153    }
154
155    /// Inserts a `impl Serialize` value into the session.
156    ///
157    /// # Examples
158    ///
159    /// ```rust
160    /// # tokio_test::block_on(async {
161    /// use std::sync::Arc;
162    ///
163    /// use tower_sessions_ext::{MemoryStore, Session};
164    ///
165    /// let store = Arc::new(MemoryStore::default());
166    /// let session = Session::new(None, store, None);
167    ///
168    /// session.insert("foo", 42).await.unwrap();
169    ///
170    /// let value = session.get::<usize>("foo").await.unwrap();
171    /// assert_eq!(value, Some(42));
172    /// # });
173    /// ```
174    ///
175    /// # Errors
176    ///
177    /// - This method can fail when [`serde_json::to_value`] fails.
178    /// - If the session has not been hydrated and loading from the store fails,
179    ///   we fail with [`Error::Store`].
180    pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
181        self.insert_value(key, serde_json::to_value(&value)?)
182            .await?;
183        Ok(())
184    }
185
186    /// Inserts a `serde_json::Value` into the session.
187    ///
188    /// If the key was not present in the underlying map, `None` is returned and
189    /// `modified` is set to `true`.
190    ///
191    /// If the underlying map did have the key and its value is the same as the
192    /// provided value, `None` is returned and `modified` is not set.
193    ///
194    /// # Examples
195    ///
196    /// ```rust
197    /// # tokio_test::block_on(async {
198    /// use std::sync::Arc;
199    ///
200    /// use tower_sessions_ext::{MemoryStore, Session};
201    ///
202    /// let store = Arc::new(MemoryStore::default());
203    /// let session = Session::new(None, store, None);
204    ///
205    /// let value = session
206    ///     .insert_value("foo", serde_json::json!(42))
207    ///     .await
208    ///     .unwrap();
209    /// assert!(value.is_none());
210    ///
211    /// let value = session
212    ///     .insert_value("foo", serde_json::json!(42))
213    ///     .await
214    ///     .unwrap();
215    /// assert!(value.is_none());
216    ///
217    /// let value = session
218    ///     .insert_value("foo", serde_json::json!("bar"))
219    ///     .await
220    ///     .unwrap();
221    /// assert_eq!(value, Some(serde_json::json!(42)));
222    /// # });
223    /// ```
224    ///
225    /// # Errors
226    ///
227    /// - If the session has not been hydrated and loading from the store fails,
228    ///   we fail with [`Error::Store`].
229    pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
230        let mut record_guard = self.get_record().await?;
231        Ok(if record_guard.data.get(key) != Some(&value) {
232            self.inner
233                .is_modified
234                .store(true, atomic::Ordering::Release);
235            record_guard.data.insert(key.to_string(), value)
236        } else {
237            None
238        })
239    }
240
241    /// Gets a value from the store.
242    ///
243    /// # Examples
244    ///
245    /// ```rust
246    /// # tokio_test::block_on(async {
247    /// use std::sync::Arc;
248    ///
249    /// use tower_sessions_ext::{MemoryStore, Session};
250    ///
251    /// let store = Arc::new(MemoryStore::default());
252    /// let session = Session::new(None, store, None);
253    ///
254    /// session.insert("foo", 42).await.unwrap();
255    ///
256    /// let value = session.get::<usize>("foo").await.unwrap();
257    /// assert_eq!(value, Some(42));
258    /// # });
259    /// ```
260    ///
261    /// # Errors
262    ///
263    /// - This method can fail when [`serde_json::from_value`] fails.
264    /// - If the session has not been hydrated and loading from the store fails,
265    ///   we fail with [`Error::Store`].
266    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
267        Ok(self
268            .get_value(key)
269            .await?
270            .map(serde_json::from_value)
271            .transpose()?)
272    }
273
274    /// Gets a `serde_json::Value` from the store.
275    ///
276    /// # Examples
277    ///
278    /// ```rust
279    /// # tokio_test::block_on(async {
280    /// use std::sync::Arc;
281    ///
282    /// use tower_sessions_ext::{MemoryStore, Session};
283    ///
284    /// let store = Arc::new(MemoryStore::default());
285    /// let session = Session::new(None, store, None);
286    ///
287    /// session.insert("foo", 42).await.unwrap();
288    ///
289    /// let value = session.get_value("foo").await.unwrap().unwrap();
290    /// assert_eq!(value, serde_json::json!(42));
291    /// # });
292    /// ```
293    ///
294    /// # Errors
295    ///
296    /// - If the session has not been hydrated and loading from the store fails,
297    ///   we fail with [`Error::Store`].
298    pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
299        let record_guard = self.get_record().await?;
300        Ok(record_guard.data.get(key).cloned())
301    }
302
303    /// Removes a value from the store, retuning the value of the key if it was
304    /// present in the underlying map.
305    ///
306    /// # Examples
307    ///
308    /// ```rust
309    /// # tokio_test::block_on(async {
310    /// use std::sync::Arc;
311    ///
312    /// use tower_sessions_ext::{MemoryStore, Session};
313    ///
314    /// let store = Arc::new(MemoryStore::default());
315    /// let session = Session::new(None, store, None);
316    ///
317    /// session.insert("foo", 42).await.unwrap();
318    ///
319    /// let value: Option<usize> = session.remove("foo").await.unwrap();
320    /// assert_eq!(value, Some(42));
321    ///
322    /// let value: Option<usize> = session.get("foo").await.unwrap();
323    /// assert!(value.is_none());
324    /// # });
325    /// ```
326    ///
327    /// # Errors
328    ///
329    /// - This method can fail when [`serde_json::from_value`] fails.
330    /// - If the session has not been hydrated and loading from the store fails,
331    ///   we fail with [`Error::Store`].
332    pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
333        Ok(self
334            .remove_value(key)
335            .await?
336            .map(serde_json::from_value)
337            .transpose()?)
338    }
339
340    /// Removes a `serde_json::Value` from the session.
341    ///
342    /// # Examples
343    ///
344    /// ```rust
345    /// # tokio_test::block_on(async {
346    /// use std::sync::Arc;
347    ///
348    /// use tower_sessions_ext::{MemoryStore, Session};
349    ///
350    /// let store = Arc::new(MemoryStore::default());
351    /// let session = Session::new(None, store, None);
352    ///
353    /// session.insert("foo", 42).await.unwrap();
354    /// let value = session.remove_value("foo").await.unwrap().unwrap();
355    /// assert_eq!(value, serde_json::json!(42));
356    ///
357    /// let value: Option<usize> = session.get("foo").await.unwrap();
358    /// assert!(value.is_none());
359    /// # });
360    /// ```
361    ///
362    /// # Errors
363    ///
364    /// - If the session has not been hydrated and loading from the store fails,
365    ///   we fail with [`Error::Store`].
366    pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
367        let mut record_guard = self.get_record().await?;
368        self.inner
369            .is_modified
370            .store(true, atomic::Ordering::Release);
371        Ok(record_guard.data.remove(key))
372    }
373
374    /// Clears the session of all data but does not delete it from the store.
375    ///
376    /// # Examples
377    ///
378    /// ```rust
379    /// # tokio_test::block_on(async {
380    /// use std::sync::Arc;
381    ///
382    /// use tower_sessions_ext::{MemoryStore, Session};
383    ///
384    /// let store = Arc::new(MemoryStore::default());
385    ///
386    /// let session = Session::new(None, store.clone(), None);
387    /// session.insert("foo", 42).await.unwrap();
388    /// assert!(!session.is_empty().await);
389    ///
390    /// session.save().await.unwrap();
391    ///
392    /// session.clear().await;
393    ///
394    /// // Not empty! (We have an ID still.)
395    /// assert!(!session.is_empty().await);
396    /// // Data is cleared...
397    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
398    ///
399    /// // ...data is cleared before loading from the backend...
400    /// let session = Session::new(session.id(), store.clone(), None);
401    /// session.clear().await;
402    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
403    ///
404    /// let session = Session::new(session.id(), store, None);
405    /// // ...but data is not deleted from the store.
406    /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
407    /// # });
408    /// ```
409    pub async fn clear(&self) {
410        let mut record_guard = self.inner.record.lock().await;
411        if let Some(record) = record_guard.as_mut() {
412            record.data.clear();
413        } else if let Some(session_id) = *self.inner.session_id.lock() {
414            let mut new_record = self.create_record();
415            new_record.id = session_id;
416            *record_guard = Some(new_record);
417        }
418
419        self.inner
420            .is_modified
421            .store(true, atomic::Ordering::Release);
422    }
423
424    /// Returns `true` if there is no session ID and the session is empty.
425    ///
426    /// # Examples
427    ///
428    /// ```rust
429    /// # tokio_test::block_on(async {
430    /// use std::sync::Arc;
431    ///
432    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
433    ///
434    /// let store = Arc::new(MemoryStore::default());
435    ///
436    /// let session = Session::new(None, store.clone(), None);
437    /// // Empty if we have no ID and record is not loaded.
438    /// assert!(session.is_empty().await);
439    ///
440    /// let session = Session::new(Some(Id::default()), store.clone(), None);
441    /// // Not empty if we have an ID but no record. (Record is not loaded here.)
442    /// assert!(!session.is_empty().await);
443    ///
444    /// let session = Session::new(Some(Id::default()), store.clone(), None);
445    /// session.insert("foo", 42).await.unwrap();
446    /// // Not empty after inserting.
447    /// assert!(!session.is_empty().await);
448    /// session.save().await.unwrap();
449    /// // Not empty after saving.
450    /// assert!(!session.is_empty().await);
451    ///
452    /// let session = Session::new(session.id(), store.clone(), None);
453    /// session.load().await.unwrap();
454    /// // Not empty after loading from store...
455    /// assert!(!session.is_empty().await);
456    /// // ...and not empty after accessing the session.
457    /// session.get::<usize>("foo").await.unwrap();
458    /// assert!(!session.is_empty().await);
459    ///
460    /// let session = Session::new(session.id(), store.clone(), None);
461    /// session.delete().await.unwrap();
462    /// // Not empty after deleting from store...
463    /// assert!(!session.is_empty().await);
464    /// session.get::<usize>("foo").await.unwrap();
465    /// // ...but empty after trying to access the deleted session.
466    /// assert!(session.is_empty().await);
467    ///
468    /// let session = Session::new(None, store, None);
469    /// session.insert("foo", 42).await.unwrap();
470    /// session.flush().await.unwrap();
471    /// // Empty after flushing.
472    /// assert!(session.is_empty().await);
473    /// # });
474    /// ```
475    pub async fn is_empty(&self) -> bool {
476        let record_guard = self.inner.record.lock().await;
477
478        // N.B.: Session IDs are `None` if:
479        //
480        // 1. The cookie was not provided or otherwise could not be parsed,
481        // 2. Or the session could not be loaded from the store.
482        let session_id = self.inner.session_id.lock();
483
484        let Some(record) = record_guard.as_ref() else {
485            return session_id.is_none();
486        };
487
488        session_id.is_none() && record.data.is_empty()
489    }
490
491    /// Get the session ID.
492    ///
493    /// # Examples
494    ///
495    /// ```rust
496    /// use std::sync::Arc;
497    ///
498    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
499    ///
500    /// let store = Arc::new(MemoryStore::default());
501    ///
502    /// let session = Session::new(None, store.clone(), None);
503    /// assert!(session.id().is_none());
504    ///
505    /// let id = Some(Id::default());
506    /// let session = Session::new(id, store, None);
507    /// assert_eq!(id, session.id());
508    /// ```
509    pub fn id(&self) -> Option<Id> {
510        *self.inner.session_id.lock()
511    }
512
513    /// Get the session expiry.
514    ///
515    /// # Examples
516    ///
517    /// ```rust
518    /// use std::sync::Arc;
519    ///
520    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
521    ///
522    /// let store = Arc::new(MemoryStore::default());
523    /// let session = Session::new(None, store, None);
524    ///
525    /// assert_eq!(session.expiry(), None);
526    /// ```
527    pub fn expiry(&self) -> Option<Expiry> {
528        *self.inner.expiry.lock()
529    }
530
531    /// Set `expiry` to the given value.
532    ///
533    /// This may be used within applications directly to alter the session's
534    /// time to live.
535    ///
536    /// # Examples
537    ///
538    /// ```rust
539    /// use std::sync::Arc;
540    ///
541    /// use time::OffsetDateTime;
542    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
543    ///
544    /// let store = Arc::new(MemoryStore::default());
545    /// let session = Session::new(None, store, None);
546    ///
547    /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
548    /// session.set_expiry(Some(expiry));
549    ///
550    /// assert_eq!(session.expiry(), Some(expiry));
551    /// ```
552    pub fn set_expiry(&self, expiry: Option<Expiry>) {
553        *self.inner.expiry.lock() = expiry;
554        self.inner
555            .is_modified
556            .store(true, atomic::Ordering::Release);
557    }
558
559    /// Get session expiry as `OffsetDateTime`.
560    ///
561    /// # Examples
562    ///
563    /// ```rust
564    /// use std::sync::Arc;
565    ///
566    /// use time::{Duration, OffsetDateTime};
567    /// use tower_sessions_ext::{MemoryStore, Session};
568    ///
569    /// let store = Arc::new(MemoryStore::default());
570    /// let session = Session::new(None, store, None);
571    ///
572    /// // Our default duration is two weeks.
573    /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
574    ///
575    /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
576    /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
577    /// ```
578    pub fn expiry_date(&self) -> OffsetDateTime {
579        let expiry = self.inner.expiry.lock();
580        match *expiry {
581            Some(Expiry::OnInactivity(duration)) => {
582                OffsetDateTime::now_utc().saturating_add(duration)
583            }
584            Some(Expiry::AtDateTime(datetime)) => datetime,
585            Some(Expiry::OnSessionEnd(datetime)) => {
586                OffsetDateTime::now_utc().saturating_add(datetime)
587            },
588            None => 
589                OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION)
590        }
591    }
592
593    /// Get session expiry as `Duration`.
594    ///
595    /// # Examples
596    ///
597    /// ```rust
598    /// use std::sync::Arc;
599    ///
600    /// use time::Duration;
601    /// use tower_sessions_ext::{MemoryStore, Session};
602    ///
603    /// let store = Arc::new(MemoryStore::default());
604    /// let session = Session::new(None, store, None);
605    ///
606    /// let expected_duration = Duration::weeks(2);
607    ///
608    /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
609    /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
610    /// ```
611    pub fn expiry_age(&self) -> Duration {
612        std::cmp::max(
613            self.expiry_date() - OffsetDateTime::now_utc(),
614            Duration::ZERO,
615        )
616    }
617
618    /// Returns `true` if the session has been modified during the request.
619    ///
620    /// # Examples
621    ///
622    /// ```rust
623    /// # tokio_test::block_on(async {
624    /// use std::sync::Arc;
625    ///
626    /// use tower_sessions_ext::{MemoryStore, Session};
627    ///
628    /// let store = Arc::new(MemoryStore::default());
629    /// let session = Session::new(None, store, None);
630    ///
631    /// // Not modified initially.
632    /// assert!(!session.is_modified());
633    ///
634    /// // Getting doesn't count as a modification.
635    /// session.get::<usize>("foo").await.unwrap();
636    /// assert!(!session.is_modified());
637    ///
638    /// // Insertions and removals do though.
639    /// session.insert("foo", 42).await.unwrap();
640    /// assert!(session.is_modified());
641    /// # });
642    /// ```
643    pub fn is_modified(&self) -> bool {
644        self.inner.is_modified.load(atomic::Ordering::Acquire)
645    }
646
647    /// Saves the session record to the store.
648    ///
649    /// Note that this method is generally not needed and is reserved for
650    /// situations where the session store must be updated during the
651    /// request.
652    ///
653    /// # Examples
654    ///
655    /// ```rust
656    /// # tokio_test::block_on(async {
657    /// use std::sync::Arc;
658    ///
659    /// use tower_sessions_ext::{MemoryStore, Session};
660    ///
661    /// let store = Arc::new(MemoryStore::default());
662    /// let session = Session::new(None, store.clone(), None);
663    ///
664    /// session.insert("foo", 42).await.unwrap();
665    /// session.save().await.unwrap();
666    ///
667    /// let session = Session::new(session.id(), store, None);
668    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
669    /// # });
670    /// ```
671    ///
672    /// # Errors
673    ///
674    /// - If saving to the store fails, we fail with [`Error::Store`].
675    #[tracing::instrument(skip(self), err, level = "trace")]
676    pub async fn save(&self) -> Result<()> {
677        let mut record_guard = self.get_record().await?;
678        record_guard.expiry_date = self.expiry_date();
679
680        // Session ID is `None` if:
681        //
682        //  1. No valid cookie was found on the request or,
683        //  2. No valid session was found in the store.
684        //
685        // In either case, we must create a new session via the store interface.
686        //
687        // Potential ID collisions must be handled by session store implementers.
688        if self.inner.session_id.lock().is_none() {
689            self.store.create(&mut record_guard).await?;
690            *self.inner.session_id.lock() = Some(record_guard.id);
691        } else {
692            self.store.save(&record_guard).await?;
693        }
694        Ok(())
695    }
696
697    /// Loads the session record from the store.
698    ///
699    /// Note that this method is generally not needed and is reserved for
700    /// situations where the session must be updated during the request.
701    ///
702    /// # Examples
703    ///
704    /// ```rust
705    /// # tokio_test::block_on(async {
706    /// use std::sync::Arc;
707    ///
708    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
709    ///
710    /// let store = Arc::new(MemoryStore::default());
711    /// let id = Some(Id::default());
712    /// let session = Session::new(id, store.clone(), None);
713    ///
714    /// session.insert("foo", 42).await.unwrap();
715    /// session.save().await.unwrap();
716    ///
717    /// let session = Session::new(session.id(), store, None);
718    /// session.load().await.unwrap();
719    ///
720    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
721    /// # });
722    /// ```
723    ///
724    /// # Errors
725    ///
726    /// - If loading from the store fails, we fail with [`Error::Store`].
727    #[tracing::instrument(skip(self), err, level = "trace")]
728    pub async fn load(&self) -> Result<()> {
729        let session_id = *self.inner.session_id.lock();
730        let Some(ref id) = session_id else {
731            tracing::warn!("called load with no session id");
732            return Ok(());
733        };
734        let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
735        let mut record_guard = self.inner.record.lock().await;
736        *record_guard = loaded_record;
737        Ok(())
738    }
739
740    /// Deletes the session from the store.
741    ///
742    /// # Examples
743    ///
744    /// ```rust
745    /// # tokio_test::block_on(async {
746    /// use std::sync::Arc;
747    ///
748    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
749    ///
750    /// let store = Arc::new(MemoryStore::default());
751    /// let session = Session::new(Some(Id::default()), store.clone(), None);
752    ///
753    /// // Save before deleting.
754    /// session.save().await.unwrap();
755    ///
756    /// // Delete from the store.
757    /// session.delete().await.unwrap();
758    ///
759    /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
760    /// # });
761    /// ```
762    ///
763    /// # Errors
764    ///
765    /// - If deleting from the store fails, we fail with [`Error::Store`].
766    #[tracing::instrument(skip(self), err, level = "trace")]
767    pub async fn delete(&self) -> Result<()> {
768        let session_id = *self.inner.session_id.lock();
769        let Some(ref session_id) = session_id else {
770            tracing::warn!("called delete with no session id");
771            return Ok(());
772        };
773        self.store.delete(session_id).await.map_err(Error::Store)?;
774        Ok(())
775    }
776
777    /// Flushes the session by removing all data contained in the session and
778    /// then deleting it from the store.
779    ///
780    /// # Examples
781    ///
782    /// ```rust
783    /// # tokio_test::block_on(async {
784    /// use std::sync::Arc;
785    ///
786    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
787    ///
788    /// let store = Arc::new(MemoryStore::default());
789    /// let session = Session::new(None, store.clone(), None);
790    ///
791    /// session.insert("foo", "bar").await.unwrap();
792    /// session.save().await.unwrap();
793    ///
794    /// let id = session.id().unwrap();
795    ///
796    /// session.flush().await.unwrap();
797    ///
798    /// assert!(session.id().is_none());
799    /// assert!(session.is_empty().await);
800    /// assert!(store.load(&id).await.unwrap().is_none());
801    /// # });
802    /// ```
803    ///
804    /// # Errors
805    ///
806    /// - If deleting from the store fails, we fail with [`Error::Store`].
807    pub async fn flush(&self) -> Result<()> {
808        self.clear().await;
809        self.delete().await?;
810        *self.inner.session_id.lock() = None;
811        Ok(())
812    }
813
814    /// Cycles the session ID while retaining any data that was associated with
815    /// it.
816    ///
817    /// Using this method helps prevent session fixation attacks by ensuring a
818    /// new ID is assigned to the session.
819    ///
820    /// # Examples
821    ///
822    /// ```rust
823    /// # tokio_test::block_on(async {
824    /// use std::sync::Arc;
825    ///
826    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
827    ///
828    /// let store = Arc::new(MemoryStore::default());
829    /// let session = Session::new(None, store.clone(), None);
830    ///
831    /// session.insert("foo", 42).await.unwrap();
832    /// session.save().await.unwrap();
833    /// let id = session.id();
834    ///
835    /// let session = Session::new(session.id(), store.clone(), None);
836    /// session.cycle_id().await.unwrap();
837    ///
838    /// assert!(!session.is_empty().await);
839    /// assert!(session.is_modified());
840    ///
841    /// session.save().await.unwrap();
842    ///
843    /// let session = Session::new(session.id(), store, None);
844    ///
845    /// assert_ne!(id, session.id());
846    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
847    /// # });
848    /// ```
849    ///
850    /// # Errors
851    ///
852    /// - If deleting from the store fails or saving to the store fails, we fail
853    ///   with [`Error::Store`].
854    pub async fn cycle_id(&self) -> Result<()> {
855        let mut record_guard = self.get_record().await?;
856
857        let old_session_id = record_guard.id;
858        record_guard.id = Id::default();
859        *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
860        // `create` method.
861
862        self.store
863            .delete(&old_session_id)
864            .await
865            .map_err(Error::Store)?;
866
867        self.inner
868            .is_modified
869            .store(true, atomic::Ordering::Release);
870
871        Ok(())
872    }
873}
874
875/// ID type for sessions.
876///
877/// Wraps an array of 16 bytes.
878///
879/// # Examples
880///
881/// ```rust
882/// use tower_sessions_ext::session::Id;
883///
884/// Id::default();
885/// ```
886#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
887pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
888// session ID, which is undesirable.
889
890impl Default for Id {
891    fn default() -> Self {
892        use rand::prelude::*;
893
894        Self(rand::rng().random())
895    }
896}
897
898impl Display for Id {
899    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900        let mut encoded = [0; 22];
901        URL_SAFE_NO_PAD
902            .encode_slice(self.0.to_le_bytes(), &mut encoded)
903            .expect("Encoded ID must be exactly 22 bytes");
904        let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
905
906        f.write_str(encoded)
907    }
908}
909
910impl FromStr for Id {
911    type Err = base64::DecodeSliceError;
912
913    fn from_str(s: &str) -> result::Result<Self, Self::Err> {
914        let mut decoded = [0; 16];
915        let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
916        if bytes_decoded != 16 {
917            let err = DecodeError::InvalidLength(bytes_decoded);
918            return Err(base64::DecodeSliceError::DecodeError(err));
919        }
920
921        Ok(Self(i128::from_le_bytes(decoded)))
922    }
923}
924
925/// Record type that's appropriate for encoding and decoding sessions to and
926/// from session stores.
927#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
928pub struct Record {
929    pub id: Id,
930    pub data: Data,
931    pub expiry_date: OffsetDateTime,
932}
933
934impl Record {
935    fn new(expiry_date: OffsetDateTime) -> Self {
936        Self {
937            id: Id::default(),
938            data: Data::default(),
939            expiry_date,
940        }
941    }
942}
943
944/// Session expiry configuration.
945///
946/// # Examples
947///
948/// ```rust
949/// use time::{Duration, OffsetDateTime};
950/// use tower_sessions_ext::Expiry;
951///
952/// // Will be expired on "session end".
953/// let expiry = Expiry::OnSessionEnd;
954///
955/// // Will be expired in five minutes from last acitve.
956/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
957///
958/// // Will be expired at the given timestamp.
959/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
960/// let expiry = Expiry::AtDateTime(expired_at);
961/// ```
962#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
963pub enum Expiry {
964    /// Expire on [current session end][current-session-end], as defined by the
965    /// browser.
966    ///
967    /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
968    OnSessionEnd(Duration),
969
970    /// Expire on inactivity.
971    ///
972    /// Reading a session is not considered activity for expiration purposes.
973    /// [`Session`] expiration is computed from the last time the session was
974    /// _modified_.
975    OnInactivity(Duration),
976
977    /// Expire at a specific date and time.
978    ///
979    /// This value may be extended manually with
980    /// [`set_expiry`](Session::set_expiry).
981    AtDateTime(OffsetDateTime),
982}
983
984#[cfg(test)]
985mod tests {
986    use async_trait::async_trait;
987    use mockall::{
988        mock,
989        predicate::{self, always},
990    };
991
992    use super::*;
993
994    mock! {
995        #[derive(Debug)]
996        pub Store {}
997
998        #[async_trait]
999        impl SessionStore for Store {
1000            async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1001            async fn save(&self, record: &Record) -> session_store::Result<()>;
1002            async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1003            async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1004        }
1005    }
1006
1007    #[tokio::test]
1008    async fn test_cycle_id() {
1009        let mut mock_store = MockStore::new();
1010
1011        let initial_id = Id::default();
1012        let new_id = Id::default();
1013
1014        // Set up expectations for the mock store
1015        mock_store
1016            .expect_save()
1017            .with(always())
1018            .times(1)
1019            .returning(|_| Ok(()));
1020        mock_store
1021            .expect_load()
1022            .with(predicate::eq(initial_id))
1023            .times(1)
1024            .returning(move |_| {
1025                Ok(Some(Record {
1026                    id: initial_id,
1027                    data: Data::default(),
1028                    expiry_date: OffsetDateTime::now_utc(),
1029                }))
1030            });
1031        mock_store
1032            .expect_delete()
1033            .with(predicate::eq(initial_id))
1034            .times(1)
1035            .returning(|_| Ok(()));
1036        mock_store
1037            .expect_create()
1038            .times(1)
1039            .returning(move |record| {
1040                record.id = new_id;
1041                Ok(())
1042            });
1043
1044        let store = Arc::new(mock_store);
1045        let session = Session::new(Some(initial_id), store.clone(), None);
1046
1047        // Insert some data and save the session
1048        session.insert("foo", 42).await.unwrap();
1049        session.save().await.unwrap();
1050
1051        // Cycle the session ID
1052        session.cycle_id().await.unwrap();
1053
1054        // Verify that the session ID has changed and the data is still present
1055        assert_ne!(session.id(), Some(initial_id));
1056        assert!(session.id().is_none()); // The session ID should be None
1057        assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1058
1059        // Save the session to update the ID in the session object
1060        session.save().await.unwrap();
1061        assert_eq!(session.id(), Some(new_id));
1062    }
1063}