Skip to main content

smithy_transport_reqwest/
lib.rs

1//! A [`reqwest`] transport for smithy-generated clients.
2//!
3//! [`ReqwestHttpClient`] implements
4//! [`aws_smithy_runtime_api::client::http::HttpClient`] and can be installed into
5//! AWS SDK for Rust or other smithy-runtime client configurations that accept a
6//! smithy HTTP client.
7//!
8//! # TLS features
9//!
10//! This crate disables reqwest's default features and forwards TLS selection to
11//! reqwest. No TLS backend is enabled by default; enable one of `rustls`,
12//! `rustls-no-provider`, `native-tls`, `native-tls-vendored`, or `default-tls`
13//! when HTTPS support is required.
14//!
15//! ```toml
16//! smithy-transport-reqwest = { version = "0.1", features = ["native-tls"] }
17//! ```
18//!
19//! # Using with AWS SDK for Rust
20//!
21//! ```rust,ignore
22//! use aws_config::BehaviorVersion;
23//! use smithy_transport_reqwest::ReqwestHttpClient;
24//!
25//! # async fn example() {
26//! let sdk_config = aws_config::defaults(BehaviorVersion::latest())
27//!     .http_client(ReqwestHttpClient::new())
28//!     .load()
29//!     .await;
30//!
31//! let s3 = aws_sdk_s3::Client::new(&sdk_config);
32//! # let _ = s3;
33//! # }
34//! ```
35
36use std::borrow::Cow;
37use std::collections::HashMap;
38use std::error::Error;
39use std::fmt;
40use std::sync::Arc;
41use std::time::Duration;
42
43use aws_smithy_runtime_api::client::connector_metadata::ConnectorMetadata;
44use aws_smithy_runtime_api::client::http::{
45    HttpClient, HttpConnector, HttpConnectorFuture, HttpConnectorSettings, SharedHttpConnector,
46};
47use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, HttpResponse};
48use aws_smithy_runtime_api::client::result::ConnectorError;
49use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
50use aws_smithy_types::body::SdkBody;
51use parking_lot::Mutex;
52
53/// A reqwest-backed smithy HTTP client.
54///
55/// The client lazily creates and caches reqwest clients for each distinct
56/// smithy connector timeout configuration. Reqwest's automatic redirect policy
57/// is disabled so smithy callers observe service responses directly.
58#[derive(Debug)]
59pub struct ReqwestHttpClient {
60    connector_cache: Mutex<HashMap<CacheKey, SharedHttpConnector>>,
61}
62
63impl ReqwestHttpClient {
64    /// Creates a new reqwest-backed smithy HTTP client.
65    pub fn new() -> Self {
66        Self::default()
67    }
68}
69
70impl Default for ReqwestHttpClient {
71    fn default() -> Self {
72        Self {
73            connector_cache: Mutex::new(HashMap::new()),
74        }
75    }
76}
77
78impl HttpClient for ReqwestHttpClient {
79    fn http_connector(
80        &self,
81        settings: &HttpConnectorSettings,
82        _: &RuntimeComponents,
83    ) -> SharedHttpConnector {
84        let key = CacheKey::from(settings);
85        self.connector_cache
86            .lock()
87            .entry(key)
88            .or_insert_with(|| SharedHttpConnector::new(ReqwestConnector::new(settings)))
89            .clone()
90    }
91
92    fn connector_metadata(&self) -> Option<ConnectorMetadata> {
93        Some(ConnectorMetadata::new(
94            "reqwest",
95            Some(Cow::Borrowed("0.13.x")),
96        ))
97    }
98}
99
100#[derive(Clone, Debug, Eq, PartialEq, Hash)]
101struct CacheKey {
102    connect_timeout: Option<Duration>,
103    read_timeout: Option<Duration>,
104}
105
106impl From<&HttpConnectorSettings> for CacheKey {
107    fn from(value: &HttpConnectorSettings) -> Self {
108        Self {
109            connect_timeout: value.connect_timeout(),
110            read_timeout: value.read_timeout(),
111        }
112    }
113}
114
115#[derive(Clone, Debug)]
116struct ReqwestConnector {
117    client: Result<reqwest::Client, Arc<ClientBuildError>>,
118}
119
120impl ReqwestConnector {
121    fn new(settings: &HttpConnectorSettings) -> Self {
122        let mut builder = reqwest::Client::builder().redirect(reqwest::redirect::Policy::none());
123        if let Some(timeout) = settings.connect_timeout() {
124            builder = builder.connect_timeout(timeout);
125        }
126        if let Some(timeout) = settings.read_timeout() {
127            builder = builder.read_timeout(timeout);
128        }
129
130        Self {
131            client: builder
132                .build()
133                .map_err(|err| Arc::new(ClientBuildError(err.to_string()))),
134        }
135    }
136}
137
138impl HttpConnector for ReqwestConnector {
139    fn call(&self, request: HttpRequest) -> HttpConnectorFuture {
140        let client = match &self.client {
141            Ok(client) => client.clone(),
142            Err(err) => {
143                let err = ClientBuildError(err.0.clone());
144                return HttpConnectorFuture::ready(Err(ConnectorError::other(Box::new(err), None)));
145            }
146        };
147
148        let request = match request.try_into_http1x() {
149            Ok(request) => request.map(reqwest::Body::wrap),
150            Err(err) => {
151                return HttpConnectorFuture::ready(Err(ConnectorError::user(Box::new(err))));
152            }
153        };
154
155        let request = match reqwest::Request::try_from(request) {
156            Ok(request) => request,
157            Err(err) => {
158                return HttpConnectorFuture::ready(Err(map_reqwest_error(err)));
159            }
160        };
161
162        HttpConnectorFuture::new(async move {
163            let response: http::Response<reqwest::Body> = client
164                .execute(request)
165                .await
166                .map_err(map_reqwest_error)?
167                .into();
168            let response = response.map(SdkBody::from_body_1_x);
169
170            HttpResponse::try_from(response)
171                .map_err(|err| ConnectorError::other(Box::new(err), None))
172        })
173    }
174}
175
176#[derive(Debug)]
177struct ClientBuildError(String);
178
179impl fmt::Display for ClientBuildError {
180    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
181        f.write_str(&self.0)
182    }
183}
184
185impl Error for ClientBuildError {}
186
187fn map_reqwest_error(err: reqwest::Error) -> ConnectorError {
188    if err.is_timeout() {
189        ConnectorError::timeout(Box::new(err))
190    } else if err.is_request() || err.is_builder() {
191        ConnectorError::user(Box::new(err))
192    } else if err.is_connect() {
193        ConnectorError::io(Box::new(err)).never_connected()
194    } else if err.is_body() || err.is_decode() {
195        ConnectorError::io(Box::new(err))
196    } else {
197        ConnectorError::other(Box::new(err), None)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
206    use http_body_util::BodyExt;
207    use tokio::io::{AsyncReadExt, AsyncWriteExt};
208    use tokio::net::TcpListener;
209
210    #[tokio::test]
211    async fn sends_request_and_streams_response_body() {
212        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
213        let address = listener.local_addr().unwrap();
214
215        let server = tokio::spawn(async move {
216            let (mut socket, _) = listener.accept().await.unwrap();
217            let mut buffer = Vec::new();
218            let mut chunk = [0; 1024];
219
220            loop {
221                let bytes_read = socket.read(&mut chunk).await.unwrap();
222                assert_ne!(0, bytes_read);
223                buffer.extend_from_slice(&chunk[..bytes_read]);
224
225                if let Some(header_end) = find_subsequence(&buffer, b"\r\n\r\n") {
226                    let headers = String::from_utf8_lossy(&buffer[..header_end]);
227                    let content_length = headers
228                        .lines()
229                        .find_map(|line| {
230                            let (name, value) = line.split_once(':')?;
231                            name.eq_ignore_ascii_case("content-length")
232                                .then(|| value.trim().parse::<usize>().ok())
233                                .flatten()
234                        })
235                        .unwrap_or_default();
236
237                    if buffer.len() >= header_end + 4 + content_length {
238                        break;
239                    }
240                }
241            }
242
243            let request = String::from_utf8_lossy(&buffer);
244
245            assert!(request.starts_with("POST /hello?x=1 HTTP/1.1"));
246            assert!(request.contains("x-test: ok"));
247            assert!(request.contains("\r\n\r\nping"));
248
249            socket
250                .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nx-answer: yes\r\n\r\nworld")
251                .await
252                .unwrap();
253        });
254
255        let client = ReqwestHttpClient::new();
256        let settings = HttpConnectorSettings::builder()
257            .connect_timeout(Duration::from_secs(1))
258            .read_timeout(Duration::from_secs(1))
259            .build();
260        let runtime_components = RuntimeComponentsBuilder::for_tests().build().unwrap();
261        let connector = client.http_connector(&settings, &runtime_components);
262
263        let request = http::Request::builder()
264            .method("POST")
265            .uri(format!("http://{address}/hello?x=1"))
266            .header("x-test", "ok")
267            .body(SdkBody::from("ping"))
268            .unwrap();
269        let request = HttpRequest::try_from(request).unwrap();
270
271        let response = connector.call(request).await.unwrap();
272        assert_eq!(200, response.status().as_u16());
273        assert_eq!("yes", response.headers().get("x-answer").unwrap());
274
275        let body = response.into_body().collect().await.unwrap().to_bytes();
276        assert_eq!("world", body);
277
278        server.await.unwrap();
279    }
280
281    #[test]
282    fn connector_metadata_identifies_reqwest() {
283        let metadata = ReqwestHttpClient::new().connector_metadata().unwrap();
284        assert_eq!("reqwest", metadata.name());
285        assert_eq!(Some(Cow::Borrowed("0.13.x")), metadata.version());
286    }
287
288    fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
289        haystack
290            .windows(needle.len())
291            .position(|window| window == needle)
292    }
293}