tower_reqwest/adapters/
reqwest.rs1use std::{future::Future, task::Poll};
6
7use pin_project::pin_project;
8use tower_service::Service;
9
10use crate::HttpClientService;
11
12impl<S> Service<http::Request<reqwest::Body>> for HttpClientService<S>
13where
14 S: Service<reqwest::Request>,
15 S::Future: Send + 'static,
16 S::Error: 'static,
17 http::Response<reqwest::Body>: From<S::Response>,
18 crate::Error: From<S::Error>,
19{
20 type Response = http::Response<reqwest::Body>;
21 type Error = crate::Error;
22 type Future = ExecuteRequestFuture<S>;
23
24 fn poll_ready(
25 &mut self,
26 _cx: &mut std::task::Context<'_>,
27 ) -> std::task::Poll<Result<(), Self::Error>> {
28 Poll::Ready(Ok(()))
29 }
30
31 fn call(&mut self, req: http::Request<reqwest::Body>) -> Self::Future {
32 let future = reqwest::Request::try_from(req).map(|reqw| self.0.call(reqw));
33 ExecuteRequestFuture::new(future)
34 }
35}
36
37#[pin_project]
38#[derive(Debug)]
40pub struct ExecuteRequestFuture<S>
41where
42 S: Service<reqwest::Request>,
43{
44 #[pin]
45 inner: Inner<S::Future>,
46}
47
48#[pin_project(project = InnerProj)]
49#[derive(Debug)]
50enum Inner<F> {
51 Future {
52 #[pin]
53 fut: F,
54 },
55 Error {
56 error: Option<crate::Error>,
57 },
58}
59
60impl<S> ExecuteRequestFuture<S>
61where
62 S: Service<reqwest::Request>,
63{
64 fn new(future: Result<S::Future, reqwest::Error>) -> Self {
65 let inner = match future {
66 Ok(fut) => Inner::Future { fut },
67 Err(error) => Inner::Error {
68 error: Some(error.into()),
69 },
70 };
71 Self { inner }
72 }
73}
74
75impl<S> Future for ExecuteRequestFuture<S>
76where
77 S: Service<reqwest::Request>,
78 http::Response<reqwest::Body>: From<S::Response>,
79 crate::Error: From<S::Error>,
80{
81 type Output = crate::Result<http::Response<reqwest::Body>>;
82
83 fn poll(
84 self: std::pin::Pin<&mut Self>,
85 cx: &mut std::task::Context<'_>,
86 ) -> std::task::Poll<Self::Output> {
87 let this = self.project();
88 match this.inner.project() {
89 InnerProj::Future { fut } => {
90 fut.poll(cx).map_ok(From::from).map_err(crate::Error::from)
91 }
92 InnerProj::Error { error } => {
93 let error = error.take().expect("Polled after ready");
94 Poll::Ready(Err(error))
95 }
96 }
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use http::{HeaderName, HeaderValue, header::USER_AGENT};
103 use http_body_util::BodyExt;
104 use pretty_assertions::assert_eq;
105 use reqwest::Client;
106 use serde::{Deserialize, Serialize};
107 use tower::{Service, ServiceBuilder, ServiceExt};
108 use tower_http::{ServiceBuilderExt, request_id::MakeRequestUuid};
109 use wiremock::{
110 Mock, MockServer, ResponseTemplate,
111 matchers::{method, path},
112 };
113
114 use crate::HttpClientLayer;
115
116 #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
117 struct Info {
118 student: String,
119 answer: u32,
120 request_id: Option<String>,
121 }
122
123 impl Info {
124 async fn from_body(body: reqwest::Body) -> anyhow::Result<Self> {
125 let body_bytes = body.collect().await?.to_bytes();
126 let info: Info = serde_json::from_slice(&body_bytes)?;
127 Ok(info)
128 }
129 }
130
131 #[tokio::test]
132 async fn test_http_client_layer() -> anyhow::Result<()> {
133 let mock_server = MockServer::start().await;
135 let mock_uri = mock_server.uri();
137
138 Mock::given(method("GET"))
141 .and(path("/hello"))
142 .respond_with(|req: &wiremock::Request| {
143 let request_id = req
144 .headers
145 .get(HeaderName::from_static("x-request-id"))
146 .map(|value| value.to_str().unwrap().to_owned());
147
148 ResponseTemplate::new(200).set_body_json(Info {
149 student: "Vasya Pupkin".to_owned(),
150 answer: 42,
151 request_id,
152 })
153 })
154 .mount(&mock_server)
156 .await;
157 let client = Client::new();
159
160 let request = http::request::Builder::new()
162 .method(http::Method::GET)
163 .uri(format!("{mock_uri}/hello"))
164 .body(reqwest::Body::default())?;
166
167 let response = ServiceBuilder::new()
168 .layer(HttpClientLayer)
169 .service(client.clone())
170 .call(request)
171 .await?;
172 assert!(response.status().is_success());
173 let info = Info::from_body(response.into_body()).await?;
175 assert!(info.request_id.is_none());
176
177 let service = ServiceBuilder::new()
179 .override_response_header(USER_AGENT, HeaderValue::from_static("tower-reqwest"))
180 .set_x_request_id(MakeRequestUuid)
181 .layer(HttpClientLayer)
182 .service(client)
183 .boxed_clone();
184 let request = http::request::Builder::new()
186 .method(http::Method::GET)
187 .uri(format!("{mock_uri}/hello"))
188 .body(reqwest::Body::default())?;
190 let response = service.clone().call(request).await?;
191
192 assert!(response.status().is_success());
193 assert_eq!(
194 response.headers().get(USER_AGENT).unwrap(),
195 HeaderValue::from_static("tower-reqwest")
196 );
197
198 let info = Info::from_body(response.into_body()).await?;
200 assert_eq!(info.student, "Vasya Pupkin");
201 assert_eq!(info.answer, 42);
202 assert!(info.request_id.is_some());
203
204 Ok(())
205 }
206}