1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
64#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
65#![cfg_attr(docsrs, feature(doc_cfg))]
66
67pub use saysion::{CookieStore, MemoryStore, Session, SessionStore};
68
69use std::fmt::{self, Formatter};
70use std::time::Duration;
71
72use cookie::{Cookie, Key, SameSite};
73use salvo_core::http::uri::Scheme;
74use salvo_core::{Depot, Error, FlowCtrl, Handler, Request, Response, async_trait};
75use saysion::base64::{Engine as _, engine::general_purpose};
76use saysion::hmac::{Hmac, Mac};
77use saysion::sha2::Sha256;
78
79pub const SESSION_KEY: &str = "::salvo::session";
81const BASE64_DIGEST_LEN: usize = 44;
82
83pub trait SessionDepotExt {
85 fn set_session(&mut self, session: Session) -> &mut Self;
87 fn take_session(&mut self) -> Option<Session>;
89 fn session(&self) -> Option<&Session>;
91 fn session_mut(&mut self) -> Option<&mut Session>;
93}
94
95impl SessionDepotExt for Depot {
96 #[inline]
97 fn set_session(&mut self, session: Session) -> &mut Self {
98 self.insert(SESSION_KEY, session);
99 self
100 }
101 #[inline]
102 fn take_session(&mut self) -> Option<Session> {
103 self.remove(SESSION_KEY).ok()
104 }
105 #[inline]
106 fn session(&self) -> Option<&Session> {
107 self.get(SESSION_KEY).ok()
108 }
109 #[inline]
110 fn session_mut(&mut self) -> Option<&mut Session> {
111 self.get_mut(SESSION_KEY).ok()
112 }
113}
114
115pub struct HandlerBuilder<S> {
117 store: S,
118 cookie_path: String,
119 cookie_name: String,
120 cookie_domain: Option<String>,
121 session_ttl: Option<Duration>,
122 save_unchanged: bool,
123 same_site_policy: SameSite,
124 key: Key,
125 fallback_keys: Vec<Key>,
126}
127impl<S> fmt::Debug for HandlerBuilder<S>
128where
129 S: SessionStore + fmt::Debug,
130{
131 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
132 f.debug_struct("HandlerBuilder")
133 .field("store", &self.store)
134 .field("cookie_path", &self.cookie_path)
135 .field("cookie_name", &self.cookie_name)
136 .field("cookie_domain", &self.cookie_domain)
137 .field("session_ttl", &self.session_ttl)
138 .field("same_site_policy", &self.same_site_policy)
139 .field("key", &"..")
140 .field("fallback_keys", &"..")
141 .field("save_unchanged", &self.save_unchanged)
142 .finish()
143 }
144}
145
146impl<S> HandlerBuilder<S>
147where
148 S: SessionStore,
149{
150 #[inline]
152 #[must_use]
153 pub fn new(store: S, secret: &[u8]) -> Self {
154 Self {
155 store,
156 save_unchanged: true,
157 cookie_path: "/".into(),
158 cookie_name: "salvo.session.id".into(),
159 cookie_domain: None,
160 same_site_policy: SameSite::Lax,
161 session_ttl: Some(Duration::from_secs(24 * 60 * 60)),
162 key: Key::from(secret),
163 fallback_keys: vec![],
164 }
165 }
166
167 #[inline]
171 #[must_use]
172 pub fn cookie_path(mut self, cookie_path: impl Into<String>) -> Self {
173 self.cookie_path = cookie_path.into();
174 self
175 }
176
177 #[inline]
183 #[must_use]
184 pub fn session_ttl(mut self, session_ttl: Option<Duration>) -> Self {
185 self.session_ttl = session_ttl;
186 self
187 }
188
189 #[inline]
195 #[must_use]
196 pub fn cookie_name(mut self, cookie_name: impl Into<String>) -> Self {
197 self.cookie_name = cookie_name.into();
198 self
199 }
200
201 #[inline]
211 #[must_use]
212 pub fn save_unchanged(mut self, value: bool) -> Self {
213 self.save_unchanged = value;
214 self
215 }
216
217 #[inline]
222 #[must_use]
223 pub fn same_site_policy(mut self, policy: SameSite) -> Self {
224 self.same_site_policy = policy;
225 self
226 }
227
228 #[inline]
230 #[must_use]
231 pub fn cookie_domain(mut self, cookie_domain: impl AsRef<str>) -> Self {
232 self.cookie_domain = Some(cookie_domain.as_ref().to_owned());
233 self
234 }
235 #[inline]
237 #[must_use]
238 pub fn fallback_keys(mut self, keys: Vec<impl Into<Key>>) -> Self {
239 self.fallback_keys = keys.into_iter().map(|s| s.into()).collect();
240 self
241 }
242
243 #[inline]
245 #[must_use]
246 pub fn add_fallback_key(mut self, key: impl Into<Key>) -> Self {
247 self.fallback_keys.push(key.into());
248 self
249 }
250
251 pub fn build(self) -> Result<SessionHandler<S>, Error> {
253 let Self {
254 store,
255 save_unchanged,
256 cookie_path,
257 cookie_name,
258 cookie_domain,
259 session_ttl,
260 same_site_policy,
261 key,
262 fallback_keys,
263 } = self;
264 let hmac = Hmac::<Sha256>::new_from_slice(key.signing())
265 .map_err(|_| Error::Other("invalid key length".into()))?;
266 let fallback_hmacs = fallback_keys
267 .iter()
268 .map(|key| Hmac::<Sha256>::new_from_slice(key.signing()))
269 .collect::<Result<Vec<_>, _>>()
270 .map_err(|_| Error::Other("invalid key length".into()))?;
271 Ok(SessionHandler {
272 store,
273 save_unchanged,
274 cookie_path,
275 cookie_name,
276 cookie_domain,
277 session_ttl,
278 same_site_policy,
279 hmac,
280 fallback_hmacs,
281 })
282 }
283}
284
285pub struct SessionHandler<S> {
287 store: S,
288 cookie_path: String,
289 cookie_name: String,
290 cookie_domain: Option<String>,
291 session_ttl: Option<Duration>,
292 save_unchanged: bool,
293 same_site_policy: SameSite,
294 hmac: Hmac<Sha256>,
295 fallback_hmacs: Vec<Hmac<Sha256>>,
296}
297impl<S> fmt::Debug for SessionHandler<S>
298where
299 S: SessionStore + fmt::Debug,
300{
301 #[inline]
302 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
303 f.debug_struct("SessionHandler")
304 .field("store", &self.store)
305 .field("cookie_path", &self.cookie_path)
306 .field("cookie_name", &self.cookie_name)
307 .field("cookie_domain", &self.cookie_domain)
308 .field("session_ttl", &self.session_ttl)
309 .field("same_site_policy", &self.same_site_policy)
310 .field("key", &"..")
311 .field("fallback_keys", &"..")
312 .field("save_unchanged", &self.save_unchanged)
313 .finish()
314 }
315}
316#[async_trait]
317impl<S> Handler for SessionHandler<S>
318where
319 S: SessionStore + Send + Sync + 'static,
320{
321 async fn handle(
322 &self,
323 req: &mut Request,
324 depot: &mut Depot,
325 res: &mut Response,
326 ctrl: &mut FlowCtrl,
327 ) {
328 let cookie = req.cookies().get(&self.cookie_name);
329 let cookie_value = cookie.and_then(|cookie| self.verify_signature(cookie.value()).ok());
330
331 let mut session = self.load_or_create(cookie_value).await;
332
333 if let Some(ttl) = self.session_ttl {
334 session.expire_in(ttl);
335 }
336
337 depot.set_session(session);
338
339 ctrl.call_next(req, depot, res).await;
340 if ctrl.is_ceased() {
341 return;
342 }
343
344 let session = depot.take_session().expect("session should exist in depot");
345 if session.is_destroyed() {
346 if let Err(e) = self.store.destroy_session(session).await {
347 tracing::error!(error = ?e, "unable to destroy session");
348 }
349 res.remove_cookie(&self.cookie_name);
350 } else if self.save_unchanged || session.data_changed() {
351 match self.store.store_session(session).await {
352 Ok(cookie_value) => {
353 if let Some(cookie_value) = cookie_value {
354 let secure_cookie = req.uri().scheme() == Some(&Scheme::HTTPS);
355 let cookie = self.build_cookie(secure_cookie, cookie_value);
356 res.add_cookie(cookie);
357 }
358 }
359 Err(e) => {
360 tracing::error!(error = ?e, "store session error");
361 }
362 }
363 }
364 }
365}
366
367impl<S> SessionHandler<S>
368where
369 S: SessionStore + Send + Sync + 'static,
370{
371 pub fn builder(store: S, secret: &[u8]) -> HandlerBuilder<S> {
373 HandlerBuilder::new(store, secret)
374 }
375 #[inline]
376 async fn load_or_create(&self, cookie_value: Option<String>) -> Session {
377 let session = match cookie_value {
378 Some(cookie_value) => self.store.load_session(cookie_value).await.ok().flatten(),
379 None => None,
380 };
381
382 session
383 .and_then(|session| session.validate())
384 .unwrap_or_default()
385 }
386 fn verify_signature(&self, cookie_value: &str) -> Result<String, Error> {
392 if cookie_value.len() < BASE64_DIGEST_LEN {
393 return Err(Error::Other(
394 "length of value is <= BASE64_DIGEST_LEN".into(),
395 ));
396 }
397
398 let (digest_str, value) = cookie_value.split_at(BASE64_DIGEST_LEN);
400 let digest = general_purpose::STANDARD
401 .decode(digest_str)
402 .map_err(|_| Error::Other("bad base64 digest".into()))?;
403
404 let mut hmac = self.hmac.clone();
406 hmac.update(value.as_bytes());
407 if hmac.verify_slice(&digest).is_ok() {
408 return Ok(value.to_owned());
409 }
410 for hmac in &self.fallback_hmacs {
411 let mut hmac = hmac.clone();
412 hmac.update(value.as_bytes());
413 if hmac.verify_slice(&digest).is_ok() {
414 return Ok(value.to_owned());
415 }
416 }
417 Err(Error::Other("value did not verify".into()))
418 }
419 fn build_cookie(&self, secure: bool, cookie_value: String) -> Cookie<'static> {
420 let mut cookie = Cookie::build((self.cookie_name.clone(), cookie_value))
421 .http_only(true)
422 .same_site(self.same_site_policy)
423 .secure(secure)
424 .path(self.cookie_path.clone())
425 .build();
426
427 if let Some(ttl) = self.session_ttl {
428 cookie.set_expires(Some((std::time::SystemTime::now() + ttl).into()));
429 }
430
431 if let Some(cookie_domain) = self.cookie_domain.clone() {
432 cookie.set_domain(cookie_domain)
433 }
434
435 self.sign_cookie(&mut cookie);
436
437 cookie
438 }
439 fn sign_cookie(&self, cookie: &mut Cookie<'_>) {
443 let mut mac = self.hmac.clone();
445 mac.update(cookie.value().as_bytes());
446
447 let mut new_value = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
449 new_value.push_str(cookie.value());
450 cookie.set_value(new_value);
451 }
452}
453
454#[cfg(test)]
455mod tests {
456 use salvo_core::http::Method;
457 use salvo_core::http::header::*;
458 use salvo_core::prelude::*;
459 use salvo_core::test::{ResponseExt, TestClient};
460
461 use super::*;
462
463 #[test]
464 fn test_session_data() {
465 let builder = SessionHandler::builder(
466 saysion::CookieStore,
467 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
468 )
469 .cookie_domain("test.domain")
470 .cookie_name("test_cookie")
471 .cookie_path("/abc")
472 .same_site_policy(SameSite::Strict)
473 .session_ttl(Some(Duration::from_secs(30)));
474 assert!(format!("{builder:?}").contains("test_cookie"));
475
476 let handler = builder.build().unwrap();
477 assert!(format!("{handler:?}").contains("test_cookie"));
478 assert_eq!(handler.cookie_domain, Some("test.domain".into()));
479 assert_eq!(handler.cookie_name, "test_cookie");
480 assert_eq!(handler.cookie_path, "/abc");
481 assert_eq!(handler.same_site_policy, SameSite::Strict);
482 assert_eq!(handler.session_ttl, Some(Duration::from_secs(30)));
483 }
484
485 #[tokio::test]
486 async fn test_session_login() {
487 #[handler]
488 pub async fn login(req: &mut Request, depot: &mut Depot, res: &mut Response) {
489 if req.method() == Method::POST {
490 let mut session = Session::new();
491 session
492 .insert("username", req.form::<String>("username").await.unwrap())
493 .unwrap();
494 depot.set_session(session);
495 res.render(Redirect::other("/"));
496 } else {
497 res.render(Text::Html("login page"));
498 }
499 }
500
501 #[handler]
502 pub async fn logout(depot: &mut Depot, res: &mut Response) {
503 if let Some(session) = depot.session_mut() {
504 session.remove("username");
505 }
506 res.render(Redirect::other("/"));
507 }
508
509 #[handler]
510 pub async fn home(depot: &mut Depot, res: &mut Response) {
511 let mut content = r#"home"#.into();
512 if let Some(session) = depot.session_mut() {
513 if let Some(username) = session.get::<String>("username") {
514 content = username;
515 }
516 }
517 res.render(Text::Html(content));
518 }
519
520 let session_handler = SessionHandler::builder(
521 MemoryStore::new(),
522 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
523 )
524 .build()
525 .unwrap();
526 let router = Router::new()
527 .hoop(session_handler)
528 .get(home)
529 .push(Router::with_path("login").get(login).post(login))
530 .push(Router::with_path("logout").get(logout));
531 let service = Service::new(router);
532
533 let response = TestClient::post("http://127.0.0.1:8698/login")
534 .raw_form("username=salvo")
535 .send(&service)
536 .await;
537 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
538 let cookie = response.headers().get(SET_COOKIE).unwrap();
539
540 let mut response = TestClient::get("http://127.0.0.1:8698/")
541 .add_header(COOKIE, cookie, true)
542 .send(&service)
543 .await;
544 assert_eq!(response.take_string().await.unwrap(), "salvo");
545
546 let response = TestClient::get("http://127.0.0.1:8698/logout")
547 .send(&service)
548 .await;
549 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
550
551 let mut response = TestClient::get("http://127.0.0.1:8698/")
552 .send(&service)
553 .await;
554 assert_eq!(response.take_string().await.unwrap(), "home");
555 }
556}