1use std::time::Duration;
2
3use lambda_runtime::Context;
4use log::{debug, error};
5use rusoto_sqs::Message as SqsMessage;
6use rusoto_sqs::{ReceiveMessageRequest, Sqs};
7use tokio::sync::mpsc::{channel, Sender};
8
9use tracing::instrument;
10
11use crate::consumer::Consumer;
12use async_trait::async_trait;
13
14use crate::completion_handler::CompletionHandler;
15use crate::event_processor::EventProcessorActor;
16
17use chrono::Utc;
18use std::marker::PhantomData;
19
20#[derive(Debug, Clone, Default)]
21pub struct ConsumePolicyBuilder {
22 deadline: Option<i64>,
23 stop_at: Option<Duration>,
24 max_empty_receives: Option<u16>,
25}
26
27impl ConsumePolicyBuilder {
28 pub fn with_max_empty_receives(mut self, arg: u16) -> Self {
29 self.max_empty_receives = Some(arg);
30 self
31 }
32
33 pub fn with_stop_at(mut self, arg: Duration) -> Self {
34 self.stop_at = Some(arg);
35 self
36 }
37
38 pub fn build(self, deadline: impl IntoDeadline) -> ConsumePolicy {
39 ConsumePolicy::new(
40 deadline,
41 self.stop_at.unwrap_or_else(|| Duration::from_secs(10)),
42 self.max_empty_receives.unwrap_or_else(|| 1),
43 )
44 }
45}
46
47#[derive(Debug, Clone)]
48pub struct ConsumePolicy {
49 deadline: i64,
50 stop_at: Duration,
51 max_empty_receives: u16,
52 empty_receives: u16,
53}
54
55pub trait IntoDeadline {
56 fn into_deadline(self) -> i64;
57}
58
59impl IntoDeadline for Context {
60 fn into_deadline(self) -> i64 {
61 self.deadline
62 }
63}
64
65impl IntoDeadline for i64 {
66 fn into_deadline(self) -> i64 {
67 self
68 }
69}
70
71impl ConsumePolicy {
72 pub fn new(deadline: impl IntoDeadline, stop_at: Duration, max_empty_receives: u16) -> Self {
73 Self {
74 deadline: deadline.into_deadline(),
75 stop_at,
76 max_empty_receives,
77 empty_receives: 0,
78 }
79 }
80
81 pub fn get_time_remaining_millis(&self) -> i64 {
82 self.deadline - Utc::now().timestamp_millis()
83 }
84
85 pub fn should_consume(&self) -> bool {
86 (self.stop_at.as_millis() <= self.get_time_remaining_millis() as u128)
87 && self.empty_receives <= self.max_empty_receives
88 }
89
90 pub fn register_received(&mut self, any: bool) {
91 if any {
92 self.empty_receives = 0;
93 } else {
94 self.empty_receives += 1;
95 }
96 }
97}
98
99pub struct SqsConsumer<S, CH>
100where
101 S: Sqs + Send + Sync + 'static,
102 CH: CompletionHandler + Clone + Send + Sync + 'static,
103{
104 sqs_client: S,
105 queue_url: String,
106 stored_events: Vec<SqsMessage>,
107 consume_policy: ConsumePolicy,
108 completion_handler: CH,
109 shutdown_subscriber: Option<tokio::sync::oneshot::Sender<()>>,
110 self_actor: Option<SqsConsumerActor<S, CH>>,
111}
112
113impl<S, CH> SqsConsumer<S, CH>
114where
115 S: Sqs + Send + Sync + 'static,
116 CH: CompletionHandler + Clone + Send + Sync + 'static,
117{
118 pub fn new(
119 sqs_client: S,
120 queue_url: String,
121 consume_policy: ConsumePolicy,
122 completion_handler: CH,
123 shutdown_subscriber: tokio::sync::oneshot::Sender<()>,
124 ) -> SqsConsumer<S, CH>
125 where
126 S: Sqs,
127 {
128 Self {
129 sqs_client,
130 queue_url,
131 stored_events: Vec::with_capacity(20),
132 consume_policy,
133 completion_handler,
134 shutdown_subscriber: Some(shutdown_subscriber),
135 self_actor: None,
136 }
137 }
138}
139impl<S: Sqs + Send + Sync + 'static, CH: CompletionHandler + Clone + Send + Sync + 'static>
140 SqsConsumer<S, CH>
141{
142 #[instrument(skip(self))]
143 pub async fn batch_get_events(&self, wait_time_seconds: i64) -> eyre::Result<Vec<SqsMessage>> {
144 debug!("Calling receive_message");
145 let visibility_timeout = Duration::from_millis(self.consume_policy.get_time_remaining_millis() as u64).as_secs() + 1;
146 let recv = self.sqs_client.receive_message(ReceiveMessageRequest {
147 max_number_of_messages: Some(10),
148 queue_url: self.queue_url.clone(),
149 wait_time_seconds: Some(wait_time_seconds),
150 visibility_timeout: Some(visibility_timeout as i64),
151 ..Default::default()
152 });
153
154 let recv = tokio::time::timeout(Duration::from_secs(wait_time_seconds as u64 + 20), recv)
155 .await??;
156 debug!("Called receive_message : {:?}", recv);
157
158 Ok(recv.messages.unwrap_or(vec![]))
159 }
160}
161
162#[derive_aktor::derive_actor]
163impl<S: Sqs + Send + Sync + 'static, CH: CompletionHandler + Clone + Send + Sync + 'static>
164 SqsConsumer<S, CH>
165{
166 #[instrument(skip(self, event_processor))]
167 pub async fn get_new_event(&mut self, event_processor: EventProcessorActor<SqsMessage>) {
168 debug!("New event request");
169 let should_consume = self.consume_policy.should_consume();
170
171 if self.stored_events.is_empty() && should_consume {
172 let new_events = match self.batch_get_events(1).await {
173 Ok(new_events) => new_events,
174 Err(e) => {
175 error!("Failed to get new events with: {:?}", e);
176 tokio::time::delay_for(Duration::from_secs(1)).await;
177 self.self_actor
178 .clone()
179 .unwrap()
180 .get_next_event(event_processor)
181 .await;
182
183 self.consume_policy.register_received(false);
185 return;
186 }
187 };
188
189 self.consume_policy
190 .register_received(!new_events.is_empty());
191 self.stored_events.extend(new_events);
192 }
193
194 if !should_consume {
195 debug!("Done consuming, forcing ack");
196 let (tx, shutdown_notify) = tokio::sync::oneshot::channel();
197
198 self.completion_handler.ack_all(Some(tx)).await;
200
201 let _ = shutdown_notify.await;
202 debug!("Ack complete");
203 }
204
205 if self.stored_events.is_empty() && !should_consume {
206 debug!("No more events to process, and we should not consume more");
207 let shutdown_subscriber = std::mem::replace(&mut self.shutdown_subscriber, None);
208 match shutdown_subscriber {
209 Some(shutdown_subscriber) => {
210 shutdown_subscriber.send(()).unwrap();
211 }
212 None => debug!("Attempted to shut down with empty shutdown_subscriber"),
213 };
214
215 event_processor.stop_processing().await;
216 drop(event_processor);
217 return;
218 }
219
220 if let Some(next_event) = self.stored_events.pop() {
221 debug!("Sending next event to processor");
222 event_processor.process_event(next_event).await;
223 debug!("Sent next event to processor");
224 } else {
225 tokio::time::delay_for(Duration::from_millis(50)).await;
226 debug!("No events to send to processor");
227 self.self_actor
228 .clone()
229 .unwrap()
230 .get_next_event(event_processor)
231 .await;
232 }
233 }
234
235 pub async fn _p(&self, __p: PhantomData<(S, CH)>) {}
236}
237
238#[async_trait]
239impl<S, CH> Consumer<SqsMessage> for SqsConsumerActor<S, CH>
240where
241 S: Sqs + Send + Sync + 'static,
242 CH: CompletionHandler + Clone + Send + Sync + 'static,
243{
244 #[instrument(skip(self, event_processor))]
245 async fn get_next_event(&self, event_processor: EventProcessorActor<SqsMessage>) {
246 let msg = SqsConsumerMessage::get_new_event { event_processor };
247 self.queue_len
248 .clone()
249 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
250
251 let mut sender = self.sender.clone();
252 tokio::task::spawn(async move {
253 if let Err(e) = sender.send(msg).await {
254 panic!(
255 concat!(
256 "Receiver has failed with {}, propagating error. ",
257 "SqsConsumerActor.get_next_event"
258 ),
259 e
260 )
261 }
262 });
263 }
264}