1use std::{
3 borrow::Cow,
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use http::{Request, Response};
11use time::OffsetDateTime;
12#[cfg(any(feature = "signed", feature = "private"))]
13use tower_cookies::Key;
14use tower_cookies::{cookie::SameSite, Cookie, CookieManager, Cookies};
15use tower_layer::Layer;
16use tower_service::Service;
17use tracing::Instrument;
18
19use crate::{
20 session::{self, Expiry},
21 Session, SessionStore,
22};
23
24#[doc(hidden)]
25pub trait CookieController: Clone + Send + 'static {
26 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
27 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
28 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
29}
30
31#[doc(hidden)]
32#[derive(Debug, Clone)]
33pub struct PlaintextCookie;
34
35impl CookieController for PlaintextCookie {
36 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
37 cookies.get(name).map(Cookie::into_owned)
38 }
39
40 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
41 cookies.add(cookie)
42 }
43
44 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
45 cookies.remove(cookie)
46 }
47}
48
49#[doc(hidden)]
50#[cfg(feature = "signed")]
51#[derive(Debug, Clone)]
52pub struct SignedCookie {
53 key: Key,
54}
55
56#[cfg(feature = "signed")]
57impl CookieController for SignedCookie {
58 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
59 cookies.signed(&self.key).get(name).map(Cookie::into_owned)
60 }
61
62 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
63 cookies.signed(&self.key).add(cookie)
64 }
65
66 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
67 cookies.signed(&self.key).remove(cookie)
68 }
69}
70
71#[doc(hidden)]
72#[cfg(feature = "private")]
73#[derive(Debug, Clone)]
74pub struct PrivateCookie {
75 key: Key,
76}
77
78#[cfg(feature = "private")]
79impl CookieController for PrivateCookie {
80 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
81 cookies.private(&self.key).get(name).map(Cookie::into_owned)
82 }
83
84 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
85 cookies.private(&self.key).add(cookie)
86 }
87
88 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
89 cookies.private(&self.key).remove(cookie)
90 }
91}
92
93#[derive(Debug, Clone)]
94struct SessionConfig<'a> {
95 name: Cow<'a, str>,
96 http_only: bool,
97 same_site: SameSite,
98 expiry: Option<Expiry>,
99 secure: bool,
100 path: Cow<'a, str>,
101 domain: Option<Cow<'a, str>>,
102 always_save: bool,
103}
104
105impl<'a> SessionConfig<'a> {
106 fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
107 let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
108 .http_only(self.http_only)
109 .same_site(self.same_site)
110 .secure(self.secure)
111 .path(self.path);
112
113 cookie_builder = match expiry {
114 Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
115 Some(Expiry::AtDateTime(datetime)) => {
116 cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
117 }
118 Some(Expiry::OnSessionEnd) | None => cookie_builder,
119 };
120
121 if let Some(domain) = self.domain {
122 cookie_builder = cookie_builder.domain(domain);
123 }
124
125 cookie_builder.build()
126 }
127}
128
129impl Default for SessionConfig<'_> {
130 fn default() -> Self {
131 Self {
132 name: "id".into(), http_only: true,
134 same_site: SameSite::Strict,
135 expiry: None, secure: true,
137 path: "/".into(),
138 domain: None,
139 always_save: false,
140 }
141 }
142}
143
144#[derive(Debug, Clone)]
146pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
147 inner: S,
148 session_store: Arc<Store>,
149 session_config: SessionConfig<'static>,
150 cookie_controller: C,
151}
152
153impl<S, Store: SessionStore> SessionManager<S, Store> {
154 pub fn new(inner: S, session_store: Store) -> Self {
156 Self {
157 inner,
158 session_store: Arc::new(session_store),
159 session_config: Default::default(),
160 cookie_controller: PlaintextCookie,
161 }
162 }
163}
164
165impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
166 for SessionManager<S, Store, C>
167where
168 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
169 S::Future: Send,
170 ReqBody: Send + 'static,
171 ResBody: Default + Send,
172{
173 type Response = S::Response;
174 type Error = S::Error;
175 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
176
177 #[inline]
178 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
179 self.inner.poll_ready(cx)
180 }
181
182 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
183 let span = tracing::info_span!("call");
184
185 let session_store = self.session_store.clone();
186 let session_config = self.session_config.clone();
187 let cookie_controller = self.cookie_controller.clone();
188
189 let clone = self.inner.clone();
194 let mut inner = std::mem::replace(&mut self.inner, clone);
195
196 Box::pin(
197 async move {
198 let Some(cookies) = req.extensions().get::<_>().cloned() else {
199 tracing::error!("missing cookies request extension");
202 return Ok(Response::default());
203 };
204
205 let session_cookie = cookie_controller.get(&cookies, &session_config.name);
206 let session_id = session_cookie.as_ref().and_then(|cookie| {
207 cookie
208 .value()
209 .parse::<session::Id>()
210 .map_err(|err| {
211 tracing::warn!(
212 err = %err,
213 "possibly suspicious activity: malformed session id"
214 )
215 })
216 .ok()
217 });
218
219 let session = Session::new(session_id, session_store, session_config.expiry);
220
221 req.extensions_mut().insert(session.clone());
222
223 let res = inner.call(req).await?;
224
225 let modified = session.is_modified();
226 let empty = session.is_empty().await;
227
228 tracing::trace!(
229 modified = modified,
230 empty = empty,
231 always_save = session_config.always_save,
232 "session response state",
233 );
234
235 match session_cookie {
236 Some(mut cookie) if empty => {
237 tracing::debug!("removing session cookie");
238
239 cookie.set_path(session_config.path);
244 if let Some(domain) = session_config.domain {
245 cookie.set_domain(domain);
246 }
247
248 cookie_controller.remove(&cookies, cookie);
249 }
250
251 _ if (modified || session_config.always_save)
252 && !empty
253 && !res.status().is_server_error() =>
254 {
255 tracing::debug!("saving session");
256 if let Err(err) = session.save().await {
257 tracing::error!(err = %err, "failed to save session");
258
259 let mut res = Response::default();
260 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
261 return Ok(res);
262 }
263
264 let Some(session_id) = session.id() else {
265 tracing::error!("missing session id");
266
267 let mut res = Response::default();
268 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
269 return Ok(res);
270 };
271
272 let expiry = session.expiry();
273 let session_cookie = session_config.build_cookie(session_id, expiry);
274
275 tracing::debug!("adding session cookie");
276 cookie_controller.add(&cookies, session_cookie);
277 }
278
279 _ => (),
280 };
281
282 Ok(res)
283 }
284 .instrument(span),
285 )
286 }
287}
288
289#[derive(Debug, Clone)]
291pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
292 session_store: Arc<Store>,
293 session_config: SessionConfig<'static>,
294 cookie_controller: C,
295}
296
297impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
298 pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
310 self.session_config.name = name.into();
311 self
312 }
313
314 pub fn with_http_only(mut self, http_only: bool) -> Self {
332 self.session_config.http_only = http_only;
333 self
334 }
335
336 pub fn with_same_site(mut self, same_site: SameSite) -> Self {
349 self.session_config.same_site = same_site;
350 self
351 }
352
353 pub fn with_expiry(mut self, expiry: Expiry) -> Self {
367 self.session_config.expiry = Some(expiry);
368 self
369 }
370
371 pub fn with_secure(mut self, secure: bool) -> Self {
383 self.session_config.secure = secure;
384 self
385 }
386
387 pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
399 self.session_config.path = path.into();
400 self
401 }
402
403 pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
415 self.session_config.domain = Some(domain.into());
416 self
417 }
418
419 pub fn with_always_save(mut self, always_save: bool) -> Self {
446 self.session_config.always_save = always_save;
447 self
448 }
449
450 #[cfg(feature = "signed")]
468 pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
469 SessionManagerLayer::<Store, SignedCookie> {
470 session_store: self.session_store,
471 session_config: self.session_config,
472 cookie_controller: SignedCookie { key },
473 }
474 }
475
476 #[cfg(feature = "private")]
494 pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
495 SessionManagerLayer::<Store, PrivateCookie> {
496 session_store: self.session_store,
497 session_config: self.session_config,
498 cookie_controller: PrivateCookie { key },
499 }
500 }
501}
502
503impl<Store: SessionStore> SessionManagerLayer<Store> {
504 pub fn new(session_store: Store) -> Self {
516 let session_config = SessionConfig::default();
517
518 Self {
519 session_store: Arc::new(session_store),
520 session_config,
521 cookie_controller: PlaintextCookie,
522 }
523 }
524}
525
526impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
527 type Service = CookieManager<SessionManager<S, Store, C>>;
528
529 fn layer(&self, inner: S) -> Self::Service {
530 let session_manager = SessionManager {
531 inner,
532 session_store: self.session_store.clone(),
533 session_config: self.session_config.clone(),
534 cookie_controller: self.cookie_controller.clone(),
535 };
536
537 CookieManager::new(session_manager)
538 }
539}
540
541#[cfg(test)]
542mod tests {
543 use std::str::FromStr;
544
545 use anyhow::anyhow;
546 use axum::body::Body;
547 use tower::{ServiceBuilder, ServiceExt};
548 use tower_sessions_memory_store::MemoryStore;
549
550 use super::*;
551 use crate::session::{Id, Record};
552
553 async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
554 let session = req
555 .extensions()
556 .get::<Session>()
557 .ok_or(anyhow!("Missing session"))?;
558
559 session.insert("foo", 42).await?;
560
561 Ok(Response::new(Body::empty()))
562 }
563
564 async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
565 Ok(Response::new(Body::empty()))
566 }
567
568 #[tokio::test]
569 async fn basic_service_test() -> anyhow::Result<()> {
570 let session_store = MemoryStore::default();
571 let session_layer = SessionManagerLayer::new(session_store);
572 let svc = ServiceBuilder::new()
573 .layer(session_layer)
574 .service_fn(handler);
575
576 let req = Request::builder().body(Body::empty())?;
577 let res = svc.clone().oneshot(req).await?;
578
579 let session = res.headers().get(http::header::SET_COOKIE);
580 assert!(session.is_some());
581
582 let req = Request::builder()
583 .header(http::header::COOKIE, session.unwrap())
584 .body(Body::empty())?;
585 let res = svc.oneshot(req).await?;
586
587 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
588
589 Ok(())
590 }
591
592 #[tokio::test]
593 async fn bogus_cookie_test() -> anyhow::Result<()> {
594 let session_store = MemoryStore::default();
595 let session_layer = SessionManagerLayer::new(session_store);
596 let svc = ServiceBuilder::new()
597 .layer(session_layer)
598 .service_fn(handler);
599
600 let req = Request::builder().body(Body::empty())?;
601 let res = svc.clone().oneshot(req).await?;
602
603 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
604
605 let req = Request::builder()
606 .header(http::header::COOKIE, "id=bogus")
607 .body(Body::empty())?;
608 let res = svc.oneshot(req).await?;
609
610 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
611
612 Ok(())
613 }
614
615 #[tokio::test]
616 async fn no_set_cookie_test() -> anyhow::Result<()> {
617 let session_store = MemoryStore::default();
618 let session_layer = SessionManagerLayer::new(session_store);
619 let svc = ServiceBuilder::new()
620 .layer(session_layer)
621 .service_fn(noop_handler);
622
623 let req = Request::builder().body(Body::empty())?;
624 let res = svc.oneshot(req).await?;
625
626 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
627
628 Ok(())
629 }
630
631 #[tokio::test]
632 async fn name_test() -> anyhow::Result<()> {
633 let session_store = MemoryStore::default();
634 let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
635 let svc = ServiceBuilder::new()
636 .layer(session_layer)
637 .service_fn(handler);
638
639 let req = Request::builder().body(Body::empty())?;
640 let res = svc.oneshot(req).await?;
641
642 assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
643
644 Ok(())
645 }
646
647 #[tokio::test]
648 async fn http_only_test() -> anyhow::Result<()> {
649 let session_store = MemoryStore::default();
650 let session_layer = SessionManagerLayer::new(session_store);
651 let svc = ServiceBuilder::new()
652 .layer(session_layer)
653 .service_fn(handler);
654
655 let req = Request::builder().body(Body::empty())?;
656 let res = svc.oneshot(req).await?;
657
658 assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
659
660 let session_store = MemoryStore::default();
661 let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
662 let svc = ServiceBuilder::new()
663 .layer(session_layer)
664 .service_fn(handler);
665
666 let req = Request::builder().body(Body::empty())?;
667 let res = svc.oneshot(req).await?;
668
669 assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
670
671 Ok(())
672 }
673
674 #[tokio::test]
675 async fn same_site_strict_test() -> anyhow::Result<()> {
676 let session_store = MemoryStore::default();
677 let session_layer =
678 SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
679 let svc = ServiceBuilder::new()
680 .layer(session_layer)
681 .service_fn(handler);
682
683 let req = Request::builder().body(Body::empty())?;
684 let res = svc.oneshot(req).await?;
685
686 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
687
688 Ok(())
689 }
690
691 #[tokio::test]
692 async fn same_site_lax_test() -> anyhow::Result<()> {
693 let session_store = MemoryStore::default();
694 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
695 let svc = ServiceBuilder::new()
696 .layer(session_layer)
697 .service_fn(handler);
698
699 let req = Request::builder().body(Body::empty())?;
700 let res = svc.oneshot(req).await?;
701
702 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
703
704 Ok(())
705 }
706
707 #[tokio::test]
708 async fn same_site_none_test() -> anyhow::Result<()> {
709 let session_store = MemoryStore::default();
710 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
711 let svc = ServiceBuilder::new()
712 .layer(session_layer)
713 .service_fn(handler);
714
715 let req = Request::builder().body(Body::empty())?;
716 let res = svc.oneshot(req).await?;
717
718 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
719
720 Ok(())
721 }
722
723 #[tokio::test]
724 async fn expiry_on_session_end_test() -> anyhow::Result<()> {
725 let session_store = MemoryStore::default();
726 let session_layer =
727 SessionManagerLayer::new(session_store).with_expiry(Expiry::OnSessionEnd);
728 let svc = ServiceBuilder::new()
729 .layer(session_layer)
730 .service_fn(handler);
731
732 let req = Request::builder().body(Body::empty())?;
733 let res = svc.oneshot(req).await?;
734
735 assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
736
737 Ok(())
738 }
739
740 #[tokio::test]
741 async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
742 let session_store = MemoryStore::default();
743 let inactivity_duration = time::Duration::hours(2);
744 let session_layer = SessionManagerLayer::new(session_store)
745 .with_expiry(Expiry::OnInactivity(inactivity_duration));
746 let svc = ServiceBuilder::new()
747 .layer(session_layer)
748 .service_fn(handler);
749
750 let req = Request::builder().body(Body::empty())?;
751 let res = svc.oneshot(req).await?;
752
753 let expected_max_age = inactivity_duration.whole_seconds();
754 assert!(cookie_has_expected_max_age(&res, expected_max_age));
755
756 Ok(())
757 }
758
759 #[tokio::test]
760 async fn expiry_at_date_time_test() -> anyhow::Result<()> {
761 let session_store = MemoryStore::default();
762 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
763 let session_layer =
764 SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
765 let svc = ServiceBuilder::new()
766 .layer(session_layer)
767 .service_fn(handler);
768
769 let req = Request::builder().body(Body::empty())?;
770 let res = svc.oneshot(req).await?;
771
772 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
773 assert!(cookie_has_expected_max_age(&res, expected_max_age));
774
775 Ok(())
776 }
777
778 #[tokio::test]
779 async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
780 let session_store = MemoryStore::default();
781 let session_layer = SessionManagerLayer::new(session_store.clone())
782 .with_expiry(Expiry::OnSessionEnd)
783 .with_always_save(true);
784 let mut svc = ServiceBuilder::new()
785 .layer(session_layer)
786 .service_fn(handler);
787
788 let req1 = Request::builder().body(Body::empty())?;
789 let res1 = svc.call(req1).await?;
790 let sid1 = get_session_id(&res1);
791 let rec1 = get_record(&session_store, &sid1).await;
792 let req2 = Request::builder()
793 .header(http::header::COOKIE, format!("id={}", sid1))
794 .body(Body::empty())?;
795 let res2 = svc.call(req2).await?;
796 let sid2 = get_session_id(&res2);
797 let rec2 = get_record(&session_store, &sid2).await;
798
799 assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
800 assert!(sid1 == sid2);
801 assert!(rec1.expiry_date < rec2.expiry_date);
802
803 Ok(())
804 }
805
806 #[tokio::test]
807 async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
808 let session_store = MemoryStore::default();
809 let inactivity_duration = time::Duration::hours(2);
810 let session_layer = SessionManagerLayer::new(session_store.clone())
811 .with_expiry(Expiry::OnInactivity(inactivity_duration))
812 .with_always_save(true);
813 let mut svc = ServiceBuilder::new()
814 .layer(session_layer)
815 .service_fn(handler);
816
817 let req1 = Request::builder().body(Body::empty())?;
818 let res1 = svc.call(req1).await?;
819 let sid1 = get_session_id(&res1);
820 let rec1 = get_record(&session_store, &sid1).await;
821 let req2 = Request::builder()
822 .header(http::header::COOKIE, format!("id={}", sid1))
823 .body(Body::empty())?;
824 let res2 = svc.call(req2).await?;
825 let sid2 = get_session_id(&res2);
826 let rec2 = get_record(&session_store, &sid2).await;
827
828 let expected_max_age = inactivity_duration.whole_seconds();
829 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
830 assert!(sid1 == sid2);
831 assert!(rec1.expiry_date < rec2.expiry_date);
832
833 Ok(())
834 }
835
836 #[tokio::test]
837 async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
838 let session_store = MemoryStore::default();
839 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
840 let session_layer = SessionManagerLayer::new(session_store.clone())
841 .with_expiry(Expiry::AtDateTime(expiry_time))
842 .with_always_save(true);
843 let mut svc = ServiceBuilder::new()
844 .layer(session_layer)
845 .service_fn(handler);
846
847 let req1 = Request::builder().body(Body::empty())?;
848 let res1 = svc.call(req1).await?;
849 let sid1 = get_session_id(&res1);
850 let rec1 = get_record(&session_store, &sid1).await;
851 let req2 = Request::builder()
852 .header(http::header::COOKIE, format!("id={}", sid1))
853 .body(Body::empty())?;
854 let res2 = svc.call(req2).await?;
855 let sid2 = get_session_id(&res2);
856 let rec2 = get_record(&session_store, &sid2).await;
857
858 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
859 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
860 assert!(sid1 == sid2);
861 assert!(rec1.expiry_date == rec2.expiry_date);
862
863 Ok(())
864 }
865
866 #[tokio::test]
867 async fn secure_test() -> anyhow::Result<()> {
868 let session_store = MemoryStore::default();
869 let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
870 let svc = ServiceBuilder::new()
871 .layer(session_layer)
872 .service_fn(handler);
873
874 let req = Request::builder().body(Body::empty())?;
875 let res = svc.oneshot(req).await?;
876
877 assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
878
879 let session_store = MemoryStore::default();
880 let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
881 let svc = ServiceBuilder::new()
882 .layer(session_layer)
883 .service_fn(handler);
884
885 let req = Request::builder().body(Body::empty())?;
886 let res = svc.oneshot(req).await?;
887
888 assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
889
890 Ok(())
891 }
892
893 #[tokio::test]
894 async fn path_test() -> anyhow::Result<()> {
895 let session_store = MemoryStore::default();
896 let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
897 let svc = ServiceBuilder::new()
898 .layer(session_layer)
899 .service_fn(handler);
900
901 let req = Request::builder().body(Body::empty())?;
902 let res = svc.oneshot(req).await?;
903
904 assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
905
906 Ok(())
907 }
908
909 #[tokio::test]
910 async fn domain_test() -> anyhow::Result<()> {
911 let session_store = MemoryStore::default();
912 let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
913 let svc = ServiceBuilder::new()
914 .layer(session_layer)
915 .service_fn(handler);
916
917 let req = Request::builder().body(Body::empty())?;
918 let res = svc.oneshot(req).await?;
919
920 assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
921
922 Ok(())
923 }
924
925 #[cfg(feature = "signed")]
926 #[tokio::test]
927 async fn signed_test() -> anyhow::Result<()> {
928 let key = Key::generate();
929 let session_store = MemoryStore::default();
930 let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
931 let svc = ServiceBuilder::new()
932 .layer(session_layer)
933 .service_fn(handler);
934
935 let req = Request::builder().body(Body::empty())?;
936 let res = svc.oneshot(req).await?;
937
938 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
939
940 Ok(())
941 }
942
943 #[cfg(feature = "private")]
944 #[tokio::test]
945 async fn private_test() -> anyhow::Result<()> {
946 let key = Key::generate();
947 let session_store = MemoryStore::default();
948 let session_layer = SessionManagerLayer::new(session_store).with_private(key);
949 let svc = ServiceBuilder::new()
950 .layer(session_layer)
951 .service_fn(handler);
952
953 let req = Request::builder().body(Body::empty())?;
954 let res = svc.oneshot(req).await?;
955
956 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
957
958 Ok(())
959 }
960
961 fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
962 where
963 F: FnOnce(&str) -> bool,
964 {
965 res.headers()
966 .get(http::header::SET_COOKIE)
967 .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
968 }
969
970 fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
971 res.headers()
972 .get(http::header::SET_COOKIE)
973 .is_some_and(|set_cookie| {
974 set_cookie.to_str().is_ok_and(|s| {
975 let max_age_value = s
976 .split("Max-Age=")
977 .nth(1)
978 .unwrap_or_default()
979 .split(';')
980 .next()
981 .unwrap_or_default()
982 .parse::<i64>()
983 .unwrap_or_default();
984 (max_age_value - expected_value).abs() <= 1
985 })
986 })
987 }
988
989 fn get_session_id(res: &Response<Body>) -> String {
990 res.headers()
991 .get(http::header::SET_COOKIE)
992 .unwrap()
993 .to_str()
994 .unwrap()
995 .split("id=")
996 .nth(1)
997 .unwrap()
998 .split(";")
999 .next()
1000 .unwrap()
1001 .to_string()
1002 }
1003
1004 async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1005 store
1006 .load(&Id::from_str(id).unwrap())
1007 .await
1008 .unwrap()
1009 .unwrap()
1010 }
1011}