tower_sessions_libsql_store/
lib.rs1#![doc = include_str!("../README.md")]
2
3use async_trait::async_trait;
4use libsql::params;
5use time::OffsetDateTime;
6use tower_sessions_core::{
7 session::{Id, Record},
8 session_store::{self, ExpiredDeletion},
9 SessionStore,
10};
11
12#[derive(thiserror::Error, Debug)]
14pub enum LibsqlStoreError {
15 #[error(transparent)]
17 Libsql(#[from] libsql::Error),
18
19 #[error(transparent)]
21 Encode(#[from] rmp_serde::encode::Error),
22
23 #[error(transparent)]
25 Decode(#[from] rmp_serde::decode::Error),
26}
27
28impl From<LibsqlStoreError> for session_store::Error {
29 fn from(err: LibsqlStoreError) -> Self {
30 match err {
31 LibsqlStoreError::Libsql(inner) => session_store::Error::Backend(inner.to_string()),
32 LibsqlStoreError::Decode(inner) => session_store::Error::Decode(inner.to_string()),
33 LibsqlStoreError::Encode(inner) => session_store::Error::Encode(inner.to_string()),
34 }
35 }
36}
37
38#[derive(Clone)]
40pub struct LibsqlStore {
41 connection: libsql::Connection,
42 table_name: String,
43}
44
45impl std::fmt::Debug for LibsqlStore {
47 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48 f.debug_struct("LibsqlStore")
49 .field("connection", &std::any::type_name::<libsql::Connection>())
51 .field("table_name", &self.table_name)
52 .finish()
53 }
54}
55
56impl LibsqlStore {
57 pub fn new(client: libsql::Connection) -> Self {
59 Self {
60 connection: client,
61 table_name: "tower_sessions".into(),
62 }
63 }
64
65 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
67 let table_name = table_name.as_ref();
68 if !is_valid_table_name(table_name) {
69 return Err(format!(
70 "Invalid table name '{}'. Table names must be alphanumeric and may contain \
71 hyphens or underscores.",
72 table_name
73 ));
74 }
75
76 table_name.clone_into(&mut self.table_name);
77 Ok(self)
78 }
79
80 pub async fn migrate(&self) -> libsql::Result<()> {
82 let query = format!(
83 r#"
84 create table if not exists {}
85 (
86 id text primary key not null,
87 data blob not null,
88 expiry_date integer not null
89 )
90 "#,
91 self.table_name
92 );
93 self.connection.execute(&query, ()).await?;
94
95 Ok(())
96 }
97
98 async fn id_exists(&self, conn: &libsql::Connection, id: &Id) -> session_store::Result<bool> {
100 let query = format!(
101 r#"
102 select exists(select 1 from {table_name} where id = ?)
103 "#,
104 table_name = self.table_name
105 );
106
107 let res = conn
108 .query(&query, params![id.to_string()])
109 .await
110 .map_err(LibsqlStoreError::Libsql)
111 .unwrap()
112 .next()
113 .await
114 .unwrap()
115 .unwrap()
116 .get_value(0)
117 .unwrap();
118
119 Ok(res == libsql::Value::Integer(1))
120 }
121
122 async fn save_with_conn(
124 &self,
125 conn: &libsql::Connection,
126 record: &Record,
127 ) -> session_store::Result<()> {
128 let query = format!(
129 r#"
130 insert into {}
131 (id, data, expiry_date) values (?, ?, ?)
132 on conflict(id) do update set
133 data = excluded.data,
134 expiry_date = excluded.expiry_date
135 "#,
136 self.table_name
137 );
138 conn.execute(
139 &query,
140 params![
141 record.id.to_string(),
142 rmp_serde::to_vec(record).map_err(LibsqlStoreError::Encode)?,
143 record.expiry_date.unix_timestamp()
144 ],
145 )
146 .await
147 .map_err(LibsqlStoreError::Libsql)?;
148
149 Ok(())
150 }
151}
152
153#[async_trait]
154impl ExpiredDeletion for LibsqlStore {
155 async fn delete_expired(&self) -> session_store::Result<()> {
156 let query = format!(
157 r#"
158 delete from {table_name}
159 where expiry_date < unixepoch('now')
160 "#,
161 table_name = self.table_name
162 );
163 self.connection
164 .execute(&query, ())
165 .await
166 .map_err(LibsqlStoreError::Libsql)?;
167 Ok(())
168 }
169}
170
171#[async_trait]
172impl SessionStore for LibsqlStore {
173 async fn create(&self, record: &mut Record) -> session_store::Result<()> {
174 while self.id_exists(&self.connection, &record.id).await? {
175 record.id = Id::default() }
177
178 let conn = self.connection.clone();
179 self.save_with_conn(&conn, record).await?;
180
181 Ok(())
182 }
183
184 async fn save(&self, record: &Record) -> session_store::Result<()> {
185 let conn = self.connection.clone();
186 self.save_with_conn(&conn, record).await
187 }
188
189 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
190 let query = format!(
191 r#"
192 select data from {}
193 where id = ? and expiry_date > ?
194 "#,
195 self.table_name
196 );
197
198 let mut data = self
199 .connection
200 .query(
201 &query,
202 params![
203 session_id.to_string(),
204 OffsetDateTime::now_utc().unix_timestamp()
205 ],
206 )
207 .await
208 .map_err(LibsqlStoreError::Libsql)?;
209
210 if let Ok(Some(data)) = data.next().await {
211 Ok(Some(
212 rmp_serde::from_slice(
213 data.get_value(0)
214 .map_err(LibsqlStoreError::Libsql)
215 .unwrap()
216 .as_blob()
217 .unwrap(),
218 )
219 .map_err(LibsqlStoreError::Decode)?,
220 ))
221 } else {
222 Ok(None)
223 }
224 }
225
226 async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
227 let query = format!(
228 r#"
229 delete from {} where id = ?
230 "#,
231 self.table_name
232 );
233
234 self.connection
235 .execute(&query, params![session_id.to_string()])
236 .await
237 .map_err(LibsqlStoreError::Libsql)?;
238
239 Ok(())
240 }
241}
242
243fn is_valid_table_name(name: &str) -> bool {
244 !name.is_empty()
245 && name
246 .chars()
247 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
248}
249
250#[cfg(test)]
251mod libsql_store_tests {
252 use std::collections::HashMap;
253
254 use libsql::Builder;
255 use serde_json::Value;
256 use tower_sessions::cookie::time::{Duration, OffsetDateTime};
257
258 use super::*;
259
260 #[tokio::test]
261 async fn basic_roundtrip() {
264 let db = Builder::new_local(":memory:").build().await.unwrap();
265 let conn = db.connect().unwrap();
266 let store = LibsqlStore::new(conn.clone());
267 store.migrate().await.unwrap();
268
269 let query = r#"
270 select * from tower_sessions limit 1
271 "#;
272
273 let row = conn.query(query, ()).await.unwrap().next().await.unwrap();
274
275 assert!(row.is_none());
276 }
277
278 #[tokio::test]
279 async fn create_with_conflict() {
281 let db = Builder::new_local(":memory:").build().await.unwrap();
282 let conn = db.connect().unwrap();
283 let store = LibsqlStore::new(conn.clone());
284 store.migrate().await.unwrap();
285
286 let data: HashMap<String, Value> =
287 HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
288 (
289 k.to_string(),
290 serde_json::to_value(v).expect("Error encoding"),
291 )
292 }));
293
294 let mut session_record1 = Record {
295 id: Id::default(),
296 data,
297 expiry_date: OffsetDateTime::now_utc()
298 .checked_add(Duration::days(1))
299 .expect("Overflow making expiry"),
300 };
301 store
302 .create(&mut session_record1)
303 .await
304 .expect("Error saving session");
305
306 let mut session_record2 = session_record1.clone();
307 store
308 .create(&mut session_record2)
309 .await
310 .expect("Error saving session");
311
312 let loaded1 = store
313 .load(&session_record1.id)
314 .await
315 .expect("Error loading")
316 .expect("Value missing");
317
318 let loaded2 = store
319 .load(&session_record2.id)
320 .await
321 .expect("Error loading")
322 .expect("Value missing");
323
324 assert_eq!(
325 loaded1.data, loaded2.data,
326 "Session created with dumplcate data"
327 );
328 assert_ne!(
329 loaded1.id, loaded2.id,
330 "Session conflict on id generates a new id"
331 );
332 }
333
334 #[tokio::test]
335 async fn save_and_load() {
337 let db = Builder::new_local(":memory:").build().await.unwrap();
338 let conn = db.connect().unwrap();
339 let store = LibsqlStore::new(conn.clone());
340 store.migrate().await.unwrap();
341
342 let data: HashMap<String, Value> =
343 HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
344 (
345 k.to_string(),
346 serde_json::to_value(v).expect("Error encoding"),
347 )
348 }));
349
350 let session_record = Record {
351 id: Id::default(),
352 data,
353 expiry_date: OffsetDateTime::now_utc()
354 .checked_add(Duration::days(1))
355 .expect("Overflow making expiry"),
356 };
357
358 store
359 .save(&session_record)
360 .await
361 .expect("Error saving session");
362
363 let loaded = store
364 .load(&session_record.id)
365 .await
366 .expect("Error loading")
367 .expect("Value missing");
368
369 assert_eq!(session_record, loaded, "Save and load match");
370 }
371
372 #[tokio::test]
373 async fn save_and_delete() {
375 let db = Builder::new_local(":memory:").build().await.unwrap();
376 let conn = db.connect().unwrap();
377 let store = LibsqlStore::new(conn.clone());
378 store.migrate().await.unwrap();
379
380 let data: HashMap<String, Value> =
381 HashMap::from_iter([("key", "value")].to_vec().iter().map(|(k, v)| {
382 (
383 k.to_string(),
384 serde_json::to_value(v).expect("Error encoding"),
385 )
386 }));
387
388 let session_record = Record {
389 id: Id::default(),
390 data,
391 expiry_date: OffsetDateTime::now_utc()
392 .checked_add(Duration::days(1))
393 .expect("Overflow making expiry"),
394 };
395
396 store
397 .save(&session_record)
398 .await
399 .expect("Error saving session");
400
401 let loaded = store
402 .load(&session_record.id)
403 .await
404 .expect("Error loading")
405 .expect("Value missing");
406
407 assert_eq!(session_record, loaded, "Save and load match");
408
409 store
410 .delete(&session_record.id)
411 .await
412 .expect("Error deleting session record");
413
414 let loaded = store.load(&session_record.id).await.expect("Error loading");
415
416 assert!(loaded.is_none())
417 }
418}