1use std::{
2 ops::{Deref, DerefMut},
3 sync::Arc,
4};
5
6use async_trait::async_trait;
7use parking_lot::{Mutex, MutexGuard};
8use tower_sesh_core::{store::Ttl, Record, SessionKey};
9
10pub struct Session<T>(Arc<Mutex<Inner<T>>>);
22
23pub struct SessionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
38
39pub struct OptionSessionGuard<'a, T>(MutexGuard<'a, Inner<T>>);
45
46struct Inner<T> {
47 session_key: Option<SessionKey>,
48 data: Option<T>,
49 expires_at: Option<Ttl>,
50 status: Status,
51}
52
53enum Status {
60 Unchanged,
61 Renewed,
62 Changed,
63 Purged,
64}
65use Status::*;
66
67impl<T> Inner<T> {
68 fn changed(&mut self) {
69 if !matches!(self.status, Purged) {
70 self.status = Changed;
71 }
72 }
73}
74
75impl<T> Session<T> {
76 fn new(session_key: SessionKey, record: Record<T>) -> Session<T> {
77 let inner = Inner {
78 session_key: Some(session_key),
79 data: Some(record.data),
80 expires_at: Some(record.ttl),
81 status: Unchanged,
82 };
83 Session(Arc::new(Mutex::new(inner)))
84 }
85
86 fn empty() -> Session<T> {
87 let inner = Inner {
88 session_key: None,
89 data: None,
90 expires_at: None,
91 status: Unchanged,
92 };
93 Session(Arc::new(Mutex::new(inner)))
94 }
95
96 fn ignored(session_key: SessionKey) -> Session<T> {
97 let inner = Inner {
98 session_key: Some(session_key),
99 data: None,
100 expires_at: None,
101 status: Unchanged,
102 };
103 Session(Arc::new(Mutex::new(inner)))
104 }
105
106 #[must_use]
107 pub fn get(&self) -> OptionSessionGuard<'_, T> {
108 let lock = self.0.lock();
109
110 OptionSessionGuard::new(lock)
111 }
112
113 pub fn insert(&self, value: T) -> SessionGuard<'_, T> {
114 let mut lock = self.0.lock();
115
116 lock.data = Some(value);
117 lock.changed();
118
119 unsafe { SessionGuard::new(lock) }
122 }
123
124 pub fn get_or_insert(&self, value: T) -> SessionGuard<'_, T> {
125 let mut lock = self.0.lock();
126
127 if lock.data.is_none() {
128 lock.data = Some(value);
129 lock.changed();
130 }
131
132 unsafe { SessionGuard::new(lock) }
135 }
136
137 pub fn get_or_insert_with<F>(&self, f: F) -> SessionGuard<'_, T>
138 where
139 F: FnOnce() -> T,
140 {
141 let mut lock = self.0.lock();
142
143 if lock.data.is_none() {
144 lock.data = Some(f());
145 lock.changed();
146 }
147
148 unsafe { SessionGuard::new(lock) }
151 }
152
153 #[inline]
154 pub fn get_or_insert_default(&self) -> SessionGuard<'_, T>
155 where
156 T: Default,
157 {
158 self.get_or_insert_with(T::default)
159 }
160}
161
162impl<T> Clone for Session<T> {
163 fn clone(&self) -> Self {
164 Session(Arc::clone(&self.0))
165 }
166}
167
168define_rejection! {
169 #[status = INTERNAL_SERVER_ERROR]
170 #[body = "Failed to load session"]
171 pub struct SessionRejection;
174}
175
176#[cfg(feature = "axum")]
177#[async_trait]
178impl<S, T> axum::extract::FromRequestParts<S> for Session<T>
179where
180 T: 'static + Send + Sync,
181{
182 type Rejection = SessionRejection;
183
184 async fn from_request_parts(
185 parts: &mut http::request::Parts,
186 _state: &S,
187 ) -> Result<Self, Self::Rejection> {
188 match lazy::get_or_init(&mut parts.extensions).await {
189 Ok(Some(session)) => Ok(session),
190 Ok(None) => Err(SessionRejection),
191 Err(_) => panic!(
194 "Missing request extension. `SessionLayer` must be called \
195 before the `Session` extractor is run. Also, check that the \
196 generic type for `Session<T>` is correct."
197 ),
198 }
199 }
200}
201
202impl<'a, T> SessionGuard<'a, T> {
203 #[track_caller]
208 unsafe fn new(guard: MutexGuard<'a, Inner<T>>) -> Self {
209 debug_assert!(guard.data.is_some());
210 SessionGuard(guard)
211 }
212}
213
214impl<T> Deref for SessionGuard<'_, T> {
215 type Target = T;
216
217 fn deref(&self) -> &Self::Target {
218 unsafe { self.0.data.as_ref().unwrap_unchecked() }
221 }
222}
223
224impl<T> DerefMut for SessionGuard<'_, T> {
225 fn deref_mut(&mut self) -> &mut Self::Target {
226 self.0.changed();
227
228 unsafe { self.0.data.as_mut().unwrap_unchecked() }
231 }
232}
233
234impl<'a, T> OptionSessionGuard<'a, T> {
235 fn new(guard: MutexGuard<'a, Inner<T>>) -> Self {
236 OptionSessionGuard(guard)
237 }
238}
239
240impl<T> Deref for OptionSessionGuard<'_, T> {
241 type Target = Option<T>;
242
243 fn deref(&self) -> &Self::Target {
244 &self.0.data
245 }
246}
247
248impl<T> DerefMut for OptionSessionGuard<'_, T> {
249 fn deref_mut(&mut self) -> &mut Self::Target {
250 self.0.changed();
251
252 &mut self.0.data
253 }
254}
255
256pub(crate) mod lazy {
257 use std::{error::Error as StdError, fmt, sync::Arc};
258
259 use async_once_cell::OnceCell;
260 use cookie::Cookie;
261 use http::Extensions;
262 use tower_sesh_core::{store::ErrorKind, SessionKey, SessionStore};
263
264 use crate::{middleware::SessionConfig, util::ErrorExt};
265
266 use super::Session;
267
268 pub(crate) fn insert<T>(
269 cookie: Option<Cookie<'static>>,
270 store: &Arc<impl SessionStore<T>>,
271 extensions: &mut Extensions,
272 session_config: SessionConfig,
273 ) where
274 T: 'static + Send,
275 {
276 debug_assert!(
277 extensions.get::<LazySession<T>>().is_none(),
278 "`session::lazy::insert` was called more than once!"
279 );
280
281 let lazy_session = match cookie {
282 Some(cookie) => LazySession::new(cookie, Arc::clone(store), session_config),
283 None => LazySession::empty(),
284 };
285 extensions.insert::<LazySession<T>>(lazy_session);
286 }
287
288 pub(super) async fn get_or_init<T>(
289 extensions: &mut Extensions,
290 ) -> Result<Option<Session<T>>, Error>
291 where
292 T: 'static + Send,
293 {
294 match extensions.get::<LazySession<T>>() {
295 Some(lazy_session) => Ok(lazy_session.get_or_init().await.cloned()),
296 None => Err(Error),
297 }
298 }
299
300 pub(crate) fn take<T>(extensions: &mut Extensions) -> Result<Option<Session<T>>, Error>
301 where
302 T: 'static + Send,
303 {
304 match extensions.remove::<LazySession<T>>() {
305 Some(lazy_session) => Ok(lazy_session.get().cloned()),
306 None => Err(Error),
307 }
308 }
309
310 enum LazySession<T> {
311 Empty(Arc<OnceCell<Session<T>>>),
312 Init {
313 cookie: Cookie<'static>,
314 store: Arc<dyn SessionStore<T> + 'static>,
315 session: Arc<OnceCell<Option<Session<T>>>>,
316 config: SessionConfig,
317 },
318 }
319
320 impl<T> Clone for LazySession<T> {
321 fn clone(&self) -> Self {
322 match self {
323 LazySession::Empty(session) => LazySession::Empty(Arc::clone(session)),
324 LazySession::Init {
325 cookie,
326 store,
327 session,
328 config,
329 } => LazySession::Init {
330 cookie: cookie.clone(),
331 store: Arc::clone(store),
332 session: Arc::clone(session),
333 config: config.clone(),
334 },
335 }
336 }
337 }
338
339 impl<T> LazySession<T>
340 where
341 T: 'static,
342 {
343 fn new(
344 cookie: Cookie<'static>,
345 store: Arc<impl SessionStore<T>>,
346 config: SessionConfig,
347 ) -> LazySession<T> {
348 LazySession::Init {
349 cookie,
350 store,
351 session: Arc::new(OnceCell::new()),
352 config,
353 }
354 }
355
356 fn empty() -> LazySession<T> {
357 LazySession::Empty(Arc::new(OnceCell::new()))
358 }
359
360 async fn get_or_init(&self) -> Option<&Session<T>> {
361 match self {
362 LazySession::Empty(session) => {
363 Some(session.get_or_init(async { Session::empty() }).await)
364 }
365 LazySession::Init {
366 cookie,
367 store,
368 session,
369 config,
370 } => session
371 .get_or_init(init_session(cookie, store.as_ref(), config))
372 .await
373 .as_ref(),
374 }
375 }
376
377 fn get(&self) -> Option<&Session<T>> {
378 match self {
379 LazySession::Empty(session) => session.get(),
380 LazySession::Init { session, .. } => session.get().and_then(Option::as_ref),
381 }
382 }
383 }
384
385 async fn init_session<T>(
386 cookie: &Cookie<'static>,
387 store: &dyn SessionStore<T>,
388 config: &SessionConfig,
389 ) -> Option<Session<T>>
390 where
391 T: 'static,
392 {
393 let session_key = match SessionKey::decode(cookie.value()) {
394 Ok(session_key) => session_key,
395 Err(_) => return Some(Session::empty()),
396 };
397
398 match store.load(&session_key).await {
399 Ok(Some(record)) => Some(Session::new(session_key, record)),
400 Ok(None) => Some(Session::empty()),
401 Err(err) => {
402 match err.kind() {
403 ErrorKind::Serde(_) if config.ignore_invalid_session => {
404 Some(Session::ignored(session_key))
405 }
406 _ => {
407 error!(message = %err.display_chain());
409 None
410 }
411 }
412 }
413 }
414 }
415
416 pub(crate) struct Error;
417
418 impl StdError for Error {
419 fn source(&self) -> Option<&(dyn StdError + 'static)> {
420 None
421 }
422 }
423
424 impl fmt::Display for Error {
425 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
426 f.write_str("missing request extension")
427 }
428 }
429
430 impl fmt::Debug for Error {
431 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432 write!(f, "Error({:?})", self.to_string())
433 }
434 }
435}