1use crate::{Session, SessionConfig, SessionStore};
6use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
7use http::{Request, Response};
8use http_body::Body;
9use rand::Rng;
10use std::{marker::PhantomData, sync::Arc};
11use tower::{Layer, Service};
12
13#[derive(Debug, Clone)]
17pub struct SessionLayer<S, ReqBody, ResBody>
18where
19 S: SessionStore,
20 ReqBody: Body + Send + 'static,
21 ResBody: Body + Send + 'static,
22{
23 store: S,
25 config: SessionConfig,
27 _phantom: PhantomData<(ReqBody, ResBody)>,
29}
30
31impl<S, ReqBody, ResBody> SessionLayer<S, ReqBody, ResBody>
32where
33 S: SessionStore,
34 ReqBody: Body + Send + 'static,
35 ResBody: Body + Send + 'static,
36{
37 pub fn new(store: S, config: SessionConfig) -> Self {
39 Self { store, config, _phantom: PhantomData }
40 }
41
42 pub fn with_store(store: S) -> Self {
44 Self::new(store, SessionConfig::default())
45 }
46}
47
48impl<S, T, ReqBody, ResBody> Layer<T> for SessionLayer<S, ReqBody, ResBody>
49where
50 S: SessionStore,
51 T: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
52 T::Future: Send,
53 ReqBody: Body + Send + 'static,
54 ResBody: Body + Send + 'static,
55{
56 type Service = SessionService<T, S, ReqBody, ResBody>;
57
58 fn layer(&self, inner: T) -> Self::Service {
59 SessionService { inner, store: self.store.clone(), config: self.config.clone(), _phantom: PhantomData }
60 }
61}
62
63#[derive(Debug, Clone)]
67pub struct SessionService<T, S, ReqBody, ResBody>
68where
69 S: SessionStore,
70 ReqBody: Body + Send + 'static,
71 ResBody: Body + Send + 'static,
72{
73 inner: T,
75 store: S,
77 config: SessionConfig,
79 _phantom: PhantomData<(ReqBody, ResBody)>,
81}
82
83impl<T, S, ReqBody, ResBody> SessionService<T, S, ReqBody, ResBody>
84where
85 S: SessionStore,
86 ReqBody: Body + Send + 'static,
87 ResBody: Body + Send + 'static,
88{
89 fn extract_session_id<B>(&self, request: &Request<B>) -> Option<String> {
91 let cookie_header = request.headers().get("cookie")?.to_str().ok()?;
92
93 for cookie in cookie_header.split(';') {
94 let cookie = cookie.trim();
95 if let Some(value) = cookie.strip_prefix(&format!("{}=", self.config.cookie_name)) {
96 return Some(value.to_string());
97 }
98 }
99
100 None
101 }
102}
103
104impl<T, S, ReqBody, ResBody> Service<Request<ReqBody>> for SessionService<T, S, ReqBody, ResBody>
105where
106 S: SessionStore,
107 T: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
108 T::Future: Send,
109 ReqBody: Body + Send + 'static,
110 ResBody: Body + Send + 'static,
111{
112 type Response = Response<ResBody>;
113 type Error = T::Error;
114 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
115
116 fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<Result<(), Self::Error>> {
117 self.inner.poll_ready(cx)
118 }
119
120 fn call(&mut self, mut request: Request<ReqBody>) -> Self::Future {
121 let mut inner = self.inner.clone();
122 let session_id = self.extract_session_id(&request);
123 let config = self.config.clone();
124 let store = self.store.clone();
125
126 Box::pin(async move {
127 let generate_session_id = || -> String {
128 let mut bytes = vec![0u8; config.id_length];
129 rand::rng().fill_bytes(&mut bytes);
130 URL_SAFE_NO_PAD.encode(&bytes)
131 };
132
133 let session = if let Some(id) = session_id {
134 if let Some(data) = store.get(&id).await {
135 if let Ok(data_map) = serde_json::from_str(&data) {
136 Session::from_data(id.to_string(), data_map).await
137 }
138 else {
139 Session::new(generate_session_id())
140 }
141 }
142 else {
143 Session::new(generate_session_id())
144 }
145 }
146 else {
147 Session::new(generate_session_id())
148 };
149
150 let session = Arc::new(session);
151 request.extensions_mut().insert(session.clone());
152
153 let mut response = inner.call(request).await?;
154
155 if session.is_dirty().await || session.is_new() {
156 let data = session.to_json().await;
157 store.set(session.id(), &data, config.ttl).await;
158
159 let cookie_value = config.build_cookie_header(session.id());
160 if let Ok(header_value) = cookie_value.parse() {
161 response.headers_mut().append(http::header::SET_COOKIE, header_value);
162 }
163 }
164
165 Ok(response)
166 })
167 }
168}