1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
70#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
71#![cfg_attr(docsrs, feature(doc_cfg))]
72
73use std::fmt::{self, Debug, Display, Formatter};
74use std::ops::Deref;
75
76use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
77use serde::{Deserialize, Serialize};
78
79#[macro_use]
80mod cfg;
81
82cfg_feature! {
83 #![feature = "cookie-store"]
84
85 mod cookie_store;
86 pub use cookie_store::CookieStore;
87
88 #[must_use] pub fn cookie_store() -> CookieStore {
90 CookieStore::new()
91 }
92}
93
94cfg_feature! {
95 #![feature = "session-store"]
96
97 mod session_store;
98 pub use session_store::SessionStore;
99
100 #[must_use]
102 pub fn session_store() -> SessionStore {
103 SessionStore::new()
104 }
105}
106
107pub const INCOMING_FLASH_KEY: &str = "::salvo::flash::incoming_flash";
109
110pub const OUTGOING_FLASH_KEY: &str = "::salvo::flash::outgoing_flash";
112
113#[derive(Default, Serialize, Deserialize, Clone, Debug)]
115pub struct Flash(pub Vec<FlashMessage>);
116impl Flash {
117 #[inline]
119 pub fn debug(&mut self, message: impl Into<String>) -> &mut Self {
120 self.0.push(FlashMessage::debug(message));
121 self
122 }
123 #[inline]
125 pub fn info(&mut self, message: impl Into<String>) -> &mut Self {
126 self.0.push(FlashMessage::info(message));
127 self
128 }
129 #[inline]
131 pub fn success(&mut self, message: impl Into<String>) -> &mut Self {
132 self.0.push(FlashMessage::success(message));
133 self
134 }
135 #[inline]
137 pub fn warning(&mut self, message: impl Into<String>) -> &mut Self {
138 self.0.push(FlashMessage::warning(message));
139 self
140 }
141 #[inline]
143 pub fn error(&mut self, message: impl Into<String>) -> &mut Self {
144 self.0.push(FlashMessage::error(message));
145 self
146 }
147}
148
149impl Deref for Flash {
150 type Target = Vec<FlashMessage>;
151
152 fn deref(&self) -> &Self::Target {
153 &self.0
154 }
155}
156
157#[derive(Serialize, Deserialize, Clone, Debug)]
159#[non_exhaustive]
160pub struct FlashMessage {
161 pub level: FlashLevel,
163 pub value: String,
165}
166impl FlashMessage {
167 #[inline]
169 pub fn debug(message: impl Into<String>) -> Self {
170 Self {
171 level: FlashLevel::Debug,
172 value: message.into(),
173 }
174 }
175 #[inline]
177 pub fn info(message: impl Into<String>) -> Self {
178 Self {
179 level: FlashLevel::Info,
180 value: message.into(),
181 }
182 }
183 #[inline]
185 pub fn success(message: impl Into<String>) -> Self {
186 Self {
187 level: FlashLevel::Success,
188 value: message.into(),
189 }
190 }
191 #[inline]
193 pub fn warning(message: impl Into<String>) -> Self {
194 Self {
195 level: FlashLevel::Warning,
196 value: message.into(),
197 }
198 }
199 #[inline]
201 pub fn error(message: impl Into<String>) -> Self {
202 Self {
203 level: FlashLevel::Error,
204 value: message.into(),
205 }
206 }
207}
208
209#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
211pub enum FlashLevel {
212 #[allow(missing_docs)]
213 Debug = 0,
214 #[allow(missing_docs)]
215 Info = 1,
216 #[allow(missing_docs)]
217 Success = 2,
218 #[allow(missing_docs)]
219 Warning = 3,
220 #[allow(missing_docs)]
221 Error = 4,
222}
223impl FlashLevel {
224 #[must_use]
226 pub fn to_str(&self) -> &'static str {
227 match self {
228 Self::Debug => "debug",
229 Self::Info => "info",
230 Self::Success => "success",
231 Self::Warning => "warning",
232 Self::Error => "error",
233 }
234 }
235}
236impl Debug for FlashLevel {
237 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
238 write!(f, "{}", self.to_str())
239 }
240}
241
242impl Display for FlashLevel {
243 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
244 write!(f, "{}", self.to_str())
245 }
246}
247
248pub trait FlashStore: Debug + Send + Sync + 'static {
250 fn load_flash(
252 &self,
253 req: &mut Request,
254 depot: &mut Depot,
255 ) -> impl Future<Output = Option<Flash>> + Send;
256 fn save_flash(
258 &self,
259 req: &mut Request,
260 depot: &mut Depot,
261 res: &mut Response,
262 flash: Flash,
263 ) -> impl Future<Output = ()> + Send;
264 fn clear_flash(&self, depot: &mut Depot, res: &mut Response)
266 -> impl Future<Output = ()> + Send;
267}
268
269pub trait FlashDepotExt {
271 fn incoming_flash(&mut self) -> Option<&Flash>;
273 fn outgoing_flash(&self) -> &Flash;
275 fn outgoing_flash_mut(&mut self) -> &mut Flash;
277}
278
279impl FlashDepotExt for Depot {
280 #[inline]
281 fn incoming_flash(&mut self) -> Option<&Flash> {
282 self.get::<Flash>(INCOMING_FLASH_KEY).ok()
283 }
284
285 #[inline]
286 fn outgoing_flash(&self) -> &Flash {
287 self.get::<Flash>(OUTGOING_FLASH_KEY)
288 .expect("Flash should be initialized")
289 }
290
291 #[inline]
292 fn outgoing_flash_mut(&mut self) -> &mut Flash {
293 self.get_mut::<Flash>(OUTGOING_FLASH_KEY)
294 .expect("Flash should be initialized")
295 }
296}
297
298#[non_exhaustive]
300pub struct FlashHandler<S> {
301 store: S,
302 pub minimum_level: Option<FlashLevel>,
304}
305impl<S> FlashHandler<S> {
306 #[inline]
308 pub fn new(store: S) -> Self {
309 Self {
310 store,
311 minimum_level: None,
312 }
313 }
314
315 #[inline]
317 pub fn minimum_level(&mut self, level: impl Into<Option<FlashLevel>>) -> &mut Self {
318 self.minimum_level = level.into();
319 self
320 }
321}
322impl<S: FlashStore> fmt::Debug for FlashHandler<S> {
323 #[inline]
324 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
325 f.debug_struct("FlashHandler")
326 .field("store", &self.store)
327 .finish()
328 }
329}
330#[async_trait]
331impl<S> Handler for FlashHandler<S>
332where
333 S: FlashStore,
334{
335 async fn handle(
336 &self,
337 req: &mut Request,
338 depot: &mut Depot,
339 res: &mut Response,
340 ctrl: &mut FlowCtrl,
341 ) {
342 let mut has_incoming = false;
343 if let Some(flash) = self.store.load_flash(req, depot).await {
344 has_incoming = !flash.is_empty();
345 depot.insert(INCOMING_FLASH_KEY, flash);
346 }
347 depot.insert(OUTGOING_FLASH_KEY, Flash(vec![]));
348
349 ctrl.call_next(req, depot, res).await;
350 if ctrl.is_ceased() {
351 return;
352 }
353
354 let mut flash = depot
355 .remove::<Flash>(OUTGOING_FLASH_KEY)
356 .unwrap_or_default();
357 if let Some(min_level) = self.minimum_level {
358 flash.0.retain(|msg| msg.level >= min_level);
359 }
360 if !flash.is_empty() {
361 self.store.save_flash(req, depot, res, flash).await;
362 } else if has_incoming {
363 self.store.clear_flash(depot, res).await;
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use std::fmt::Write;
371
372 use salvo_core::http::header::{COOKIE, SET_COOKIE};
373 use salvo_core::prelude::*;
374 use salvo_core::test::{ResponseExt, TestClient};
375
376 use super::*;
377
378 #[handler]
379 pub async fn set_flash(depot: &mut Depot, res: &mut Response) {
380 let flash = depot.outgoing_flash_mut();
381 flash.info("Hey there!").debug("How is it going?");
382 res.render(Redirect::other("/get"));
383 }
384
385 #[handler]
386 pub async fn get_flash(depot: &mut Depot, _res: &mut Response) -> String {
387 let mut body = String::new();
388 if let Some(flash) = depot.incoming_flash() {
389 for message in flash.iter() {
390 writeln!(body, "{} - {}", message.value, message.level).unwrap();
391 }
392 }
393 body
394 }
395
396 #[cfg(feature = "cookie-store")]
397 #[tokio::test]
398 async fn test_cookie_store() {
399 let cookie_name = "my-custom-cookie-name".to_owned();
400 let router = Router::new()
401 .hoop(CookieStore::new().name(&cookie_name).into_handler())
402 .push(Router::with_path("get").get(get_flash))
403 .push(Router::with_path("set").get(set_flash));
404 let service = Service::new(router);
405
406 let response = TestClient::get("http://127.0.0.1:8698/set")
407 .send(&service)
408 .await;
409 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
410
411 let cookie = response.headers().get(SET_COOKIE).unwrap();
412 assert!(cookie.to_str().unwrap().contains(&cookie_name));
413
414 let mut response = TestClient::get("http://127.0.0.1:8698/get")
415 .add_header(COOKIE, cookie, true)
416 .send(&service)
417 .await;
418 assert!(response.take_string().await.unwrap().contains("Hey there!"));
419
420 let cookie = response.headers().get(SET_COOKIE).unwrap();
421 assert!(cookie.to_str().unwrap().contains(&cookie_name));
422
423 let mut response = TestClient::get("http://127.0.0.1:8698/get")
424 .add_header(COOKIE, cookie, true)
425 .send(&service)
426 .await;
427 assert!(response.take_string().await.unwrap().is_empty());
428 }
429
430 #[cfg(feature = "session-store")]
431 #[tokio::test]
432 async fn test_session_store() {
433 let session_handler = salvo_session::SessionHandler::builder(
434 salvo_session::MemoryStore::new(),
435 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
436 )
437 .build()
438 .unwrap();
439
440 let session_name = "my-custom-session-name".to_string();
441 let router = Router::new()
442 .hoop(session_handler)
443 .hoop(SessionStore::new().name(&session_name).into_handler())
444 .push(Router::with_path("get").get(get_flash))
445 .push(Router::with_path("set").get(set_flash));
446 let service = Service::new(router);
447
448 let response = TestClient::get("http://127.0.0.1:8698/set")
449 .send(&service)
450 .await;
451 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
452
453 let cookie = response.headers().get(SET_COOKIE).unwrap();
454
455 let mut response = TestClient::get("http://127.0.0.1:8698/get")
456 .add_header(COOKIE, cookie, true)
457 .send(&service)
458 .await;
459 assert!(response.take_string().await.unwrap().contains("Hey there!"));
460
461 let mut response = TestClient::get("http://127.0.0.1:8698/get")
462 .add_header(COOKIE, cookie, true)
463 .send(&service)
464 .await;
465 assert!(response.take_string().await.unwrap().is_empty());
466 }
467
468 #[test]
470 fn test_flash_default() {
471 let flash = Flash::default();
472 assert!(flash.0.is_empty());
473 }
474
475 #[test]
476 fn test_flash_debug() {
477 let mut flash = Flash::default();
478 flash.info("test message");
479 let debug_str = format!("{:?}", flash);
480 assert!(debug_str.contains("Flash"));
481 assert!(debug_str.contains("test message"));
482 }
483
484 #[test]
485 fn test_flash_clone() {
486 let mut flash = Flash::default();
487 flash.info("test");
488 let cloned = flash.clone();
489 assert_eq!(flash.0.len(), cloned.0.len());
490 }
491
492 #[test]
493 fn test_flash_add_debug() {
494 let mut flash = Flash::default();
495 flash.debug("debug message");
496 assert_eq!(flash.0.len(), 1);
497 assert_eq!(flash.0[0].level, FlashLevel::Debug);
498 assert_eq!(flash.0[0].value, "debug message");
499 }
500
501 #[test]
502 fn test_flash_add_info() {
503 let mut flash = Flash::default();
504 flash.info("info message");
505 assert_eq!(flash.0.len(), 1);
506 assert_eq!(flash.0[0].level, FlashLevel::Info);
507 assert_eq!(flash.0[0].value, "info message");
508 }
509
510 #[test]
511 fn test_flash_add_success() {
512 let mut flash = Flash::default();
513 flash.success("success message");
514 assert_eq!(flash.0.len(), 1);
515 assert_eq!(flash.0[0].level, FlashLevel::Success);
516 assert_eq!(flash.0[0].value, "success message");
517 }
518
519 #[test]
520 fn test_flash_add_warning() {
521 let mut flash = Flash::default();
522 flash.warning("warning message");
523 assert_eq!(flash.0.len(), 1);
524 assert_eq!(flash.0[0].level, FlashLevel::Warning);
525 assert_eq!(flash.0[0].value, "warning message");
526 }
527
528 #[test]
529 fn test_flash_add_error() {
530 let mut flash = Flash::default();
531 flash.error("error message");
532 assert_eq!(flash.0.len(), 1);
533 assert_eq!(flash.0[0].level, FlashLevel::Error);
534 assert_eq!(flash.0[0].value, "error message");
535 }
536
537 #[test]
538 fn test_flash_chain_messages() {
539 let mut flash = Flash::default();
540 flash
541 .debug("debug")
542 .info("info")
543 .success("success")
544 .warning("warning")
545 .error("error");
546 assert_eq!(flash.0.len(), 5);
547 }
548
549 #[test]
550 fn test_flash_deref() {
551 let mut flash = Flash::default();
552 flash.info("test");
553 assert_eq!(flash.len(), 1);
555 assert!(flash.iter().any(|m| m.value == "test"));
556 }
557
558 #[test]
560 fn test_flash_message_debug() {
561 let msg = FlashMessage::debug("debug msg");
562 assert_eq!(msg.level, FlashLevel::Debug);
563 assert_eq!(msg.value, "debug msg");
564 }
565
566 #[test]
567 fn test_flash_message_info() {
568 let msg = FlashMessage::info("info msg");
569 assert_eq!(msg.level, FlashLevel::Info);
570 assert_eq!(msg.value, "info msg");
571 }
572
573 #[test]
574 fn test_flash_message_success() {
575 let msg = FlashMessage::success("success msg");
576 assert_eq!(msg.level, FlashLevel::Success);
577 assert_eq!(msg.value, "success msg");
578 }
579
580 #[test]
581 fn test_flash_message_warning() {
582 let msg = FlashMessage::warning("warning msg");
583 assert_eq!(msg.level, FlashLevel::Warning);
584 assert_eq!(msg.value, "warning msg");
585 }
586
587 #[test]
588 fn test_flash_message_error() {
589 let msg = FlashMessage::error("error msg");
590 assert_eq!(msg.level, FlashLevel::Error);
591 assert_eq!(msg.value, "error msg");
592 }
593
594 #[test]
595 fn test_flash_message_clone() {
596 let msg = FlashMessage::info("test");
597 let cloned = msg.clone();
598 assert_eq!(msg.level, cloned.level);
599 assert_eq!(msg.value, cloned.value);
600 }
601
602 #[test]
603 fn test_flash_message_debug_trait() {
604 let msg = FlashMessage::info("test");
605 let debug_str = format!("{:?}", msg);
606 assert!(debug_str.contains("FlashMessage"));
607 assert!(debug_str.contains("test"));
608 }
609
610 #[test]
612 fn test_flash_level_to_str() {
613 assert_eq!(FlashLevel::Debug.to_str(), "debug");
614 assert_eq!(FlashLevel::Info.to_str(), "info");
615 assert_eq!(FlashLevel::Success.to_str(), "success");
616 assert_eq!(FlashLevel::Warning.to_str(), "warning");
617 assert_eq!(FlashLevel::Error.to_str(), "error");
618 }
619
620 #[test]
621 fn test_flash_level_debug_trait() {
622 assert_eq!(format!("{:?}", FlashLevel::Debug), "debug");
623 assert_eq!(format!("{:?}", FlashLevel::Info), "info");
624 assert_eq!(format!("{:?}", FlashLevel::Success), "success");
625 assert_eq!(format!("{:?}", FlashLevel::Warning), "warning");
626 assert_eq!(format!("{:?}", FlashLevel::Error), "error");
627 }
628
629 #[test]
630 fn test_flash_level_display() {
631 assert_eq!(format!("{}", FlashLevel::Debug), "debug");
632 assert_eq!(format!("{}", FlashLevel::Info), "info");
633 assert_eq!(format!("{}", FlashLevel::Success), "success");
634 assert_eq!(format!("{}", FlashLevel::Warning), "warning");
635 assert_eq!(format!("{}", FlashLevel::Error), "error");
636 }
637
638 #[test]
639 fn test_flash_level_ord() {
640 assert!(FlashLevel::Debug < FlashLevel::Info);
641 assert!(FlashLevel::Info < FlashLevel::Success);
642 assert!(FlashLevel::Success < FlashLevel::Warning);
643 assert!(FlashLevel::Warning < FlashLevel::Error);
644 }
645
646 #[test]
647 fn test_flash_level_eq() {
648 assert_eq!(FlashLevel::Debug, FlashLevel::Debug);
649 assert_ne!(FlashLevel::Debug, FlashLevel::Info);
650 }
651
652 #[test]
653 fn test_flash_level_clone() {
654 let level = FlashLevel::Info;
655 let cloned = level;
656 assert_eq!(level, cloned);
657 }
658
659 #[test]
660 fn test_flash_level_copy() {
661 let level = FlashLevel::Warning;
662 let copied = level;
663 assert_eq!(level, copied);
664 }
665
666 #[test]
668 fn test_flash_handler_new() {
669 #[cfg(feature = "cookie-store")]
670 {
671 let handler = FlashHandler::new(CookieStore::new());
672 assert!(handler.minimum_level.is_none());
673 }
674 }
675
676 #[test]
677 fn test_flash_handler_minimum_level() {
678 #[cfg(feature = "cookie-store")]
679 {
680 let mut handler = FlashHandler::new(CookieStore::new());
681 handler.minimum_level(FlashLevel::Warning);
682 assert_eq!(handler.minimum_level, Some(FlashLevel::Warning));
683 }
684 }
685
686 #[test]
687 fn test_flash_handler_minimum_level_none() {
688 #[cfg(feature = "cookie-store")]
689 {
690 let mut handler = FlashHandler::new(CookieStore::new());
691 handler.minimum_level(FlashLevel::Info);
692 handler.minimum_level(None);
693 assert!(handler.minimum_level.is_none());
694 }
695 }
696
697 #[test]
698 fn test_flash_handler_debug() {
699 #[cfg(feature = "cookie-store")]
700 {
701 let handler = FlashHandler::new(CookieStore::new());
702 let debug_str = format!("{:?}", handler);
703 assert!(debug_str.contains("FlashHandler"));
704 assert!(debug_str.contains("store"));
705 }
706 }
707
708 #[test]
710 fn test_flash_serialization() {
711 let mut flash = Flash::default();
712 flash.info("test message");
713
714 let serialized = serde_json::to_string(&flash).unwrap();
715 let deserialized: Flash = serde_json::from_str(&serialized).unwrap();
716
717 assert_eq!(flash.0.len(), deserialized.0.len());
718 assert_eq!(flash.0[0].value, deserialized.0[0].value);
719 assert_eq!(flash.0[0].level, deserialized.0[0].level);
720 }
721
722 #[test]
723 fn test_flash_message_serialization() {
724 let msg = FlashMessage::warning("test");
725
726 let serialized = serde_json::to_string(&msg).unwrap();
727 let deserialized: FlashMessage = serde_json::from_str(&serialized).unwrap();
728
729 assert_eq!(msg.value, deserialized.value);
730 assert_eq!(msg.level, deserialized.level);
731 }
732
733 #[test]
734 fn test_flash_level_serialization() {
735 let level = FlashLevel::Error;
736
737 let serialized = serde_json::to_string(&level).unwrap();
738 let deserialized: FlashLevel = serde_json::from_str(&serialized).unwrap();
739
740 assert_eq!(level, deserialized);
741 }
742
743 #[cfg(feature = "cookie-store")]
744 #[tokio::test]
745 async fn test_flash_handler_filters_by_minimum_level() {
746 #[handler]
747 pub async fn set_all_levels(depot: &mut Depot, res: &mut Response) {
748 let flash = depot.outgoing_flash_mut();
749 flash
750 .debug("debug msg")
751 .info("info msg")
752 .warning("warning msg")
753 .error("error msg");
754 res.render(Redirect::other("/get"));
755 }
756
757 let mut handler = FlashHandler::new(CookieStore::new());
758 handler.minimum_level(FlashLevel::Warning);
759
760 let router = Router::new()
761 .hoop(handler)
762 .push(Router::with_path("get").get(get_flash))
763 .push(Router::with_path("set").get(set_all_levels));
764 let service = Service::new(router);
765
766 let response = TestClient::get("http://127.0.0.1:8698/set")
767 .send(&service)
768 .await;
769
770 let cookie = response.headers().get(SET_COOKIE).unwrap();
771
772 let mut response = TestClient::get("http://127.0.0.1:8698/get")
773 .add_header(COOKIE, cookie, true)
774 .send(&service)
775 .await;
776
777 let body = response.take_string().await.unwrap();
778
779 assert!(body.contains("warning msg"));
781 assert!(body.contains("error msg"));
782 assert!(!body.contains("debug msg"));
783 assert!(!body.contains("info msg"));
784 }
785}