1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
5#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::ops::Deref;
10
11use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
12use serde::{Deserialize, Serialize};
13
14#[macro_use]
15mod cfg;
16
17cfg_feature! {
18 #![feature = "cookie-store"]
19
20 mod cookie_store;
21 pub use cookie_store::CookieStore;
22
23 pub fn cookie_store() -> CookieStore {
25 CookieStore::new()
26 }
27}
28
29cfg_feature! {
30 #![feature = "session-store"]
31
32 mod session_store;
33 pub use session_store::SessionStore;
34
35 pub fn session_store() -> SessionStore {
37 SessionStore::new()
38 }
39}
40
41pub const INCOMING_FLASH_KEY: &str = "::salvo::flash::incoming_flash";
43
44pub const OUTGOING_FLASH_KEY: &str = "::salvo::flash::outgoing_flash";
46
47#[derive(Default, Serialize, Deserialize, Clone, Debug)]
49pub struct Flash(pub Vec<FlashMessage>);
50impl Flash {
51 #[inline]
53 pub fn debug(&mut self, message: impl Into<String>) -> &mut Self {
54 self.0.push(FlashMessage::debug(message));
55 self
56 }
57 #[inline]
59 pub fn info(&mut self, message: impl Into<String>) -> &mut Self {
60 self.0.push(FlashMessage::info(message));
61 self
62 }
63 #[inline]
65 pub fn success(&mut self, message: impl Into<String>) -> &mut Self {
66 self.0.push(FlashMessage::success(message));
67 self
68 }
69 #[inline]
71 pub fn warning(&mut self, message: impl Into<String>) -> &mut Self {
72 self.0.push(FlashMessage::warning(message));
73 self
74 }
75 #[inline]
77 pub fn error(&mut self, message: impl Into<String>) -> &mut Self {
78 self.0.push(FlashMessage::error(message));
79 self
80 }
81}
82
83impl Deref for Flash {
84 type Target = Vec<FlashMessage>;
85
86 fn deref(&self) -> &Self::Target {
87 &self.0
88 }
89}
90
91#[derive(Serialize, Deserialize, Clone, Debug)]
93#[non_exhaustive]
94pub struct FlashMessage {
95 pub level: FlashLevel,
97 pub value: String,
99}
100impl FlashMessage {
101 #[inline]
103 pub fn debug(message: impl Into<String>) -> Self {
104 Self {
105 level: FlashLevel::Debug,
106 value: message.into(),
107 }
108 }
109 #[inline]
111 pub fn info(message: impl Into<String>) -> Self {
112 Self {
113 level: FlashLevel::Info,
114 value: message.into(),
115 }
116 }
117 #[inline]
119 pub fn success(message: impl Into<String>) -> Self {
120 Self {
121 level: FlashLevel::Success,
122 value: message.into(),
123 }
124 }
125 #[inline]
127 pub fn warning(message: impl Into<String>) -> Self {
128 Self {
129 level: FlashLevel::Warning,
130 value: message.into(),
131 }
132 }
133 #[inline]
135 pub fn error(message: impl Into<String>) -> Self {
136 Self {
137 level: FlashLevel::Error,
138 value: message.into(),
139 }
140 }
141}
142
143#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
145pub enum FlashLevel {
146 #[allow(missing_docs)]
147 Debug = 0,
148 #[allow(missing_docs)]
149 Info = 1,
150 #[allow(missing_docs)]
151 Success = 2,
152 #[allow(missing_docs)]
153 Warning = 3,
154 #[allow(missing_docs)]
155 Error = 4,
156}
157impl FlashLevel {
158 pub fn to_str(&self) -> &'static str {
160 match self {
161 FlashLevel::Debug => "debug",
162 FlashLevel::Info => "info",
163 FlashLevel::Success => "success",
164 FlashLevel::Warning => "warning",
165 FlashLevel::Error => "error",
166 }
167 }
168}
169impl Debug for FlashLevel {
170 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
171 write!(f, "{}", self.to_str())
172 }
173}
174
175impl Display for FlashLevel {
176 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
177 write!(f, "{}", self.to_str())
178 }
179}
180
181pub trait FlashStore: Debug + Send + Sync + 'static {
183 fn load_flash(
185 &self,
186 req: &mut Request,
187 depot: &mut Depot,
188 ) -> impl Future<Output = Option<Flash>> + Send;
189 fn save_flash(
191 &self,
192 req: &mut Request,
193 depot: &mut Depot,
194 res: &mut Response,
195 flash: Flash,
196 ) -> impl Future<Output = ()> + Send;
197 fn clear_flash(&self, depot: &mut Depot, res: &mut Response)
199 -> impl Future<Output = ()> + Send;
200}
201
202pub trait FlashDepotExt {
204 fn incoming_flash(&mut self) -> Option<&Flash>;
206 fn outgoing_flash(&self) -> &Flash;
208 fn outgoing_flash_mut(&mut self) -> &mut Flash;
210}
211
212impl FlashDepotExt for Depot {
213 #[inline]
214 fn incoming_flash(&mut self) -> Option<&Flash> {
215 self.get::<Flash>(INCOMING_FLASH_KEY).ok()
216 }
217
218 #[inline]
219 fn outgoing_flash(&self) -> &Flash {
220 self.get::<Flash>(OUTGOING_FLASH_KEY)
221 .expect("Flash should be initialized")
222 }
223
224 #[inline]
225 fn outgoing_flash_mut(&mut self) -> &mut Flash {
226 self.get_mut::<Flash>(OUTGOING_FLASH_KEY)
227 .expect("Flash should be initialized")
228 }
229}
230
231#[non_exhaustive]
233pub struct FlashHandler<S> {
234 store: S,
235 pub minimum_level: Option<FlashLevel>,
237}
238impl<S> FlashHandler<S> {
239 #[inline]
241 pub fn new(store: S) -> Self {
242 Self {
243 store,
244 minimum_level: None,
245 }
246 }
247
248 #[inline]
250 pub fn minimum_level(&mut self, level: impl Into<Option<FlashLevel>>) -> &mut Self {
251 self.minimum_level = level.into();
252 self
253 }
254}
255impl<S: FlashStore> fmt::Debug for FlashHandler<S> {
256 #[inline]
257 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258 f.debug_struct("FlashHandler")
259 .field("store", &self.store)
260 .finish()
261 }
262}
263#[async_trait]
264impl<S> Handler for FlashHandler<S>
265where
266 S: FlashStore,
267{
268 async fn handle(
269 &self,
270 req: &mut Request,
271 depot: &mut Depot,
272 res: &mut Response,
273 ctrl: &mut FlowCtrl,
274 ) {
275 let mut has_incoming = false;
276 if let Some(flash) = self.store.load_flash(req, depot).await {
277 has_incoming = !flash.is_empty();
278 depot.insert(INCOMING_FLASH_KEY, flash);
279 }
280 depot.insert(OUTGOING_FLASH_KEY, Flash(vec![]));
281
282 ctrl.call_next(req, depot, res).await;
283 if ctrl.is_ceased() {
284 return;
285 }
286
287 let mut flash = depot
288 .remove::<Flash>(OUTGOING_FLASH_KEY)
289 .unwrap_or_default();
290 if let Some(min_level) = self.minimum_level {
291 flash.0.retain(|msg| msg.level >= min_level);
292 }
293 if !flash.is_empty() {
294 self.store.save_flash(req, depot, res, flash).await;
295 } else if has_incoming {
296 self.store.clear_flash(depot, res).await;
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use std::fmt::Write;
304
305 use salvo_core::http::header::{COOKIE, SET_COOKIE};
306 use salvo_core::prelude::*;
307 use salvo_core::test::{ResponseExt, TestClient};
308
309 use super::*;
310
311 #[handler]
312 pub async fn set_flash(depot: &mut Depot, res: &mut Response) {
313 let flash = depot.outgoing_flash_mut();
314 flash.info("Hey there!").debug("How is it going?");
315 res.render(Redirect::other("/get"));
316 }
317
318 #[handler]
319 pub async fn get_flash(depot: &mut Depot, _res: &mut Response) -> String {
320 let mut body = String::new();
321 if let Some(flash) = depot.incoming_flash() {
322 for message in flash.iter() {
323 writeln!(body, "{} - {}", message.value, message.level).unwrap();
324 }
325 }
326 body
327 }
328
329 #[cfg(feature = "cookie-store")]
330 #[tokio::test]
331 async fn test_cookie_store() {
332 let cookie_name = "my-custom-cookie-name".to_string();
333 let router = Router::new()
334 .hoop(CookieStore::new().name(&cookie_name).into_handler())
335 .push(Router::with_path("get").get(get_flash))
336 .push(Router::with_path("set").get(set_flash));
337 let service = Service::new(router);
338
339 let response = TestClient::get("http://127.0.0.1:5800/set")
340 .send(&service)
341 .await;
342 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
343
344 let cookie = response.headers().get(SET_COOKIE).unwrap();
345 assert!(cookie.to_str().unwrap().contains(&cookie_name));
346
347 let mut response = TestClient::get("http://127.0.0.1:5800/get")
348 .add_header(COOKIE, cookie, true)
349 .send(&service)
350 .await;
351 assert!(response.take_string().await.unwrap().contains("Hey there!"));
352
353 let cookie = response.headers().get(SET_COOKIE).unwrap();
354 assert!(cookie.to_str().unwrap().contains(&cookie_name));
355
356 let mut response = TestClient::get("http://127.0.0.1:5800/get")
357 .add_header(COOKIE, cookie, true)
358 .send(&service)
359 .await;
360 assert!(response.take_string().await.unwrap().is_empty());
361 }
362
363 #[cfg(feature = "session-store")]
364 #[tokio::test]
365 async fn test_session_store() {
366 let session_handler = salvo_session::SessionHandler::builder(
367 salvo_session::MemoryStore::new(),
368 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
369 )
370 .build()
371 .unwrap();
372
373 let session_name = "my-custom-session-name".to_string();
374 let router = Router::new()
375 .hoop(session_handler)
376 .hoop(SessionStore::new().name(&session_name).into_handler())
377 .push(Router::with_path("get").get(get_flash))
378 .push(Router::with_path("set").get(set_flash));
379 let service = Service::new(router);
380
381 let response = TestClient::get("http://127.0.0.1:5800/set")
382 .send(&service)
383 .await;
384 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
385
386 let cookie = response.headers().get(SET_COOKIE).unwrap();
387
388 let mut response = TestClient::get("http://127.0.0.1:5800/get")
389 .add_header(COOKIE, cookie, true)
390 .send(&service)
391 .await;
392 assert!(response.take_string().await.unwrap().contains("Hey there!"));
393
394 let mut response = TestClient::get("http://127.0.0.1:5800/get")
395 .add_header(COOKIE, cookie, true)
396 .send(&service)
397 .await;
398 assert!(response.take_string().await.unwrap().is_empty());
399 }
400}