Skip to main content

s3_wasi_http/api/
mod.rs

1use std::{marker::PhantomData, str::FromStr};
2
3use conditional_headers::ConditionalHeaders;
4use content_headers::ContentHeaders;
5use x_amz_headers::{XAmzHeaders, XAmzStorageClass, storage_class_from_str};
6
7use anyhow::{Result, anyhow};
8use chrono::{DateTime, Utc};
9use hmac::{Hmac, Mac};
10use http::{StatusCode, response::Parts};
11use percent_encoding::{AsciiSet, CONTROLS};
12use sha2::{Digest, Sha256};
13use wstd::http::{Body, HeaderName, HeaderValue, Method, Request, Response, Scheme, Uri};
14use xml::{EventReader, reader::XmlEvent};
15
16use crate::AWS_SERVICE;
17
18pub mod get_object;
19pub mod head_object;
20pub mod list_buckets;
21pub mod list_objects_v2;
22pub mod put_object;
23
24pub mod conditional_headers;
25pub mod content_headers;
26pub mod x_amz_headers;
27
28const AWS_SERVICE_EMPTY_PAYLOAD: &[u8] = "UNSIGNED-PAYLOAD".as_bytes();
29const AWS_SIGN_ALGORITHM: &str = "AWS4-HMAC-SHA256";
30const QUERY_SET: &AsciiSet = &CONTROLS
31    .add(b' ')
32    .add(b'/')
33    .add(b':') // Required to be percent encoded to function with aws services
34    .add(b',') // Required to be percent encoded to function with aws services
35    .add(b'?')
36    .add(b'#')
37    .add(b'[')
38    .add(b']')
39    .add(b'{')
40    .add(b'}')
41    .add(b'|')
42    .add(b'@')
43    .add(b'!')
44    .add(b'$')
45    .add(b'&')
46    .add(b'\'')
47    .add(b'(')
48    .add(b')')
49    .add(b'*')
50    .add(b'+')
51    .add(b';')
52    .add(b'=')
53    .add(b'%')
54    .add(b'<')
55    .add(b'>')
56    .add(b'"')
57    .add(b'^')
58    .add(b'`')
59    .add(b'\\');
60const PATH_SET: &AsciiSet = &QUERY_SET.remove(b'/');
61
62pub enum ChecksumAlgorithm {
63    CRC32,
64    CRC32C,
65    SHA1,
66    SHA256,
67    CRC64NVME,
68    Alogrithm(String),
69}
70pub(crate) fn checksum_algorithm_from_str(algo: String) -> ChecksumAlgorithm {
71    match algo.to_lowercase() {
72        a if a == "crc32" => ChecksumAlgorithm::CRC32,
73        a if a == "crc32c" => ChecksumAlgorithm::CRC32C,
74        a if a == "sha1" => ChecksumAlgorithm::SHA1,
75        a if a == "sha256" => ChecksumAlgorithm::SHA256,
76        a if a == "crc64nvme" => ChecksumAlgorithm::CRC64NVME,
77
78        a => ChecksumAlgorithm::Alogrithm(a),
79    }
80}
81
82pub(crate) fn parse_xml_string(parser: &mut EventReader<&[u8]>, field: &str) -> Result<String> {
83    if let XmlEvent::Characters(value) = parser.next()? {
84        Ok(value)
85    } else {
86        Err(anyhow!("Invalid response object, {field} has no value"))
87    }
88}
89
90pub(crate) fn parse_xml_bool(parser: &mut EventReader<&[u8]>, field: &str) -> Result<bool> {
91    if let XmlEvent::Characters(value) = parser.next()? {
92        match value.to_lowercase() {
93            v if v == "true" => Ok(true),
94            v if v == "false" => Ok(false),
95            _ => {
96                Err(anyhow!(
97                    "Invalid response object, {field} is not a boolean, value: {value}"
98                ))
99            }
100        }
101    } else {
102        Err(anyhow!(
103            "Invalid response object, {field} element has no value"
104        ))
105    }
106}
107
108pub(crate) fn parse_xml_value<T>(parser: &mut EventReader<&[u8]>, field: &str) -> Result<T>
109where
110    T: FromStr,
111{
112    if let XmlEvent::Characters(value) = parser.next()? {
113        match value.parse::<T>() {
114            Ok(v) => Ok(v),
115            Err(_) => Err(anyhow!(
116                "Unable to parse value for field {field}, value {value}"
117            )),
118        }
119    } else {
120        Err(anyhow!("Invalid response object, {field} has no value"))
121    }
122}
123
124pub enum ApiChecksumType {
125    Composite,
126    FullObject,
127}
128
129pub struct ApiRestoreStatus {
130    pub is_restore_in_progress: bool,
131    pub restore_expiry_date: DateTime<Utc>,
132}
133
134pub struct ApiObject {
135    pub checksum_algorithm: Option<ChecksumAlgorithm>,
136    pub checksum_type: Option<ApiChecksumType>,
137    pub etag: String,
138    pub key: String,
139    pub last_modified: DateTime<Utc>,
140    pub owner: Option<ApiOwner>,
141    pub restore_status: Option<ApiRestoreStatus>,
142    pub size: usize,
143    pub storage_class: XAmzStorageClass,
144}
145
146impl ApiObject {
147    pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
148        let mut api_object = ApiObject {
149            checksum_algorithm: None,
150            checksum_type: None,
151            etag: String::new(),
152            key: String::new(),
153            last_modified: Utc::now(),
154            owner: None,
155            restore_status: None,
156            size: 0,
157            storage_class: XAmzStorageClass::Standard,
158        };
159        loop {
160            match parser.next()? {
161                XmlEvent::EndElement { name } if name.local_name == "Contents" => break,
162
163                XmlEvent::StartElement { name, .. } if name.local_name == "ChecksumAlgorithm" => {
164                    api_object.checksum_algorithm = Some(checksum_algorithm_from_str(
165                        parse_xml_string(parser, "ChecksumAlgorithm")?,
166                    ));
167                }
168                XmlEvent::StartElement { name, .. } if name.local_name == "ChecksumType" => {
169                    let checksum_type = match parse_xml_string(parser, "ChecksumType")? {
170                        v if v == "COMPOSITE" => ApiChecksumType::Composite,
171                        v if v == "FULL_OBJECT" => ApiChecksumType::FullObject,
172
173                        _ => {
174                            return Err(anyhow!(
175                                "Invalid response object, ChecksumType has an invalid type"
176                            ));
177                        }
178                    };
179                    api_object.checksum_type = Some(checksum_type);
180                }
181                XmlEvent::StartElement { name, .. } if name.local_name == "ETag" => {
182                    api_object.etag = parse_xml_string(parser, "ETag")?;
183                }
184                XmlEvent::StartElement { name, .. } if name.local_name == "Key" => {
185                    api_object.key = parse_xml_string(parser, "Key")?;
186                }
187                XmlEvent::StartElement { name, .. } if name.local_name == "LastModified" => {
188                    if let XmlEvent::Characters(value) = &parser.next()? {
189                        let datetime = DateTime::parse_from_rfc3339(value)?.to_utc();
190                        api_object.last_modified = datetime;
191                    } else {
192                        return Err(anyhow!(
193                            "Invalid response object, LastModified has no value"
194                        ));
195                    }
196                }
197                XmlEvent::StartElement { name, .. } if name.local_name == "Size" => {
198                    api_object.size = parse_xml_value::<usize>(parser, "Size")?;
199                }
200                XmlEvent::StartElement { name, .. } if name.local_name == "StorageClass" => {
201                    api_object.storage_class =
202                        storage_class_from_str(parse_xml_string(parser, "StorageClass")?);
203                }
204
205                XmlEvent::StartElement { name, .. } if name.local_name == "Owner" => {
206                    api_object.owner = Some(ApiOwner::parse(parser)?);
207                }
208                XmlEvent::StartElement { name, .. } if name.local_name == "RestoreStatus" => {
209                    let mut restore_status = ApiRestoreStatus {
210                        is_restore_in_progress: false,
211                        restore_expiry_date: Utc::now(),
212                    };
213
214                    loop {
215                        match parser.next()? {
216                            XmlEvent::StartElement { name, .. } => {
217                                if name.local_name == "IsRestoreInProgress" {
218                                    restore_status.is_restore_in_progress =
219                                        parse_xml_bool(parser, "IsRestoreInProgress")?;
220                                } else if name.local_name == "RestoreExpiryDate" {
221                                    let datetime = DateTime::parse_from_rfc3339(
222                                        &parse_xml_string(parser, "RestoreExpiryDate")?,
223                                    )?
224                                    .to_utc();
225                                    restore_status.restore_expiry_date = datetime;
226                                }
227                            }
228                            XmlEvent::EndElement { name } if name.local_name == "Owner" => break,
229                            _ => {}
230                        }
231                    }
232
233                    api_object.restore_status = Some(restore_status)
234                }
235
236                _ => {}
237            }
238        }
239
240        Ok(api_object)
241    }
242}
243
244pub struct ApiBucket {
245    pub name: String,
246    pub creation_date: Option<DateTime<Utc>>,
247    pub region: String,
248}
249
250impl ApiBucket {
251    pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
252        let mut bucket = Self {
253            name: String::new(),
254            creation_date: None,
255            region: String::new(),
256        };
257        loop {
258            match parser.next()? {
259                XmlEvent::StartElement { name, .. } if name.local_name == "BucketRegion" => {
260                    bucket.region = parse_xml_string(parser, "BucketRegion")?;
261                }
262                XmlEvent::StartElement { name, .. } if name.local_name == "CreationDate" => {
263                    let datetime =
264                        DateTime::parse_from_rfc3339(&parse_xml_string(parser, "CreationDate")?)?
265                            .to_utc();
266                    bucket.creation_date = Some(datetime);
267                }
268                XmlEvent::StartElement { name, .. } if name.local_name == "Name" => {
269                    bucket.name = parse_xml_string(parser, "")?;
270                }
271                XmlEvent::EndElement { name } if name.local_name == "Bucket" => break,
272                _ => {}
273            }
274        }
275        Ok(bucket)
276    }
277}
278
279pub struct ApiOwner {
280    pub display_name: Option<String>,
281    pub id: String,
282}
283
284impl ApiOwner {
285    pub fn parse(parser: &mut EventReader<&[u8]>) -> Result<Self> {
286        let mut api_owner = Self {
287            display_name: None,
288            id: String::new(),
289        };
290        loop {
291            match parser.next()? {
292                XmlEvent::StartElement { name, .. } => {
293                    if let XmlEvent::Characters(value) = parser.next()? {
294                        if name.local_name == "DisplayName" {
295                            api_owner.display_name = Some(value);
296                        } else if name.local_name == "ID" {
297                            api_owner.id = value;
298                        }
299                    } else {
300                        return Err(anyhow!(
301                            "Invalid response object, {name} element has no value"
302                        ));
303                    }
304                }
305                XmlEvent::EndElement { name } if name.local_name == "Owner" => break,
306                _ => {}
307            }
308        }
309
310        Ok(api_owner)
311    }
312}
313
314pub trait S3RequestData {
315    type ResponseType;
316    /// Creates an S3RequestBuilder from the S3RequestData object
317    fn into_builder(
318        &self,
319        access_key: &str,
320        secret_key: &str,
321        region: &str,
322        endpoint: &str,
323    ) -> Result<S3RequestBuilder<Self::ResponseType>>
324    where
325        <Self as S3RequestData>::ResponseType: S3ResponseData;
326}
327
328pub struct S3Request<T>
329where
330    T: S3ResponseData,
331{
332    pub request: Request<Body>,
333    phantom: PhantomData<T>,
334}
335
336pub trait S3ResponseData {
337    /// Parse the response body into a S3ResponseData struct
338    #[allow(async_fn_in_trait)]
339    async fn parse_body(response: &mut Body) -> Result<Self>
340    where
341        Self: Sized;
342}
343
344pub struct S3Response<T>
345where
346    T: S3ResponseData,
347{
348    head: Parts,
349    body: Body,
350    phantom: PhantomData<T>,
351}
352
353impl<T> S3Response<T>
354where
355    T: S3ResponseData,
356{
357    pub fn from_response(response: Response<Body>) -> Result<Self> {
358        let (head, body) = response.into_parts();
359        Ok(Self {
360            head,
361            body,
362            phantom: PhantomData,
363        })
364    }
365
366    pub fn status(&self) -> StatusCode {
367        self.head.status
368    }
369
370    pub fn into_parts(self) -> (Parts, Body) {
371        (self.head, self.body)
372    }
373
374    /// Parse response body into an S3ResponseData struct
375    pub async fn into_response_data(&mut self) -> Result<T> {
376        T::parse_body(&mut self.body).await
377    }
378
379    /// Parse response body into an S3ResponseData struct and get headers
380    pub async fn into_response_data_parts(&mut self) -> Result<(Parts, T)> {
381        let body = T::parse_body(&mut self.body).await?;
382        Ok((self.head.clone(), body))
383    }
384}
385
386fn get_signature_key(secret_key: &str, date: &str, region: &str, service: &str) -> Result<Vec<u8>> {
387    let k_secret = format!("AWS4{}", secret_key);
388    let k_date = hmac_sha256(k_secret.as_bytes(), date.as_bytes())?;
389    let k_region = hmac_sha256(&k_date, region.as_bytes())?;
390    let k_service = hmac_sha256(&k_region, service.as_bytes())?;
391    hmac_sha256(&k_service, b"aws4_request")
392}
393
394fn hmac_sha256(key: &[u8], data: &[u8]) -> Result<Vec<u8>> {
395    let mut mac = Hmac::<Sha256>::new_from_slice(key)?;
396    mac.update(data);
397    Ok(mac.finalize().into_bytes().to_vec())
398}
399
400fn percent_encode_query<T: AsRef<str>>(value: T) -> String {
401    percent_encoding::utf8_percent_encode(value.as_ref(), QUERY_SET).to_string()
402}
403fn percent_encode_path<T: AsRef<str>>(value: T) -> String {
404    percent_encoding::utf8_percent_encode(value.as_ref(), PATH_SET).to_string()
405}
406
407/// Build and sign an s3 request
408pub struct S3RequestBuilder<T: S3ResponseData> {
409    pub(crate) method: Method,
410    pub(crate) action: String,
411    pub(crate) query: Vec<(String, String)>,
412    pub(crate) headers: Vec<(String, String)>,
413
414    pub(crate) x_amz_headers: Vec<(String, String)>,
415
416    pub(crate) access_key: String,
417    pub(crate) secret_key: String,
418    pub(crate) region: String,
419    pub(crate) endpoint: String,
420
421    pub(crate) scheme: Scheme,
422
423    pub(crate) body: Option<Vec<u8>>,
424
425    phantom: PhantomData<T>,
426}
427
428impl<T> S3RequestBuilder<T>
429where
430    T: S3ResponseData,
431{
432    /// Create a new S3RequestBuilder
433    ///
434    /// See [crate::S3Client::new_request_builder]
435    pub fn new(
436        method: Method,
437        action: &str,
438        access_key: &str,
439        secret_key: &str,
440        region: &str,
441        endpoint: &str,
442    ) -> Self {
443        Self {
444            method,
445            action: action.to_owned(),
446            query: Vec::new(),
447            headers: Vec::new(),
448            x_amz_headers: Vec::new(),
449            access_key: access_key.to_owned(),
450            secret_key: secret_key.to_owned(),
451            region: region.to_owned(),
452            endpoint: endpoint.to_owned(),
453            scheme: Scheme::HTTPS,
454            body: None,
455            phantom: PhantomData,
456        }
457    }
458
459    pub fn method(&mut self, method: Method) -> &mut Self {
460        self.method = method;
461        self
462    }
463    pub fn action(&mut self, action: &str) -> &mut Self {
464        self.action = percent_encode_path(action);
465        self
466    }
467
468    /// Add a query string
469    pub fn query(&mut self, key: &str, value: Option<&str>) -> &mut Self {
470        let str_value = match value {
471            Some(v) => percent_encode_query(v),
472            None => percent_encode_query(""),
473        };
474        self.query.push((percent_encode_query(key), str_value));
475        self
476    }
477    /// Add a header
478    pub fn header(&mut self, key: &str, value: &str) -> &mut Self {
479        if key.starts_with("x-amz") {
480            self.x_amz_headers.push((key.to_owned(), value.to_owned()));
481            self
482        } else {
483            self.headers.push((key.to_owned(), value.to_owned()));
484            self
485        }
486    }
487    /// Add a headers
488    pub fn headers(&mut self, headers: Vec<(String, String)>) -> &mut Self {
489        for (k, v) in headers {
490            self.header(&k, &v);
491        }
492
493        self
494    }
495    /// Set the request body
496    pub fn body<B>(&mut self, body: B) -> &mut Self
497    where
498        B: AsRef<[u8]>,
499    {
500        let b = body.as_ref().to_vec();
501        self.body = Some(b);
502        self
503    }
504    /// Set request scheme
505    pub fn scheme(&mut self, scheme: Scheme) -> &mut Self {
506        self.scheme = scheme;
507        self
508    }
509
510    /// Set the request content headers
511    ///
512    /// see [ContentHeaders]
513    /// [S3RequestBuilder::headers] can be easier if adding a small amount of headers
514    pub fn set_content_headers(&mut self, content: &ContentHeaders) -> &mut Self {
515        let mut content_headers = content.get_headers();
516        self.headers.append(&mut content_headers);
517        self
518    }
519    /// Set the request content query string will
520    /// also set the range header if set
521    ///
522    /// see [ContentHeaders]
523    /// [S3RequestBuilder::query] can be easier if adding a small amount of queries
524    pub fn set_content_query(&mut self, content: &ContentHeaders) -> &mut Self {
525        let query = content.get_query();
526        for (key, value) in query {
527            self.query(&key, Some(&value));
528        }
529        self
530    }
531    /// Set the request conditional headers
532    ///
533    /// see [ConditionalHeaders]
534    /// [S3RequestBuilder::headers] can be easier if adding a small amount of headers
535    pub fn set_conditional_headers(&mut self, conds: &ConditionalHeaders) -> &mut Self {
536        let mut conditional_headers = conds.get_headers();
537        self.headers.append(&mut conditional_headers);
538        self
539    }
540    /// Set the request x-amz headers
541    ///
542    /// See [XAmzHeaders] and [x_amz_headers::XAmzHeadersBuilder]
543    /// [S3RequestBuilder::headers] can be easier if adding a small amount of headers
544    pub fn set_x_amz_headers(&mut self, xamz: &XAmzHeaders) -> &mut Self {
545        let mut xamz_headers = xamz.headers();
546        self.x_amz_headers.append(&mut xamz_headers);
547        self
548    }
549
550    /// Set authentication values
551    pub fn set_auth(
552        &mut self,
553        access_key: &str,
554        secret_key: &str,
555        region: &str,
556        endpoint: &str,
557    ) -> &mut Self {
558        self.access_key = access_key.to_owned();
559        self.secret_key = secret_key.to_owned();
560        self.region = region.to_owned();
561        self.endpoint = endpoint.to_owned();
562        self
563    }
564
565    /// Build and sign the request
566    pub fn build(&mut self) -> Result<S3Request<T>> {
567        // Get current time in AWS format
568        let now = Utc::now();
569        let date_stamp = now.format("%Y%m%d").to_string();
570        let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
571
572        // Query string for the canonical request
573        let query = match self.query.is_empty() {
574            true => "".to_string(),
575            false => {
576                self.query.sort();
577                self.query
578                    .iter()
579                    .map(|(k, v)| format!("{k}={v}"))
580                    .collect::<Vec<String>>()
581                    .join("&")
582            }
583        };
584
585        // SHA-256 hash of the payload
586        let payload_hash = match &self.body {
587            Some(b) => hex::encode(Sha256::digest(b)),
588            None => hex::encode(Sha256::digest(AWS_SERVICE_EMPTY_PAYLOAD)),
589        };
590
591        // Get host from the uri.
592        //
593        // `Uri::host()` strips the port — e.g. `http://127.0.0.1:9000`
594        // returns `127.0.0.1`. That breaks SigV4 against any S3-compatible
595        // endpoint exposed on a non-default port (MinIO during local dev,
596        // Cloudflare R2 behind a custom port, in-cluster routes, …): the
597        // outbound request is sent to the bare host with no port, the
598        // signature is computed with no port in the canonical host header,
599        // and the request fails.
600        //
601        // `Uri::authority()` keeps the `host:port` form intact, which is
602        // what SigV4's canonical host header expects when a non-default
603        // port is in play.
604        let host_uri = Uri::from_str(&self.endpoint)?;
605        let (scheme, host) = match (host_uri.scheme(), host_uri.authority().map(|a| a.as_str())) {
606            (None, Some(host)) => (&self.scheme, host),
607            (Some(scheme), Some(host)) => (scheme, host),
608            (_, None) => {
609                return Err(anyhow!("No host defined"));
610            }
611        };
612
613        // Canonical Request
614        let mut canonical_headers_vec = match self.x_amz_headers.is_empty() {
615            true => Vec::new(),
616            false => self.x_amz_headers.clone(),
617        };
618        canonical_headers_vec.push(("host".to_string(), host.to_string()));
619        canonical_headers_vec.push(("x-amz-content-sha256".to_string(), payload_hash.clone()));
620        canonical_headers_vec.push(("x-amz-date".to_string(), amz_date.clone()));
621        canonical_headers_vec.sort();
622        let mut canonical_headers = canonical_headers_vec
623            .iter()
624            .map(|(k, v)| format!("{k}:{v}"))
625            .collect::<Vec<String>>()
626            .join("\n");
627        canonical_headers.push('\n');
628        let signed_headers = canonical_headers_vec
629            .iter()
630            .map(|(k, _)| k.to_owned())
631            .collect::<Vec<String>>()
632            .join(";");
633
634        let method = self.method.as_str();
635        let canonical_request = format!(
636            "{method}\n/{action}\n{query}\n{canonical_headers}\n{signed_headers}\n{payload_hash}",
637            action = self.action
638        );
639        let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
640
641        // String-to-Sign
642        let credential_scope = format!("{date_stamp}/{}/{AWS_SERVICE}/aws4_request", self.region);
643        let string_to_sign = format!(
644            "{AWS_SIGN_ALGORITHM}\n{amz_date}\n{credential_scope}\n{canonical_request_hash}"
645        );
646
647        let signing_key =
648            get_signature_key(&self.secret_key, &date_stamp, &self.region, AWS_SERVICE)?;
649
650        // Compute the Signature
651        let mut mac = Hmac::<Sha256>::new_from_slice(&signing_key)?;
652        mac.update(string_to_sign.as_bytes());
653        let signature = hex::encode(mac.finalize().into_bytes());
654
655        // Authorization Header
656        let authorization_header = format!(
657            "{AWS_SIGN_ALGORITHM} Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
658            self.access_key
659        );
660
661        let body = match &self.body {
662            Some(b) => b,
663            None => "".as_bytes(),
664        };
665
666        let uri = match self.query.is_empty() {
667            true => format!("{scheme}://{host}/{}", self.action),
668            false => format!("{scheme}://{host}/{}?{query}", self.action),
669        };
670        let mut builder = Request::builder()
671            .uri(uri)
672            .method(&self.method)
673            .header("x-amz-content-sha256", payload_hash)
674            .header("x-amz-date", amz_date)
675            .header("authorization", authorization_header)
676            .header("content-length", body.len().to_string());
677
678        if let Some(headers) = builder.headers_mut() {
679            for (key, value) in &self.headers {
680                headers.insert(HeaderName::from_str(key)?, HeaderValue::from_str(value)?);
681            }
682        };
683
684        let request = S3Request::<T> {
685            request: builder.body(body.into())?,
686            phantom: PhantomData,
687        };
688
689        Ok(request)
690    }
691}