use std::fmt::Debug;
use std::time::{Duration, Instant};
use futures::compat::Future01CompatExt;
use log::*;
use rusoto_sqs::{DeleteMessageBatchRequest, DeleteMessageBatchRequestEntry, Sqs, SqsClient};
use rusoto_sqs::Message as SqsMessage;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use async_trait::async_trait;
use crate::cache::Cache;
use crate::completion_event_serializer::CompletionEventSerializer;
use crate::event_emitter::EventEmitter;
use crate::event_handler::{OutputEvent, Completion};
use aktors::actor::Actor;
use crate::completion_handler::CompletionHandler;
pub struct CompletionPolicy {
max_messages: u16,
max_time_between_flushes: Duration,
last_flush: Instant,
}
impl CompletionPolicy {
pub fn new(max_messages: u16, max_time_between_flushes: Duration) -> Self {
Self {
max_messages,
max_time_between_flushes,
last_flush: Instant::now(),
}
}
pub fn should_flush(&self, cur_messages: u16) -> bool {
cur_messages >= self.max_messages ||
Instant::now().checked_duration_since(self.last_flush).unwrap() >= self.max_time_between_flushes
}
pub fn set_last_flush(&mut self) {
self.last_flush = Instant::now();
}
}
pub struct SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
where
SqsT: Sqs + Clone + Send + Sync + 'static,
CPE: Debug + Send + Sync + 'static,
CP: CompletionEventSerializer<CompletedEvent=CE, Output=Payload, Error=CPE> + Send + Sync + 'static,
Payload: Send + Sync + 'static,
CE: Send + Sync + Clone + 'static,
EE: EventEmitter<Event=Payload> + Send + Sync + 'static,
OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>) + Send + Sync + 'static,
CacheT: Cache<ProcErr> + Send + Sync + Clone + 'static,
ProcErr: Debug + Clone + Send + Sync + 'static,
{
sqs_client: SqsT,
queue_url: String,
completed_events: Vec<CE>,
identities: Vec<Vec<u8>>,
completed_messages: Vec<SqsMessage>,
completion_serializer: CP,
event_emitter: EE,
completion_policy: CompletionPolicy,
on_ack: OA,
self_actor: Option<SqsCompletionHandlerActor<CE, ProcErr, SqsT>>,
cache: CacheT,
_p: std::marker::PhantomData<(ProcErr)>
}
impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr> SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
where
SqsT: Sqs + Clone + Send + Sync + 'static,
CPE: Debug + Send + Sync + 'static,
CP: CompletionEventSerializer<CompletedEvent=CE, Output=Payload, Error=CPE> + Send + Sync + 'static,
Payload: Send + Sync + 'static,
CE: Send + Sync + Clone + 'static,
EE: EventEmitter<Event=Payload> + Send + Sync + 'static,
OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>) + Send + Sync + 'static,
CacheT: Cache<ProcErr> + Send + Sync + Clone + 'static,
ProcErr: Debug + Clone + Send + Sync + 'static,
{
pub fn new(
sqs_client: SqsT,
queue_url: String,
completion_serializer: CP,
event_emitter: EE,
completion_policy: CompletionPolicy,
on_ack: OA,
cache: CacheT,
) -> Self {
Self {
sqs_client,
queue_url,
completed_events: Vec::with_capacity(completion_policy.max_messages as usize),
identities: Vec::with_capacity(completion_policy.max_messages as usize),
completed_messages: Vec::with_capacity(completion_policy.max_messages as usize),
completion_serializer,
event_emitter,
completion_policy,
on_ack,
self_actor: None,
cache,
_p: std::marker::PhantomData,
}
}
}
impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr> SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
where
SqsT: Sqs + Clone + Send + Sync + 'static,
CPE: Debug + Send + Sync + 'static,
CP: CompletionEventSerializer<CompletedEvent=CE, Output=Payload, Error=CPE> + Send + Sync + 'static,
Payload: Send + Sync + 'static,
CE: Send + Sync + Clone + 'static,
EE: EventEmitter<Event=Payload> + Send + Sync + 'static,
OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>) + Send + Sync + 'static,
CacheT: Cache<ProcErr> + Send + Sync + Clone + 'static,
ProcErr: Debug + Clone + Send + Sync + 'static
{
pub async fn mark_complete(
&mut self,
sqs_message: SqsMessage,
completed: OutputEvent<CE, ProcErr>
) {
match completed.completed_event {
Completion::Total(ce) => {
info!("Marking all events complete - total success");
self.completed_events.push(ce);
self.completed_messages.push(sqs_message);
self.identities.extend(completed.identities);
},
Completion::Partial((ce, err)) => {
warn!("EventHandler was only partially successful: {:?}", err);
self.completed_events.push(ce);
self.identities.extend(completed.identities);
},
Completion::Error(e) => {
warn!("Event handler failed: {:?}", e);
}
};
info!(
"Marked event complete. {} completed events, {} completed messages",
self.completed_events.len(),
self.completed_messages.len(),
);
if self.completion_policy.should_flush(self.completed_events.len() as u16) {
self.ack_all(None).await;
self.completion_policy.set_last_flush();
}
}
pub async fn ack_all(&mut self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
info!("Flushing completed events");
let serialized_event = self
.completion_serializer
.serialize_completed_events(&self.completed_events[..]);
let serialized_event = match serialized_event {
Ok(serialized_event) => serialized_event,
Err(e) => {
self.completed_events.clear();
self.completed_messages.clear();
panic!("Serializing events failed: {:?}", e);
}
};
info!("Emitting events");
self.event_emitter.emit_event(serialized_event).await.expect(
"Failed to emit event"
);
for identity in self.identities.drain(..) {
if let Err(e) = self.cache.store(identity).await {
warn!("Failed to cache with: {:?}", e);
}
}
let mut acks = vec![];
for chunk in self.completed_messages.chunks(10) {
let msg_ids: Vec<String> = chunk.iter().map(|msg| {
msg.message_id.clone().unwrap()
}).collect();
let entries = chunk.iter().map(|msg| {
DeleteMessageBatchRequestEntry {
id: msg.message_id.clone().unwrap(),
receipt_handle: msg.receipt_handle.clone().expect("Message missing receipt"),
}
}).collect();
let dmb = self.sqs_client.delete_message_batch(
DeleteMessageBatchRequest {
entries,
queue_url: self.queue_url.clone(),
}
);
let ack = (
tokio::time::timeout(Duration::from_millis(10), dmb).await.expect("timed out"),
msg_ids
);
acks.push(ack);
};
info!("Acking all messages");
for (result, msg_ids) in acks {
match result {
Ok(batch_result) => {
for success in batch_result.successful {
(self.on_ack)(self.self_actor.clone().unwrap(), Ok(success.id))
}
for failure in batch_result.failed {
(self.on_ack)(self.self_actor.clone().unwrap(), Err(failure.id))
}
}
Err(e) => {
for msg_id in msg_ids {
(self.on_ack)(self.self_actor.clone().unwrap(), Err(msg_id))
}
warn!("Failed to acknowledge event: {:?}", e);
}
}
}
info!("Acked");
self.completed_events.clear();
self.completed_messages.clear();
if let Some(notify) = notify {
let _ = notify.send(());
}
}
}
#[allow(non_camel_case_types)]
pub enum SqsCompletionHandlerMessage<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr : Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
mark_complete { msg: SqsMessage, completed: OutputEvent<CE, ProcErr> },
ack_all { notify: Option<tokio::sync::oneshot::Sender<()>> },
_p {
_p: std::marker::PhantomData<(SqsT)>,
},
release,
}
#[async_trait]
impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr> Actor<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>> for SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
where
SqsT: Sqs + Clone + Send + Sync + 'static,
CPE: Debug + Send + Sync + 'static,
CP: CompletionEventSerializer<CompletedEvent=CE, Output=Payload, Error=CPE> + Send + Sync + 'static,
Payload: Send + Sync + 'static,
CE: Send + Sync + Clone + 'static,
EE: EventEmitter<Event=Payload> + Send + Sync + 'static,
OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>) + Send + Sync + 'static,
CacheT: Cache<ProcErr> + Send + Sync + Clone + 'static,
ProcErr: Debug + Clone + Send + Sync + 'static,
{
async fn route_message(&mut self, msg: SqsCompletionHandlerMessage<CE, ProcErr, SqsT>) {
match msg {
SqsCompletionHandlerMessage::mark_complete { msg, completed} => {
self.mark_complete(msg, completed).await
}
SqsCompletionHandlerMessage::ack_all { notify } => self.ack_all(notify).await,
SqsCompletionHandlerMessage::release { } => (),
SqsCompletionHandlerMessage::_p { .. } => (),
};
}
fn close(&mut self) {
self.self_actor = None;
}
fn get_actor_name(&self) -> &str {
&self.self_actor.as_ref().unwrap().actor_name
}
}
pub struct SqsCompletionHandlerActor<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr: Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
sender: Sender<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>>,
inner_rc: std::sync::Arc<std::sync::atomic::AtomicUsize>,
queue_len: std::sync::Arc<std::sync::atomic::AtomicUsize>,
actor_name: String,
actor_uuid: uuid::Uuid,
actor_num: u32,
}
impl<CE, ProcErr, SqsT> Clone for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr: Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
fn clone(&self) -> Self {
self.inner_rc.clone().fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Self {
sender: self.sender.clone(),
inner_rc: self.inner_rc.clone(),
queue_len: self.queue_len.clone(),
actor_name: format!(
"{} {} {}",
stringify!(SqsCompletionHandlerActor),
self.actor_uuid,
self.actor_num + 1,
),
actor_uuid: self.actor_uuid,
actor_num: self.actor_num + 1,
}
}
}
impl<CE, ProcErr, SqsT> SqsCompletionHandlerActor<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr: Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
pub fn new<CPE, CP, Payload, EE, OA, CacheT>(mut actor_impl: SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>) -> (Self, tokio::task::JoinHandle<()>)
where
SqsT: Sqs + Clone + Send + Sync + 'static,
CPE: Debug + Send + Sync + 'static,
CP: CompletionEventSerializer<CompletedEvent=CE, Output=Payload, Error=CPE>
+ Send
+ Sync
+ 'static,
Payload: Send + Sync + 'static,
EE: EventEmitter<Event=Payload> + Send + Sync + 'static,
OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>) + Send + Sync + 'static,
CacheT: Cache<ProcErr> + Send + Sync + Clone + 'static,
{
let (sender, receiver) = channel(1);
let inner_rc = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(1));
let queue_len = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
let actor_uuid = uuid::Uuid::new_v4();
let actor_name = format!(
"{} {} {}",
stringify!(#actor_ty),
actor_uuid,
0,
);
let mut self_actor = Self {
sender,
inner_rc: inner_rc.clone(),
queue_len: queue_len.clone(),
actor_name,
actor_uuid,
actor_num: 0,
};
actor_impl.self_actor = Some(self_actor.clone());
let handle = tokio::task::spawn(
aktors::actor::route_wrapper(
aktors::actor::Router::new(
actor_impl,
receiver,
inner_rc,
queue_len
)
)
);
(self_actor, handle)
}
pub async fn release(self) {
let msg = SqsCompletionHandlerMessage::release;
let mut sender = self.sender.clone();
let queue_len = self.queue_len.clone();
queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::task::spawn(
async move {
if let Err(e) = sender.send(msg).await {
panic!(
concat!(
"Receiver has failed with {}, propagating error. ",
stringify!(#ident)
),
e
)
}
}
);
}
pub async fn mark_complete(&self, msg: SqsMessage, completed: OutputEvent<CE, ProcErr>) {
let msg = SqsCompletionHandlerMessage::mark_complete { msg, completed};
let mut sender = self.sender.clone();
let queue_len = self.queue_len.clone();
queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::task::spawn(
async move {
if let Err(e) = sender.send(msg).await {
panic!(
concat!(
"Receiver has failed with {}, propagating error. ",
stringify!(#ident)
),
e
)
}
}
);
}
async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
let msg = SqsCompletionHandlerMessage::ack_all { notify };
let mut sender = self.sender.clone();
let queue_len = self.queue_len.clone();
queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
tokio::task::spawn(
async move {
if let Err(e) = sender.send(msg).await {
panic!(
concat!(
"Receiver has failed with {}, propagating error. ",
stringify!(#ident)
),
e
)
}
}
);
}
async fn _p(&self, _p: std::marker::PhantomData<(SqsT)>) {
panic!("Invalid to call p");
let msg = SqsCompletionHandlerMessage::_p { _p };
if let Err(_e) = self.sender.clone().send(msg).await {
panic!("Receiver has failed, propagating error. _p")
}
self.queue_len.clone().fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
impl<CE, ProcErr, SqsT> Drop for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr: Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
fn drop(&mut self) {
self.inner_rc.clone().fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
}
}
#[async_trait]
impl<CE, ProcErr, SqsT> CompletionHandler for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr: Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
type Message = SqsMessage;
type CompletedEvent = OutputEvent<CE, ProcErr>;
async fn mark_complete(&self, msg: Self::Message, completed_event: Self::CompletedEvent) {
SqsCompletionHandlerActor::mark_complete(
self, msg, completed_event
).await
}
async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
SqsCompletionHandlerActor::ack_all(
self, notify,
).await
}
async fn release(self) {
SqsCompletionHandlerActor::release(
self
).await
}
}
impl<CE, ProcErr, SqsT> aktors::actor::Message for SqsCompletionHandlerMessage<CE, ProcErr, SqsT>
where
CE: Send + Sync + Clone + 'static,
ProcErr : Debug + Send + Sync + Clone + 'static,
SqsT: Sqs + Clone + Send + Sync + 'static,
{
fn is_release(&self) -> bool {
if let Self::release = self {
true
} else {
false
}
}
}