tower_amqp/
lib.rs

1mod error;
2
3pub use error::PublishError;
4pub use lapin;
5
6use lapin::types::FieldTable;
7use tower::{BoxError, Layer, Service, ServiceExt};
8
9use std::convert::Infallible;
10use std::fmt::Debug;
11use std::future::Future;
12use std::sync::Arc;
13
14use self::error::HandlerError;
15use lapin::options::{BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions};
16use lapin::protocol::basic::AMQPProperties;
17use tokio_stream::StreamExt;
18
19pub trait AMQPTaskResult: Sized + Send + Debug {
20    type EncodeError: std::error::Error + Send + Sync + 'static;
21
22    fn encode(self) -> Result<Vec<u8>, Self::EncodeError>;
23
24    fn publish_exchange(&self) -> String;
25    fn publish_routing_key(&self) -> String;
26
27    fn publish(
28        self,
29        channel: &lapin::Channel,
30    ) -> impl Future<Output = Result<(), PublishError<Self>>> + Send {
31        async {
32            let exchange = self.publish_exchange();
33            let routing_key = self.publish_routing_key();
34            let payload = self.encode().map_err(PublishError::Encode)?;
35            channel
36                .basic_publish(
37                    &exchange,
38                    &routing_key,
39                    BasicPublishOptions::default(),
40                    &payload,
41                    AMQPProperties::default(),
42                )
43                .await?;
44            Ok(())
45        }
46    }
47}
48
49impl AMQPTaskResult for () {
50    type EncodeError = Infallible;
51
52    fn encode(self) -> Result<Vec<u8>, Self::EncodeError> {
53        unreachable!("empty result can't be encoded")
54    }
55
56    fn publish_exchange(&self) -> String {
57        unreachable!("empty result has no exchange")
58    }
59
60    fn publish_routing_key(&self) -> String {
61        unreachable!("empty result has no routing key")
62    }
63
64    async fn publish(self, _channel: &lapin::Channel) -> Result<(), PublishError<Self>> {
65        Ok(())
66    }
67}
68
69pub trait AMQPTask: Sized + Debug {
70    type DecodeError: std::error::Error + Send + Sync + 'static;
71    type TaskResult: AMQPTaskResult;
72
73    /// Decode a task from a byte slice, this allows tasks to use different serialization formats
74    fn decode(data: Vec<u8>) -> Result<Self, Self::DecodeError>;
75    fn queue() -> &'static str;
76
77    /// Debug representation of the task, override this if your task has such a huge payload that
78    /// it would clog up the logs
79    fn debug(&self) -> String {
80        format!("{:?}", self)
81    }
82}
83
84#[derive(Debug, Clone)]
85pub struct WorkerConfig {
86    /// - If true, the worker will ack a message if it fails to decode the task, use this option if you are sure that its a publisher error.
87    /// - If false, the worker will nack a message if it fails to decode the task, use this option if you might be attaching the worker to the wrong queue or your decoding logic is faulty.
88    pub ack_on_decode_error: bool,
89}
90
91impl Default for WorkerConfig {
92    fn default() -> Self {
93        Self {
94            ack_on_decode_error: true,
95        }
96    }
97}
98
99pub struct AMQPWorker<T, S> {
100    service: S,
101    inner: Arc<Inner>,
102    _phantom: std::marker::PhantomData<T>,
103}
104
105struct Inner {
106    channel: lapin::Channel,
107    config: WorkerConfig,
108    consumer_tag: String,
109}
110
111impl<T, S> Clone for AMQPWorker<T, S>
112where
113    S: Clone,
114{
115    fn clone(&self) -> Self {
116        Self {
117            service: self.service.clone(),
118            inner: self.inner.clone(),
119            _phantom: std::marker::PhantomData,
120        }
121    }
122}
123
124impl<T, S> AMQPWorker<T, S>
125where
126    T: AMQPTask + Send + 'static,
127    S: Service<T, Response = T::TaskResult> + Send + 'static + Clone,
128    <S as Service<T>>::Future: Send,
129    <S as Service<T>>::Error: Debug + Into<BoxError>,
130{
131    pub fn new(
132        consumer_tag: impl Into<String>,
133        service: S,
134        channel: lapin::Channel,
135        config: WorkerConfig,
136    ) -> Self {
137        Self {
138            service,
139            inner: Arc::new(Inner {
140                consumer_tag: consumer_tag.into(),
141                config,
142                channel,
143            }),
144            _phantom: std::marker::PhantomData,
145        }
146    }
147
148    pub fn consumer_tag(&self) -> &str {
149        &self.inner.consumer_tag
150    }
151
152    pub fn channel(&self) -> &lapin::Channel {
153        &self.inner.channel
154    }
155
156    pub fn config(&self) -> &WorkerConfig {
157        &self.inner.config
158    }
159
160    pub fn add_layer<L>(self, layer: L) -> AMQPWorker<T, L::Service>
161    where
162        L: Layer<S>,
163        <L as Layer<S>>::Service: Service<T, Response = T::TaskResult> + Send + 'static + Clone,
164        <<L as Layer<S>>::Service as Service<T>>::Error: Debug + Into<BoxError>,
165    {
166        let service = layer.layer(self.service);
167        AMQPWorker {
168            service,
169            inner: self.inner,
170            _phantom: std::marker::PhantomData,
171        }
172    }
173
174    #[tracing::instrument(level = "info", skip(self), fields(
175        consumer = self.consumer_tag(),
176        task = task.debug()
177    ))]
178    async fn handle_task(&mut self, task: T) -> Result<(), HandlerError<T>> {
179        tracing::info!("Calling service");
180        let task_result: T::TaskResult = self.service.call(task).await.map_err(Into::into)?;
181        task_result
182            .publish(&self.inner.channel)
183            .await
184            .map_err(|e| HandlerError::Publish(e))?;
185        Ok(())
186    }
187
188    async fn ready(&mut self) -> Result<Self, BoxError> {
189        let svc = self
190            .service
191            .clone()
192            .ready_oneshot()
193            .await
194            .map_err(Into::into)?;
195        Ok(AMQPWorker {
196            service: svc,
197            inner: self.inner.clone(),
198            _phantom: std::marker::PhantomData,
199        })
200    }
201
202    /// Start consuming tasks from the AMQP queue
203    pub async fn consume(
204        mut self,
205        consume_options: BasicConsumeOptions,
206        consume_arguments: FieldTable,
207    ) -> Result<(), BoxError> {
208        tracing::info!(
209            config = ?self.config(),
210            consumer_tag = self.consumer_tag(),
211            "Starting worker consumer"
212        );
213        let mut consumer = self
214            .inner
215            .channel
216            .basic_consume(
217                T::queue(),
218                &self.inner.consumer_tag,
219                consume_options,
220                consume_arguments,
221            )
222            .await?;
223        loop {
224            let mut worker = self.ready().await?;
225            tracing::info!(
226                consumer = worker.consumer_tag(),
227                "Consumer ready, waiting for delivery"
228            );
229            if let Some(attempted_delivery) = consumer.next().await {
230                tokio::spawn(async move {
231                    let delivery = match attempted_delivery {
232                        Ok(delivery) => delivery,
233                        Err(e) => {
234                            tracing::error!(
235                                name = worker.inner.consumer_tag,
236                                ?e,
237                                "Delivery of message failed"
238                            );
239                            return Ok(());
240                        }
241                    };
242                    let delivery_tag = delivery.delivery_tag;
243                    let task = match T::decode(delivery.data) {
244                        Ok(task) => task,
245                        Err(e) => {
246                            tracing::error!(
247                                consumer = worker.inner.consumer_tag,
248                                ?e,
249                                ack = worker.inner.config.ack_on_decode_error,
250                                "Failed to decode task"
251                            );
252                            if worker.config().ack_on_decode_error {
253                                worker
254                                    .channel()
255                                    .basic_ack(delivery_tag, BasicAckOptions::default())
256                                    .await?;
257                            }
258                            return Ok(());
259                        }
260                    };
261                    let task_debug = task.debug();
262                    match worker.handle_task(task).await {
263                        Ok(_) => {
264                            tracing::info!(
265                                consumer = worker.consumer_tag(),
266                                delivery_tag = delivery.delivery_tag,
267                                "Delivery handled successfully"
268                            );
269                            worker
270                                .channel()
271                                .basic_ack(delivery_tag, BasicAckOptions::default())
272                                .await?;
273                            return Ok(());
274                        }
275                        Err(e) => {
276                            tracing::error!(
277                                consumer = ?worker.consumer_tag(),
278                                delivery_tag = delivery.delivery_tag,
279                                ?e,
280                                task = task_debug,
281                                "Task handler returned error"
282                            );
283                            worker
284                                .channel()
285                                .basic_nack(
286                                    delivery_tag,
287                                    BasicNackOptions {
288                                        multiple: false,
289                                        requeue: true,
290                                    },
291                                )
292                                .await?;
293                        }
294                    }
295                    Ok::<_, BoxError>(())
296                });
297            } else {
298                return Ok(());
299            }
300        }
301    }
302}