1use std::error::Error;
32use std::fmt::{Debug, Display, Error as FmtError, Formatter};
33
34use async_trait::async_trait;
35use celery::{
36 error::BackendError::*,
37 error::BrokerError::BadRoutingPattern,
38 error::CeleryError::{self, *},
39 task::{AsyncResult, Signature, Task},
40 Celery, CeleryBuilder,
41};
42use tourniquet::{Connector, Next, RoundRobin};
43#[cfg(feature = "trace")]
44use tracing::{
45 field::{display, Empty},
46 instrument, Span,
47};
48
49pub struct RRCeleryError(CeleryError);
52
53impl Next for RRCeleryError {
54 fn is_next(&self) -> bool {
55 match self.0 {
56 BrokerError(BadRoutingPattern(_)) => false,
57 BrokerError(_) | IoError(_) | ProtocolError(_) => true,
58 BackendError(NotConfigured | Timeout | Redis(_)) => true,
59 BackendError(Serialization(_) | Pool(_) | PoolCreationError(_) | TaskFailed(_)) => {
60 false
61 }
62 NoQueueToConsume
63 | ForcedShutdown
64 | TaskRegistrationError(_)
65 | UnregisteredTaskError(_) => false,
66 }
67 }
68}
69
70impl From<CeleryError> for RRCeleryError {
71 fn from(e: CeleryError) -> Self {
72 Self(e)
73 }
74}
75
76impl From<RRCeleryError> for CeleryError {
77 fn from(e: RRCeleryError) -> Self {
78 e.0
79 }
80}
81
82impl Display for RRCeleryError {
83 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
84 Display::fmt(&self.0, f)
85 }
86}
87
88impl Debug for RRCeleryError {
89 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
90 Debug::fmt(&self.0, f)
91 }
92}
93
94impl Error for RRCeleryError {
95 fn source(&self) -> Option<&(dyn Error + 'static)> {
96 Some(&self.0)
97 }
98}
99
100pub struct CeleryConnector<'a> {
110 pub name: &'a str,
111 pub default_queue: Option<&'a str>,
112 pub routes: &'a [(&'a str, &'a str)],
113 pub connection_timeout: Option<u32>,
114}
115
116impl<'a> Default for CeleryConnector<'a> {
117 fn default() -> Self {
118 Self { name: "celery", default_queue: None, routes: &[], connection_timeout: None }
119 }
120}
121
122#[async_trait]
123impl<'a> Connector<String, Celery, RRCeleryError> for CeleryConnector<'a> {
124 #[cfg_attr(feature = "trace", tracing::instrument(skip(self), err))]
125 async fn connect(&self, url: &String) -> Result<Celery, RRCeleryError> {
126 let mut builder = CeleryBuilder::new(self.name, url.as_ref());
127
128 if let Some(queue) = self.default_queue {
129 builder = builder.default_queue(queue);
130 }
131 for (pattern, queue) in self.routes {
132 builder = builder.task_route(pattern, queue);
133 }
134 if let Some(timeout) = self.connection_timeout {
135 builder = builder.broker_connection_timeout(timeout);
136 }
137
138 Ok(builder.build().await?)
139 }
140}
141
142#[async_trait]
143pub trait RoundRobinExt {
144 async fn send_task<T, F>(&self, task_gen: F) -> Result<AsyncResult, CeleryError>
145 where
146 T: Task + 'static,
147 F: Fn() -> Signature<T> + Send + Sync;
148}
149
150#[async_trait]
151impl<SvcSrc, Conn> RoundRobinExt for RoundRobin<SvcSrc, Celery, RRCeleryError, Conn>
152where
153 SvcSrc: Debug + Send + Sync,
154 Conn: Connector<SvcSrc, Celery, RRCeleryError> + Send + Sync,
155{
156 #[cfg_attr(
161 feature = "trace",
162 instrument(
163 fields(task_name = display(Signature::<T>::task_name()), task_id = Empty),
164 skip(self, task_gen),
165 err,
166 ),
167 )]
168 async fn send_task<T, F>(&self, task_gen: F) -> Result<AsyncResult, CeleryError>
169 where
170 T: Task + 'static,
171 F: Fn() -> Signature<T> + Send + Sync,
172 {
173 log::debug!("Sending task {}", Signature::<T>::task_name());
174
175 let task_gen = &task_gen;
176 let task =
177 self.run(|celery| async move { Ok(celery.send_task(task_gen()).await?) }).await?;
178
179 #[cfg(feature = "trace")]
180 Span::current().record("task_id", &display(&task.task_id));
181
182 Ok(task)
183 }
184}
185
186pub type CeleryRoundRobin = RoundRobin<String, Celery, RRCeleryError, CeleryConnector<'static>>;