Skip to main content

tower_sessions_ext/
service.rs

1//! A middleware that provides [`Session`] as a request extension.
2use std::{
3    borrow::Cow,
4    fmt,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use http::{Request, Response};
12use time::OffsetDateTime;
13#[cfg(any(feature = "signed", feature = "private"))]
14use tower_cookies::Key;
15use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
16use tower_layer::Layer;
17use tower_service::Service;
18use tracing::Instrument;
19
20use crate::{
21    Session, SessionStore,
22    session::{self, Expiry, OnExpireCallback},
23};
24
25#[doc(hidden)]
26pub trait CookieController: Clone + Send + 'static {
27    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
28    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
29    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
30}
31
32#[doc(hidden)]
33#[derive(Debug, Clone)]
34pub struct PlaintextCookie;
35
36impl CookieController for PlaintextCookie {
37    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
38        cookies.get(name).map(Cookie::into_owned)
39    }
40
41    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
42        cookies.add(cookie)
43    }
44
45    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
46        cookies.remove(cookie)
47    }
48}
49
50#[doc(hidden)]
51#[cfg(feature = "signed")]
52#[derive(Debug, Clone)]
53pub struct SignedCookie {
54    key: Key,
55}
56
57#[cfg(feature = "signed")]
58impl CookieController for SignedCookie {
59    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
60        cookies.signed(&self.key).get(name).map(Cookie::into_owned)
61    }
62
63    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
64        cookies.signed(&self.key).add(cookie)
65    }
66
67    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
68        cookies.signed(&self.key).remove(cookie)
69    }
70}
71
72#[doc(hidden)]
73#[cfg(feature = "private")]
74#[derive(Debug, Clone)]
75pub struct PrivateCookie {
76    key: Key,
77}
78
79#[cfg(feature = "private")]
80impl CookieController for PrivateCookie {
81    fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
82        cookies.private(&self.key).get(name).map(Cookie::into_owned)
83    }
84
85    fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
86        cookies.private(&self.key).add(cookie)
87    }
88
89    fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
90        cookies.private(&self.key).remove(cookie)
91    }
92}
93
94#[derive(Clone)]
95struct SessionConfig<'a> {
96    name: Cow<'a, str>,
97    http_only: bool,
98    same_site: SameSite,
99    expiry: Option<Expiry>,
100    secure: bool,
101    path: Cow<'a, str>,
102    domain: Option<Cow<'a, str>>,
103    always_save: bool,
104    on_expire: Option<OnExpireCallback>,
105}
106
107impl fmt::Debug for SessionConfig<'_> {
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        f.debug_struct("SessionConfig")
110            .field("name", &self.name)
111            .field("http_only", &self.http_only)
112            .field("same_site", &self.same_site)
113            .field("expiry", &self.expiry)
114            .field("secure", &self.secure)
115            .field("path", &self.path)
116            .field("domain", &self.domain)
117            .field("always_save", &self.always_save)
118            .field("on_expire", &self.on_expire.as_ref().map(|_| "Some(_)"))
119            .finish()
120    }
121}
122
123impl<'a> SessionConfig<'a> {
124    fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
125        let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
126            .http_only(self.http_only)
127            .same_site(self.same_site)
128            .secure(self.secure)
129            .path(self.path);
130
131        cookie_builder = match expiry {
132            Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
133            Some(Expiry::AtDateTime(datetime)) => {
134                cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
135            }
136            // Session cookie: no Max-Age so the browser treats it as ending when the session ends.
137            Some(Expiry::OnSessionEnd(_)) | None => cookie_builder,
138        };
139
140        if let Some(domain) = self.domain {
141            cookie_builder = cookie_builder.domain(domain);
142        }
143
144        cookie_builder.build()
145    }
146}
147
148impl Default for SessionConfig<'_> {
149    fn default() -> Self {
150        Self {
151            name: "id".into(), /* See: https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html#session-id-name-fingerprinting */
152            http_only: true,
153            same_site: SameSite::Strict,
154            expiry: None, // TODO: Is `Max-Age: "Session"` the right default?
155            secure: true,
156            path: "/".into(),
157            domain: None,
158            always_save: false,
159            on_expire: None,
160        }
161    }
162}
163
164/// A middleware that provides [`Session`] as a request extension.
165#[derive(Debug, Clone)]
166pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
167    inner: S,
168    session_store: Arc<Store>,
169    session_config: SessionConfig<'static>,
170    cookie_controller: C,
171}
172
173impl<S, Store: SessionStore> SessionManager<S, Store> {
174    /// Create a new [`SessionManager`].
175    pub fn new(inner: S, session_store: Store) -> Self {
176        Self {
177            inner,
178            session_store: Arc::new(session_store),
179            session_config: Default::default(),
180            cookie_controller: PlaintextCookie,
181        }
182    }
183}
184
185impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
186    for SessionManager<S, Store, C>
187where
188    S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
189    S::Future: Send,
190    ReqBody: Send + 'static,
191    ResBody: Default + Send,
192{
193    type Response = S::Response;
194    type Error = S::Error;
195    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
196
197    #[inline]
198    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
199        self.inner.poll_ready(cx)
200    }
201
202    fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
203        let span = tracing::debug_span!("call");
204
205        let session_store = self.session_store.clone();
206        let session_config = self.session_config.clone();
207        let cookie_controller = self.cookie_controller.clone();
208
209        // Because the inner service can panic until ready, we need to ensure we only
210        // use the ready service.
211        //
212        // See: https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services
213        let clone = self.inner.clone();
214        let mut inner = std::mem::replace(&mut self.inner, clone);
215
216        Box::pin(
217            async move {
218                let Some(cookies) = req.extensions().get::<_>().cloned() else {
219                    // In practice this should never happen because we wrap `CookieManager`
220                    // directly.
221                    tracing::error!("missing cookies request extension");
222                    return Ok(Response::default());
223                };
224
225                let session_cookie = cookie_controller.get(&cookies, &session_config.name);
226                let session_id = session_cookie.as_ref().and_then(|cookie| {
227                    cookie
228                        .value()
229                        .parse::<session::Id>()
230                        .map_err(|err| {
231                            tracing::warn!(
232                                err = %err,
233                                "possibly suspicious activity: malformed session id"
234                            )
235                        })
236                        .ok()
237                });
238
239                let session = Session::new(
240                    session_id,
241                    session_store,
242                    session_config.expiry,
243                    session_config.on_expire.clone(),
244                );
245
246                req.extensions_mut().insert(session.clone());
247
248                let res = inner.call(req).await?;
249
250                let modified = session.is_modified();
251                let empty = session.is_empty().await;
252
253                tracing::trace!(
254                    modified = modified,
255                    empty = empty,
256                    always_save = session_config.always_save,
257                    "session response state",
258                );
259
260                match session_cookie {
261                    Some(mut cookie) if empty => {
262                        tracing::debug!("removing session cookie");
263
264                        // Path and domain must be manually set to ensure a proper removal cookie is
265                        // constructed.
266                        //
267                        // See: https://docs.rs/cookie/latest/cookie/struct.CookieJar.html#method.remove
268                        cookie.set_path(session_config.path);
269                        if let Some(domain) = session_config.domain {
270                            cookie.set_domain(domain);
271                        }
272
273                        cookie_controller.remove(&cookies, cookie);
274                    }
275
276                    _ if (modified || session_config.always_save)
277                        && !empty
278                        && !res.status().is_server_error() =>
279                    {
280                        tracing::debug!("saving session");
281                        if let Err(err) = session.save().await {
282                            tracing::error!(err = %err, "failed to save session");
283
284                            let mut res = Response::default();
285                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
286                            return Ok(res);
287                        }
288
289                        let Some(session_id) = session.id() else {
290                            tracing::error!("missing session id");
291
292                            let mut res = Response::default();
293                            *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
294                            return Ok(res);
295                        };
296
297                        let expiry = session.expiry();
298                        let session_cookie = session_config.build_cookie(session_id, expiry);
299
300                        tracing::debug!("adding session cookie");
301                        cookie_controller.add(&cookies, session_cookie);
302                    }
303
304                    _ => (),
305                };
306
307                Ok(res)
308            }
309            .instrument(span),
310        )
311    }
312}
313
314/// A layer for providing [`Session`] as a request extension.
315#[derive(Debug, Clone)]
316pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
317    session_store: Arc<Store>,
318    session_config: SessionConfig<'static>,
319    cookie_controller: C,
320}
321
322impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
323    /// Configures the name of the cookie used for the session.
324    /// The default value is `"id"`.
325    ///
326    /// # Examples
327    ///
328    /// ```rust
329    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
330    ///
331    /// let session_store = MemoryStore::default();
332    /// let session_service = SessionManagerLayer::new(session_store).with_name("my.sid");
333    /// ```
334    pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
335        self.session_config.name = name.into();
336        self
337    }
338
339    /// Configures the `"HttpOnly"` attribute of the cookie used for the
340    /// session.
341    ///
342    /// # ⚠️ **Warning: Cross-site scripting risk**
343    ///
344    /// Applications should generally **not** override the default value of
345    /// `true`. If you do, you are exposing your application to increased risk
346    /// of cookie theft via techniques like cross-site scripting.
347    ///
348    /// # Examples
349    ///
350    /// ```rust
351    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
352    ///
353    /// let session_store = MemoryStore::default();
354    /// let session_service = SessionManagerLayer::new(session_store).with_http_only(true);
355    /// ```
356    pub fn with_http_only(mut self, http_only: bool) -> Self {
357        self.session_config.http_only = http_only;
358        self
359    }
360
361    /// Configures the `"SameSite"` attribute of the cookie used for the
362    /// session.
363    /// The default value is [`SameSite::Strict`].
364    ///
365    /// # Examples
366    ///
367    /// ```rust
368    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::SameSite};
369    ///
370    /// let session_store = MemoryStore::default();
371    /// let session_service = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
372    /// ```
373    pub fn with_same_site(mut self, same_site: SameSite) -> Self {
374        self.session_config.same_site = same_site;
375        self
376    }
377
378    /// Configures the `"Max-Age"` attribute of the cookie used for the session.
379    /// The default value is `None`.
380    ///
381    /// # Examples
382    ///
383    /// ```rust
384    /// use time::Duration;
385    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
386    ///
387    /// let session_store = MemoryStore::default();
388    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
389    /// let session_service = SessionManagerLayer::new(session_store).with_expiry(session_expiry);
390    /// ```
391    pub fn with_expiry(mut self, expiry: Expiry) -> Self {
392        self.session_config.expiry = Some(expiry);
393        self
394    }
395
396    /// Configures the `"Secure"` attribute of the cookie used for the session.
397    /// The default value is `true`.
398    ///
399    /// # Examples
400    ///
401    /// ```rust
402    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
403    ///
404    /// let session_store = MemoryStore::default();
405    /// let session_service = SessionManagerLayer::new(session_store).with_secure(true);
406    /// ```
407    pub fn with_secure(mut self, secure: bool) -> Self {
408        self.session_config.secure = secure;
409        self
410    }
411
412    /// Configures the `"Path"` attribute of the cookie used for the session.
413    /// The default value is `"/"`.
414    ///
415    /// # Examples
416    ///
417    /// ```rust
418    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
419    ///
420    /// let session_store = MemoryStore::default();
421    /// let session_service = SessionManagerLayer::new(session_store).with_path("/some/path");
422    /// ```
423    pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
424        self.session_config.path = path.into();
425        self
426    }
427
428    /// Configures the `"Domain"` attribute of the cookie used for the session.
429    /// The default value is `None`.
430    ///
431    /// # Examples
432    ///
433    /// ```rust
434    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
435    ///
436    /// let session_store = MemoryStore::default();
437    /// let session_service = SessionManagerLayer::new(session_store).with_domain("localhost");
438    /// ```
439    pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
440        self.session_config.domain = Some(domain.into());
441        self
442    }
443
444    /// Configures whether unmodified session should be saved on read or not.
445    /// When the value is `true`, the session will be saved even if it was not
446    /// changed.
447    ///
448    /// This is useful when you want to reset [`Session`] expiration time
449    /// on any valid request at the cost of higher [`SessionStore`] write
450    /// activity and transmitting `set-cookie` header with each response.
451    ///
452    /// It makes sense to use this setting with relative session expiration
453    /// values, such as `Expiry::OnInactivity(Duration)`. This setting will
454    /// _not_ cause session id to be cycled on save.
455    ///
456    /// The default value is `false`.
457    ///
458    /// # Examples
459    ///
460    /// ```rust
461    /// use time::Duration;
462    /// use tower_sessions_ext::{Expiry, MemoryStore, SessionManagerLayer};
463    ///
464    /// let session_store = MemoryStore::default();
465    /// let session_expiry = Expiry::OnInactivity(Duration::hours(1));
466    /// let session_service = SessionManagerLayer::new(session_store)
467    ///     .with_expiry(session_expiry)
468    ///     .with_always_save(true);
469    /// ```
470    pub fn with_always_save(mut self, always_save: bool) -> Self {
471        self.session_config.always_save = always_save;
472        self
473    }
474
475    /// Registers a callback that is invoked when a session is discovered to have expired.
476    ///
477    /// The callback runs when the store returns `None` for a session id that was sent by the
478    /// client (e.g. the session was removed by a background cleanup or expired). Use this for
479    /// cleanup such as revoking tokens or logging out the user in other systems.
480    ///
481    /// The callback is invoked synchronously during request handling; for heavy work consider
482    /// spawning a task from within the callback.
483    ///
484    /// # Examples
485    ///
486    /// ```rust
487    /// use std::sync::Arc;
488    ///
489    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
490    ///
491    /// let session_store = MemoryStore::default();
492    /// let session_service = SessionManagerLayer::new(session_store)
493    ///     .with_on_expire(|session_id| {
494    ///         tracing::info!(%session_id, "session expired");
495    ///     });
496    /// ```
497    pub fn with_on_expire<F>(mut self, f: F) -> Self
498    where
499        F: Fn(session::Id) + Send + Sync + 'static,
500    {
501        self.session_config.on_expire = Some(Arc::new(f));
502        self
503    }
504
505    /// Manages the session cookie via a signed interface.
506    ///
507    /// See [`SignedCookies`](tower_cookies::SignedCookies).
508    ///
509    /// ```rust
510    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
511    ///
512    /// # /*
513    /// let key = { /* a cryptographically random key >= 64 bytes */ };
514    /// # */
515    /// # let key: &Vec<u8> = &(0..64).collect();
516    /// # let key: &[u8] = &key[..];
517    /// # let key = Key::try_from(key).unwrap();
518    ///
519    /// let session_store = MemoryStore::default();
520    /// let session_service = SessionManagerLayer::new(session_store).with_signed(key);
521    /// ```
522    #[cfg(feature = "signed")]
523    pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
524        SessionManagerLayer::<Store, SignedCookie> {
525            session_store: self.session_store,
526            session_config: self.session_config,
527            cookie_controller: SignedCookie { key },
528        }
529    }
530
531    /// Manages the session cookie via an encrypted interface.
532    ///
533    /// See [`PrivateCookies`](tower_cookies::PrivateCookies).
534    ///
535    /// ```rust
536    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer, cookie::Key};
537    ///
538    /// # /*
539    /// let key = { /* a cryptographically random key >= 64 bytes */ };
540    /// # */
541    /// # let key: &Vec<u8> = &(0..64).collect();
542    /// # let key: &[u8] = &key[..];
543    /// # let key = Key::try_from(key).unwrap();
544    ///
545    /// let session_store = MemoryStore::default();
546    /// let session_service = SessionManagerLayer::new(session_store).with_private(key);
547    /// ```
548    #[cfg(feature = "private")]
549    pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
550        SessionManagerLayer::<Store, PrivateCookie> {
551            session_store: self.session_store,
552            session_config: self.session_config,
553            cookie_controller: PrivateCookie { key },
554        }
555    }
556}
557
558impl<Store: SessionStore> SessionManagerLayer<Store> {
559    /// Create a new [`SessionManagerLayer`] with the provided session store
560    /// and default cookie configuration.
561    ///
562    /// # Examples
563    ///
564    /// ```rust
565    /// use tower_sessions_ext::{MemoryStore, SessionManagerLayer};
566    ///
567    /// let session_store = MemoryStore::default();
568    /// let session_service = SessionManagerLayer::new(session_store);
569    /// ```
570    pub fn new(session_store: Store) -> Self {
571        let session_config = SessionConfig::default();
572
573        Self {
574            session_store: Arc::new(session_store),
575            session_config,
576            cookie_controller: PlaintextCookie,
577        }
578    }
579}
580
581impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
582    type Service = CookieManager<SessionManager<S, Store, C>>;
583
584    fn layer(&self, inner: S) -> Self::Service {
585        let session_manager = SessionManager {
586            inner,
587            session_store: self.session_store.clone(),
588            session_config: self.session_config.clone(),
589            cookie_controller: self.cookie_controller.clone(),
590        };
591
592        CookieManager::new(session_manager)
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use std::str::FromStr;
599
600    use anyhow::anyhow;
601    use axum::body::Body;
602    use tower::{ServiceBuilder, ServiceExt};
603    use tower_sessions_ext_core::session::DEFAULT_DURATION;
604    use tower_sessions_ext_memory_store::MemoryStore;
605
606    use super::*;
607    use crate::session::{Id, Record};
608
609    async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
610        let session = req
611            .extensions()
612            .get::<Session>()
613            .ok_or(anyhow!("Missing session"))?;
614
615        session.insert("foo", 42).await?;
616
617        Ok(Response::new(Body::empty()))
618    }
619
620    async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
621        Ok(Response::new(Body::empty()))
622    }
623
624    #[tokio::test]
625    async fn basic_service_test() -> anyhow::Result<()> {
626        let session_store = MemoryStore::default();
627        let session_layer = SessionManagerLayer::new(session_store);
628        let svc = ServiceBuilder::new()
629            .layer(session_layer)
630            .service_fn(handler);
631
632        let req = Request::builder().body(Body::empty())?;
633        let res = svc.clone().oneshot(req).await?;
634
635        let session = res.headers().get(http::header::SET_COOKIE);
636        assert!(session.is_some());
637
638        let req = Request::builder()
639            .header(http::header::COOKIE, session.unwrap())
640            .body(Body::empty())?;
641        let res = svc.oneshot(req).await?;
642
643        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
644
645        Ok(())
646    }
647
648    #[tokio::test]
649    async fn bogus_cookie_test() -> anyhow::Result<()> {
650        let session_store = MemoryStore::default();
651        let session_layer = SessionManagerLayer::new(session_store);
652        let svc = ServiceBuilder::new()
653            .layer(session_layer)
654            .service_fn(handler);
655
656        let req = Request::builder().body(Body::empty())?;
657        let res = svc.clone().oneshot(req).await?;
658
659        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
660
661        let req = Request::builder()
662            .header(http::header::COOKIE, "id=bogus")
663            .body(Body::empty())?;
664        let res = svc.oneshot(req).await?;
665
666        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
667
668        Ok(())
669    }
670
671    #[tokio::test]
672    async fn no_set_cookie_test() -> anyhow::Result<()> {
673        let session_store = MemoryStore::default();
674        let session_layer = SessionManagerLayer::new(session_store);
675        let svc = ServiceBuilder::new()
676            .layer(session_layer)
677            .service_fn(noop_handler);
678
679        let req = Request::builder().body(Body::empty())?;
680        let res = svc.oneshot(req).await?;
681
682        assert!(res.headers().get(http::header::SET_COOKIE).is_none());
683
684        Ok(())
685    }
686
687    #[tokio::test]
688    async fn name_test() -> anyhow::Result<()> {
689        let session_store = MemoryStore::default();
690        let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
691        let svc = ServiceBuilder::new()
692            .layer(session_layer)
693            .service_fn(handler);
694
695        let req = Request::builder().body(Body::empty())?;
696        let res = svc.oneshot(req).await?;
697
698        assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
699
700        Ok(())
701    }
702
703    #[tokio::test]
704    async fn http_only_test() -> anyhow::Result<()> {
705        let session_store = MemoryStore::default();
706        let session_layer = SessionManagerLayer::new(session_store);
707        let svc = ServiceBuilder::new()
708            .layer(session_layer)
709            .service_fn(handler);
710
711        let req = Request::builder().body(Body::empty())?;
712        let res = svc.oneshot(req).await?;
713
714        assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
715
716        let session_store = MemoryStore::default();
717        let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
718        let svc = ServiceBuilder::new()
719            .layer(session_layer)
720            .service_fn(handler);
721
722        let req = Request::builder().body(Body::empty())?;
723        let res = svc.oneshot(req).await?;
724
725        assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
726
727        Ok(())
728    }
729
730    #[tokio::test]
731    async fn same_site_strict_test() -> anyhow::Result<()> {
732        let session_store = MemoryStore::default();
733        let session_layer =
734            SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
735        let svc = ServiceBuilder::new()
736            .layer(session_layer)
737            .service_fn(handler);
738
739        let req = Request::builder().body(Body::empty())?;
740        let res = svc.oneshot(req).await?;
741
742        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
743
744        Ok(())
745    }
746
747    #[tokio::test]
748    async fn same_site_lax_test() -> anyhow::Result<()> {
749        let session_store = MemoryStore::default();
750        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
751        let svc = ServiceBuilder::new()
752            .layer(session_layer)
753            .service_fn(handler);
754
755        let req = Request::builder().body(Body::empty())?;
756        let res = svc.oneshot(req).await?;
757
758        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
759
760        Ok(())
761    }
762
763    #[tokio::test]
764    async fn same_site_none_test() -> anyhow::Result<()> {
765        let session_store = MemoryStore::default();
766        let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
767        let svc = ServiceBuilder::new()
768            .layer(session_layer)
769            .service_fn(handler);
770
771        let req = Request::builder().body(Body::empty())?;
772        let res = svc.oneshot(req).await?;
773
774        assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
775
776        Ok(())
777    }
778
779    #[tokio::test]
780    async fn expiry_on_session_end_test() -> anyhow::Result<()> {
781        let session_store = MemoryStore::default();
782        let session_layer = SessionManagerLayer::new(session_store)
783            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
784        let svc = ServiceBuilder::new()
785            .layer(session_layer)
786            .service_fn(handler);
787
788        let req = Request::builder().body(Body::empty())?;
789        let res = svc.oneshot(req).await?;
790
791        assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
792
793        Ok(())
794    }
795
796    #[tokio::test]
797    async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
798        let session_store = MemoryStore::default();
799        let inactivity_duration = time::Duration::hours(2);
800        let session_layer = SessionManagerLayer::new(session_store)
801            .with_expiry(Expiry::OnInactivity(inactivity_duration));
802        let svc = ServiceBuilder::new()
803            .layer(session_layer)
804            .service_fn(handler);
805
806        let req = Request::builder().body(Body::empty())?;
807        let res = svc.oneshot(req).await?;
808
809        let expected_max_age = inactivity_duration.whole_seconds();
810        assert!(cookie_has_expected_max_age(&res, expected_max_age));
811
812        Ok(())
813    }
814
815    #[tokio::test]
816    async fn expiry_at_date_time_test() -> anyhow::Result<()> {
817        let session_store = MemoryStore::default();
818        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
819        let session_layer =
820            SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
821        let svc = ServiceBuilder::new()
822            .layer(session_layer)
823            .service_fn(handler);
824
825        let req = Request::builder().body(Body::empty())?;
826        let res = svc.oneshot(req).await?;
827
828        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
829        assert!(cookie_has_expected_max_age(&res, expected_max_age));
830
831        Ok(())
832    }
833
834    #[tokio::test]
835    async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
836        let session_store = MemoryStore::default();
837        let session_layer = SessionManagerLayer::new(session_store.clone())
838            .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
839            .with_always_save(true);
840        let mut svc = ServiceBuilder::new()
841            .layer(session_layer)
842            .service_fn(handler);
843
844        let req1 = Request::builder().body(Body::empty())?;
845        let res1 = svc.call(req1).await?;
846        let sid1 = get_session_id(&res1);
847        let rec1 = get_record(&session_store, &sid1).await;
848        let req2 = Request::builder()
849            .header(http::header::COOKIE, format!("id={}", sid1))
850            .body(Body::empty())?;
851        let res2 = svc.call(req2).await?;
852        let sid2 = get_session_id(&res2);
853        let rec2 = get_record(&session_store, &sid2).await;
854
855        assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
856        assert!(sid1 == sid2);
857        assert!(rec1.expiry_date < rec2.expiry_date);
858
859        Ok(())
860    }
861
862    #[tokio::test]
863    async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
864        let session_store = MemoryStore::default();
865        let inactivity_duration = time::Duration::hours(2);
866        let session_layer = SessionManagerLayer::new(session_store.clone())
867            .with_expiry(Expiry::OnInactivity(inactivity_duration))
868            .with_always_save(true);
869        let mut svc = ServiceBuilder::new()
870            .layer(session_layer)
871            .service_fn(handler);
872
873        let req1 = Request::builder().body(Body::empty())?;
874        let res1 = svc.call(req1).await?;
875        let sid1 = get_session_id(&res1);
876        let rec1 = get_record(&session_store, &sid1).await;
877        let req2 = Request::builder()
878            .header(http::header::COOKIE, format!("id={}", sid1))
879            .body(Body::empty())?;
880        let res2 = svc.call(req2).await?;
881        let sid2 = get_session_id(&res2);
882        let rec2 = get_record(&session_store, &sid2).await;
883
884        let expected_max_age = inactivity_duration.whole_seconds();
885        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
886        assert!(sid1 == sid2);
887        assert!(rec1.expiry_date < rec2.expiry_date);
888
889        Ok(())
890    }
891
892    #[tokio::test]
893    async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
894        let session_store = MemoryStore::default();
895        let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
896        let session_layer = SessionManagerLayer::new(session_store.clone())
897            .with_expiry(Expiry::AtDateTime(expiry_time))
898            .with_always_save(true);
899        let mut svc = ServiceBuilder::new()
900            .layer(session_layer)
901            .service_fn(handler);
902
903        let req1 = Request::builder().body(Body::empty())?;
904        let res1 = svc.call(req1).await?;
905        let sid1 = get_session_id(&res1);
906        let rec1 = get_record(&session_store, &sid1).await;
907        let req2 = Request::builder()
908            .header(http::header::COOKIE, format!("id={}", sid1))
909            .body(Body::empty())?;
910        let res2 = svc.call(req2).await?;
911        let sid2 = get_session_id(&res2);
912        let rec2 = get_record(&session_store, &sid2).await;
913
914        let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
915        assert!(cookie_has_expected_max_age(&res2, expected_max_age));
916        assert!(sid1 == sid2);
917        assert!(rec1.expiry_date == rec2.expiry_date);
918
919        Ok(())
920    }
921
922    #[tokio::test]
923    async fn secure_test() -> anyhow::Result<()> {
924        let session_store = MemoryStore::default();
925        let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
926        let svc = ServiceBuilder::new()
927            .layer(session_layer)
928            .service_fn(handler);
929
930        let req = Request::builder().body(Body::empty())?;
931        let res = svc.oneshot(req).await?;
932
933        assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
934
935        let session_store = MemoryStore::default();
936        let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
937        let svc = ServiceBuilder::new()
938            .layer(session_layer)
939            .service_fn(handler);
940
941        let req = Request::builder().body(Body::empty())?;
942        let res = svc.oneshot(req).await?;
943
944        assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
945
946        Ok(())
947    }
948
949    #[tokio::test]
950    async fn path_test() -> anyhow::Result<()> {
951        let session_store = MemoryStore::default();
952        let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
953        let svc = ServiceBuilder::new()
954            .layer(session_layer)
955            .service_fn(handler);
956
957        let req = Request::builder().body(Body::empty())?;
958        let res = svc.oneshot(req).await?;
959
960        assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
961
962        Ok(())
963    }
964
965    #[tokio::test]
966    async fn domain_test() -> anyhow::Result<()> {
967        let session_store = MemoryStore::default();
968        let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
969        let svc = ServiceBuilder::new()
970            .layer(session_layer)
971            .service_fn(handler);
972
973        let req = Request::builder().body(Body::empty())?;
974        let res = svc.oneshot(req).await?;
975
976        assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
977
978        Ok(())
979    }
980
981    #[cfg(feature = "signed")]
982    #[tokio::test]
983    async fn signed_test() -> anyhow::Result<()> {
984        let key = Key::generate();
985        let session_store = MemoryStore::default();
986        let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
987        let svc = ServiceBuilder::new()
988            .layer(session_layer)
989            .service_fn(handler);
990
991        let req = Request::builder().body(Body::empty())?;
992        let res = svc.oneshot(req).await?;
993
994        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
995
996        Ok(())
997    }
998
999    #[cfg(feature = "private")]
1000    #[tokio::test]
1001    async fn private_test() -> anyhow::Result<()> {
1002        let key = Key::generate();
1003        let session_store = MemoryStore::default();
1004        let session_layer = SessionManagerLayer::new(session_store).with_private(key);
1005        let svc = ServiceBuilder::new()
1006            .layer(session_layer)
1007            .service_fn(handler);
1008
1009        let req = Request::builder().body(Body::empty())?;
1010        let res = svc.oneshot(req).await?;
1011
1012        assert!(res.headers().get(http::header::SET_COOKIE).is_some());
1013
1014        Ok(())
1015    }
1016
1017    fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
1018    where
1019        F: FnOnce(&str) -> bool,
1020    {
1021        res.headers()
1022            .get(http::header::SET_COOKIE)
1023            .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
1024    }
1025
1026    fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
1027        res.headers()
1028            .get(http::header::SET_COOKIE)
1029            .is_some_and(|set_cookie| {
1030                set_cookie.to_str().is_ok_and(|s| {
1031                    let max_age_value = s
1032                        .split("Max-Age=")
1033                        .nth(1)
1034                        .unwrap_or_default()
1035                        .split(';')
1036                        .next()
1037                        .unwrap_or_default()
1038                        .parse::<i64>()
1039                        .unwrap_or_default();
1040                    (max_age_value - expected_value).abs() <= 1
1041                })
1042            })
1043    }
1044
1045    fn get_session_id(res: &Response<Body>) -> String {
1046        res.headers()
1047            .get(http::header::SET_COOKIE)
1048            .unwrap()
1049            .to_str()
1050            .unwrap()
1051            .split("id=")
1052            .nth(1)
1053            .unwrap()
1054            .split(";")
1055            .next()
1056            .unwrap()
1057            .to_string()
1058    }
1059
1060    async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1061        store
1062            .load(&Id::from_str(id).unwrap())
1063            .await
1064            .unwrap()
1065            .unwrap()
1066    }
1067}