s3/request/
request_trait.rs

1use base64::engine::general_purpose;
2use base64::Engine;
3use hmac::Mac;
4use std::collections::HashMap;
5use time::format_description::well_known::Rfc2822;
6use time::OffsetDateTime;
7use url::Url;
8
9use crate::bucket::Bucket;
10use crate::command::Command;
11use crate::error::S3Error;
12use crate::signing;
13use crate::LONG_DATETIME;
14use bytes::Bytes;
15use http::header::{
16    HeaderName, ACCEPT, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, DATE, HOST, RANGE,
17};
18use http::HeaderMap;
19use std::fmt::Write as _;
20
21#[derive(Debug)]
22
23pub struct ResponseData {
24    bytes: Bytes,
25    status_code: u16,
26    headers: HashMap<String, String>,
27}
28
29impl From<ResponseData> for Vec<u8> {
30    fn from(data: ResponseData) -> Vec<u8> {
31        data.to_vec()
32    }
33}
34
35impl ResponseData {
36    pub fn new(bytes: Bytes, status_code: u16, headers: HashMap<String, String>) -> ResponseData {
37        ResponseData {
38            bytes,
39            status_code,
40            headers,
41        }
42    }
43
44    pub fn as_slice(&self) -> &[u8] {
45        &self.bytes
46    }
47
48    pub fn to_vec(self) -> Vec<u8> {
49        self.bytes.to_vec()
50    }
51
52    pub fn bytes(&self) -> &Bytes {
53        &self.bytes
54    }
55
56    pub fn status_code(&self) -> u16 {
57        self.status_code
58    }
59
60    pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> {
61        std::str::from_utf8(self.as_slice())
62    }
63
64    pub fn to_string(&self) -> Result<String, std::str::Utf8Error> {
65        std::str::from_utf8(self.as_slice()).map(|s| s.to_string())
66    }
67
68    pub fn headers(&self) -> HashMap<String, String> {
69        self.headers.clone()
70    }
71}
72
73use std::fmt;
74
75impl fmt::Display for ResponseData {
76    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77        write!(
78            f,
79            "Status code: {}\n Data: {}",
80            self.status_code(),
81            self.to_string()
82                .unwrap_or_else(|_| "Data could not be cast to UTF string".to_string())
83        )
84    }
85}
86
87use std::pin::Pin;
88
89pub type DataStream = Pin<Box<dyn futures::Stream<Item = StreamItem> + Send>>;
90pub type StreamItem = Result<bytes::Bytes, crate::error::S3Error>;
91
92pub struct ResponseDataStream {
93    pub bytes: DataStream,
94    pub status_code: u16,
95}
96
97impl ResponseDataStream {
98    pub fn bytes(&mut self) -> &mut DataStream {
99        &mut self.bytes
100    }
101}
102
103#[async_trait::async_trait]
104pub trait Request {
105    type Response;
106    type HeaderMap;
107
108    async fn response(&self) -> Result<Self::Response, S3Error>;
109    async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error>;
110    async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin>(
111        &self,
112        writer: &mut T,
113    ) -> Result<u16, S3Error>;
114    async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error>;
115    async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>;
116    fn datetime(&self) -> OffsetDateTime;
117    fn bucket(&self) -> Bucket;
118    fn command(&self) -> Command;
119    fn path(&self) -> String;
120
121    fn signing_key(&self) -> Result<Vec<u8>, S3Error> {
122        signing::signing_key(
123            &self.datetime(),
124            &self
125                .bucket()
126                .secret_key()?
127                .expect("Secret key must be provided to sign headers, found None"),
128            &self.bucket().region(),
129            "s3",
130        )
131    }
132
133    fn request_body(&self) -> Vec<u8> {
134        match self.command() {
135            Command::PutObject { content, .. } => Vec::from(content),
136            Command::PutObjectTagging { tags } => Vec::from(tags),
137            Command::UploadPart { content, .. } => Vec::from(content),
138            Command::CompleteMultipartUpload { data, .. } => data.to_string().as_bytes().to_vec(),
139            Command::CreateBucket { config } => config
140                .location_constraint_payload()
141                .map(Vec::from)
142                .unwrap_or_default(),
143            _ => vec![],
144        }
145    }
146
147    fn long_date(&self) -> Result<String, S3Error> {
148        Ok(self.datetime().format(LONG_DATETIME)?)
149    }
150
151    fn string_to_sign(&self, request: &str) -> Result<String, S3Error> {
152        match self.command() {
153            Command::PresignPost { post_policy, .. } => Ok(post_policy),
154            _ => Ok(signing::string_to_sign(
155                &self.datetime(),
156                &self.bucket().region(),
157                request,
158            )?),
159        }
160    }
161
162    fn host_header(&self) -> String {
163        self.bucket().host()
164    }
165
166    fn presigned(&self) -> Result<String, S3Error> {
167        let (expiry, custom_headers, custom_queries) = match self.command() {
168            Command::PresignGet {
169                expiry_secs,
170                custom_queries,
171            } => (expiry_secs, None, custom_queries),
172            Command::PresignPut {
173                expiry_secs,
174                custom_headers,
175            } => (expiry_secs, custom_headers, None),
176            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
177            _ => unreachable!(),
178        };
179
180        Ok(format!(
181            "{}&X-Amz-Signature={}",
182            self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
183            self.presigned_authorization(custom_headers.as_ref())?
184        ))
185    }
186
187    fn presigned_authorization(
188        &self,
189        custom_headers: Option<&HeaderMap>,
190    ) -> Result<String, S3Error> {
191        let mut headers = HeaderMap::new();
192        let host_header = self.host_header();
193        headers.insert(HOST, host_header.parse()?);
194        if let Some(custom_headers) = custom_headers {
195            for (k, v) in custom_headers.iter() {
196                headers.insert(k.clone(), v.clone());
197            }
198        }
199        let canonical_request = self.presigned_canonical_request(&headers)?;
200        let string_to_sign = self.string_to_sign(&canonical_request)?;
201        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
202        hmac.update(string_to_sign.as_bytes());
203        let signature = hex::encode(hmac.finalize().into_bytes());
204        // let signed_header = signing::signed_header_string(&headers);
205        Ok(signature)
206    }
207
208    fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
209        let (expiry, custom_headers, custom_queries) = match self.command() {
210            Command::PresignGet {
211                expiry_secs,
212                custom_queries,
213            } => (expiry_secs, None, custom_queries),
214            Command::PresignPut {
215                expiry_secs,
216                custom_headers,
217            } => (expiry_secs, custom_headers, None),
218            Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
219            _ => unreachable!(),
220        };
221
222        signing::canonical_request(
223            &self.command().http_verb().to_string(),
224            &self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
225            headers,
226            "UNSIGNED-PAYLOAD",
227        )
228    }
229
230    fn presigned_url_no_sig(
231        &self,
232        expiry: u32,
233        custom_headers: Option<&HeaderMap>,
234        custom_queries: Option<&HashMap<String, String>>,
235    ) -> Result<Url, S3Error> {
236        let bucket = self.bucket();
237        let token = if let Some(security_token) = bucket.security_token()? {
238            Some(security_token)
239        } else {
240            bucket.session_token()?
241        };
242        let url = Url::parse(&format!(
243            "{}{}{}",
244            self.url()?,
245            &signing::authorization_query_params_no_sig(
246                &self.bucket().access_key()?.unwrap_or_default(),
247                &self.datetime(),
248                &self.bucket().region(),
249                expiry,
250                custom_headers,
251                token.as_ref()
252            )?,
253            &signing::flatten_queries(custom_queries)?,
254        ))?;
255
256        Ok(url)
257    }
258
259    fn url(&self) -> Result<Url, S3Error> {
260        let mut url_str = self.bucket().url();
261
262        if let Command::ListBuckets { .. } = self.command() {
263            return Ok(Url::parse(&url_str)?);
264        }
265
266        if let Command::CreateBucket { .. } = self.command() {
267            return Ok(Url::parse(&url_str)?);
268        }
269
270        let path = if self.path().starts_with('/') {
271            self.path()[1..].to_string()
272        } else {
273            self.path()[..].to_string()
274        };
275
276        url_str.push('/');
277        url_str.push_str(&signing::uri_encode(&path, false));
278
279        // Append to url_path
280        #[allow(clippy::collapsible_match)]
281        match self.command() {
282            Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => {
283                url_str.push_str("?uploads")
284            }
285            Command::AbortMultipartUpload { upload_id } => {
286                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
287            }
288            Command::CompleteMultipartUpload { upload_id, .. } => {
289                write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
290            }
291            Command::GetObjectTorrent => url_str.push_str("?torrent"),
292            Command::PutObject { multipart, .. } => {
293                if let Some(multipart) = multipart {
294                    url_str.push_str(&multipart.query_string())
295                }
296            }
297            _ => {}
298        }
299
300        let mut url = Url::parse(&url_str)?;
301
302        for (key, value) in &self.bucket().extra_query {
303            url.query_pairs_mut().append_pair(key, value);
304        }
305
306        if let Command::ListObjectsV2 {
307            prefix,
308            delimiter,
309            continuation_token,
310            start_after,
311            max_keys,
312        } = self.command().clone()
313        {
314            let mut query_pairs = url.query_pairs_mut();
315            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
316
317            query_pairs.append_pair("prefix", &prefix);
318            query_pairs.append_pair("list-type", "2");
319            if let Some(token) = continuation_token {
320                query_pairs.append_pair("continuation-token", &token);
321            }
322            if let Some(start_after) = start_after {
323                query_pairs.append_pair("start-after", &start_after);
324            }
325            if let Some(max_keys) = max_keys {
326                query_pairs.append_pair("max-keys", &max_keys.to_string());
327            }
328        }
329
330        if let Command::ListObjects {
331            prefix,
332            delimiter,
333            marker,
334            max_keys,
335        } = self.command().clone()
336        {
337            let mut query_pairs = url.query_pairs_mut();
338            delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
339
340            query_pairs.append_pair("prefix", &prefix);
341            if let Some(marker) = marker {
342                query_pairs.append_pair("marker", &marker);
343            }
344            if let Some(max_keys) = max_keys {
345                query_pairs.append_pair("max-keys", &max_keys.to_string());
346            }
347        }
348
349        match self.command() {
350            Command::ListMultipartUploads {
351                prefix,
352                delimiter,
353                key_marker,
354                max_uploads,
355            } => {
356                let mut query_pairs = url.query_pairs_mut();
357                delimiter.map(|d| query_pairs.append_pair("delimiter", d));
358                if let Some(prefix) = prefix {
359                    query_pairs.append_pair("prefix", prefix);
360                }
361                if let Some(key_marker) = key_marker {
362                    query_pairs.append_pair("key-marker", &key_marker);
363                }
364                if let Some(max_uploads) = max_uploads {
365                    query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str());
366                }
367            }
368            Command::PutObjectTagging { .. }
369            | Command::GetObjectTagging
370            | Command::DeleteObjectTagging => {
371                url.query_pairs_mut().append_pair("tagging", "");
372            }
373            _ => {}
374        }
375
376        Ok(url)
377    }
378
379    fn canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
380        signing::canonical_request(
381            &self.command().http_verb().to_string(),
382            &self.url()?,
383            headers,
384            &self.command().sha256(),
385        )
386    }
387
388    fn authorization(&self, headers: &HeaderMap) -> Result<String, S3Error> {
389        let canonical_request = self.canonical_request(headers)?;
390        let string_to_sign = self.string_to_sign(&canonical_request)?;
391        let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
392        hmac.update(string_to_sign.as_bytes());
393        let signature = hex::encode(hmac.finalize().into_bytes());
394        let signed_header = signing::signed_header_string(headers);
395        signing::authorization_header(
396            &self.bucket().access_key()?.expect("No access_key provided"),
397            &self.datetime(),
398            &self.bucket().region(),
399            &signed_header,
400            &signature,
401        )
402    }
403
404    fn headers(&self) -> Result<HeaderMap, S3Error> {
405        // Generate this once, but it's used in more than one place.
406        let sha256 = self.command().sha256();
407
408        // Start with extra_headers, that way our headers replace anything with
409        // the same name.
410
411        let mut headers = HeaderMap::new();
412
413        for (k, v) in self.bucket().extra_headers.iter() {
414            headers.insert(k.clone(), v.clone());
415        }
416
417        let host_header = self.host_header();
418
419        headers.insert(HOST, host_header.parse()?);
420
421        match self.command() {
422            Command::CopyObject { from } => {
423                headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?);
424            }
425            Command::ListObjects { .. } => {}
426            Command::ListObjectsV2 { .. } => {}
427            Command::GetObject => {}
428            Command::GetObjectTagging => {}
429            Command::GetBucketLocation => {}
430            _ => {
431                headers.insert(
432                    CONTENT_LENGTH,
433                    self.command().content_length().to_string().parse()?,
434                );
435                headers.insert(CONTENT_TYPE, self.command().content_type().parse()?);
436            }
437        }
438        headers.insert(
439            HeaderName::from_static("x-amz-content-sha256"),
440            sha256.parse()?,
441        );
442        headers.insert(
443            HeaderName::from_static("x-amz-date"),
444            self.long_date()?.parse()?,
445        );
446
447        if let Some(session_token) = self.bucket().session_token()? {
448            headers.insert(
449                HeaderName::from_static("x-amz-security-token"),
450                session_token.parse()?,
451            );
452        } else if let Some(security_token) = self.bucket().security_token()? {
453            headers.insert(
454                HeaderName::from_static("x-amz-security-token"),
455                security_token.parse()?,
456            );
457        }
458
459        if let Command::PutObjectTagging { tags } = self.command() {
460            let digest = md5::compute(tags);
461            let hash = general_purpose::STANDARD.encode(digest.as_ref());
462            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
463        } else if let Command::PutObject { content, .. } = self.command() {
464            let digest = md5::compute(content);
465            let hash = general_purpose::STANDARD.encode(digest.as_ref());
466            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
467        } else if let Command::UploadPart { content, .. } = self.command() {
468            let digest = md5::compute(content);
469            let hash = general_purpose::STANDARD.encode(digest.as_ref());
470            headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
471        } else if let Command::GetObject {} = self.command() {
472            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
473        // headers.insert(header::ACCEPT_CHARSET, HeaderValue::from_str("UTF-8")?);
474        } else if let Command::GetObjectRange { start, end } = self.command() {
475            headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
476
477            let mut range = format!("bytes={}-", start);
478
479            if let Some(end) = end {
480                range.push_str(&end.to_string());
481            }
482
483            headers.insert(RANGE, range.parse()?);
484        } else if let Command::CreateBucket { ref config } = self.command() {
485            config.add_headers(&mut headers)?;
486        }
487
488        // This must be last, as it signs the other headers, omitted if no secret key is provided
489        if self.bucket().secret_key()?.is_some() {
490            let authorization = self.authorization(&headers)?;
491            headers.insert(AUTHORIZATION, authorization.parse()?);
492        }
493
494        // The format of RFC2822 is somewhat malleable, so including it in
495        // signed headers can cause signature mismatches. We do include the
496        // X-Amz-Date header, so requests are still properly limited to a date
497        // range and can't be used again e.g. reply attacks. Adding this header
498        // after the generation of the Authorization header leaves it out of
499        // the signed headers.
500        headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?);
501
502        Ok(headers)
503    }
504}