1use std::fmt::Debug;
2use std::time::{Duration, Instant};
3
4use log::*;
5use rusoto_sqs::Message as SqsMessage;
6use rusoto_sqs::{DeleteMessageBatchRequest, DeleteMessageBatchRequestEntry, Sqs};
7use tokio::sync::mpsc::{channel, Sender};
8
9use crate::cache::Cache;
10use crate::completion_event_serializer::CompletionEventSerializer;
11use crate::event_emitter::EventEmitter;
12use crate::event_handler::{Completion, OutputEvent};
13use aktors::actor::Actor;
14use async_trait::async_trait;
15
16use crate::completion_handler::CompletionHandler;
17use color_eyre::Help;
18
19#[derive(Debug, Clone)]
20pub struct CompletionPolicy {
21 max_messages: u16,
22 max_time_between_flushes: Duration,
23 last_flush: Instant,
24}
25
26impl CompletionPolicy {
27 pub fn new(max_messages: u16, max_time_between_flushes: Duration) -> Self {
28 Self {
29 max_messages,
30 max_time_between_flushes,
31 last_flush: Instant::now(),
32 }
33 }
34
35 pub fn should_flush(&self, cur_messages: u16) -> bool {
36 cur_messages >= self.max_messages
37 || Instant::now()
38 .checked_duration_since(self.last_flush)
39 .unwrap()
40 >= self.max_time_between_flushes
41 }
42
43 pub fn set_last_flush(&mut self) {
44 self.last_flush = Instant::now();
45 }
46}
47
48pub struct SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
49where
50 SqsT: Sqs + Clone + Send + Sync + 'static,
51 CPE: Debug + Send + Sync + 'static,
52 CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
53 + Send
54 + Sync
55 + 'static,
56 Payload: Send + Sync + 'static,
57 CE: Send + Sync + Clone + 'static,
58 EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
59 OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
60 + Send
61 + Sync
62 + 'static,
63 CacheT: Cache + Send + Sync + Clone + 'static,
64 ProcErr: Debug + Send + Sync + 'static,
65{
66 sqs_client: SqsT,
67 queue_url: String,
68 completed_events: Vec<CE>,
69 identities: Vec<Vec<u8>>,
70 completed_messages: Vec<SqsMessage>,
71 completion_serializer: CP,
72 event_emitter: EE,
73 completion_policy: CompletionPolicy,
74 on_ack: OA,
75 self_actor: Option<SqsCompletionHandlerActor<CE, ProcErr, SqsT>>,
76 cache: CacheT,
77 _p: std::marker::PhantomData<ProcErr>,
78}
79
80impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
81 SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
82where
83 SqsT: Sqs + Clone + Send + Sync + 'static,
84 CPE: Debug + Send + Sync + 'static,
85 CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
86 + Send
87 + Sync
88 + 'static,
89 Payload: Send + Sync + 'static,
90 CE: Send + Sync + Clone + 'static,
91 EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
92 OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
93 + Send
94 + Sync
95 + 'static,
96 CacheT: Cache + Send + Sync + Clone + 'static,
97 ProcErr: Debug + Send + Sync + 'static,
98{
99 pub fn new(
100 sqs_client: SqsT,
101 queue_url: String,
102 completion_serializer: CP,
103 event_emitter: EE,
104 completion_policy: CompletionPolicy,
105 on_ack: OA,
106 cache: CacheT,
107 ) -> Self {
108 Self {
109 sqs_client,
110 queue_url,
111 completed_events: Vec::with_capacity(completion_policy.max_messages as usize),
112 identities: Vec::with_capacity(completion_policy.max_messages as usize),
113 completed_messages: Vec::with_capacity(completion_policy.max_messages as usize),
114 completion_serializer,
115 event_emitter,
116 completion_policy,
117 on_ack,
118 self_actor: None,
119 cache,
120 _p: std::marker::PhantomData,
121 }
122 }
123}
124
125async fn retry<F, T, E>(max_tries: u32, f: impl Fn() -> F) -> color_eyre::Result<T>
126where
127 T: Send,
128 F: std::future::Future<Output = Result<T, E>>,
129 E: std::error::Error + Send + Sync + 'static,
130{
131 let mut backoff = 2;
132 let mut errs: Result<T, _> = Err(eyre::eyre!("wait_loop failed"));
133 for i in 0..max_tries {
134 match (f)().await {
135 Ok(t) => return Ok(t),
136 Err(e) => {
137 errs = errs.error(e);
138 }
139 };
140
141 tokio::time::delay_for(Duration::from_millis(backoff)).await;
142 backoff *= i as u64;
143 }
144
145 errs
146}
147
148impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
149 SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
150where
151 SqsT: Sqs + Clone + Send + Sync + 'static,
152 CPE: Debug + Send + Sync + 'static,
153 CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
154 + Send
155 + Sync
156 + 'static,
157 Payload: Send + Sync + 'static,
158 CE: Send + Sync + Clone + 'static,
159 EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
160 OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
161 + Send
162 + Sync
163 + 'static,
164 CacheT: Cache + Send + Sync + Clone + 'static,
165 ProcErr: Debug + Send + Sync + 'static,
166{
167 #[tracing::instrument(skip(self))]
168 pub async fn ack_message(&mut self, sqs_message: SqsMessage) {
169 self.completed_messages.push(sqs_message);
170 if self
171 .completion_policy
172 .should_flush(self.completed_events.len() as u16)
173 {
174 self.ack_all(None).await;
175 self.completion_policy.set_last_flush();
176 }
177 }
178
179 #[tracing::instrument(skip(self, completed))]
180 pub async fn mark_complete(
181 &mut self,
182 sqs_message: SqsMessage,
183 completed: OutputEvent<CE, ProcErr>,
184 ) {
185 match completed.completed_event {
186 Completion::Total(ce) => {
187 info!("Marking all events complete - total success");
188 self.completed_events.push(ce);
189 self.completed_messages.push(sqs_message);
190 self.identities.extend(completed.identities);
191 }
192 Completion::Partial((ce, err)) => {
193 warn!("EventHandler was only partially successful: {:?}", err);
194 self.completed_events.push(ce);
195 self.identities.extend(completed.identities);
196 }
197 Completion::Error(e) => {
198 warn!("Event handler failed: {:?}", e);
199 }
200 };
201
202 info!(
203 "Marked event complete. {} completed events, {} completed messages",
204 self.completed_events.len(),
205 self.completed_messages.len(),
206 );
207
208 if self
209 .completion_policy
210 .should_flush(self.completed_events.len() as u16)
211 {
212 self.ack_all(None).await;
213 self.completion_policy.set_last_flush();
214 }
215 }
216
217 #[tracing::instrument(skip(self, notify))]
218 pub async fn ack_all(&mut self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
219 debug!("Flushing completed events");
220
221 let serialized_event = self
222 .completion_serializer
223 .serialize_completed_events(&self.completed_events[..]);
224
225 let serialized_event = match serialized_event {
226 Ok(serialized_event) => serialized_event,
227 Err(e) => {
228 self.completed_events.clear();
230 self.completed_messages.clear();
231
232 panic!("Serializing events failed: {:?}", e);
233 }
234 };
235
236 debug!("Emitting events");
237 self.event_emitter
238 .emit_event(serialized_event)
239 .await
240 .expect("Failed to emit event");
241
242 for identity in self.identities.drain(..) {
243 if let Err(e) = self.cache.store(identity).await {
244 warn!("Failed to cache with: {:?}", e);
245 }
246 }
247
248 let mut acks = vec![];
249
250 for chunk in self.completed_messages.chunks(10) {
251 let msg_ids: Vec<String> = chunk
252 .iter()
253 .map(|msg| msg.message_id.clone().unwrap())
254 .collect();
255
256 let entries: Vec<_> = chunk
257 .iter()
258 .map(|msg| DeleteMessageBatchRequestEntry {
259 id: msg.message_id.clone().unwrap(),
260 receipt_handle: msg.receipt_handle.clone().expect("Message missing receipt"),
261 })
262 .collect();
263
264 match retry(10, || async {
265 let dmb = self
266 .sqs_client
267 .delete_message_batch(DeleteMessageBatchRequest {
268 entries: entries.clone(),
269 queue_url: self.queue_url.clone(),
270 });
271
272 tokio::time::timeout(Duration::from_millis(250), dmb).await
273 })
274 .await
275 {
276 Ok(dmb) => acks.push((dmb, msg_ids)),
277 Err(e) => warn!("Failed to delete message, timed out: {:?}", e),
278 };
279 }
280
281 debug!("Acking all messages");
282
283 for (result, msg_ids) in acks {
284 match result {
285 Ok(batch_result) => {
286 for success in batch_result.successful {
287 (self.on_ack)(self.self_actor.clone().unwrap(), Ok(success.id))
288 }
289
290 for failure in batch_result.failed {
291 (self.on_ack)(self.self_actor.clone().unwrap(), Err(failure.id))
292 }
293 }
294 Err(e) => {
295 for msg_id in msg_ids {
296 (self.on_ack)(self.self_actor.clone().unwrap(), Err(msg_id))
297 }
298 warn!("Failed to acknowledge event: {:?}", e);
299 }
300 }
301 }
303 debug!("Acked");
304
305 self.completed_events.clear();
306 self.completed_messages.clear();
307
308 if let Some(notify) = notify {
309 let _ = notify.send(());
310 }
311 }
312}
313
314#[allow(non_camel_case_types)]
315pub enum SqsCompletionHandlerMessage<CE, ProcErr, SqsT>
316where
317 CE: Send + Sync + Clone + 'static,
318 ProcErr: Debug + Send + Sync + 'static,
319 SqsT: Sqs + Clone + Send + Sync + 'static,
320{
321 mark_complete {
322 msg: SqsMessage,
323 completed: OutputEvent<CE, ProcErr>,
324 },
325 ack_message {
326 msg: SqsMessage,
327 },
328 ack_all {
329 notify: Option<tokio::sync::oneshot::Sender<()>>,
330 },
331 _p {
332 _p: std::marker::PhantomData<SqsT>,
333 },
334}
335
336#[async_trait]
337impl<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
338 Actor<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>>
339 for SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>
340where
341 SqsT: Sqs + Clone + Send + Sync + 'static,
342 CPE: Debug + Send + Sync + 'static,
343 CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
344 + Send
345 + Sync
346 + 'static,
347 Payload: Send + Sync + 'static,
348 CE: Send + Sync + Clone + 'static,
349 EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
350 OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
351 + Send
352 + Sync
353 + 'static,
354 CacheT: Cache + Send + Sync + Clone + 'static,
355 ProcErr: Debug + Send + Sync + 'static,
356{
357 #[tracing::instrument(skip(self, msg))]
358 async fn route_message(&mut self, msg: SqsCompletionHandlerMessage<CE, ProcErr, SqsT>) {
359 match msg {
360 SqsCompletionHandlerMessage::mark_complete { msg, completed } => {
361 self.mark_complete(msg, completed).await
362 }
363 SqsCompletionHandlerMessage::ack_all { notify } => self.ack_all(notify).await,
364 SqsCompletionHandlerMessage::ack_message { msg } => self.ack_message(msg).await,
365 SqsCompletionHandlerMessage::_p { .. } => (),
366 };
367 }
368
369 fn close(&mut self) {
370 self.self_actor = None;
371 }
372
373 fn get_actor_name(&self) -> &str {
374 &self.self_actor.as_ref().unwrap().actor_name
375 }
376}
377
378pub struct SqsCompletionHandlerActor<CE, ProcErr, SqsT>
379where
380 CE: Send + Sync + Clone + 'static,
381 ProcErr: Debug + Send + Sync + 'static,
382 SqsT: Sqs + Clone + Send + Sync + 'static,
383{
384 sender: Sender<SqsCompletionHandlerMessage<CE, ProcErr, SqsT>>,
385 inner_rc: std::sync::Arc<std::sync::atomic::AtomicUsize>,
386 queue_len: std::sync::Arc<std::sync::atomic::AtomicUsize>,
387 actor_name: String,
388 actor_uuid: uuid::Uuid,
389 actor_num: u32,
390}
391
392impl<CE, ProcErr, SqsT> Clone for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
393where
394 CE: Send + Sync + Clone + 'static,
395 ProcErr: Debug + Send + Sync + 'static,
396 SqsT: Sqs + Clone + Send + Sync + 'static,
397{
398 fn clone(&self) -> Self {
399 self.inner_rc
400 .clone()
401 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
402
403 Self {
404 sender: self.sender.clone(),
405 inner_rc: self.inner_rc.clone(),
406 queue_len: self.queue_len.clone(),
407 actor_name: format!(
408 "{} {} {}",
409 stringify!(SqsCompletionHandlerActor),
410 self.actor_uuid,
411 self.actor_num + 1,
412 ),
413 actor_uuid: self.actor_uuid,
414 actor_num: self.actor_num + 1,
415 }
416 }
417}
418
419impl<CE, ProcErr, SqsT> SqsCompletionHandlerActor<CE, ProcErr, SqsT>
420where
421 CE: Send + Sync + Clone + 'static,
422 ProcErr: Debug + Send + Sync + 'static,
423 SqsT: Sqs + Clone + Send + Sync + 'static,
424{
425 pub fn new<CPE, CP, Payload, EE, OA, CacheT>(
426 mut actor_impl: SqsCompletionHandler<SqsT, CPE, CP, CE, Payload, EE, OA, CacheT, ProcErr>,
427 ) -> (Self, tokio::task::JoinHandle<()>)
428 where
429 SqsT: Sqs + Clone + Send + Sync + 'static,
430 CPE: Debug + Send + Sync + 'static,
431 CP: CompletionEventSerializer<CompletedEvent = CE, Output = Payload, Error = CPE>
432 + Send
433 + Sync
434 + 'static,
435 Payload: Send + Sync + 'static,
436 EE: EventEmitter<Event = Payload> + Send + Sync + 'static,
437 OA: Fn(SqsCompletionHandlerActor<CE, ProcErr, SqsT>, Result<String, String>)
438 + Send
439 + Sync
440 + 'static,
441 CacheT: Cache + Send + Sync + Clone + 'static,
442 {
443 let (sender, receiver) = channel(1);
444 let inner_rc = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(1));
445
446 let queue_len = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
447
448 let actor_uuid = uuid::Uuid::new_v4();
449 let actor_name = format!("{} {} {}", stringify!(#actor_ty), actor_uuid, 0,);
450 let self_actor = Self {
451 sender,
452 inner_rc: inner_rc.clone(),
453 queue_len: queue_len.clone(),
454 actor_name,
455 actor_uuid,
456 actor_num: 0,
457 };
458
459 actor_impl.self_actor = Some(self_actor.clone());
460
461 let handle = tokio::task::spawn(aktors::actor::route_wrapper(aktors::actor::Router::new(
462 actor_impl, receiver, inner_rc, queue_len,
463 )));
464
465 (self_actor, handle)
466 }
467
468 pub async fn mark_complete(&self, msg: SqsMessage, completed: OutputEvent<CE, ProcErr>) {
469 let msg = SqsCompletionHandlerMessage::mark_complete { msg, completed };
470 let mut sender = self.sender.clone();
471
472 let queue_len = self.queue_len.clone();
473 queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
474
475 tokio::task::spawn(async move {
476 if let Err(e) = sender.send(msg).await {
477 panic!(
478 "Receiver has failed with {}, propagating error. SqsCompletionHandler",
479 e
480 )
481 }
482 });
483 }
484
485 pub async fn ack_message(&self, msg: SqsMessage) {
486 let msg = SqsCompletionHandlerMessage::ack_message { msg };
487 let mut sender = self.sender.clone();
488
489 let queue_len = self.queue_len.clone();
490 queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
491
492 tokio::task::spawn(async move {
493 if let Err(e) = sender.send(msg).await {
494 panic!(
495 concat!(
496 "Receiver has failed with {}, propagating error. ",
497 "SqsCompletionHandler"
498 ),
499 e
500 )
501 }
502 });
503 }
504
505 async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
506 let msg = SqsCompletionHandlerMessage::ack_all { notify };
507 let mut sender = self.sender.clone();
508
509 let queue_len = self.queue_len.clone();
510 queue_len.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
511
512 tokio::task::spawn(async move {
513 if let Err(e) = sender.send(msg).await {
514 panic!(
515 "Receiver has failed with {}, propagating error. SqsCompletionHandler",
516 e
517 )
518 }
519 });
520 }
521
522 async fn _p(&self, _p: std::marker::PhantomData<SqsT>) {
523 panic!("Invalid to call p");
524 let msg = SqsCompletionHandlerMessage::_p { _p };
525 if let Err(_e) = self.sender.clone().send(msg).await {
526 panic!("Receiver has failed, propagating error. _p")
527 }
528 self.queue_len
529 .clone()
530 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
531 }
532}
533
534impl<CE, ProcErr, SqsT> Drop for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
535where
536 CE: Send + Sync + Clone + 'static,
537 ProcErr: Debug + Send + Sync + 'static,
538 SqsT: Sqs + Clone + Send + Sync + 'static,
539{
540 fn drop(&mut self) {
541 self.inner_rc
542 .clone()
543 .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
544 }
545}
546
547#[async_trait]
548impl<CE, ProcErr, SqsT> CompletionHandler for SqsCompletionHandlerActor<CE, ProcErr, SqsT>
549where
550 CE: Send + Sync + Clone + 'static,
551 ProcErr: Debug + Send + Sync + 'static,
552 SqsT: Sqs + Clone + Send + Sync + 'static,
553{
554 type Message = SqsMessage;
555 type CompletedEvent = OutputEvent<CE, ProcErr>;
556
557 async fn mark_complete(&self, msg: Self::Message, completed_event: Self::CompletedEvent) {
558 SqsCompletionHandlerActor::mark_complete(self, msg, completed_event).await
559 }
560
561 async fn ack_message(&self, msg: Self::Message) {
562 SqsCompletionHandlerActor::ack_message(self, msg).await
563 }
564
565 async fn ack_all(&self, notify: Option<tokio::sync::oneshot::Sender<()>>) {
566 SqsCompletionHandlerActor::ack_all(self, notify).await
567 }
568}