tower_sessions_libsql_store/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use async_trait::async_trait;
4use libsql::params;
5use time::OffsetDateTime;
6use tower_sessions_core::{
7    session::{Id, Record},
8    session_store::{self, ExpiredDeletion},
9    SessionStore,
10};
11
12/// An error type for libSQL stores.
13#[derive(thiserror::Error, Debug)]
14pub enum LibsqlStoreError {
15    /// A variant to map `libsql` errors.
16    #[error(transparent)]
17    Libsql(#[from] libsql::Error),
18
19    /// A variant to map `rmp_serde` encode errors.
20    #[error(transparent)]
21    Encode(#[from] rmp_serde::encode::Error),
22
23    /// A variant to map `rmp_serde` decode errors.
24    #[error(transparent)]
25    Decode(#[from] rmp_serde::decode::Error),
26}
27
28impl From<LibsqlStoreError> for session_store::Error {
29    fn from(err: LibsqlStoreError) -> Self {
30        match err {
31            LibsqlStoreError::Libsql(inner) => session_store::Error::Backend(inner.to_string()),
32            LibsqlStoreError::Decode(inner) => session_store::Error::Decode(inner.to_string()),
33            LibsqlStoreError::Encode(inner) => session_store::Error::Encode(inner.to_string()),
34        }
35    }
36}
37
38/// A libSQL session store.
39#[derive(Clone)]
40pub struct LibsqlStore {
41    connection: libsql::Connection,
42    table_name: String,
43}
44
45// Need this since connection does not implement Debug
46impl std::fmt::Debug for LibsqlStore {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        f.debug_struct("LibsqlStore")
49            // Probably want to handle this differently
50            .field("connection", &std::any::type_name::<libsql::Connection>())
51            .field("table_name", &self.table_name)
52            .finish()
53    }
54}
55
56impl LibsqlStore {
57    /// Create a new libSQL store with the provided connection pool.
58    pub fn new(client: libsql::Connection) -> Self {
59        Self {
60            connection: client,
61            table_name: "tower_sessions".into(),
62        }
63    }
64
65    /// Set the session table name with the provided name.
66    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
67        let table_name = table_name.as_ref();
68        if !is_valid_table_name(table_name) {
69            return Err(format!(
70                "Invalid table name '{}'. Table names must be alphanumeric and may contain \
71                 hyphens or underscores.",
72                table_name
73            ));
74        }
75
76        table_name.clone_into(&mut self.table_name);
77        Ok(self)
78    }
79
80    /// Migrate the session schema.
81    pub async fn migrate(&self) -> libsql::Result<()> {
82        let query = format!(
83            r#"
84            create table if not exists {}
85            (
86                id text primary key not null,
87                data blob not null,
88                expiry_date integer not null
89            )
90            "#,
91            self.table_name
92        );
93        self.connection.execute(&query, ()).await?;
94
95        Ok(())
96    }
97
98    /// Checks exitence of the ID. Helps ensure unique values in edge cases.
99    async fn id_exists(&self, conn: &libsql::Connection, id: &Id) -> session_store::Result<bool> {
100        let query = format!(
101            r#"
102            select exists(select 1 from {table_name} where id = ?)
103            "#,
104            table_name = self.table_name
105        );
106
107        let res = conn
108            .query(&query, params![id.to_string()])
109            .await
110            .map_err(LibsqlStoreError::Libsql)
111            .unwrap()
112            .next()
113            .await
114            .unwrap()
115            .unwrap()
116            .get_value(0)
117            .unwrap();
118
119        Ok(res == libsql::Value::Integer(1))
120    }
121
122    /// Save results to DB
123    async fn save_with_conn(
124        &self,
125        conn: &libsql::Connection,
126        record: &Record,
127    ) -> session_store::Result<()> {
128        let query = format!(
129            r#"
130            insert into {}
131              (id, data, expiry_date) values (?, ?, ?)
132            on conflict(id) do update set
133              data = excluded.data,
134              expiry_date = excluded.expiry_date
135            "#,
136            self.table_name
137        );
138        conn.execute(
139            &query,
140            params![
141                record.id.to_string(),
142                rmp_serde::to_vec(record).map_err(LibsqlStoreError::Encode)?,
143                record.expiry_date.unix_timestamp()
144            ],
145        )
146        .await
147        .map_err(LibsqlStoreError::Libsql)?;
148
149        Ok(())
150    }
151}
152
153#[async_trait]
154impl ExpiredDeletion for LibsqlStore {
155    async fn delete_expired(&self) -> session_store::Result<()> {
156        let query = format!(
157            r#"
158            delete from {table_name}
159            where expiry_date < unixepoch('now')
160            "#,
161            table_name = self.table_name
162        );
163        self.connection
164            .execute(&query, ())
165            .await
166            .map_err(LibsqlStoreError::Libsql)?;
167        Ok(())
168    }
169}
170
171#[async_trait]
172impl SessionStore for LibsqlStore {
173    async fn create(&self, record: &mut Record) -> session_store::Result<()> {
174        while self.id_exists(&self.connection, &record.id).await? {
175            record.id = Id::default() // Generate a new id
176        }
177
178        let conn = self.connection.clone();
179        self.save_with_conn(&conn, record).await?;
180
181        Ok(())
182    }
183
184    async fn save(&self, record: &Record) -> session_store::Result<()> {
185        let conn = self.connection.clone();
186        self.save_with_conn(&conn, record).await
187    }
188
189    async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
190        let query = format!(
191            r#"
192            select data from {}
193            where id = ? and expiry_date > ?
194            "#,
195            self.table_name
196        );
197
198        let mut data = self
199            .connection
200            .query(
201                &query,
202                params![
203                    session_id.to_string(),
204                    OffsetDateTime::now_utc().unix_timestamp()
205                ],
206            )
207            .await
208            .map_err(LibsqlStoreError::Libsql)?;
209
210        if let Ok(Some(data)) = data.next().await {
211            Ok(Some(
212                rmp_serde::from_slice(
213                    data.get_value(0)
214                        .map_err(LibsqlStoreError::Libsql)
215                        .unwrap()
216                        .as_blob()
217                        .unwrap(),
218                )
219                .map_err(LibsqlStoreError::Decode)?,
220            ))
221        } else {
222            Ok(None)
223        }
224    }
225
226    async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
227        let query = format!(
228            r#"
229            delete from {} where id = ?
230            "#,
231            self.table_name
232        );
233
234        self.connection
235            .execute(&query, params![session_id.to_string()])
236            .await
237            .map_err(LibsqlStoreError::Libsql)?;
238
239        Ok(())
240    }
241}
242
243fn is_valid_table_name(name: &str) -> bool {
244    !name.is_empty()
245        && name
246            .chars()
247            .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
248}
249
250#[cfg(test)]
251mod libsql_store_tests {
252    use std::collections::HashMap;
253
254    use libsql::Builder;
255    use serde_json::Value;
256    use tower_sessions::cookie::time::{Duration, OffsetDateTime};
257
258    use super::*;
259
260    #[tokio::test]
261    // Quick test to ensure that the db can be connected to, a migration can run,
262    // and the table is queried, returning None.
263    async fn basic_roundtrip() {
264        let db = Builder::new_local(":memory:").build().await.unwrap();
265        let conn = db.connect().unwrap();
266        let store = LibsqlStore::new(conn.clone());
267        store.migrate().await.unwrap();
268
269        let query = r#"
270            select * from tower_sessions limit 1
271        "#;
272
273        let row = conn.query(query, ()).await.unwrap().next().await.unwrap();
274
275        assert!(row.is_none());
276    }
277
278    #[tokio::test]
279    // Test a create with conflict
280    async fn create_with_conflict() {
281        let db = Builder::new_local(":memory:").build().await.unwrap();
282        let conn = db.connect().unwrap();
283        let store = LibsqlStore::new(conn.clone());
284        store.migrate().await.unwrap();
285
286        let data: HashMap<String, Value> =
287            HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
288                (
289                    k.to_string(),
290                    serde_json::to_value(v).expect("Error encoding"),
291                )
292            }));
293
294        let mut session_record1 = Record {
295            id: Id::default(),
296            data,
297            expiry_date: OffsetDateTime::now_utc()
298                .checked_add(Duration::days(1))
299                .expect("Overflow making expiry"),
300        };
301        store
302            .create(&mut session_record1)
303            .await
304            .expect("Error saving session");
305
306        let mut session_record2 = session_record1.clone();
307        store
308            .create(&mut session_record2)
309            .await
310            .expect("Error saving session");
311
312        let loaded1 = store
313            .load(&session_record1.id)
314            .await
315            .expect("Error loading")
316            .expect("Value missing");
317
318        let loaded2 = store
319            .load(&session_record2.id)
320            .await
321            .expect("Error loading")
322            .expect("Value missing");
323
324        assert_eq!(
325            loaded1.data, loaded2.data,
326            "Session created with dumplcate data"
327        );
328        assert_ne!(
329            loaded1.id, loaded2.id,
330            "Session conflict on id generates a new id"
331        );
332    }
333
334    #[tokio::test]
335    // Test a save and load
336    async fn save_and_load() {
337        let db = Builder::new_local(":memory:").build().await.unwrap();
338        let conn = db.connect().unwrap();
339        let store = LibsqlStore::new(conn.clone());
340        store.migrate().await.unwrap();
341
342        let data: HashMap<String, Value> =
343            HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
344                (
345                    k.to_string(),
346                    serde_json::to_value(v).expect("Error encoding"),
347                )
348            }));
349
350        let session_record = Record {
351            id: Id::default(),
352            data,
353            expiry_date: OffsetDateTime::now_utc()
354                .checked_add(Duration::days(1))
355                .expect("Overflow making expiry"),
356        };
357
358        store
359            .save(&session_record)
360            .await
361            .expect("Error saving session");
362
363        let loaded = store
364            .load(&session_record.id)
365            .await
366            .expect("Error loading")
367            .expect("Value missing");
368
369        assert_eq!(session_record, loaded, "Save and load match");
370    }
371
372    #[tokio::test]
373    // Test a delete
374    async fn save_and_delete() {
375        let db = Builder::new_local(":memory:").build().await.unwrap();
376        let conn = db.connect().unwrap();
377        let store = LibsqlStore::new(conn.clone());
378        store.migrate().await.unwrap();
379
380        let data: HashMap<String, Value> =
381            HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
382                (
383                    k.to_string(),
384                    serde_json::to_value(v).expect("Error encoding"),
385                )
386            }));
387
388        let session_record = Record {
389            id: Id::default(),
390            data,
391            expiry_date: OffsetDateTime::now_utc()
392                .checked_add(Duration::days(1))
393                .expect("Overflow making expiry"),
394        };
395
396        store
397            .save(&session_record)
398            .await
399            .expect("Error saving session");
400
401        let loaded = store
402            .load(&session_record.id)
403            .await
404            .expect("Error loading")
405            .expect("Value missing");
406
407        assert_eq!(session_record, loaded, "Save and load match");
408
409        store
410            .delete(&session_record.id)
411            .await
412            .expect("Error deleting session record");
413
414        let loaded = store.load(&session_record.id).await.expect("Error loading");
415
416        assert!(loaded.is_none())
417    }
418}