tower_sessions_ext/
service.rs

1//! A middleware that provides [`Session`] as a request extension.
2use std::{
3    borrow::Cow,
4    future::Future,
5    pin::Pin,
6    sync::Arc,
7    task::{Context, Poll},
8};
9
10use http::{Request, Response};
11use time::OffsetDateTime;
12#[cfg(any(feature = "signed", feature = "private"))]
13use tower_cookies::Key;
14use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
15use tower_layer::Layer;
16use tower_service::Service;
17use tracing::Instrument;
18
19use crate::{
20    Session, SessionStore,
21    session::{self, Expiry},
22};
23
24#[doc(hidden)]
25pub trait CookieController: Clone + Send + 'static {
26    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
27    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
28    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
29}
30
31#[doc(hidden)]
32#[derive(Debug, Clone)]
33pub struct PlaintextCookie;
34
35impl CookieController for PlaintextCookie {
36    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
37        cookies.get(name).map(Cookie::into_owned)
38    }
39
40    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
41        cookies.add(cookie)
42    }
43
44    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
45        cookies.remove(cookie)
46    }
47}
48
49#[doc(hidden)]
50#[cfg(feature = "signed")]
51#[derive(Debug, Clone)]
52pub struct SignedCookie {
53    key: Key,
54}
55
56#[cfg(feature = "signed")]
57impl CookieController for SignedCookie {
58    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
59        cookies.signed(&self.key).get(name).map(Cookie::into_owned)
60    }
61
62    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
63        cookies.signed(&self.key).add(cookie)
64    }
65
66    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
67        cookies.signed(&self.key).remove(cookie)
68    }
69}
70
71#[doc(hidden)]
72#[cfg(feature = "private")]
73#[derive(Debug, Clone)]
74pub struct PrivateCookie {
75    key: Key,
76}
77
78#[cfg(feature = "private")]
79impl CookieController for PrivateCookie {
80    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
81        cookies.private(&self.key).get(name).map(Cookie::into_owned)
82    }
83
84    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
85        cookies.private(&self.key).add(cookie)
86    }
87
88    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
89        cookies.private(&self.key).remove(cookie)
90    }
91}
92
93#[derive(Debug, Clone)]
94struct SessionConfig<'a> {
95    name: Cow<'a, str>,
96    http_only: bool,
97    same_site: SameSite,
98    expiry: Option<Expiry>,
99    secure: bool,
100    path: Cow<'a, str>,
101    domain: Option<Cow<'a, str>>,
102    always_save: bool,
103}
104
105impl<'a> SessionConfig<'a> {
106    fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
107        let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
108            .http_only(self.http_only)
109            .same_site(self.same_site)
110            .secure(self.secure)
111            .path(self.path);
112
113        cookie_builder = match expiry {
114            Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
115            Some(Expiry::AtDateTime(datetime)) => {
116                cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
117            }
118            Some(Expiry::OnSessionEnd(duration)) => cookie_builder.max_age(duration),
119            None => cookie_builder,
120        };
121
122        if let Some(domain) = self.domain {
123            cookie_builder = cookie_builder.domain(domain);
124        }
125
126        cookie_builder.build()
127    }
128}
129
130impl Default for SessionConfig<'_> {
131    fn default() -> Self {
132        Self {
133            name: "id".into(), /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */
134            http_only: true,
135            same_site: SameSite::Strict,
136            expiry: None, // TODO: Is `Max-Age: "Session"` the right default?
137            secure: true,
138            path: "/".into(),
139            domain: None,
140            always_save: false,
141        }
142    }
143}
144
145/// A middleware that provides [`Session`] as a request extension.
146#[derive(Debug, Clone)]
147pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
148    inner: S,
149    session_store: Arc<Store>,
150    session_config: SessionConfig<'static>,
151    cookie_controller: C,
152}
153
154impl<S, Store: SessionStore> SessionManager<S, Store> {
155    /// Create a new [`SessionManager`].
156    pub fn new(inner: S, session_store: Store) -> Self {
157        Self {
158            inner,
159            session_store: Arc::new(session_store),
160            session_config: Default::default(),
161            cookie_controller: PlaintextCookie,
162        }
163    }
164}
165
166impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
167    for SessionManager<S, Store, C>
168where
169    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
170    S::Future: Send,
171    ReqBody: Send + 'static,
172    ResBody: Default + Send,
173{
174    type Response = S::Response;
175    type Error = S::Error;
176    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
177
178    #[inline]
179    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180        self.inner.poll_ready(cx)
181    }
182
183    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
184        let span = tracing::debug_span!("call");
185
186        let session_store = self.session_store.clone();
187        let session_config = self.session_config.clone();
188        let cookie_controller = self.cookie_controller.clone();
189
190        // Because the inner service can panic until ready, we need to ensure we only
191        // use the ready service.
192        //
193        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
194        let clone = self.inner.clone();
195        let mut inner = std::mem::replace(&mut self.inner, clone);
196
197        Box::pin(
198            async move {
199                let Some(cookies) = req.extensions().get::<_>().cloned() else {
200                    // In practice this should never happen because we wrap `CookieManager`
201                    // directly.
202                    tracing::error!("missing cookies request extension");
203                    return Ok(Response::default());
204                };
205
206                let session_cookie = cookie_controller.get(&cookies, &session_config.name);
207                let session_id = session_cookie.as_ref().and_then(|cookie| {
208                    cookie
209                        .value()
210                        .parse::<session::Id>()
211                        .map_err(|err| {
212                            tracing::warn!(
213                                err = %err,
214                                "possibly suspicious activity: malformed session id"
215                            )
216                        })
217                        .ok()
218                });
219
220                let session = Session::new(session_id, session_store, session_config.expiry);
221
222                req.extensions_mut().insert(session.clone());
223
224                let res = inner.call(req).await?;
225
226                let modified = session.is_modified();
227                let empty = session.is_empty().await;
228
229                tracing::trace!(
230                    modified = modified,
231                    empty = empty,
232                    always_save = session_config.always_save,
233                    "session response state",
234                );
235
236                match session_cookie {
237                    Some(mut cookie) if empty => {
238                        tracing::debug!("removing session cookie");
239
240                        // Path and domain must be manually set to ensure a proper removal cookie is
241                        // constructed.
242                        //
243                        // See: https://docs.rs/cookie/latest/cookie/struct.CookieJar.html#method.remove
244                        cookie.set_path(session_config.path);
245                        if let Some(domain) = session_config.domain {
246                            cookie.set_domain(domain);
247                        }
248
249                        cookie_controller.remove(&cookies, cookie);
250                    }
251
252                    _ if (modified || session_config.always_save)
253                        && !empty
254                        && !res.status().is_server_error() =>
255                    {
256                        tracing::debug!("saving session");
257                        if let Err(err) = session.save().await {
258                            tracing::error!(err = %err, "failed to save session");
259
260                            let mut res = Response::default();
261                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
262                            return Ok(res);
263                        }
264
265                        let Some(session_id) = session.id() else {
266                            tracing::error!("missing session id");
267
268                            let mut res = Response::default();
269                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
270                            return Ok(res);
271                        };
272
273                        let expiry = session.expiry();
274                        let session_cookie = session_config.build_cookie(session_id, expiry);
275
276                        tracing::debug!("adding session cookie");
277                        cookie_controller.add(&cookies, session_cookie);
278                    }
279
280                    _ => (),
281                };
282
283                Ok(res)
284            }
285            .instrument(span),
286        )
287    }
288}
289
290/// A layer for providing [`Session`] as a request extension.
291#[derive(Debug, Clone)]
292pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
293    session_store: Arc<Store>,
294    session_config: SessionConfig<'static>,
295    cookie_controller: C,
296}
297
298impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
299    /// Configures the name of the cookie used for the session.
300    /// The default value is `"id"`.
301    ///
302    /// # Examples
303    ///
304    /// ```rust
305    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
306    ///
307    /// let session_store = MemoryStore::default();
308    /// let session_service = SessionManagerLayer::new(session_store).with_name("my.sid");
309    /// ```
310    pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
311        self.session_config.name = name.into();
312        self
313    }
314
315    /// Configures the `"HttpOnly"` attribute of the cookie used for the
316    /// session.
317    ///
318    /// # ⚠️ **Warning: Cross-site scripting risk**
319    ///
320    /// Applications should generally **not** override the default value of
321    /// `true`. If you do, you are exposing your application to increased risk
322    /// of cookie theft via techniques like cross-site scripting.
323    ///
324    /// # Examples
325    ///
326    /// ```rust
327    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
328    ///
329    /// let session_store = MemoryStore::default();
330    /// let session_service = SessionManagerLayer::new(session_store).with_http_only(true);
331    /// ```
332    pub fn with_http_only(mut self, http_only: bool) -> Self {
333        self.session_config.http_only = http_only;
334        self
335    }
336
337    /// Configures the `"SameSite"` attribute of the cookie used for the
338    /// session.
339    /// The default value is [`SameSite::Strict`].
340    ///
341    /// # Examples
342    ///
343    /// ```rust
344    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::SameSite};
345    ///
346    /// let session_store = MemoryStore::default();
347    /// let session_service = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
348    /// ```
349    pub fn with_same_site(mut self, same_site: SameSite) -> Self {
350        self.session_config.same_site = same_site;
351        self
352    }
353
354    /// Configures the `"Max-Age"` attribute of the cookie used for the session.
355    /// The default value is `None`.
356    ///
357    /// # Examples
358    ///
359    /// ```rust
360    /// use time::Duration;
361    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
362    ///
363    /// let session_store = MemoryStore::default();
364    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
365    /// let session_service = SessionManagerLayer::new(session_store).with_expiry(session_expiry);
366    /// ```
367    pub fn with_expiry(mut self, expiry: Expiry) -> Self {
368        self.session_config.expiry = Some(expiry);
369        self
370    }
371
372    /// Configures the `"Secure"` attribute of the cookie used for the session.
373    /// The default value is `true`.
374    ///
375    /// # Examples
376    ///
377    /// ```rust
378    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
379    ///
380    /// let session_store = MemoryStore::default();
381    /// let session_service = SessionManagerLayer::new(session_store).with_secure(true);
382    /// ```
383    pub fn with_secure(mut self, secure: bool) -> Self {
384        self.session_config.secure = secure;
385        self
386    }
387
388    /// Configures the `"Path"` attribute of the cookie used for the session.
389    /// The default value is `"/"`.
390    ///
391    /// # Examples
392    ///
393    /// ```rust
394    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
395    ///
396    /// let session_store = MemoryStore::default();
397    /// let session_service = SessionManagerLayer::new(session_store).with_path("/some/path");
398    /// ```
399    pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
400        self.session_config.path = path.into();
401        self
402    }
403
404    /// Configures the `"Domain"` attribute of the cookie used for the session.
405    /// The default value is `None`.
406    ///
407    /// # Examples
408    ///
409    /// ```rust
410    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
411    ///
412    /// let session_store = MemoryStore::default();
413    /// let session_service = SessionManagerLayer::new(session_store).with_domain("localhost");
414    /// ```
415    pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
416        self.session_config.domain = Some(domain.into());
417        self
418    }
419
420    /// Configures whether unmodified session should be saved on read or not.
421    /// When the value is `true`, the session will be saved even if it was not
422    /// changed.
423    ///
424    /// This is useful when you want to reset [`Session`] expiration time
425    /// on any valid request at the cost of higher [`SessionStore`] write
426    /// activity and transmitting `set-cookie` header with each response.
427    ///
428    /// It makes sense to use this setting with relative session expiration
429    /// values, such as `Expiry::OnInactivity(Duration)`. This setting will
430    /// _not_ cause session id to be cycled on save.
431    ///
432    /// The default value is `false`.
433    ///
434    /// # Examples
435    ///
436    /// ```rust
437    /// use time::Duration;
438    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
439    ///
440    /// let session_store = MemoryStore::default();
441    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
442    /// let session_service = SessionManagerLayer::new(session_store)
443    ///     .with_expiry(session_expiry)
444    ///     .with_always_save(true);
445    /// ```
446    pub fn with_always_save(mut self, always_save: bool) -> Self {
447        self.session_config.always_save = always_save;
448        self
449    }
450
451    /// Manages the session cookie via a signed interface.
452    ///
453    /// See [`SignedCookies`](tower_cookies::SignedCookies).
454    ///
455    /// ```rust
456    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
457    ///
458    /// # /*
459    /// let key = { /* a cryptographically random key >= 64 bytes */ };
460    /// # */
461    /// # let key: &Vec<u8> = &(0..64).collect();
462    /// # let key: &[u8] = &key[..];
463    /// # let key = Key::try_from(key).unwrap();
464    ///
465    /// let session_store = MemoryStore::default();
466    /// let session_service = SessionManagerLayer::new(session_store).with_signed(key);
467    /// ```
468    #[cfg(feature = "signed")]
469    pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
470        SessionManagerLayer::<Store, SignedCookie> {
471            session_store: self.session_store,
472            session_config: self.session_config,
473            cookie_controller: SignedCookie { key },
474        }
475    }
476
477    /// Manages the session cookie via an encrypted interface.
478    ///
479    /// See [`PrivateCookies`](tower_cookies::PrivateCookies).
480    ///
481    /// ```rust
482    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
483    ///
484    /// # /*
485    /// let key = { /* a cryptographically random key >= 64 bytes */ };
486    /// # */
487    /// # let key: &Vec<u8> = &(0..64).collect();
488    /// # let key: &[u8] = &key[..];
489    /// # let key = Key::try_from(key).unwrap();
490    ///
491    /// let session_store = MemoryStore::default();
492    /// let session_service = SessionManagerLayer::new(session_store).with_private(key);
493    /// ```
494    #[cfg(feature = "private")]
495    pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
496        SessionManagerLayer::<Store, PrivateCookie> {
497            session_store: self.session_store,
498            session_config: self.session_config,
499            cookie_controller: PrivateCookie { key },
500        }
501    }
502}
503
504impl<Store: SessionStore> SessionManagerLayer<Store> {
505    /// Create a new [`SessionManagerLayer`] with the provided session store
506    /// and default cookie configuration.
507    ///
508    /// # Examples
509    ///
510    /// ```rust
511    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
512    ///
513    /// let session_store = MemoryStore::default();
514    /// let session_service = SessionManagerLayer::new(session_store);
515    /// ```
516    pub fn new(session_store: Store) -> Self {
517        let session_config = SessionConfig::default();
518
519        Self {
520            session_store: Arc::new(session_store),
521            session_config,
522            cookie_controller: PlaintextCookie,
523        }
524    }
525}
526
527impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
528    type Service = CookieManager<SessionManager<S, Store, C>>;
529
530    fn layer(&self, inner: S) -> Self::Service {
531        let session_manager = SessionManager {
532            inner,
533            session_store: self.session_store.clone(),
534            session_config: self.session_config.clone(),
535            cookie_controller: self.cookie_controller.clone(),
536        };
537
538        CookieManager::new(session_manager)
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use std::str::FromStr;
545
546    use anyhow::anyhow;
547    use axum::body::Body;
548    use tower::{ServiceBuilder, ServiceExt};
549    use tower_sessions_ext_core::session::DEFAULT_DURATION;
550    use tower_sessions_ext_memory_store::MemoryStore;
551
552    use super::*;
553    use crate::session::{Id, Record};
554
555    async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
556        let session = req
557            .extensions()
558            .get::<Session>()
559            .ok_or(anyhow!("Missing session"))?;
560
561        session.insert("foo", 42).await?;
562
563        Ok(Response::new(Body::empty()))
564    }
565
566    async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
567        Ok(Response::new(Body::empty()))
568    }
569
570    #[tokio::test]
571    async fn basic_service_test() -> anyhow::Result<()> {
572        let session_store = MemoryStore::default();
573        let session_layer = SessionManagerLayer::new(session_store);
574        let svc = ServiceBuilder::new()
575            .layer(session_layer)
576            .service_fn(handler);
577
578        let req = Request::builder().body(Body::empty())?;
579        let res = svc.clone().oneshot(req).await?;
580
581        let session = res.headers().get(http::header::SET_COOKIE);
582        assert!(session.is_some());
583
584        let req = Request::builder()
585            .header(http::header::COOKIE, session.unwrap())
586            .body(Body::empty())?;
587        let res = svc.oneshot(req).await?;
588
589        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
590
591        Ok(())
592    }
593
594    #[tokio::test]
595    async fn bogus_cookie_test() -> anyhow::Result<()> {
596        let session_store = MemoryStore::default();
597        let session_layer = SessionManagerLayer::new(session_store);
598        let svc = ServiceBuilder::new()
599            .layer(session_layer)
600            .service_fn(handler);
601
602        let req = Request::builder().body(Body::empty())?;
603        let res = svc.clone().oneshot(req).await?;
604
605        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
606
607        let req = Request::builder()
608            .header(http::header::COOKIE, "id=bogus")
609            .body(Body::empty())?;
610        let res = svc.oneshot(req).await?;
611
612        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
613
614        Ok(())
615    }
616
617    #[tokio::test]
618    async fn no_set_cookie_test() -> anyhow::Result<()> {
619        let session_store = MemoryStore::default();
620        let session_layer = SessionManagerLayer::new(session_store);
621        let svc = ServiceBuilder::new()
622            .layer(session_layer)
623            .service_fn(noop_handler);
624
625        let req = Request::builder().body(Body::empty())?;
626        let res = svc.oneshot(req).await?;
627
628        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
629
630        Ok(())
631    }
632
633    #[tokio::test]
634    async fn name_test() -> anyhow::Result<()> {
635        let session_store = MemoryStore::default();
636        let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
637        let svc = ServiceBuilder::new()
638            .layer(session_layer)
639            .service_fn(handler);
640
641        let req = Request::builder().body(Body::empty())?;
642        let res = svc.oneshot(req).await?;
643
644        assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
645
646        Ok(())
647    }
648
649    #[tokio::test]
650    async fn http_only_test() -> anyhow::Result<()> {
651        let session_store = MemoryStore::default();
652        let session_layer = SessionManagerLayer::new(session_store);
653        let svc = ServiceBuilder::new()
654            .layer(session_layer)
655            .service_fn(handler);
656
657        let req = Request::builder().body(Body::empty())?;
658        let res = svc.oneshot(req).await?;
659
660        assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
661
662        let session_store = MemoryStore::default();
663        let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
664        let svc = ServiceBuilder::new()
665            .layer(session_layer)
666            .service_fn(handler);
667
668        let req = Request::builder().body(Body::empty())?;
669        let res = svc.oneshot(req).await?;
670
671        assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
672
673        Ok(())
674    }
675
676    #[tokio::test]
677    async fn same_site_strict_test() -> anyhow::Result<()> {
678        let session_store = MemoryStore::default();
679        let session_layer =
680            SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
681        let svc = ServiceBuilder::new()
682            .layer(session_layer)
683            .service_fn(handler);
684
685        let req = Request::builder().body(Body::empty())?;
686        let res = svc.oneshot(req).await?;
687
688        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
689
690        Ok(())
691    }
692
693    #[tokio::test]
694    async fn same_site_lax_test() -> anyhow::Result<()> {
695        let session_store = MemoryStore::default();
696        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
697        let svc = ServiceBuilder::new()
698            .layer(session_layer)
699            .service_fn(handler);
700
701        let req = Request::builder().body(Body::empty())?;
702        let res = svc.oneshot(req).await?;
703
704        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
705
706        Ok(())
707    }
708
709    #[tokio::test]
710    async fn same_site_none_test() -> anyhow::Result<()> {
711        let session_store = MemoryStore::default();
712        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
713        let svc = ServiceBuilder::new()
714            .layer(session_layer)
715            .service_fn(handler);
716
717        let req = Request::builder().body(Body::empty())?;
718        let res = svc.oneshot(req).await?;
719
720        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
721
722        Ok(())
723    }
724
725    #[tokio::test]
726    async fn expiry_on_session_end_test() -> anyhow::Result<()> {
727        let session_store = MemoryStore::default();
728        let session_layer = SessionManagerLayer::new(session_store)
729            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
730        let svc = ServiceBuilder::new()
731            .layer(session_layer)
732            .service_fn(handler);
733
734        let req = Request::builder().body(Body::empty())?;
735        let res = svc.oneshot(req).await?;
736
737        assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
738
739        Ok(())
740    }
741
742    #[tokio::test]
743    async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
744        let session_store = MemoryStore::default();
745        let inactivity_duration = time::Duration::hours(2);
746        let session_layer = SessionManagerLayer::new(session_store)
747            .with_expiry(Expiry::OnInactivity(inactivity_duration));
748        let svc = ServiceBuilder::new()
749            .layer(session_layer)
750            .service_fn(handler);
751
752        let req = Request::builder().body(Body::empty())?;
753        let res = svc.oneshot(req).await?;
754
755        let expected_max_age = inactivity_duration.whole_seconds();
756        assert!(cookie_has_expected_max_age(&res, expected_max_age));
757
758        Ok(())
759    }
760
761    #[tokio::test]
762    async fn expiry_at_date_time_test() -> anyhow::Result<()> {
763        let session_store = MemoryStore::default();
764        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
765        let session_layer =
766            SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
767        let svc = ServiceBuilder::new()
768            .layer(session_layer)
769            .service_fn(handler);
770
771        let req = Request::builder().body(Body::empty())?;
772        let res = svc.oneshot(req).await?;
773
774        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
775        assert!(cookie_has_expected_max_age(&res, expected_max_age));
776
777        Ok(())
778    }
779
780    #[tokio::test]
781    async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
782        let session_store = MemoryStore::default();
783        let session_layer = SessionManagerLayer::new(session_store.clone())
784            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
785            .with_always_save(true);
786        let mut svc = ServiceBuilder::new()
787            .layer(session_layer)
788            .service_fn(handler);
789
790        let req1 = Request::builder().body(Body::empty())?;
791        let res1 = svc.call(req1).await?;
792        let sid1 = get_session_id(&res1);
793        let rec1 = get_record(&session_store, &sid1).await;
794        let req2 = Request::builder()
795            .header(http::header::COOKIE, format!("id={}", sid1))
796            .body(Body::empty())?;
797        let res2 = svc.call(req2).await?;
798        let sid2 = get_session_id(&res2);
799        let rec2 = get_record(&session_store, &sid2).await;
800
801        assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
802        assert!(sid1 == sid2);
803        assert!(rec1.expiry_date < rec2.expiry_date);
804
805        Ok(())
806    }
807
808    #[tokio::test]
809    async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
810        let session_store = MemoryStore::default();
811        let inactivity_duration = time::Duration::hours(2);
812        let session_layer = SessionManagerLayer::new(session_store.clone())
813            .with_expiry(Expiry::OnInactivity(inactivity_duration))
814            .with_always_save(true);
815        let mut svc = ServiceBuilder::new()
816            .layer(session_layer)
817            .service_fn(handler);
818
819        let req1 = Request::builder().body(Body::empty())?;
820        let res1 = svc.call(req1).await?;
821        let sid1 = get_session_id(&res1);
822        let rec1 = get_record(&session_store, &sid1).await;
823        let req2 = Request::builder()
824            .header(http::header::COOKIE, format!("id={}", sid1))
825            .body(Body::empty())?;
826        let res2 = svc.call(req2).await?;
827        let sid2 = get_session_id(&res2);
828        let rec2 = get_record(&session_store, &sid2).await;
829
830        let expected_max_age = inactivity_duration.whole_seconds();
831        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
832        assert!(sid1 == sid2);
833        assert!(rec1.expiry_date < rec2.expiry_date);
834
835        Ok(())
836    }
837
838    #[tokio::test]
839    async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
840        let session_store = MemoryStore::default();
841        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
842        let session_layer = SessionManagerLayer::new(session_store.clone())
843            .with_expiry(Expiry::AtDateTime(expiry_time))
844            .with_always_save(true);
845        let mut svc = ServiceBuilder::new()
846            .layer(session_layer)
847            .service_fn(handler);
848
849        let req1 = Request::builder().body(Body::empty())?;
850        let res1 = svc.call(req1).await?;
851        let sid1 = get_session_id(&res1);
852        let rec1 = get_record(&session_store, &sid1).await;
853        let req2 = Request::builder()
854            .header(http::header::COOKIE, format!("id={}", sid1))
855            .body(Body::empty())?;
856        let res2 = svc.call(req2).await?;
857        let sid2 = get_session_id(&res2);
858        let rec2 = get_record(&session_store, &sid2).await;
859
860        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
861        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
862        assert!(sid1 == sid2);
863        assert!(rec1.expiry_date == rec2.expiry_date);
864
865        Ok(())
866    }
867
868    #[tokio::test]
869    async fn secure_test() -> anyhow::Result<()> {
870        let session_store = MemoryStore::default();
871        let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
872        let svc = ServiceBuilder::new()
873            .layer(session_layer)
874            .service_fn(handler);
875
876        let req = Request::builder().body(Body::empty())?;
877        let res = svc.oneshot(req).await?;
878
879        assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
880
881        let session_store = MemoryStore::default();
882        let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
883        let svc = ServiceBuilder::new()
884            .layer(session_layer)
885            .service_fn(handler);
886
887        let req = Request::builder().body(Body::empty())?;
888        let res = svc.oneshot(req).await?;
889
890        assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
891
892        Ok(())
893    }
894
895    #[tokio::test]
896    async fn path_test() -> anyhow::Result<()> {
897        let session_store = MemoryStore::default();
898        let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
899        let svc = ServiceBuilder::new()
900            .layer(session_layer)
901            .service_fn(handler);
902
903        let req = Request::builder().body(Body::empty())?;
904        let res = svc.oneshot(req).await?;
905
906        assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
907
908        Ok(())
909    }
910
911    #[tokio::test]
912    async fn domain_test() -> anyhow::Result<()> {
913        let session_store = MemoryStore::default();
914        let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
915        let svc = ServiceBuilder::new()
916            .layer(session_layer)
917            .service_fn(handler);
918
919        let req = Request::builder().body(Body::empty())?;
920        let res = svc.oneshot(req).await?;
921
922        assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
923
924        Ok(())
925    }
926
927    #[cfg(feature = "signed")]
928    #[tokio::test]
929    async fn signed_test() -> anyhow::Result<()> {
930        let key = Key::generate();
931        let session_store = MemoryStore::default();
932        let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
933        let svc = ServiceBuilder::new()
934            .layer(session_layer)
935            .service_fn(handler);
936
937        let req = Request::builder().body(Body::empty())?;
938        let res = svc.oneshot(req).await?;
939
940        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
941
942        Ok(())
943    }
944
945    #[cfg(feature = "private")]
946    #[tokio::test]
947    async fn private_test() -> anyhow::Result<()> {
948        let key = Key::generate();
949        let session_store = MemoryStore::default();
950        let session_layer = SessionManagerLayer::new(session_store).with_private(key);
951        let svc = ServiceBuilder::new()
952            .layer(session_layer)
953            .service_fn(handler);
954
955        let req = Request::builder().body(Body::empty())?;
956        let res = svc.oneshot(req).await?;
957
958        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
959
960        Ok(())
961    }
962
963    fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
964    where
965        F: FnOnce(&str) -> bool,
966    {
967        res.headers()
968            .get(http::header::SET_COOKIE)
969            .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
970    }
971
972    fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
973        res.headers()
974            .get(http::header::SET_COOKIE)
975            .is_some_and(|set_cookie| {
976                set_cookie.to_str().is_ok_and(|s| {
977                    let max_age_value = s
978                        .split("Max-Age=")
979                        .nth(1)
980                        .unwrap_or_default()
981                        .split(';')
982                        .next()
983                        .unwrap_or_default()
984                        .parse::<i64>()
985                        .unwrap_or_default();
986                    (max_age_value - expected_value).abs() <= 1
987                })
988            })
989    }
990
991    fn get_session_id(res: &Response<Body>) -> String {
992        res.headers()
993            .get(http::header::SET_COOKIE)
994            .unwrap()
995            .to_str()
996            .unwrap()
997            .split("id=")
998            .nth(1)
999            .unwrap()
1000            .split(";")
1001            .next()
1002            .unwrap()
1003            .to_string()
1004    }
1005
1006    async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1007        store
1008            .load(&Id::from_str(id).unwrap())
1009            .await
1010            .unwrap()
1011            .unwrap()
1012    }
1013}