1use std::{marker::PhantomData, str::FromStr};
2
3use conditional_headers::ConditionalHeaders;
4use content_headers::ContentHeaders;
5use x_amz_headers::{XAmzHeaders, XAmzStorageClass, storage_class_from_str};
6
7use anyhow::{Result, anyhow};
8use chrono::{DateTime, Utc};
9use hmac::{Hmac, Mac};
10use http::{StatusCode, response::Parts};
11use percent_encoding::{AsciiSet, CONTROLS};
12use sha2::{Digest, Sha256};
13use wstd::http::{Body, HeaderName, HeaderValue, Method, Request, Response, Scheme, Uri};
14use xml::{EventReader, reader::XmlEvent};
15
16use crate::AWS_SERVICE;
17
18pub mod get_object;
19pub mod head_object;
20pub mod list_buckets;
21pub mod list_objects_v2;
22pub mod put_object;
23
24pub mod conditional_headers;
25pub mod content_headers;
26pub mod x_amz_headers;
27
28const AWS_SERVICE_EMPTY_PAYLOAD: &[u8] = "UNSIGNED-PAYLOAD".as_bytes();
29const AWS_SIGN_ALGORITHM: &str = "AWS4-HMAC-SHA256";
30const QUERY_SET: &AsciiSet = &CONTROLS
31 .add(b' ')
32 .add(b'/')
33 .add(b':') .add(b',') .add(b'?')
36 .add(b'#')
37 .add(b'[')
38 .add(b']')
39 .add(b'{')
40 .add(b'}')
41 .add(b'|')
42 .add(b'@')
43 .add(b'!')
44 .add(b'$')
45 .add(b'&')
46 .add(b'\'')
47 .add(b'(')
48 .add(b')')
49 .add(b'*')
50 .add(b'+')
51 .add(b';')
52 .add(b'=')
53 .add(b'%')
54 .add(b'<')
55 .add(b'>')
56 .add(b'"')
57 .add(b'^')
58 .add(b'`')
59 .add(b'\\');
60const PATH_SET: &AsciiSet = &QUERY_SET.remove(b'/');
61
62pub enum ChecksumAlgorithm {
63 CRC32,
64 CRC32C,
65 SHA1,
66 SHA256,
67 CRC64NVME,
68 Alogrithm(String),
69}
70pub(crate) fn checksum_algorithm_from_str(algo: String) -> ChecksumAlgorithm {
71 match algo.to_lowercase() {
72 a if a == "crc32" => ChecksumAlgorithm::CRC32,
73 a if a == "crc32c" => ChecksumAlgorithm::CRC32C,
74 a if a == "sha1" => ChecksumAlgorithm::SHA1,
75 a if a == "sha256" => ChecksumAlgorithm::SHA256,
76 a if a == "crc64nvme" => ChecksumAlgorithm::CRC64NVME,
77
78 a => ChecksumAlgorithm::Alogrithm(a),
79 }
80}
81
82pub(crate) fn parse_xml_string(parser: &mut EventReader<&[u8]>, field: &str) -> Result<String> {
83 if let XmlEvent::Characters(value) = parser.next()? {
84 Ok(value)
85 } else {
86 Err(anyhow!("Invalid response object, {field} has no value"))
87 }
88}
89
90pub(crate) fn parse_xml_bool(parser: &mut EventReader<&[u8]>, field: &str) -> Result<bool> {
91 if let XmlEvent::Characters(value) = parser.next()? {
92 match value.to_lowercase() {
93 v if v == "true" => Ok(true),
94 v if v == "false" => Ok(false),
95 _ => {
96 Err(anyhow!(
97 "Invalid response object, {field} is not a boolean, value: {value}"
98 ))
99 }
100 }
101 } else {
102 Err(anyhow!(
103 "Invalid response object, {field} element has no value"
104 ))
105 }
106}
107
108pub(crate) fn parse_xml_value<T>(parser: &mut EventReader<&[u8]>, field: &str) -> Result<T>
109where
110 T: FromStr,
111{
112 if let XmlEvent::Characters(value) = parser.next()? {
113 match value.parse::<T>() {
114 Ok(v) => Ok(v),
115 Err(_) => Err(anyhow!(
116 "Unable to parse value for field {field}, value {value}"
117 )),
118 }
119 } else {
120 Err(anyhow!("Invalid response object, {field} has no value"))
121 }
122}
123
124pub enum ApiChecksumType {
125 Composite,
126 FullObject,
127}
128
129pub struct ApiRestoreStatus {
130 pub is_restore_in_progress: bool,
131 pub restore_expiry_date: DateTime<Utc>,
132}
133
134pub struct ApiObject {
135 pub checksum_algorithm: Option<ChecksumAlgorithm>,
136 pub checksum_type: Option<ApiChecksumType>,
137 pub etag: String,
138 pub key: String,
139 pub last_modified: DateTime<Utc>,
140 pub owner: Option<ApiOwner>,
141 pub restore_status: Option<ApiRestoreStatus>,
142 pub size: usize,
143 pub storage_class: XAmzStorageClass,
144}
145
146impl ApiObject {
147 pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
148 let mut api_object = ApiObject {
149 checksum_algorithm: None,
150 checksum_type: None,
151 etag: String::new(),
152 key: String::new(),
153 last_modified: Utc::now(),
154 owner: None,
155 restore_status: None,
156 size: 0,
157 storage_class: XAmzStorageClass::Standard,
158 };
159 loop {
160 match parser.next()? {
161 XmlEvent::EndElement { name } if name.local_name == "Contents" => break,
162
163 XmlEvent::StartElement { name, .. } if name.local_name == "ChecksumAlgorithm" => {
164 api_object.checksum_algorithm = Some(checksum_algorithm_from_str(
165 parse_xml_string(parser, "ChecksumAlgorithm")?,
166 ));
167 }
168 XmlEvent::StartElement { name, .. } if name.local_name == "ChecksumType" => {
169 let checksum_type = match parse_xml_string(parser, "ChecksumType")? {
170 v if v == "COMPOSITE" => ApiChecksumType::Composite,
171 v if v == "FULL_OBJECT" => ApiChecksumType::FullObject,
172
173 _ => {
174 return Err(anyhow!(
175 "Invalid response object, ChecksumType has an invalid type"
176 ));
177 }
178 };
179 api_object.checksum_type = Some(checksum_type);
180 }
181 XmlEvent::StartElement { name, .. } if name.local_name == "ETag" => {
182 api_object.etag = parse_xml_string(parser, "ETag")?;
183 }
184 XmlEvent::StartElement { name, .. } if name.local_name == "Key" => {
185 api_object.key = parse_xml_string(parser, "Key")?;
186 }
187 XmlEvent::StartElement { name, .. } if name.local_name == "LastModified" => {
188 if let XmlEvent::Characters(value) = &parser.next()? {
189 let datetime = DateTime::parse_from_rfc3339(value)?.to_utc();
190 api_object.last_modified = datetime;
191 } else {
192 return Err(anyhow!(
193 "Invalid response object, LastModified has no value"
194 ));
195 }
196 }
197 XmlEvent::StartElement { name, .. } if name.local_name == "Size" => {
198 api_object.size = parse_xml_value::<usize>(parser, "Size")?;
199 }
200 XmlEvent::StartElement { name, .. } if name.local_name == "StorageClass" => {
201 api_object.storage_class =
202 storage_class_from_str(parse_xml_string(parser, "StorageClass")?);
203 }
204
205 XmlEvent::StartElement { name, .. } if name.local_name == "Owner" => {
206 api_object.owner = Some(ApiOwner::parse(parser)?);
207 }
208 XmlEvent::StartElement { name, .. } if name.local_name == "RestoreStatus" => {
209 let mut restore_status = ApiRestoreStatus {
210 is_restore_in_progress: false,
211 restore_expiry_date: Utc::now(),
212 };
213
214 loop {
215 match parser.next()? {
216 XmlEvent::StartElement { name, .. } => {
217 if name.local_name == "IsRestoreInProgress" {
218 restore_status.is_restore_in_progress =
219 parse_xml_bool(parser, "IsRestoreInProgress")?;
220 } else if name.local_name == "RestoreExpiryDate" {
221 let datetime = DateTime::parse_from_rfc3339(
222 &parse_xml_string(parser, "RestoreExpiryDate")?,
223 )?
224 .to_utc();
225 restore_status.restore_expiry_date = datetime;
226 }
227 }
228 XmlEvent::EndElement { name } if name.local_name == "Owner" => break,
229 _ => {}
230 }
231 }
232
233 api_object.restore_status = Some(restore_status)
234 }
235
236 _ => {}
237 }
238 }
239
240 Ok(api_object)
241 }
242}
243
244pub struct ApiBucket {
245 pub name: String,
246 pub creation_date: Option<DateTime<Utc>>,
247 pub region: String,
248}
249
250impl ApiBucket {
251 pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
252 let mut bucket = Self {
253 name: String::new(),
254 creation_date: None,
255 region: String::new(),
256 };
257 loop {
258 match parser.next()? {
259 XmlEvent::StartElement { name, .. } if name.local_name == "BucketRegion" => {
260 bucket.region = parse_xml_string(parser, "BucketRegion")?;
261 }
262 XmlEvent::StartElement { name, .. } if name.local_name == "CreationDate" => {
263 let datetime =
264 DateTime::parse_from_rfc3339(&parse_xml_string(parser, "CreationDate")?)?
265 .to_utc();
266 bucket.creation_date = Some(datetime);
267 }
268 XmlEvent::StartElement { name, .. } if name.local_name == "Name" => {
269 bucket.name = parse_xml_string(parser, "")?;
270 }
271 XmlEvent::EndElement { name } if name.local_name == "Bucket" => break,
272 _ => {}
273 }
274 }
275 Ok(bucket)
276 }
277}
278
279pub struct ApiOwner {
280 pub display_name: Option<String>,
281 pub id: String,
282}
283
284impl ApiOwner {
285 pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
286 let mut api_owner = Self {
287 display_name: None,
288 id: String::new(),
289 };
290 loop {
291 match parser.next()? {
292 XmlEvent::StartElement { name, .. } => {
293 if let XmlEvent::Characters(value) = parser.next()? {
294 if name.local_name == "DisplayName" {
295 api_owner.display_name = Some(value);
296 } else if name.local_name == "ID" {
297 api_owner.id = value;
298 }
299 } else {
300 return Err(anyhow!(
301 "Invalid response object, {name} element has no value"
302 ));
303 }
304 }
305 XmlEvent::EndElement { name } if name.local_name == "Owner" => break,
306 _ => {}
307 }
308 }
309
310 Ok(api_owner)
311 }
312}
313
314pub trait S3RequestData {
315 type ResponseType;
316 fn into_builder(
318 &self,
319 access_key: &str,
320 secret_key: &str,
321 region: &str,
322 endpoint: &str,
323 ) -> Result<S3RequestBuilder<Self::ResponseType>>
324 where
325 <Self as S3RequestData>::ResponseType: S3ResponseData;
326}
327
328pub struct S3Request<T>
329where
330 T: S3ResponseData,
331{
332 pub request: Request<Body>,
333 phantom: PhantomData<T>,
334}
335
336pub trait S3ResponseData {
337 #[allow(async_fn_in_trait)]
339 async fn parse_body(response: &mut Body) -> Result<Self>
340 where
341 Self: Sized;
342}
343
344pub struct S3Response<T>
345where
346 T: S3ResponseData,
347{
348 head: Parts,
349 body: Body,
350 phantom: PhantomData<T>,
351}
352
353impl<T> S3Response<T>
354where
355 T: S3ResponseData,
356{
357 pub fn from_response(response: Response<Body>) -> Result<Self> {
358 let (head, body) = response.into_parts();
359 Ok(Self {
360 head,
361 body,
362 phantom: PhantomData,
363 })
364 }
365
366 pub fn status(&self) -> StatusCode {
367 self.head.status
368 }
369
370 pub fn into_parts(self) -> (Parts, Body) {
371 (self.head, self.body)
372 }
373
374 pub async fn into_response_data(&mut self) -> Result<T> {
376 T::parse_body(&mut self.body).await
377 }
378
379 pub async fn into_response_data_parts(&mut self) -> Result<(Parts, T)> {
381 let body = T::parse_body(&mut self.body).await?;
382 Ok((self.head.clone(), body))
383 }
384}
385
386fn get_signature_key(secret_key: &str, date: &str, region: &str, service: &str) -> Result<Vec<u8>> {
387 let k_secret = format!("AWS4{}", secret_key);
388 let k_date = hmac_sha256(k_secret.as_bytes(), date.as_bytes())?;
389 let k_region = hmac_sha256(&k_date, region.as_bytes())?;
390 let k_service = hmac_sha256(&k_region, service.as_bytes())?;
391 hmac_sha256(&k_service, b"aws4_request")
392}
393
394fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
395 let mut mac = Hmac::<Sha256>::new_from_slice(key)?;
396 mac.update(data);
397 Ok(mac.finalize().into_bytes().to_vec())
398}
399
400fn percent_encode_query<T: AsRef<str>>(value: T) -> String {
401 percent_encoding::utf8_percent_encode(value.as_ref(), QUERY_SET).to_string()
402}
403fn percent_encode_path<T: AsRef<str>>(value: T) -> String {
404 percent_encoding::utf8_percent_encode(value.as_ref(), PATH_SET).to_string()
405}
406
407pub struct S3RequestBuilder<T: S3ResponseData> {
409 pub(crate) method: Method,
410 pub(crate) action: String,
411 pub(crate) query: Vec<(String, String)>,
412 pub(crate) headers: Vec<(String, String)>,
413
414 pub(crate) x_amz_headers: Vec<(String, String)>,
415
416 pub(crate) access_key: String,
417 pub(crate) secret_key: String,
418 pub(crate) region: String,
419 pub(crate) endpoint: String,
420
421 pub(crate) scheme: Scheme,
422
423 pub(crate) body: Option<Vec<u8>>,
424
425 phantom: PhantomData<T>,
426}
427
428impl<T> S3RequestBuilder<T>
429where
430 T: S3ResponseData,
431{
432 pub fn new(
436 method: Method,
437 action: &str,
438 access_key: &str,
439 secret_key: &str,
440 region: &str,
441 endpoint: &str,
442 ) -> Self {
443 Self {
444 method,
445 action: action.to_owned(),
446 query: Vec::new(),
447 headers: Vec::new(),
448 x_amz_headers: Vec::new(),
449 access_key: access_key.to_owned(),
450 secret_key: secret_key.to_owned(),
451 region: region.to_owned(),
452 endpoint: endpoint.to_owned(),
453 scheme: Scheme::HTTPS,
454 body: None,
455 phantom: PhantomData,
456 }
457 }
458
459 pub fn method(&mut self, method: Method) -> &mut Self {
460 self.method = method;
461 self
462 }
463 pub fn action(&mut self, action: &str) -> &mut Self {
464 self.action = percent_encode_path(action);
465 self
466 }
467
468 pub fn query(&mut self, key: &str, value: Option<&str>) -> &mut Self {
470 let str_value = match value {
471 Some(v) => percent_encode_query(v),
472 None => percent_encode_query(""),
473 };
474 self.query.push((percent_encode_query(key), str_value));
475 self
476 }
477 pub fn header(&mut self, key: &str, value: &str) -> &mut Self {
479 if key.starts_with("x-amz") {
480 self.x_amz_headers.push((key.to_owned(), value.to_owned()));
481 self
482 } else {
483 self.headers.push((key.to_owned(), value.to_owned()));
484 self
485 }
486 }
487 pub fn headers(&mut self, headers: Vec<(String, String)>) -> &mut Self {
489 for (k, v) in headers {
490 self.header(&k, &v);
491 }
492
493 self
494 }
495 pub fn body<B>(&mut self, body: B) -> &mut Self
497 where
498 B: AsRef<[u8]>,
499 {
500 let b = body.as_ref().to_vec();
501 self.body = Some(b);
502 self
503 }
504 pub fn scheme(&mut self, scheme: Scheme) -> &mut Self {
506 self.scheme = scheme;
507 self
508 }
509
510 pub fn set_content_headers(&mut self, content: &ContentHeaders) -> &mut Self {
515 let mut content_headers = content.get_headers();
516 self.headers.append(&mut content_headers);
517 self
518 }
519 pub fn set_content_query(&mut self, content: &ContentHeaders) -> &mut Self {
525 let query = content.get_query();
526 for (key, value) in query {
527 self.query(&key, Some(&value));
528 }
529 self
530 }
531 pub fn set_conditional_headers(&mut self, conds: &ConditionalHeaders) -> &mut Self {
536 let mut conditional_headers = conds.get_headers();
537 self.headers.append(&mut conditional_headers);
538 self
539 }
540 pub fn set_x_amz_headers(&mut self, xamz: &XAmzHeaders) -> &mut Self {
545 let mut xamz_headers = xamz.headers();
546 self.x_amz_headers.append(&mut xamz_headers);
547 self
548 }
549
550 pub fn set_auth(
552 &mut self,
553 access_key: &str,
554 secret_key: &str,
555 region: &str,
556 endpoint: &str,
557 ) -> &mut Self {
558 self.access_key = access_key.to_owned();
559 self.secret_key = secret_key.to_owned();
560 self.region = region.to_owned();
561 self.endpoint = endpoint.to_owned();
562 self
563 }
564
565 pub fn build(&mut self) -> Result<S3Request<T>> {
567 let now = Utc::now();
569 let date_stamp = now.format("%Y%m%d").to_string();
570 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
571
572 let query = match self.query.is_empty() {
574 true => "".to_string(),
575 false => {
576 self.query.sort();
577 self.query
578 .iter()
579 .map(|(k, v)| format!("{k}={v}"))
580 .collect::<Vec<String>>()
581 .join("&")
582 }
583 };
584
585 let payload_hash = match &self.body {
587 Some(b) => hex::encode(Sha256::digest(b)),
588 None => hex::encode(Sha256::digest(AWS_SERVICE_EMPTY_PAYLOAD)),
589 };
590
591 let host_uri = Uri::from_str(&self.endpoint)?;
605 let (scheme, host) = match (host_uri.scheme(), host_uri.authority().map(|a| a.as_str())) {
606 (None, Some(host)) => (&self.scheme, host),
607 (Some(scheme), Some(host)) => (scheme, host),
608 (_, None) => {
609 return Err(anyhow!("No host defined"));
610 }
611 };
612
613 let mut canonical_headers_vec = match self.x_amz_headers.is_empty() {
615 true => Vec::new(),
616 false => self.x_amz_headers.clone(),
617 };
618 canonical_headers_vec.push(("host".to_string(), host.to_string()));
619 canonical_headers_vec.push(("x-amz-content-sha256".to_string(), payload_hash.clone()));
620 canonical_headers_vec.push(("x-amz-date".to_string(), amz_date.clone()));
621 canonical_headers_vec.sort();
622 let mut canonical_headers = canonical_headers_vec
623 .iter()
624 .map(|(k, v)| format!("{k}:{v}"))
625 .collect::<Vec<String>>()
626 .join("\n");
627 canonical_headers.push('\n');
628 let signed_headers = canonical_headers_vec
629 .iter()
630 .map(|(k, _)| k.to_owned())
631 .collect::<Vec<String>>()
632 .join(";");
633
634 let method = self.method.as_str();
635 let canonical_request = format!(
636 "{method}\n/{action}\n{query}\n{canonical_headers}\n{signed_headers}\n{payload_hash}",
637 action = self.action
638 );
639 let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
640
641 let credential_scope = format!("{date_stamp}/{}/{AWS_SERVICE}/aws4_request", self.region);
643 let string_to_sign = format!(
644 "{AWS_SIGN_ALGORITHM}\n{amz_date}\n{credential_scope}\n{canonical_request_hash}"
645 );
646
647 let signing_key =
648 get_signature_key(&self.secret_key, &date_stamp, &self.region, AWS_SERVICE)?;
649
650 let mut mac = Hmac::<Sha256>::new_from_slice(&signing_key)?;
652 mac.update(string_to_sign.as_bytes());
653 let signature = hex::encode(mac.finalize().into_bytes());
654
655 let authorization_header = format!(
657 "{AWS_SIGN_ALGORITHM} Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
658 self.access_key
659 );
660
661 let body = match &self.body {
662 Some(b) => b,
663 None => "".as_bytes(),
664 };
665
666 let uri = match self.query.is_empty() {
667 true => format!("{scheme}://{host}/{}", self.action),
668 false => format!("{scheme}://{host}/{}?{query}", self.action),
669 };
670 let mut builder = Request::builder()
671 .uri(uri)
672 .method(&self.method)
673 .header("x-amz-content-sha256", payload_hash)
674 .header("x-amz-date", amz_date)
675 .header("authorization", authorization_header)
676 .header("content-length", body.len().to_string());
677
678 if let Some(headers) = builder.headers_mut() {
679 for (key, value) in &self.headers {
680 headers.insert(HeaderName::from_str(key)?, HeaderValue::from_str(value)?);
681 }
682 };
683
684 let request = S3Request::<T> {
685 request: builder.body(body.into())?,
686 phantom: PhantomData,
687 };
688
689 Ok(request)
690 }
691}