tower_sessions_deadpool_sqlite_store/
lib.rs1use async_trait::async_trait;
2use rusqlite::{params, OptionalExtension};
3use thiserror::Error;
4use tower_sessions::{cookie::time::OffsetDateTime, session::{Id, Record}, session_store, ExpiredDeletion, SessionStore};
5use deadpool_sqlite::{Object, Pool};
6
7const DEFAULT_TABLE_NAME: &'static str = "__tower_sessions";
8
9#[derive(Debug, Error)]
10pub enum DeadpoolSqliteStoreError {
11 #[error("Deadpool interact error: {0}")]
12 DeadpoolInteract(#[from] deadpool_sqlite::InteractError),
13 #[error("Deadpool pool error: {0}")]
14 DeadpoolPool(#[from] deadpool_sqlite::PoolError),
15 #[error("Rusqlite error: {0}")]
16 Rusqlite(#[from] rusqlite::Error),
17 #[error("Serde json decode error: {0}")]
18 JsonDecode(serde_json::Error),
19 #[error("Serde json encode error: {0}")]
20 JsonEncode(serde_json::Error),
21}
22
23impl From<DeadpoolSqliteStoreError> for session_store::Error {
24 fn from (err: DeadpoolSqliteStoreError) -> Self {
25 use DeadpoolSqliteStoreError::*;
26 use session_store::Error;
27
28 match err {
29 JsonEncode(inner) => Error::Encode(inner.to_string()),
30 JsonDecode(inner) => Error::Decode(inner.to_string()),
31 other => Error::Backend(other.to_string()),
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
37pub struct DeadpoolSqliteStore {
38 pool: Pool,
39 table_name: String,
40}
41impl DeadpoolSqliteStore {
42 pub fn new(pool: Pool) -> Self {
43 Self::new_with_table_name(pool, DEFAULT_TABLE_NAME).unwrap()
44 }
45
46 pub fn new_with_table_name<T: Into<String>>(pool: Pool, table_name: T) -> Result<Self, String> {
47 let table_name = table_name.into();
48
49 if table_name.is_empty() || !table_name.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_') {
50 Err("Table name is not valid. Can only contain ascii alphanumeric, - and _".to_string())?;
51 }
52
53 Ok(Self {
54 pool,
55 table_name,
56 })
57 }
58
59 pub async fn get_conn(&self) -> Result<Object, session_store::Error> {
60 Ok(self.pool.get().await
61 .map_err(DeadpoolSqliteStoreError::from)?)
62 }
63
64 pub async fn migrate(&self) -> Result<(), session_store::Error> {
65 let conn = self.get_conn().await?;
66
67 let sql = format!(r#"
68 CREATE TABLE IF NOT EXISTS {} (
69 id TEXT PRIMARY KEY NOT NULL,
70 data BLOB NOT NULL,
71 expiry_date INTEGER NOT NULL
72 );"#,
73 self.table_name);
74
75 conn.interact(move |conn| {
76 conn.execute(&sql, ())
77 }).await
78 .map_err(DeadpoolSqliteStoreError::from)?
79 .map_err(DeadpoolSqliteStoreError::from)?;
80
81 Ok(())
82 }
83}
84
85#[async_trait]
86impl ExpiredDeletion for DeadpoolSqliteStore {
87 async fn delete_expired(&self) -> Result<(), session_store::Error> {
88 let sql = format!(r#"
89 DELETE FROM {}
90 WHERE expiry_date < ?1
91 "#, self.table_name);
92 let now = OffsetDateTime::now_utc().unix_timestamp();
93
94 let conn = self.get_conn().await?;
95
96 conn.interact(move |conn|
97 conn.execute(&sql, params![now]))
98 .await
99 .map_err(DeadpoolSqliteStoreError::from)?
100 .map_err(DeadpoolSqliteStoreError::from)?;
101
102 Ok(())
103 }
104}
105
106#[async_trait]
107impl SessionStore for DeadpoolSqliteStore {
108 async fn create(&self, record: &mut Record) -> Result<(), session_store::Error> {
109 let exists_sql = format!(r#"SELECT 1 FROM {} WHERE id = ?1"#, self.table_name);
110 let insert_sql = format!(r#"
111 INSERT INTO {} (id, data, expiry_date)
112 VALUES (?1, ?2, ?3);
113 "#, self.table_name);
114
115 let mut id = record.id.clone();
116 let payload = serde_json::to_vec(&record)
117 .map_err(|e| DeadpoolSqliteStoreError::JsonEncode(e))?;
118 let expiry = record.expiry_date.unix_timestamp();
119
120 let conn = self.get_conn().await?;
121 let id = conn.interact(move |conn| {
122 let tx = conn.transaction()?;
123
124 {
125 let mut exists_stmd = tx.prepare_cached(&exists_sql)?;
126
127 while exists_stmd.exists(params![id.to_string()])? {
129 id = Id::default();
130 }
131 }
132
133 {
134 let mut insert_stmt = tx.prepare_cached(&insert_sql)?;
135
136 insert_stmt.execute(params![
137 id.to_string(),
138 payload,
139 expiry,
140 ])?;
141 }
142
143 tx.commit()?;
144
145 Ok::<_, DeadpoolSqliteStoreError>(id)
146 })
147 .await
148 .map_err(DeadpoolSqliteStoreError::from)?
149 .map_err(DeadpoolSqliteStoreError::from)?;
150
151 record.id = id;
152
153 Ok(())
154 }
155
156 async fn save(&self, record: &Record) -> Result<(), session_store::Error> {
157 let update_sql = format!(r#"
158 UPDATE {} SET
159 data = ?1,
160 expiry_date = ?2
161 WHERE
162 id = ?3;
163 "#, self.table_name);
164
165 let conn = self.get_conn().await?;
166
167
168 let id = record.id.clone();
169 let payload = serde_json::to_vec(&record)
170 .map_err(|e| DeadpoolSqliteStoreError::JsonEncode(e))?;
171 let expiry = record.expiry_date.unix_timestamp();
172
173 conn.interact(move |conn| {
174 let mut update_stmt = conn.prepare_cached(&update_sql)?;
175
176 update_stmt.execute(params![
177 payload,
178 expiry,
179 id.to_string(),
180 ])?;
181
182 Ok::<_, DeadpoolSqliteStoreError>(())
183 })
184 .await
185 .map_err(DeadpoolSqliteStoreError::from)?
186 .map_err(DeadpoolSqliteStoreError::from)?;
187
188 Ok(())
189 }
190
191 async fn load(&self, id: &Id) -> Result<Option<Record>, session_store::Error> {
192 let select_sql = format!(r#"
193 SELECT data
194 FROM {}
195 WHERE
196 id = ?1
197 AND expiry_date > ?2;
198 "#, self.table_name);
199
200 let conn = self.get_conn().await?;
201 let id_string = id.to_string();
202 let payload = conn.interact(move |conn| {
203 let now = OffsetDateTime::now_utc().unix_timestamp();
204
205 let mut select_stmt = conn.prepare_cached(&select_sql)?;
206
207 let data = select_stmt.query_row(params![id_string, now], |row|
208 row.get::<_, Vec<u8>>(0))
209 .optional()?;
210
211 Ok::<_, DeadpoolSqliteStoreError>(data)
212 })
213 .await
214 .map_err(DeadpoolSqliteStoreError::from)?
215 .map_err(DeadpoolSqliteStoreError::from)?;
216
217 let record = payload
218 .map(|data| serde_json::from_slice::<Record>(&data))
219 .transpose()
220 .map_err(|e| DeadpoolSqliteStoreError::JsonDecode(e))?
221 .map(|mut record| {
222 record.id = id.to_owned();
224 record
225 });
226
227 Ok(record)
228 }
229
230 async fn delete(&self, id: &Id) -> Result<(), session_store::Error> {
231 let delete_sql = format!(r#"
232 DELETE FROM {}
233 WHERE id = ?1
234 "#, self.table_name);
235
236 let conn = self.get_conn().await?;
237 let id_string = id.to_string();
238 conn.interact(move |conn| {
239 let mut delete_stmt = conn.prepare_cached(&delete_sql)?;
240
241 delete_stmt.execute(params![id_string])?;
242
243 Ok::<_, DeadpoolSqliteStoreError>(())
244 })
245 .await
246 .map_err(DeadpoolSqliteStoreError::from)?
247 .map_err(DeadpoolSqliteStoreError::from)?;
248
249 Ok(())
250 }
251}
252