salvo_flash/
lib.rs

1//! The flash message lib for Salvo web framework.
2//!
3//! Read more: <https://salvo.rs>
4#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
5#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::ops::Deref;
10
11use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
12use serde::{Deserialize, Serialize};
13
14#[macro_use]
15mod cfg;
16
17cfg_feature! {
18    #![feature = "cookie-store"]
19
20    mod cookie_store;
21    pub use cookie_store::CookieStore;
22
23    /// Helper function to create a `CookieStore`.
24    pub fn cookie_store() -> CookieStore {
25        CookieStore::new()
26    }
27}
28
29cfg_feature! {
30    #![feature = "session-store"]
31
32    mod session_store;
33    pub use session_store::SessionStore;
34
35    /// Helper function to create a `SessionStore`.
36    pub fn session_store() -> SessionStore {
37        SessionStore::new()
38    }
39}
40
41/// Key for incoming flash messages in depot.
42pub const INCOMING_FLASH_KEY: &str = "::salvo::flash::incoming_flash";
43
44/// Key for outgoing flash messages in depot.
45pub const OUTGOING_FLASH_KEY: &str = "::salvo::flash::outgoing_flash";
46
47/// A flash is a list of messages.
48#[derive(Default, Serialize, Deserialize, Clone, Debug)]
49pub struct Flash(pub Vec<FlashMessage>);
50impl Flash {
51    /// Add a new message with level `Debug`.
52    #[inline]
53    pub fn debug(&mut self, message: impl Into<String>) -> &mut Self {
54        self.0.push(FlashMessage::debug(message));
55        self
56    }
57    /// Add a new message with level `Info`.
58    #[inline]
59    pub fn info(&mut self, message: impl Into<String>) -> &mut Self {
60        self.0.push(FlashMessage::info(message));
61        self
62    }
63    /// Add a new message with level `Success`.
64    #[inline]
65    pub fn success(&mut self, message: impl Into<String>) -> &mut Self {
66        self.0.push(FlashMessage::success(message));
67        self
68    }
69    /// Add a new message with level `Waring`.
70    #[inline]
71    pub fn warning(&mut self, message: impl Into<String>) -> &mut Self {
72        self.0.push(FlashMessage::warning(message));
73        self
74    }
75    /// Add a new message with level `Error`.
76    #[inline]
77    pub fn error(&mut self, message: impl Into<String>) -> &mut Self {
78        self.0.push(FlashMessage::error(message));
79        self
80    }
81}
82
83impl Deref for Flash {
84    type Target = Vec<FlashMessage>;
85
86    fn deref(&self) -> &Self::Target {
87        &self.0
88    }
89}
90
91/// A flash message.
92#[derive(Serialize, Deserialize, Clone, Debug)]
93#[non_exhaustive]
94pub struct FlashMessage {
95    /// Flash message level.
96    pub level: FlashLevel,
97    /// Flash message content.
98    pub value: String,
99}
100impl FlashMessage {
101    /// Create a new `FlashMessage` with `FlashLevel::Debug`.
102    #[inline]
103    pub fn debug(message: impl Into<String>) -> Self {
104        Self {
105            level: FlashLevel::Debug,
106            value: message.into(),
107        }
108    }
109    /// Create a new `FlashMessage` with `FlashLevel::Info`.
110    #[inline]
111    pub fn info(message: impl Into<String>) -> Self {
112        Self {
113            level: FlashLevel::Info,
114            value: message.into(),
115        }
116    }
117    /// Create a new `FlashMessage` with `FlashLevel::Success`.
118    #[inline]
119    pub fn success(message: impl Into<String>) -> Self {
120        Self {
121            level: FlashLevel::Success,
122            value: message.into(),
123        }
124    }
125    /// Create a new `FlashMessage` with `FlashLevel::Warning`.
126    #[inline]
127    pub fn warning(message: impl Into<String>) -> Self {
128        Self {
129            level: FlashLevel::Warning,
130            value: message.into(),
131        }
132    }
133    /// create a new `FlashMessage` with `FlashLevel::Error`.
134    #[inline]
135    pub fn error(message: impl Into<String>) -> Self {
136        Self {
137            level: FlashLevel::Error,
138            value: message.into(),
139        }
140    }
141}
142
143/// Verbosity level of a flash message.
144#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
145pub enum FlashLevel {
146    #[allow(missing_docs)]
147    Debug = 0,
148    #[allow(missing_docs)]
149    Info = 1,
150    #[allow(missing_docs)]
151    Success = 2,
152    #[allow(missing_docs)]
153    Warning = 3,
154    #[allow(missing_docs)]
155    Error = 4,
156}
157impl FlashLevel {
158    /// Convert a `FlashLevel` to a `&str`.
159    pub fn to_str(&self) -> &'static str {
160        match self {
161            FlashLevel::Debug => "debug",
162            FlashLevel::Info => "info",
163            FlashLevel::Success => "success",
164            FlashLevel::Warning => "warning",
165            FlashLevel::Error => "error",
166        }
167    }
168}
169impl Debug for FlashLevel {
170    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
171        write!(f, "{}", self.to_str())
172    }
173}
174
175impl Display for FlashLevel {
176    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177        write!(f, "{}", self.to_str())
178    }
179}
180
181/// `FlashStore` is for stores flash messages.
182pub trait FlashStore: Debug + Send + Sync + 'static {
183    /// Get the flash messages from the store.
184    fn load_flash(
185        &self,
186        req: &mut Request,
187        depot: &mut Depot,
188    ) -> impl Future<Output = Option<Flash>> + Send;
189    /// Save the flash messages to the store.
190    fn save_flash(
191        &self,
192        req: &mut Request,
193        depot: &mut Depot,
194        res: &mut Response,
195        flash: Flash,
196    ) -> impl Future<Output = ()> + Send;
197    /// Clear the flash store.
198    fn clear_flash(&self, depot: &mut Depot, res: &mut Response)
199    -> impl Future<Output = ()> + Send;
200}
201
202/// A trait for `Depot` to get flash messages.
203pub trait FlashDepotExt {
204    /// Get incoming flash.
205    fn incoming_flash(&mut self) -> Option<&Flash>;
206    /// Get outgoing flash.
207    fn outgoing_flash(&self) -> &Flash;
208    /// Get mutable outgoing flash.
209    fn outgoing_flash_mut(&mut self) -> &mut Flash;
210}
211
212impl FlashDepotExt for Depot {
213    #[inline]
214    fn incoming_flash(&mut self) -> Option<&Flash> {
215        self.get::<Flash>(INCOMING_FLASH_KEY).ok()
216    }
217
218    #[inline]
219    fn outgoing_flash(&self) -> &Flash {
220        self.get::<Flash>(OUTGOING_FLASH_KEY)
221            .expect("Flash should be initialized")
222    }
223
224    #[inline]
225    fn outgoing_flash_mut(&mut self) -> &mut Flash {
226        self.get_mut::<Flash>(OUTGOING_FLASH_KEY)
227            .expect("Flash should be initialized")
228    }
229}
230
231/// `FlashHandler` is a middleware for flash messages.
232#[non_exhaustive]
233pub struct FlashHandler<S> {
234    store: S,
235    /// Minimum level of messages to be displayed.
236    pub minimum_level: Option<FlashLevel>,
237}
238impl<S> FlashHandler<S> {
239    /// Create a new `FlashHandler` with the given `FlashStore`.
240    #[inline]
241    pub fn new(store: S) -> Self {
242        Self {
243            store,
244            minimum_level: None,
245        }
246    }
247
248    /// Sets the minimum level of messages to be displayed.
249    #[inline]
250    pub fn minimum_level(&mut self, level: impl Into<Option<FlashLevel>>) -> &mut Self {
251        self.minimum_level = level.into();
252        self
253    }
254}
255impl<S: FlashStore> fmt::Debug for FlashHandler<S> {
256    #[inline]
257    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258        f.debug_struct("FlashHandler")
259            .field("store", &self.store)
260            .finish()
261    }
262}
263#[async_trait]
264impl<S> Handler for FlashHandler<S>
265where
266    S: FlashStore,
267{
268    async fn handle(
269        &self,
270        req: &mut Request,
271        depot: &mut Depot,
272        res: &mut Response,
273        ctrl: &mut FlowCtrl,
274    ) {
275        let mut has_incoming = false;
276        if let Some(flash) = self.store.load_flash(req, depot).await {
277            has_incoming = !flash.is_empty();
278            depot.insert(INCOMING_FLASH_KEY, flash);
279        }
280        depot.insert(OUTGOING_FLASH_KEY, Flash(vec![]));
281
282        ctrl.call_next(req, depot, res).await;
283        if ctrl.is_ceased() {
284            return;
285        }
286
287        let mut flash = depot
288            .remove::<Flash>(OUTGOING_FLASH_KEY)
289            .unwrap_or_default();
290        if let Some(min_level) = self.minimum_level {
291            flash.0.retain(|msg| msg.level >= min_level);
292        }
293        if !flash.is_empty() {
294            self.store.save_flash(req, depot, res, flash).await;
295        } else if has_incoming {
296            self.store.clear_flash(depot, res).await;
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use std::fmt::Write;
304
305    use salvo_core::http::header::{COOKIE, SET_COOKIE};
306    use salvo_core::prelude::*;
307    use salvo_core::test::{ResponseExt, TestClient};
308
309    use super::*;
310
311    #[handler]
312    pub async fn set_flash(depot: &mut Depot, res: &mut Response) {
313        let flash = depot.outgoing_flash_mut();
314        flash.info("Hey there!").debug("How is it going?");
315        res.render(Redirect::other("/get"));
316    }
317
318    #[handler]
319    pub async fn get_flash(depot: &mut Depot, _res: &mut Response) -> String {
320        let mut body = String::new();
321        if let Some(flash) = depot.incoming_flash() {
322            for message in flash.iter() {
323                writeln!(body, "{} - {}", message.value, message.level).unwrap();
324            }
325        }
326        body
327    }
328
329    #[cfg(feature = "cookie-store")]
330    #[tokio::test]
331    async fn test_cookie_store() {
332        let cookie_name = "my-custom-cookie-name".to_string();
333        let router = Router::new()
334            .hoop(CookieStore::new().name(&cookie_name).into_handler())
335            .push(Router::with_path("get").get(get_flash))
336            .push(Router::with_path("set").get(set_flash));
337        let service = Service::new(router);
338
339        let response = TestClient::get("http://127.0.0.1:5800/set")
340            .send(&service)
341            .await;
342        assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
343
344        let cookie = response.headers().get(SET_COOKIE).unwrap();
345        assert!(cookie.to_str().unwrap().contains(&cookie_name));
346
347        let mut response = TestClient::get("http://127.0.0.1:5800/get")
348            .add_header(COOKIE, cookie, true)
349            .send(&service)
350            .await;
351        assert!(response.take_string().await.unwrap().contains("Hey there!"));
352
353        let cookie = response.headers().get(SET_COOKIE).unwrap();
354        assert!(cookie.to_str().unwrap().contains(&cookie_name));
355
356        let mut response = TestClient::get("http://127.0.0.1:5800/get")
357            .add_header(COOKIE, cookie, true)
358            .send(&service)
359            .await;
360        assert!(response.take_string().await.unwrap().is_empty());
361    }
362
363    #[cfg(feature = "session-store")]
364    #[tokio::test]
365    async fn test_session_store() {
366        let session_handler = salvo_session::SessionHandler::builder(
367            salvo_session::MemoryStore::new(),
368            b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
369        )
370        .build()
371        .unwrap();
372
373        let session_name = "my-custom-session-name".to_string();
374        let router = Router::new()
375            .hoop(session_handler)
376            .hoop(SessionStore::new().name(&session_name).into_handler())
377            .push(Router::with_path("get").get(get_flash))
378            .push(Router::with_path("set").get(set_flash));
379        let service = Service::new(router);
380
381        let response = TestClient::get("http://127.0.0.1:5800/set")
382            .send(&service)
383            .await;
384        assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
385
386        let cookie = response.headers().get(SET_COOKIE).unwrap();
387
388        let mut response = TestClient::get("http://127.0.0.1:5800/get")
389            .add_header(COOKIE, cookie, true)
390            .send(&service)
391            .await;
392        assert!(response.take_string().await.unwrap().contains("Hey there!"));
393
394        let mut response = TestClient::get("http://127.0.0.1:5800/get")
395            .add_header(COOKIE, cookie, true)
396            .send(&service)
397            .await;
398        assert!(response.take_string().await.unwrap().is_empty());
399    }
400}