tower_reqwest/adapters/
reqwest.rs1use std::{future::Future, task::Poll};
6
7use pin_project::pin_project;
8use tower_service::Service;
9
10use crate::HttpClientService;
11
12impl<S> Service<http::Request<reqwest::Body>> for HttpClientService<S>
13where
14 S: Service<reqwest::Request, Error = reqwest::Error>,
15 http::Response<reqwest::Body>: From<S::Response>,
16{
17 type Response = http::Response<reqwest::Body>;
18 type Error = S::Error;
19 type Future = ExecuteRequestFuture<S>;
20
21 fn poll_ready(
22 &mut self,
23 _cx: &mut std::task::Context<'_>,
24 ) -> std::task::Poll<Result<(), Self::Error>> {
25 Poll::Ready(Ok(()))
26 }
27
28 fn call(&mut self, req: http::Request<reqwest::Body>) -> Self::Future {
29 let future = reqwest::Request::try_from(req).map(|reqw| self.0.call(reqw));
30 ExecuteRequestFuture::new(future)
31 }
32}
33
34#[pin_project(project = ExecuteRequestFutureProj)]
36#[derive(Debug)]
37pub enum ExecuteRequestFuture<S>
38where
39 S: Service<reqwest::Request>,
40{
41 Future {
42 #[pin]
43 fut: S::Future,
44 },
45 Error {
46 error: Option<S::Error>,
47 },
48}
49
50impl<S> ExecuteRequestFuture<S>
51where
52 S: Service<reqwest::Request>,
53{
54 fn new(future: Result<S::Future, S::Error>) -> Self {
55 match future {
56 Ok(fut) => Self::Future { fut },
57 Err(error) => Self::Error { error: Some(error) },
58 }
59 }
60}
61
62impl<S> Future for ExecuteRequestFuture<S>
63where
64 S: Service<reqwest::Request>,
65 http::Response<reqwest::Body>: From<S::Response>,
66{
67 type Output = Result<http::Response<reqwest::Body>, S::Error>;
68
69 fn poll(
70 self: std::pin::Pin<&mut Self>,
71 cx: &mut std::task::Context<'_>,
72 ) -> std::task::Poll<Self::Output> {
73 match self.project() {
74 ExecuteRequestFutureProj::Future { fut } => fut.poll(cx).map_ok(From::from),
75 ExecuteRequestFutureProj::Error { error } => {
76 let error = error.take().expect("Polled after ready");
77 Poll::Ready(Err(error))
78 }
79 }
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use futures_util::{FutureExt, TryFutureExt as _, future::BoxFuture};
86 use http::{HeaderName, HeaderValue, header::USER_AGENT};
87 use http_body_util::BodyExt;
88 use pretty_assertions::assert_eq;
89 use reqwest::Client;
90 use serde::{Deserialize, Serialize};
91 use tower::{Layer, Service, ServiceBuilder, ServiceExt};
92 use tower_http::{ServiceBuilderExt, request_id::MakeRequestUuid};
93 use wiremock::{
94 Mock, MockServer, ResponseTemplate,
95 matchers::{method, path},
96 };
97
98 use crate::HttpClientLayer;
99
100 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
101 struct Info {
102 student: String,
103 answer: u32,
104 request_id: Option<String>,
105 }
106
107 impl Info {
108 async fn from_body(body: reqwest::Body) -> anyhow::Result<Self> {
109 let body_bytes = body.collect().await?.to_bytes();
110 let info: Info = serde_json::from_slice(&body_bytes)?;
111 Ok(info)
112 }
113 }
114
115 #[tokio::test]
116 async fn test_http_client_layer() -> anyhow::Result<()> {
117 let mock_server = MockServer::start().await;
119 let mock_uri = mock_server.uri();
121
122 Mock::given(method("GET"))
125 .and(path("/hello"))
126 .respond_with(|req: &wiremock::Request| {
127 let request_id = req
128 .headers
129 .get(HeaderName::from_static("x-request-id"))
130 .map(|value| value.to_str().unwrap().to_owned());
131
132 ResponseTemplate::new(200).set_body_json(Info {
133 student: "Vasya Pupkin".to_owned(),
134 answer: 42,
135 request_id,
136 })
137 })
138 .mount(&mock_server)
140 .await;
141 let client = Client::new();
143
144 let request = http::request::Builder::new()
146 .method(http::Method::GET)
147 .uri(format!("{mock_uri}/hello"))
148 .body(reqwest::Body::default())?;
149
150 let response = ServiceBuilder::new()
151 .layer(HttpClientLayer)
152 .service(client.clone())
153 .call(request)
154 .await?;
155 assert!(response.status().is_success());
156 let info = Info::from_body(response.into_body()).await?;
158 assert!(info.request_id.is_none());
159
160 let service = ServiceBuilder::new()
161 .override_response_header(USER_AGENT, HeaderValue::from_static("tower-reqwest"))
162 .set_x_request_id(MakeRequestUuid)
163 .map_err(|err: FailableServiceError<reqwest::Error>| anyhow::Error::from(err))
164 .layer(FailableServiceLayer)
165 .layer(HttpClientLayer)
166 .service(client)
167 .boxed_clone();
168 let request = http::request::Builder::new()
170 .method(http::Method::GET)
171 .uri(format!("{mock_uri}/hello"))
172 .body(reqwest::Body::default())?;
173 let response = service
174 .clone()
175 .call(request)
176 .await
177 .inspect_err(|_: &anyhow::Error| {})?;
178
179 assert!(response.status().is_success());
180 assert_eq!(
181 response.headers().get(USER_AGENT).unwrap(),
182 HeaderValue::from_static("tower-reqwest")
183 );
184
185 let info = Info::from_body(response.into_body()).await?;
187 assert_eq!(info.student, "Vasya Pupkin");
188 assert_eq!(info.answer, 42);
189 assert!(info.request_id.is_some());
190
191 Ok(())
192 }
193
194 #[derive(Debug, Clone)]
195 struct FailableServiceLayer;
196
197 impl<S> Layer<S> for FailableServiceLayer {
198 type Service = FailableService<S>;
199
200 fn layer(&self, inner: S) -> Self::Service {
201 FailableService { inner }
202 }
203 }
204
205 #[derive(Debug, Clone)]
206 struct FailableService<S> {
207 inner: S,
208 }
209
210 impl<S, B> Service<http::Request<B>> for FailableService<S>
211 where
212 S: Service<http::Request<B>>,
213 S::Future: Send + 'static,
214 S::Error: 'static,
215 {
216 type Response = S::Response;
217 type Error = FailableServiceError<S::Error>;
218 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
219
220 fn poll_ready(
221 &mut self,
222 cx: &mut std::task::Context<'_>,
223 ) -> std::task::Poll<Result<(), Self::Error>> {
224 self.inner
225 .poll_ready(cx)
226 .map_err(FailableServiceError::Inner)
227 }
228
229 fn call(&mut self, req: http::Request<B>) -> Self::Future {
230 self.inner
231 .call(req)
232 .map_err(FailableServiceError::Inner)
233 .boxed()
234 }
235 }
236
237 #[derive(Debug, Clone, thiserror::Error)]
238 #[error("i'm failed")]
239 #[allow(unused)]
240 enum FailableServiceError<E> {
241 Inner(E),
242 Other,
243 }
244}