Skip to main content

tower_sessions_ext_core/
session.rs

1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3    collections::HashMap,
4    fmt::{self, Display},
5    hash::Hash,
6    result,
7    str::{self, FromStr},
8    sync::{
9        Arc,
10        atomic::{self, AtomicBool},
11    },
12};
13
14use base64::{DecodeError, Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize, de::DeserializeOwned};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{SessionStore, session_store};
21
22pub const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24/// Callback invoked when a session is discovered to have expired (store returned `None` for a
25/// previously valid session id). Useful for cleanup (e.g. revoking tokens, logging out elsewhere).
26pub type OnExpireCallback = Arc<dyn Fn(Id) + Send + Sync>;
27
28type Result<T> = result::Result<T, Error>;
29
30type Data = HashMap<String, Value>;
31
32/// Session errors.
33#[derive(thiserror::Error, Debug)]
34pub enum Error {
35    /// Maps `serde_json` errors.
36    #[error(transparent)]
37    SerdeJson(#[from] serde_json::Error),
38
39    /// Maps `session_store::Error` errors.
40    #[error(transparent)]
41    Store(#[from] session_store::Error),
42}
43
44#[derive(Debug)]
45struct Inner {
46    // This will be `None` when:
47    //
48    // 1. We have not been provided a session cookie or have failed to parse it,
49    // 2. The store has not found the session.
50    //
51    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
52    session_id: parking_lot::Mutex<Option<Id>>,
53
54    // A lazy representation of the session's value, hydrated on a just-in-time basis. A
55    // `None` value indicates we have not tried to access it yet. After access, it will always
56    // contain `Some(Record)`.
57    record: Mutex<Option<Record>>,
58
59    // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
60    expiry: parking_lot::Mutex<Option<Expiry>>,
61
62    is_modified: AtomicBool,
63}
64
65/// A session which allows HTTP applications to associate key-value pairs with
66/// visitors.
67#[derive(Clone)]
68pub struct Session {
69    store: Arc<dyn SessionStore>,
70    inner: Arc<Inner>,
71    on_expire: Option<OnExpireCallback>,
72}
73
74impl fmt::Debug for Session {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_struct("Session")
77            .field("store", &self.store)
78            .field("inner", &self.inner)
79            .field("on_expire", &self.on_expire.as_ref().map(|_| "Some(_)"))
80            .finish()
81    }
82}
83
84impl Session {
85    /// Creates a new session with the session ID, store, and expiry.
86    ///
87    /// This method is lazy and does not invoke the overhead of talking to the
88    /// backing store.
89    ///
90    /// The optional `on_expire` callback is invoked when the store returns `None`
91    /// for a session id (e.g. session expired or was deleted).
92    ///
93    /// # Examples
94    ///
95    /// ```rust
96    /// use std::sync::Arc;
97    ///
98    /// use tower_sessions_ext::{MemoryStore, Session};
99    ///
100    /// let store = Arc::new(MemoryStore::default());
101    /// Session::new(None, store, None, None);
102    /// ```
103    pub fn new(
104        session_id: Option<Id>,
105        store: Arc<impl SessionStore>,
106        expiry: Option<Expiry>,
107        on_expire: Option<OnExpireCallback>,
108    ) -> Self {
109        let inner = Inner {
110            session_id: parking_lot::Mutex::new(session_id),
111            record: Mutex::new(None), // `None` indicates we have not loaded from store.
112            expiry: parking_lot::Mutex::new(expiry),
113            is_modified: AtomicBool::new(false),
114        };
115
116        Self {
117            store,
118            inner: Arc::new(inner),
119            on_expire,
120        }
121    }
122
123    fn create_record(&self) -> Record {
124        Record::new(self.expiry_date())
125    }
126
127    #[tracing::instrument(skip(self), err, level = "trace")]
128    async fn get_record(&'_ self) -> Result<MappedMutexGuard<'_, Record>> {
129        let mut record_guard = self.inner.record.lock().await;
130
131        // Lazily load the record since `None` here indicates we have no yet loaded it.
132        if record_guard.is_none() {
133            tracing::trace!("record not loaded from store; loading");
134
135            let session_id = *self.inner.session_id.lock();
136            *record_guard = Some(if let Some(session_id) = session_id {
137                match self.store.load(&session_id).await? {
138                    Some(loaded_record) => {
139                        tracing::trace!("record found in store");
140                        loaded_record
141                    }
142
143                    None => {
144                        // A well-behaved user agent should not send session cookies after
145                        // expiration. Even so it's possible for an expired session to be removed
146                        // from the store after a request was initiated. However, such a race should
147                        // be relatively uncommon and as such entering this branch could indicate
148                        // malicious behavior.
149                        tracing::warn!("possibly suspicious activity: record not found in store");
150                        if let Some(ref on_expire) = self.on_expire {
151                            on_expire(session_id);
152                        }
153                        *self.inner.session_id.lock() = None;
154                        self.create_record()
155                    }
156                }
157            } else {
158                tracing::trace!("session id not found");
159                self.create_record()
160            })
161        }
162
163        Ok(MutexGuard::map(record_guard, |opt| {
164            opt.as_mut()
165                .expect("Record should always be `Option::Some` at this point")
166        }))
167    }
168
169    /// Inserts a `impl Serialize` value into the session.
170    ///
171    /// # Examples
172    ///
173    /// ```rust
174    /// # tokio_test::block_on(async {
175    /// use std::sync::Arc;
176    ///
177    /// use tower_sessions_ext::{MemoryStore, Session};
178    ///
179    /// let store = Arc::new(MemoryStore::default());
180    /// let session = Session::new(None, store, None, None);
181    ///
182    /// session.insert("foo", 42).await.unwrap();
183    ///
184    /// let value = session.get::<usize>("foo").await.unwrap();
185    /// assert_eq!(value, Some(42));
186    /// # });
187    /// ```
188    ///
189    /// # Errors
190    ///
191    /// - This method can fail when [`serde_json::to_value`] fails.
192    /// - If the session has not been hydrated and loading from the store fails,
193    ///   we fail with [`Error::Store`].
194    pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
195        self.insert_value(key, serde_json::to_value(&value)?)
196            .await?;
197        Ok(())
198    }
199
200    /// Inserts a `serde_json::Value` into the session.
201    ///
202    /// If the key was not present in the underlying map, `None` is returned and
203    /// `modified` is set to `true`.
204    ///
205    /// If the underlying map did have the key and its value is the same as the
206    /// provided value, `None` is returned and `modified` is not set.
207    ///
208    /// # Examples
209    ///
210    /// ```rust
211    /// # tokio_test::block_on(async {
212    /// use std::sync::Arc;
213    ///
214    /// use tower_sessions_ext::{MemoryStore, Session};
215    ///
216    /// let store = Arc::new(MemoryStore::default());
217    /// let session = Session::new(None, store, None, None);
218    ///
219    /// let value = session
220    ///     .insert_value("foo", serde_json::json!(42))
221    ///     .await
222    ///     .unwrap();
223    /// assert!(value.is_none());
224    ///
225    /// let value = session
226    ///     .insert_value("foo", serde_json::json!(42))
227    ///     .await
228    ///     .unwrap();
229    /// assert!(value.is_none());
230    ///
231    /// let value = session
232    ///     .insert_value("foo", serde_json::json!("bar"))
233    ///     .await
234    ///     .unwrap();
235    /// assert_eq!(value, Some(serde_json::json!(42)));
236    /// # });
237    /// ```
238    ///
239    /// # Errors
240    ///
241    /// - If the session has not been hydrated and loading from the store fails,
242    ///   we fail with [`Error::Store`].
243    pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
244        let mut record_guard = self.get_record().await?;
245        Ok(if record_guard.data.get(key) != Some(&value) {
246            self.inner
247                .is_modified
248                .store(true, atomic::Ordering::Release);
249            record_guard.data.insert(key.to_string(), value)
250        } else {
251            None
252        })
253    }
254
255    /// Gets a value from the store.
256    ///
257    /// # Examples
258    ///
259    /// ```rust
260    /// # tokio_test::block_on(async {
261    /// use std::sync::Arc;
262    ///
263    /// use tower_sessions_ext::{MemoryStore, Session};
264    ///
265    /// let store = Arc::new(MemoryStore::default());
266    /// let session = Session::new(None, store, None, None);
267    ///
268    /// session.insert("foo", 42).await.unwrap();
269    ///
270    /// let value = session.get::<usize>("foo").await.unwrap();
271    /// assert_eq!(value, Some(42));
272    /// # });
273    /// ```
274    ///
275    /// # Errors
276    ///
277    /// - This method can fail when [`serde_json::from_value`] fails.
278    /// - If the session has not been hydrated and loading from the store fails,
279    ///   we fail with [`Error::Store`].
280    pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
281        Ok(self
282            .get_value(key)
283            .await?
284            .map(serde_json::from_value)
285            .transpose()?)
286    }
287
288    /// Gets a `serde_json::Value` from the store.
289    ///
290    /// # Examples
291    ///
292    /// ```rust
293    /// # tokio_test::block_on(async {
294    /// use std::sync::Arc;
295    ///
296    /// use tower_sessions_ext::{MemoryStore, Session};
297    ///
298    /// let store = Arc::new(MemoryStore::default());
299    /// let session = Session::new(None, store, None, None);
300    ///
301    /// session.insert("foo", 42).await.unwrap();
302    ///
303    /// let value = session.get_value("foo").await.unwrap().unwrap();
304    /// assert_eq!(value, serde_json::json!(42));
305    /// # });
306    /// ```
307    ///
308    /// # Errors
309    ///
310    /// - If the session has not been hydrated and loading from the store fails,
311    ///   we fail with [`Error::Store`].
312    pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
313        let record_guard = self.get_record().await?;
314        Ok(record_guard.data.get(key).cloned())
315    }
316
317    /// Removes a value from the store, retuning the value of the key if it was
318    /// present in the underlying map.
319    ///
320    /// # Examples
321    ///
322    /// ```rust
323    /// # tokio_test::block_on(async {
324    /// use std::sync::Arc;
325    ///
326    /// use tower_sessions_ext::{MemoryStore, Session};
327    ///
328    /// let store = Arc::new(MemoryStore::default());
329    /// let session = Session::new(None, store, None, None);
330    ///
331    /// session.insert("foo", 42).await.unwrap();
332    ///
333    /// let value: Option<usize> = session.remove("foo").await.unwrap();
334    /// assert_eq!(value, Some(42));
335    ///
336    /// let value: Option<usize> = session.get("foo").await.unwrap();
337    /// assert!(value.is_none());
338    /// # });
339    /// ```
340    ///
341    /// # Errors
342    ///
343    /// - This method can fail when [`serde_json::from_value`] fails.
344    /// - If the session has not been hydrated and loading from the store fails,
345    ///   we fail with [`Error::Store`].
346    pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
347        Ok(self
348            .remove_value(key)
349            .await?
350            .map(serde_json::from_value)
351            .transpose()?)
352    }
353
354    /// Removes a `serde_json::Value` from the session.
355    ///
356    /// # Examples
357    ///
358    /// ```rust
359    /// # tokio_test::block_on(async {
360    /// use std::sync::Arc;
361    ///
362    /// use tower_sessions_ext::{MemoryStore, Session};
363    ///
364    /// let store = Arc::new(MemoryStore::default());
365    /// let session = Session::new(None, store, None, None);
366    ///
367    /// session.insert("foo", 42).await.unwrap();
368    /// let value = session.remove_value("foo").await.unwrap().unwrap();
369    /// assert_eq!(value, serde_json::json!(42));
370    ///
371    /// let value: Option<usize> = session.get("foo").await.unwrap();
372    /// assert!(value.is_none());
373    /// # });
374    /// ```
375    ///
376    /// # Errors
377    ///
378    /// - If the session has not been hydrated and loading from the store fails,
379    ///   we fail with [`Error::Store`].
380    pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
381        let mut record_guard = self.get_record().await?;
382        self.inner
383            .is_modified
384            .store(true, atomic::Ordering::Release);
385        Ok(record_guard.data.remove(key))
386    }
387
388    /// Clears the session of all data but does not delete it from the store.
389    ///
390    /// # Examples
391    ///
392    /// ```rust
393    /// # tokio_test::block_on(async {
394    /// use std::sync::Arc;
395    ///
396    /// use tower_sessions_ext::{MemoryStore, Session};
397    ///
398    /// let store = Arc::new(MemoryStore::default());
399    ///
400    /// let session = Session::new(None, store.clone(), None, None);
401    /// session.insert("foo", 42).await.unwrap();
402    /// assert!(!session.is_empty().await);
403    ///
404    /// session.save().await.unwrap();
405    ///
406    /// session.clear().await;
407    ///
408    /// // Not empty! (We have an ID still.)
409    /// assert!(!session.is_empty().await);
410    /// // Data is cleared...
411    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
412    ///
413    /// // ...data is cleared before loading from the backend...
414    /// let session = Session::new(session.id(), store.clone(), None, None);
415    /// session.clear().await;
416    /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
417    ///
418    /// let session = Session::new(session.id(), store, None, None);
419    /// // ...but data is not deleted from the store.
420    /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
421    /// # });
422    /// ```
423    pub async fn clear(&self) {
424        let mut record_guard = self.inner.record.lock().await;
425        if let Some(record) = record_guard.as_mut() {
426            record.data.clear();
427        } else if let Some(session_id) = *self.inner.session_id.lock() {
428            let mut new_record = self.create_record();
429            new_record.id = session_id;
430            *record_guard = Some(new_record);
431        }
432
433        self.inner
434            .is_modified
435            .store(true, atomic::Ordering::Release);
436    }
437
438    /// Returns `true` if there is no session ID and the session is empty.
439    ///
440    /// # Examples
441    ///
442    /// ```rust
443    /// # tokio_test::block_on(async {
444    /// use std::sync::Arc;
445    ///
446    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
447    ///
448    /// let store = Arc::new(MemoryStore::default());
449    ///
450    /// let session = Session::new(None, store.clone(), None, None);
451    /// // Empty if we have no ID and record is not loaded.
452    /// assert!(session.is_empty().await);
453    ///
454    /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
455    /// // Not empty if we have an ID but no record. (Record is not loaded here.)
456    /// assert!(!session.is_empty().await);
457    ///
458    /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
459    /// session.insert("foo", 42).await.unwrap();
460    /// // Not empty after inserting.
461    /// assert!(!session.is_empty().await);
462    /// session.save().await.unwrap();
463    /// // Not empty after saving.
464    /// assert!(!session.is_empty().await);
465    ///
466    /// let session = Session::new(session.id(), store.clone(), None, None);
467    /// session.load().await.unwrap();
468    /// // Not empty after loading from store...
469    /// assert!(!session.is_empty().await);
470    /// // ...and not empty after accessing the session.
471    /// session.get::<usize>("foo").await.unwrap();
472    /// assert!(!session.is_empty().await);
473    ///
474    /// let session = Session::new(session.id(), store.clone(), None, None);
475    /// session.delete().await.unwrap();
476    /// // Not empty after deleting from store...
477    /// assert!(!session.is_empty().await);
478    /// session.get::<usize>("foo").await.unwrap();
479    /// // ...but empty after trying to access the deleted session.
480    /// assert!(session.is_empty().await);
481    ///
482    /// let session = Session::new(None, store, None, None);
483    /// session.insert("foo", 42).await.unwrap();
484    /// session.flush().await.unwrap();
485    /// // Empty after flushing.
486    /// assert!(session.is_empty().await);
487    /// # });
488    /// ```
489    pub async fn is_empty(&self) -> bool {
490        let record_guard = self.inner.record.lock().await;
491
492        // N.B.: Session IDs are `None` if:
493        //
494        // 1. The cookie was not provided or otherwise could not be parsed,
495        // 2. Or the session could not be loaded from the store.
496        let session_id = self.inner.session_id.lock();
497
498        let Some(record) = record_guard.as_ref() else {
499            return session_id.is_none();
500        };
501
502        session_id.is_none() && record.data.is_empty()
503    }
504
505    /// Get the session ID.
506    ///
507    /// # Examples
508    ///
509    /// ```rust
510    /// use std::sync::Arc;
511    ///
512    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
513    ///
514    /// let store = Arc::new(MemoryStore::default());
515    ///
516    /// let session = Session::new(None, store.clone(), None, None);
517    /// assert!(session.id().is_none());
518    ///
519    /// let id = Some(Id::default());
520    /// let session = Session::new(id, store, None, None);
521    /// assert_eq!(id, session.id());
522    /// ```
523    pub fn id(&self) -> Option<Id> {
524        *self.inner.session_id.lock()
525    }
526
527    /// Get the session expiry.
528    ///
529    /// # Examples
530    ///
531    /// ```rust
532    /// use std::sync::Arc;
533    ///
534    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
535    ///
536    /// let store = Arc::new(MemoryStore::default());
537    /// let session = Session::new(None, store, None, None);
538    ///
539    /// assert_eq!(session.expiry(), None);
540    /// ```
541    pub fn expiry(&self) -> Option<Expiry> {
542        *self.inner.expiry.lock()
543    }
544
545    /// Set `expiry` to the given value.
546    ///
547    /// This may be used within applications directly to alter the session's
548    /// time to live.
549    ///
550    /// # Examples
551    ///
552    /// ```rust
553    /// use std::sync::Arc;
554    ///
555    /// use time::OffsetDateTime;
556    /// use tower_sessions_ext::{MemoryStore, Session, session::Expiry};
557    ///
558    /// let store = Arc::new(MemoryStore::default());
559    /// let session = Session::new(None, store, None, None);
560    ///
561    /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
562    /// session.set_expiry(Some(expiry));
563    ///
564    /// assert_eq!(session.expiry(), Some(expiry));
565    /// ```
566    pub fn set_expiry(&self, expiry: Option<Expiry>) {
567        *self.inner.expiry.lock() = expiry;
568        self.inner
569            .is_modified
570            .store(true, atomic::Ordering::Release);
571    }
572
573    /// Get session expiry as `OffsetDateTime`.
574    ///
575    /// # Examples
576    ///
577    /// ```rust
578    /// use std::sync::Arc;
579    ///
580    /// use time::{Duration, OffsetDateTime};
581    /// use tower_sessions_ext::{MemoryStore, Session};
582    ///
583    /// let store = Arc::new(MemoryStore::default());
584    /// let session = Session::new(None, store, None, None);
585    ///
586    /// // Our default duration is two weeks.
587    /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
588    ///
589    /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
590    /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
591    /// ```
592    pub fn expiry_date(&self) -> OffsetDateTime {
593        let expiry = self.inner.expiry.lock();
594        match *expiry {
595            Some(Expiry::OnInactivity(duration)) => {
596                OffsetDateTime::now_utc().saturating_add(duration)
597            }
598            Some(Expiry::AtDateTime(datetime)) => datetime,
599            Some(Expiry::OnSessionEnd(datetime)) => {
600                OffsetDateTime::now_utc().saturating_add(datetime)
601            },
602            None => 
603                OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION)
604        }
605    }
606
607    /// Get session expiry as `Duration`.
608    ///
609    /// # Examples
610    ///
611    /// ```rust
612    /// use std::sync::Arc;
613    ///
614    /// use time::Duration;
615    /// use tower_sessions_ext::{MemoryStore, Session};
616    ///
617    /// let store = Arc::new(MemoryStore::default());
618    /// let session = Session::new(None, store, None, None);
619    ///
620    /// let expected_duration = Duration::weeks(2);
621    ///
622    /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
623    /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
624    /// ```
625    pub fn expiry_age(&self) -> Duration {
626        std::cmp::max(
627            self.expiry_date() - OffsetDateTime::now_utc(),
628            Duration::ZERO,
629        )
630    }
631
632    /// Returns `true` if the session has been modified during the request.
633    ///
634    /// # Examples
635    ///
636    /// ```rust
637    /// # tokio_test::block_on(async {
638    /// use std::sync::Arc;
639    ///
640    /// use tower_sessions_ext::{MemoryStore, Session};
641    ///
642    /// let store = Arc::new(MemoryStore::default());
643    /// let session = Session::new(None, store, None, None);
644    ///
645    /// // Not modified initially.
646    /// assert!(!session.is_modified());
647    ///
648    /// // Getting doesn't count as a modification.
649    /// session.get::<usize>("foo").await.unwrap();
650    /// assert!(!session.is_modified());
651    ///
652    /// // Insertions and removals do though.
653    /// session.insert("foo", 42).await.unwrap();
654    /// assert!(session.is_modified());
655    /// # });
656    /// ```
657    pub fn is_modified(&self) -> bool {
658        self.inner.is_modified.load(atomic::Ordering::Acquire)
659    }
660
661    /// Saves the session record to the store.
662    ///
663    /// Note that this method is generally not needed and is reserved for
664    /// situations where the session store must be updated during the
665    /// request.
666    ///
667    /// # Examples
668    ///
669    /// ```rust
670    /// # tokio_test::block_on(async {
671    /// use std::sync::Arc;
672    ///
673    /// use tower_sessions_ext::{MemoryStore, Session};
674    ///
675    /// let store = Arc::new(MemoryStore::default());
676    /// let session = Session::new(None, store.clone(), None, None);
677    ///
678    /// session.insert("foo", 42).await.unwrap();
679    /// session.save().await.unwrap();
680    ///
681    /// let session = Session::new(session.id(), store, None, None);
682    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
683    /// # });
684    /// ```
685    ///
686    /// # Errors
687    ///
688    /// - If saving to the store fails, we fail with [`Error::Store`].
689    #[tracing::instrument(skip(self), err, level = "trace")]
690    pub async fn save(&self) -> Result<()> {
691        let mut record_guard = self.get_record().await?;
692        record_guard.expiry_date = self.expiry_date();
693
694        // Session ID is `None` if:
695        //
696        //  1. No valid cookie was found on the request or,
697        //  2. No valid session was found in the store.
698        //
699        // In either case, we must create a new session via the store interface.
700        //
701        // Potential ID collisions must be handled by session store implementers.
702        if self.inner.session_id.lock().is_none() {
703            self.store.create(&mut record_guard).await?;
704            *self.inner.session_id.lock() = Some(record_guard.id);
705        } else {
706            self.store.save(&record_guard).await?;
707        }
708        Ok(())
709    }
710
711    /// Loads the session record from the store.
712    ///
713    /// Note that this method is generally not needed and is reserved for
714    /// situations where the session must be updated during the request.
715    ///
716    /// # Examples
717    ///
718    /// ```rust
719    /// # tokio_test::block_on(async {
720    /// use std::sync::Arc;
721    ///
722    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
723    ///
724    /// let store = Arc::new(MemoryStore::default());
725    /// let id = Some(Id::default());
726    /// let session = Session::new(id, store.clone(), None, None);
727    ///
728    /// session.insert("foo", 42).await.unwrap();
729    /// session.save().await.unwrap();
730    ///
731    /// let session = Session::new(session.id(), store, None, None);
732    /// session.load().await.unwrap();
733    ///
734    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
735    /// # });
736    /// ```
737    ///
738    /// # Errors
739    ///
740    /// - If loading from the store fails, we fail with [`Error::Store`].
741    #[tracing::instrument(skip(self), err, level = "trace")]
742    pub async fn load(&self) -> Result<()> {
743        let session_id = *self.inner.session_id.lock();
744        let Some(ref id) = session_id else {
745            tracing::warn!("called load with no session id");
746            return Ok(());
747        };
748        let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
749        let mut record_guard = self.inner.record.lock().await;
750        *record_guard = loaded_record;
751        Ok(())
752    }
753
754    /// Deletes the session from the store.
755    ///
756    /// # Examples
757    ///
758    /// ```rust
759    /// # tokio_test::block_on(async {
760    /// use std::sync::Arc;
761    ///
762    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore, session::Id};
763    ///
764    /// let store = Arc::new(MemoryStore::default());
765    /// let session = Session::new(Some(Id::default()), store.clone(), None, None);
766    ///
767    /// // Save before deleting.
768    /// session.save().await.unwrap();
769    ///
770    /// // Delete from the store.
771    /// session.delete().await.unwrap();
772    ///
773    /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
774    /// # });
775    /// ```
776    ///
777    /// # Errors
778    ///
779    /// - If deleting from the store fails, we fail with [`Error::Store`].
780    #[tracing::instrument(skip(self), err, level = "trace")]
781    pub async fn delete(&self) -> Result<()> {
782        let session_id = *self.inner.session_id.lock();
783        let Some(ref session_id) = session_id else {
784            tracing::warn!("called delete with no session id");
785            return Ok(());
786        };
787        self.store.delete(session_id).await.map_err(Error::Store)?;
788        Ok(())
789    }
790
791    /// Flushes the session by removing all data contained in the session and
792    /// then deleting it from the store.
793    ///
794    /// # Examples
795    ///
796    /// ```rust
797    /// # tokio_test::block_on(async {
798    /// use std::sync::Arc;
799    ///
800    /// use tower_sessions_ext::{MemoryStore, Session, SessionStore};
801    ///
802    /// let store = Arc::new(MemoryStore::default());
803    /// let session = Session::new(None, store.clone(), None, None);
804    ///
805    /// session.insert("foo", "bar").await.unwrap();
806    /// session.save().await.unwrap();
807    ///
808    /// let id = session.id().unwrap();
809    ///
810    /// session.flush().await.unwrap();
811    ///
812    /// assert!(session.id().is_none());
813    /// assert!(session.is_empty().await);
814    /// assert!(store.load(&id).await.unwrap().is_none());
815    /// # });
816    /// ```
817    ///
818    /// # Errors
819    ///
820    /// - If deleting from the store fails, we fail with [`Error::Store`].
821    pub async fn flush(&self) -> Result<()> {
822        self.clear().await;
823        self.delete().await?;
824        *self.inner.session_id.lock() = None;
825        Ok(())
826    }
827
828    /// Cycles the session ID while retaining any data that was associated with
829    /// it.
830    ///
831    /// Using this method helps prevent session fixation attacks by ensuring a
832    /// new ID is assigned to the session.
833    ///
834    /// # Examples
835    ///
836    /// ```rust
837    /// # tokio_test::block_on(async {
838    /// use std::sync::Arc;
839    ///
840    /// use tower_sessions_ext::{MemoryStore, Session, session::Id};
841    ///
842    /// let store = Arc::new(MemoryStore::default());
843    /// let session = Session::new(None, store.clone(), None, None);
844    ///
845    /// session.insert("foo", 42).await.unwrap();
846    /// session.save().await.unwrap();
847    /// let id = session.id();
848    ///
849    /// let session = Session::new(session.id(), store.clone(), None, None);
850    /// session.cycle_id().await.unwrap();
851    ///
852    /// assert!(!session.is_empty().await);
853    /// assert!(session.is_modified());
854    ///
855    /// session.save().await.unwrap();
856    ///
857    /// let session = Session::new(session.id(), store, None, None);
858    ///
859    /// assert_ne!(id, session.id());
860    /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
861    /// # });
862    /// ```
863    ///
864    /// # Errors
865    ///
866    /// - If deleting from the store fails or saving to the store fails, we fail
867    ///   with [`Error::Store`].
868    pub async fn cycle_id(&self) -> Result<()> {
869        let mut record_guard = self.get_record().await?;
870
871        let old_session_id = record_guard.id;
872        record_guard.id = Id::default();
873        *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
874        // `create` method.
875
876        self.store
877            .delete(&old_session_id)
878            .await
879            .map_err(Error::Store)?;
880
881        self.inner
882            .is_modified
883            .store(true, atomic::Ordering::Release);
884
885        Ok(())
886    }
887}
888
889/// ID type for sessions.
890///
891/// Wraps an array of 16 bytes.
892///
893/// # Examples
894///
895/// ```rust
896/// use tower_sessions_ext::session::Id;
897///
898/// Id::default();
899/// ```
900#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
901pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
902// session ID, which is undesirable.
903
904impl Default for Id {
905    fn default() -> Self {
906        use rand::prelude::*;
907
908        Self(rand::rng().random())
909    }
910}
911
912impl Display for Id {
913    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
914        let mut encoded = [0; 22];
915        URL_SAFE_NO_PAD
916            .encode_slice(self.0.to_le_bytes(), &mut encoded)
917            .expect("Encoded ID must be exactly 22 bytes");
918        let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
919
920        f.write_str(encoded)
921    }
922}
923
924impl FromStr for Id {
925    type Err = base64::DecodeSliceError;
926
927    fn from_str(s: &str) -> result::Result<Self, Self::Err> {
928        let mut decoded = [0; 16];
929        let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
930        if bytes_decoded != 16 {
931            let err = DecodeError::InvalidLength(bytes_decoded);
932            return Err(base64::DecodeSliceError::DecodeError(err));
933        }
934
935        Ok(Self(i128::from_le_bytes(decoded)))
936    }
937}
938
939/// Record type that's appropriate for encoding and decoding sessions to and
940/// from session stores.
941#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
942pub struct Record {
943    pub id: Id,
944    pub data: Data,
945    pub expiry_date: OffsetDateTime,
946}
947
948impl Record {
949    fn new(expiry_date: OffsetDateTime) -> Self {
950        Self {
951            id: Id::default(),
952            data: Data::default(),
953            expiry_date,
954        }
955    }
956}
957
958/// Session expiry configuration.
959///
960/// # Examples
961///
962/// ```rust
963/// use time::{Duration, OffsetDateTime};
964/// use tower_sessions_ext::Expiry;
965///
966/// // Will be expired on "session end".
967/// let expiry = Expiry::OnSessionEnd;
968///
969/// // Will be expired in five minutes from last acitve.
970/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
971///
972/// // Will be expired at the given timestamp.
973/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
974/// let expiry = Expiry::AtDateTime(expired_at);
975/// ```
976#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
977pub enum Expiry {
978    /// Expire on [current session end][current-session-end], as defined by the
979    /// browser.
980    ///
981    /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
982    OnSessionEnd(Duration),
983
984    /// Expire on inactivity.
985    ///
986    /// Reading a session is not considered activity for expiration purposes.
987    /// [`Session`] expiration is computed from the last time the session was
988    /// _modified_.
989    OnInactivity(Duration),
990
991    /// Expire at a specific date and time.
992    ///
993    /// This value may be extended manually with
994    /// [`set_expiry`](Session::set_expiry).
995    AtDateTime(OffsetDateTime),
996}
997
998#[cfg(test)]
999mod tests {
1000    use async_trait::async_trait;
1001    use mockall::{
1002        mock,
1003        predicate::{self, always},
1004    };
1005
1006    use super::*;
1007
1008    mock! {
1009        #[derive(Debug)]
1010        pub Store {}
1011
1012        #[async_trait]
1013        impl SessionStore for Store {
1014            async fn create(&self, record: &mut Record) -> session_store::Result<()>;
1015            async fn save(&self, record: &Record) -> session_store::Result<()>;
1016            async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
1017            async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
1018        }
1019    }
1020
1021    #[tokio::test]
1022    async fn test_cycle_id() {
1023        let mut mock_store = MockStore::new();
1024
1025        let initial_id = Id::default();
1026        let new_id = Id::default();
1027
1028        // Set up expectations for the mock store
1029        mock_store
1030            .expect_save()
1031            .with(always())
1032            .times(1)
1033            .returning(|_| Ok(()));
1034        mock_store
1035            .expect_load()
1036            .with(predicate::eq(initial_id))
1037            .times(1)
1038            .returning(move |_| {
1039                Ok(Some(Record {
1040                    id: initial_id,
1041                    data: Data::default(),
1042                    expiry_date: OffsetDateTime::now_utc(),
1043                }))
1044            });
1045        mock_store
1046            .expect_delete()
1047            .with(predicate::eq(initial_id))
1048            .times(1)
1049            .returning(|_| Ok(()));
1050        mock_store
1051            .expect_create()
1052            .times(1)
1053            .returning(move |record| {
1054                record.id = new_id;
1055                Ok(())
1056            });
1057
1058        let store = Arc::new(mock_store);
1059        let session = Session::new(Some(initial_id), store.clone(), None, None);
1060
1061        // Insert some data and save the session
1062        session.insert("foo", 42).await.unwrap();
1063        session.save().await.unwrap();
1064
1065        // Cycle the session ID
1066        session.cycle_id().await.unwrap();
1067
1068        // Verify that the session ID has changed and the data is still present
1069        assert_ne!(session.id(), Some(initial_id));
1070        assert!(session.id().is_none()); // The session ID should be None
1071        assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1072
1073        // Save the session to update the ID in the session object
1074        session.save().await.unwrap();
1075        assert_eq!(session.id(), Some(new_id));
1076    }
1077}