1use base64::engine::general_purpose;
2use base64::Engine;
3use hmac::Mac;
4use std::collections::HashMap;
5use time::format_description::well_known::Rfc2822;
6use time::OffsetDateTime;
7use url::Url;
8
9use crate::bucket::Bucket;
10use crate::command::Command;
11use crate::error::S3Error;
12use crate::signing;
13use crate::LONG_DATETIME;
14use bytes::Bytes;
15use http::header::{
16 HeaderName, ACCEPT, AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, DATE, HOST, RANGE,
17};
18use http::HeaderMap;
19use std::fmt::Write as _;
20
21#[derive(Debug)]
22
23pub struct ResponseData {
24 bytes: Bytes,
25 status_code: u16,
26 headers: HashMap<String, String>,
27}
28
29impl From<ResponseData> for Vec<u8> {
30 fn from(data: ResponseData) -> Vec<u8> {
31 data.to_vec()
32 }
33}
34
35impl ResponseData {
36 pub fn new(bytes: Bytes, status_code: u16, headers: HashMap<String, String>) -> ResponseData {
37 ResponseData {
38 bytes,
39 status_code,
40 headers,
41 }
42 }
43
44 pub fn as_slice(&self) -> &[u8] {
45 &self.bytes
46 }
47
48 pub fn to_vec(self) -> Vec<u8> {
49 self.bytes.to_vec()
50 }
51
52 pub fn bytes(&self) -> &Bytes {
53 &self.bytes
54 }
55
56 pub fn status_code(&self) -> u16 {
57 self.status_code
58 }
59
60 pub fn as_str(&self) -> Result<&str, std::str::Utf8Error> {
61 std::str::from_utf8(self.as_slice())
62 }
63
64 pub fn to_string(&self) -> Result<String, std::str::Utf8Error> {
65 std::str::from_utf8(self.as_slice()).map(|s| s.to_string())
66 }
67
68 pub fn headers(&self) -> HashMap<String, String> {
69 self.headers.clone()
70 }
71}
72
73use std::fmt;
74
75impl fmt::Display for ResponseData {
76 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
77 write!(
78 f,
79 "Status code: {}\n Data: {}",
80 self.status_code(),
81 self.to_string()
82 .unwrap_or_else(|_| "Data could not be cast to UTF string".to_string())
83 )
84 }
85}
86
87use std::pin::Pin;
88
89pub type DataStream = Pin<Box<dyn futures::Stream<Item = StreamItem> + Send>>;
90pub type StreamItem = Result<bytes::Bytes, crate::error::S3Error>;
91
92pub struct ResponseDataStream {
93 pub bytes: DataStream,
94 pub status_code: u16,
95}
96
97impl ResponseDataStream {
98 pub fn bytes(&mut self) -> &mut DataStream {
99 &mut self.bytes
100 }
101}
102
103#[async_trait::async_trait]
104pub trait Request {
105 type Response;
106 type HeaderMap;
107
108 async fn response(&self) -> Result<Self::Response, S3Error>;
109 async fn response_data(&self, etag: bool) -> Result<ResponseData, S3Error>;
110 async fn response_data_to_writer<T: tokio::io::AsyncWrite + Send + Unpin>(
111 &self,
112 writer: &mut T,
113 ) -> Result<u16, S3Error>;
114 async fn response_data_to_stream(&self) -> Result<ResponseDataStream, S3Error>;
115 async fn response_header(&self) -> Result<(Self::HeaderMap, u16), S3Error>;
116 fn datetime(&self) -> OffsetDateTime;
117 fn bucket(&self) -> Bucket;
118 fn command(&self) -> Command;
119 fn path(&self) -> String;
120
121 fn signing_key(&self) -> Result<Vec<u8>, S3Error> {
122 signing::signing_key(
123 &self.datetime(),
124 &self
125 .bucket()
126 .secret_key()?
127 .expect("Secret key must be provided to sign headers, found None"),
128 &self.bucket().region(),
129 "s3",
130 )
131 }
132
133 fn request_body(&self) -> Vec<u8> {
134 match self.command() {
135 Command::PutObject { content, .. } => Vec::from(content),
136 Command::PutObjectTagging { tags } => Vec::from(tags),
137 Command::UploadPart { content, .. } => Vec::from(content),
138 Command::CompleteMultipartUpload { data, .. } => data.to_string().as_bytes().to_vec(),
139 Command::CreateBucket { config } => config
140 .location_constraint_payload()
141 .map(Vec::from)
142 .unwrap_or_default(),
143 _ => vec![],
144 }
145 }
146
147 fn long_date(&self) -> Result<String, S3Error> {
148 Ok(self.datetime().format(LONG_DATETIME)?)
149 }
150
151 fn string_to_sign(&self, request: &str) -> Result<String, S3Error> {
152 match self.command() {
153 Command::PresignPost { post_policy, .. } => Ok(post_policy),
154 _ => Ok(signing::string_to_sign(
155 &self.datetime(),
156 &self.bucket().region(),
157 request,
158 )?),
159 }
160 }
161
162 fn host_header(&self) -> String {
163 self.bucket().host()
164 }
165
166 fn presigned(&self) -> Result<String, S3Error> {
167 let (expiry, custom_headers, custom_queries) = match self.command() {
168 Command::PresignGet {
169 expiry_secs,
170 custom_queries,
171 } => (expiry_secs, None, custom_queries),
172 Command::PresignPut {
173 expiry_secs,
174 custom_headers,
175 } => (expiry_secs, custom_headers, None),
176 Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
177 _ => unreachable!(),
178 };
179
180 Ok(format!(
181 "{}&X-Amz-Signature={}",
182 self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
183 self.presigned_authorization(custom_headers.as_ref())?
184 ))
185 }
186
187 fn presigned_authorization(
188 &self,
189 custom_headers: Option<&HeaderMap>,
190 ) -> Result<String, S3Error> {
191 let mut headers = HeaderMap::new();
192 let host_header = self.host_header();
193 headers.insert(HOST, host_header.parse()?);
194 if let Some(custom_headers) = custom_headers {
195 for (k, v) in custom_headers.iter() {
196 headers.insert(k.clone(), v.clone());
197 }
198 }
199 let canonical_request = self.presigned_canonical_request(&headers)?;
200 let string_to_sign = self.string_to_sign(&canonical_request)?;
201 let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
202 hmac.update(string_to_sign.as_bytes());
203 let signature = hex::encode(hmac.finalize().into_bytes());
204 Ok(signature)
206 }
207
208 fn presigned_canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
209 let (expiry, custom_headers, custom_queries) = match self.command() {
210 Command::PresignGet {
211 expiry_secs,
212 custom_queries,
213 } => (expiry_secs, None, custom_queries),
214 Command::PresignPut {
215 expiry_secs,
216 custom_headers,
217 } => (expiry_secs, custom_headers, None),
218 Command::PresignDelete { expiry_secs } => (expiry_secs, None, None),
219 _ => unreachable!(),
220 };
221
222 signing::canonical_request(
223 &self.command().http_verb().to_string(),
224 &self.presigned_url_no_sig(expiry, custom_headers.as_ref(), custom_queries.as_ref())?,
225 headers,
226 "UNSIGNED-PAYLOAD",
227 )
228 }
229
230 fn presigned_url_no_sig(
231 &self,
232 expiry: u32,
233 custom_headers: Option<&HeaderMap>,
234 custom_queries: Option<&HashMap<String, String>>,
235 ) -> Result<Url, S3Error> {
236 let bucket = self.bucket();
237 let token = if let Some(security_token) = bucket.security_token()? {
238 Some(security_token)
239 } else {
240 bucket.session_token()?
241 };
242 let url = Url::parse(&format!(
243 "{}{}{}",
244 self.url()?,
245 &signing::authorization_query_params_no_sig(
246 &self.bucket().access_key()?.unwrap_or_default(),
247 &self.datetime(),
248 &self.bucket().region(),
249 expiry,
250 custom_headers,
251 token.as_ref()
252 )?,
253 &signing::flatten_queries(custom_queries)?,
254 ))?;
255
256 Ok(url)
257 }
258
259 fn url(&self) -> Result<Url, S3Error> {
260 let mut url_str = self.bucket().url();
261
262 if let Command::ListBuckets { .. } = self.command() {
263 return Ok(Url::parse(&url_str)?);
264 }
265
266 if let Command::CreateBucket { .. } = self.command() {
267 return Ok(Url::parse(&url_str)?);
268 }
269
270 let path = if self.path().starts_with('/') {
271 self.path()[1..].to_string()
272 } else {
273 self.path()[..].to_string()
274 };
275
276 url_str.push('/');
277 url_str.push_str(&signing::uri_encode(&path, false));
278
279 #[allow(clippy::collapsible_match)]
281 match self.command() {
282 Command::InitiateMultipartUpload { .. } | Command::ListMultipartUploads { .. } => {
283 url_str.push_str("?uploads")
284 }
285 Command::AbortMultipartUpload { upload_id } => {
286 write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
287 }
288 Command::CompleteMultipartUpload { upload_id, .. } => {
289 write!(url_str, "?uploadId={}", upload_id).expect("Could not write to url_str");
290 }
291 Command::GetObjectTorrent => url_str.push_str("?torrent"),
292 Command::PutObject { multipart, .. } => {
293 if let Some(multipart) = multipart {
294 url_str.push_str(&multipart.query_string())
295 }
296 }
297 _ => {}
298 }
299
300 let mut url = Url::parse(&url_str)?;
301
302 for (key, value) in &self.bucket().extra_query {
303 url.query_pairs_mut().append_pair(key, value);
304 }
305
306 if let Command::ListObjectsV2 {
307 prefix,
308 delimiter,
309 continuation_token,
310 start_after,
311 max_keys,
312 } = self.command().clone()
313 {
314 let mut query_pairs = url.query_pairs_mut();
315 delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
316
317 query_pairs.append_pair("prefix", &prefix);
318 query_pairs.append_pair("list-type", "2");
319 if let Some(token) = continuation_token {
320 query_pairs.append_pair("continuation-token", &token);
321 }
322 if let Some(start_after) = start_after {
323 query_pairs.append_pair("start-after", &start_after);
324 }
325 if let Some(max_keys) = max_keys {
326 query_pairs.append_pair("max-keys", &max_keys.to_string());
327 }
328 }
329
330 if let Command::ListObjects {
331 prefix,
332 delimiter,
333 marker,
334 max_keys,
335 } = self.command().clone()
336 {
337 let mut query_pairs = url.query_pairs_mut();
338 delimiter.map(|d| query_pairs.append_pair("delimiter", &d));
339
340 query_pairs.append_pair("prefix", &prefix);
341 if let Some(marker) = marker {
342 query_pairs.append_pair("marker", &marker);
343 }
344 if let Some(max_keys) = max_keys {
345 query_pairs.append_pair("max-keys", &max_keys.to_string());
346 }
347 }
348
349 match self.command() {
350 Command::ListMultipartUploads {
351 prefix,
352 delimiter,
353 key_marker,
354 max_uploads,
355 } => {
356 let mut query_pairs = url.query_pairs_mut();
357 delimiter.map(|d| query_pairs.append_pair("delimiter", d));
358 if let Some(prefix) = prefix {
359 query_pairs.append_pair("prefix", prefix);
360 }
361 if let Some(key_marker) = key_marker {
362 query_pairs.append_pair("key-marker", &key_marker);
363 }
364 if let Some(max_uploads) = max_uploads {
365 query_pairs.append_pair("max-uploads", max_uploads.to_string().as_str());
366 }
367 }
368 Command::PutObjectTagging { .. }
369 | Command::GetObjectTagging
370 | Command::DeleteObjectTagging => {
371 url.query_pairs_mut().append_pair("tagging", "");
372 }
373 _ => {}
374 }
375
376 Ok(url)
377 }
378
379 fn canonical_request(&self, headers: &HeaderMap) -> Result<String, S3Error> {
380 signing::canonical_request(
381 &self.command().http_verb().to_string(),
382 &self.url()?,
383 headers,
384 &self.command().sha256(),
385 )
386 }
387
388 fn authorization(&self, headers: &HeaderMap) -> Result<String, S3Error> {
389 let canonical_request = self.canonical_request(headers)?;
390 let string_to_sign = self.string_to_sign(&canonical_request)?;
391 let mut hmac = signing::HmacSha256::new_from_slice(&self.signing_key()?)?;
392 hmac.update(string_to_sign.as_bytes());
393 let signature = hex::encode(hmac.finalize().into_bytes());
394 let signed_header = signing::signed_header_string(headers);
395 signing::authorization_header(
396 &self.bucket().access_key()?.expect("No access_key provided"),
397 &self.datetime(),
398 &self.bucket().region(),
399 &signed_header,
400 &signature,
401 )
402 }
403
404 fn headers(&self) -> Result<HeaderMap, S3Error> {
405 let sha256 = self.command().sha256();
407
408 let mut headers = HeaderMap::new();
412
413 for (k, v) in self.bucket().extra_headers.iter() {
414 headers.insert(k.clone(), v.clone());
415 }
416
417 let host_header = self.host_header();
418
419 headers.insert(HOST, host_header.parse()?);
420
421 match self.command() {
422 Command::CopyObject { from } => {
423 headers.insert(HeaderName::from_static("x-amz-copy-source"), from.parse()?);
424 }
425 Command::ListObjects { .. } => {}
426 Command::ListObjectsV2 { .. } => {}
427 Command::GetObject => {}
428 Command::GetObjectTagging => {}
429 Command::GetBucketLocation => {}
430 _ => {
431 headers.insert(
432 CONTENT_LENGTH,
433 self.command().content_length().to_string().parse()?,
434 );
435 headers.insert(CONTENT_TYPE, self.command().content_type().parse()?);
436 }
437 }
438 headers.insert(
439 HeaderName::from_static("x-amz-content-sha256"),
440 sha256.parse()?,
441 );
442 headers.insert(
443 HeaderName::from_static("x-amz-date"),
444 self.long_date()?.parse()?,
445 );
446
447 if let Some(session_token) = self.bucket().session_token()? {
448 headers.insert(
449 HeaderName::from_static("x-amz-security-token"),
450 session_token.parse()?,
451 );
452 } else if let Some(security_token) = self.bucket().security_token()? {
453 headers.insert(
454 HeaderName::from_static("x-amz-security-token"),
455 security_token.parse()?,
456 );
457 }
458
459 if let Command::PutObjectTagging { tags } = self.command() {
460 let digest = md5::compute(tags);
461 let hash = general_purpose::STANDARD.encode(digest.as_ref());
462 headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
463 } else if let Command::PutObject { content, .. } = self.command() {
464 let digest = md5::compute(content);
465 let hash = general_purpose::STANDARD.encode(digest.as_ref());
466 headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
467 } else if let Command::UploadPart { content, .. } = self.command() {
468 let digest = md5::compute(content);
469 let hash = general_purpose::STANDARD.encode(digest.as_ref());
470 headers.insert(HeaderName::from_static("content-md5"), hash.parse()?);
471 } else if let Command::GetObject {} = self.command() {
472 headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
473 } else if let Command::GetObjectRange { start, end } = self.command() {
475 headers.insert(ACCEPT, "application/octet-stream".to_string().parse()?);
476
477 let mut range = format!("bytes={}-", start);
478
479 if let Some(end) = end {
480 range.push_str(&end.to_string());
481 }
482
483 headers.insert(RANGE, range.parse()?);
484 } else if let Command::CreateBucket { ref config } = self.command() {
485 config.add_headers(&mut headers)?;
486 }
487
488 if self.bucket().secret_key()?.is_some() {
490 let authorization = self.authorization(&headers)?;
491 headers.insert(AUTHORIZATION, authorization.parse()?);
492 }
493
494 headers.insert(DATE, self.datetime().format(&Rfc2822)?.parse()?);
501
502 Ok(headers)
503 }
504}