Skip to main content

sentry_arroyo/processing/strategies/
run_task.rs

1use crate::processing::strategies::{
2    merge_commit_request, CommitRequest, MessageRejected, ProcessingStrategy, StrategyError,
3    SubmitError,
4};
5use crate::types::Message;
6use std::time::Duration;
7
8pub struct RunTask<TTransformed, F, N> {
9    pub function: F,
10    pub next_step: N,
11    pub message_carried_over: Option<Message<TTransformed>>,
12    pub commit_request_carried_over: Option<CommitRequest>,
13}
14
15impl<TTransformed, F, N> RunTask<TTransformed, F, N> {
16    pub fn new(function: F, next_step: N) -> Self {
17        Self {
18            function,
19            next_step,
20            message_carried_over: None,
21            commit_request_carried_over: None,
22        }
23    }
24}
25
26impl<TPayload, TTransformed, F, N> ProcessingStrategy<TPayload> for RunTask<TTransformed, F, N>
27where
28    TTransformed: Send + Sync,
29    F: FnMut(Message<TPayload>) -> Result<Message<TTransformed>, SubmitError<TPayload>>
30        + Send
31        + Sync
32        + 'static,
33    N: ProcessingStrategy<TTransformed> + 'static,
34{
35    fn poll(&mut self) -> Result<Option<CommitRequest>, StrategyError> {
36        match self.next_step.poll() {
37            Ok(commit_request) => {
38                self.commit_request_carried_over =
39                    merge_commit_request(self.commit_request_carried_over.take(), commit_request)
40            }
41            Err(invalid_message) => return Err(invalid_message),
42        }
43
44        if let Some(message) = self.message_carried_over.take() {
45            match self.next_step.submit(message) {
46                Err(SubmitError::MessageRejected(MessageRejected {
47                    message: transformed_message,
48                })) => {
49                    self.message_carried_over = Some(transformed_message);
50                }
51                Err(SubmitError::InvalidMessage(invalid_message)) => {
52                    return Err(invalid_message.into());
53                }
54                Ok(_) => {}
55            }
56        }
57
58        Ok(self.commit_request_carried_over.take())
59    }
60
61    fn submit(&mut self, message: Message<TPayload>) -> Result<(), SubmitError<TPayload>> {
62        if self.message_carried_over.is_some() {
63            return Err(SubmitError::MessageRejected(MessageRejected { message }));
64        }
65
66        let next_message = (self.function)(message)?;
67
68        match self.next_step.submit(next_message) {
69            Err(SubmitError::MessageRejected(MessageRejected {
70                message: transformed_message,
71            })) => {
72                self.message_carried_over = Some(transformed_message);
73            }
74            Err(SubmitError::InvalidMessage(invalid_message)) => {
75                return Err(SubmitError::InvalidMessage(invalid_message));
76            }
77            Ok(_) => {}
78        }
79        Ok(())
80    }
81
82    fn terminate(&mut self) {
83        self.next_step.terminate()
84    }
85
86    fn join(&mut self, timeout: Option<Duration>) -> Result<Option<CommitRequest>, StrategyError> {
87        let next_commit = self.next_step.join(timeout)?;
88        Ok(merge_commit_request(
89            self.commit_request_carried_over.take(),
90            next_commit,
91        ))
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use super::*;
98    use crate::{
99        processing::strategies::noop::Noop,
100        types::{BrokerMessage, InnerMessage, Message, Partition, Topic},
101    };
102    use chrono::Utc;
103
104    #[test]
105    fn test_run_task() {
106        fn identity(value: Message<String>) -> Result<Message<String>, SubmitError<String>> {
107            Ok(value)
108        }
109
110        let mut strategy = RunTask::new(identity, Noop {});
111
112        let partition = Partition::new(Topic::new("test"), 0);
113
114        strategy
115            .submit(Message {
116                inner_message: InnerMessage::BrokerMessage(BrokerMessage::new(
117                    "Hello world".to_string(),
118                    partition,
119                    0,
120                    Utc::now(),
121                )),
122            })
123            .unwrap();
124    }
125}