tower_sessions_postgres_store/
lib.rs1use async_trait::async_trait;
2use deadpool_postgres::{GenericClient, Pool};
3use time::OffsetDateTime;
4use tower_sessions_core::{
5 session::{Id, Record},
6 session_store, ExpiredDeletion, SessionStore,
7};
8
9#[derive(Debug, thiserror::Error)]
10#[error("Pg session store error: {0}")]
11pub enum Error {
12 Pool(
13 #[from]
14 #[source]
15 deadpool_postgres::PoolError,
16 ),
17 Pg(
18 #[from]
19 #[source]
20 tokio_postgres::Error,
21 ),
22 Encode(
23 #[from]
24 #[source]
25 rmp_serde::encode::Error,
26 ),
27 Decode(
28 #[from]
29 #[source]
30 rmp_serde::decode::Error,
31 ),
32}
33
34impl From<Error> for session_store::Error {
35 fn from(e: Error) -> Self {
36 Self::Backend(e.to_string())
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct PostgresStore {
43 pool: Pool,
44 schema_name: String,
45 table_name: String,
46}
47
48impl PostgresStore {
49 pub fn new(pool: Pool) -> Self {
51 Self {
52 pool,
53 schema_name: "tower_sessions".to_string(),
54 table_name: "session".to_string(),
55 }
56 }
57
58 pub fn with_schema_name(mut self, schema_name: impl AsRef<str>) -> Result<Self, String> {
60 let schema_name = schema_name.as_ref();
61 if !is_valid_identifier(schema_name) {
62 return Err(format!(
63 "Invalid schema name '{}'. Schema names must start with a letter or underscore \
64 (including letters with diacritical marks and non-Latin letters).Subsequent \
65 characters can be letters, underscores, digits (0-9), or dollar signs ($).",
66 schema_name
67 ));
68 }
69
70 schema_name.clone_into(&mut self.schema_name);
71 Ok(self)
72 }
73
74 pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
76 let table_name = table_name.as_ref();
77 if !is_valid_identifier(table_name) {
78 return Err(format!(
79 "Invalid table name '{}'. Table names must start with a letter or underscore \
80 (including letters with diacritical marks and non-Latin letters).Subsequent \
81 characters can be letters, underscores, digits (0-9), or dollar signs ($).",
82 table_name
83 ));
84 }
85
86 table_name.clone_into(&mut self.table_name);
87 Ok(self)
88 }
89
90 pub async fn migrate(&self) -> Result<(), Error> {
92 let mut client = self.pool.get().await?;
93 let tx = client.transaction().await?;
94
95 let create_schema_query = format!(
96 r#"create schema if not exists "{schema_name}""#,
97 schema_name = self.schema_name,
98 );
99
100 if let Err(err) = tx.execute(&create_schema_query, &[]).await {
104 use tokio_postgres::error::SqlState;
105 if matches!(
106 err.code(),
107 Some(&SqlState::DUPLICATE_SCHEMA | &SqlState::UNIQUE_VIOLATION)
108 ) {
109 return Ok(());
110 }
111
112 return Err(err.into());
113 }
114
115 let create_table_query = format!(
116 r#"
117 create table if not exists "{schema_name}"."{table_name}"
118 (
119 id text primary key not null,
120 data bytea not null,
121 expiry_date timestamptz not null
122 )
123 "#,
124 schema_name = self.schema_name,
125 table_name = self.table_name
126 );
127 tx.execute(&create_table_query, &[]).await?;
128
129 tx.commit().await?;
130
131 Ok(())
132 }
133
134 async fn id_exists(&self, conn: &impl GenericClient, id: &Id) -> Result<bool, Error> {
135 let query = format!(
136 r#"
137 select exists(select 1 from "{schema_name}"."{table_name}" where id = $1)
138 "#,
139 schema_name = self.schema_name,
140 table_name = self.table_name
141 );
142
143 Ok(conn.query_one(&query, &[&id.to_string()]).await?.get(0))
144 }
145
146 async fn save_with_conn(
147 &self,
148 conn: &impl GenericClient,
149 record: &Record,
150 ) -> Result<(), Error> {
151 let query = format!(
152 r#"
153 insert into "{schema_name}"."{table_name}" (id, data, expiry_date)
154 values ($1, $2, $3)
155 on conflict (id) do update
156 set
157 data = excluded.data,
158 expiry_date = excluded.expiry_date
159 "#,
160 schema_name = self.schema_name,
161 table_name = self.table_name
162 );
163 conn.execute(
164 &query,
165 &[
166 &record.id.to_string(),
167 &rmp_serde::to_vec(&record).map_err(Error::Encode)?,
168 &record.expiry_date,
169 ],
170 )
171 .await?;
172
173 Ok(())
174 }
175}
176
177#[async_trait]
178impl ExpiredDeletion for PostgresStore {
179 async fn delete_expired(&self) -> session_store::Result<()> {
180 let query = format!(
181 r#"
182 delete from "{schema_name}"."{table_name}"
183 where expiry_date < (now() at time zone 'utc')
184 "#,
185 schema_name = self.schema_name,
186 table_name = self.table_name
187 );
188 let client = self.pool.get().await.map_err(Error::Pool)?;
189 client.execute(&query, &[]).await.map_err(Error::Pg)?;
190 Ok(())
191 }
192}
193
194#[async_trait]
195impl SessionStore for PostgresStore {
196 async fn create(&self, record: &mut Record) -> session_store::Result<()> {
197 let mut client = self.pool.get().await.map_err(Error::Pool)?;
198 let tx = client.transaction().await.map_err(Error::Pg)?;
199
200 while self.id_exists(&tx, &record.id).await? {
201 record.id = Id::default();
202 }
203
204 self.save_with_conn(&tx, record).await?;
205 tx.commit().await.map_err(Error::Pg)?;
206 Ok(())
207 }
208
209 async fn save(&self, record: &Record) -> session_store::Result<()> {
210 let mut client = self.pool.get().await.map_err(Error::Pool)?;
211 let tx = client.transaction().await.map_err(Error::Pg)?;
212 self.save_with_conn(&tx, record).await?;
213 tx.commit().await.map_err(Error::Pg)?;
214 Ok(())
215 }
216
217 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
218 let query = format!(
219 r#"
220 select data from "{schema_name}"."{table_name}"
221 where id = $1 and expiry_date > $2
222 "#,
223 schema_name = self.schema_name,
224 table_name = self.table_name
225 );
226 let client = self.pool.get().await.map_err(Error::Pool)?;
227 let record_value: Option<Vec<u8>> = client
228 .query_opt(
229 &query,
230 &[&session_id.to_string(), &OffsetDateTime::now_utc()],
231 )
232 .await
233 .map_err(Error::Pg)?
234 .map(|row| row.get(0));
235
236 if let Some(data) = record_value {
237 Ok(Some(rmp_serde::from_slice(&data).map_err(Error::Decode)?))
238 } else {
239 Ok(None)
240 }
241 }
242
243 async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
244 let query = format!(
245 r#"delete from "{schema_name}"."{table_name}" where id = $1"#,
246 schema_name = self.schema_name,
247 table_name = self.table_name
248 );
249 let client = self.pool.get().await.map_err(Error::Pool)?;
250 client
251 .execute(&query, &[&session_id.to_string()])
252 .await
253 .map_err(Error::Pg)?;
254
255 Ok(())
256 }
257}
258
259fn is_valid_identifier(name: &str) -> bool {
264 !name.is_empty()
265 && name
266 .chars()
267 .next()
268 .map(|c| c.is_alphabetic() || c == '_')
269 .unwrap_or_default()
270 && name
271 .chars()
272 .all(|c| c.is_alphanumeric() || c == '_' || c == '$')
273}