salvo_extra/
request_id.rs

1//! Request id middleware.
2//!
3//! # Example
4//!
5//! ```no_run
6//! use salvo_core::prelude::*;
7//! use salvo_extra::request_id::RequestId;
8//!
9//! #[handler]
10//! async fn hello(req: &mut Request) -> String {
11//!     format!("Request id: {:?}", req.header::<String>("x-request-id"))
12//! }
13//!
14//! #[tokio::main]
15//! async fn main() {
16//!     let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
17//!     let router = Router::new().hoop(RequestId::new()).get(hello);
18//!     Server::new(acceptor).serve(router).await;
19//! }
20//! ```
21use ulid::Ulid;
22
23use salvo_core::http::{header::HeaderName, Request, Response};
24use salvo_core::{async_trait, Depot, FlowCtrl, Handler};
25
26/// Key for incoming flash messages in depot.
27pub const REQUEST_ID_KEY: &str = "::salvo::request_id";
28
29/// Extension for Depot.
30pub trait RequestIdDepotExt {
31    /// Get request id reference from depot.
32    fn csrf_token(&self) -> Option<&str>;
33}
34
35impl RequestIdDepotExt for Depot {
36    #[inline]
37    fn csrf_token(&self) -> Option<&str> {
38        self.get::<String>(REQUEST_ID_KEY).map(|v|&**v).ok()
39    }
40}
41
42/// A middleware for generate request id.
43#[non_exhaustive]
44pub struct RequestId {
45    /// The header name for request id.
46    pub header_name: HeaderName,
47    /// Whether overwrite exists request id. Default is `true`
48    pub overwrite: bool,
49    /// The generator for request id.
50    pub generator: Box<dyn IdGenerator + Send + Sync>,
51}
52
53impl RequestId {
54    /// Create new `CatchPanic` middleware.
55    pub fn new() -> Self {
56        Self {
57            header_name: HeaderName::from_static("x-request-id"),
58            overwrite: true,
59            generator: Box::new(UlidGenerator::new()),
60        }
61    }
62
63    /// Set the header name for request id.
64    pub fn header_name(mut self, name: HeaderName) -> Self {
65        self.header_name = name;
66        self
67    }
68
69    /// Set whether overwrite exists request id. Default is `true`.
70    pub fn overwrite(mut self, overwrite: bool) -> Self {
71        self.overwrite = overwrite;
72        self
73    }
74
75    /// Set the generator for request id.
76    pub fn generator(mut self, generator: impl IdGenerator + Send + Sync + 'static) -> Self {
77        self.generator = Box::new(generator);
78        self
79    }
80}
81
82impl Default for RequestId {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88/// A trait for generate request id.
89pub trait IdGenerator {
90    /// Generate a new request id.
91    fn generate(&self, req: &mut Request, depot: &mut Depot) -> String;
92}
93
94impl<F> IdGenerator for F
95where
96    F: Fn() -> String + Send + Sync,
97{
98    fn generate(&self, _req: &mut Request, _depot: &mut Depot) -> String {
99        self()
100    }
101}
102
103/// A generator for generate request id with ulid.
104#[derive(Default, Debug)]
105pub struct UlidGenerator{}
106impl UlidGenerator{
107    /// Create new `UlidGenerator`.
108    pub fn new() -> Self {
109        Self {}
110    }
111}
112impl IdGenerator for UlidGenerator {
113    fn generate(&self, _req: &mut Request, _depot: &mut Depot) -> String {
114        Ulid::new().to_string()
115    }
116}
117
118#[async_trait]
119impl Handler for RequestId {
120    async fn handle(&self, req: &mut Request, depot: &mut Depot, _res: &mut Response, _ctrl: &mut FlowCtrl) {
121        if !self.overwrite && req.headers().contains_key(&self.header_name) {
122            return;
123        }
124        let id = self.generator.generate(req, depot);
125        let _ = req.add_header(self.header_name.clone(), &id, true);
126        depot.insert(REQUEST_ID_KEY, id);
127    }
128}