s3/request/
tokio_backend.rs

1extern crate base64;
2extern crate md5;
3
4use bytes::Bytes;
5use futures_util::TryStreamExt;
6use maybe_async::maybe_async;
7use std::collections::HashMap;
8use std::str::FromStr as _;
9use time::OffsetDateTime;
10
11use super::request_trait::{Request, ResponseData, ResponseDataStream};
12use crate::bucket::Bucket;
13use crate::command::Command;
14use crate::command::HttpMethod;
15use crate::error::S3Error;
16use crate::retry;
17use crate::utils::now_utc;
18
19use tokio_stream::StreamExt;
20
21#[derive(Clone, Debug, Default)]
22pub(crate) struct ClientOptions {
23    pub request_timeout: Option<std::time::Duration>,
24    pub proxy: Option<reqwest::Proxy>,
25    #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
26    pub accept_invalid_certs: bool,
27    #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))]
28    pub accept_invalid_hostnames: bool,
29}
30
31#[cfg(feature = "with-tokio")]
32pub(crate) fn client(options: &ClientOptions) -> Result<reqwest::Client, S3Error> {
33    let client = reqwest::Client::builder();
34
35    let client = if let Some(timeout) = options.request_timeout {
36        client.timeout(timeout)
37    } else {
38        client
39    };
40
41    let client = if let Some(ref proxy) = options.proxy {
42        client.proxy(proxy.clone())
43    } else {
44        client
45    };
46
47    cfg_if::cfg_if! {
48        if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
49            let client = client.danger_accept_invalid_certs(options.accept_invalid_certs);
50        }
51    }
52
53    cfg_if::cfg_if! {
54        if #[cfg(any(feature = "tokio-native-tls", feature = "tokio-rustls-tls"))] {
55            let client = client.danger_accept_invalid_hostnames(options.accept_invalid_hostnames);
56        }
57    }
58
59    Ok(client.build()?)
60}
61// Temporary structure for making a request
62pub struct ReqwestRequest<'a> {
63    pub bucket: &'a Bucket,
64    pub path: &'a str,
65    pub command: Command<'a>,
66    pub datetime: OffsetDateTime,
67    pub sync: bool,
68}
69
70#[maybe_async]
71impl<'a> Request for ReqwestRequest<'a> {
72    type Response = reqwest::Response;
73    type HeaderMap = reqwest::header::HeaderMap;
74
75    async fn response(&self) -> Result<Self::Response, S3Error> {
76        let headers = self
77            .headers()
78            .await?
79            .iter()
80            .map(|(k, v)| {
81                (
82                    reqwest::header::HeaderName::from_str(k.as_str()),
83                    reqwest::header::HeaderValue::from_str(v.to_str().unwrap_or_default()),
84                )
85            })
86            .filter(|(k, v)| k.is_ok() && v.is_ok())
87            .map(|(k, v)| (k.unwrap(), v.unwrap()))
88            .collect();
89
90        let client = self.bucket.http_client();
91
92        let method = match self.command.http_verb() {
93            HttpMethod::Delete => reqwest::Method::DELETE,
94            HttpMethod::Get => reqwest::Method::GET,
95            HttpMethod::Post => reqwest::Method::POST,
96            HttpMethod::Put => reqwest::Method::PUT,
97            HttpMethod::Head => reqwest::Method::HEAD,
98        };
99
100        let request = client
101            .request(method, self.url()?.as_str())
102            .headers(headers)
103            .body(self.request_body()?);
104
105        let request = request.build()?;
106
107        // println!("Request: {:?}", request);
108
109        let response = client.execute(request).await?;
110
111        if cfg!(feature = "fail-on-err") && !response.status().is_success() {
112            let status = response.status().as_u16();
113            let text = response.text().await?;
114            return Err(S3Error::HttpFailWithBody(status, text));
115        }
116
117        Ok(response)
118    }
119
120    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error> {
121        let response = retry! {self.response().await }?;
122        let status_code = response.status().as_u16();
123        let mut headers = response.headers().clone();
124        let response_headers = headers
125            .clone()
126            .iter()
127            .map(|(k, v)| {
128                (
129                    k.to_string(),
130                    v.to_str()
131                        .unwrap_or("could-not-decode-header-value")
132                        .to_string(),
133                )
134            })
135            .collect::<HashMap<String, String>>();
136        // When etag=true, we extract the ETag header and return it as the body.
137        // This is used for PUT operations (regular puts, multipart chunks) where:
138        // 1. S3 returns an empty or non-useful response body
139        // 2. The ETag header contains the essential information we need
140        // 3. The calling code expects to get the ETag via response_data.as_str()
141        //
142        // Note: This approach means we discard any actual response body when etag=true,
143        // but for the operations that use this (PUTs), the body is typically empty
144        // or contains redundant information already available in headers.
145        //
146        // TODO: Refactor this to properly return the response body and access ETag
147        // from headers instead of replacing the body. This would be a breaking change.
148        let body_vec = if etag {
149            if let Some(etag) = headers.remove("ETag") {
150                Bytes::from(etag.to_str()?.to_string())
151            } else {
152                Bytes::from("")
153            }
154        } else {
155            response.bytes().await?
156        };
157        Ok(ResponseData::new(body_vec, status_code, response_headers))
158    }
159
160    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin + ?Sized>(
161        &self,
162        writer: &mut T,
163    ) -> Result<u16, S3Error> {
164        use tokio::io::AsyncWriteExt;
165        let response = retry! {self.response().await}?;
166
167        let status_code = response.status();
168        let mut stream = response.bytes_stream();
169
170        while let Some(item) = stream.next().await {
171            writer.write_all(&item?).await?;
172        }
173
174        Ok(status_code.as_u16())
175    }
176
177    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error> {
178        let response = retry! {self.response().await}?;
179        let status_code = response.status();
180        let stream = response.bytes_stream().map_err(S3Error::Reqwest);
181
182        Ok(ResponseDataStream {
183            bytes: Box::pin(stream),
184            status_code: status_code.as_u16(),
185        })
186    }
187
188    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error> {
189        let response = retry! {self.response().await}?;
190        let status_code = response.status().as_u16();
191        let headers = response.headers().clone();
192        Ok((headers, status_code))
193    }
194
195    fn datetime(&self) -> OffsetDateTime {
196        self.datetime
197    }
198
199    fn bucket(&self) -> Bucket {
200        self.bucket.clone()
201    }
202
203    fn command(&self) -> Command<'_> {
204        self.command.clone()
205    }
206
207    fn path(&self) -> String {
208        self.path.to_string()
209    }
210}
211
212impl<'a> ReqwestRequest<'a> {
213    pub async fn new(
214        bucket: &'a Bucket,
215        path: &'a str,
216        command: Command<'a>,
217    ) -> Result<ReqwestRequest<'a>, S3Error> {
218        bucket.credentials_refresh().await?;
219        Ok(Self {
220            bucket,
221            path,
222            command,
223            datetime: now_utc(),
224            sync: false,
225        })
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use crate::bucket::Bucket;
232    use crate::command::Command;
233    use crate::request::Request;
234    use crate::request::tokio_backend::ReqwestRequest;
235    use awscreds::Credentials;
236    use http::header::{HOST, RANGE};
237
238    // Fake keys - otherwise using Credentials::default will use actual user
239    // credentials if they exist.
240    fn fake_credentials() -> Credentials {
241        let access_key = "AKIAIOSFODNN7EXAMPLE";
242        let secert_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY";
243        Credentials::new(Some(access_key), Some(secert_key), None, None, None).unwrap()
244    }
245
246    #[tokio::test]
247    async fn url_uses_https_by_default() {
248        let region = "custom-region".parse().unwrap();
249        let bucket = Bucket::new("my-first-bucket", region, fake_credentials()).unwrap();
250        let path = "/my-first/path";
251        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
252            .await
253            .unwrap();
254
255        assert_eq!(request.url().unwrap().scheme(), "https");
256
257        let headers = request.headers().await.unwrap();
258        let host = headers.get(HOST).unwrap();
259
260        assert_eq!(*host, "my-first-bucket.custom-region".to_string());
261    }
262
263    #[tokio::test]
264    async fn url_uses_https_by_default_path_style() {
265        let region = "custom-region".parse().unwrap();
266        let bucket = Bucket::new("my-first-bucket", region, fake_credentials())
267            .unwrap()
268            .with_path_style();
269        let path = "/my-first/path";
270        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
271            .await
272            .unwrap();
273
274        assert_eq!(request.url().unwrap().scheme(), "https");
275
276        let headers = request.headers().await.unwrap();
277        let host = headers.get(HOST).unwrap();
278
279        assert_eq!(*host, "custom-region".to_string());
280    }
281
282    #[tokio::test]
283    async fn url_uses_scheme_from_custom_region_if_defined() {
284        let region = "http://custom-region".parse().unwrap();
285        let bucket = Bucket::new("my-second-bucket", region, fake_credentials()).unwrap();
286        let path = "/my-second/path";
287        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
288            .await
289            .unwrap();
290
291        assert_eq!(request.url().unwrap().scheme(), "http");
292
293        let headers = request.headers().await.unwrap();
294        let host = headers.get(HOST).unwrap();
295        assert_eq!(*host, "my-second-bucket.custom-region".to_string());
296    }
297
298    #[tokio::test]
299    async fn url_uses_scheme_from_custom_region_if_defined_with_path_style() {
300        let region = "http://custom-region".parse().unwrap();
301        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
302            .unwrap()
303            .with_path_style();
304        let path = "/my-second/path";
305        let request = ReqwestRequest::new(&bucket, path, Command::GetObject)
306            .await
307            .unwrap();
308
309        assert_eq!(request.url().unwrap().scheme(), "http");
310
311        let headers = request.headers().await.unwrap();
312        let host = headers.get(HOST).unwrap();
313        assert_eq!(*host, "custom-region".to_string());
314    }
315
316    #[tokio::test]
317    async fn test_get_object_range_header() {
318        let region = "http://custom-region".parse().unwrap();
319        let bucket = Bucket::new("my-second-bucket", region, fake_credentials())
320            .unwrap()
321            .with_path_style();
322        let path = "/my-second/path";
323
324        let request = ReqwestRequest::new(
325            &bucket,
326            path,
327            Command::GetObjectRange {
328                start: 0,
329                end: None,
330            },
331        )
332        .await
333        .unwrap();
334        let headers = request.headers().await.unwrap();
335        let range = headers.get(RANGE).unwrap();
336        assert_eq!(range, "bytes=0-");
337
338        let request = ReqwestRequest::new(
339            &bucket,
340            path,
341            Command::GetObjectRange {
342                start: 0,
343                end: Some(1),
344            },
345        )
346        .await
347        .unwrap();
348        let headers = request.headers().await.unwrap();
349        let range = headers.get(RANGE).unwrap();
350        assert_eq!(range, "bytes=0-1");
351    }
352}