Skip to main content

tower_reqwest/adapters/
reqwest.rs

1//! Adapter for [`reqwest`] client.
2//!
3//! [`reqwest`]: https://crates.io/crates/reqwest
4
5use std::{future::Future, task::Poll};
6
7use pin_project::pin_project;
8use tower_service::Service;
9
10use crate::HttpClientService;
11
12impl<S> Service<http::Request<reqwest::Body>> for HttpClientService<S>
13where
14    S: Service<reqwest::Request, Error = reqwest::Error>,
15    http::Response<reqwest::Body>: From<S::Response>,
16{
17    type Response = http::Response<reqwest::Body>;
18    type Error = S::Error;
19    type Future = ExecuteRequestFuture<S>;
20
21    fn poll_ready(
22        &mut self,
23        _cx: &mut std::task::Context<'_>,
24    ) -> std::task::Poll<Result<(), Self::Error>> {
25        Poll::Ready(Ok(()))
26    }
27
28    fn call(&mut self, req: http::Request<reqwest::Body>) -> Self::Future {
29        let future = reqwest::Request::try_from(req).map(|reqw| self.0.call(reqw));
30        ExecuteRequestFuture::new(future)
31    }
32}
33
34/// Future that resolves to the response or failure to connect.
35#[pin_project(project = ExecuteRequestFutureProj)]
36#[derive(Debug)]
37pub enum ExecuteRequestFuture<S>
38where
39    S: Service<reqwest::Request>,
40{
41    Future {
42        #[pin]
43        fut: S::Future,
44    },
45    Error {
46        error: Option<S::Error>,
47    },
48}
49
50impl<S> ExecuteRequestFuture<S>
51where
52    S: Service<reqwest::Request>,
53{
54    fn new(future: Result<S::Future, S::Error>) -> Self {
55        match future {
56            Ok(fut) => Self::Future { fut },
57            Err(error) => Self::Error { error: Some(error) },
58        }
59    }
60}
61
62impl<S> Future for ExecuteRequestFuture<S>
63where
64    S: Service<reqwest::Request>,
65    http::Response<reqwest::Body>: From<S::Response>,
66{
67    type Output = Result<http::Response<reqwest::Body>, S::Error>;
68
69    fn poll(
70        self: std::pin::Pin<&mut Self>,
71        cx: &mut std::task::Context<'_>,
72    ) -> std::task::Poll<Self::Output> {
73        match self.project() {
74            ExecuteRequestFutureProj::Future { fut } => fut.poll(cx).map_ok(From::from),
75            ExecuteRequestFutureProj::Error { error } => {
76                let error = error.take().expect("Polled after ready");
77                Poll::Ready(Err(error))
78            }
79        }
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use futures_util::{FutureExt, TryFutureExt as _, future::BoxFuture};
86    use http::{HeaderName, HeaderValue, header::USER_AGENT};
87    use http_body_util::BodyExt;
88    use pretty_assertions::assert_eq;
89    use reqwest::Client;
90    use serde::{Deserialize, Serialize};
91    use tower::{Layer, Service, ServiceBuilder, ServiceExt};
92    use tower_http::{ServiceBuilderExt, request_id::MakeRequestUuid};
93    use wiremock::{
94        Mock, MockServer, ResponseTemplate,
95        matchers::{method, path},
96    };
97
98    use crate::HttpClientLayer;
99
100    #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
101    struct Info {
102        student: String,
103        answer: u32,
104        request_id: Option<String>,
105    }
106
107    impl Info {
108        async fn from_body(body: reqwest::Body) -> anyhow::Result<Self> {
109            let body_bytes = body.collect().await?.to_bytes();
110            let info: Info = serde_json::from_slice(&body_bytes)?;
111            Ok(info)
112        }
113    }
114
115    #[tokio::test]
116    async fn test_http_client_layer() -> anyhow::Result<()> {
117        // Start a background HTTP server on a random local port
118        let mock_server = MockServer::start().await;
119        // Get mock server base uri
120        let mock_uri = mock_server.uri();
121
122        // Arrange the behaviour of the MockServer adding a Mock:
123        // when it receives a GET request on '/hello' it will respond with a 200.
124        Mock::given(method("GET"))
125            .and(path("/hello"))
126            .respond_with(|req: &wiremock::Request| {
127                let request_id = req
128                    .headers
129                    .get(HeaderName::from_static("x-request-id"))
130                    .map(|value| value.to_str().unwrap().to_owned());
131
132                ResponseTemplate::new(200).set_body_json(Info {
133                    student: "Vasya Pupkin".to_owned(),
134                    answer: 42,
135                    request_id,
136                })
137            })
138            // Mounting the mock on the mock server - it's now effective!
139            .mount(&mock_server)
140            .await;
141        // Create HTTP client
142        let client = Client::new();
143
144        // Execute request without layers
145        let request = http::request::Builder::new()
146            .method(http::Method::GET)
147            .uri(format!("{mock_uri}/hello"))
148            .body(reqwest::Body::default())?;
149
150        let response = ServiceBuilder::new()
151            .layer(HttpClientLayer)
152            .service(client.clone())
153            .call(request)
154            .await?;
155        assert!(response.status().is_success());
156        // Try to read body
157        let info = Info::from_body(response.into_body()).await?;
158        assert!(info.request_id.is_none());
159
160        let service = ServiceBuilder::new()
161            .override_response_header(USER_AGENT, HeaderValue::from_static("tower-reqwest"))
162            .set_x_request_id(MakeRequestUuid)
163            .map_err(|err: FailableServiceError<reqwest::Error>| anyhow::Error::from(err))
164            .layer(FailableServiceLayer)
165            .layer(HttpClientLayer)
166            .service(client)
167            .boxed_clone();
168        // Execute request with a several layers from the tower-http
169        let request = http::request::Builder::new()
170            .method(http::Method::GET)
171            .uri(format!("{mock_uri}/hello"))
172            .body(reqwest::Body::default())?;
173        let response = service
174            .clone()
175            .call(request)
176            .await
177            .inspect_err(|_: &anyhow::Error| {})?;
178
179        assert!(response.status().is_success());
180        assert_eq!(
181            response.headers().get(USER_AGENT).unwrap(),
182            HeaderValue::from_static("tower-reqwest")
183        );
184
185        // Try to read body again.
186        let info = Info::from_body(response.into_body()).await?;
187        assert_eq!(info.student, "Vasya Pupkin");
188        assert_eq!(info.answer, 42);
189        assert!(info.request_id.is_some());
190
191        Ok(())
192    }
193
194    #[derive(Debug, Clone)]
195    struct FailableServiceLayer;
196
197    impl<S> Layer<S> for FailableServiceLayer {
198        type Service = FailableService<S>;
199
200        fn layer(&self, inner: S) -> Self::Service {
201            FailableService { inner }
202        }
203    }
204
205    #[derive(Debug, Clone)]
206    struct FailableService<S> {
207        inner: S,
208    }
209
210    impl<S, B> Service<http::Request<B>> for FailableService<S>
211    where
212        S: Service<http::Request<B>>,
213        S::Future: Send + 'static,
214        S::Error: 'static,
215    {
216        type Response = S::Response;
217        type Error = FailableServiceError<S::Error>;
218        type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
219
220        fn poll_ready(
221            &mut self,
222            cx: &mut std::task::Context<'_>,
223        ) -> std::task::Poll<Result<(), Self::Error>> {
224            self.inner
225                .poll_ready(cx)
226                .map_err(FailableServiceError::Inner)
227        }
228
229        fn call(&mut self, req: http::Request<B>) -> Self::Future {
230            self.inner
231                .call(req)
232                .map_err(FailableServiceError::Inner)
233                .boxed()
234        }
235    }
236
237    #[derive(Debug, Clone, thiserror::Error)]
238    #[error("i'm failed")]
239    #[allow(unused)]
240    enum FailableServiceError<E> {
241        Inner(E),
242        Other,
243    }
244}