tower_sessions_rorm_store/
lib.rs

1//! # tower-sessions-rorm-store
2//!
3//! Implementation of [SessionStore] provided by [tower_sessions] for [rorm].
4//!
5//! In order to provide the possibility to use a user-defined [Model], this crate
6//! defines [SessionModel] which must be implemented to create a [RormStore].
7//!
8//! Look at our example crate for the usage.
9
10#![warn(missing_docs)]
11
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::fmt::Formatter;
15use std::marker::PhantomData;
16use std::str::FromStr;
17
18use async_trait::async_trait;
19use rorm::and;
20use rorm::delete;
21use rorm::fields::types::Json;
22use rorm::insert;
23use rorm::internal::field::Field;
24use rorm::internal::field::FieldProxy;
25use rorm::query;
26use rorm::update;
27use rorm::Database;
28use rorm::FieldAccess;
29use rorm::Model;
30use rorm::Patch;
31pub use serde_json::Value;
32use thiserror::Error;
33/// Export tower sessions
34pub use tower_sessions;
35use tower_sessions::cookie::time::OffsetDateTime;
36use tower_sessions::session::Id;
37use tower_sessions::session::Record;
38use tower_sessions::session_store::Error;
39use tower_sessions::session_store::Result;
40use tower_sessions::ExpiredDeletion;
41use tower_sessions::SessionStore;
42use tracing::debug;
43use tracing::instrument;
44
45/// Implement this trait on a [Model] that should be used
46/// to store sessions in a database.
47///
48/// Ths [Model] must not define relations that make it impossible to
49/// delete it using the [FieldProxy] retrieved by [SessionModel::get_expires_at_field].
50pub trait SessionModel
51where
52    Self: Model + Send + Sync + 'static,
53    Self::Primary: Field<Type = String>,
54{
55    /// Retrieve the primary field from the Model
56    fn get_primary_field() -> FieldProxy<Self::Primary, Self> {
57        FieldProxy::new()
58    }
59
60    /// Retrieve the expires_at field from the Model
61    fn get_expires_at_field() -> FieldProxy<impl Field<Type = OffsetDateTime, Model = Self>, Self>;
62
63    /// Retrieve the data field of the Model
64    fn get_data_field(
65    ) -> FieldProxy<impl Field<Type = Json<HashMap<String, Value>>, Model = Self>, Self>;
66
67    /// Retrieve an insert patch that should use the parameters
68    /// provided for construction
69    fn get_insert_patch(
70        id: String,
71        expires_at: OffsetDateTime,
72        data: Json<HashMap<String, Value>>,
73    ) -> impl Patch<Model = Self> + Send + Sync + 'static;
74
75    /// Get the session data from an instance of the session [Model]
76    fn get_session_data(&self) -> (String, OffsetDateTime, Json<HashMap<String, Value>>);
77}
78
79/// The session store for rorm
80pub struct RormStore<S> {
81    db: Database,
82    marker: PhantomData<S>,
83}
84
85impl<S> RormStore<S> {
86    /// Construct a new Store
87    pub fn new(db: Database) -> Self {
88        Self {
89            db,
90            marker: PhantomData,
91        }
92    }
93}
94
95impl<S> Debug for RormStore<S> {
96    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97        write!(f, "")
98    }
99}
100
101impl<S> Clone for RormStore<S> {
102    fn clone(&self) -> Self {
103        Self {
104            db: self.db.clone(),
105            marker: PhantomData,
106        }
107    }
108}
109
110#[async_trait]
111impl<S> ExpiredDeletion for RormStore<S>
112where
113    S: Model + SessionModel + Debug,
114    <S as Model>::Primary: Field<Type = String>,
115    <S as Patch>::Decoder: Send + Sync + 'static,
116{
117    #[instrument(level = "trace")]
118    async fn delete_expired(&self) -> Result<()> {
119        let db = &self.db;
120
121        delete!(db, S)
122            .condition(S::get_expires_at_field().less_than(OffsetDateTime::now_utc()))
123            .await
124            .map_err(RormStoreError::from)?;
125
126        Ok(())
127    }
128}
129
130#[async_trait]
131impl<S> SessionStore for RormStore<S>
132where
133    S: Model + Send + Sync + SessionModel,
134    <S as Model>::Primary: Field<Type = String>,
135    <S as Patch>::Decoder: Send + Sync + 'static,
136{
137    #[instrument(level = "trace")]
138    async fn create(&self, session_record: &mut Record) -> Result<()> {
139        debug!("Creating new session");
140        let mut tx = self
141            .db
142            .start_transaction()
143            .await
144            .map_err(RormStoreError::from)?;
145        loop {
146            let existing = query!(&mut tx, S)
147                .condition(S::get_primary_field().equals(session_record.id.to_string()))
148                .optional()
149                .await
150                .map_err(RormStoreError::from)?;
151
152            if existing.is_none() {
153                insert!(&mut tx, S)
154                    .return_nothing()
155                    .single(&S::get_insert_patch(
156                        session_record.id.to_string(),
157                        session_record.expiry_date,
158                        Json(session_record.data.clone()),
159                    ))
160                    .await
161                    .map_err(RormStoreError::from)?;
162
163                break;
164            }
165
166            session_record.id = Id::default();
167        }
168
169        tx.commit().await.map_err(RormStoreError::from)?;
170
171        Ok(())
172    }
173
174    #[instrument(level = "trace")]
175    async fn save(&self, session_record: &Record) -> Result<()> {
176        let Record {
177            id,
178            data,
179            expiry_date,
180        } = session_record;
181
182        let mut tx = self
183            .db
184            .start_transaction()
185            .await
186            .map_err(RormStoreError::from)?;
187
188        let existing_session = query!(&mut tx, S)
189            .condition(S::get_primary_field().equals(id.to_string()))
190            .optional()
191            .await
192            .map_err(RormStoreError::from)?;
193
194        if existing_session.is_some() {
195            update!(&mut tx, S)
196                .condition(S::get_primary_field().equals(id.to_string()))
197                .set(S::get_expires_at_field(), *expiry_date)
198                .set(S::get_data_field(), Json(data.clone()))
199                .exec()
200                .await
201                .map_err(RormStoreError::from)?;
202        } else {
203            insert!(&mut tx, S)
204                .single(&S::get_insert_patch(
205                    id.to_string(),
206                    *expiry_date,
207                    Json(data.clone()),
208                ))
209                .await
210                .map_err(RormStoreError::from)?;
211        }
212
213        tx.commit().await.map_err(RormStoreError::from)?;
214
215        Ok(())
216    }
217
218    #[instrument(level = "trace")]
219    async fn load(&self, session_id: &Id) -> Result<Option<Record>> {
220        debug!("Loading session");
221        let db = &self.db;
222
223        let session = query!(db, S)
224            .condition(and!(
225                S::get_primary_field().equals(session_id.to_string()),
226                S::get_expires_at_field().greater_than(OffsetDateTime::now_utc())
227            ))
228            .optional()
229            .await
230            .map_err(RormStoreError::from)?;
231
232        Ok(match session {
233            None => None,
234            Some(session) => {
235                let (id, expiry, data) = session.get_session_data();
236
237                Some(Record {
238                    id: Id::from_str(id.as_str()).map_err(RormStoreError::from)?,
239                    data: data.into_inner(),
240                    expiry_date: expiry,
241                })
242            }
243        })
244    }
245
246    #[instrument(level = "trace")]
247    async fn delete(&self, session_id: &Id) -> Result<()> {
248        let db = &self.db;
249
250        delete!(db, S)
251            .condition(S::get_primary_field().equals(session_id.to_string()))
252            .await
253            .map_err(RormStoreError::from)?;
254
255        Ok(())
256    }
257}
258
259/// Error type that is used in the [SessionStore] trait
260#[derive(Debug, Error)]
261#[allow(missing_docs)]
262pub enum RormStoreError {
263    #[error("Database error: {0}")]
264    Database(#[from] rorm::Error),
265    #[error("Decoding of id failed: {0}")]
266    DecodingFailed(#[from] base64::DecodeSliceError),
267}
268
269impl From<RormStoreError> for Error {
270    fn from(value: RormStoreError) -> Self {
271        Self::Backend(value.to_string())
272    }
273}