sqs_lambda/
sqs_completion_handler.rs

1use std::fmt::Debug;
2use std::time::{Duration, Instant};
3
4use log::*;
5use rusoto_sqs::Message as SqsMessage;
6use rusoto_sqs::{DeleteMessageBatchRequest, DeleteMessageBatchRequestEntry, Sqs};
7use tokio::sync::mpsc::{channel, Sender};
8
9use crate::cache::Cache;
10use crate::completion_event_serializer::CompletionEventSerializer;
11use crate::event_emitter::EventEmitter;
12use crate::event_handler::{Completion, OutputEvent};
13use aktors::actor::Actor;
14use async_trait::async_trait;
15
16use crate::completion_handler::CompletionHandler;
17use color_eyre::Help;
18
19#[derive(Debug, Clone)]
20pub struct CompletionPolicy {
21    max_messages: u16,
22    max_time_between_flushes: Duration,
23    last_flush: Instant,
24}
25
26impl CompletionPolicy {
27    pub fn new(max_messages: u16, max_time_between_flushes: Duration) -> Self {
28        Self {
29            max_messages,
30            max_time_between_flushes,
31            last_flush: Instant::now(),
32        }
33    }
34
35    pub fn should_flush(&self, cur_messages: u16) -> bool {
36        cur_messages >= self.max_messages
37            || Instant::now()
38                .checked_duration_since(self.last_flush)
39                .unwrap()
40                >= self.max_time_between_flushes
41    }
42
43    pub fn set_last_flush(&mut self) {
44        self.last_flush = Instant::now();
45    }
46}
47
48pub struct SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
49where
50    SqsT: Sqs + Clone + Send + Sync + 'static,
51    CPE: Debug + Send + Sync + 'static,
52    CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
53        + Send
54        + Sync
55        + 'static,
56    Payload: Send + Sync + 'static,
57    CE: Send + Sync + Clone + 'static,
58    EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
59    OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
60        + Send
61        + Sync
62        + 'static,
63    CacheT: Cache + Send + Sync + Clone + 'static,
64    ProcErr: Debug + Send + Sync + 'static,
65{
66    sqs_client: SqsT,
67    queue_url: String,
68    completed_events: Vec<CE>,
69    identities: Vec<Vec<u8>>,
70    completed_messages: Vec<SqsMessage>,
71    completion_serializer: CP,
72    event_emitter: EE,
73    completion_policy: CompletionPolicy,
74    on_ack: OA,
75    self_actor: Option<SqsCompletionHandlerActor<CE, ProcErr, SqsT>>,
76    cache: CacheT,
77    _p: std::marker::PhantomData<ProcErr>,
78}
79
80impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
81    SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
82where
83    SqsT: Sqs + Clone + Send + Sync + 'static,
84    CPE: Debug + Send + Sync + 'static,
85    CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
86        + Send
87        + Sync
88        + 'static,
89    Payload: Send + Sync + 'static,
90    CE: Send + Sync + Clone + 'static,
91    EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
92    OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
93        + Send
94        + Sync
95        + 'static,
96    CacheT: Cache + Send + Sync + Clone + 'static,
97    ProcErr: Debug + Send + Sync + 'static,
98{
99    pub fn new(
100        sqs_client: SqsT,
101        queue_url: String,
102        completion_serializer: CP,
103        event_emitter: EE,
104        completion_policy: CompletionPolicy,
105        on_ack: OA,
106        cache: CacheT,
107    ) -> Self {
108        Self {
109            sqs_client,
110            queue_url,
111            completed_events: Vec::with_capacity(completion_policy.max_messages as usize),
112            identities: Vec::with_capacity(completion_policy.max_messages as usize),
113            completed_messages: Vec::with_capacity(completion_policy.max_messages as usize),
114            completion_serializer,
115            event_emitter,
116            completion_policy,
117            on_ack,
118            self_actor: None,
119            cache,
120            _p: std::marker::PhantomData,
121        }
122    }
123}
124
125async fn retry<F, T, E>(max_tries: u32, f: impl Fn() -> F) -> color_eyre::Result<T>
126where
127    T: Send,
128    F: std::future::Future<Output = Result<T, E>>,
129    E: std::error::Error + Send + Sync + 'static,
130{
131    let mut backoff = 2;
132    let mut errs: Result<T, _> = Err(eyre::eyre!("wait_loop failed"));
133    for i in 0..max_tries {
134        match (f)().await {
135            Ok(t) => return Ok(t),
136            Err(e) => {
137                errs = errs.error(e);
138            }
139        };
140
141        tokio::time::delay_for(Duration::from_millis(backoff)).await;
142        backoff *= i as u64;
143    }
144
145    errs
146}
147
148impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
149    SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
150where
151    SqsT: Sqs + Clone + Send + Sync + 'static,
152    CPE: Debug + Send + Sync + 'static,
153    CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
154        + Send
155        + Sync
156        + 'static,
157    Payload: Send + Sync + 'static,
158    CE: Send + Sync + Clone + 'static,
159    EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
160    OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
161        + Send
162        + Sync
163        + 'static,
164    CacheT: Cache + Send + Sync + Clone + 'static,
165    ProcErr: Debug + Send + Sync + 'static,
166{
167    #[tracing::instrument(skip(self))]
168    pub async fn ack_message(&mut self, sqs_message: SqsMessage) {
169        self.completed_messages.push(sqs_message);
170        if self
171            .completion_policy
172            .should_flush(self.completed_events.len() as u16)
173        {
174            self.ack_all(None).await;
175            self.completion_policy.set_last_flush();
176        }
177    }
178
179    #[tracing::instrument(skip(self, completed))]
180    pub async fn mark_complete(
181        &mut self,
182        sqs_message: SqsMessage,
183        completed: OutputEvent<CE, ProcErr>,
184    ) {
185        match completed.completed_event {
186            Completion::Total(ce) => {
187                info!("Marking all events complete - total success");
188                self.completed_events.push(ce);
189                self.completed_messages.push(sqs_message);
190                self.identities.extend(completed.identities);
191            }
192            Completion::Partial((ce, err)) => {
193                warn!("EventHandler was only partially successful: {:?}", err);
194                self.completed_events.push(ce);
195                self.identities.extend(completed.identities);
196            }
197            Completion::Error(e) => {
198                warn!("Event handler failed: {:?}", e);
199            }
200        };
201
202        info!(
203            "Marked event complete. {} completed events, {} completed messages",
204            self.completed_events.len(),
205            self.completed_messages.len(),
206        );
207
208        if self
209            .completion_policy
210            .should_flush(self.completed_events.len() as u16)
211        {
212            self.ack_all(None).await;
213            self.completion_policy.set_last_flush();
214        }
215    }
216
217    #[tracing::instrument(skip(self, notify))]
218    pub async fn ack_all(&mut self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
219        debug!("Flushing completed events");
220
221        let serialized_event = self
222            .completion_serializer
223            .serialize_completed_events(&self.completed_events[..]);
224
225        let serialized_event = match serialized_event {
226            Ok(serialized_event) => serialized_event,
227            Err(e) => {
228                // We should emit a failure, but ultimately we just have to not ack these messages
229                self.completed_events.clear();
230                self.completed_messages.clear();
231
232                panic!("Serializing events failed: {:?}", e);
233            }
234        };
235
236        debug!("Emitting events");
237        self.event_emitter
238            .emit_event(serialized_event)
239            .await
240            .expect("Failed to emit event");
241
242        for identity in self.identities.drain(..) {
243            if let Err(e) = self.cache.store(identity).await {
244                warn!("Failed to cache with: {:?}", e);
245            }
246        }
247
248        let mut acks = vec![];
249
250        for chunk in self.completed_messages.chunks(10) {
251            let msg_ids: Vec<String> = chunk
252                .iter()
253                .map(|msg| msg.message_id.clone().unwrap())
254                .collect();
255
256            let entries: Vec<_> = chunk
257                .iter()
258                .map(|msg| DeleteMessageBatchRequestEntry {
259                    id: msg.message_id.clone().unwrap(),
260                    receipt_handle: msg.receipt_handle.clone().expect("Message missing receipt"),
261                })
262                .collect();
263
264            match retry(10, || async {
265                let dmb = self
266                    .sqs_client
267                    .delete_message_batch(DeleteMessageBatchRequest {
268                        entries: entries.clone(),
269                        queue_url: self.queue_url.clone(),
270                    });
271
272                tokio::time::timeout(Duration::from_millis(250), dmb).await
273            })
274            .await
275            {
276                Ok(dmb) => acks.push((dmb, msg_ids)),
277                Err(e) => warn!("Failed to delete message, timed out: {:?}", e),
278            };
279        }
280
281        debug!("Acking all messages");
282
283        for (result, msg_ids) in acks {
284            match result {
285                Ok(batch_result) => {
286                    for success in batch_result.successful {
287                        (self.on_ack)(self.self_actor.clone().unwrap(), Ok(success.id))
288                    }
289
290                    for failure in batch_result.failed {
291                        (self.on_ack)(self.self_actor.clone().unwrap(), Err(failure.id))
292                    }
293                }
294                Err(e) => {
295                    for msg_id in msg_ids {
296                        (self.on_ack)(self.self_actor.clone().unwrap(), Err(msg_id))
297                    }
298                    warn!("Failed to acknowledge event: {:?}", e);
299                }
300            }
301            // (self.on_ack)(result, message_id);
302        }
303        debug!("Acked");
304
305        self.completed_events.clear();
306        self.completed_messages.clear();
307
308        if let Some(notify) = notify {
309            let _ = notify.send(());
310        }
311    }
312}
313
314#[allow(non_camel_case_types)]
315pub enum SqsCompletionHandlerMessage<CE, ProcErr, SqsT>
316where
317    CE: Send + Sync + Clone + 'static,
318    ProcErr: Debug + Send + Sync + 'static,
319    SqsT: Sqs + Clone + Send + Sync + 'static,
320{
321    mark_complete {
322        msg: SqsMessage,
323        completed: OutputEvent<CE, ProcErr>,
324    },
325    ack_message {
326        msg: SqsMessage,
327    },
328    ack_all {
329        notify: Option<tokio::sync::oneshot::Sender<()>>,
330    },
331    _p {
332        _p: std::marker::PhantomData<SqsT>,
333    },
334}
335
336#[async_trait]
337impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
338    Actor<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>>
339    for SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
340where
341    SqsT: Sqs + Clone + Send + Sync + 'static,
342    CPE: Debug + Send + Sync + 'static,
343    CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
344        + Send
345        + Sync
346        + 'static,
347    Payload: Send + Sync + 'static,
348    CE: Send + Sync + Clone + 'static,
349    EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
350    OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
351        + Send
352        + Sync
353        + 'static,
354    CacheT: Cache + Send + Sync + Clone + 'static,
355    ProcErr: Debug + Send + Sync + 'static,
356{
357    #[tracing::instrument(skip(self, msg))]
358    async fn route_message(&mut self, msg: SqsCompletionHandlerMessage<CE, ProcErr, SqsT>) {
359        match msg {
360            SqsCompletionHandlerMessage::mark_complete { msg, completed } => {
361                self.mark_complete(msg, completed).await
362            }
363            SqsCompletionHandlerMessage::ack_all { notify } => self.ack_all(notify).await,
364            SqsCompletionHandlerMessage::ack_message { msg } => self.ack_message(msg).await,
365            SqsCompletionHandlerMessage::_p { .. } => (),
366        };
367    }
368
369    fn close(&mut self) {
370        self.self_actor = None;
371    }
372
373    fn get_actor_name(&self) -> &str {
374        &self.self_actor.as_ref().unwrap().actor_name
375    }
376}
377
378pub struct SqsCompletionHandlerActor<CE, ProcErr, SqsT>
379where
380    CE: Send + Sync + Clone + 'static,
381    ProcErr: Debug + Send + Sync + 'static,
382    SqsT: Sqs + Clone + Send + Sync + 'static,
383{
384    sender: Sender<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>>,
385    inner_rc: std::sync::Arc<std::sync::atomic::AtomicUsize>,
386    queue_len: std::sync::Arc<std::sync::atomic::AtomicUsize>,
387    actor_name: String,
388    actor_uuid: uuid::Uuid,
389    actor_num: u32,
390}
391
392impl<CE, ProcErr, SqsT> Clone for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
393where
394    CE: Send + Sync + Clone + 'static,
395    ProcErr: Debug + Send + Sync + 'static,
396    SqsT: Sqs + Clone + Send + Sync + 'static,
397{
398    fn clone(&self) -> Self {
399        self.inner_rc
400            .clone()
401            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
402
403        Self {
404            sender: self.sender.clone(),
405            inner_rc: self.inner_rc.clone(),
406            queue_len: self.queue_len.clone(),
407            actor_name: format!(
408                "{} {} {}",
409                stringify!(SqsCompletionHandlerActor),
410                self.actor_uuid,
411                self.actor_num + 1,
412            ),
413            actor_uuid: self.actor_uuid,
414            actor_num: self.actor_num + 1,
415        }
416    }
417}
418
419impl<CE, ProcErr, SqsT> SqsCompletionHandlerActor<CE, ProcErr, SqsT>
420where
421    CE: Send + Sync + Clone + 'static,
422    ProcErr: Debug + Send + Sync + 'static,
423    SqsT: Sqs + Clone + Send + Sync + 'static,
424{
425    pub fn new<CPE, CP, Payload, EE, OA, CacheT>(
426        mut actor_impl: SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>,
427    ) -> (Self, tokio::task::JoinHandle<()>)
428    where
429        SqsT: Sqs + Clone + Send + Sync + 'static,
430        CPE: Debug + Send + Sync + 'static,
431        CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
432            + Send
433            + Sync
434            + 'static,
435        Payload: Send + Sync + 'static,
436        EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
437        OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
438            + Send
439            + Sync
440            + 'static,
441        CacheT: Cache + Send + Sync + Clone + 'static,
442    {
443        let (sender, receiver) = channel(1);
444        let inner_rc = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(1));
445
446        let queue_len = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
447
448        let actor_uuid = uuid::Uuid::new_v4();
449        let actor_name = format!("{} {} {}", stringify!(#actor_ty), actor_uuid, 0,);
450        let self_actor = Self {
451            sender,
452            inner_rc: inner_rc.clone(),
453            queue_len: queue_len.clone(),
454            actor_name,
455            actor_uuid,
456            actor_num: 0,
457        };
458
459        actor_impl.self_actor = Some(self_actor.clone());
460
461        let handle = tokio::task::spawn(aktors::actor::route_wrapper(aktors::actor::Router::new(
462            actor_impl, receiver, inner_rc, queue_len,
463        )));
464
465        (self_actor, handle)
466    }
467
468    pub async fn mark_complete(&self, msg: SqsMessage, completed: OutputEvent<CE, ProcErr>) {
469        let msg = SqsCompletionHandlerMessage::mark_complete { msg, completed };
470        let mut sender = self.sender.clone();
471
472        let queue_len = self.queue_len.clone();
473        queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
474
475        tokio::task::spawn(async move {
476            if let Err(e) = sender.send(msg).await {
477                panic!(
478                    "Receiver has failed with {}, propagating error. SqsCompletionHandler",
479                    e
480                )
481            }
482        });
483    }
484
485    pub async fn ack_message(&self, msg: SqsMessage) {
486        let msg = SqsCompletionHandlerMessage::ack_message { msg };
487        let mut sender = self.sender.clone();
488
489        let queue_len = self.queue_len.clone();
490        queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
491
492        tokio::task::spawn(async move {
493            if let Err(e) = sender.send(msg).await {
494                panic!(
495                    concat!(
496                        "Receiver has failed with {}, propagating error. ",
497                        "SqsCompletionHandler"
498                    ),
499                    e
500                )
501            }
502        });
503    }
504
505    async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
506        let msg = SqsCompletionHandlerMessage::ack_all { notify };
507        let mut sender = self.sender.clone();
508
509        let queue_len = self.queue_len.clone();
510        queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
511
512        tokio::task::spawn(async move {
513            if let Err(e) = sender.send(msg).await {
514                panic!(
515                    "Receiver has failed with {}, propagating error. SqsCompletionHandler",
516                    e
517                )
518            }
519        });
520    }
521
522    async fn _p(&self, _p: std::marker::PhantomData<SqsT>) {
523        panic!("Invalid to call p");
524        let msg = SqsCompletionHandlerMessage::_p { _p };
525        if let Err(_e) = self.sender.clone().send(msg).await {
526            panic!("Receiver has failed, propagating error. _p")
527        }
528        self.queue_len
529            .clone()
530            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
531    }
532}
533
534impl<CE, ProcErr, SqsT> Drop for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
535where
536    CE: Send + Sync + Clone + 'static,
537    ProcErr: Debug + Send + Sync + 'static,
538    SqsT: Sqs + Clone + Send + Sync + 'static,
539{
540    fn drop(&mut self) {
541        self.inner_rc
542            .clone()
543            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
544    }
545}
546
547#[async_trait]
548impl<CE, ProcErr, SqsT> CompletionHandler for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
549where
550    CE: Send + Sync + Clone + 'static,
551    ProcErr: Debug + Send + Sync + 'static,
552    SqsT: Sqs + Clone + Send + Sync + 'static,
553{
554    type Message = SqsMessage;
555    type CompletedEvent = OutputEvent<CE, ProcErr>;
556
557    async fn mark_complete(&self, msg: Self::Message, completed_event: Self::CompletedEvent) {
558        SqsCompletionHandlerActor::mark_complete(self, msg, completed_event).await
559    }
560
561    async fn ack_message(&self, msg: Self::Message) {
562        SqsCompletionHandlerActor::ack_message(self, msg).await
563    }
564
565    async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
566        SqsCompletionHandlerActor::ack_all(self, notify).await
567    }
568}