1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
use std::error::Error;
use std::io::Read;
use std::marker::PhantomData;
use std::time::Duration;

use rusoto_s3::{GetObjectRequest, S3};
use rusoto_sqs::Message as SqsMessage;
use tokio::prelude::*;
use tracing::info;

use async_trait::async_trait;

use crate::event_decoder::PayloadDecoder;
use std::collections::HashMap;

#[async_trait]
pub trait PayloadRetriever<T> {
    type Message;
    async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<T>, Box<dyn Error>>;
}

#[derive(Clone)]
pub struct S3PayloadRetriever<S, SInit, D, E>
where
    S: S3 + Clone + Send + Sync + 'static,
    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
    D: PayloadDecoder<E> + Clone + Send + 'static,
    E: Send + 'static,
{
    s3_init: SInit,
    s3_clients: HashMap<String, S>,
    decoder: D,
    phantom: PhantomData<E>,
}

impl<S, SInit, D, E> S3PayloadRetriever<S, SInit, D, E>
where
    S: S3 + Clone + Send + Sync + 'static,
    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
    D: PayloadDecoder<E> + Clone + Send + 'static,
    E: Send + 'static,
{
    pub fn new(s3: SInit, decoder: D) -> Self {
        Self {
            s3_init: s3,
            s3_clients: HashMap::new(),
            decoder,
            phantom: PhantomData,
        }
    }

    pub fn get_client(&mut self, region: String) -> S {
        match self.s3_clients.get(&region) {
            Some(s3) => s3.clone(),
            None => {
                let client = (self.s3_init)(region.clone());
                self.s3_clients.insert(region.to_string(), client.clone());
                client
            }
        }
    }
}

#[async_trait]
impl<S, SInit, D, E> PayloadRetriever<E> for S3PayloadRetriever<S, SInit, D, E>
where
    S: S3 + Clone + Send + Sync + 'static,
    SInit: (Fn(String) -> S) + Clone + Send + Sync + 'static,
    D: PayloadDecoder<E> + Clone + Send + 'static,
    E: Send + 'static,
{
    type Message = SqsMessage;
    #[tracing::instrument(skip(self, msg))]
    async fn retrieve_event(&mut self, msg: &Self::Message) -> Result<Option<E>, Box<dyn Error>> {
        let body = msg.body.as_ref().unwrap();
        info!("Got body from message: {}", body);
        let event: serde_json::Value = serde_json::from_str(body)?;

        if let Some(Some(event_str)) = event.get("Event").map(serde_json::Value::as_str) {
            if event_str == "s3:TestEvent" {
                return Ok(None);
            }
        }
        let record = &event["Records"][0]["s3"];

        let bucket = record["bucket"]["name"].as_str().expect("bucket name");
        let key = record["object"]["key"].as_str().expect("object key");

        let region = &event["Records"][0]["awsRegion"].as_str().expect("region");
        let s3 = self.get_client(region.to_string());
        let s3_data = s3.get_object(GetObjectRequest {
            bucket: bucket.to_string(),
            key: key.to_string(),
            ..Default::default()
        });

        let s3_data = tokio::time::timeout(Duration::from_secs(5), s3_data).await??;

        let object_size = record["object"]["size"].as_u64().unwrap_or_default();
        let prealloc = if object_size < 1024 {
            1024
        } else {
            object_size as usize
        };

        info!("Retrieved s3 payload with size : {:?}", prealloc);

        let mut body = Vec::with_capacity(prealloc);

        s3_data
            .body
            .expect("Missing S3 body")
            .into_async_read()
            .read_to_end(&mut body)
            .await?;

        info!("Read s3 payload body");
        self.decoder.decode(body).map(Option::from)
    }
}