Skip to main content

wae_https/middleware/
request_id.rs

1//! 请求 ID 中间件
2//!
3//! 为每个请求生成唯一的追踪 ID。
4
5use http::{Request, Response};
6use pin_project_lite::pin_project;
7use std::{
8    future::Future,
9    pin::Pin,
10    task::{Context, Poll},
11};
12use tower::{Layer, Service};
13use uuid::Uuid;
14
15/// 请求 ID 头名称
16pub const X_REQUEST_ID: &str = "x-request-id";
17
18/// 请求 ID 中间件层
19#[derive(Debug, Clone)]
20pub struct RequestIdLayer;
21
22impl<S> Layer<S> for RequestIdLayer {
23    type Service = RequestIdService<S>;
24
25    fn layer(&self, inner: S) -> Self::Service {
26        RequestIdService { inner }
27    }
28}
29
30/// 请求 ID 服务
31#[derive(Debug, Clone)]
32pub struct RequestIdService<S> {
33    inner: S,
34}
35
36impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for RequestIdService<S>
37where
38    S: Service<Request<ReqBody>, Response = Response<ResBody>>,
39    S::Future: Send + 'static,
40{
41    type Response = S::Response;
42    type Error = S::Error;
43    type Future = RequestIdFuture<S::Future>;
44
45    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
46        self.inner.poll_ready(cx)
47    }
48
49    fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
50        let request_id = req
51            .headers()
52            .get(X_REQUEST_ID)
53            .and_then(|v| v.to_str().ok())
54            .map(|s| s.to_string())
55            .unwrap_or_else(|| Uuid::new_v4().to_string());
56
57        let future = self.inner.call(req);
58        RequestIdFuture { inner: future, request_id }
59    }
60}
61
62pin_project! {
63    /// 请求 ID 未来
64    pub struct RequestIdFuture<F> {
65        #[pin]
66        inner: F,
67        request_id: String,
68    }
69}
70
71impl<F, Res, E> Future for RequestIdFuture<F>
72where
73    F: Future<Output = Result<Res, E>>,
74{
75    type Output = Result<Res, E>;
76
77    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78        let this = self.project();
79        match this.inner.poll(cx) {
80            Poll::Ready(result) => Poll::Ready(result),
81            Poll::Pending => Poll::Pending,
82        }
83    }
84}