1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
use std::error::Error; use std::io::Read; use std::marker::PhantomData; use std::time::Duration; use rusoto_s3::{GetObjectRequest, S3}; use rusoto_sqs::Message as SqsMessage; use tokio::prelude::*; use tracing::info; use async_trait::async_trait; use crate::event_decoder::PayloadDecoder; use std::collections::HashMap; #[async_trait] pub trait PayloadRetriever<T> { type Message; async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<T>, Box<dyn Error>>; } #[derive(Clone)] pub struct S3PayloadRetriever<S, SInit, D, E> where S: S3 + Clone + Send + Sync + 'static, SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static, D: PayloadDecoder<E> + Clone + Send + 'static, E: Send + 'static, { s3_init: SInit, s3_clients: HashMap<String, S>, decoder: D, phantom: PhantomData<E>, } impl<S, SInit, D, E> S3PayloadRetriever<S, SInit, D, E> where S: S3 + Clone + Send + Sync + 'static, SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static, D: PayloadDecoder<E> + Clone + Send + 'static, E: Send + 'static, { pub fn new(s3: SInit, decoder: D) -> Self { Self { s3_init: s3, s3_clients: HashMap::new(), decoder, phantom: PhantomData, } } pub fn get_client(&mut self, region: String) -> S { match self.s3_clients.get(®ion) { Some(s3) => s3.clone(), None => { let client = (self.s3_init)(region.clone()); self.s3_clients.insert(region.to_string(), client.clone()); client } } } } #[async_trait] impl<S, SInit, D, E> PayloadRetriever<E> for S3PayloadRetriever<S, SInit, D, E> where S: S3 + Clone + Send + Sync + 'static, SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static, D: PayloadDecoder<E> + Clone + Send + 'static, E: Send + 'static, { type Message = SqsMessage; #[tracing::instrument(skip(self, msg))] async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<E>, Box<dyn Error>> { let body = msg.body.as_ref().unwrap(); info!("Got body from message: {}", body); let event: serde_json::Value = serde_json::from_str(body)?; if let Some(Some(event_str)) = event.get("Event").map(serde_json::Value::as_str) { if event_str == "s3:TestEvent" { return Ok(None); } } let record = &event["Records"][0]["s3"]; let bucket = record["bucket"]["name"].as_str().expect("bucket name"); let key = record["object"]["key"].as_str().expect("object key"); let region = &event["Records"][0]["awsRegion"].as_str().expect("region"); let s3 = self.get_client(region.to_string()); let s3_data = s3.get_object(GetObjectRequest { bucket: bucket.to_string(), key: key.to_string(), ..Default::default() }); let s3_data = tokio::time::timeout(Duration::from_secs(5), s3_data).await??; let object_size = record["object"]["size"].as_u64().unwrap_or_default(); let prealloc = if object_size < 1024 { 1024 } else { object_size as usize }; info!("Retrieved s3 payload with size : {:?}", prealloc); let mut body = Vec::with_capacity(prealloc); s3_data .body .expect("Missing S3 body") .into_async_read() .read_to_end(&mut body) .await?; info!("Read s3 payload body"); self.decoder.decode(body).map(Option::from) } }