1pub mod utils;
2use std::{
3 collections::HashMap,
4 sync::{atomic::AtomicBool, Arc},
5 time::Duration,
6};
7
8use contants::task::TaskStatus;
9use redis::{self};
10use serde_json::Value;
11pub mod contants;
12use handler::{Handler, HandlerFn, HandlerMap};
13pub mod handler;
14use log::{debug, error, info};
15use redis::AsyncCommands;
16use tokio::sync::{Semaphore, TryAcquireError};
17
18pub mod task;
19
20#[derive(Debug)]
21pub struct SimpleTaskApp {
22 pub tasks: HandlerMap,
23 redis_client: redis::Client,
24 should_stop: AtomicBool,
25 available_workers: Arc<Semaphore>,
26}
27
28const WAITING_TASK_ID_QUEUE: &str = "simple_task:waiting_task_id";
29const PROCESSING_TASK_ID_QUEUE: &str = "simple_task:processing_task_id";
30const TASK_TABLE: &str = "simple_task:task";
31
32impl SimpleTaskApp {
33 pub fn new(redis_client: redis::Client, concurrency: usize) -> Self {
34 Self {
35 tasks: HandlerMap::new(),
36 redis_client,
37 should_stop: AtomicBool::new(false),
38 available_workers: Arc::new(Semaphore::new(concurrency as usize)),
39 }
40 }
41
42 pub fn gen_task_id() -> String {
43 uuid::Uuid::new_v4().to_string()
44 }
45
46 pub fn prepare_stop(&self) {
47 self.should_stop
48 .store(true, std::sync::atomic::Ordering::Relaxed);
49 }
50
51 pub fn should_stop(&self) -> bool {
52 self.should_stop.load(std::sync::atomic::Ordering::Relaxed)
53 }
54
55 pub fn wait_shutdown_background(self: &Arc<Self>) {
56 let _self = self.clone();
57 tokio::spawn(async move {
58 match tokio::signal::ctrl_c().await {
59 Ok(_) => {
60 _self.prepare_stop();
61 }
62 Err(e) => {}
63 };
64 });
65 }
66
67 pub async fn send_task(&self, handler_name: &str, input: Value) -> anyhow::Result<()> {
68 let mut conn = self.redis_client.get_async_connection().await?;
69
70 let task_id = Self::gen_task_id();
71 let task_id_key = join_key!(TASK_TABLE, &task_id);
72 let task_id_key = task_id_key.as_str();
73
74 conn.hset_multiple(
75 task_id_key,
76 &[
77 ("handler_name", handler_name),
78 ("input", &input.to_string()),
79 ("status", TaskStatus::Waiting.as_str()),
80 ],
81 )
82 .await?;
83 conn.lpush(WAITING_TASK_ID_QUEUE, task_id).await?;
84 Ok(())
85 }
86
87 pub async fn run_task(
88 self: &Arc<Self>,
89 task_id: &str,
90 handler: &Arc<Handler>,
91 input: Value,
92 ) -> anyhow::Result<()> {
93 let mut redis_conn = self.redis_client.get_async_connection().await?;
94 let f = &handler.func;
95 let task_id_key = join_key!(TASK_TABLE, task_id);
96 info!(
97 "Task {}, handler {}, task handler start.",
98 &task_id, &handler.name
99 );
100 match f(input).await {
101 Ok(result) => {
102 info!(
103 "Task {}, handler {}, task handler done",
104 &task_id, &handler.name
105 );
106 redis_conn
107 .hset_multiple(
108 task_id_key,
109 &[
110 ("status", TaskStatus::Done.as_str()),
111 ("result", &result.to_string()),
112 ],
113 )
114 .await?;
115 redis_conn
116 .lrem(PROCESSING_TASK_ID_QUEUE, 1, task_id)
117 .await?;
118 redis_conn.expire(task_id_key, 3600).await?;
119 }
120 Err(e) => {
121 error!(
122 "Task {}, handler {}, task handler failed {}",
123 &task_id, &handler.name, e
124 );
125 redis_conn
126 .hset_multiple(
127 task_id_key,
128 &[
129 ("status", TaskStatus::Error.as_str()),
130 ("error", &e.to_string()),
131 ],
132 )
133 .await?;
134 }
135 };
136 Ok(())
137 }
138
139 pub fn register_handler(&mut self, name: &str, func: HandlerFn) {
140 let name = name.to_string();
141 let task = handler::Handler {
142 name: name.clone(),
143 func,
144 };
145 self.tasks.insert(name, Arc::new(task));
146 }
147
148 pub fn log_for_start(&self) {
149 info!("Running simple task app...");
150 let m = &self.tasks;
151 info!("Registered handlers:");
152 for (name, _) in m.iter() {
153 info!(" {}", name);
154 }
155 let conn_info = self.redis_client.get_connection_info();
156 info!(
157 "Broker & result backend: redis://{}/{}",
158 conn_info.addr, conn_info.redis.db
159 );
160 info!(
161 "Concurrency: {}",
162 self.available_workers.available_permits()
163 );
164
165 info!("Press CTRL+C to quit");
166 }
167
168 pub async fn run(self) -> anyhow::Result<()> {
169 self.log_for_start();
170
171 let _self = Arc::new(self);
172 _self.wait_shutdown_background();
173 let mut conn = _self.redis_client.get_async_connection().await?;
174
175 loop {
176 let p = loop {
177 match _self.available_workers.clone().try_acquire_owned() {
178 Ok(p) => {
179 break p;
180 }
181 Err(e) => match e {
182 TryAcquireError::NoPermits => {
183 tokio::time::sleep(Duration::from_millis(100)).await;
184 }
185 TryAcquireError::Closed => {
186 anyhow::bail!("Semaphore closed");
187 }
188 },
189 }
190 };
191 if _self.should_stop() {
192 break;
193 }
194
195 let task_id: Option<String> = conn
196 .blmove(
197 WAITING_TASK_ID_QUEUE,
198 PROCESSING_TASK_ID_QUEUE,
199 redis::Direction::Right,
200 redis::Direction::Left,
201 1,
202 )
203 .await?;
204 let Some(ref task_id) = task_id else {
205 continue;
206 };
207
208 let task_id_key = join_key!(TASK_TABLE, task_id);
209
210 conn.hset(task_id_key, "status", TaskStatus::Processing.as_str())
211 .await?;
212
213 let task: HashMap<String, String> = conn.hgetall(task_id_key).await?;
214
215 let (input, handler_name) = match (task.get("input"), task.get("handler_name")) {
216 (Some(input), Some(handler_name)) => (input.clone(), handler_name.clone()),
217 _ => {
218 conn.hset_multiple(
219 task_id_key,
220 &[
221 ("status", TaskStatus::Error.as_str()),
222 ("error", "Missing field"),
223 ],
224 )
225 .await?;
226 continue;
227 }
228 };
229
230 let input: Value = serde_json::from_str(&input)?;
231 let handler = _self
232 .tasks
233 .get(&handler_name)
234 .ok_or(anyhow::anyhow!("Handler {} not found", handler_name))?
235 .clone();
236
237 info!("Task {} accepted, handler {}", &task_id, handler_name);
238
239 let task_id = task_id.clone();
240 let _self = _self.clone();
241 tokio::spawn(async move {
242 let _p = p;
243 if let Err(e) = _self.run_task(&task_id, &handler, input).await {
244 error!(
245 "Internal error when running task {}, handler {}, error {}",
246 &task_id, handler_name, e
247 );
248 }
249
250 debug!(
251 "Task done. Task id {}, handler {}, Releasing SemaphorePermit.",
252 &task_id, handler_name
253 );
254 });
255 }
256 Ok(())
257 }
258}
259
260#[macro_export]
261macro_rules! register_handler {
262 ($a:expr, $b:expr) => {
263 $a.register_handler(
264 $crate::utils::lang::type_name_of($b),
265 std::sync::Arc::new(|input| Box::pin($b(input))),
266 )
267 };
268}
269
270#[macro_export]
271macro_rules! send_task {
272 ($app:expr, $f:expr, $value:expr) => {
273 $app.send_task($crate::utils::lang::type_name_of($f), $value)
274 };
275}
276
277#[macro_export]
278macro_rules! join_key {
279 ($a:expr) => {
280 $a
281 };
282 ($a:expr, $($b:expr),+) => {
283 &vec![$a, $($b),+].join(":")
284 };
285}