1use std::{
3 borrow::Cow,
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use http::{Request, Response};
11use time::OffsetDateTime;
12#[cfg(any(feature = "signed", feature = "private"))]
13use tower_cookies::Key;
14use tower_cookies::{Cookie, CookieManager, Cookies, cookie::SameSite};
15use tower_layer::Layer;
16use tower_service::Service;
17use tracing::Instrument;
18
19use crate::{
20 Session, SessionStore,
21 session::{self, Expiry},
22};
23
24#[doc(hidden)]
25pub trait CookieController: Clone + Send + 'static {
26 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>>;
27 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>);
28 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>);
29}
30
31#[doc(hidden)]
32#[derive(Debug, Clone)]
33pub struct PlaintextCookie;
34
35impl CookieController for PlaintextCookie {
36 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
37 cookies.get(name).map(Cookie::into_owned)
38 }
39
40 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
41 cookies.add(cookie)
42 }
43
44 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
45 cookies.remove(cookie)
46 }
47}
48
49#[doc(hidden)]
50#[cfg(feature = "signed")]
51#[derive(Debug, Clone)]
52pub struct SignedCookie {
53 key: Key,
54}
55
56#[cfg(feature = "signed")]
57impl CookieController for SignedCookie {
58 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
59 cookies.signed(&self.key).get(name).map(Cookie::into_owned)
60 }
61
62 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
63 cookies.signed(&self.key).add(cookie)
64 }
65
66 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
67 cookies.signed(&self.key).remove(cookie)
68 }
69}
70
71#[doc(hidden)]
72#[cfg(feature = "private")]
73#[derive(Debug, Clone)]
74pub struct PrivateCookie {
75 key: Key,
76}
77
78#[cfg(feature = "private")]
79impl CookieController for PrivateCookie {
80 fn get(&self, cookies: &Cookies, name: &str) -> Option<Cookie<'static>> {
81 cookies.private(&self.key).get(name).map(Cookie::into_owned)
82 }
83
84 fn add(&self, cookies: &Cookies, cookie: Cookie<'static>) {
85 cookies.private(&self.key).add(cookie)
86 }
87
88 fn remove(&self, cookies: &Cookies, cookie: Cookie<'static>) {
89 cookies.private(&self.key).remove(cookie)
90 }
91}
92
93#[derive(Debug, Clone)]
94struct SessionConfig<'a> {
95 name: Cow<'a, str>,
96 http_only: bool,
97 same_site: SameSite,
98 expiry: Option<Expiry>,
99 secure: bool,
100 path: Cow<'a, str>,
101 domain: Option<Cow<'a, str>>,
102 always_save: bool,
103}
104
105impl<'a> SessionConfig<'a> {
106 fn build_cookie(self, session_id: session::Id, expiry: Option<Expiry>) -> Cookie<'a> {
107 let mut cookie_builder = Cookie::build((self.name, session_id.to_string()))
108 .http_only(self.http_only)
109 .same_site(self.same_site)
110 .secure(self.secure)
111 .path(self.path);
112
113 cookie_builder = match expiry {
114 Some(Expiry::OnInactivity(duration)) => cookie_builder.max_age(duration),
115 Some(Expiry::AtDateTime(datetime)) => {
116 cookie_builder.max_age(datetime - OffsetDateTime::now_utc())
117 }
118 Some(Expiry::OnSessionEnd(duration)) => cookie_builder.max_age(duration),
119 None => cookie_builder,
120 };
121
122 if let Some(domain) = self.domain {
123 cookie_builder = cookie_builder.domain(domain);
124 }
125
126 cookie_builder.build()
127 }
128}
129
130impl Default for SessionConfig<'_> {
131 fn default() -> Self {
132 Self {
133 name: "id".into(), http_only: true,
135 same_site: SameSite::Strict,
136 expiry: None, secure: true,
138 path: "/".into(),
139 domain: None,
140 always_save: false,
141 }
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct SessionManager<S, Store: SessionStore, C: CookieController = PlaintextCookie> {
148 inner: S,
149 session_store: Arc<Store>,
150 session_config: SessionConfig<'static>,
151 cookie_controller: C,
152}
153
154impl<S, Store: SessionStore> SessionManager<S, Store> {
155 pub fn new(inner: S, session_store: Store) -> Self {
157 Self {
158 inner,
159 session_store: Arc::new(session_store),
160 session_config: Default::default(),
161 cookie_controller: PlaintextCookie,
162 }
163 }
164}
165
166impl<ReqBody, ResBody, S, Store: SessionStore, C: CookieController> Service<Request<ReqBody>>
167 for SessionManager<S, Store, C>
168where
169 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
170 S::Future: Send,
171 ReqBody: Send + 'static,
172 ResBody: Default + Send,
173{
174 type Response = S::Response;
175 type Error = S::Error;
176 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
177
178 #[inline]
179 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
180 self.inner.poll_ready(cx)
181 }
182
183 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
184 let span = tracing::debug_span!("call");
185
186 let session_store = self.session_store.clone();
187 let session_config = self.session_config.clone();
188 let cookie_controller = self.cookie_controller.clone();
189
190 let clone = self.inner.clone();
195 let mut inner = std::mem::replace(&mut self.inner, clone);
196
197 Box::pin(
198 async move {
199 let Some(cookies) = req.extensions().get::<_>().cloned() else {
200 tracing::error!("missing cookies request extension");
203 return Ok(Response::default());
204 };
205
206 let session_cookie = cookie_controller.get(&cookies, &session_config.name);
207 let session_id = session_cookie.as_ref().and_then(|cookie| {
208 cookie
209 .value()
210 .parse::<session::Id>()
211 .map_err(|err| {
212 tracing::warn!(
213 err = %err,
214 "possibly suspicious activity: malformed session id"
215 )
216 })
217 .ok()
218 });
219
220 let session = Session::new(session_id, session_store, session_config.expiry);
221
222 req.extensions_mut().insert(session.clone());
223
224 let res = inner.call(req).await?;
225
226 let modified = session.is_modified();
227 let empty = session.is_empty().await;
228
229 tracing::trace!(
230 modified = modified,
231 empty = empty,
232 always_save = session_config.always_save,
233 "session response state",
234 );
235
236 match session_cookie {
237 Some(mut cookie) if empty => {
238 tracing::debug!("removing session cookie");
239
240 cookie.set_path(session_config.path);
245 if let Some(domain) = session_config.domain {
246 cookie.set_domain(domain);
247 }
248
249 cookie_controller.remove(&cookies, cookie);
250 }
251
252 _ if (modified || session_config.always_save)
253 && !empty
254 && !res.status().is_server_error() =>
255 {
256 tracing::debug!("saving session");
257 if let Err(err) = session.save().await {
258 tracing::error!(err = %err, "failed to save session");
259
260 let mut res = Response::default();
261 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
262 return Ok(res);
263 }
264
265 let Some(session_id) = session.id() else {
266 tracing::error!("missing session id");
267
268 let mut res = Response::default();
269 *res.status_mut() = http::StatusCode::INTERNAL_SERVER_ERROR;
270 return Ok(res);
271 };
272
273 let expiry = session.expiry();
274 let session_cookie = session_config.build_cookie(session_id, expiry);
275
276 tracing::debug!("adding session cookie");
277 cookie_controller.add(&cookies, session_cookie);
278 }
279
280 _ => (),
281 };
282
283 Ok(res)
284 }
285 .instrument(span),
286 )
287 }
288}
289
290#[derive(Debug, Clone)]
292pub struct SessionManagerLayer<Store: SessionStore, C: CookieController = PlaintextCookie> {
293 session_store: Arc<Store>,
294 session_config: SessionConfig<'static>,
295 cookie_controller: C,
296}
297
298impl<Store: SessionStore, C: CookieController> SessionManagerLayer<Store, C> {
299 pub fn with_name<N: Into<Cow<'static, str>>>(mut self, name: N) -> Self {
311 self.session_config.name = name.into();
312 self
313 }
314
315 pub fn with_http_only(mut self, http_only: bool) -> Self {
333 self.session_config.http_only = http_only;
334 self
335 }
336
337 pub fn with_same_site(mut self, same_site: SameSite) -> Self {
350 self.session_config.same_site = same_site;
351 self
352 }
353
354 pub fn with_expiry(mut self, expiry: Expiry) -> Self {
368 self.session_config.expiry = Some(expiry);
369 self
370 }
371
372 pub fn with_secure(mut self, secure: bool) -> Self {
384 self.session_config.secure = secure;
385 self
386 }
387
388 pub fn with_path<P: Into<Cow<'static, str>>>(mut self, path: P) -> Self {
400 self.session_config.path = path.into();
401 self
402 }
403
404 pub fn with_domain<D: Into<Cow<'static, str>>>(mut self, domain: D) -> Self {
416 self.session_config.domain = Some(domain.into());
417 self
418 }
419
420 pub fn with_always_save(mut self, always_save: bool) -> Self {
447 self.session_config.always_save = always_save;
448 self
449 }
450
451 #[cfg(feature = "signed")]
469 pub fn with_signed(self, key: Key) -> SessionManagerLayer<Store, SignedCookie> {
470 SessionManagerLayer::<Store, SignedCookie> {
471 session_store: self.session_store,
472 session_config: self.session_config,
473 cookie_controller: SignedCookie { key },
474 }
475 }
476
477 #[cfg(feature = "private")]
495 pub fn with_private(self, key: Key) -> SessionManagerLayer<Store, PrivateCookie> {
496 SessionManagerLayer::<Store, PrivateCookie> {
497 session_store: self.session_store,
498 session_config: self.session_config,
499 cookie_controller: PrivateCookie { key },
500 }
501 }
502}
503
504impl<Store: SessionStore> SessionManagerLayer<Store> {
505 pub fn new(session_store: Store) -> Self {
517 let session_config = SessionConfig::default();
518
519 Self {
520 session_store: Arc::new(session_store),
521 session_config,
522 cookie_controller: PlaintextCookie,
523 }
524 }
525}
526
527impl<S, Store: SessionStore, C: CookieController> Layer<S> for SessionManagerLayer<Store, C> {
528 type Service = CookieManager<SessionManager<S, Store, C>>;
529
530 fn layer(&self, inner: S) -> Self::Service {
531 let session_manager = SessionManager {
532 inner,
533 session_store: self.session_store.clone(),
534 session_config: self.session_config.clone(),
535 cookie_controller: self.cookie_controller.clone(),
536 };
537
538 CookieManager::new(session_manager)
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use std::str::FromStr;
545
546 use anyhow::anyhow;
547 use axum::body::Body;
548 use tower::{ServiceBuilder, ServiceExt};
549 use tower_sessions_ext_core::session::DEFAULT_DURATION;
550 use tower_sessions_ext_memory_store::MemoryStore;
551
552 use super::*;
553 use crate::session::{Id, Record};
554
555 async fn handler(req: Request<Body>) -> anyhow::Result<Response<Body>> {
556 let session = req
557 .extensions()
558 .get::<Session>()
559 .ok_or(anyhow!("Missing session"))?;
560
561 session.insert("foo", 42).await?;
562
563 Ok(Response::new(Body::empty()))
564 }
565
566 async fn noop_handler(_: Request<Body>) -> anyhow::Result<Response<Body>> {
567 Ok(Response::new(Body::empty()))
568 }
569
570 #[tokio::test]
571 async fn basic_service_test() -> anyhow::Result<()> {
572 let session_store = MemoryStore::default();
573 let session_layer = SessionManagerLayer::new(session_store);
574 let svc = ServiceBuilder::new()
575 .layer(session_layer)
576 .service_fn(handler);
577
578 let req = Request::builder().body(Body::empty())?;
579 let res = svc.clone().oneshot(req).await?;
580
581 let session = res.headers().get(http::header::SET_COOKIE);
582 assert!(session.is_some());
583
584 let req = Request::builder()
585 .header(http::header::COOKIE, session.unwrap())
586 .body(Body::empty())?;
587 let res = svc.oneshot(req).await?;
588
589 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
590
591 Ok(())
592 }
593
594 #[tokio::test]
595 async fn bogus_cookie_test() -> anyhow::Result<()> {
596 let session_store = MemoryStore::default();
597 let session_layer = SessionManagerLayer::new(session_store);
598 let svc = ServiceBuilder::new()
599 .layer(session_layer)
600 .service_fn(handler);
601
602 let req = Request::builder().body(Body::empty())?;
603 let res = svc.clone().oneshot(req).await?;
604
605 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
606
607 let req = Request::builder()
608 .header(http::header::COOKIE, "id=bogus")
609 .body(Body::empty())?;
610 let res = svc.oneshot(req).await?;
611
612 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
613
614 Ok(())
615 }
616
617 #[tokio::test]
618 async fn no_set_cookie_test() -> anyhow::Result<()> {
619 let session_store = MemoryStore::default();
620 let session_layer = SessionManagerLayer::new(session_store);
621 let svc = ServiceBuilder::new()
622 .layer(session_layer)
623 .service_fn(noop_handler);
624
625 let req = Request::builder().body(Body::empty())?;
626 let res = svc.oneshot(req).await?;
627
628 assert!(res.headers().get(http::header::SET_COOKIE).is_none());
629
630 Ok(())
631 }
632
633 #[tokio::test]
634 async fn name_test() -> anyhow::Result<()> {
635 let session_store = MemoryStore::default();
636 let session_layer = SessionManagerLayer::new(session_store).with_name("my.sid");
637 let svc = ServiceBuilder::new()
638 .layer(session_layer)
639 .service_fn(handler);
640
641 let req = Request::builder().body(Body::empty())?;
642 let res = svc.oneshot(req).await?;
643
644 assert!(cookie_value_matches(&res, |s| s.starts_with("my.sid=")));
645
646 Ok(())
647 }
648
649 #[tokio::test]
650 async fn http_only_test() -> anyhow::Result<()> {
651 let session_store = MemoryStore::default();
652 let session_layer = SessionManagerLayer::new(session_store);
653 let svc = ServiceBuilder::new()
654 .layer(session_layer)
655 .service_fn(handler);
656
657 let req = Request::builder().body(Body::empty())?;
658 let res = svc.oneshot(req).await?;
659
660 assert!(cookie_value_matches(&res, |s| s.contains("HttpOnly")));
661
662 let session_store = MemoryStore::default();
663 let session_layer = SessionManagerLayer::new(session_store).with_http_only(false);
664 let svc = ServiceBuilder::new()
665 .layer(session_layer)
666 .service_fn(handler);
667
668 let req = Request::builder().body(Body::empty())?;
669 let res = svc.oneshot(req).await?;
670
671 assert!(cookie_value_matches(&res, |s| !s.contains("HttpOnly")));
672
673 Ok(())
674 }
675
676 #[tokio::test]
677 async fn same_site_strict_test() -> anyhow::Result<()> {
678 let session_store = MemoryStore::default();
679 let session_layer =
680 SessionManagerLayer::new(session_store).with_same_site(SameSite::Strict);
681 let svc = ServiceBuilder::new()
682 .layer(session_layer)
683 .service_fn(handler);
684
685 let req = Request::builder().body(Body::empty())?;
686 let res = svc.oneshot(req).await?;
687
688 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Strict")));
689
690 Ok(())
691 }
692
693 #[tokio::test]
694 async fn same_site_lax_test() -> anyhow::Result<()> {
695 let session_store = MemoryStore::default();
696 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::Lax);
697 let svc = ServiceBuilder::new()
698 .layer(session_layer)
699 .service_fn(handler);
700
701 let req = Request::builder().body(Body::empty())?;
702 let res = svc.oneshot(req).await?;
703
704 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=Lax")));
705
706 Ok(())
707 }
708
709 #[tokio::test]
710 async fn same_site_none_test() -> anyhow::Result<()> {
711 let session_store = MemoryStore::default();
712 let session_layer = SessionManagerLayer::new(session_store).with_same_site(SameSite::None);
713 let svc = ServiceBuilder::new()
714 .layer(session_layer)
715 .service_fn(handler);
716
717 let req = Request::builder().body(Body::empty())?;
718 let res = svc.oneshot(req).await?;
719
720 assert!(cookie_value_matches(&res, |s| s.contains("SameSite=None")));
721
722 Ok(())
723 }
724
725 #[tokio::test]
726 async fn expiry_on_session_end_test() -> anyhow::Result<()> {
727 let session_store = MemoryStore::default();
728 let session_layer = SessionManagerLayer::new(session_store)
729 .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION));
730 let svc = ServiceBuilder::new()
731 .layer(session_layer)
732 .service_fn(handler);
733
734 let req = Request::builder().body(Body::empty())?;
735 let res = svc.oneshot(req).await?;
736
737 assert!(cookie_value_matches(&res, |s| !s.contains("Max-Age")));
738
739 Ok(())
740 }
741
742 #[tokio::test]
743 async fn expiry_on_inactivity_test() -> anyhow::Result<()> {
744 let session_store = MemoryStore::default();
745 let inactivity_duration = time::Duration::hours(2);
746 let session_layer = SessionManagerLayer::new(session_store)
747 .with_expiry(Expiry::OnInactivity(inactivity_duration));
748 let svc = ServiceBuilder::new()
749 .layer(session_layer)
750 .service_fn(handler);
751
752 let req = Request::builder().body(Body::empty())?;
753 let res = svc.oneshot(req).await?;
754
755 let expected_max_age = inactivity_duration.whole_seconds();
756 assert!(cookie_has_expected_max_age(&res, expected_max_age));
757
758 Ok(())
759 }
760
761 #[tokio::test]
762 async fn expiry_at_date_time_test() -> anyhow::Result<()> {
763 let session_store = MemoryStore::default();
764 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
765 let session_layer =
766 SessionManagerLayer::new(session_store).with_expiry(Expiry::AtDateTime(expiry_time));
767 let svc = ServiceBuilder::new()
768 .layer(session_layer)
769 .service_fn(handler);
770
771 let req = Request::builder().body(Body::empty())?;
772 let res = svc.oneshot(req).await?;
773
774 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
775 assert!(cookie_has_expected_max_age(&res, expected_max_age));
776
777 Ok(())
778 }
779
780 #[tokio::test]
781 async fn expiry_on_session_end_always_save_test() -> anyhow::Result<()> {
782 let session_store = MemoryStore::default();
783 let session_layer = SessionManagerLayer::new(session_store.clone())
784 .with_expiry(Expiry::OnSessionEnd(DEFAULT_DURATION))
785 .with_always_save(true);
786 let mut svc = ServiceBuilder::new()
787 .layer(session_layer)
788 .service_fn(handler);
789
790 let req1 = Request::builder().body(Body::empty())?;
791 let res1 = svc.call(req1).await?;
792 let sid1 = get_session_id(&res1);
793 let rec1 = get_record(&session_store, &sid1).await;
794 let req2 = Request::builder()
795 .header(http::header::COOKIE, format!("id={}", sid1))
796 .body(Body::empty())?;
797 let res2 = svc.call(req2).await?;
798 let sid2 = get_session_id(&res2);
799 let rec2 = get_record(&session_store, &sid2).await;
800
801 assert!(cookie_value_matches(&res2, |s| !s.contains("Max-Age")));
802 assert!(sid1 == sid2);
803 assert!(rec1.expiry_date < rec2.expiry_date);
804
805 Ok(())
806 }
807
808 #[tokio::test]
809 async fn expiry_on_inactivity_always_save_test() -> anyhow::Result<()> {
810 let session_store = MemoryStore::default();
811 let inactivity_duration = time::Duration::hours(2);
812 let session_layer = SessionManagerLayer::new(session_store.clone())
813 .with_expiry(Expiry::OnInactivity(inactivity_duration))
814 .with_always_save(true);
815 let mut svc = ServiceBuilder::new()
816 .layer(session_layer)
817 .service_fn(handler);
818
819 let req1 = Request::builder().body(Body::empty())?;
820 let res1 = svc.call(req1).await?;
821 let sid1 = get_session_id(&res1);
822 let rec1 = get_record(&session_store, &sid1).await;
823 let req2 = Request::builder()
824 .header(http::header::COOKIE, format!("id={}", sid1))
825 .body(Body::empty())?;
826 let res2 = svc.call(req2).await?;
827 let sid2 = get_session_id(&res2);
828 let rec2 = get_record(&session_store, &sid2).await;
829
830 let expected_max_age = inactivity_duration.whole_seconds();
831 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
832 assert!(sid1 == sid2);
833 assert!(rec1.expiry_date < rec2.expiry_date);
834
835 Ok(())
836 }
837
838 #[tokio::test]
839 async fn expiry_at_date_time_always_save_test() -> anyhow::Result<()> {
840 let session_store = MemoryStore::default();
841 let expiry_time = time::OffsetDateTime::now_utc() + time::Duration::weeks(1);
842 let session_layer = SessionManagerLayer::new(session_store.clone())
843 .with_expiry(Expiry::AtDateTime(expiry_time))
844 .with_always_save(true);
845 let mut svc = ServiceBuilder::new()
846 .layer(session_layer)
847 .service_fn(handler);
848
849 let req1 = Request::builder().body(Body::empty())?;
850 let res1 = svc.call(req1).await?;
851 let sid1 = get_session_id(&res1);
852 let rec1 = get_record(&session_store, &sid1).await;
853 let req2 = Request::builder()
854 .header(http::header::COOKIE, format!("id={}", sid1))
855 .body(Body::empty())?;
856 let res2 = svc.call(req2).await?;
857 let sid2 = get_session_id(&res2);
858 let rec2 = get_record(&session_store, &sid2).await;
859
860 let expected_max_age = (expiry_time - time::OffsetDateTime::now_utc()).whole_seconds();
861 assert!(cookie_has_expected_max_age(&res2, expected_max_age));
862 assert!(sid1 == sid2);
863 assert!(rec1.expiry_date == rec2.expiry_date);
864
865 Ok(())
866 }
867
868 #[tokio::test]
869 async fn secure_test() -> anyhow::Result<()> {
870 let session_store = MemoryStore::default();
871 let session_layer = SessionManagerLayer::new(session_store).with_secure(true);
872 let svc = ServiceBuilder::new()
873 .layer(session_layer)
874 .service_fn(handler);
875
876 let req = Request::builder().body(Body::empty())?;
877 let res = svc.oneshot(req).await?;
878
879 assert!(cookie_value_matches(&res, |s| s.contains("Secure")));
880
881 let session_store = MemoryStore::default();
882 let session_layer = SessionManagerLayer::new(session_store).with_secure(false);
883 let svc = ServiceBuilder::new()
884 .layer(session_layer)
885 .service_fn(handler);
886
887 let req = Request::builder().body(Body::empty())?;
888 let res = svc.oneshot(req).await?;
889
890 assert!(cookie_value_matches(&res, |s| !s.contains("Secure")));
891
892 Ok(())
893 }
894
895 #[tokio::test]
896 async fn path_test() -> anyhow::Result<()> {
897 let session_store = MemoryStore::default();
898 let session_layer = SessionManagerLayer::new(session_store).with_path("/foo/bar");
899 let svc = ServiceBuilder::new()
900 .layer(session_layer)
901 .service_fn(handler);
902
903 let req = Request::builder().body(Body::empty())?;
904 let res = svc.oneshot(req).await?;
905
906 assert!(cookie_value_matches(&res, |s| s.contains("Path=/foo/bar")));
907
908 Ok(())
909 }
910
911 #[tokio::test]
912 async fn domain_test() -> anyhow::Result<()> {
913 let session_store = MemoryStore::default();
914 let session_layer = SessionManagerLayer::new(session_store).with_domain("example.com");
915 let svc = ServiceBuilder::new()
916 .layer(session_layer)
917 .service_fn(handler);
918
919 let req = Request::builder().body(Body::empty())?;
920 let res = svc.oneshot(req).await?;
921
922 assert!(cookie_value_matches(&res, |s| s.contains("Domain=example.com")));
923
924 Ok(())
925 }
926
927 #[cfg(feature = "signed")]
928 #[tokio::test]
929 async fn signed_test() -> anyhow::Result<()> {
930 let key = Key::generate();
931 let session_store = MemoryStore::default();
932 let session_layer = SessionManagerLayer::new(session_store).with_signed(key);
933 let svc = ServiceBuilder::new()
934 .layer(session_layer)
935 .service_fn(handler);
936
937 let req = Request::builder().body(Body::empty())?;
938 let res = svc.oneshot(req).await?;
939
940 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
941
942 Ok(())
943 }
944
945 #[cfg(feature = "private")]
946 #[tokio::test]
947 async fn private_test() -> anyhow::Result<()> {
948 let key = Key::generate();
949 let session_store = MemoryStore::default();
950 let session_layer = SessionManagerLayer::new(session_store).with_private(key);
951 let svc = ServiceBuilder::new()
952 .layer(session_layer)
953 .service_fn(handler);
954
955 let req = Request::builder().body(Body::empty())?;
956 let res = svc.oneshot(req).await?;
957
958 assert!(res.headers().get(http::header::SET_COOKIE).is_some());
959
960 Ok(())
961 }
962
963 fn cookie_value_matches<F>(res: &Response<Body>, matcher: F) -> bool
964 where
965 F: FnOnce(&str) -> bool,
966 {
967 res.headers()
968 .get(http::header::SET_COOKIE)
969 .is_some_and(|set_cookie| set_cookie.to_str().is_ok_and(matcher))
970 }
971
972 fn cookie_has_expected_max_age(res: &Response<Body>, expected_value: i64) -> bool {
973 res.headers()
974 .get(http::header::SET_COOKIE)
975 .is_some_and(|set_cookie| {
976 set_cookie.to_str().is_ok_and(|s| {
977 let max_age_value = s
978 .split("Max-Age=")
979 .nth(1)
980 .unwrap_or_default()
981 .split(';')
982 .next()
983 .unwrap_or_default()
984 .parse::<i64>()
985 .unwrap_or_default();
986 (max_age_value - expected_value).abs() <= 1
987 })
988 })
989 }
990
991 fn get_session_id(res: &Response<Body>) -> String {
992 res.headers()
993 .get(http::header::SET_COOKIE)
994 .unwrap()
995 .to_str()
996 .unwrap()
997 .split("id=")
998 .nth(1)
999 .unwrap()
1000 .split(";")
1001 .next()
1002 .unwrap()
1003 .to_string()
1004 }
1005
1006 async fn get_record(store: &impl SessionStore, id: &str) -> Record {
1007 store
1008 .load(&Id::from_str(id).unwrap())
1009 .await
1010 .unwrap()
1011 .unwrap()
1012 }
1013}