Skip to main content

svanill_vault_server/
rusoto_extra.rs

1use chrono::Datelike;
2use chrono::{DateTime, Utc};
3use rusoto_signature::{region::Region, signature::SignedRequest};
4use serde::ser::{Serialize, SerializeSeq, Serializer};
5use std::collections::HashMap;
6use time::Date;
7
8// Policy explanation:
9// http://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-HTTPPOSTConstructPolicy.html
10
11#[derive(Default)]
12pub struct PostPolicy<'a> {
13    expiration: Option<&'a DateTime<Utc>>,
14    content_length_range: Option<(u64, u64)>,
15    conditions: Vec<Condition<'a>>,
16    form_data: HashMap<String, String>,
17    bucket_name: Option<&'a str>,
18    key: Option<&'a str>,
19    region: Option<&'a Region>,
20    access_key_id: Option<&'a str>,
21    secret_access_key: Option<&'a str>,
22}
23
24#[derive(Serialize)]
25pub struct SerializablePolicy<'a> {
26    expiration: &'a str,
27    conditions: &'a Vec<Condition<'a>>,
28}
29
30struct Condition<'a>((&'a str, &'a str, &'a str));
31
32impl<'a> Serialize for Condition<'a> {
33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34    where
35        S: Serializer,
36    {
37        let mut seq = serializer.serialize_seq(Some(3))?;
38        let v = &self.0;
39        seq.serialize_element(v.0)?;
40
41        if v.0 == "content-length-range" {
42            seq.serialize_element(&v.1.parse::<u64>().map_err(|_| {
43                serde::ser::Error::custom("expected u64 value, the minimum content length")
44            })?)?;
45            seq.serialize_element(&v.2.parse::<u64>().map_err(|_| {
46                serde::ser::Error::custom("expected u64 value, the maximum content length")
47            })?)?;
48        } else {
49            seq.serialize_element(v.1)?;
50            seq.serialize_element(v.2)?;
51        }
52
53        seq.end()
54    }
55}
56
57impl<'a> PostPolicy<'a> {
58    /// Set expiration time
59    pub fn set_expiration(mut self, t: &'a DateTime<Utc>) -> Self {
60        self.expiration = Some(t);
61        self
62    }
63
64    /// Set key policy condition
65    pub fn set_key(mut self, key: &'a str) -> Self {
66        if key.is_empty() {
67            return self;
68        }
69
70        self = self.append_policy("eq", "$key", &key);
71        self.key = Some(key);
72        self.form_data.insert("key".to_string(), key.to_string());
73        self
74    }
75
76    /// Set key startswith policy condition
77    #[allow(dead_code)]
78    pub fn set_key_startswith(mut self, key: &'a str) -> Self {
79        if key.is_empty() {
80            return self;
81        }
82
83        self.key = Some(key);
84
85        self = self.append_policy("starts-with", "$key", &key);
86        self.form_data.insert("key".to_string(), key.to_string());
87        self
88    }
89
90    /// Set bucket name
91    pub fn set_bucket_name(mut self, bucket_name: &'a str) -> Self {
92        self.form_data
93            .insert("bucket".to_string(), bucket_name.to_string());
94        self = self.append_policy("eq", "$bucket", bucket_name);
95        self.bucket_name = Some(bucket_name);
96        self
97    }
98
99    /// Set region
100    pub fn set_region(mut self, region: &'a Region) -> Self {
101        self.region = Some(region);
102        self
103    }
104
105    /// Set access key id
106    pub fn set_access_key_id(mut self, access_key_id: &'a str) -> Self {
107        if access_key_id.is_empty() {
108            return self;
109        }
110
111        self.access_key_id = Some(access_key_id);
112        self
113    }
114
115    /// Set secret access key
116    pub fn set_secret_access_key(mut self, secret_access_key: &'a str) -> Self {
117        if secret_access_key.is_empty() {
118            return self;
119        }
120
121        self.secret_access_key = Some(secret_access_key);
122        self
123    }
124
125    /// Set content-type policy condition
126    #[allow(dead_code)]
127    pub fn set_content_type(mut self, content_type: &'a str) -> Self {
128        self.form_data
129            .insert("Content-Type".to_string(), content_type.to_string());
130        self = self.append_policy("eq", "$Content-Type", content_type);
131        self
132    }
133
134    /// Set content length range policy condition
135    pub fn set_content_length_range(mut self, min_length: u64, max_length: u64) -> Self {
136        self.content_length_range = Some((min_length, max_length));
137        // We should append the policy here, but ownership it's tricky,
138        // so we'll do it inside build_form_data()
139        self
140    }
141
142    /// Append policy condition
143    pub fn append_policy(mut self, match_type: &'a str, target: &'a str, value: &'a str) -> Self {
144        self.conditions.push(Condition((match_type, target, value)));
145        self
146    }
147
148    /// Create the form data using the policy
149    pub fn build_form_data(mut self) -> Result<(String, HashMap<String, String>), String> {
150        match self.content_length_range {
151            Some((min_length, max_length)) if min_length > max_length => {
152                return Err(format!(
153                    "Min-length ({}) must be <= Max-length ({})",
154                    min_length, max_length
155                ));
156            }
157            _ => (),
158        }
159
160        if self.expiration.is_none() {
161            return Err("Expiration date must be specified".to_string());
162        }
163
164        if self.key.is_none() {
165            return Err("Object key must be specified".to_string());
166        }
167
168        if self.bucket_name.is_none() {
169            return Err("Bucket name must be specified".to_string());
170        }
171
172        if self.region.is_none() {
173            return Err("Region must be specified".to_string());
174        }
175
176        if self.access_key_id.is_none() {
177            return Err("Access key id must be specified".to_string());
178        }
179
180        if self.secret_access_key.is_none() {
181            return Err("Secret access key must be specified".to_string());
182        }
183
184        let bucket_name = self.bucket_name.unwrap();
185        let secret_access_key = self.secret_access_key.unwrap();
186
187        let expiration = self
188            .expiration
189            .unwrap()
190            .format("%Y-%m-%dT%H:%M:%S.000Z")
191            .to_string();
192
193        let current_time = if cfg!(test) {
194            use chrono::TimeZone;
195            Utc.ymd(2020, 1, 1).and_hms(0, 0, 0)
196        } else {
197            Utc::now()
198        };
199        let current_time_fmted = current_time.format("%Y%m%dT%H%M%SZ").to_string();
200        let current_date = current_time.format("%Y%m%d").to_string();
201
202        let access_key_id = self.access_key_id.unwrap();
203        let region = self.region.unwrap();
204        let region_name = region.name();
205
206        let x_amz_credential = format!(
207            "{}/{}/{}/{}/aws4_request",
208            &access_key_id, &current_date, &region_name, "s3",
209        );
210
211        let mut conditions: Vec<Condition> = self.conditions.into_iter().collect();
212
213        conditions.push(Condition(("eq", "$x-amz-date", &current_time_fmted)));
214        conditions.push(Condition(("eq", "$x-amz-algorithm", "AWS4-HMAC-SHA256")));
215        conditions.push(Condition(("eq", "$x-amz-credential", &x_amz_credential)));
216
217        let min_length_as_string: String;
218        let max_length_as_string: String;
219        if let Some((min, max)) = self.content_length_range {
220            min_length_as_string = min.to_string();
221            max_length_as_string = max.to_string();
222            conditions.push(Condition((
223                "content-length-range",
224                &min_length_as_string,
225                &max_length_as_string,
226            )))
227        }
228
229        let policy_to_serialize = SerializablePolicy {
230            expiration: &expiration,
231            conditions: &conditions,
232        };
233
234        let policy_as_json =
235            serde_json::to_string(&policy_to_serialize).map_err(|e| format!("{:?}", e))?;
236
237        let policy_as_base64 = base64::encode(policy_as_json);
238
239        let signature_date = Date::try_from_ymd(
240            current_time.date().year() as i32,
241            current_time.date().month() as u8,
242            current_time.date().day() as u8,
243        )
244        .unwrap();
245
246        let x_amz_signature = signature::sign_string(
247            &policy_as_base64,
248            &secret_access_key,
249            signature_date,
250            &region_name,
251            "s3",
252        );
253
254        self.form_data
255            .insert("policy".to_string(), policy_as_base64);
256        self.form_data
257            .insert("x-amz-date".to_string(), current_time_fmted);
258        self.form_data.insert(
259            "x-amz-algorithm".to_string(),
260            "AWS4-HMAC-SHA256".to_string(),
261        );
262        self.form_data
263            .insert("x-amz-credential".to_string(), x_amz_credential);
264        self.form_data
265            .insert("x-amz-signature".to_string(), x_amz_signature);
266
267        let signed_request = SignedRequest::new("GET", "s3", &region, "/");
268
269        let upload_url = format!(
270            "{}://{}.{}",
271            signed_request.scheme(),
272            bucket_name,
273            signed_request.hostname()
274        );
275
276        Ok((upload_url, self.form_data))
277    }
278}
279
280// Copied from rusoto/signature/src/signature.rs
281// because `sign_string` was not public and I wanted to
282// implement generate_presigned_post_policy in a way that
283// could be easily implemented in rusoto_signature
284mod signature {
285    use hmac::{Hmac, Mac};
286    use sha2::Sha256;
287    use time::Date;
288
289    #[inline]
290    fn hmac(secret: &[u8], message: &[u8]) -> Hmac<Sha256> {
291        let mut hmac = Hmac::<Sha256>::new_varkey(secret).expect("failed to create hmac");
292        hmac.input(message);
293        hmac
294    }
295
296    /// Takes a message and signs it using AWS secret, time, region keys and service keys.
297    pub fn sign_string(
298        string_to_sign: &str,
299        secret: &str,
300        date: Date,
301        region: &str,
302        service: &str,
303    ) -> String {
304        let date_str = date.format("%Y%m%d");
305        let date_hmac = hmac(format!("AWS4{}", secret).as_bytes(), date_str.as_bytes())
306            .result()
307            .code();
308        let region_hmac = hmac(date_hmac.as_ref(), region.as_bytes()).result().code();
309        let service_hmac = hmac(region_hmac.as_ref(), service.as_bytes())
310            .result()
311            .code();
312        let signing_hmac = hmac(service_hmac.as_ref(), b"aws4_request").result().code();
313        hex::encode(
314            hmac(signing_hmac.as_ref(), string_to_sign.as_bytes())
315                .result()
316                .code()
317                .as_ref(),
318        )
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325    use chrono::prelude::*;
326
327    const BUCKET: &str = "the-bucket";
328    const REGION: Region = Region::EuCentral1;
329    const ACCESS_KEY_ID: &str = "foo_access_key";
330    const SECRET_ACCESS_KEY: &str = "foo_secret_key";
331    const OBJECT_KEY: &str = "the-object-key";
332
333    #[test]
334    fn bucket_name_is_required() {
335        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
336
337        let res = PostPolicy::default()
338            .set_region(&REGION)
339            .set_access_key_id(ACCESS_KEY_ID)
340            .set_secret_access_key(SECRET_ACCESS_KEY)
341            .set_key(OBJECT_KEY)
342            .set_expiration(&expiration_date)
343            .build_form_data();
344
345        assert_eq!(res, Err("Bucket name must be specified".to_string()));
346    }
347
348    #[test]
349    fn region_is_required() {
350        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
351
352        let res = PostPolicy::default()
353            .set_bucket_name(&BUCKET)
354            .set_access_key_id(ACCESS_KEY_ID)
355            .set_secret_access_key(SECRET_ACCESS_KEY)
356            .set_key(OBJECT_KEY)
357            .set_expiration(&expiration_date)
358            .build_form_data();
359
360        assert_eq!(res, Err("Region must be specified".to_string()));
361    }
362    #[test]
363    fn access_key_id_is_required() {
364        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
365
366        let res = PostPolicy::default()
367            .set_bucket_name(&BUCKET)
368            .set_region(&REGION)
369            .set_secret_access_key(SECRET_ACCESS_KEY)
370            .set_key(OBJECT_KEY)
371            .set_expiration(&expiration_date)
372            .build_form_data();
373
374        assert_eq!(res, Err("Access key id must be specified".to_string()));
375    }
376
377    #[test]
378    fn secret_access_key_is_required() {
379        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
380
381        let res = PostPolicy::default()
382            .set_bucket_name(&BUCKET)
383            .set_region(&REGION)
384            .set_access_key_id(ACCESS_KEY_ID)
385            .set_key(OBJECT_KEY)
386            .set_expiration(&expiration_date)
387            .build_form_data();
388
389        assert_eq!(res, Err("Secret access key must be specified".to_string()));
390    }
391
392    #[test]
393    fn expiration_is_required() {
394        let res = PostPolicy::default()
395            .set_bucket_name(&BUCKET)
396            .set_region(&REGION)
397            .set_access_key_id(ACCESS_KEY_ID)
398            .set_key(OBJECT_KEY)
399            .build_form_data();
400
401        assert_eq!(res, Err("Expiration date must be specified".to_string()));
402    }
403    #[test]
404    fn build_successfully() {
405        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
406
407        let res = PostPolicy::default()
408            .set_bucket_name(BUCKET)
409            .set_region(&REGION)
410            .set_access_key_id(ACCESS_KEY_ID)
411            .set_secret_access_key(SECRET_ACCESS_KEY)
412            .set_key(OBJECT_KEY)
413            .set_expiration(&expiration_date)
414            .set_content_length_range(123, 456)
415            .build_form_data();
416
417        assert!(res.is_ok());
418        let (upload_url, form_data) = res.unwrap();
419        assert_eq!(
420            upload_url,
421            "https://the-bucket.s3.eu-central-1.amazonaws.com"
422        );
423        assert_eq!(form_data.get("key").unwrap(), "the-object-key");
424
425        assert_eq!(form_data.get("bucket").unwrap(), "the-bucket");
426        assert_eq!(
427            form_data.get("x-amz-algorithm").unwrap(),
428            "AWS4-HMAC-SHA256"
429        );
430        assert_eq!(
431            form_data.get("x-amz-credential").unwrap(),
432            "foo_access_key/20200101/eu-central-1/s3/aws4_request"
433        );
434        assert_eq!(form_data.get("x-amz-date").unwrap(), "20200101T000000Z");
435
436        let expected_policy = serde_json::json!({
437            "expiration": "2020-01-01T01:02:03.000Z",
438            "conditions": [
439                ["eq", "$bucket", "the-bucket"],
440                ["eq", "$key", "the-object-key"],
441                ["eq", "$x-amz-date", "20200101T000000Z"],
442                ["eq", "$x-amz-algorithm", "AWS4-HMAC-SHA256"],
443                ["eq", "$x-amz-credential", "foo_access_key/20200101/eu-central-1/s3/aws4_request"],
444                ["content-length-range", 123, 456]
445            ]
446        });
447
448        let policy_as_base64 = form_data.get("policy").unwrap();
449        let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
450        let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
451        assert_eq!(policy, expected_policy);
452    }
453
454    #[test]
455    fn set_content_type() {
456        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
457
458        let res = PostPolicy::default()
459            .set_content_type("some/type")
460            .set_bucket_name(BUCKET)
461            .set_region(&REGION)
462            .set_access_key_id(ACCESS_KEY_ID)
463            .set_secret_access_key(SECRET_ACCESS_KEY)
464            .set_key(OBJECT_KEY)
465            .set_expiration(&expiration_date)
466            .build_form_data();
467
468        assert!(res.is_ok());
469
470        let (_, form_data) = res.unwrap();
471        dbg!(&form_data);
472        assert_eq!(form_data.get("Content-Type").unwrap(), "some/type");
473
474        let policy_as_base64 = form_data.get("policy").unwrap();
475        let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
476        let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
477        let conditions = policy["conditions"].as_array().unwrap();
478        assert!(conditions.contains(&serde_json::json!(["eq", "$Content-Type", "some/type"])));
479    }
480
481    #[test]
482    fn append_policy() {
483        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
484
485        let res = PostPolicy::default()
486            .append_policy("a", "b", "c")
487            .set_bucket_name(BUCKET)
488            .set_region(&REGION)
489            .set_access_key_id(ACCESS_KEY_ID)
490            .set_secret_access_key(SECRET_ACCESS_KEY)
491            .set_key(OBJECT_KEY)
492            .set_expiration(&expiration_date)
493            .build_form_data();
494
495        let (_, form_data) = res.unwrap();
496
497        assert_eq!(form_data.get("a"), None);
498
499        let policy_as_base64 = form_data.get("policy").unwrap();
500        let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
501        let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
502        let conditions = policy["conditions"].as_array().unwrap();
503        assert!(conditions.contains(&serde_json::json!(["a", "b", "c"])));
504    }
505
506    #[test]
507    fn set_key_startswith() {
508        let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
509
510        let res = PostPolicy::default()
511            .set_key_startswith("foo")
512            .set_bucket_name(BUCKET)
513            .set_region(&REGION)
514            .set_access_key_id(ACCESS_KEY_ID)
515            .set_secret_access_key(SECRET_ACCESS_KEY)
516            .set_expiration(&expiration_date)
517            .build_form_data();
518
519        let (_, form_data) = res.unwrap();
520        dbg!(&form_data);
521        assert_eq!(form_data.get("key").unwrap(), "foo");
522
523        let policy_as_base64 = form_data.get("policy").unwrap();
524        let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
525        let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
526        let conditions = policy["conditions"].as_array().unwrap();
527        assert!(conditions.contains(&serde_json::json!(["starts-with", "$key", "foo"])));
528    }
529}