s3/utils/
mod.rs

1mod time_utils;
2
3pub use time_utils::*;
4
5use std::str::FromStr;
6
7use crate::error::S3Error;
8use crate::request::ResponseData;
9use crate::{bucket::CHUNK_SIZE, serde_types::HeadObjectResult};
10
11use std::fs::File;
12
13use std::io::Read;
14use std::path::Path;
15
16#[cfg(feature = "with-tokio")]
17use tokio::io::{AsyncRead, AsyncReadExt};
18
19#[cfg(feature = "with-async-std")]
20use async_std::io::Read as AsyncRead;
21
22#[cfg(feature = "with-async-std")]
23use async_std::io::ReadExt as AsyncReadExt;
24
25#[derive(Debug)]
26pub struct PutStreamResponse {
27    status_code: u16,
28    uploaded_bytes: usize,
29}
30
31impl PutStreamResponse {
32    pub fn new(status_code: u16, uploaded_bytes: usize) -> Self {
33        Self {
34            status_code,
35            uploaded_bytes,
36        }
37    }
38    pub fn status_code(&self) -> u16 {
39        self.status_code
40    }
41
42    pub fn uploaded_bytes(&self) -> usize {
43        self.uploaded_bytes
44    }
45}
46
47/// # Example
48/// ```rust,no_run
49/// use s3::utils::etag_for_path;
50///
51/// let path = "test_etag";
52/// let etag = etag_for_path(path).unwrap();
53/// println!("{}", etag);
54/// ```
55pub fn etag_for_path(path: impl AsRef<Path>) -> Result<String, S3Error> {
56    let mut file = File::open(path)?;
57    let mut last_digest: [u8; 16];
58    let mut digests = Vec::new();
59    let mut chunks = 0;
60    loop {
61        let chunk = read_chunk(&mut file)?;
62        last_digest = md5::compute(&chunk).into();
63        digests.extend_from_slice(&last_digest);
64        chunks += 1;
65        if chunk.len() < CHUNK_SIZE {
66            break;
67        }
68    }
69    let etag = if chunks <= 1 {
70        format!("{:x}", md5::Digest(last_digest))
71    } else {
72        let digest = format!("{:x}", md5::compute(digests));
73        format!("{}-{}", digest, chunks)
74    };
75    Ok(etag)
76}
77
78pub fn read_chunk<R: Read + ?Sized>(reader: &mut R) -> Result<Vec<u8>, S3Error> {
79    let mut chunk = Vec::with_capacity(CHUNK_SIZE);
80    let mut take = reader.take(CHUNK_SIZE as u64);
81    take.read_to_end(&mut chunk)?;
82
83    Ok(chunk)
84}
85
86#[cfg(any(feature = "with-tokio", feature = "with-async-std"))]
87pub async fn read_chunk_async<R: AsyncRead + Unpin + ?Sized>(
88    reader: &mut R,
89) -> Result<Vec<u8>, S3Error> {
90    let mut chunk = Vec::with_capacity(CHUNK_SIZE);
91    let mut take = reader.take(CHUNK_SIZE as u64);
92    take.read_to_end(&mut chunk).await?;
93
94    Ok(chunk)
95}
96
97pub trait GetAndConvertHeaders {
98    fn get_and_convert<T: FromStr>(&self, header: &str) -> Option<T>;
99    fn get_string(&self, header: &str) -> Option<String>;
100}
101
102impl GetAndConvertHeaders for http::header::HeaderMap {
103    fn get_and_convert<T: FromStr>(&self, header: &str) -> Option<T> {
104        self.get(header)?.to_str().ok()?.parse::<T>().ok()
105    }
106    fn get_string(&self, header: &str) -> Option<String> {
107        Some(self.get(header)?.to_str().ok()?.to_owned())
108    }
109}
110
111#[cfg(feature = "with-async-std")]
112impl From<&http::HeaderMap> for HeadObjectResult {
113    fn from(headers: &http::HeaderMap) -> Self {
114        let mut result = HeadObjectResult {
115            accept_ranges: headers.get_string("accept-ranges"),
116            cache_control: headers.get_string("Cache-Control"),
117            content_disposition: headers.get_string("Content-Disposition"),
118            content_encoding: headers.get_string("Content-Encoding"),
119            content_language: headers.get_string("Content-Language"),
120            content_length: headers.get_and_convert("Content-Length"),
121            content_type: headers.get_string("Content-Type"),
122            delete_marker: headers.get_and_convert("x-amz-delete-marker"),
123            e_tag: headers.get_string("ETag"),
124            expiration: headers.get_string("x-amz-expiration"),
125            expires: headers.get_string("Expires"),
126            last_modified: headers.get_string("Last-Modified"),
127            ..Default::default()
128        };
129        let mut values = ::std::collections::HashMap::new();
130        for (key, value) in headers.iter() {
131            if key.as_str().starts_with("x-amz-meta-")
132                && let Ok(value) = value.to_str()
133            {
134                values.insert(
135                    key.as_str()["x-amz-meta-".len()..].to_owned(),
136                    value.to_owned(),
137                );
138            }
139        }
140        result.metadata = Some(values);
141        result.missing_meta = headers.get_and_convert("x-amz-missing-meta");
142        result.object_lock_legal_hold_status = headers.get_string("x-amz-object-lock-legal-hold");
143        result.object_lock_mode = headers.get_string("x-amz-object-lock-mode");
144        result.object_lock_retain_until_date =
145            headers.get_string("x-amz-object-lock-retain-until-date");
146        result.parts_count = headers.get_and_convert("x-amz-mp-parts-count");
147        result.replication_status = headers.get_string("x-amz-replication-status");
148        result.request_charged = headers.get_string("x-amz-request-charged");
149        result.restore = headers.get_string("x-amz-restore");
150        result.sse_customer_algorithm =
151            headers.get_string("x-amz-server-side-encryption-customer-algorithm");
152        result.sse_customer_key_md5 =
153            headers.get_string("x-amz-server-side-encryption-customer-key-MD5");
154        result.ssekms_key_id = headers.get_string("x-amz-server-side-encryption-aws-kms-key-id");
155        result.server_side_encryption = headers.get_string("x-amz-server-side-encryption");
156        result.storage_class = headers.get_string("x-amz-storage-class");
157        result.version_id = headers.get_string("x-amz-version-id");
158        result.website_redirect_location = headers.get_string("x-amz-website-redirect-location");
159        result
160    }
161}
162
163#[cfg(feature = "with-tokio")]
164impl From<&reqwest::header::HeaderMap> for HeadObjectResult {
165    fn from(headers: &reqwest::header::HeaderMap) -> Self {
166        let mut result = HeadObjectResult {
167            accept_ranges: headers
168                .get("accept-ranges")
169                .map(|v| v.to_str().unwrap_or_default().to_string()),
170            cache_control: headers
171                .get("Cache-Control")
172                .map(|v| v.to_str().unwrap_or_default().to_string()),
173            content_disposition: headers
174                .get("Content-Disposition")
175                .map(|v| v.to_str().unwrap_or_default().to_string()),
176            content_encoding: headers
177                .get("Content-Encoding")
178                .map(|v| v.to_str().unwrap_or_default().to_string()),
179            content_language: headers
180                .get("Content-Language")
181                .map(|v| v.to_str().unwrap_or_default().to_string()),
182            content_length: headers
183                .get("Content-Length")
184                .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default()),
185            content_type: headers
186                .get("Content-Type")
187                .map(|v| v.to_str().unwrap_or_default().to_string()),
188            delete_marker: headers
189                .get("x-amz-delete-marker")
190                .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default()),
191            e_tag: headers
192                .get("ETag")
193                .map(|v| v.to_str().unwrap_or_default().to_string()),
194            expiration: headers
195                .get("x-amz-expiration")
196                .map(|v| v.to_str().unwrap_or_default().to_string()),
197            expires: headers
198                .get("Expires")
199                .map(|v| v.to_str().unwrap_or_default().to_string()),
200            last_modified: headers
201                .get("Last-Modified")
202                .map(|v| v.to_str().unwrap_or_default().to_string()),
203            ..Default::default()
204        };
205        let mut values = ::std::collections::HashMap::new();
206        for (key, value) in headers.iter() {
207            if key.as_str().starts_with("x-amz-meta-")
208                && let Ok(value) = value.to_str()
209            {
210                values.insert(
211                    key.as_str()["x-amz-meta-".len()..].to_owned(),
212                    value.to_owned(),
213                );
214            }
215        }
216        result.metadata = Some(values);
217        result.missing_meta = headers
218            .get("x-amz-missing-meta")
219            .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default());
220        result.object_lock_legal_hold_status = headers
221            .get("x-amz-object-lock-legal-hold")
222            .map(|v| v.to_str().unwrap_or_default().to_string());
223        result.object_lock_mode = headers
224            .get("x-amz-object-lock-mode")
225            .map(|v| v.to_str().unwrap_or_default().to_string());
226        result.object_lock_retain_until_date = headers
227            .get("x-amz-object-lock-retain-until-date")
228            .map(|v| v.to_str().unwrap_or_default().to_string());
229        result.parts_count = headers
230            .get("x-amz-mp-parts-count")
231            .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default());
232        result.replication_status = headers
233            .get("x-amz-replication-status")
234            .map(|v| v.to_str().unwrap_or_default().to_string());
235        result.request_charged = headers
236            .get("x-amz-request-charged")
237            .map(|v| v.to_str().unwrap_or_default().to_string());
238        result.restore = headers
239            .get("x-amz-restore")
240            .map(|v| v.to_str().unwrap_or_default().to_string());
241        result.sse_customer_algorithm = headers
242            .get("x-amz-server-side-encryption-customer-algorithm")
243            .map(|v| v.to_str().unwrap_or_default().to_string());
244        result.sse_customer_key_md5 = headers
245            .get("x-amz-server-side-encryption-customer-key-MD5")
246            .map(|v| v.to_str().unwrap_or_default().to_string());
247        result.ssekms_key_id = headers
248            .get("x-amz-server-side-encryption-aws-kms-key-id")
249            .map(|v| v.to_str().unwrap_or_default().to_string());
250        result.server_side_encryption = headers
251            .get("x-amz-server-side-encryption")
252            .map(|v| v.to_str().unwrap_or_default().to_string());
253        result.storage_class = headers
254            .get("x-amz-storage-class")
255            .map(|v| v.to_str().unwrap_or_default().to_string());
256        result.version_id = headers
257            .get("x-amz-version-id")
258            .map(|v| v.to_str().unwrap_or_default().to_string());
259        result.website_redirect_location = headers
260            .get("x-amz-website-redirect-location")
261            .map(|v| v.to_str().unwrap_or_default().to_string());
262        result
263    }
264}
265
266#[cfg(feature = "sync")]
267impl From<&attohttpc::header::HeaderMap> for HeadObjectResult {
268    fn from(headers: &attohttpc::header::HeaderMap) -> Self {
269        let mut result = HeadObjectResult {
270            accept_ranges: headers
271                .get("accept-ranges")
272                .map(|v| v.to_str().unwrap_or_default().to_string()),
273            cache_control: headers
274                .get("Cache-Control")
275                .map(|v| v.to_str().unwrap_or_default().to_string()),
276            content_disposition: headers
277                .get("Content-Disposition")
278                .map(|v| v.to_str().unwrap_or_default().to_string()),
279            content_encoding: headers
280                .get("Content-Encoding")
281                .map(|v| v.to_str().unwrap_or_default().to_string()),
282            content_language: headers
283                .get("Content-Language")
284                .map(|v| v.to_str().unwrap_or_default().to_string()),
285            content_length: headers
286                .get("Content-Length")
287                .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default()),
288            content_type: headers
289                .get("get-Type")
290                .map(|v| v.to_str().unwrap_or_default().to_string()),
291            delete_marker: headers
292                .get("x-amz-delete-marker")
293                .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default()),
294            e_tag: headers
295                .get("ETag")
296                .map(|v| v.to_str().unwrap_or_default().to_string()),
297            expiration: headers
298                .get("x-amz-expiration")
299                .map(|v| v.to_str().unwrap_or_default().to_string()),
300            expires: headers
301                .get("Expires")
302                .map(|v| v.to_str().unwrap_or_default().to_string()),
303            last_modified: headers
304                .get("Last-Modified")
305                .map(|v| v.to_str().unwrap_or_default().to_string()),
306            ..Default::default()
307        };
308        let mut values = ::std::collections::HashMap::new();
309        for (key, value) in headers.iter() {
310            if key.as_str().starts_with("x-amz-meta-")
311                && let Ok(value) = value.to_str()
312            {
313                values.insert(
314                    key.as_str()["x-amz-meta-".len()..].to_owned(),
315                    value.to_owned(),
316                );
317            }
318        }
319        result.metadata = Some(values);
320        result.missing_meta = headers
321            .get("x-amz-missing-meta")
322            .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default());
323        result.object_lock_legal_hold_status = headers
324            .get("x-amz-object-lock-legal-hold")
325            .map(|v| v.to_str().unwrap_or_default().to_string());
326        result.object_lock_mode = headers
327            .get("x-amz-object-lock-mode")
328            .map(|v| v.to_str().unwrap_or_default().to_string());
329        result.object_lock_retain_until_date = headers
330            .get("x-amz-object-lock-retain-until-date")
331            .map(|v| v.to_str().unwrap_or_default().to_string());
332        result.parts_count = headers
333            .get("x-amz-mp-parts-count")
334            .map(|v| v.to_str().unwrap_or_default().parse().unwrap_or_default());
335        result.replication_status = headers
336            .get("x-amz-replication-status")
337            .map(|v| v.to_str().unwrap_or_default().to_string());
338        result.request_charged = headers
339            .get("x-amz-request-charged")
340            .map(|v| v.to_str().unwrap_or_default().to_string());
341        result.restore = headers
342            .get("x-amz-restore")
343            .map(|v| v.to_str().unwrap_or_default().to_string());
344        result.sse_customer_algorithm = headers
345            .get("x-amz-server-side-encryption-customer-algorithm")
346            .map(|v| v.to_str().unwrap_or_default().to_string());
347        result.sse_customer_key_md5 = headers
348            .get("x-amz-server-side-encryption-customer-key-MD5")
349            .map(|v| v.to_str().unwrap_or_default().to_string());
350        result.ssekms_key_id = headers
351            .get("x-amz-server-side-encryption-aws-kms-key-id")
352            .map(|v| v.to_str().unwrap_or_default().to_string());
353        result.server_side_encryption = headers
354            .get("x-amz-server-side-encryption")
355            .map(|v| v.to_str().unwrap_or_default().to_string());
356        result.storage_class = headers
357            .get("x-amz-storage-class")
358            .map(|v| v.to_str().unwrap_or_default().to_string());
359        result.version_id = headers
360            .get("x-amz-version-id")
361            .map(|v| v.to_str().unwrap_or_default().to_string());
362        result.website_redirect_location = headers
363            .get("x-amz-website-redirect-location")
364            .map(|v| v.to_str().unwrap_or_default().to_string());
365        result
366    }
367}
368
369pub(crate) fn error_from_response_data(response_data: ResponseData) -> Result<S3Error, S3Error> {
370    let utf8_content = String::from_utf8(response_data.as_slice().to_vec())?;
371    Err(S3Error::HttpFailWithBody(
372        response_data.status_code(),
373        utf8_content,
374    ))
375}
376
377/// Retries a given expression a specified number of times with exponential backoff.
378///
379/// This macro attempts to execute the provided expression up to `N` times, where `N`
380/// is the value set by `set_retries`. If the expression returns `Ok`, it returns the value.
381/// If the expression returns `Err`, it logs a warning and retries after a delay that increases
382/// exponentially with each retry.
383///
384/// The delay between retries is calculated as `1 * retry_cnt.pow(2)` seconds, where `retry_cnt`
385/// is the current retry attempt.
386///
387/// This macro supports both asynchronous and synchronous contexts:
388/// - For `tokio` users, it uses `tokio::time::sleep`.
389/// - For `async-std` users, it uses `async_std::task::sleep`.
390/// - For synchronous contexts, it uses `std::thread::sleep`.
391///
392/// # Features
393///
394/// - `with-tokio`: Uses `tokio::time::sleep` for async retries.
395/// - `with-async-std`: Uses `async_std::task::sleep` for async retries.
396/// - `sync`: Uses `std::thread::sleep` for sync retries.
397///
398/// # Errors
399///
400/// If all retry attempts fail, the last error is returned.
401#[macro_export]
402macro_rules! retry {
403    ($e:expr) => {{
404        let mut retry_cnt: u64 = 0;
405        let max_retries = $crate::get_retries() as u64;
406
407        loop {
408            match $e {
409                Ok(v) => break Ok(v),
410                Err(e) => {
411                    log::warn!("Retrying {e}");
412                    if retry_cnt >= max_retries {
413                        break Err(e);
414                    }
415                    retry_cnt += 1;
416                    let delay = std::time::Duration::from_secs(1 * retry_cnt.pow(2));
417                    #[cfg(feature = "with-tokio")]
418                    tokio::time::sleep(delay).await;
419                    #[cfg(feature = "with-async-std")]
420                    async_std::task::sleep(delay).await;
421                    #[cfg(feature = "sync")]
422                    std::thread::sleep(delay);
423                    continue;
424                }
425            }
426        }
427    }};
428}
429
430#[cfg(test)]
431mod test {
432    use crate::utils::etag_for_path;
433    use std::fs::File;
434    use std::io::Cursor;
435    use std::io::prelude::*;
436
437    fn object(size: u32) -> Vec<u8> {
438        (0..size).map(|_| 33).collect()
439    }
440
441    #[test]
442    fn test_etag_large_file() {
443        let path = "test_etag";
444        std::fs::remove_file(path).unwrap_or(());
445        let test: Vec<u8> = object(10_000_000);
446
447        let mut file = File::create(path).unwrap();
448        file.write_all(&test).unwrap();
449
450        let etag = etag_for_path(path).unwrap();
451
452        std::fs::remove_file(path).unwrap_or(());
453
454        assert_eq!(etag, "e438487f09f09c042b2de097765e5ac2-2");
455    }
456
457    #[test]
458    fn test_etag_small_file() {
459        let path = "test_etag";
460        std::fs::remove_file(path).unwrap_or(());
461        let test: Vec<u8> = object(1000);
462
463        let mut file = File::create(path).unwrap();
464        file.write_all(&test).unwrap();
465
466        let etag = etag_for_path(path).unwrap();
467
468        std::fs::remove_file(path).unwrap_or(());
469
470        assert_eq!(etag, "8122ef1c2b2331f7986349560248cf56");
471    }
472
473    #[test]
474    fn test_read_chunk_all_zero() {
475        let blob = vec![0u8; 10_000_000];
476        let mut blob = Cursor::new(blob);
477
478        let result = super::read_chunk(&mut blob).unwrap();
479
480        assert_eq!(result.len(), crate::bucket::CHUNK_SIZE);
481    }
482
483    #[test]
484    fn test_read_chunk_multi_chunk() {
485        let blob = vec![1u8; 10_000_000];
486        let mut blob = Cursor::new(blob);
487
488        let result = super::read_chunk(&mut blob).unwrap();
489        assert_eq!(result.len(), crate::bucket::CHUNK_SIZE);
490
491        let result = super::read_chunk(&mut blob).unwrap();
492        assert_eq!(result.len(), 1_611_392);
493    }
494}