1mod error;
2
3pub use error::PublishError;
4pub use lapin;
5
6use lapin::types::FieldTable;
7use tower::{BoxError, Layer, Service, ServiceExt};
8
9use std::convert::Infallible;
10use std::fmt::Debug;
11use std::future::Future;
12use std::sync::Arc;
13
14use self::error::HandlerError;
15use lapin::options::{BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions};
16use lapin::protocol::basic::AMQPProperties;
17use tokio_stream::StreamExt;
18
19pub trait AMQPTaskResult: Sized + Send + Debug {
20 type EncodeError: std::error::Error + Send + Sync + 'static;
21
22 fn encode(self) -> Result<Vec<u8>, Self::EncodeError>;
23
24 fn publish_exchange(&self) -> String;
25 fn publish_routing_key(&self) -> String;
26
27 fn publish(
28 self,
29 channel: &lapin::Channel,
30 ) -> impl Future<Output = Result<(), PublishError<Self>>> + Send {
31 async {
32 let exchange = self.publish_exchange();
33 let routing_key = self.publish_routing_key();
34 let payload = self.encode().map_err(PublishError::Encode)?;
35 channel
36 .basic_publish(
37 &exchange,
38 &routing_key,
39 BasicPublishOptions::default(),
40 &payload,
41 AMQPProperties::default(),
42 )
43 .await?;
44 Ok(())
45 }
46 }
47}
48
49impl AMQPTaskResult for () {
50 type EncodeError = Infallible;
51
52 fn encode(self) -> Result<Vec<u8>, Self::EncodeError> {
53 unreachable!("empty result can't be encoded")
54 }
55
56 fn publish_exchange(&self) -> String {
57 unreachable!("empty result has no exchange")
58 }
59
60 fn publish_routing_key(&self) -> String {
61 unreachable!("empty result has no routing key")
62 }
63
64 async fn publish(self, _channel: &lapin::Channel) -> Result<(), PublishError<Self>> {
65 Ok(())
66 }
67}
68
69pub trait AMQPTask: Sized + Debug {
70 type DecodeError: std::error::Error + Send + Sync + 'static;
71 type TaskResult: AMQPTaskResult;
72
73 fn decode(data: Vec<u8>) -> Result<Self, Self::DecodeError>;
75 fn queue() -> &'static str;
76
77 fn debug(&self) -> String {
80 format!("{:?}", self)
81 }
82}
83
84#[derive(Debug, Clone)]
85pub struct WorkerConfig {
86 pub ack_on_decode_error: bool,
89}
90
91impl Default for WorkerConfig {
92 fn default() -> Self {
93 Self {
94 ack_on_decode_error: true,
95 }
96 }
97}
98
99pub struct AMQPWorker<T, S> {
100 service: S,
101 inner: Arc<Inner>,
102 _phantom: std::marker::PhantomData<T>,
103}
104
105struct Inner {
106 channel: lapin::Channel,
107 config: WorkerConfig,
108 consumer_tag: String,
109}
110
111impl<T, S> Clone for AMQPWorker<T, S>
112where
113 S: Clone,
114{
115 fn clone(&self) -> Self {
116 Self {
117 service: self.service.clone(),
118 inner: self.inner.clone(),
119 _phantom: std::marker::PhantomData,
120 }
121 }
122}
123
124impl<T, S> AMQPWorker<T, S>
125where
126 T: AMQPTask + Send + 'static,
127 S: Service<T, Response = T::TaskResult> + Send + 'static + Clone,
128 <S as Service<T>>::Future: Send,
129 <S as Service<T>>::Error: Debug + Into<BoxError>,
130{
131 pub fn new(
132 consumer_tag: impl Into<String>,
133 service: S,
134 channel: lapin::Channel,
135 config: WorkerConfig,
136 ) -> Self {
137 Self {
138 service,
139 inner: Arc::new(Inner {
140 consumer_tag: consumer_tag.into(),
141 config,
142 channel,
143 }),
144 _phantom: std::marker::PhantomData,
145 }
146 }
147
148 pub fn consumer_tag(&self) -> &str {
149 &self.inner.consumer_tag
150 }
151
152 pub fn channel(&self) -> &lapin::Channel {
153 &self.inner.channel
154 }
155
156 pub fn config(&self) -> &WorkerConfig {
157 &self.inner.config
158 }
159
160 pub fn add_layer<L>(self, layer: L) -> AMQPWorker<T, L::Service>
161 where
162 L: Layer<S>,
163 <L as Layer<S>>::Service: Service<T, Response = T::TaskResult> + Send + 'static + Clone,
164 <<L as Layer<S>>::Service as Service<T>>::Error: Debug + Into<BoxError>,
165 {
166 let service = layer.layer(self.service);
167 AMQPWorker {
168 service,
169 inner: self.inner,
170 _phantom: std::marker::PhantomData,
171 }
172 }
173
174 #[tracing::instrument(level = "info", skip(self), fields(
175 consumer = self.consumer_tag(),
176 task = task.debug()
177 ))]
178 async fn handle_task(&mut self, task: T) -> Result<(), HandlerError<T>> {
179 tracing::info!("Calling service");
180 let task_result: T::TaskResult = self.service.call(task).await.map_err(Into::into)?;
181 task_result
182 .publish(&self.inner.channel)
183 .await
184 .map_err(|e| HandlerError::Publish(e))?;
185 Ok(())
186 }
187
188 async fn ready(&mut self) -> Result<Self, BoxError> {
189 let svc = self
190 .service
191 .clone()
192 .ready_oneshot()
193 .await
194 .map_err(Into::into)?;
195 Ok(AMQPWorker {
196 service: svc,
197 inner: self.inner.clone(),
198 _phantom: std::marker::PhantomData,
199 })
200 }
201
202 pub async fn consume(
204 mut self,
205 consume_options: BasicConsumeOptions,
206 consume_arguments: FieldTable,
207 ) -> Result<(), BoxError> {
208 tracing::info!(
209 config = ?self.config(),
210 consumer_tag = self.consumer_tag(),
211 "Starting worker consumer"
212 );
213 let mut consumer = self
214 .inner
215 .channel
216 .basic_consume(
217 T::queue(),
218 &self.inner.consumer_tag,
219 consume_options,
220 consume_arguments,
221 )
222 .await?;
223 loop {
224 let mut worker = self.ready().await?;
225 tracing::info!(
226 consumer = worker.consumer_tag(),
227 "Consumer ready, waiting for delivery"
228 );
229 if let Some(attempted_delivery) = consumer.next().await {
230 tokio::spawn(async move {
231 let delivery = match attempted_delivery {
232 Ok(delivery) => delivery,
233 Err(e) => {
234 tracing::error!(
235 name = worker.inner.consumer_tag,
236 ?e,
237 "Delivery of message failed"
238 );
239 return Ok(());
240 }
241 };
242 let delivery_tag = delivery.delivery_tag;
243 let task = match T::decode(delivery.data) {
244 Ok(task) => task,
245 Err(e) => {
246 tracing::error!(
247 consumer = worker.inner.consumer_tag,
248 ?e,
249 ack = worker.inner.config.ack_on_decode_error,
250 "Failed to decode task"
251 );
252 if worker.config().ack_on_decode_error {
253 worker
254 .channel()
255 .basic_ack(delivery_tag, BasicAckOptions::default())
256 .await?;
257 }
258 return Ok(());
259 }
260 };
261 let task_debug = task.debug();
262 match worker.handle_task(task).await {
263 Ok(_) => {
264 tracing::info!(
265 consumer = worker.consumer_tag(),
266 delivery_tag = delivery.delivery_tag,
267 "Delivery handled successfully"
268 );
269 worker
270 .channel()
271 .basic_ack(delivery_tag, BasicAckOptions::default())
272 .await?;
273 return Ok(());
274 }
275 Err(e) => {
276 tracing::error!(
277 consumer = ?worker.consumer_tag(),
278 delivery_tag = delivery.delivery_tag,
279 ?e,
280 task = task_debug,
281 "Task handler returned error"
282 );
283 worker
284 .channel()
285 .basic_nack(
286 delivery_tag,
287 BasicNackOptions {
288 multiple: false,
289 requeue: true,
290 },
291 )
292 .await?;
293 }
294 }
295 Ok::<_, BoxError>(())
296 });
297 } else {
298 return Ok(());
299 }
300 }
301 }
302}