viz_middleware/
request_id.rs1use std::{future::Future, pin::Pin};
2
3use viz_core::{http, Context, Error, Middleware, Response, Result};
4use viz_utils::tracing;
5
6fn generate_id() -> Result<String> {
7 cfg_if::cfg_if! {
8 if #[cfg(feature = "request-nanoid")] {
9 Ok(nano_id::base64::<21>())
10 } else if #[cfg(feature = "request-uuid")] {
11 Ok(uuid::Uuid::new_v4().to_string())
12 }
13 }
14}
15
16pub struct RequestID<F = fn() -> Result<String>> {
18 header: &'static str,
20 generator: F,
22}
23
24impl Default for RequestID {
25 fn default() -> Self {
26 Self::new(Self::HEADER, generate_id)
27 }
28}
29
30impl<F> RequestID<F>
31where
32 F: Fn() -> Result<String>,
33{
34 const HEADER: &'static str = "x-request-id";
35
36 pub fn new(header: &'static str, generator: F) -> Self {
38 Self { header, generator }
39 }
40
41 async fn run(&self, cx: &mut Context) -> Result<Response> {
42 let mut res: Response = cx.next().await.into();
43
44 let id = match cx.header_value(&self.header).cloned() {
45 Some(id) => id,
46 None => (self.generator)()
47 .and_then(|id| http::HeaderValue::from_str(&id).map_err(Error::new))?,
48 };
49
50 tracing::trace!(" {:>7?}", id);
51
52 res.headers_mut().insert(http::header::HeaderName::from_static(Self::HEADER), id);
53
54 Ok(res)
55 }
56}
57
58impl<'a, F> Middleware<'a, Context> for RequestID<F>
59where
60 F: Sync + Send + 'static + Fn() -> Result<String>,
61{
62 type Output = Result<Response>;
63
64 #[must_use]
65 fn call(
66 &'a self,
67 cx: &'a mut Context,
68 ) -> Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>> {
69 Box::pin(self.run(cx))
70 }
71}