simple_task/
lib.rs

1pub mod utils;
2use std::{
3    collections::HashMap,
4    sync::{atomic::AtomicBool, Arc},
5    time::Duration,
6};
7
8use contants::task::TaskStatus;
9use redis::{self};
10use serde_json::Value;
11pub mod contants;
12use handler::{Handler, HandlerFn, HandlerMap};
13pub mod handler;
14use log::{debug, error, info};
15use redis::AsyncCommands;
16use tokio::sync::{Semaphore, TryAcquireError};
17
18pub mod task;
19
20#[derive(Debug)]
21pub struct SimpleTaskApp {
22    pub tasks: HandlerMap,
23    redis_client: redis::Client,
24    should_stop: AtomicBool,
25    available_workers: Arc<Semaphore>,
26}
27
28const WAITING_TASK_ID_QUEUE: &str = "simple_task:waiting_task_id";
29const PROCESSING_TASK_ID_QUEUE: &str = "simple_task:processing_task_id";
30const TASK_TABLE: &str = "simple_task:task";
31
32impl SimpleTaskApp {
33    pub fn new(redis_client: redis::Client, concurrency: usize) -> Self {
34        Self {
35            tasks: HandlerMap::new(),
36            redis_client,
37            should_stop: AtomicBool::new(false),
38            available_workers: Arc::new(Semaphore::new(concurrency as usize)),
39        }
40    }
41
42    pub fn gen_task_id() -> String {
43        uuid::Uuid::new_v4().to_string()
44    }
45
46    pub fn prepare_stop(&self) {
47        self.should_stop
48            .store(true, std::sync::atomic::Ordering::Relaxed);
49    }
50
51    pub fn should_stop(&self) -> bool {
52        self.should_stop.load(std::sync::atomic::Ordering::Relaxed)
53    }
54
55    pub fn wait_shutdown_background(self: &Arc<Self>) {
56        let _self = self.clone();
57        tokio::spawn(async move {
58            match tokio::signal::ctrl_c().await {
59                Ok(_) => {
60                    _self.prepare_stop();
61                }
62                Err(e) => {}
63            };
64        });
65    }
66
67    pub async fn send_task(&self, handler_name: &str, input: Value) -> anyhow::Result<()> {
68        let mut conn = self.redis_client.get_async_connection().await?;
69
70        let task_id = Self::gen_task_id();
71        let task_id_key = join_key!(TASK_TABLE, &task_id);
72        let task_id_key = task_id_key.as_str();
73
74        conn.hset_multiple(
75            task_id_key,
76            &[
77                ("handler_name", handler_name),
78                ("input", &input.to_string()),
79                ("status", TaskStatus::Waiting.as_str()),
80            ],
81        )
82        .await?;
83        conn.lpush(WAITING_TASK_ID_QUEUE, task_id).await?;
84        Ok(())
85    }
86
87    pub async fn run_task(
88        self: &Arc<Self>,
89        task_id: &str,
90        handler: &Arc<Handler>,
91        input: Value,
92    ) -> anyhow::Result<()> {
93        let mut redis_conn = self.redis_client.get_async_connection().await?;
94        let f = &handler.func;
95        let task_id_key = join_key!(TASK_TABLE, task_id);
96        info!(
97            "Task {}, handler {}, task handler start.",
98            &task_id, &handler.name
99        );
100        match f(input).await {
101            Ok(result) => {
102                info!(
103                    "Task {}, handler {}, task handler done",
104                    &task_id, &handler.name
105                );
106                redis_conn
107                    .hset_multiple(
108                        task_id_key,
109                        &[
110                            ("status", TaskStatus::Done.as_str()),
111                            ("result", &result.to_string()),
112                        ],
113                    )
114                    .await?;
115                redis_conn
116                    .lrem(PROCESSING_TASK_ID_QUEUE, 1, task_id)
117                    .await?;
118                redis_conn.expire(task_id_key, 3600).await?;
119            }
120            Err(e) => {
121                error!(
122                    "Task {}, handler {}, task handler failed {}",
123                    &task_id, &handler.name, e
124                );
125                redis_conn
126                    .hset_multiple(
127                        task_id_key,
128                        &[
129                            ("status", TaskStatus::Error.as_str()),
130                            ("error", &e.to_string()),
131                        ],
132                    )
133                    .await?;
134            }
135        };
136        Ok(())
137    }
138
139    pub fn register_handler(&mut self, name: &str, func: HandlerFn) {
140        let name = name.to_string();
141        let task = handler::Handler {
142            name: name.clone(),
143            func,
144        };
145        self.tasks.insert(name, Arc::new(task));
146    }
147
148    pub fn log_for_start(&self) {
149        info!("Running simple task app...");
150        let m = &self.tasks;
151        info!("Registered handlers:");
152        for (name, _) in m.iter() {
153            info!("  {}", name);
154        }
155        let conn_info = self.redis_client.get_connection_info();
156        info!(
157            "Broker & result backend: redis://{}/{}",
158            conn_info.addr, conn_info.redis.db
159        );
160        info!(
161            "Concurrency: {}",
162            self.available_workers.available_permits()
163        );
164
165        info!("Press CTRL+C to quit");
166    }
167
168    pub async fn run(self) -> anyhow::Result<()> {
169        self.log_for_start();
170
171        let _self = Arc::new(self);
172        _self.wait_shutdown_background();
173        let mut conn = _self.redis_client.get_async_connection().await?;
174
175        loop {
176            let p = loop {
177                match _self.available_workers.clone().try_acquire_owned() {
178                    Ok(p) => {
179                        break p;
180                    }
181                    Err(e) => match e {
182                        TryAcquireError::NoPermits => {
183                            tokio::time::sleep(Duration::from_millis(100)).await;
184                        }
185                        TryAcquireError::Closed => {
186                            anyhow::bail!("Semaphore closed");
187                        }
188                    },
189                }
190            };
191            if _self.should_stop() {
192                break;
193            }
194
195            let task_id: Option<String> = conn
196                .blmove(
197                    WAITING_TASK_ID_QUEUE,
198                    PROCESSING_TASK_ID_QUEUE,
199                    redis::Direction::Right,
200                    redis::Direction::Left,
201                    1,
202                )
203                .await?;
204            let Some(ref task_id) = task_id else {
205                continue;
206            };
207
208            let task_id_key = join_key!(TASK_TABLE, task_id);
209
210            conn.hset(task_id_key, "status", TaskStatus::Processing.as_str())
211                .await?;
212
213            let task: HashMap<String, String> = conn.hgetall(task_id_key).await?;
214
215            let (input, handler_name) = match (task.get("input"), task.get("handler_name")) {
216                (Some(input), Some(handler_name)) => (input.clone(), handler_name.clone()),
217                _ => {
218                    conn.hset_multiple(
219                        task_id_key,
220                        &[
221                            ("status", TaskStatus::Error.as_str()),
222                            ("error", "Missing field"),
223                        ],
224                    )
225                    .await?;
226                    continue;
227                }
228            };
229
230            let input: Value = serde_json::from_str(&input)?;
231            let handler = _self
232                .tasks
233                .get(&handler_name)
234                .ok_or(anyhow::anyhow!("Handler {} not found", handler_name))?
235                .clone();
236
237            info!("Task {} accepted, handler {}", &task_id, handler_name);
238
239            let task_id = task_id.clone();
240            let _self = _self.clone();
241            tokio::spawn(async move {
242                let _p = p;
243                if let Err(e) = _self.run_task(&task_id, &handler, input).await {
244                    error!(
245                        "Internal error when running task {}, handler {}, error {}",
246                        &task_id, handler_name, e
247                    );
248                }
249
250                debug!(
251                    "Task done. Task id {}, handler {}, Releasing SemaphorePermit.",
252                    &task_id, handler_name
253                );
254            });
255        }
256        Ok(())
257    }
258}
259
260#[macro_export]
261macro_rules! register_handler {
262    ($a:expr, $b:expr) => {
263        $a.register_handler(
264            $crate::utils::lang::type_name_of($b),
265            std::sync::Arc::new(|input| Box::pin($b(input))),
266        )
267    };
268}
269
270#[macro_export]
271macro_rules! send_task {
272    ($app:expr, $f:expr, $value:expr) => {
273        $app.send_task($crate::utils::lang::type_name_of($f), $value)
274    };
275}
276
277#[macro_export]
278macro_rules! join_key {
279    ($a:expr) => {
280        $a
281    };
282    ($a:expr, $($b:expr),+) => {
283        &vec![$a, $($b),+].join(":")
284    };
285}