tower_http/timeout/
body.rs1use crate::BoxError;
2use http_body::Body;
3use pin_project_lite::pin_project;
4use std::{
5 future::Future,
6 pin::Pin,
7 task::{ready, Context, Poll},
8 time::Duration,
9};
10use tokio::time::{sleep, Sleep};
11
12pin_project! {
13 pub struct TimeoutBody<B> {
53 timeout: Duration,
54 #[pin]
55 sleep: Option<Sleep>,
56 #[pin]
57 body: B,
58 }
59}
60
61impl<B> TimeoutBody<B> {
62 pub fn new(timeout: Duration, body: B) -> Self {
64 TimeoutBody {
65 timeout,
66 sleep: None,
67 body,
68 }
69 }
70}
71
72impl<B> Body for TimeoutBody<B>
73where
74 B: Body,
75 B::Error: Into<BoxError>,
76{
77 type Data = B::Data;
78 type Error = Box<dyn std::error::Error + Send + Sync>;
79
80 fn poll_frame(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
84 let mut this = self.project();
85
86 let sleep_pinned = if let Some(some) = this.sleep.as_mut().as_pin_mut() {
88 some
89 } else {
90 this.sleep.set(Some(sleep(*this.timeout)));
91 this.sleep.as_mut().as_pin_mut().unwrap()
92 };
93
94 if let Poll::Ready(()) = sleep_pinned.poll(cx) {
96 return Poll::Ready(Some(Err(Box::new(TimeoutError(())))));
97 }
98
99 let frame = ready!(this.body.poll_frame(cx));
101 this.sleep.set(None);
103
104 Poll::Ready(frame.transpose().map_err(Into::into).transpose())
105 }
106}
107
108#[derive(Debug)]
110pub struct TimeoutError(());
111
112impl std::error::Error for TimeoutError {}
113
114impl std::fmt::Display for TimeoutError {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(f, "data was not received within the designated timeout")
117 }
118}
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 use bytes::Bytes;
124 use http_body::Frame;
125 use http_body_util::BodyExt;
126 use pin_project_lite::pin_project;
127 use std::{error::Error, fmt::Display};
128
129 #[derive(Debug)]
130 struct MockError;
131
132 impl Error for MockError {}
133
134 impl Display for MockError {
135 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
136 write!(f, "mock error")
137 }
138 }
139
140 pin_project! {
141 struct MockBody {
142 #[pin]
143 sleep: Sleep
144 }
145 }
146
147 impl Body for MockBody {
148 type Data = Bytes;
149 type Error = MockError;
150
151 fn poll_frame(
152 self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 ) -> Poll<Option<Result<http_body::Frame<Self::Data>, Self::Error>>> {
155 let this = self.project();
156 this.sleep
157 .poll(cx)
158 .map(|_| Some(Ok(Frame::data(vec![].into()))))
159 }
160 }
161
162 #[tokio::test]
163 async fn test_body_available_within_timeout() {
164 let mock_sleep = Duration::from_secs(1);
165 let timeout_sleep = Duration::from_secs(2);
166
167 let mock_body = MockBody {
168 sleep: sleep(mock_sleep),
169 };
170 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
171
172 assert!(timeout_body
173 .boxed()
174 .frame()
175 .await
176 .expect("no frame")
177 .is_ok());
178 }
179
180 #[tokio::test]
181 async fn test_body_unavailable_within_timeout_error() {
182 let mock_sleep = Duration::from_secs(2);
183 let timeout_sleep = Duration::from_secs(1);
184
185 let mock_body = MockBody {
186 sleep: sleep(mock_sleep),
187 };
188 let timeout_body = TimeoutBody::new(timeout_sleep, mock_body);
189
190 assert!(timeout_body.boxed().frame().await.unwrap().is_err());
191 }
192}