viz_middleware/
request_id.rs

1use std::{future::Future, pin::Pin};
2
3use viz_core::{http, Context, Error, Middleware, Response, Result};
4use viz_utils::tracing;
5
6fn generate_id() -> Result<String> {
7    cfg_if::cfg_if! {
8        if #[cfg(feature = "request-nanoid")] {
9            Ok(nano_id::base64::<21>())
10        }  else if #[cfg(feature = "request-uuid")] {
11            Ok(uuid::Uuid::new_v4().to_string())
12        }
13    }
14}
15
16/// RequestID Middleware
17pub struct RequestID<F = fn() -> Result<String>> {
18    /// Header Name is must be lower-case.
19    header: &'static str,
20    /// Generates request id
21    generator: F,
22}
23
24impl Default for RequestID {
25    fn default() -> Self {
26        Self::new(Self::HEADER, generate_id)
27    }
28}
29
30impl<F> RequestID<F>
31where
32    F: Fn() -> Result<String>,
33{
34    const HEADER: &'static str = "x-request-id";
35
36    /// Creates new `RequestID` Middleware.
37    pub fn new(header: &'static str, generator: F) -> Self {
38        Self { header, generator }
39    }
40
41    async fn run(&self, cx: &mut Context) -> Result<Response> {
42        let mut res: Response = cx.next().await.into();
43
44        let id = match cx.header_value(&self.header).cloned() {
45            Some(id) => id,
46            None => (self.generator)()
47                .and_then(|id| http::HeaderValue::from_str(&id).map_err(Error::new))?,
48        };
49
50        tracing::trace!(" {:>7?}", id);
51
52        res.headers_mut().insert(http::header::HeaderName::from_static(Self::HEADER), id);
53
54        Ok(res)
55    }
56}
57
58impl<'a, F> Middleware<'a, Context> for RequestID<F>
59where
60    F: Sync + Send + 'static + Fn() -> Result<String>,
61{
62    type Output = Result<Response>;
63
64    #[must_use]
65    fn call(
66        &'a self,
67        cx: &'a mut Context,
68    ) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>> {
69        Box::pin(self.run(cx))
70    }
71}