tower_sessions_rusqlite_store/
lib.rs

1use async_trait::async_trait;
2use rusqlite::OptionalExtension;
3use std::error::Error;
4use time::OffsetDateTime;
5pub use tokio_rusqlite;
6use tokio_rusqlite::{params, Connection, Result as SqlResult};
7use tower_sessions_core::{
8    session::{Id, Record},
9    session_store, ExpiredDeletion, SessionStore,
10};
11
12/// An error type for Rusqlite stores.
13#[derive(thiserror::Error, Debug)]
14pub enum RusqliteStoreError {
15    /// A variant to map `tokio_rusqlite` errors.
16    #[error(transparent)]
17    TokioRusqlite(#[from] tokio_rusqlite::Error),
18
19    /// A variant to map `rmp_serde` encode errors.
20    #[error(transparent)]
21    Encode(#[from] rmp_serde::encode::Error),
22
23    /// A variant to map `rmp_serde` decode errors.
24    #[error(transparent)]
25    Decode(#[from] rmp_serde::decode::Error),
26
27    /// A variant for other backend errors.
28    #[error("Backend error: {0}")]
29    Other(String),
30}
31
32impl From<RusqliteStoreError> for session_store::Error {
33    fn from(err: RusqliteStoreError) -> Self {
34        match err {
35            RusqliteStoreError::TokioRusqlite(inner) => {
36                session_store::Error::Backend(inner.to_string())
37            }
38            RusqliteStoreError::Decode(inner) => session_store::Error::Decode(inner.to_string()),
39            RusqliteStoreError::Encode(inner) => session_store::Error::Encode(inner.to_string()),
40            RusqliteStoreError::Other(inner) => session_store::Error::Backend(inner),
41        }
42    }
43}
44
45#[derive(Clone, Debug)]
46pub struct RusqliteStore {
47    conn: Connection,
48    table_name: String,
49}
50
51impl RusqliteStore {
52    /// Create a new SQLite store with the provided connection.
53    ///
54    /// # Examples
55    ///
56    /// ```rust,no_run
57    /// use tower_sessions_rusqlite_store::{tokio_rusqlite::Connection, RusqliteStore};
58    ///
59    /// # tokio_test::block_on(async {
60    /// let conn = Connection::open_in_memory().await.unwrap();
61    /// let session_store = RusqliteStore::new(conn);
62    /// # })
63    /// ```
64    pub fn new(conn: Connection) -> Self {
65        Self {
66            conn,
67            table_name: "tower_sessions".into(),
68        }
69    }
70
71    /// Set the session table name with the provided name.
72    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
73        let table_name = table_name.as_ref();
74        if !is_valid_table_name(table_name) {
75            return Err(format!(
76                "Invalid table name '{}'. Table names must be alphanumeric and may contain \
77                 hyphens or underscores.",
78                table_name
79            ));
80        }
81
82        self.table_name = table_name.to_owned();
83        Ok(self)
84    }
85
86    /// Migrate the session schema.
87    pub async fn migrate(&self) -> SqlResult<()> {
88        let conn = self.conn.clone();
89        let query = format!(
90            r#"
91            create table if not exists {}
92            (
93                id text primary key not null,
94                data blob not null,
95                expiry_date integer not null
96            )
97            "#,
98            self.table_name
99        );
100        conn.call(move |conn| conn.execute(&query, [])).await?;
101
102        Ok(())
103    }
104}
105
106fn id_exists_with_conn(
107    conn: &rusqlite::Connection,
108    table_name: &str,
109    id: &Id,
110) -> rusqlite::Result<bool> {
111    let query = format!(
112        r#"
113        select exists(select 1 from {} where id = ?1)
114        "#,
115        table_name
116    );
117    let mut stmt = conn.prepare(&query)?;
118    stmt.query_row(params![id.to_string()], |row| row.get(0))
119}
120
121fn save_with_conn(
122    conn: &rusqlite::Connection,
123    table_name: &str,
124    record: &Record,
125    record_data: &[u8],
126) -> rusqlite::Result<usize> {
127    let query = format!(
128        r#"
129        insert into {}
130            (id, data, expiry_date) values (?1, ?2, ?3)
131        on conflict(id) do update set
132            data = excluded.data,
133            expiry_date = excluded.expiry_date
134        "#,
135        table_name
136    );
137    conn.execute(
138        &query,
139        params![
140            record.id.to_string(),
141            record_data,
142            record.expiry_date.unix_timestamp()
143        ],
144    )
145}
146
147#[async_trait]
148impl ExpiredDeletion for RusqliteStore {
149    async fn delete_expired(&self) -> session_store::Result<()> {
150        let conn = self.conn.clone();
151        let query = format!(
152            r#"
153            delete from {table_name}
154            where expiry_date < ?1
155            "#,
156            table_name = self.table_name
157        );
158        conn.call(move |conn| conn.execute(&query, [OffsetDateTime::now_utc().unix_timestamp()]))
159            .await
160            .map_err(|e| {
161                // printing the error here because this usually runs in the background
162                // and thus the error is only received shortly before the process exits
163                eprintln!("Error deleting expired sessions: {:?}", e);
164                RusqliteStoreError::TokioRusqlite(e)
165            })?;
166
167        Ok(())
168    }
169}
170
171#[async_trait]
172impl SessionStore for RusqliteStore {
173    async fn create(&self, record: &mut Record) -> session_store::Result<()> {
174        let conn = self.conn.clone();
175
176        let new_id = conn
177            .call({
178                let mut record = record.clone();
179                let table_name = self.table_name.clone();
180
181                move |conn| {
182                    let tx = conn.transaction()?;
183
184                    while id_exists_with_conn(&tx, &table_name, &record.id)? {
185                        record.id = Id::default();
186                    }
187
188                    let record_data = rmp_serde::to_vec(&record).map_err(Box::new)?;
189
190                    save_with_conn(&tx, &table_name, &record, &record_data)?;
191
192                    tx.commit()?;
193
194                    Ok(record.id)
195                }
196            })
197            .await
198            .map_err(
199                |e: tokio_rusqlite::Error<Box<dyn Error + Send + Sync>>| match e {
200                    tokio_rusqlite::Error::Error(boxed_err) => {
201                        match boxed_err.downcast::<rmp_serde::encode::Error>() {
202                            Ok(encode_error) => RusqliteStoreError::Encode(*encode_error),
203                            Err(original_box) => {
204                                RusqliteStoreError::Other(original_box.to_string())
205                            }
206                        }
207                    }
208                    other => RusqliteStoreError::Other(other.to_string()),
209                },
210            )?;
211
212        record.id = new_id;
213
214        Ok(())
215    }
216
217    async fn save(&self, record: &Record) -> session_store::Result<()> {
218        let conn = self.conn.clone();
219        let table_name = self.table_name.clone();
220        let record = record.clone();
221        let record_data = rmp_serde::to_vec(&record).map_err(RusqliteStoreError::Encode)?;
222
223        conn.call(move |conn| save_with_conn(conn, &table_name, &record, &record_data))
224            .await
225            .map_err(RusqliteStoreError::TokioRusqlite)?;
226
227        Ok(())
228    }
229
230    async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
231        let conn = self.conn.clone();
232
233        let data = conn
234            .call({
235                let table_name = self.table_name.clone();
236                let session_id = session_id.to_string();
237                move |conn| {
238                    let query = format!(
239                        r#"
240                        select data from {}
241                        where id = ?1 and expiry_date > ?2
242                        "#,
243                        table_name
244                    );
245                    let mut stmt = conn.prepare(&query)?;
246                    stmt.query_row(
247                        params![session_id, OffsetDateTime::now_utc().unix_timestamp()],
248                        |row| {
249                            let data: Vec<u8> = row.get(0)?;
250                            Ok(data)
251                        },
252                    )
253                    .optional()
254                }
255            })
256            .await
257            .map_err(RusqliteStoreError::TokioRusqlite)?;
258
259        match data {
260            Some(data) => {
261                let record: Record =
262                    rmp_serde::from_slice(&data).map_err(RusqliteStoreError::Decode)?;
263                Ok(Some(record))
264            }
265            None => Ok(None),
266        }
267    }
268
269    async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
270        let conn = self.conn.clone();
271
272        conn.call({
273            let table_name = self.table_name.clone();
274            let session_id = session_id.to_string();
275            move |conn| {
276                let query = format!(
277                    r#"
278                    delete from {} where id = ?1
279                    "#,
280                    table_name
281                );
282                conn.execute(&query, params![session_id])
283            }
284        })
285        .await
286        .map_err(RusqliteStoreError::TokioRusqlite)?;
287
288        Ok(())
289    }
290}
291
292fn is_valid_table_name(name: &str) -> bool {
293    !name.is_empty()
294        && name
295            .chars()
296            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
297}
298
299// unit tests from https://github.com/maxcountryman/tower-sessions/blob/6ad8933b4f5e71f3202f0c1a28f194f3db5234c8/memory-store/src/lib.rs#L62
300#[cfg(test)]
301mod rusqlite_store_tests {
302    use time::Duration;
303
304    use super::*;
305
306    async fn create_store() -> RusqliteStore {
307        let conn = Connection::open_in_memory().await.unwrap();
308        let store = RusqliteStore::new(conn);
309        store.migrate().await.unwrap();
310        store
311    }
312
313    #[tokio::test]
314    async fn test_create() {
315        let store = create_store().await;
316        let mut record = Record {
317            id: Default::default(),
318            data: Default::default(),
319            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
320        };
321        assert!(store.create(&mut record).await.is_ok());
322    }
323
324    #[tokio::test]
325    async fn test_save() {
326        let store = create_store().await;
327        let record = Record {
328            id: Default::default(),
329            data: Default::default(),
330            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
331        };
332        assert!(store.save(&record).await.is_ok());
333    }
334
335    #[tokio::test]
336    async fn test_load() {
337        let store = create_store().await;
338        let mut record = Record {
339            id: Default::default(),
340            data: Default::default(),
341            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
342        };
343        store.create(&mut record).await.unwrap();
344        let loaded_record = store.load(&record.id).await.unwrap();
345        assert_eq!(Some(record), loaded_record);
346    }
347
348    #[tokio::test]
349    async fn test_delete() {
350        let store = create_store().await;
351        let mut record = Record {
352            id: Default::default(),
353            data: Default::default(),
354            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
355        };
356        store.create(&mut record).await.unwrap();
357        assert!(store.delete(&record.id).await.is_ok());
358        assert_eq!(None, store.load(&record.id).await.unwrap());
359    }
360
361    #[tokio::test]
362    async fn test_create_id_collision() {
363        let store = create_store().await;
364        let expiry_date = OffsetDateTime::now_utc() + Duration::minutes(30);
365        let mut record1 = Record {
366            id: Default::default(),
367            data: Default::default(),
368            expiry_date,
369        };
370        let mut record2 = Record {
371            id: Default::default(),
372            data: Default::default(),
373            expiry_date,
374        };
375        store.create(&mut record1).await.unwrap();
376        record2.id = record1.id; // Set the same ID for record2
377        store.create(&mut record2).await.unwrap();
378        assert_ne!(record1.id, record2.id); // IDs should be different
379    }
380
381    #[tokio::test]
382    async fn test_delete_expired() {
383        let store = create_store().await;
384        let mut record = Record {
385            id: Default::default(),
386            data: Default::default(),
387            expiry_date: OffsetDateTime::now_utc() - Duration::minutes(30),
388        };
389        store.create(&mut record).await.unwrap();
390        store.delete_expired().await.unwrap();
391        assert_eq!(None, store.load(&record.id).await.unwrap());
392    }
393}