Skip to main content

tower_sessions_ext/
service.rs

1//! A middleware that provides [`Session`] as a request extension.
2use std::{
3    borrow::Cow,
4    fmt,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use http::{Request, Response};
12use time::OffsetDateTime;
13#[cfg(any(feature = "signed", feature = "private"))]
14use tower_cookies::Key;
15use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
16use tower_layer::Layer;
17use tower_service::Service;
18use tracing::Instrument;
19
20use crate::{
21    Session, SessionStore,
22    session::{self, Expiry},
23};
24
25#[doc(hidden)]
26pub trait CookieController: Clone + Send + 'static {
27    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
28    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
29    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
30}
31
32#[doc(hidden)]
33#[derive(Debug, Clone)]
34pub struct PlaintextCookie;
35
36impl CookieController for PlaintextCookie {
37    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
38        cookies.get(name).map(Cookie::into_owned)
39    }
40
41    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
42        cookies.add(cookie)
43    }
44
45    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
46        cookies.remove(cookie)
47    }
48}
49
50#[doc(hidden)]
51#[cfg(feature = "signed")]
52#[derive(Debug, Clone)]
53pub struct SignedCookie {
54    key: Key,
55}
56
57#[cfg(feature = "signed")]
58impl CookieController for SignedCookie {
59    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
60        cookies.signed(&self.key).get(name).map(Cookie::into_owned)
61    }
62
63    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
64        cookies.signed(&self.key).add(cookie)
65    }
66
67    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
68        cookies.signed(&self.key).remove(cookie)
69    }
70}
71
72#[doc(hidden)]
73#[cfg(feature = "private")]
74#[derive(Debug, Clone)]
75pub struct PrivateCookie {
76    key: Key,
77}
78
79#[cfg(feature = "private")]
80impl CookieController for PrivateCookie {
81    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
82        cookies.private(&self.key).get(name).map(Cookie::into_owned)
83    }
84
85    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
86        cookies.private(&self.key).add(cookie)
87    }
88
89    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
90        cookies.private(&self.key).remove(cookie)
91    }
92}
93
94#[derive(Clone)]
95struct SessionConfig<'a> {
96    name: Cow<'a, str>,
97    http_only: bool,
98    same_site: SameSite,
99    expiry: Option<Expiry>,
100    secure: bool,
101    path: Cow<'a, str>,
102    domain: Option<Cow<'a, str>>,
103    always_save: bool,
104}
105
106impl fmt::Debug for SessionConfig<'_> {
107    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
108        f.debug_struct("SessionConfig")
109            .field("name", &self.name)
110            .field("http_only", &self.http_only)
111            .field("same_site", &self.same_site)
112            .field("expiry", &self.expiry)
113            .field("secure", &self.secure)
114            .field("path", &self.path)
115            .field("domain", &self.domain)
116            .field("always_save", &self.always_save)
117            .finish()
118    }
119}
120
121impl<'a> SessionConfig<'a> {
122    fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
123        let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
124            .http_only(self.http_only)
125            .same_site(self.same_site)
126            .secure(self.secure)
127            .path(self.path);
128
129        cookie_builder = match expiry {
130            Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
131            Some(Expiry::AtDateTime(datetime)) => {
132                cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
133            }
134            // Session cookie: no Max-Age so the browser treats it as ending when the session ends.
135            Some(Expiry::OnSessionEnd(_)) | None => cookie_builder,
136        };
137
138        if let Some(domain) = self.domain {
139            cookie_builder = cookie_builder.domain(domain);
140        }
141
142        cookie_builder.build()
143    }
144}
145
146impl Default for SessionConfig<'_> {
147    fn default() -> Self {
148        Self {
149            name: "id".into(), /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */
150            http_only: true,
151            same_site: SameSite::Strict,
152            expiry: None, // TODO: Is `Max-Age: "Session"` the right default?
153            secure: true,
154            path: "/".into(),
155            domain: None,
156            always_save: false,
157        }
158    }
159}
160
161/// A middleware that provides [`Session`] as a request extension.
162#[derive(Debug, Clone)]
163pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
164    inner: S,
165    session_store: Arc<Store>,
166    session_config: SessionConfig<'static>,
167    cookie_controller: C,
168}
169
170impl<S, Store: SessionStore> SessionManager<S, Store> {
171    /// Create a new [`SessionManager`].
172    pub fn new(inner: S, session_store: Store) -> Self {
173        Self {
174            inner,
175            session_store: Arc::new(session_store),
176            session_config: Default::default(),
177            cookie_controller: PlaintextCookie,
178        }
179    }
180}
181
182impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
183    for SessionManager<S, Store, C>
184where
185    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
186    S::Future: Send,
187    ReqBody: Send + 'static,
188    ResBody: Default + Send,
189{
190    type Response = S::Response;
191    type Error = S::Error;
192    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
193
194    #[inline]
195    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
196        self.inner.poll_ready(cx)
197    }
198
199    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
200        let span = tracing::debug_span!("call");
201
202        let session_store = self.session_store.clone();
203        let session_config = self.session_config.clone();
204        let cookie_controller = self.cookie_controller.clone();
205
206        // Because the inner service can panic until ready, we need to ensure we only
207        // use the ready service.
208        //
209        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
210        let clone = self.inner.clone();
211        let mut inner = std::mem::replace(&mut self.inner, clone);
212
213        Box::pin(
214            async move {
215                let Some(cookies) = req.extensions().get::<_>().cloned() else {
216                    // In practice this should never happen because we wrap `CookieManager`
217                    // directly.
218                    tracing::error!("missing cookies request extension");
219                    return Ok(Response::default());
220                };
221
222                let session_cookie = cookie_controller.get(&cookies, &session_config.name);
223                let session_id = session_cookie.as_ref().and_then(|cookie| {
224                    cookie
225                        .value()
226                        .parse::<session::Id>()
227                        .map_err(|err| {
228                            tracing::warn!(
229                                err = %err,
230                                "possibly suspicious activity: malformed session id"
231                            )
232                        })
233                        .ok()
234                });
235
236                let session = Session::new(
237                    session_id,
238                    session_store,
239                    session_config.expiry,
240                );
241
242                req.extensions_mut().insert(session.clone());
243
244                let res = inner.call(req).await?;
245
246                let modified = session.is_modified();
247                let empty = session.is_empty().await;
248
249                tracing::trace!(
250                    modified = modified,
251                    empty = empty,
252                    always_save = session_config.always_save,
253                    "session response state",
254                );
255
256                match session_cookie {
257                    Some(mut cookie) if empty => {
258                        tracing::debug!("removing session cookie");
259
260                        // Path and domain must be manually set to ensure a proper removal cookie is
261                        // constructed.
262                        //
263                        // See: https://docs.rs/cookie/latest/cookie/struct.CookieJar.html#method.remove
264                        cookie.set_path(session_config.path);
265                        if let Some(domain) = session_config.domain {
266                            cookie.set_domain(domain);
267                        }
268
269                        cookie_controller.remove(&cookies, cookie);
270                    }
271
272                    _ if (modified || session_config.always_save)
273                        && !empty
274                        && !res.status().is_server_error() =>
275                    {
276                        tracing::debug!("saving session");
277                        if let Err(err) = session.save().await {
278                            tracing::error!(err = %err, "failed to save session");
279
280                            let mut res = Response::default();
281                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
282                            return Ok(res);
283                        }
284
285                        let Some(session_id) = session.id() else {
286                            tracing::error!("missing session id");
287
288                            let mut res = Response::default();
289                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
290                            return Ok(res);
291                        };
292
293                        let expiry = session.expiry();
294                        let session_cookie = session_config.build_cookie(session_id, expiry);
295
296                        tracing::debug!("adding session cookie");
297                        cookie_controller.add(&cookies, session_cookie);
298                    }
299
300                    _ => (),
301                };
302
303                Ok(res)
304            }
305            .instrument(span),
306        )
307    }
308}
309
310/// A layer for providing [`Session`] as a request extension.
311#[derive(Debug, Clone)]
312pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
313    session_store: Arc<Store>,
314    session_config: SessionConfig<'static>,
315    cookie_controller: C,
316}
317
318impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
319    /// Configures the name of the cookie used for the session.
320    /// The default value is `"id"`.
321    ///
322    /// # Examples
323    ///
324    /// ```rust
325    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
326    ///
327    /// let session_store = MemoryStore::default();
328    /// let session_service = SessionManagerLayer::new(session_store).with_name("my.sid");
329    /// ```
330    pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
331        self.session_config.name = name.into();
332        self
333    }
334
335    /// Configures the `"HttpOnly"` attribute of the cookie used for the
336    /// session.
337    ///
338    /// # ⚠️ **Warning: Cross-site scripting risk**
339    ///
340    /// Applications should generally **not** override the default value of
341    /// `true`. If you do, you are exposing your application to increased risk
342    /// of cookie theft via techniques like cross-site scripting.
343    ///
344    /// # Examples
345    ///
346    /// ```rust
347    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
348    ///
349    /// let session_store = MemoryStore::default();
350    /// let session_service = SessionManagerLayer::new(session_store).with_http_only(true);
351    /// ```
352    pub fn with_http_only(mut self, http_only: bool) -> Self {
353        self.session_config.http_only = http_only;
354        self
355    }
356
357    /// Configures the `"SameSite"` attribute of the cookie used for the
358    /// session.
359    /// The default value is [`SameSite::Strict`].
360    ///
361    /// # Examples
362    ///
363    /// ```rust
364    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::SameSite};
365    ///
366    /// let session_store = MemoryStore::default();
367    /// let session_service = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
368    /// ```
369    pub fn with_same_site(mut self, same_site: SameSite) -> Self {
370        self.session_config.same_site = same_site;
371        self
372    }
373
374    /// Configures the `"Max-Age"` attribute of the cookie used for the session.
375    /// The default value is `None`.
376    ///
377    /// # Examples
378    ///
379    /// ```rust
380    /// use time::Duration;
381    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
382    ///
383    /// let session_store = MemoryStore::default();
384    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
385    /// let session_service = SessionManagerLayer::new(session_store).with_expiry(session_expiry);
386    /// ```
387    pub fn with_expiry(mut self, expiry: Expiry) -> Self {
388        self.session_config.expiry = Some(expiry);
389        self
390    }
391
392    /// Configures the `"Secure"` attribute of the cookie used for the session.
393    /// The default value is `true`.
394    ///
395    /// # Examples
396    ///
397    /// ```rust
398    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
399    ///
400    /// let session_store = MemoryStore::default();
401    /// let session_service = SessionManagerLayer::new(session_store).with_secure(true);
402    /// ```
403    pub fn with_secure(mut self, secure: bool) -> Self {
404        self.session_config.secure = secure;
405        self
406    }
407
408    /// Configures the `"Path"` attribute of the cookie used for the session.
409    /// The default value is `"/"`.
410    ///
411    /// # Examples
412    ///
413    /// ```rust
414    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
415    ///
416    /// let session_store = MemoryStore::default();
417    /// let session_service = SessionManagerLayer::new(session_store).with_path("/some/path");
418    /// ```
419    pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
420        self.session_config.path = path.into();
421        self
422    }
423
424    /// Configures the `"Domain"` attribute of the cookie used for the session.
425    /// The default value is `None`.
426    ///
427    /// # Examples
428    ///
429    /// ```rust
430    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
431    ///
432    /// let session_store = MemoryStore::default();
433    /// let session_service = SessionManagerLayer::new(session_store).with_domain("localhost");
434    /// ```
435    pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
436        self.session_config.domain = Some(domain.into());
437        self
438    }
439
440    /// Configures whether unmodified session should be saved on read or not.
441    /// When the value is `true`, the session will be saved even if it was not
442    /// changed.
443    ///
444    /// This is useful when you want to reset [`Session`] expiration time
445    /// on any valid request at the cost of higher [`SessionStore`] write
446    /// activity and transmitting `set-cookie` header with each response.
447    ///
448    /// It makes sense to use this setting with relative session expiration
449    /// values, such as `Expiry::OnInactivity(Duration)`. This setting will
450    /// _not_ cause session id to be cycled on save.
451    ///
452    /// The default value is `false`.
453    ///
454    /// # Examples
455    ///
456    /// ```rust
457    /// use time::Duration;
458    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
459    ///
460    /// let session_store = MemoryStore::default();
461    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
462    /// let session_service = SessionManagerLayer::new(session_store)
463    ///     .with_expiry(session_expiry)
464    ///     .with_always_save(true);
465    /// ```
466    pub fn with_always_save(mut self, always_save: bool) -> Self {
467        self.session_config.always_save = always_save;
468        self
469    }
470
471    /// Manages the session cookie via a signed interface.
472    ///
473    /// See [`SignedCookies`](tower_cookies::SignedCookies).
474    ///
475    /// ```rust
476    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
477    ///
478    /// # /*
479    /// let key = { /* a cryptographically random key >= 64 bytes */ };
480    /// # */
481    /// # let key: &Vec<u8> = &(0..64).collect();
482    /// # let key: &[u8] = &key[..];
483    /// # let key = Key::try_from(key).unwrap();
484    ///
485    /// let session_store = MemoryStore::default();
486    /// let session_service = SessionManagerLayer::new(session_store).with_signed(key);
487    /// ```
488    #[cfg(feature = "signed")]
489    pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
490        SessionManagerLayer::<Store, SignedCookie> {
491            session_store: self.session_store,
492            session_config: self.session_config,
493            cookie_controller: SignedCookie { key },
494        }
495    }
496
497    /// Manages the session cookie via an encrypted interface.
498    ///
499    /// See [`PrivateCookies`](tower_cookies::PrivateCookies).
500    ///
501    /// ```rust
502    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
503    ///
504    /// # /*
505    /// let key = { /* a cryptographically random key >= 64 bytes */ };
506    /// # */
507    /// # let key: &Vec<u8> = &(0..64).collect();
508    /// # let key: &[u8] = &key[..];
509    /// # let key = Key::try_from(key).unwrap();
510    ///
511    /// let session_store = MemoryStore::default();
512    /// let session_service = SessionManagerLayer::new(session_store).with_private(key);
513    /// ```
514    #[cfg(feature = "private")]
515    pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
516        SessionManagerLayer::<Store, PrivateCookie> {
517            session_store: self.session_store,
518            session_config: self.session_config,
519            cookie_controller: PrivateCookie { key },
520        }
521    }
522}
523
524impl<Store: SessionStore> SessionManagerLayer<Store> {
525    /// Create a new [`SessionManagerLayer`] with the provided session store
526    /// and default cookie configuration.
527    ///
528    /// # Examples
529    ///
530    /// ```rust
531    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
532    ///
533    /// let session_store = MemoryStore::default();
534    /// let session_service = SessionManagerLayer::new(session_store);
535    /// ```
536    pub fn new(session_store: Store) -> Self {
537        let session_config = SessionConfig::default();
538
539        Self {
540            session_store: Arc::new(session_store),
541            session_config,
542            cookie_controller: PlaintextCookie,
543        }
544    }
545}
546
547impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
548    type Service = CookieManager<SessionManager<S, Store, C>>;
549
550    fn layer(&self, inner: S) -> Self::Service {
551        let session_manager = SessionManager {
552            inner,
553            session_store: self.session_store.clone(),
554            session_config: self.session_config.clone(),
555            cookie_controller: self.cookie_controller.clone(),
556        };
557
558        CookieManager::new(session_manager)
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use std::str::FromStr;
565
566    use anyhow::anyhow;
567    use axum::body::Body;
568    use tower::{ServiceBuilder, ServiceExt};
569    use tower_sessions_ext_core::session::DEFAULT_DURATION;
570    use tower_sessions_ext_memory_store::MemoryStore;
571
572    use super::*;
573    use crate::session::{Id, Record};
574
575    async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
576        let session = req
577            .extensions()
578            .get::<Session>()
579            .ok_or(anyhow!("Missing session"))?;
580
581        session.insert("foo", 42).await?;
582
583        Ok(Response::new(Body::empty()))
584    }
585
586    async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
587        Ok(Response::new(Body::empty()))
588    }
589
590    #[tokio::test]
591    async fn basic_service_test() -> anyhow::Result<()> {
592        let session_store = MemoryStore::default();
593        let session_layer = SessionManagerLayer::new(session_store);
594        let svc = ServiceBuilder::new()
595            .layer(session_layer)
596            .service_fn(handler);
597
598        let req = Request::builder().body(Body::empty())?;
599        let res = svc.clone().oneshot(req).await?;
600
601        let session = res.headers().get(http::header::SET_COOKIE);
602        assert!(session.is_some());
603
604        let req = Request::builder()
605            .header(http::header::COOKIE, session.unwrap())
606            .body(Body::empty())?;
607        let res = svc.oneshot(req).await?;
608
609        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
610
611        Ok(())
612    }
613
614    #[tokio::test]
615    async fn bogus_cookie_test() -> anyhow::Result<()> {
616        let session_store = MemoryStore::default();
617        let session_layer = SessionManagerLayer::new(session_store);
618        let svc = ServiceBuilder::new()
619            .layer(session_layer)
620            .service_fn(handler);
621
622        let req = Request::builder().body(Body::empty())?;
623        let res = svc.clone().oneshot(req).await?;
624
625        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
626
627        let req = Request::builder()
628            .header(http::header::COOKIE, "id=bogus")
629            .body(Body::empty())?;
630        let res = svc.oneshot(req).await?;
631
632        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
633
634        Ok(())
635    }
636
637    #[tokio::test]
638    async fn no_set_cookie_test() -> anyhow::Result<()> {
639        let session_store = MemoryStore::default();
640        let session_layer = SessionManagerLayer::new(session_store);
641        let svc = ServiceBuilder::new()
642            .layer(session_layer)
643            .service_fn(noop_handler);
644
645        let req = Request::builder().body(Body::empty())?;
646        let res = svc.oneshot(req).await?;
647
648        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
649
650        Ok(())
651    }
652
653    #[tokio::test]
654    async fn name_test() -> anyhow::Result<()> {
655        let session_store = MemoryStore::default();
656        let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
657        let svc = ServiceBuilder::new()
658            .layer(session_layer)
659            .service_fn(handler);
660
661        let req = Request::builder().body(Body::empty())?;
662        let res = svc.oneshot(req).await?;
663
664        assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
665
666        Ok(())
667    }
668
669    #[tokio::test]
670    async fn http_only_test() -> anyhow::Result<()> {
671        let session_store = MemoryStore::default();
672        let session_layer = SessionManagerLayer::new(session_store);
673        let svc = ServiceBuilder::new()
674            .layer(session_layer)
675            .service_fn(handler);
676
677        let req = Request::builder().body(Body::empty())?;
678        let res = svc.oneshot(req).await?;
679
680        assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
681
682        let session_store = MemoryStore::default();
683        let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
684        let svc = ServiceBuilder::new()
685            .layer(session_layer)
686            .service_fn(handler);
687
688        let req = Request::builder().body(Body::empty())?;
689        let res = svc.oneshot(req).await?;
690
691        assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
692
693        Ok(())
694    }
695
696    #[tokio::test]
697    async fn same_site_strict_test() -> anyhow::Result<()> {
698        let session_store = MemoryStore::default();
699        let session_layer =
700            SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
701        let svc = ServiceBuilder::new()
702            .layer(session_layer)
703            .service_fn(handler);
704
705        let req = Request::builder().body(Body::empty())?;
706        let res = svc.oneshot(req).await?;
707
708        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
709
710        Ok(())
711    }
712
713    #[tokio::test]
714    async fn same_site_lax_test() -> anyhow::Result<()> {
715        let session_store = MemoryStore::default();
716        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
717        let svc = ServiceBuilder::new()
718            .layer(session_layer)
719            .service_fn(handler);
720
721        let req = Request::builder().body(Body::empty())?;
722        let res = svc.oneshot(req).await?;
723
724        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
725
726        Ok(())
727    }
728
729    #[tokio::test]
730    async fn same_site_none_test() -> anyhow::Result<()> {
731        let session_store = MemoryStore::default();
732        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
733        let svc = ServiceBuilder::new()
734            .layer(session_layer)
735            .service_fn(handler);
736
737        let req = Request::builder().body(Body::empty())?;
738        let res = svc.oneshot(req).await?;
739
740        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
741
742        Ok(())
743    }
744
745    #[tokio::test]
746    async fn expiry_on_session_end_test() -> anyhow::Result<()> {
747        let session_store = MemoryStore::default();
748        let session_layer = SessionManagerLayer::new(session_store)
749            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
750        let svc = ServiceBuilder::new()
751            .layer(session_layer)
752            .service_fn(handler);
753
754        let req = Request::builder().body(Body::empty())?;
755        let res = svc.oneshot(req).await?;
756
757        assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
758
759        Ok(())
760    }
761
762    #[tokio::test]
763    async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
764        let session_store = MemoryStore::default();
765        let inactivity_duration = time::Duration::hours(2);
766        let session_layer = SessionManagerLayer::new(session_store)
767            .with_expiry(Expiry::OnInactivity(inactivity_duration));
768        let svc = ServiceBuilder::new()
769            .layer(session_layer)
770            .service_fn(handler);
771
772        let req = Request::builder().body(Body::empty())?;
773        let res = svc.oneshot(req).await?;
774
775        let expected_max_age = inactivity_duration.whole_seconds();
776        assert!(cookie_has_expected_max_age(&res, expected_max_age));
777
778        Ok(())
779    }
780
781    #[tokio::test]
782    async fn expiry_at_date_time_test() -> anyhow::Result<()> {
783        let session_store = MemoryStore::default();
784        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
785        let session_layer =
786            SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
787        let svc = ServiceBuilder::new()
788            .layer(session_layer)
789            .service_fn(handler);
790
791        let req = Request::builder().body(Body::empty())?;
792        let res = svc.oneshot(req).await?;
793
794        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
795        assert!(cookie_has_expected_max_age(&res, expected_max_age));
796
797        Ok(())
798    }
799
800    #[tokio::test]
801    async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
802        let session_store = MemoryStore::default();
803        let session_layer = SessionManagerLayer::new(session_store.clone())
804            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
805            .with_always_save(true);
806        let mut svc = ServiceBuilder::new()
807            .layer(session_layer)
808            .service_fn(handler);
809
810        let req1 = Request::builder().body(Body::empty())?;
811        let res1 = svc.call(req1).await?;
812        let sid1 = get_session_id(&res1);
813        let rec1 = get_record(&session_store, &sid1).await;
814        let req2 = Request::builder()
815            .header(http::header::COOKIE, format!("id={}", sid1))
816            .body(Body::empty())?;
817        let res2 = svc.call(req2).await?;
818        let sid2 = get_session_id(&res2);
819        let rec2 = get_record(&session_store, &sid2).await;
820
821        assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
822        assert!(sid1 == sid2);
823        assert!(rec1.expiry_date < rec2.expiry_date);
824
825        Ok(())
826    }
827
828    #[tokio::test]
829    async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
830        let session_store = MemoryStore::default();
831        let inactivity_duration = time::Duration::hours(2);
832        let session_layer = SessionManagerLayer::new(session_store.clone())
833            .with_expiry(Expiry::OnInactivity(inactivity_duration))
834            .with_always_save(true);
835        let mut svc = ServiceBuilder::new()
836            .layer(session_layer)
837            .service_fn(handler);
838
839        let req1 = Request::builder().body(Body::empty())?;
840        let res1 = svc.call(req1).await?;
841        let sid1 = get_session_id(&res1);
842        let rec1 = get_record(&session_store, &sid1).await;
843        let req2 = Request::builder()
844            .header(http::header::COOKIE, format!("id={}", sid1))
845            .body(Body::empty())?;
846        let res2 = svc.call(req2).await?;
847        let sid2 = get_session_id(&res2);
848        let rec2 = get_record(&session_store, &sid2).await;
849
850        let expected_max_age = inactivity_duration.whole_seconds();
851        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
852        assert!(sid1 == sid2);
853        assert!(rec1.expiry_date < rec2.expiry_date);
854
855        Ok(())
856    }
857
858    #[tokio::test]
859    async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
860        let session_store = MemoryStore::default();
861        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
862        let session_layer = SessionManagerLayer::new(session_store.clone())
863            .with_expiry(Expiry::AtDateTime(expiry_time))
864            .with_always_save(true);
865        let mut svc = ServiceBuilder::new()
866            .layer(session_layer)
867            .service_fn(handler);
868
869        let req1 = Request::builder().body(Body::empty())?;
870        let res1 = svc.call(req1).await?;
871        let sid1 = get_session_id(&res1);
872        let rec1 = get_record(&session_store, &sid1).await;
873        let req2 = Request::builder()
874            .header(http::header::COOKIE, format!("id={}", sid1))
875            .body(Body::empty())?;
876        let res2 = svc.call(req2).await?;
877        let sid2 = get_session_id(&res2);
878        let rec2 = get_record(&session_store, &sid2).await;
879
880        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
881        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
882        assert!(sid1 == sid2);
883        assert!(rec1.expiry_date == rec2.expiry_date);
884
885        Ok(())
886    }
887
888    #[tokio::test]
889    async fn secure_test() -> anyhow::Result<()> {
890        let session_store = MemoryStore::default();
891        let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
892        let svc = ServiceBuilder::new()
893            .layer(session_layer)
894            .service_fn(handler);
895
896        let req = Request::builder().body(Body::empty())?;
897        let res = svc.oneshot(req).await?;
898
899        assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
900
901        let session_store = MemoryStore::default();
902        let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
903        let svc = ServiceBuilder::new()
904            .layer(session_layer)
905            .service_fn(handler);
906
907        let req = Request::builder().body(Body::empty())?;
908        let res = svc.oneshot(req).await?;
909
910        assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
911
912        Ok(())
913    }
914
915    #[tokio::test]
916    async fn path_test() -> anyhow::Result<()> {
917        let session_store = MemoryStore::default();
918        let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
919        let svc = ServiceBuilder::new()
920            .layer(session_layer)
921            .service_fn(handler);
922
923        let req = Request::builder().body(Body::empty())?;
924        let res = svc.oneshot(req).await?;
925
926        assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
927
928        Ok(())
929    }
930
931    #[tokio::test]
932    async fn domain_test() -> anyhow::Result<()> {
933        let session_store = MemoryStore::default();
934        let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
935        let svc = ServiceBuilder::new()
936            .layer(session_layer)
937            .service_fn(handler);
938
939        let req = Request::builder().body(Body::empty())?;
940        let res = svc.oneshot(req).await?;
941
942        assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
943
944        Ok(())
945    }
946
947    #[cfg(feature = "signed")]
948    #[tokio::test]
949    async fn signed_test() -> anyhow::Result<()> {
950        let key = Key::generate();
951        let session_store = MemoryStore::default();
952        let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
953        let svc = ServiceBuilder::new()
954            .layer(session_layer)
955            .service_fn(handler);
956
957        let req = Request::builder().body(Body::empty())?;
958        let res = svc.oneshot(req).await?;
959
960        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
961
962        Ok(())
963    }
964
965    #[cfg(feature = "private")]
966    #[tokio::test]
967    async fn private_test() -> anyhow::Result<()> {
968        let key = Key::generate();
969        let session_store = MemoryStore::default();
970        let session_layer = SessionManagerLayer::new(session_store).with_private(key);
971        let svc = ServiceBuilder::new()
972            .layer(session_layer)
973            .service_fn(handler);
974
975        let req = Request::builder().body(Body::empty())?;
976        let res = svc.oneshot(req).await?;
977
978        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
979
980        Ok(())
981    }
982
983    fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
984    where
985        F: FnOnce(&str) -> bool,
986    {
987        res.headers()
988            .get(http::header::SET_COOKIE)
989            .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
990    }
991
992    fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
993        res.headers()
994            .get(http::header::SET_COOKIE)
995            .is_some_and(|set_cookie| {
996                set_cookie.to_str().is_ok_and(|s| {
997                    let max_age_value = s
998                        .split("Max-Age=")
999                        .nth(1)
1000                        .unwrap_or_default()
1001                        .split(';')
1002                        .next()
1003                        .unwrap_or_default()
1004                        .parse::<i64>()
1005                        .unwrap_or_default();
1006                    (max_age_value - expected_value).abs() <= 1
1007                })
1008            })
1009    }
1010
1011    fn get_session_id(res: &Response<Body>) -> String {
1012        res.headers()
1013            .get(http::header::SET_COOKIE)
1014            .unwrap()
1015            .to_str()
1016            .unwrap()
1017            .split("id=")
1018            .nth(1)
1019            .unwrap()
1020            .split(";")
1021            .next()
1022            .unwrap()
1023            .to_string()
1024    }
1025
1026    async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1027        store
1028            .load(&Id::from_str(id).unwrap())
1029            .await
1030            .unwrap()
1031            .unwrap()
1032    }
1033}