sqs_lambda/
sqs_consumer.rs

1use std::time::Duration;
2
3use lambda_runtime::Context;
4use log::{debug, error};
5use rusoto_sqs::Message as SqsMessage;
6use rusoto_sqs::{ReceiveMessageRequest, Sqs};
7use tokio::sync::mpsc::{channel, Sender};
8
9use tracing::instrument;
10
11use crate::consumer::Consumer;
12use async_trait::async_trait;
13
14use crate::completion_handler::CompletionHandler;
15use crate::event_processor::EventProcessorActor;
16
17use chrono::Utc;
18use std::marker::PhantomData;
19
20#[derive(Debug, Clone, Default)]
21pub struct ConsumePolicyBuilder {
22    deadline: Option<i64>,
23    stop_at: Option<Duration>,
24    max_empty_receives: Option<u16>,
25}
26
27impl ConsumePolicyBuilder {
28    pub fn with_max_empty_receives(mut self, arg: u16) -> Self {
29        self.max_empty_receives = Some(arg);
30        self
31    }
32
33    pub fn with_stop_at(mut self, arg: Duration) -> Self {
34        self.stop_at = Some(arg);
35        self
36    }
37
38    pub fn build(self, deadline: impl IntoDeadline) -> ConsumePolicy {
39        ConsumePolicy::new(
40            deadline,
41            self.stop_at.unwrap_or_else(|| Duration::from_secs(10)),
42            self.max_empty_receives.unwrap_or_else(|| 1),
43        )
44    }
45}
46
47#[derive(Debug, Clone)]
48pub struct ConsumePolicy {
49    deadline: i64,
50    stop_at: Duration,
51    max_empty_receives: u16,
52    empty_receives: u16,
53}
54
55pub trait IntoDeadline {
56    fn into_deadline(self) -> i64;
57}
58
59impl IntoDeadline for Context {
60    fn into_deadline(self) -> i64 {
61        self.deadline
62    }
63}
64
65impl IntoDeadline for i64 {
66    fn into_deadline(self) -> i64 {
67        self
68    }
69}
70
71impl ConsumePolicy {
72    pub fn new(deadline: impl IntoDeadline, stop_at: Duration, max_empty_receives: u16) -> Self {
73        Self {
74            deadline: deadline.into_deadline(),
75            stop_at,
76            max_empty_receives,
77            empty_receives: 0,
78        }
79    }
80
81    pub fn get_time_remaining_millis(&self) -> i64 {
82        self.deadline - Utc::now().timestamp_millis()
83    }
84
85    pub fn should_consume(&self) -> bool {
86        (self.stop_at.as_millis() <= self.get_time_remaining_millis() as u128)
87            && self.empty_receives <= self.max_empty_receives
88    }
89
90    pub fn register_received(&mut self, any: bool) {
91        if any {
92            self.empty_receives = 0;
93        } else {
94            self.empty_receives += 1;
95        }
96    }
97}
98
99pub struct SqsConsumer<S, CH>
100where
101    S: Sqs + Send + Sync + 'static,
102    CH: CompletionHandler + Clone + Send + Sync + 'static,
103{
104    sqs_client: S,
105    queue_url: String,
106    stored_events: Vec<SqsMessage>,
107    consume_policy: ConsumePolicy,
108    completion_handler: CH,
109    shutdown_subscriber: Option<tokio::sync::oneshot::Sender<()>>,
110    self_actor: Option<SqsConsumerActor<S, CH>>,
111}
112
113impl<S, CH> SqsConsumer<S, CH>
114where
115    S: Sqs + Send + Sync + 'static,
116    CH: CompletionHandler + Clone + Send + Sync + 'static,
117{
118    pub fn new(
119        sqs_client: S,
120        queue_url: String,
121        consume_policy: ConsumePolicy,
122        completion_handler: CH,
123        shutdown_subscriber: tokio::sync::oneshot::Sender<()>,
124    ) -> SqsConsumer<S, CH>
125    where
126        S: Sqs,
127    {
128        Self {
129            sqs_client,
130            queue_url,
131            stored_events: Vec::with_capacity(20),
132            consume_policy,
133            completion_handler,
134            shutdown_subscriber: Some(shutdown_subscriber),
135            self_actor: None,
136        }
137    }
138}
139impl<S: Sqs + Send + Sync + 'static, CH: CompletionHandler + Clone + Send + Sync + 'static>
140    SqsConsumer<S, CH>
141{
142    #[instrument(skip(self))]
143    pub async fn batch_get_events(&self, wait_time_seconds: i64) -> eyre::Result<Vec<SqsMessage>> {
144        debug!("Calling receive_message");
145        let visibility_timeout = Duration::from_millis(self.consume_policy.get_time_remaining_millis() as u64).as_secs() + 1;
146        let recv = self.sqs_client.receive_message(ReceiveMessageRequest {
147            max_number_of_messages: Some(10),
148            queue_url: self.queue_url.clone(),
149            wait_time_seconds: Some(wait_time_seconds),
150            visibility_timeout: Some(visibility_timeout as i64),
151            ..Default::default()
152        });
153
154        let recv = tokio::time::timeout(Duration::from_secs(wait_time_seconds as u64 + 20), recv)
155            .await??;
156        debug!("Called receive_message : {:?}", recv);
157
158        Ok(recv.messages.unwrap_or(vec![]))
159    }
160}
161
162#[derive_aktor::derive_actor]
163impl<S: Sqs + Send + Sync + 'static, CH: CompletionHandler + Clone + Send + Sync + 'static>
164    SqsConsumer<S, CH>
165{
166    #[instrument(skip(self, event_processor))]
167    pub async fn get_new_event(&mut self, event_processor: EventProcessorActor<SqsMessage>) {
168        debug!("New event request");
169        let should_consume = self.consume_policy.should_consume();
170
171        if self.stored_events.is_empty() && should_consume {
172            let new_events = match self.batch_get_events(1).await {
173                Ok(new_events) => new_events,
174                Err(e) => {
175                    error!("Failed to get new events with: {:?}", e);
176                    tokio::time::delay_for(Duration::from_secs(1)).await;
177                    self.self_actor
178                        .clone()
179                        .unwrap()
180                        .get_next_event(event_processor)
181                        .await;
182
183                    // Register the empty receive on error
184                    self.consume_policy.register_received(false);
185                    return;
186                }
187            };
188
189            self.consume_policy
190                .register_received(!new_events.is_empty());
191            self.stored_events.extend(new_events);
192        }
193
194        if !should_consume {
195            debug!("Done consuming, forcing ack");
196            let (tx, shutdown_notify) = tokio::sync::oneshot::channel();
197
198            // If we're past the point of consuming it's time to start acking
199            self.completion_handler.ack_all(Some(tx)).await;
200
201            let _ = shutdown_notify.await;
202            debug!("Ack complete");
203        }
204
205        if self.stored_events.is_empty() && !should_consume {
206            debug!("No more events to process, and we should not consume more");
207            let shutdown_subscriber = std::mem::replace(&mut self.shutdown_subscriber, None);
208            match shutdown_subscriber {
209                Some(shutdown_subscriber) => {
210                    shutdown_subscriber.send(()).unwrap();
211                }
212                None => debug!("Attempted to shut down with empty shutdown_subscriber"),
213            };
214
215            event_processor.stop_processing().await;
216            drop(event_processor);
217            return;
218        }
219
220        if let Some(next_event) = self.stored_events.pop() {
221            debug!("Sending next event to processor");
222            event_processor.process_event(next_event).await;
223            debug!("Sent next event to processor");
224        } else {
225            tokio::time::delay_for(Duration::from_millis(50)).await;
226            debug!("No events to send to processor");
227            self.self_actor
228                .clone()
229                .unwrap()
230                .get_next_event(event_processor)
231                .await;
232        }
233    }
234
235    pub async fn _p(&self, __p: PhantomData<(S, CH)>) {}
236}
237
238#[async_trait]
239impl<S, CH> Consumer<SqsMessage> for SqsConsumerActor<S, CH>
240where
241    S: Sqs + Send + Sync + 'static,
242    CH: CompletionHandler + Clone + Send + Sync + 'static,
243{
244    #[instrument(skip(self, event_processor))]
245    async fn get_next_event(&self, event_processor: EventProcessorActor<SqsMessage>) {
246        let msg = SqsConsumerMessage::get_new_event { event_processor };
247        self.queue_len
248            .clone()
249            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
250
251        let mut sender = self.sender.clone();
252        tokio::task::spawn(async move {
253            if let Err(e) = sender.send(msg).await {
254                panic!(
255                    concat!(
256                        "Receiver has failed with {}, propagating error. ",
257                        "SqsConsumerActor.get_next_event"
258                    ),
259                    e
260                )
261            }
262        });
263    }
264}