tower_sessions_rusqlite_store/
lib.rs1use async_trait::async_trait;
2use rusqlite::OptionalExtension;
3use std::error::Error;
4use time::OffsetDateTime;
5pub use tokio_rusqlite;
6use tokio_rusqlite::{params, Connection, Result as SqlResult};
7use tower_sessions_core::{
8 session::{Id, Record},
9 session_store, ExpiredDeletion, SessionStore,
10};
11
12#[derive(thiserror::Error, Debug)]
14pub enum RusqliteStoreError {
15 #[error(transparent)]
17 TokioRusqlite(#[from] tokio_rusqlite::Error),
18
19 #[error(transparent)]
21 Encode(#[from] rmp_serde::encode::Error),
22
23 #[error(transparent)]
25 Decode(#[from] rmp_serde::decode::Error),
26
27 #[error("Backend error: {0}")]
29 Other(String),
30}
31
32impl From<RusqliteStoreError> for session_store::Error {
33 fn from(err: RusqliteStoreError) -> Self {
34 match err {
35 RusqliteStoreError::TokioRusqlite(inner) => {
36 session_store::Error::Backend(inner.to_string())
37 }
38 RusqliteStoreError::Decode(inner) => session_store::Error::Decode(inner.to_string()),
39 RusqliteStoreError::Encode(inner) => session_store::Error::Encode(inner.to_string()),
40 RusqliteStoreError::Other(inner) => session_store::Error::Backend(inner),
41 }
42 }
43}
44
45#[derive(Clone, Debug)]
46pub struct RusqliteStore {
47 conn: Connection,
48 table_name: String,
49}
50
51impl RusqliteStore {
52 pub fn new(conn: Connection) -> Self {
65 Self {
66 conn,
67 table_name: "tower_sessions".into(),
68 }
69 }
70
71 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
73 let table_name = table_name.as_ref();
74 if !is_valid_table_name(table_name) {
75 return Err(format!(
76 "Invalid table name '{}'. Table names must be alphanumeric and may contain \
77 hyphens or underscores.",
78 table_name
79 ));
80 }
81
82 self.table_name = table_name.to_owned();
83 Ok(self)
84 }
85
86 pub async fn migrate(&self) -> SqlResult<()> {
88 let conn = self.conn.clone();
89 let query = format!(
90 r#"
91 create table if not exists {}
92 (
93 id text primary key not null,
94 data blob not null,
95 expiry_date integer not null
96 )
97 "#,
98 self.table_name
99 );
100 conn.call(move |conn| conn.execute(&query, [])).await?;
101
102 Ok(())
103 }
104}
105
106fn id_exists_with_conn(
107 conn: &rusqlite::Connection,
108 table_name: &str,
109 id: &Id,
110) -> rusqlite::Result<bool> {
111 let query = format!(
112 r#"
113 select exists(select 1 from {} where id = ?1)
114 "#,
115 table_name
116 );
117 let mut stmt = conn.prepare(&query)?;
118 stmt.query_row(params![id.to_string()], |row| row.get(0))
119}
120
121fn save_with_conn(
122 conn: &rusqlite::Connection,
123 table_name: &str,
124 record: &Record,
125 record_data: &[u8],
126) -> rusqlite::Result<usize> {
127 let query = format!(
128 r#"
129 insert into {}
130 (id, data, expiry_date) values (?1, ?2, ?3)
131 on conflict(id) do update set
132 data = excluded.data,
133 expiry_date = excluded.expiry_date
134 "#,
135 table_name
136 );
137 conn.execute(
138 &query,
139 params![
140 record.id.to_string(),
141 record_data,
142 record.expiry_date.unix_timestamp()
143 ],
144 )
145}
146
147#[async_trait]
148impl ExpiredDeletion for RusqliteStore {
149 async fn delete_expired(&self) -> session_store::Result<()> {
150 let conn = self.conn.clone();
151 let query = format!(
152 r#"
153 delete from {table_name}
154 where expiry_date < ?1
155 "#,
156 table_name = self.table_name
157 );
158 conn.call(move |conn| conn.execute(&query, [OffsetDateTime::now_utc().unix_timestamp()]))
159 .await
160 .map_err(|e| {
161 eprintln!("Error deleting expired sessions: {:?}", e);
164 RusqliteStoreError::TokioRusqlite(e)
165 })?;
166
167 Ok(())
168 }
169}
170
171#[async_trait]
172impl SessionStore for RusqliteStore {
173 async fn create(&self, record: &mut Record) -> session_store::Result<()> {
174 let conn = self.conn.clone();
175
176 let new_id = conn
177 .call({
178 let mut record = record.clone();
179 let table_name = self.table_name.clone();
180
181 move |conn| {
182 let tx = conn.transaction()?;
183
184 while id_exists_with_conn(&tx, &table_name, &record.id)? {
185 record.id = Id::default();
186 }
187
188 let record_data = rmp_serde::to_vec(&record).map_err(Box::new)?;
189
190 save_with_conn(&tx, &table_name, &record, &record_data)?;
191
192 tx.commit()?;
193
194 Ok(record.id)
195 }
196 })
197 .await
198 .map_err(
199 |e: tokio_rusqlite::Error<Box<dyn Error + Send + Sync>>| match e {
200 tokio_rusqlite::Error::Error(boxed_err) => {
201 match boxed_err.downcast::<rmp_serde::encode::Error>() {
202 Ok(encode_error) => RusqliteStoreError::Encode(*encode_error),
203 Err(original_box) => {
204 RusqliteStoreError::Other(original_box.to_string())
205 }
206 }
207 }
208 other => RusqliteStoreError::Other(other.to_string()),
209 },
210 )?;
211
212 record.id = new_id;
213
214 Ok(())
215 }
216
217 async fn save(&self, record: &Record) -> session_store::Result<()> {
218 let conn = self.conn.clone();
219 let table_name = self.table_name.clone();
220 let record = record.clone();
221 let record_data = rmp_serde::to_vec(&record).map_err(RusqliteStoreError::Encode)?;
222
223 conn.call(move |conn| save_with_conn(conn, &table_name, &record, &record_data))
224 .await
225 .map_err(RusqliteStoreError::TokioRusqlite)?;
226
227 Ok(())
228 }
229
230 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
231 let conn = self.conn.clone();
232
233 let data = conn
234 .call({
235 let table_name = self.table_name.clone();
236 let session_id = session_id.to_string();
237 move |conn| {
238 let query = format!(
239 r#"
240 select data from {}
241 where id = ?1 and expiry_date > ?2
242 "#,
243 table_name
244 );
245 let mut stmt = conn.prepare(&query)?;
246 stmt.query_row(
247 params![session_id, OffsetDateTime::now_utc().unix_timestamp()],
248 |row| {
249 let data: Vec<u8> = row.get(0)?;
250 Ok(data)
251 },
252 )
253 .optional()
254 }
255 })
256 .await
257 .map_err(RusqliteStoreError::TokioRusqlite)?;
258
259 match data {
260 Some(data) => {
261 let record: Record =
262 rmp_serde::from_slice(&data).map_err(RusqliteStoreError::Decode)?;
263 Ok(Some(record))
264 }
265 None => Ok(None),
266 }
267 }
268
269 async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
270 let conn = self.conn.clone();
271
272 conn.call({
273 let table_name = self.table_name.clone();
274 let session_id = session_id.to_string();
275 move |conn| {
276 let query = format!(
277 r#"
278 delete from {} where id = ?1
279 "#,
280 table_name
281 );
282 conn.execute(&query, params![session_id])
283 }
284 })
285 .await
286 .map_err(RusqliteStoreError::TokioRusqlite)?;
287
288 Ok(())
289 }
290}
291
292fn is_valid_table_name(name: &str) -> bool {
293 !name.is_empty()
294 && name
295 .chars()
296 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
297}
298
299#[cfg(test)]
301mod rusqlite_store_tests {
302 use time::Duration;
303
304 use super::*;
305
306 async fn create_store() -> RusqliteStore {
307 let conn = Connection::open_in_memory().await.unwrap();
308 let store = RusqliteStore::new(conn);
309 store.migrate().await.unwrap();
310 store
311 }
312
313 #[tokio::test]
314 async fn test_create() {
315 let store = create_store().await;
316 let mut record = Record {
317 id: Default::default(),
318 data: Default::default(),
319 expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
320 };
321 assert!(store.create(&mut record).await.is_ok());
322 }
323
324 #[tokio::test]
325 async fn test_save() {
326 let store = create_store().await;
327 let record = Record {
328 id: Default::default(),
329 data: Default::default(),
330 expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
331 };
332 assert!(store.save(&record).await.is_ok());
333 }
334
335 #[tokio::test]
336 async fn test_load() {
337 let store = create_store().await;
338 let mut record = Record {
339 id: Default::default(),
340 data: Default::default(),
341 expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
342 };
343 store.create(&mut record).await.unwrap();
344 let loaded_record = store.load(&record.id).await.unwrap();
345 assert_eq!(Some(record), loaded_record);
346 }
347
348 #[tokio::test]
349 async fn test_delete() {
350 let store = create_store().await;
351 let mut record = Record {
352 id: Default::default(),
353 data: Default::default(),
354 expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
355 };
356 store.create(&mut record).await.unwrap();
357 assert!(store.delete(&record.id).await.is_ok());
358 assert_eq!(None, store.load(&record.id).await.unwrap());
359 }
360
361 #[tokio::test]
362 async fn test_create_id_collision() {
363 let store = create_store().await;
364 let expiry_date = OffsetDateTime::now_utc() + Duration::minutes(30);
365 let mut record1 = Record {
366 id: Default::default(),
367 data: Default::default(),
368 expiry_date,
369 };
370 let mut record2 = Record {
371 id: Default::default(),
372 data: Default::default(),
373 expiry_date,
374 };
375 store.create(&mut record1).await.unwrap();
376 record2.id = record1.id; store.create(&mut record2).await.unwrap();
378 assert_ne!(record1.id, record2.id); }
380
381 #[tokio::test]
382 async fn test_delete_expired() {
383 let store = create_store().await;
384 let mut record = Record {
385 id: Default::default(),
386 data: Default::default(),
387 expiry_date: OffsetDateTime::now_utc() - Duration::minutes(30),
388 };
389 store.create(&mut record).await.unwrap();
390 store.delete_expired().await.unwrap();
391 assert_eq!(None, store.load(&record.id).await.unwrap());
392 }
393}