tower_sessions_deadpool_sqlite_store/
lib.rs

1use async_trait::async_trait;
2use rusqlite::{params, OptionalExtension};
3use thiserror::Error;
4use tower_sessions::{cookie::time::OffsetDateTime, session::{Id, Record}, session_store, ExpiredDeletion, SessionStore};
5use deadpool_sqlite::{Object, Pool};
6
7const DEFAULT_TABLE_NAME: &'static str = "__tower_sessions";
8
9#[derive(Debug, Error)]
10pub enum DeadpoolSqliteStoreError {
11    #[error("Deadpool interact error: {0}")]
12    DeadpoolInteract(#[from] deadpool_sqlite::InteractError),
13    #[error("Deadpool pool error: {0}")]
14    DeadpoolPool(#[from] deadpool_sqlite::PoolError),
15    #[error("Rusqlite error: {0}")]
16    Rusqlite(#[from] rusqlite::Error),
17    #[error("Serde json decode error: {0}")]
18    JsonDecode(serde_json::Error),
19    #[error("Serde json encode error: {0}")]
20    JsonEncode(serde_json::Error),
21}
22
23impl From<DeadpoolSqliteStoreError> for session_store::Error {
24    fn from (err: DeadpoolSqliteStoreError) -> Self {
25        use DeadpoolSqliteStoreError::*;
26        use session_store::Error;
27
28        match err {
29            JsonEncode(inner) => Error::Encode(inner.to_string()),
30            JsonDecode(inner) => Error::Decode(inner.to_string()),
31            other => Error::Backend(other.to_string()),
32        }
33    }
34}
35
36#[derive(Debug, Clone)]
37pub struct DeadpoolSqliteStore {
38    pool: Pool,
39    table_name: String,
40}
41impl DeadpoolSqliteStore {
42    pub fn new(pool: Pool) -> Self {
43        Self::new_with_table_name(pool, DEFAULT_TABLE_NAME).unwrap()
44    }
45
46    pub fn new_with_table_name<T: Into<String>>(pool: Pool, table_name: T) -> Result<Self, String> {
47        let table_name = table_name.into();
48
49        if table_name.is_empty() || !table_name.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') {
50            Err("Table name is not valid. Can only contain ascii alphanumeric, - and _".to_string())?;
51        }
52        
53        Ok(Self {
54            pool,
55            table_name,
56        })
57    }
58
59    pub async fn get_conn(&self) -> Result<Object, session_store::Error> {
60        Ok(self.pool.get().await
61            .map_err(DeadpoolSqliteStoreError::from)?)
62    }
63
64    pub async fn migrate(&self) -> Result<(), session_store::Error> {
65        let conn = self.get_conn().await?;
66
67        let sql = format!(r#"
68            CREATE TABLE IF NOT EXISTS {} (
69                id TEXT PRIMARY KEY NOT NULL,
70                data BLOB NOT NULL,
71                expiry_date INTEGER NOT NULL
72            );"#,
73            self.table_name);
74
75        conn.interact(move |conn| {
76            conn.execute(&sql, ())
77        }).await
78        .map_err(DeadpoolSqliteStoreError::from)?
79        .map_err(DeadpoolSqliteStoreError::from)?;
80        
81        Ok(())
82    }
83}
84
85#[async_trait]
86impl ExpiredDeletion for DeadpoolSqliteStore {
87    async fn delete_expired(&self) -> Result<(), session_store::Error> {
88        let sql = format!(r#"
89            DELETE FROM {}
90            WHERE expiry_date < ?1 
91        "#, self.table_name);
92        let now = OffsetDateTime::now_utc().unix_timestamp();
93        
94        let conn = self.get_conn().await?;
95        
96        conn.interact(move |conn| 
97            conn.execute(&sql, params![now]))
98            .await
99            .map_err(DeadpoolSqliteStoreError::from)?
100            .map_err(DeadpoolSqliteStoreError::from)?;
101        
102        Ok(())
103    }
104}
105
106#[async_trait]
107impl SessionStore for DeadpoolSqliteStore {
108    async fn create(&self, record: &mut Record) -> Result<(), session_store::Error> {
109        let exists_sql = format!(r#"SELECT 1 FROM {} WHERE id = ?1"#, self.table_name);
110        let insert_sql = format!(r#"
111            INSERT INTO {} (id, data, expiry_date)
112            VALUES (?1, ?2, ?3);
113        "#, self.table_name);
114
115        let mut id = record.id.clone();
116        let payload = serde_json::to_vec(&record)
117            .map_err(|e| DeadpoolSqliteStoreError::JsonEncode(e))?;
118        let expiry = record.expiry_date.unix_timestamp();
119    
120        let conn = self.get_conn().await?;
121        let id = conn.interact(move |conn| {
122            let tx = conn.transaction()?;
123
124            {
125                let mut exists_stmd = tx.prepare_cached(&exists_sql)?;
126
127                // Re-key the record until we successfully find a unique ID
128                while exists_stmd.exists(params![id.to_string()])? {
129                    id = Id::default();
130                }
131            }
132
133            {
134                let mut insert_stmt = tx.prepare_cached(&insert_sql)?;
135
136                insert_stmt.execute(params![
137                    id.to_string(),
138                    payload,
139                    expiry,
140                ])?;
141            }
142
143            tx.commit()?;
144
145            Ok::<_, DeadpoolSqliteStoreError>(id)
146        })
147        .await
148        .map_err(DeadpoolSqliteStoreError::from)?
149        .map_err(DeadpoolSqliteStoreError::from)?;
150
151        record.id = id;
152
153        Ok(())
154    }
155
156    async fn save(&self, record: &Record) -> Result<(), session_store::Error> {
157        let update_sql = format!(r#"
158            UPDATE {} SET
159                data = ?1, 
160                expiry_date = ?2
161            WHERE
162                id = ?3;
163        "#, self.table_name);
164
165        let conn = self.get_conn().await?;
166        
167
168        let id = record.id.clone();
169        let payload = serde_json::to_vec(&record)
170                .map_err(|e| DeadpoolSqliteStoreError::JsonEncode(e))?;
171        let expiry = record.expiry_date.unix_timestamp();
172
173        conn.interact(move |conn| {
174            let mut update_stmt = conn.prepare_cached(&update_sql)?;
175
176            update_stmt.execute(params![
177                payload,
178                expiry,
179                id.to_string(),
180            ])?;
181
182            Ok::<_, DeadpoolSqliteStoreError>(())
183        })
184        .await
185        .map_err(DeadpoolSqliteStoreError::from)?
186        .map_err(DeadpoolSqliteStoreError::from)?;
187
188        Ok(())
189    }
190
191    async fn load(&self, id: &Id) -> Result<Option<Record>, session_store::Error> {
192        let select_sql = format!(r#"
193            SELECT data 
194            FROM {} 
195            WHERE 
196                id = ?1
197                AND expiry_date > ?2;
198        "#, self.table_name);
199
200        let conn = self.get_conn().await?;
201        let id_string = id.to_string();
202        let payload = conn.interact(move |conn| {
203            let now = OffsetDateTime::now_utc().unix_timestamp();
204
205            let mut select_stmt = conn.prepare_cached(&select_sql)?;
206
207            let data = select_stmt.query_row(params![id_string, now], |row| 
208                row.get::<_, Vec<u8>>(0))
209                .optional()?;
210
211            Ok::<_, DeadpoolSqliteStoreError>(data)
212        })
213        .await
214        .map_err(DeadpoolSqliteStoreError::from)?
215        .map_err(DeadpoolSqliteStoreError::from)?;
216
217        let record = payload
218            .map(|data| serde_json::from_slice::<Record>(&data))
219            .transpose()
220            .map_err(|e| DeadpoolSqliteStoreError::JsonDecode(e))?
221            .map(|mut record| {
222                // Make sure the id is updated after the re-keying done during insert
223                record.id = id.to_owned();
224                record
225            });
226
227        Ok(record)
228    }
229
230    async fn delete(&self, id: &Id) -> Result<(), session_store::Error> {
231        let delete_sql = format!(r#"
232            DELETE FROM {}
233            WHERE id = ?1
234        "#, self.table_name);
235
236        let conn = self.get_conn().await?;
237        let id_string = id.to_string();
238        conn.interact(move |conn| {
239            let mut delete_stmt = conn.prepare_cached(&delete_sql)?;
240
241            delete_stmt.execute(params![id_string])?;
242
243            Ok::<_, DeadpoolSqliteStoreError>(())
244        })
245        .await
246        .map_err(DeadpoolSqliteStoreError::from)?
247        .map_err(DeadpoolSqliteStoreError::from)?;
248
249        Ok(())
250    }
251}
252