1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
64#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
65#![cfg_attr(docsrs, feature(doc_cfg))]
66
67pub use async_session::{CookieStore, MemoryStore, Session, SessionStore};
68
69use std::fmt::{self, Formatter};
70use std::time::Duration;
71
72use async_session::base64;
73use async_session::hmac::{Hmac, Mac, NewMac};
74use async_session::sha2::Sha256;
75use cookie::{Cookie, Key, SameSite};
76use salvo_core::http::uri::Scheme;
77use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
78
79pub const SESSION_KEY: &str = "::salvo::session";
81const BASE64_DIGEST_LEN: usize = 44;
82
83pub trait SessionDepotExt {
85 fn set_session(&mut self, session: Session) -> &mut Self;
87 fn take_session(&mut self) -> Option<Session>;
89 fn session(&self) -> Option<&Session>;
91 fn session_mut(&mut self) -> Option<&mut Session>;
93}
94
95impl SessionDepotExt for Depot {
96 #[inline]
97 fn set_session(&mut self, session: Session) -> &mut Self {
98 self.insert(SESSION_KEY, session);
99 self
100 }
101 #[inline]
102 fn take_session(&mut self) -> Option<Session> {
103 self.remove(SESSION_KEY).ok()
104 }
105 #[inline]
106 fn session(&self) -> Option<&Session> {
107 self.get(SESSION_KEY).ok()
108 }
109 #[inline]
110 fn session_mut(&mut self) -> Option<&mut Session> {
111 self.get_mut(SESSION_KEY).ok()
112 }
113}
114
115pub struct HandlerBuilder<S> {
117 store: S,
118 cookie_path: String,
119 cookie_name: String,
120 cookie_domain: Option<String>,
121 session_ttl: Option<Duration>,
122 save_unchanged: bool,
123 same_site_policy: SameSite,
124 key: Key,
125 fallback_keys: Vec<Key>,
126}
127impl<S: SessionStore> fmt::Debug for HandlerBuilder<S> {
128 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
129 f.debug_struct("HandlerBuilder")
130 .field("store", &self.store)
131 .field("cookie_path", &self.cookie_path)
132 .field("cookie_name", &self.cookie_name)
133 .field("cookie_domain", &self.cookie_domain)
134 .field("session_ttl", &self.session_ttl)
135 .field("same_site_policy", &self.same_site_policy)
136 .field("key", &"..")
137 .field("fallback_keys", &"..")
138 .field("save_unchanged", &self.save_unchanged)
139 .finish()
140 }
141}
142
143impl<S> HandlerBuilder<S>
144where
145 S: SessionStore,
146{
147 #[inline]
149 pub fn new(store: S, secret: &[u8]) -> Self {
150 Self {
151 store,
152 save_unchanged: true,
153 cookie_path: "/".into(),
154 cookie_name: "salvo.session.id".into(),
155 cookie_domain: None,
156 same_site_policy: SameSite::Lax,
157 session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
158 key: Key::from(secret),
159 fallback_keys: vec![],
160 }
161 }
162
163 #[inline]
167 pub fn cookie_path(mut self, cookie_path: impl Into<String>) -> Self {
168 self.cookie_path = cookie_path.into();
169 self
170 }
171
172 #[inline]
178 pub fn session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
179 self.session_ttl = session_ttl;
180 self
181 }
182
183 #[inline]
189 pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
190 self.cookie_name = cookie_name.into();
191 self
192 }
193
194 #[inline]
204 pub fn save_unchanged(mut self, value: bool) -> Self {
205 self.save_unchanged = value;
206 self
207 }
208
209 #[inline]
214 pub fn same_site_policy(mut self, policy: SameSite) -> Self {
215 self.same_site_policy = policy;
216 self
217 }
218
219 #[inline]
221 pub fn cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
222 self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
223 self
224 }
225 #[inline]
227 pub fn fallback_keys(mut self, keys: Vec<impl Into<Key>>) -> Self {
228 self.fallback_keys = keys.into_iter().map(|s| s.into()).collect();
229 self
230 }
231
232 #[inline]
234 pub fn add_fallback_key(mut self, key: impl Into<Key>) -> Self {
235 self.fallback_keys.push(key.into());
236 self
237 }
238
239 pub fn build(self) -> Result<SessionHandler<S>, Error> {
241 let Self {
242 store,
243 save_unchanged,
244 cookie_path,
245 cookie_name,
246 cookie_domain,
247 session_ttl,
248 same_site_policy,
249 key,
250 fallback_keys,
251 } = self;
252 let hmac = Hmac::<Sha256>::new_from_slice(key.signing())
253 .map_err(|_| Error::Other("invalid key length".into()))?;
254 let fallback_hmacs = fallback_keys
255 .iter()
256 .map(|key| Hmac::<Sha256>::new_from_slice(key.signing()))
257 .collect::<Result<Vec<_>, _>>()
258 .map_err(|_| Error::Other("invalid key length".into()))?;
259 Ok(SessionHandler {
260 store,
261 save_unchanged,
262 cookie_path,
263 cookie_name,
264 cookie_domain,
265 session_ttl,
266 same_site_policy,
267 hmac,
268 fallback_hmacs,
269 })
270 }
271}
272
273pub struct SessionHandler<S> {
275 store: S,
276 cookie_path: String,
277 cookie_name: String,
278 cookie_domain: Option<String>,
279 session_ttl: Option<Duration>,
280 save_unchanged: bool,
281 same_site_policy: SameSite,
282 hmac: Hmac<Sha256>,
283 fallback_hmacs: Vec<Hmac<Sha256>>,
284}
285impl<S: SessionStore> fmt::Debug for SessionHandler<S> {
286 #[inline]
287 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
288 f.debug_struct("SessionHandler")
289 .field("store", &self.store)
290 .field("cookie_path", &self.cookie_path)
291 .field("cookie_name", &self.cookie_name)
292 .field("cookie_domain", &self.cookie_domain)
293 .field("session_ttl", &self.session_ttl)
294 .field("same_site_policy", &self.same_site_policy)
295 .field("key", &"..")
296 .field("fallback_keys", &"..")
297 .field("save_unchanged", &self.save_unchanged)
298 .finish()
299 }
300}
301#[async_trait]
302impl<S> Handler for SessionHandler<S>
303where
304 S: SessionStore,
305{
306 async fn handle(
307 &self,
308 req: &mut Request,
309 depot: &mut Depot,
310 res: &mut Response,
311 ctrl: &mut FlowCtrl,
312 ) {
313 let cookie = req.cookies().get(&self.cookie_name);
314 let cookie_value = cookie.and_then(|cookie| self.verify_signature(cookie.value()).ok());
315
316 let mut session = self.load_or_create(cookie_value).await;
317
318 if let Some(ttl) = self.session_ttl {
319 session.expire_in(ttl);
320 }
321
322 depot.set_session(session);
323
324 ctrl.call_next(req, depot, res).await;
325 if ctrl.is_ceased() {
326 return;
327 }
328
329 let session = depot.take_session().expect("session should exist in depot");
330 if session.is_destroyed() {
331 if let Err(e) = self.store.destroy_session(session).await {
332 tracing::error!(error = ?e, "unable to destroy session");
333 }
334 res.remove_cookie(&self.cookie_name);
335 } else if self.save_unchanged || session.data_changed() {
336 match self.store.store_session(session).await {
337 Ok(cookie_value) => {
338 if let Some(cookie_value) = cookie_value {
339 let secure_cookie = req.uri().scheme() == Some(&Scheme::HTTPS);
340 let cookie = self.build_cookie(secure_cookie, cookie_value);
341 res.add_cookie(cookie);
342 }
343 }
344 Err(e) => {
345 tracing::error!(error = ?e, "store session error");
346 }
347 }
348 }
349 }
350}
351
352impl<S> SessionHandler<S>
353where
354 S: SessionStore,
355{
356 pub fn builder(store: S, secret: &[u8]) -> HandlerBuilder<S> {
358 HandlerBuilder::new(store, secret)
359 }
360 #[inline]
361 async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
362 let session = match cookie_value {
363 Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
364 None => None,
365 };
366
367 session
368 .and_then(|session| session.validate())
369 .unwrap_or_default()
370 }
371 fn verify_signature(&self, cookie_value: &str) -> Result<String, Error> {
377 if cookie_value.len() < BASE64_DIGEST_LEN {
378 return Err(Error::Other(
379 "length of value is <= BASE64_DIGEST_LEN".into(),
380 ));
381 }
382
383 let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
385 let digest =
386 base64::decode(digest_str).map_err(|_| Error::Other("bad base64 digest".into()))?;
387
388 let mut hmac = self.hmac.clone();
390 hmac.update(value.as_bytes());
391 if hmac.verify(&digest).is_ok() {
392 return Ok(value.to_string());
393 }
394 for hmac in &self.fallback_hmacs {
395 let mut hmac = hmac.clone();
396 hmac.update(value.as_bytes());
397 if hmac.verify(&digest).is_ok() {
398 return Ok(value.to_string());
399 }
400 }
401 Err(Error::Other("value did not verify".into()))
402 }
403 fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
404 let mut cookie = Cookie::build((self.cookie_name.clone(), cookie_value))
405 .http_only(true)
406 .same_site(self.same_site_policy)
407 .secure(secure)
408 .path(self.cookie_path.clone())
409 .build();
410
411 if let Some(ttl) = self.session_ttl {
412 cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
413 }
414
415 if let Some(cookie_domain) = self.cookie_domain.clone() {
416 cookie.set_domain(cookie_domain)
417 }
418
419 self.sign_cookie(&mut cookie);
420
421 cookie
422 }
423 fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
427 let mut mac = self.hmac.clone();
429 mac.update(cookie.value().as_bytes());
430
431 let mut new_value = base64::encode(mac.finalize().into_bytes());
433 new_value.push_str(cookie.value());
434 cookie.set_value(new_value);
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use salvo_core::http::Method;
441 use salvo_core::http::header::*;
442 use salvo_core::prelude::*;
443 use salvo_core::test::{ResponseExt, TestClient};
444
445 use super::*;
446
447 #[test]
448 fn test_session_data() {
449 let builder = SessionHandler::builder(
450 async_session::CookieStore,
451 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
452 )
453 .cookie_domain("test.domain")
454 .cookie_name("test_cookie")
455 .cookie_path("/abc")
456 .same_site_policy(SameSite::Strict)
457 .session_ttl(Some(Duration::from_secs(30)));
458 assert!(format!("{:?}", builder).contains("test_cookie"));
459
460 let handler = builder.build().unwrap();
461 assert!(format!("{:?}", handler).contains("test_cookie"));
462 assert_eq!(handler.cookie_domain, Some("test.domain".into()));
463 assert_eq!(handler.cookie_name, "test_cookie");
464 assert_eq!(handler.cookie_path, "/abc");
465 assert_eq!(handler.same_site_policy, SameSite::Strict);
466 assert_eq!(handler.session_ttl, Some(Duration::from_secs(30)));
467 }
468
469 #[tokio::test]
470 async fn test_session_login() {
471 #[handler]
472 pub async fn login(req: &mut Request, depot: &mut Depot, res: &mut Response) {
473 if req.method() == Method::POST {
474 let mut session = Session::new();
475 session
476 .insert("username", req.form::<String>("username").await.unwrap())
477 .unwrap();
478 depot.set_session(session);
479 res.render(Redirect::other("/"));
480 } else {
481 res.render(Text::Html("login page"));
482 }
483 }
484
485 #[handler]
486 pub async fn logout(depot: &mut Depot, res: &mut Response) {
487 if let Some(session) = depot.session_mut() {
488 session.remove("username");
489 }
490 res.render(Redirect::other("/"));
491 }
492
493 #[handler]
494 pub async fn home(depot: &mut Depot, res: &mut Response) {
495 let mut content = r#"home"#.into();
496 if let Some(session) = depot.session_mut() {
497 if let Some(username) = session.get::<String>("username") {
498 content = username;
499 }
500 }
501 res.render(Text::Html(content));
502 }
503
504 let session_handler = SessionHandler::builder(
505 MemoryStore::new(),
506 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
507 )
508 .build()
509 .unwrap();
510 let router = Router::new()
511 .hoop(session_handler)
512 .get(home)
513 .push(Router::with_path("login").get(login).post(login))
514 .push(Router::with_path("logout").get(logout));
515 let service = Service::new(router);
516
517 let response = TestClient::post("http://127.0.0.1:5800/login")
518 .raw_form("username=salvo")
519 .send(&service)
520 .await;
521 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
522 let cookie = response.headers().get(SET_COOKIE).unwrap();
523
524 let mut response = TestClient::get("http://127.0.0.1:5800/")
525 .add_header(COOKIE, cookie, true)
526 .send(&service)
527 .await;
528 assert_eq!(response.take_string().await.unwrap(), "salvo");
529
530 let response = TestClient::get("http://127.0.0.1:5800/logout")
531 .send(&service)
532 .await;
533 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
534
535 let mut response = TestClient::get("http://127.0.0.1:5800/")
536 .send(&service)
537 .await;
538 assert_eq!(response.take_string().await.unwrap(), "home");
539 }
540}