xitca_web/middleware/
grpc_timeout.rs1use core::time::Duration;
22
23use tokio::time::Instant;
24
25use crate::{
26 body::ResponseBody,
27 context::WebContext,
28 error::{GrpcError, GrpcStatus},
29 http::{
30 WebResponse,
31 const_header_name::GRPC_TIMEOUT,
32 const_header_value::GRPC,
33 header::{CONTENT_TYPE, HeaderValue},
34 },
35 service::{Service, ready::ReadyService},
36};
37
38pub struct GrpcTimeout;
43
44impl<S, E> Service<Result<S, E>> for GrpcTimeout {
45 type Response = GrpcTimeoutService<S>;
46 type Error = E;
47
48 async fn call(&self, res: Result<S, E>) -> Result<Self::Response, Self::Error> {
49 res.map(|service| GrpcTimeoutService { service })
50 }
51}
52
53pub struct GrpcTimeoutService<S> {
54 service: S,
55}
56
57impl<'r, S, C, B> Service<WebContext<'r, C, B>> for GrpcTimeoutService<S>
58where
59 S: for<'r2> Service<WebContext<'r2, C, B>, Response = WebResponse, Error = crate::error::Error>,
60{
61 type Response = WebResponse;
62 type Error = crate::error::Error;
63
64 async fn call(&self, ctx: WebContext<'r, C, B>) -> Result<Self::Response, Self::Error> {
65 let timeout = ctx.req().headers().get(GRPC_TIMEOUT).and_then(parse_grpc_timeout);
66
67 match timeout {
68 Some(duration) => {
69 let deadline = Instant::now() + duration;
70
71 match tokio::time::timeout_at(deadline, self.service.call(ctx)).await {
72 Ok(result) => result,
73 Err(_elapsed) => {
74 let err = GrpcError::new(GrpcStatus::DeadlineExceeded, "deadline exceeded");
75 let mut res = WebResponse::new(ResponseBody::empty());
76 res.headers_mut().insert(CONTENT_TYPE, GRPC);
77 res.headers_mut().extend(err.trailers());
78 Ok(res)
79 }
80 }
81 }
82 None => self.service.call(ctx).await,
83 }
84 }
85}
86
87impl<S> ReadyService for GrpcTimeoutService<S>
88where
89 S: ReadyService,
90{
91 type Ready = S::Ready;
92
93 #[inline]
94 async fn ready(&self) -> Self::Ready {
95 self.service.ready().await
96 }
97}
98
99fn parse_grpc_timeout(value: &HeaderValue) -> Option<Duration> {
105 let bytes = value.as_bytes();
106 if bytes.len() < 2 {
107 return None;
108 }
109
110 let (digits, unit) = bytes.split_at(bytes.len() - 1);
111
112 if digits.is_empty() || digits.len() > 8 {
114 return None;
115 }
116
117 let mut val: u64 = 0;
118 for &b in digits {
119 if !b.is_ascii_digit() {
120 return None;
121 }
122 val = val * 10 + (b - b'0') as u64;
123 }
124
125 match unit[0] {
126 b'H' => Some(Duration::from_secs(val * 3600)),
127 b'M' => Some(Duration::from_secs(val * 60)),
128 b'S' => Some(Duration::from_secs(val)),
129 b'm' => Some(Duration::from_millis(val)),
130 b'u' => Some(Duration::from_micros(val)),
131 b'n' => Some(Duration::from_nanos(val)),
132 _ => None,
133 }
134}
135
136#[cfg(test)]
137mod test {
138 use super::*;
139
140 #[test]
141 fn parse_timeout_values() {
142 assert_eq!(
143 parse_grpc_timeout(&HeaderValue::from_static("1H")),
144 Some(Duration::from_secs(3600))
145 );
146 assert_eq!(
147 parse_grpc_timeout(&HeaderValue::from_static("5M")),
148 Some(Duration::from_secs(300))
149 );
150 assert_eq!(
151 parse_grpc_timeout(&HeaderValue::from_static("10S")),
152 Some(Duration::from_secs(10))
153 );
154 assert_eq!(
155 parse_grpc_timeout(&HeaderValue::from_static("100m")),
156 Some(Duration::from_millis(100))
157 );
158 assert_eq!(
159 parse_grpc_timeout(&HeaderValue::from_static("5000u")),
160 Some(Duration::from_micros(5000))
161 );
162 assert_eq!(
163 parse_grpc_timeout(&HeaderValue::from_static("999n")),
164 Some(Duration::from_nanos(999))
165 );
166 }
167
168 #[test]
169 fn parse_timeout_invalid() {
170 assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("H")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("5x")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("abc")), None); assert_eq!(parse_grpc_timeout(&HeaderValue::from_static("123456789S")), None); }
176
177 #[test]
178 fn parse_timeout_max_digits() {
179 assert_eq!(
181 parse_grpc_timeout(&HeaderValue::from_static("99999999S")),
182 Some(Duration::from_secs(99999999))
183 );
184 }
185}