tower_sessions_rorm_store/
lib.rs1#![warn(missing_docs)]
11
12use std::collections::HashMap;
13use std::fmt::Debug;
14use std::fmt::Formatter;
15use std::marker::PhantomData;
16use std::str::FromStr;
17
18use async_trait::async_trait;
19use rorm::and;
20use rorm::delete;
21use rorm::fields::types::Json;
22use rorm::insert;
23use rorm::internal::field::Field;
24use rorm::internal::field::FieldProxy;
25use rorm::query;
26use rorm::update;
27use rorm::Database;
28use rorm::FieldAccess;
29use rorm::Model;
30use rorm::Patch;
31pub use serde_json::Value;
32use thiserror::Error;
33pub use tower_sessions;
35use tower_sessions::cookie::time::OffsetDateTime;
36use tower_sessions::session::Id;
37use tower_sessions::session::Record;
38use tower_sessions::session_store::Error;
39use tower_sessions::session_store::Result;
40use tower_sessions::ExpiredDeletion;
41use tower_sessions::SessionStore;
42use tracing::debug;
43use tracing::instrument;
44
45pub trait SessionModel
51where
52 Self: Model + Send + Sync + 'static,
53 Self::Primary: Field<Type = String>,
54{
55 fn get_primary_field() -> FieldProxy<Self::Primary, Self> {
57 FieldProxy::new()
58 }
59
60 fn get_expires_at_field() -> FieldProxy<impl Field<Type = OffsetDateTime, Model = Self>, Self>;
62
63 fn get_data_field(
65 ) -> FieldProxy<impl Field<Type = Json<HashMap<String, Value>>, Model = Self>, Self>;
66
67 fn get_insert_patch(
70 id: String,
71 expires_at: OffsetDateTime,
72 data: Json<HashMap<String, Value>>,
73 ) -> impl Patch<Model = Self> + Send + Sync + 'static;
74
75 fn get_session_data(&self) -> (String, OffsetDateTime, Json<HashMap<String, Value>>);
77}
78
79pub struct RormStore<S> {
81 db: Database,
82 marker: PhantomData<S>,
83}
84
85impl<S> RormStore<S> {
86 pub fn new(db: Database) -> Self {
88 Self {
89 db,
90 marker: PhantomData,
91 }
92 }
93}
94
95impl<S> Debug for RormStore<S> {
96 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
97 write!(f, "")
98 }
99}
100
101impl<S> Clone for RormStore<S> {
102 fn clone(&self) -> Self {
103 Self {
104 db: self.db.clone(),
105 marker: PhantomData,
106 }
107 }
108}
109
110#[async_trait]
111impl<S> ExpiredDeletion for RormStore<S>
112where
113 S: Model + SessionModel + Debug,
114 <S as Model>::Primary: Field<Type = String>,
115 <S as Patch>::Decoder: Send + Sync + 'static,
116{
117 #[instrument(level = "trace")]
118 async fn delete_expired(&self) -> Result<()> {
119 let db = &self.db;
120
121 delete!(db, S)
122 .condition(S::get_expires_at_field().less_than(OffsetDateTime::now_utc()))
123 .await
124 .map_err(RormStoreError::from)?;
125
126 Ok(())
127 }
128}
129
130#[async_trait]
131impl<S> SessionStore for RormStore<S>
132where
133 S: Model + Send + Sync + SessionModel,
134 <S as Model>::Primary: Field<Type = String>,
135 <S as Patch>::Decoder: Send + Sync + 'static,
136{
137 #[instrument(level = "trace")]
138 async fn create(&self, session_record: &mut Record) -> Result<()> {
139 debug!("Creating new session");
140 let mut tx = self
141 .db
142 .start_transaction()
143 .await
144 .map_err(RormStoreError::from)?;
145 loop {
146 let existing = query!(&mut tx, S)
147 .condition(S::get_primary_field().equals(session_record.id.to_string()))
148 .optional()
149 .await
150 .map_err(RormStoreError::from)?;
151
152 if existing.is_none() {
153 insert!(&mut tx, S)
154 .return_nothing()
155 .single(&S::get_insert_patch(
156 session_record.id.to_string(),
157 session_record.expiry_date,
158 Json(session_record.data.clone()),
159 ))
160 .await
161 .map_err(RormStoreError::from)?;
162
163 break;
164 }
165
166 session_record.id = Id::default();
167 }
168
169 tx.commit().await.map_err(RormStoreError::from)?;
170
171 Ok(())
172 }
173
174 #[instrument(level = "trace")]
175 async fn save(&self, session_record: &Record) -> Result<()> {
176 let Record {
177 id,
178 data,
179 expiry_date,
180 } = session_record;
181
182 let mut tx = self
183 .db
184 .start_transaction()
185 .await
186 .map_err(RormStoreError::from)?;
187
188 let existing_session = query!(&mut tx, S)
189 .condition(S::get_primary_field().equals(id.to_string()))
190 .optional()
191 .await
192 .map_err(RormStoreError::from)?;
193
194 if existing_session.is_some() {
195 update!(&mut tx, S)
196 .condition(S::get_primary_field().equals(id.to_string()))
197 .set(S::get_expires_at_field(), *expiry_date)
198 .set(S::get_data_field(), Json(data.clone()))
199 .exec()
200 .await
201 .map_err(RormStoreError::from)?;
202 } else {
203 insert!(&mut tx, S)
204 .single(&S::get_insert_patch(
205 id.to_string(),
206 *expiry_date,
207 Json(data.clone()),
208 ))
209 .await
210 .map_err(RormStoreError::from)?;
211 }
212
213 tx.commit().await.map_err(RormStoreError::from)?;
214
215 Ok(())
216 }
217
218 #[instrument(level = "trace")]
219 async fn load(&self, session_id: &Id) -> Result<Option<Record>> {
220 debug!("Loading session");
221 let db = &self.db;
222
223 let session = query!(db, S)
224 .condition(and!(
225 S::get_primary_field().equals(session_id.to_string()),
226 S::get_expires_at_field().greater_than(OffsetDateTime::now_utc())
227 ))
228 .optional()
229 .await
230 .map_err(RormStoreError::from)?;
231
232 Ok(match session {
233 None => None,
234 Some(session) => {
235 let (id, expiry, data) = session.get_session_data();
236
237 Some(Record {
238 id: Id::from_str(id.as_str()).map_err(RormStoreError::from)?,
239 data: data.into_inner(),
240 expiry_date: expiry,
241 })
242 }
243 })
244 }
245
246 #[instrument(level = "trace")]
247 async fn delete(&self, session_id: &Id) -> Result<()> {
248 let db = &self.db;
249
250 delete!(db, S)
251 .condition(S::get_primary_field().equals(session_id.to_string()))
252 .await
253 .map_err(RormStoreError::from)?;
254
255 Ok(())
256 }
257}
258
259#[derive(Debug, Error)]
261#[allow(missing_docs)]
262pub enum RormStoreError {
263 #[error("Database error: {0}")]
264 Database(#[from] rorm::Error),
265 #[error("Decoding of id failed: {0}")]
266 DecodingFailed(#[from] base64::DecodeSliceError),
267}
268
269impl From<RormStoreError> for Error {
270 fn from(value: RormStoreError) -> Self {
271 Self::Backend(value.to_string())
272 }
273}