1#![doc(html_favicon_url = "https://salvo.rs/favicon-32x32.png")]
5#![doc(html_logo_url = "https://salvo.rs/images/logo.svg")]
6#![cfg_attr(docsrs, feature(doc_cfg))]
7
8use std::fmt::{self, Debug, Display, Formatter};
9use std::ops::Deref;
10
11use salvo_core::{Depot, FlowCtrl, Handler, Request, Response, async_trait};
12use serde::{Deserialize, Serialize};
13
14#[macro_use]
15mod cfg;
16
17cfg_feature! {
18 #![feature = "cookie-store"]
19
20 mod cookie_store;
21 pub use cookie_store::CookieStore;
22
23 #[must_use] pub fn cookie_store() -> CookieStore {
25 CookieStore::new()
26 }
27}
28
29cfg_feature! {
30 #![feature = "session-store"]
31
32 mod session_store;
33 pub use session_store::SessionStore;
34
35 #[must_use]
37 pub fn session_store() -> SessionStore {
38 SessionStore::new()
39 }
40}
41
42pub const INCOMING_FLASH_KEY: &str = "::salvo::flash::incoming_flash";
44
45pub const OUTGOING_FLASH_KEY: &str = "::salvo::flash::outgoing_flash";
47
48#[derive(Default, Serialize, Deserialize, Clone, Debug)]
50pub struct Flash(pub Vec<FlashMessage>);
51impl Flash {
52 #[inline]
54 pub fn debug(&mut self, message: impl Into<String>) -> &mut Self {
55 self.0.push(FlashMessage::debug(message));
56 self
57 }
58 #[inline]
60 pub fn info(&mut self, message: impl Into<String>) -> &mut Self {
61 self.0.push(FlashMessage::info(message));
62 self
63 }
64 #[inline]
66 pub fn success(&mut self, message: impl Into<String>) -> &mut Self {
67 self.0.push(FlashMessage::success(message));
68 self
69 }
70 #[inline]
72 pub fn warning(&mut self, message: impl Into<String>) -> &mut Self {
73 self.0.push(FlashMessage::warning(message));
74 self
75 }
76 #[inline]
78 pub fn error(&mut self, message: impl Into<String>) -> &mut Self {
79 self.0.push(FlashMessage::error(message));
80 self
81 }
82}
83
84impl Deref for Flash {
85 type Target = Vec<FlashMessage>;
86
87 fn deref(&self) -> &Self::Target {
88 &self.0
89 }
90}
91
92#[derive(Serialize, Deserialize, Clone, Debug)]
94#[non_exhaustive]
95pub struct FlashMessage {
96 pub level: FlashLevel,
98 pub value: String,
100}
101impl FlashMessage {
102 #[inline]
104 pub fn debug(message: impl Into<String>) -> Self {
105 Self {
106 level: FlashLevel::Debug,
107 value: message.into(),
108 }
109 }
110 #[inline]
112 pub fn info(message: impl Into<String>) -> Self {
113 Self {
114 level: FlashLevel::Info,
115 value: message.into(),
116 }
117 }
118 #[inline]
120 pub fn success(message: impl Into<String>) -> Self {
121 Self {
122 level: FlashLevel::Success,
123 value: message.into(),
124 }
125 }
126 #[inline]
128 pub fn warning(message: impl Into<String>) -> Self {
129 Self {
130 level: FlashLevel::Warning,
131 value: message.into(),
132 }
133 }
134 #[inline]
136 pub fn error(message: impl Into<String>) -> Self {
137 Self {
138 level: FlashLevel::Error,
139 value: message.into(),
140 }
141 }
142}
143
144#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
146pub enum FlashLevel {
147 #[allow(missing_docs)]
148 Debug = 0,
149 #[allow(missing_docs)]
150 Info = 1,
151 #[allow(missing_docs)]
152 Success = 2,
153 #[allow(missing_docs)]
154 Warning = 3,
155 #[allow(missing_docs)]
156 Error = 4,
157}
158impl FlashLevel {
159 #[must_use]
161 pub fn to_str(&self) -> &'static str {
162 match self {
163 Self::Debug => "debug",
164 Self::Info => "info",
165 Self::Success => "success",
166 Self::Warning => "warning",
167 Self::Error => "error",
168 }
169 }
170}
171impl Debug for FlashLevel {
172 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
173 write!(f, "{}", self.to_str())
174 }
175}
176
177impl Display for FlashLevel {
178 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
179 write!(f, "{}", self.to_str())
180 }
181}
182
183pub trait FlashStore: Debug + Send + Sync + 'static {
185 fn load_flash(
187 &self,
188 req: &mut Request,
189 depot: &mut Depot,
190 ) -> impl Future<Output = Option<Flash>> + Send;
191 fn save_flash(
193 &self,
194 req: &mut Request,
195 depot: &mut Depot,
196 res: &mut Response,
197 flash: Flash,
198 ) -> impl Future<Output = ()> + Send;
199 fn clear_flash(&self, depot: &mut Depot, res: &mut Response)
201 -> impl Future<Output = ()> + Send;
202}
203
204pub trait FlashDepotExt {
206 fn incoming_flash(&mut self) -> Option<&Flash>;
208 fn outgoing_flash(&self) -> &Flash;
210 fn outgoing_flash_mut(&mut self) -> &mut Flash;
212}
213
214impl FlashDepotExt for Depot {
215 #[inline]
216 fn incoming_flash(&mut self) -> Option<&Flash> {
217 self.get::<Flash>(INCOMING_FLASH_KEY).ok()
218 }
219
220 #[inline]
221 fn outgoing_flash(&self) -> &Flash {
222 self.get::<Flash>(OUTGOING_FLASH_KEY)
223 .expect("Flash should be initialized")
224 }
225
226 #[inline]
227 fn outgoing_flash_mut(&mut self) -> &mut Flash {
228 self.get_mut::<Flash>(OUTGOING_FLASH_KEY)
229 .expect("Flash should be initialized")
230 }
231}
232
233#[non_exhaustive]
235pub struct FlashHandler<S> {
236 store: S,
237 pub minimum_level: Option<FlashLevel>,
239}
240impl<S> FlashHandler<S> {
241 #[inline]
243 pub fn new(store: S) -> Self {
244 Self {
245 store,
246 minimum_level: None,
247 }
248 }
249
250 #[inline]
252 pub fn minimum_level(&mut self, level: impl Into<Option<FlashLevel>>) -> &mut Self {
253 self.minimum_level = level.into();
254 self
255 }
256}
257impl<S: FlashStore> fmt::Debug for FlashHandler<S> {
258 #[inline]
259 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
260 f.debug_struct("FlashHandler")
261 .field("store", &self.store)
262 .finish()
263 }
264}
265#[async_trait]
266impl<S> Handler for FlashHandler<S>
267where
268 S: FlashStore,
269{
270 async fn handle(
271 &self,
272 req: &mut Request,
273 depot: &mut Depot,
274 res: &mut Response,
275 ctrl: &mut FlowCtrl,
276 ) {
277 let mut has_incoming = false;
278 if let Some(flash) = self.store.load_flash(req, depot).await {
279 has_incoming = !flash.is_empty();
280 depot.insert(INCOMING_FLASH_KEY, flash);
281 }
282 depot.insert(OUTGOING_FLASH_KEY, Flash(vec![]));
283
284 ctrl.call_next(req, depot, res).await;
285 if ctrl.is_ceased() {
286 return;
287 }
288
289 let mut flash = depot
290 .remove::<Flash>(OUTGOING_FLASH_KEY)
291 .unwrap_or_default();
292 if let Some(min_level) = self.minimum_level {
293 flash.0.retain(|msg| msg.level >= min_level);
294 }
295 if !flash.is_empty() {
296 self.store.save_flash(req, depot, res, flash).await;
297 } else if has_incoming {
298 self.store.clear_flash(depot, res).await;
299 }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use std::fmt::Write;
306
307 use salvo_core::http::header::{COOKIE, SET_COOKIE};
308 use salvo_core::prelude::*;
309 use salvo_core::test::{ResponseExt, TestClient};
310
311 use super::*;
312
313 #[handler]
314 pub async fn set_flash(depot: &mut Depot, res: &mut Response) {
315 let flash = depot.outgoing_flash_mut();
316 flash.info("Hey there!").debug("How is it going?");
317 res.render(Redirect::other("/get"));
318 }
319
320 #[handler]
321 pub async fn get_flash(depot: &mut Depot, _res: &mut Response) -> String {
322 let mut body = String::new();
323 if let Some(flash) = depot.incoming_flash() {
324 for message in flash.iter() {
325 writeln!(body, "{} - {}", message.value, message.level).unwrap();
326 }
327 }
328 body
329 }
330
331 #[cfg(feature = "cookie-store")]
332 #[tokio::test]
333 async fn test_cookie_store() {
334 let cookie_name = "my-custom-cookie-name".to_owned();
335 let router = Router::new()
336 .hoop(CookieStore::new().name(&cookie_name).into_handler())
337 .push(Router::with_path("get").get(get_flash))
338 .push(Router::with_path("set").get(set_flash));
339 let service = Service::new(router);
340
341 let response = TestClient::get("http://127.0.0.1:8698/set")
342 .send(&service)
343 .await;
344 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
345
346 let cookie = response.headers().get(SET_COOKIE).unwrap();
347 assert!(cookie.to_str().unwrap().contains(&cookie_name));
348
349 let mut response = TestClient::get("http://127.0.0.1:8698/get")
350 .add_header(COOKIE, cookie, true)
351 .send(&service)
352 .await;
353 assert!(response.take_string().await.unwrap().contains("Hey there!"));
354
355 let cookie = response.headers().get(SET_COOKIE).unwrap();
356 assert!(cookie.to_str().unwrap().contains(&cookie_name));
357
358 let mut response = TestClient::get("http://127.0.0.1:8698/get")
359 .add_header(COOKIE, cookie, true)
360 .send(&service)
361 .await;
362 assert!(response.take_string().await.unwrap().is_empty());
363 }
364
365 #[cfg(feature = "session-store")]
366 #[tokio::test]
367 async fn test_session_store() {
368 let session_handler = salvo_session::SessionHandler::builder(
369 salvo_session::MemoryStore::new(),
370 b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
371 )
372 .build()
373 .unwrap();
374
375 let session_name = "my-custom-session-name".to_string();
376 let router = Router::new()
377 .hoop(session_handler)
378 .hoop(SessionStore::new().name(&session_name).into_handler())
379 .push(Router::with_path("get").get(get_flash))
380 .push(Router::with_path("set").get(set_flash));
381 let service = Service::new(router);
382
383 let response = TestClient::get("http://127.0.0.1:8698/set")
384 .send(&service)
385 .await;
386 assert_eq!(response.status_code, Some(StatusCode::SEE_OTHER));
387
388 let cookie = response.headers().get(SET_COOKIE).unwrap();
389
390 let mut response = TestClient::get("http://127.0.0.1:8698/get")
391 .add_header(COOKIE, cookie, true)
392 .send(&service)
393 .await;
394 assert!(response.take_string().await.unwrap().contains("Hey there!"));
395
396 let mut response = TestClient::get("http://127.0.0.1:8698/get")
397 .add_header(COOKIE, cookie, true)
398 .send(&service)
399 .await;
400 assert!(response.take_string().await.unwrap().is_empty());
401 }
402}