tower_sessions_postgres_store/
lib.rs

1use async_trait::async_trait;
2use deadpool_postgres::{GenericClient, Pool};
3use time::OffsetDateTime;
4use tower_sessions_core::{
5    session::{Id, Record},
6    session_store, ExpiredDeletion, SessionStore,
7};
8
9#[derive(Debug, thiserror::Error)]
10#[error("Pg session store error: {0}")]
11pub enum Error {
12    Pool(
13        #[from]
14        #[source]
15        deadpool_postgres::PoolError,
16    ),
17    Pg(
18        #[from]
19        #[source]
20        tokio_postgres::Error,
21    ),
22    Encode(
23        #[from]
24        #[source]
25        rmp_serde::encode::Error,
26    ),
27    Decode(
28        #[from]
29        #[source]
30        rmp_serde::decode::Error,
31    ),
32}
33
34impl From<Error> for session_store::Error {
35    fn from(e: Error) -> Self {
36        Self::Backend(e.to_string())
37    }
38}
39
40/// A PostgreSQL session store.
41#[derive(Clone, Debug)]
42pub struct PostgresStore {
43    pool: Pool,
44    schema_name: String,
45    table_name: String,
46}
47
48impl PostgresStore {
49    /// Create a new PostgreSQL store with the provided connection pool.
50    pub fn new(pool: Pool) -> Self {
51        Self {
52            pool,
53            schema_name: "tower_sessions".to_string(),
54            table_name: "session".to_string(),
55        }
56    }
57
58    /// Set the session table schema name with the provided name.
59    pub fn with_schema_name(mut self, schema_name: impl AsRef<str>) -> Result<Self, String> {
60        let schema_name = schema_name.as_ref();
61        if !is_valid_identifier(schema_name) {
62            return Err(format!(
63                "Invalid schema name '{}'. Schema names must start with a letter or underscore \
64                 (including letters with diacritical marks and non-Latin letters).Subsequent \
65                 characters can be letters, underscores, digits (0-9), or dollar signs ($).",
66                schema_name
67            ));
68        }
69
70        schema_name.clone_into(&mut self.schema_name);
71        Ok(self)
72    }
73
74    /// Set the session table name with the provided name.
75    pub fn with_table_name(mut self, table_name: impl AsRef<str>) -> Result<Self, String> {
76        let table_name = table_name.as_ref();
77        if !is_valid_identifier(table_name) {
78            return Err(format!(
79                "Invalid table name '{}'. Table names must start with a letter or underscore \
80                 (including letters with diacritical marks and non-Latin letters).Subsequent \
81                 characters can be letters, underscores, digits (0-9), or dollar signs ($).",
82                table_name
83            ));
84        }
85
86        table_name.clone_into(&mut self.table_name);
87        Ok(self)
88    }
89
90    /// Migrate the session schema.
91    pub async fn migrate(&self) -> Result<(), Error> {
92        let mut client = self.pool.get().await?;
93        let tx = client.transaction().await?;
94
95        let create_schema_query = format!(
96            r#"create schema if not exists "{schema_name}""#,
97            schema_name = self.schema_name,
98        );
99
100        // Concurrent create schema may fail due to duplicate key violations.
101        //
102        // This works around that by assuming the schema must exist on such an error.
103        if let Err(err) = tx.execute(&create_schema_query, &[]).await {
104            use tokio_postgres::error::SqlState;
105            if matches!(
106                err.code(),
107                Some(&SqlState::DUPLICATE_SCHEMA | &SqlState::UNIQUE_VIOLATION)
108            ) {
109                return Ok(());
110            }
111
112            return Err(err.into());
113        }
114
115        let create_table_query = format!(
116            r#"
117            create table if not exists "{schema_name}"."{table_name}"
118            (
119                id text primary key not null,
120                data bytea not null,
121                expiry_date timestamptz not null
122            )
123            "#,
124            schema_name = self.schema_name,
125            table_name = self.table_name
126        );
127        tx.execute(&create_table_query, &[]).await?;
128
129        tx.commit().await?;
130
131        Ok(())
132    }
133
134    async fn id_exists(&self, conn: &impl GenericClient, id: &Id) -> Result<bool, Error> {
135        let query = format!(
136            r#"
137            select exists(select 1 from "{schema_name}"."{table_name}" where id = $1)
138            "#,
139            schema_name = self.schema_name,
140            table_name = self.table_name
141        );
142
143        Ok(conn.query_one(&query, &[&id.to_string()]).await?.get(0))
144    }
145
146    async fn save_with_conn(
147        &self,
148        conn: &impl GenericClient,
149        record: &Record,
150    ) -> Result<(), Error> {
151        let query = format!(
152            r#"
153            insert into "{schema_name}"."{table_name}" (id, data, expiry_date)
154            values ($1, $2, $3)
155            on conflict (id) do update
156            set
157              data = excluded.data,
158              expiry_date = excluded.expiry_date
159            "#,
160            schema_name = self.schema_name,
161            table_name = self.table_name
162        );
163        conn.execute(
164            &query,
165            &[
166                &record.id.to_string(),
167                &rmp_serde::to_vec(&record).map_err(Error::Encode)?,
168                &record.expiry_date,
169            ],
170        )
171        .await?;
172
173        Ok(())
174    }
175}
176
177#[async_trait]
178impl ExpiredDeletion for PostgresStore {
179    async fn delete_expired(&self) -> session_store::Result<()> {
180        let query = format!(
181            r#"
182            delete from "{schema_name}"."{table_name}"
183            where expiry_date < (now() at time zone 'utc')
184            "#,
185            schema_name = self.schema_name,
186            table_name = self.table_name
187        );
188        let client = self.pool.get().await.map_err(Error::Pool)?;
189        client.execute(&query, &[]).await.map_err(Error::Pg)?;
190        Ok(())
191    }
192}
193
194#[async_trait]
195impl SessionStore for PostgresStore {
196    async fn create(&self, record: &mut Record) -> session_store::Result<()> {
197        let mut client = self.pool.get().await.map_err(Error::Pool)?;
198        let tx = client.transaction().await.map_err(Error::Pg)?;
199
200        while self.id_exists(&tx, &record.id).await? {
201            record.id = Id::default();
202        }
203
204        self.save_with_conn(&tx, record).await?;
205        tx.commit().await.map_err(Error::Pg)?;
206        Ok(())
207    }
208
209    async fn save(&self, record: &Record) -> session_store::Result<()> {
210        let mut client = self.pool.get().await.map_err(Error::Pool)?;
211        let tx = client.transaction().await.map_err(Error::Pg)?;
212        self.save_with_conn(&tx, record).await?;
213        tx.commit().await.map_err(Error::Pg)?;
214        Ok(())
215    }
216
217    async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
218        let query = format!(
219            r#"
220            select data from "{schema_name}"."{table_name}"
221            where id = $1 and expiry_date > $2
222            "#,
223            schema_name = self.schema_name,
224            table_name = self.table_name
225        );
226        let client = self.pool.get().await.map_err(Error::Pool)?;
227        let record_value: Option<Vec<u8>> = client
228            .query_opt(
229                &query,
230                &[&session_id.to_string(), &OffsetDateTime::now_utc()],
231            )
232            .await
233            .map_err(Error::Pg)?
234            .map(|row| row.get(0));
235
236        if let Some(data) = record_value {
237            Ok(Some(rmp_serde::from_slice(&data).map_err(Error::Decode)?))
238        } else {
239            Ok(None)
240        }
241    }
242
243    async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
244        let query = format!(
245            r#"delete from "{schema_name}"."{table_name}" where id = $1"#,
246            schema_name = self.schema_name,
247            table_name = self.table_name
248        );
249        let client = self.pool.get().await.map_err(Error::Pool)?;
250        client
251            .execute(&query, &[&session_id.to_string()])
252            .await
253            .map_err(Error::Pg)?;
254
255        Ok(())
256    }
257}
258
259/// A valid PostreSQL identifier must start with a letter or underscore
260/// (including letters with diacritical marks and non-Latin letters). Subsequent
261/// characters in an identifier or key word can be letters, underscores, digits
262/// (0-9), or dollar signs ($). See https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS for details.
263fn is_valid_identifier(name: &str) -> bool {
264    !name.is_empty()
265        && name
266            .chars()
267            .next()
268            .map(|c| c.is_alphabetic() || c == '_')
269            .unwrap_or_default()
270        && name
271            .chars()
272            .all(|c| c.is_alphanumeric() || c == '_' || c == '$')
273}