1use chrono::Datelike;
2use chrono::{DateTime, Utc};
3use rusoto_signature::{region::Region, signature::SignedRequest};
4use serde::ser::{Serialize, SerializeSeq, Serializer};
5use std::collections::HashMap;
6use time::Date;
7
8#[derive(Default)]
12pub struct PostPolicy<'a> {
13 expiration: Option<&'a DateTime<Utc>>,
14 content_length_range: Option<(u64, u64)>,
15 conditions: Vec<Condition<'a>>,
16 form_data: HashMap<String, String>,
17 bucket_name: Option<&'a str>,
18 key: Option<&'a str>,
19 region: Option<&'a Region>,
20 access_key_id: Option<&'a str>,
21 secret_access_key: Option<&'a str>,
22}
23
24#[derive(Serialize)]
25pub struct SerializablePolicy<'a> {
26 expiration: &'a str,
27 conditions: &'a Vec<Condition<'a>>,
28}
29
30struct Condition<'a>((&'a str, &'a str, &'a str));
31
32impl<'a> Serialize for Condition<'a> {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 let mut seq = serializer.serialize_seq(Some(3))?;
38 let v = &self.0;
39 seq.serialize_element(v.0)?;
40
41 if v.0 == "content-length-range" {
42 seq.serialize_element(&v.1.parse::<u64>().map_err(|_| {
43 serde::ser::Error::custom("expected u64 value, the minimum content length")
44 })?)?;
45 seq.serialize_element(&v.2.parse::<u64>().map_err(|_| {
46 serde::ser::Error::custom("expected u64 value, the maximum content length")
47 })?)?;
48 } else {
49 seq.serialize_element(v.1)?;
50 seq.serialize_element(v.2)?;
51 }
52
53 seq.end()
54 }
55}
56
57impl<'a> PostPolicy<'a> {
58 pub fn set_expiration(mut self, t: &'a DateTime<Utc>) -> Self {
60 self.expiration = Some(t);
61 self
62 }
63
64 pub fn set_key(mut self, key: &'a str) -> Self {
66 if key.is_empty() {
67 return self;
68 }
69
70 self = self.append_policy("eq", "$key", &key);
71 self.key = Some(key);
72 self.form_data.insert("key".to_string(), key.to_string());
73 self
74 }
75
76 #[allow(dead_code)]
78 pub fn set_key_startswith(mut self, key: &'a str) -> Self {
79 if key.is_empty() {
80 return self;
81 }
82
83 self.key = Some(key);
84
85 self = self.append_policy("starts-with", "$key", &key);
86 self.form_data.insert("key".to_string(), key.to_string());
87 self
88 }
89
90 pub fn set_bucket_name(mut self, bucket_name: &'a str) -> Self {
92 self.form_data
93 .insert("bucket".to_string(), bucket_name.to_string());
94 self = self.append_policy("eq", "$bucket", bucket_name);
95 self.bucket_name = Some(bucket_name);
96 self
97 }
98
99 pub fn set_region(mut self, region: &'a Region) -> Self {
101 self.region = Some(region);
102 self
103 }
104
105 pub fn set_access_key_id(mut self, access_key_id: &'a str) -> Self {
107 if access_key_id.is_empty() {
108 return self;
109 }
110
111 self.access_key_id = Some(access_key_id);
112 self
113 }
114
115 pub fn set_secret_access_key(mut self, secret_access_key: &'a str) -> Self {
117 if secret_access_key.is_empty() {
118 return self;
119 }
120
121 self.secret_access_key = Some(secret_access_key);
122 self
123 }
124
125 #[allow(dead_code)]
127 pub fn set_content_type(mut self, content_type: &'a str) -> Self {
128 self.form_data
129 .insert("Content-Type".to_string(), content_type.to_string());
130 self = self.append_policy("eq", "$Content-Type", content_type);
131 self
132 }
133
134 pub fn set_content_length_range(mut self, min_length: u64, max_length: u64) -> Self {
136 self.content_length_range = Some((min_length, max_length));
137 self
140 }
141
142 pub fn append_policy(mut self, match_type: &'a str, target: &'a str, value: &'a str) -> Self {
144 self.conditions.push(Condition((match_type, target, value)));
145 self
146 }
147
148 pub fn build_form_data(mut self) -> Result<(String, HashMap<String, String>), String> {
150 match self.content_length_range {
151 Some((min_length, max_length)) if min_length > max_length => {
152 return Err(format!(
153 "Min-length ({}) must be <= Max-length ({})",
154 min_length, max_length
155 ));
156 }
157 _ => (),
158 }
159
160 if self.expiration.is_none() {
161 return Err("Expiration date must be specified".to_string());
162 }
163
164 if self.key.is_none() {
165 return Err("Object key must be specified".to_string());
166 }
167
168 if self.bucket_name.is_none() {
169 return Err("Bucket name must be specified".to_string());
170 }
171
172 if self.region.is_none() {
173 return Err("Region must be specified".to_string());
174 }
175
176 if self.access_key_id.is_none() {
177 return Err("Access key id must be specified".to_string());
178 }
179
180 if self.secret_access_key.is_none() {
181 return Err("Secret access key must be specified".to_string());
182 }
183
184 let bucket_name = self.bucket_name.unwrap();
185 let secret_access_key = self.secret_access_key.unwrap();
186
187 let expiration = self
188 .expiration
189 .unwrap()
190 .format("%Y-%m-%dT%H:%M:%S.000Z")
191 .to_string();
192
193 let current_time = if cfg!(test) {
194 use chrono::TimeZone;
195 Utc.ymd(2020, 1, 1).and_hms(0, 0, 0)
196 } else {
197 Utc::now()
198 };
199 let current_time_fmted = current_time.format("%Y%m%dT%H%M%SZ").to_string();
200 let current_date = current_time.format("%Y%m%d").to_string();
201
202 let access_key_id = self.access_key_id.unwrap();
203 let region = self.region.unwrap();
204 let region_name = region.name();
205
206 let x_amz_credential = format!(
207 "{}/{}/{}/{}/aws4_request",
208 &access_key_id, ¤t_date, ®ion_name, "s3",
209 );
210
211 let mut conditions: Vec<Condition> = self.conditions.into_iter().collect();
212
213 conditions.push(Condition(("eq", "$x-amz-date", ¤t_time_fmted)));
214 conditions.push(Condition(("eq", "$x-amz-algorithm", "AWS4-HMAC-SHA256")));
215 conditions.push(Condition(("eq", "$x-amz-credential", &x_amz_credential)));
216
217 let min_length_as_string: String;
218 let max_length_as_string: String;
219 if let Some((min, max)) = self.content_length_range {
220 min_length_as_string = min.to_string();
221 max_length_as_string = max.to_string();
222 conditions.push(Condition((
223 "content-length-range",
224 &min_length_as_string,
225 &max_length_as_string,
226 )))
227 }
228
229 let policy_to_serialize = SerializablePolicy {
230 expiration: &expiration,
231 conditions: &conditions,
232 };
233
234 let policy_as_json =
235 serde_json::to_string(&policy_to_serialize).map_err(|e| format!("{:?}", e))?;
236
237 let policy_as_base64 = base64::encode(policy_as_json);
238
239 let signature_date = Date::try_from_ymd(
240 current_time.date().year() as i32,
241 current_time.date().month() as u8,
242 current_time.date().day() as u8,
243 )
244 .unwrap();
245
246 let x_amz_signature = signature::sign_string(
247 &policy_as_base64,
248 &secret_access_key,
249 signature_date,
250 ®ion_name,
251 "s3",
252 );
253
254 self.form_data
255 .insert("policy".to_string(), policy_as_base64);
256 self.form_data
257 .insert("x-amz-date".to_string(), current_time_fmted);
258 self.form_data.insert(
259 "x-amz-algorithm".to_string(),
260 "AWS4-HMAC-SHA256".to_string(),
261 );
262 self.form_data
263 .insert("x-amz-credential".to_string(), x_amz_credential);
264 self.form_data
265 .insert("x-amz-signature".to_string(), x_amz_signature);
266
267 let signed_request = SignedRequest::new("GET", "s3", ®ion, "/");
268
269 let upload_url = format!(
270 "{}://{}.{}",
271 signed_request.scheme(),
272 bucket_name,
273 signed_request.hostname()
274 );
275
276 Ok((upload_url, self.form_data))
277 }
278}
279
280mod signature {
285 use hmac::{Hmac, Mac};
286 use sha2::Sha256;
287 use time::Date;
288
289 #[inline]
290 fn hmac(secret: &[u8], message: &[u8]) -> Hmac<Sha256> {
291 let mut hmac = Hmac::<Sha256>::new_varkey(secret).expect("failed to create hmac");
292 hmac.input(message);
293 hmac
294 }
295
296 pub fn sign_string(
298 string_to_sign: &str,
299 secret: &str,
300 date: Date,
301 region: &str,
302 service: &str,
303 ) -> String {
304 let date_str = date.format("%Y%m%d");
305 let date_hmac = hmac(format!("AWS4{}", secret).as_bytes(), date_str.as_bytes())
306 .result()
307 .code();
308 let region_hmac = hmac(date_hmac.as_ref(), region.as_bytes()).result().code();
309 let service_hmac = hmac(region_hmac.as_ref(), service.as_bytes())
310 .result()
311 .code();
312 let signing_hmac = hmac(service_hmac.as_ref(), b"aws4_request").result().code();
313 hex::encode(
314 hmac(signing_hmac.as_ref(), string_to_sign.as_bytes())
315 .result()
316 .code()
317 .as_ref(),
318 )
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325 use chrono::prelude::*;
326
327 const BUCKET: &str = "the-bucket";
328 const REGION: Region = Region::EuCentral1;
329 const ACCESS_KEY_ID: &str = "foo_access_key";
330 const SECRET_ACCESS_KEY: &str = "foo_secret_key";
331 const OBJECT_KEY: &str = "the-object-key";
332
333 #[test]
334 fn bucket_name_is_required() {
335 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
336
337 let res = PostPolicy::default()
338 .set_region(®ION)
339 .set_access_key_id(ACCESS_KEY_ID)
340 .set_secret_access_key(SECRET_ACCESS_KEY)
341 .set_key(OBJECT_KEY)
342 .set_expiration(&expiration_date)
343 .build_form_data();
344
345 assert_eq!(res, Err("Bucket name must be specified".to_string()));
346 }
347
348 #[test]
349 fn region_is_required() {
350 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
351
352 let res = PostPolicy::default()
353 .set_bucket_name(&BUCKET)
354 .set_access_key_id(ACCESS_KEY_ID)
355 .set_secret_access_key(SECRET_ACCESS_KEY)
356 .set_key(OBJECT_KEY)
357 .set_expiration(&expiration_date)
358 .build_form_data();
359
360 assert_eq!(res, Err("Region must be specified".to_string()));
361 }
362 #[test]
363 fn access_key_id_is_required() {
364 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
365
366 let res = PostPolicy::default()
367 .set_bucket_name(&BUCKET)
368 .set_region(®ION)
369 .set_secret_access_key(SECRET_ACCESS_KEY)
370 .set_key(OBJECT_KEY)
371 .set_expiration(&expiration_date)
372 .build_form_data();
373
374 assert_eq!(res, Err("Access key id must be specified".to_string()));
375 }
376
377 #[test]
378 fn secret_access_key_is_required() {
379 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
380
381 let res = PostPolicy::default()
382 .set_bucket_name(&BUCKET)
383 .set_region(®ION)
384 .set_access_key_id(ACCESS_KEY_ID)
385 .set_key(OBJECT_KEY)
386 .set_expiration(&expiration_date)
387 .build_form_data();
388
389 assert_eq!(res, Err("Secret access key must be specified".to_string()));
390 }
391
392 #[test]
393 fn expiration_is_required() {
394 let res = PostPolicy::default()
395 .set_bucket_name(&BUCKET)
396 .set_region(®ION)
397 .set_access_key_id(ACCESS_KEY_ID)
398 .set_key(OBJECT_KEY)
399 .build_form_data();
400
401 assert_eq!(res, Err("Expiration date must be specified".to_string()));
402 }
403 #[test]
404 fn build_successfully() {
405 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
406
407 let res = PostPolicy::default()
408 .set_bucket_name(BUCKET)
409 .set_region(®ION)
410 .set_access_key_id(ACCESS_KEY_ID)
411 .set_secret_access_key(SECRET_ACCESS_KEY)
412 .set_key(OBJECT_KEY)
413 .set_expiration(&expiration_date)
414 .set_content_length_range(123, 456)
415 .build_form_data();
416
417 assert!(res.is_ok());
418 let (upload_url, form_data) = res.unwrap();
419 assert_eq!(
420 upload_url,
421 "https://the-bucket.s3.eu-central-1.amazonaws.com"
422 );
423 assert_eq!(form_data.get("key").unwrap(), "the-object-key");
424
425 assert_eq!(form_data.get("bucket").unwrap(), "the-bucket");
426 assert_eq!(
427 form_data.get("x-amz-algorithm").unwrap(),
428 "AWS4-HMAC-SHA256"
429 );
430 assert_eq!(
431 form_data.get("x-amz-credential").unwrap(),
432 "foo_access_key/20200101/eu-central-1/s3/aws4_request"
433 );
434 assert_eq!(form_data.get("x-amz-date").unwrap(), "20200101T000000Z");
435
436 let expected_policy = serde_json::json!({
437 "expiration": "2020-01-01T01:02:03.000Z",
438 "conditions": [
439 ["eq", "$bucket", "the-bucket"],
440 ["eq", "$key", "the-object-key"],
441 ["eq", "$x-amz-date", "20200101T000000Z"],
442 ["eq", "$x-amz-algorithm", "AWS4-HMAC-SHA256"],
443 ["eq", "$x-amz-credential", "foo_access_key/20200101/eu-central-1/s3/aws4_request"],
444 ["content-length-range", 123, 456]
445 ]
446 });
447
448 let policy_as_base64 = form_data.get("policy").unwrap();
449 let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
450 let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
451 assert_eq!(policy, expected_policy);
452 }
453
454 #[test]
455 fn set_content_type() {
456 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
457
458 let res = PostPolicy::default()
459 .set_content_type("some/type")
460 .set_bucket_name(BUCKET)
461 .set_region(®ION)
462 .set_access_key_id(ACCESS_KEY_ID)
463 .set_secret_access_key(SECRET_ACCESS_KEY)
464 .set_key(OBJECT_KEY)
465 .set_expiration(&expiration_date)
466 .build_form_data();
467
468 assert!(res.is_ok());
469
470 let (_, form_data) = res.unwrap();
471 dbg!(&form_data);
472 assert_eq!(form_data.get("Content-Type").unwrap(), "some/type");
473
474 let policy_as_base64 = form_data.get("policy").unwrap();
475 let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
476 let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
477 let conditions = policy["conditions"].as_array().unwrap();
478 assert!(conditions.contains(&serde_json::json!(["eq", "$Content-Type", "some/type"])));
479 }
480
481 #[test]
482 fn append_policy() {
483 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
484
485 let res = PostPolicy::default()
486 .append_policy("a", "b", "c")
487 .set_bucket_name(BUCKET)
488 .set_region(®ION)
489 .set_access_key_id(ACCESS_KEY_ID)
490 .set_secret_access_key(SECRET_ACCESS_KEY)
491 .set_key(OBJECT_KEY)
492 .set_expiration(&expiration_date)
493 .build_form_data();
494
495 let (_, form_data) = res.unwrap();
496
497 assert_eq!(form_data.get("a"), None);
498
499 let policy_as_base64 = form_data.get("policy").unwrap();
500 let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
501 let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
502 let conditions = policy["conditions"].as_array().unwrap();
503 assert!(conditions.contains(&serde_json::json!(["a", "b", "c"])));
504 }
505
506 #[test]
507 fn set_key_startswith() {
508 let expiration_date = Utc.ymd(2020, 1, 1).and_hms(1, 2, 3);
509
510 let res = PostPolicy::default()
511 .set_key_startswith("foo")
512 .set_bucket_name(BUCKET)
513 .set_region(®ION)
514 .set_access_key_id(ACCESS_KEY_ID)
515 .set_secret_access_key(SECRET_ACCESS_KEY)
516 .set_expiration(&expiration_date)
517 .build_form_data();
518
519 let (_, form_data) = res.unwrap();
520 dbg!(&form_data);
521 assert_eq!(form_data.get("key").unwrap(), "foo");
522
523 let policy_as_base64 = form_data.get("policy").unwrap();
524 let policy_as_vec_u8 = base64::decode(policy_as_base64).unwrap();
525 let policy: serde_json::Value = serde_json::from_slice(&policy_as_vec_u8).unwrap();
526 let conditions = policy["conditions"].as_array().unwrap();
527 assert!(conditions.contains(&serde_json::json!(["starts-with", "$key", "foo"])));
528 }
529}