tower_sessions_sled_store/
lib.rs

1use async_trait::async_trait;
2use sled::IVec;
3use tower_sessions::{
4    cookie::time::OffsetDateTime,
5    session::{Id, Record},
6    session_store, ExpiredDeletion, SessionStore,
7};
8
9/// Session store backed by sled
10#[derive(Debug, Clone)]
11pub struct SledStore {
12    /// the sled tree which should be used for storage
13    sled: sled::Tree,
14}
15
16impl SledStore {
17    /// Create a new SledStore using a sled tree
18    ///
19    /// A [`sled::Tree`] can be acquired by taking one from a [`sled::Db`]
20    ///
21    /// ```rust,no_run
22    /// use tower_sessions_sled_store::SledStore;
23    /// let sled = sled::open("storage").unwrap();
24    /// let session_store = SledStore::new(sled.open_tree("session").expect("Error opening tree"));
25    /// ```
26    pub fn new(branch: sled::Tree) -> Self {
27        Self { sled: branch }
28    }
29}
30
31#[async_trait]
32impl SessionStore for SledStore {
33    async fn save(&self, record: &Record) -> session_store::Result<()> {
34        let encoded =
35            encode_record(record).map_err(|e| session_store::Error::Encode(e.to_string()))?;
36
37        self.sled
38            .insert(record.id.0.to_be_bytes(), encoded)
39            .map_err(|e| session_store::Error::Backend(e.to_string()))?;
40
41        Ok(())
42    }
43
44    async fn load(&self, id: &Id) -> session_store::Result<Option<Record>> {
45        let rec = self
46            .sled
47            .get(id.0.to_be_bytes())
48            .map_err(|e| session_store::Error::Backend(e.to_string()))?;
49
50        if let Some(sr) = rec {
51            let rec =
52                decode_record(&sr).map_err(|e| session_store::Error::Decode(e.to_string()))?;
53
54            return Ok(Some(rec));
55        }
56
57        Ok(None)
58    }
59
60    async fn delete(&self, id: &Id) -> session_store::Result<()> {
61        self.sled
62            .remove(id.0.to_be_bytes())
63            .map_err(|e| session_store::Error::Backend(e.to_string()))?;
64
65        Ok(())
66    }
67}
68
69/// Encode the data using rmp_serde for storage within the sled database
70fn encode_record(record: &Record) -> Result<IVec, rmp_serde::encode::Error> {
71    let serialized = rmp_serde::to_vec(record)?;
72
73    Ok(IVec::from(serialized))
74}
75
76/// Decode the data using rmp_serde from the sled database
77fn decode_record(data: &IVec) -> Result<Record, rmp_serde::decode::Error> {
78    let decoded = rmp_serde::from_slice(data)?;
79
80    Ok(decoded)
81}
82
83#[async_trait]
84impl ExpiredDeletion for SledStore {
85    /// Deletes expired sessions from the session store
86    ///
87    /// Note that running deletion may be expensive as this function has to iterate every session stored in the database.
88    /// This may become an issue as sled technically has no idea its running under async and it may block for a long time.
89    /// However to solve this this function automatically runs the deletion within [`tokio::task::spawn_blocking`]
90    async fn delete_expired(&self) -> session_store::Result<()> {
91        let sled = self.sled.clone();
92
93        // deletion is ran within a sync block as sled has no concept of async and acessing the whole database may block for a long time
94        tokio::task::spawn_blocking(move || -> session_store::Result<()> {
95            let now = OffsetDateTime::now_utc();
96
97            for (k, v) in sled.iter().flatten() {
98                let rec =
99                    decode_record(&v).map_err(|e| session_store::Error::Decode(e.to_string()))?;
100
101                if rec.expiry_date < now {
102                    sled.remove(k)
103                        .map_err(|e| session_store::Error::Backend(e.to_string()))?;
104                }
105            }
106
107            Ok(())
108        })
109        .await
110        .map_err(|e| session_store::Error::Backend(e.to_string()))?
111    }
112}