smithy_transport_reqwest/
lib.rs1use std::borrow::Cow;
37use std::collections::HashMap;
38use std::error::Error;
39use std::fmt;
40use std::sync::Arc;
41use std::time::Duration;
42
43use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
44use aws_smithy_runtime_api::client::http::{
45 HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
46};
47use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
48use aws_smithy_runtime_api::client::result::ConnectorError;
49use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
50use aws_smithy_types::body::SdkBody;
51use parking_lot::Mutex;
52
53#[derive(Debug)]
59pub struct ReqwestHttpClient {
60 connector_cache: Mutex<HashMap<CacheKey, SharedHttpConnector>>,
61}
62
63impl ReqwestHttpClient {
64 pub fn new() -> Self {
66 Self::default()
67 }
68}
69
70impl Default for ReqwestHttpClient {
71 fn default() -> Self {
72 Self {
73 connector_cache: Mutex::new(HashMap::new()),
74 }
75 }
76}
77
78impl HttpClient for ReqwestHttpClient {
79 fn http_connector(
80 &self,
81 settings: &HttpConnectorSettings,
82 _: &RuntimeComponents,
83 ) -> SharedHttpConnector {
84 let key = CacheKey::from(settings);
85 self.connector_cache
86 .lock()
87 .entry(key)
88 .or_insert_with(|| SharedHttpConnector::new(ReqwestConnector::new(settings)))
89 .clone()
90 }
91
92 fn connector_metadata(&self) -> Option<ConnectorMetadata> {
93 Some(ConnectorMetadata::new(
94 "reqwest",
95 Some(Cow::Borrowed("0.13.x")),
96 ))
97 }
98}
99
100#[derive(Clone, Debug, Eq, PartialEq, Hash)]
101struct CacheKey {
102 connect_timeout: Option<Duration>,
103 read_timeout: Option<Duration>,
104}
105
106impl From<&HttpConnectorSettings> for CacheKey {
107 fn from(value: &HttpConnectorSettings) -> Self {
108 Self {
109 connect_timeout: value.connect_timeout(),
110 read_timeout: value.read_timeout(),
111 }
112 }
113}
114
115#[derive(Clone, Debug)]
116struct ReqwestConnector {
117 client: Result<reqwest::Client, Arc<ClientBuildError>>,
118}
119
120impl ReqwestConnector {
121 fn new(settings: &HttpConnectorSettings) -> Self {
122 let mut builder = reqwest::Client::builder().redirect(reqwest::redirect::Policy::none());
123 if let Some(timeout) = settings.connect_timeout() {
124 builder = builder.connect_timeout(timeout);
125 }
126 if let Some(timeout) = settings.read_timeout() {
127 builder = builder.read_timeout(timeout);
128 }
129
130 Self {
131 client: builder
132 .build()
133 .map_err(|err| Arc::new(ClientBuildError(err.to_string()))),
134 }
135 }
136}
137
138impl HttpConnector for ReqwestConnector {
139 fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
140 let client = match &self.client {
141 Ok(client) => client.clone(),
142 Err(err) => {
143 let err = ClientBuildError(err.0.clone());
144 return HttpConnectorFuture::ready(Err(ConnectorError::other(Box::new(err), None)));
145 }
146 };
147
148 let request = match request.try_into_http1x() {
149 Ok(request) => request.map(reqwest::Body::wrap),
150 Err(err) => {
151 return HttpConnectorFuture::ready(Err(ConnectorError::user(Box::new(err))));
152 }
153 };
154
155 let request = match reqwest::Request::try_from(request) {
156 Ok(request) => request,
157 Err(err) => {
158 return HttpConnectorFuture::ready(Err(map_reqwest_error(err)));
159 }
160 };
161
162 HttpConnectorFuture::new(async move {
163 let response: http::Response<reqwest::Body> = client
164 .execute(request)
165 .await
166 .map_err(map_reqwest_error)?
167 .into();
168 let response = response.map(SdkBody::from_body_1_x);
169
170 HttpResponse::try_from(response)
171 .map_err(|err| ConnectorError::other(Box::new(err), None))
172 })
173 }
174}
175
176#[derive(Debug)]
177struct ClientBuildError(String);
178
179impl fmt::Display for ClientBuildError {
180 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181 f.write_str(&self.0)
182 }
183}
184
185impl Error for ClientBuildError {}
186
187fn map_reqwest_error(err: reqwest::Error) -> ConnectorError {
188 if err.is_timeout() {
189 ConnectorError::timeout(Box::new(err))
190 } else if err.is_request() || err.is_builder() {
191 ConnectorError::user(Box::new(err))
192 } else if err.is_connect() {
193 ConnectorError::io(Box::new(err)).never_connected()
194 } else if err.is_body() || err.is_decode() {
195 ConnectorError::io(Box::new(err))
196 } else {
197 ConnectorError::other(Box::new(err), None)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
206 use http_body_util::BodyExt;
207 use tokio::io::{AsyncReadExt, AsyncWriteExt};
208 use tokio::net::TcpListener;
209
210 #[tokio::test]
211 async fn sends_request_and_streams_response_body() {
212 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
213 let address = listener.local_addr().unwrap();
214
215 let server = tokio::spawn(async move {
216 let (mut socket, _) = listener.accept().await.unwrap();
217 let mut buffer = Vec::new();
218 let mut chunk = [0; 1024];
219
220 loop {
221 let bytes_read = socket.read(&mut chunk).await.unwrap();
222 assert_ne!(0, bytes_read);
223 buffer.extend_from_slice(&chunk[..bytes_read]);
224
225 if let Some(header_end) = find_subsequence(&buffer, b"\r\n\r\n") {
226 let headers = String::from_utf8_lossy(&buffer[..header_end]);
227 let content_length = headers
228 .lines()
229 .find_map(|line| {
230 let (name, value) = line.split_once(':')?;
231 name.eq_ignore_ascii_case("content-length")
232 .then(|| value.trim().parse::<usize>().ok())
233 .flatten()
234 })
235 .unwrap_or_default();
236
237 if buffer.len() >= header_end + 4 + content_length {
238 break;
239 }
240 }
241 }
242
243 let request = String::from_utf8_lossy(&buffer);
244
245 assert!(request.starts_with("POST /hello?x=1 HTTP/1.1"));
246 assert!(request.contains("x-test: ok"));
247 assert!(request.contains("\r\n\r\nping"));
248
249 socket
250 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nx-answer: yes\r\n\r\nworld")
251 .await
252 .unwrap();
253 });
254
255 let client = ReqwestHttpClient::new();
256 let settings = HttpConnectorSettings::builder()
257 .connect_timeout(Duration::from_secs(1))
258 .read_timeout(Duration::from_secs(1))
259 .build();
260 let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
261 let connector = client.http_connector(&settings, &runtime_components);
262
263 let request = http::Request::builder()
264 .method("POST")
265 .uri(format!("http://{address}/hello?x=1"))
266 .header("x-test", "ok")
267 .body(SdkBody::from("ping"))
268 .unwrap();
269 let request = HttpRequest::try_from(request).unwrap();
270
271 let response = connector.call(request).await.unwrap();
272 assert_eq!(200, response.status().as_u16());
273 assert_eq!("yes", response.headers().get("x-answer").unwrap());
274
275 let body = response.into_body().collect().await.unwrap().to_bytes();
276 assert_eq!("world", body);
277
278 server.await.unwrap();
279 }
280
281 #[test]
282 fn connector_metadata_identifies_reqwest() {
283 let metadata = ReqwestHttpClient::new().connector_metadata().unwrap();
284 assert_eq!("reqwest", metadata.name());
285 assert_eq!(Some(Cow::Borrowed("0.13.x")), metadata.version());
286 }
287
288 fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
289 haystack
290 .windows(needle.len())
291 .position(|window| window == needle)
292 }
293}