s3_presign/
lib.rs

1use chrono::{DateTime, Utc};
2use hmac::{Hmac, Mac};
3use percent_encoding::percent_encode;
4use sha2::{Digest, Sha256};
5use url::Url;
6
7type HmacSha256 = Hmac<Sha256>;
8
9const LONG_DATETIME_FMT: &str = "%Y%m%dT%H%M%SZ";
10const SHORT_DATE_FMT: &str = "%Y%m%d";
11
12const PERCENT_ENCODING_CHARSET: percent_encoding::AsciiSet = percent_encoding::CONTROLS.add(b'/').add(b':').add(b'+');
13// Safe characters: https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html
14const S3_KEY_PERCENT_ENCODING_CHARSET: percent_encoding::AsciiSet = percent_encoding::NON_ALPHANUMERIC
15    .remove(b'/')
16    .remove(b'-')
17    .remove(b'!')
18    .remove(b'_')
19    .remove(b'.')
20    .remove(b'*')
21    .remove(b'\'')
22    //.remove(b'(') // OCI can't handle this
23    //.remove(b')')
24    .remove(b'~');
25
26/// AWS Credentials
27#[derive(Debug, Clone, PartialEq)]
28pub struct Credentials {
29    /// AWS_ACCESS_KEY_ID,
30    /// The access key applications use for authentication
31    access_key: String,
32    /// AWS_SECRET_ACCESS_KEY
33    /// The secret key applications use for authentication
34    secret_key: String,
35    /// AWS_SESSION_TOKEN
36    // ref: https://docs.aws.amazon.com/STS/latest/APIReference/CommonParameters.html
37    /// The session token applications use for authentication, temporary credentials
38    session_token: Option<String>,
39}
40
41impl Credentials {
42    pub fn new(access_key: &str, secret_key: &str, session_token: Option<&str>) -> Self {
43        Self {
44            access_key: access_key.to_string(),
45            secret_key: secret_key.to_string(),
46            session_token: session_token.map(|s| s.to_string()),
47        }
48    }
49
50    pub fn new_temporary(access_key: &str, secret_key: &str, session_token: &str) -> Self {
51        Self {
52            access_key: access_key.to_string(),
53            secret_key: secret_key.to_string(),
54            session_token: Some(session_token.to_string()),
55        }
56    }
57}
58
59/// S3 Bucket
60#[derive(Debug, Clone)]
61pub struct Bucket {
62    /// AWS_DEFAULT_REGION, AWS_REGION
63    region: String,
64    bucket: String,
65
66    root: String,
67}
68
69impl Bucket {
70    pub fn new(region: &str, bucket: &str) -> Self {
71        Self {
72            region: region.to_string(),
73            bucket: bucket.to_string(),
74            root: "s3.amazonaws.com".to_string(),
75        }
76    }
77
78    pub fn new_with_root(region: &str, bucket: &str, root: &str) -> Self {
79        Self {
80            region: region.to_string(),
81            bucket: bucket.to_string(),
82            root: root.to_string(),
83        }
84    }
85
86    pub fn from_with_root(s: &str, root: &str) -> Self {
87        if s.contains(":") {
88            let mut parts = s.splitn(2, ':');
89            let region = parts.next().unwrap();
90            let bucket = parts.next().unwrap();
91            Self {
92                region: region.to_string(),
93                bucket: bucket.to_string(),
94                root: root.to_string(),
95            }
96        } else {
97            Self {
98                region: "us-east-1".to_string(),
99                bucket: s.to_string(),
100                root: root.to_string(),
101            }
102        }
103    }
104}
105
106impl From<&str> for Bucket {
107    fn from(s: &str) -> Self {
108        Bucket::from_with_root(s, "s3.amazonaws.com")
109    }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq)]
113pub enum AddressingStyle {
114    Virtual,
115    Path,
116}
117
118/// Generate a presigned URL
119#[derive(Debug)]
120pub struct Presigner {
121    credentials: Credentials,
122    bucket: String,
123    root: String,
124    region: String,
125    endpoint: Url,
126    addressing_style: AddressingStyle,
127}
128
129impl Presigner {
130    pub fn new(cred: Credentials, bucket: &str, region: &str) -> Self {
131        Self::new_with_root(cred, bucket, region, "s3.amazonaws.com")
132    }
133
134    pub fn new_with_root(cred: Credentials, bucket: &str, region: &str, root: &str) -> Self {
135        Self {
136            credentials: cred,
137            bucket: bucket.to_string(),
138            root: root.to_string(),
139            region: region.to_string(),
140            endpoint: Url::parse(&format!("https://{}.{}", bucket, root)).unwrap(),
141            addressing_style: AddressingStyle::Virtual,
142        }
143    }
144
145    pub fn from_bucket(credentials: Credentials, bucket: &Bucket) -> Self {
146        Self::new_with_root(credentials, bucket.bucket.as_str(), bucket.region.as_str(), bucket.root.as_str())
147    }
148
149    /// Set the endpoint to use for presigned URLs, also enables path style
150    pub fn endpoint<U: TryInto<Url>>(&mut self, url: U) -> &mut Self
151    where
152        <U as TryInto<Url>>::Error: core::fmt::Debug,
153    {
154        self.endpoint = url.try_into().unwrap();
155        self.addressing_style = AddressingStyle::Path;
156        self
157    }
158
159    pub fn use_path_style(&mut self) -> &mut Self {
160        self.addressing_style = AddressingStyle::Path;
161        if self.endpoint == Url::parse(&format!("https://{}.{}", self.bucket, self.root)).unwrap() {
162            self.endpoint = Url::parse(&format!("https://{}/{}", self.root, self.bucket)).unwrap();
163        }
164        self
165    }
166
167    /// Convert from s3://bucket/key URL to http url
168    pub fn url_for_s3_url(&self, url: &Url) -> Option<Url> {
169        if url.scheme() != "s3" {
170            return None;
171        }
172        let bucket = url.host_str()?;
173        let key = url.path().trim_start_matches('/');
174
175        // S3 has special percent encoding rules for keys
176        // let key = percent_encoding::percent_decode_str(&key);
177        // let key = escape_key(&key.decode_utf8().ok()?);
178
179        match self.addressing_style {
180            AddressingStyle::Virtual => {
181                if bucket != self.bucket {
182                    return None;
183                }
184                self.endpoint.join(&key).ok()
185            }
186            AddressingStyle::Path => {
187                let endpoint = self.endpoint.clone();
188                endpoint.join(&(bucket.to_owned() + "/")).unwrap().join(key).ok()
189            }
190        }
191    }
192
193    pub fn url_for_key(&self, key: &str) -> Option<Url> {
194        if self.bucket.is_empty() {
195            return None;
196        }
197        match self.addressing_style {
198            AddressingStyle::Virtual => self.endpoint.join(key).ok(),
199            AddressingStyle::Path => {
200                let mut endpoint = self.endpoint.clone();
201                endpoint.set_path(&format!("{}/{}", self.bucket, key));
202                Some(endpoint)
203            }
204        }
205    }
206
207    pub fn get(&self, key: &str, expires: i64) -> Option<String> {
208        let url = self.url_for_key(key)?;
209        let now = Utc::now();
210        presigned_url(
211            &self.credentials,
212            expires as _,
213            &url,
214            "GET",
215            "UNSIGNED-PAYLOAD",
216            &self.region,
217            &now,
218            "s3",
219            vec![],
220        )
221    }
222
223    pub fn put(&self, key: &str, expires: i64) -> Option<String> {
224        let url = self.url_for_key(key)?;
225        let now = Utc::now();
226        presigned_url(
227            &self.credentials,
228            expires as _,
229            &url,
230            "PUT",
231            "UNSIGNED-PAYLOAD",
232            &self.region,
233            &now,
234            "s3",
235            vec![],
236        )
237    }
238
239    pub fn url_join(&self, key: &str) -> Option<Url> {
240        self.url_for_key(key)
241    }
242
243    pub fn sign_request(
244        &self,
245        method: &str,
246        url: &Url,
247        expiration: u64,
248        extra_headers: Vec<(String, String)>,
249    ) -> Option<String> {
250        let now = Utc::now();
251        presigned_url(
252            &self.credentials,
253            expiration,
254            url,
255            method,
256            "UNSIGNED-PAYLOAD",
257            &self.region,
258            &now,
259            "s3",
260            extra_headers,
261        )
262    }
263}
264
265/// Generate a presigned GET URL for downloading
266pub fn get(credentials: &Credentials, bucket: &Bucket, key: &str, expires: i64) -> Option<String> {
267    let url = format!("https://{}.{}/{}", bucket.bucket, bucket.root, escape_key(key));
268    let now = Utc::now();
269
270    presigned_url(
271        &credentials,
272        expires as _,
273        &url.parse().unwrap(),
274        "GET",
275        "UNSIGNED-PAYLOAD",
276        &bucket.region,
277        &now,
278        "s3",
279        vec![],
280    )
281}
282
283/// Generate a presigned PUT URL for uploading
284pub fn put(credentials: &Credentials, bucket: &Bucket, key: &str, expires: i64) -> Option<String> {
285    let url = format!("https://{}.{}/{}", bucket.bucket, bucket.root, escape_key(key));
286    /*let url = format!(
287        "https://s3.amazonaws.com/{}/{}",
288        bucket.bucket,
289        escape_key(key)
290    );*/
291    let now = Utc::now();
292
293    presigned_url(
294        credentials,
295        expires as _,
296        &url.parse().unwrap(),
297        "PUT",
298        "UNSIGNED-PAYLOAD",
299        &bucket.region,
300        &now,
301        "s3",
302        vec![],
303    )
304}
305
306fn escape_key(key: &str) -> String {
307    let mut encoded = true;
308    for (i, &c) in key.as_bytes().iter().enumerate() {
309        if c == b'%' {
310            if i + 2 >= key.len() {
311                encoded = false;
312                break;
313            }
314            let c1 = key.as_bytes()[i + 1];
315            let c2 = key.as_bytes()[i + 2];
316            if !matches!(c1, b'a'..=b'f' | b'A'..=b'F' | b'0'..=b'9') {
317                encoded = false;
318                break;
319            }
320            if !matches!(c2, b'a'..=b'f' | b'A'..=b'F' | b'0'..=b'9') {
321                encoded = false;
322                break;
323            }
324        }
325        if !matches!(c, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'/' | b',') {
326            encoded = false;
327            break;
328        }
329    }
330    if encoded {
331        key.to_string() // assume esacped
332    } else {
333        percent_encode(key.as_bytes(), &S3_KEY_PERCENT_ENCODING_CHARSET).to_string()
334    }
335}
336
337/// Generate pre-signed s3 URL
338pub fn presigned_url(
339    credentials: &Credentials,
340    expiration: u64,
341    url: &Url,
342    method: &str,
343    payload_hash: &str,
344    region: &str,
345    date_time: &DateTime<Utc>,
346    service: &str,
347    extra_headers: Vec<(String, String)>,
348) -> Option<String> {
349    let access_key = &credentials.access_key;
350    let secret_key = &credentials.secret_key;
351    let session_token = credentials.session_token.as_ref();
352
353    let date_time_txt = date_time.format(LONG_DATETIME_FMT).to_string();
354    let short_date_time_txt = date_time.format(SHORT_DATE_FMT).to_string();
355    let credentials = format!(
356        "{}/{}/{}/{}/aws4_request",
357        access_key, short_date_time_txt, region, service
358    );
359    let mut params = vec![
360        ("X-Amz-Algorithm".to_string(), "AWS4-HMAC-SHA256".to_string()),
361        ("X-Amz-Credential".to_string(), credentials),
362        ("X-Amz-Date".to_string(), date_time_txt),
363        // only relevant for the S3 service
364        // Ref: https://github.com/aws/aws-sdk-go/issues/2167#issuecomment-430792002
365        ("X-Amz-Expires".to_string(), expiration.to_string()),
366        ("X-Amz-SignedHeaders".to_string(), "host".to_string()),
367    ];
368    for (k, v) in extra_headers {
369        params.push((k, v));
370    }
371    if let Some(session_token) = session_token {
372        params.push(("X-Amz-Security-Token".to_string(), session_token.to_string()));
373    }
374
375    url.query_pairs().for_each(|(k, v)| {
376        params.push((k.to_string(), v.to_string()));
377    });
378
379    params.sort();
380
381    let canonical_query_string = params
382        .iter()
383        .map(|(k, v)| {
384            format!(
385                "{}={}",
386                percent_encode(k.as_bytes(), &PERCENT_ENCODING_CHARSET),
387                percent_encode(v.as_bytes(), &PERCENT_ENCODING_CHARSET)
388            )
389        })
390        .collect::<Vec<_>>()
391        .join("&");
392
393    // NOTE: this is not the same as the canonical query string
394    let query_keys = url.query_pairs().map(|(k, _)| k.to_string()).collect::<Vec<_>>();
395    let query_string = if query_keys.is_empty() {
396        canonical_query_string.clone()
397    } else {
398        params
399            .iter()
400            .filter(|(k, _)| !query_keys.contains(k))
401            .map(|(k, v)| {
402                format!(
403                    "{}={}",
404                    percent_encode(k.as_bytes(), &PERCENT_ENCODING_CHARSET),
405                    percent_encode(v.as_bytes(), &PERCENT_ENCODING_CHARSET)
406                )
407            })
408            .collect::<Vec<_>>()
409            .join("&")
410    };
411
412    let canonical_resource = url.path();
413
414    let mut host = url.host_str().unwrap().to_owned();
415    if let Some(port) = url.port() {
416        host.push(':');
417        host.push_str(&port.to_string());
418    }
419
420    let canonical_headers = format!("host:{}", host);
421    let signed_headers = "host";
422    let canonical_request = format!(
423        "{}\n{}\n{}\n{}\n\n{}\n{}",
424        method.to_uppercase(),
425        canonical_resource,
426        canonical_query_string,
427        canonical_headers,
428        signed_headers,
429        payload_hash
430    );
431
432    let string_to_sign = string_to_sign(&date_time, &region, &canonical_request, service);
433    let signing_key = signing_key(&date_time, secret_key, region, service)?;
434
435    let mut hmac = HmacSha256::new_from_slice(&signing_key).ok()?;
436    hmac.update(string_to_sign.as_bytes());
437    let signature = format!("{:x}", hmac.finalize().into_bytes());
438
439    let request_url = if url.query().is_some() {
440        url.to_string() + "&" + &query_string + "&X-Amz-Signature=" + &signature
441    } else {
442        url.to_string() + "?" + &query_string + "&X-Amz-Signature=" + &signature
443    };
444
445    Some(request_url)
446}
447
448/// Generate the "string to sign" - the value to which the HMAC signing is
449/// applied to sign requests.
450fn string_to_sign(date_time: &DateTime<Utc>, region: &str, canonical_req: &str, service: &str) -> String {
451    let mut hasher = Sha256::default();
452    hasher.update(canonical_req.as_bytes());
453    format!(
454        "AWS4-HMAC-SHA256\n{timestamp}\n{scope}\n{hash}",
455        timestamp = date_time.format(LONG_DATETIME_FMT),
456        scope = scope_string(date_time, region, service),
457        hash = format!("{:x}", hasher.finalize())
458    )
459}
460
461/// Generate the AWS signing key, derived from the secret key, date, region,
462/// and service name.
463fn signing_key(date_time: &DateTime<Utc>, secret_key: &str, region: &str, service: &str) -> Option<Vec<u8>> {
464    let secret = format!("AWS4{}", secret_key);
465    let mut date_hmac = HmacSha256::new_from_slice(secret.as_bytes()).ok()?;
466    date_hmac.update(date_time.format(SHORT_DATE_FMT).to_string().as_bytes());
467    let mut region_hmac = HmacSha256::new_from_slice(&date_hmac.finalize().into_bytes()).ok()?;
468    region_hmac.update(region.to_string().as_bytes());
469    let mut service_hmac = HmacSha256::new_from_slice(&region_hmac.finalize().into_bytes()).ok()?;
470    service_hmac.update(service.as_bytes());
471    let mut signing_hmac = HmacSha256::new_from_slice(&service_hmac.finalize().into_bytes()).ok()?;
472    signing_hmac.update(b"aws4_request");
473    Some(signing_hmac.finalize().into_bytes().to_vec())
474}
475
476/// Generate an AWS scope string.
477fn scope_string(date_time: &DateTime<Utc>, region: &str, service: &str) -> String {
478    format!(
479        "{date}/{region}/{service}/aws4_request",
480        date = date_time.format(SHORT_DATE_FMT),
481        region = region,
482        service = service
483    )
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489
490    #[test]
491    fn test_generate() {
492        let credentials = Credentials {
493            access_key: "ASIAAAAAABBBBBCCCCCDDDDDD".to_string(),
494            secret_key: "AAAAAAA+BBBBBBBB/CCCCCCC/DDDDDDDDDD".to_string(),
495            session_token: Some("xxxxxxxxx".to_string()),
496        };
497
498        let bucket = Bucket {
499            region: "us-east-1".to_string(),
500            bucket: "the-bucket".to_string(),
501            root: "s3.amazonaws.com".to_string(),
502        };
503
504        let s = put(
505            &credentials,
506            &bucket,
507            "5e4ed04f-1d37-4cef-8210-eea624f2aef5/f219644fdfb",
508            600,
509        );
510        assert!(s.is_some());
511        println!("=> {:?}", s);
512    }
513}