Skip to main content

xitca_web/middleware/
grpc_timeout.rs

1//! gRPC timeout middleware.
2//!
3//! Parses the `grpc-timeout` request header and enforces it as a deadline on the inner service call.
4//! If the deadline is exceeded, a trailers-only response with `grpc-status: 4` (DeadlineExceeded) is returned.
5//!
6//! The parsed deadline [`Instant`] is inserted into request extensions so downstream extractors and
7//! handlers can observe the remaining time.
8//!
9//! # Example
10//! ```rust
11//! # use xitca_web::{handler::handler_service, App, WebContext};
12//! use xitca_web::middleware::grpc_timeout::GrpcTimeout;
13//!
14//! App::new()
15//!     .at("/my.Service/Method", handler_service(handler))
16//!     .enclosed(GrpcTimeout);
17//!
18//! # async fn handler(_: &WebContext<'_>) -> &'static str { "" }
19//! ```
20
21use core::time::Duration;
22
23use tokio::time::Instant;
24
25use crate::{
26    body::ResponseBody,
27    context::WebContext,
28    error::{GrpcError, GrpcStatus},
29    http::{
30        WebResponse,
31        const_header_name::GRPC_TIMEOUT,
32        const_header_value::GRPC,
33        header::{CONTENT_TYPE, HeaderValue},
34    },
35    service::{Service, ready::ReadyService},
36};
37
38/// Middleware that enforces the `grpc-timeout` deadline on the inner service call.
39///
40/// If the header is absent, no timeout is applied.
41/// The deadline [`Instant`] is inserted into request extensions.
42pub struct GrpcTimeout;
43
44impl<S, E> Service<Result<S, E>> for GrpcTimeout {
45    type Response = GrpcTimeoutService<S>;
46    type Error = E;
47
48    async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
49        res.map(|service| GrpcTimeoutService { service })
50    }
51}
52
53pub struct GrpcTimeoutService<S> {
54    service: S,
55}
56
57impl<'r, S, C, B> Service<WebContext<'r, C, B>> for GrpcTimeoutService<S>
58where
59    S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse, Error = crate::error::Error>,
60{
61    type Response = WebResponse;
62    type Error = crate::error::Error;
63
64    async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
65        let timeout = ctx.req().headers().get(GRPC_TIMEOUT).and_then(parse_grpc_timeout);
66
67        match timeout {
68            Some(duration) => {
69                let deadline = Instant::now() + duration;
70
71                match tokio::time::timeout_at(deadline, self.service.call(ctx)).await {
72                    Ok(result) => result,
73                    Err(_elapsed) => {
74                        let err = GrpcError::new(GrpcStatus::DeadlineExceeded, "deadline exceeded");
75                        let mut res = WebResponse::new(ResponseBody::empty());
76                        res.headers_mut().insert(CONTENT_TYPE, GRPC);
77                        res.headers_mut().extend(err.trailers());
78                        Ok(res)
79                    }
80                }
81            }
82            None => self.service.call(ctx).await,
83        }
84    }
85}
86
87impl<S> ReadyService for GrpcTimeoutService<S>
88where
89    S: ReadyService,
90{
91    type Ready = S::Ready;
92
93    #[inline]
94    async fn ready(&self) -> Self::Ready {
95        self.service.ready().await
96    }
97}
98
99/// Parse the `grpc-timeout` header value into a [`Duration`].
100///
101/// Format: `{value}{unit}` where value is 1-8 ASCII digits and unit is one of:
102/// - `H` (hours), `M` (minutes), `S` (seconds)
103/// - `m` (milliseconds), `u` (microseconds), `n` (nanoseconds)
104fn parse_grpc_timeout(value: &HeaderValue) -> Option<Duration> {
105    let bytes = value.as_bytes();
106    if bytes.len() < 2 {
107        return None;
108    }
109
110    let (digits, unit) = bytes.split_at(bytes.len() - 1);
111
112    // spec says max 8 digits
113    if digits.is_empty() || digits.len() > 8 {
114        return None;
115    }
116
117    let mut val: u64 = 0;
118    for &b in digits {
119        if !b.is_ascii_digit() {
120            return None;
121        }
122        val = val * 10 + (b - b'0') as u64;
123    }
124
125    match unit[0] {
126        b'H' => Some(Duration::from_secs(val * 3600)),
127        b'M' => Some(Duration::from_secs(val * 60)),
128        b'S' => Some(Duration::from_secs(val)),
129        b'm' => Some(Duration::from_millis(val)),
130        b'u' => Some(Duration::from_micros(val)),
131        b'n' => Some(Duration::from_nanos(val)),
132        _ => None,
133    }
134}
135
136#[cfg(test)]
137mod test {
138    use super::*;
139
140    #[test]
141    fn parse_timeout_values() {
142        assert_eq!(
143            parse_grpc_timeout(&HeaderValue::from_static("1H")),
144            Some(Duration::from_secs(3600))
145        );
146        assert_eq!(
147            parse_grpc_timeout(&HeaderValue::from_static("5M")),
148            Some(Duration::from_secs(300))
149        );
150        assert_eq!(
151            parse_grpc_timeout(&HeaderValue::from_static("10S")),
152            Some(Duration::from_secs(10))
153        );
154        assert_eq!(
155            parse_grpc_timeout(&HeaderValue::from_static("100m")),
156            Some(Duration::from_millis(100))
157        );
158        assert_eq!(
159            parse_grpc_timeout(&HeaderValue::from_static("5000u")),
160            Some(Duration::from_micros(5000))
161        );
162        assert_eq!(
163            parse_grpc_timeout(&HeaderValue::from_static("999n")),
164            Some(Duration::from_nanos(999))
165        );
166    }
167
168    #[test]
169    fn parse_timeout_invalid() {
170        assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("H")), None); // no digits
171        assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5")), None); // no unit
172        assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5x")), None); // bad unit
173        assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("abc")), None); // non-digit
174        assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("123456789S")), None); // 9 digits > max 8
175    }
176
177    #[test]
178    fn parse_timeout_max_digits() {
179        // exactly 8 digits should work
180        assert_eq!(
181            parse_grpc_timeout(&HeaderValue::from_static("99999999S")),
182            Some(Duration::from_secs(99999999))
183        );
184    }
185}