xan_actor/
actor.rs

1use crate::LifeCycle;
2use crate::error::ActorError;
3use crate::types::{JobSpec, Message};
4use std::collections::{HashMap, HashSet};
5use std::error::Error;
6
7pub enum ActorSystemCmd {
8    Register(
9        String,
10        tokio::sync::mpsc::UnboundedSender<Message>,
11        tokio::sync::mpsc::UnboundedSender<()>,
12        tokio::sync::mpsc::UnboundedSender<()>,
13        LifeCycle,
14        tokio::sync::oneshot::Sender<()>,
15    ),
16    Restart(String),
17    Unregister(String),
18    FilterAddress(String, tokio::sync::oneshot::Sender<Vec<String>>),
19    FindActor(
20        String,
21        tokio::sync::oneshot::Sender<
22            Option<(tokio::sync::mpsc::UnboundedSender<Message>, bool)>, // tx, ready
23        >,
24    ),
25    SetLifeCycle(String, LifeCycle),
26}
27
28#[async_trait::async_trait]
29pub trait Actor<T, R, E>
30where
31    Self: Sized + 'static,
32    T: Sized + Send + serde::de::DeserializeOwned,
33    R: Sized + Send + serde::Serialize,
34    E: Error + Send,
35{
36    fn address(&self) -> &str;
37
38    async fn actor(&mut self, msg: T) -> Result<R, E>;
39
40    async fn pre_start(&mut self) {}
41
42    async fn pre_restart(&mut self) {}
43
44    async fn post_stop(&mut self) {}
45
46    async fn post_restart(&mut self) {}
47
48    async fn run_actor(
49        &mut self,
50        actor_system_tx: tokio::sync::mpsc::UnboundedSender<ActorSystemCmd>,
51        kill_in_error: bool,
52        ready_tx: tokio::sync::mpsc::UnboundedSender<()>,
53    ) -> Result<(), ActorError> {
54        let mut restarted = false;
55        loop {
56            if restarted {
57                self.post_restart().await;
58            }
59            let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
60            let (kill_tx, mut kill_rx) = tokio::sync::mpsc::unbounded_channel::<()>();
61            let (restart_tx, mut restart_rx) = tokio::sync::mpsc::unbounded_channel::<()>();
62            let (result_tx, result_rx) = tokio::sync::oneshot::channel();
63
64            let _ = actor_system_tx.send(ActorSystemCmd::Register(
65                self.address().to_string(),
66                tx,
67                restart_tx,
68                kill_tx,
69                if restarted {
70                    LifeCycle::Restarting
71                } else {
72                    LifeCycle::Starting
73                },
74                result_tx,
75            ));
76            let _ = result_rx.await;
77            self.pre_start().await;
78            restarted = true;
79            let _ = actor_system_tx.send(ActorSystemCmd::SetLifeCycle(
80                self.address().to_string(),
81                LifeCycle::Receiving,
82            ));
83            let _ = ready_tx.send(());
84            if let Some(_) = loop {
85                tokio::select! {
86                    Some(mut msg) = rx.recv() => {
87                        let result_tx = msg.result_tx();
88                        let msg_de = match rmp_serde::from_slice::<T>(msg.inner()) {
89                            Ok(msg) => msg,
90                            Err(e) => {
91                                if kill_in_error {
92                                    error!("Deserialize message failed: {:?}", e);
93                                    break Some(());
94                                }
95                                debug!("Deserialize message failed: {:?}", e);
96                                break None;
97                            }
98                        };
99                        match self.actor(msg_de).await {
100                           Ok(result) => {
101                                if let Some(result_tx) = result_tx {
102                                    let result = rmp_serde::to_vec(&result)?;
103                                    let _ = result_tx.send(result);
104                                }
105                            }
106                           Err(e) => {
107                                if kill_in_error {
108                                    error!("Handler's result has error: {:?}", e);
109                                    break Some(());
110                                }
111                                debug!("Handler's result has error: {:?}", e);
112                                break None;
113                           }
114                       }
115                    }
116                    Some(_) = kill_rx.recv() => {
117                        info!("Kill actor: address={}", self.address());
118                        break Some(());
119                    }
120                    Some(_) = restart_rx.recv() => {
121                        info!("Restart actor: address={}", self.address());
122                        break None;
123                    }
124                };
125            } {
126                let _ = actor_system_tx.send(ActorSystemCmd::SetLifeCycle(
127                    self.address().to_string(),
128                    LifeCycle::Stopping,
129                ));
130                self.post_stop().await;
131                let _ = actor_system_tx.send(ActorSystemCmd::SetLifeCycle(
132                    self.address().to_string(),
133                    LifeCycle::Terminated,
134                ));
135                break Ok(());
136            }
137            let _ = actor_system_tx.send(ActorSystemCmd::SetLifeCycle(
138                self.address().to_string(),
139                LifeCycle::Stopping,
140            ));
141            self.pre_restart().await;
142            let _ = actor_system_tx.send(ActorSystemCmd::SetLifeCycle(
143                self.address().to_string(),
144                LifeCycle::Restarting,
145            ));
146        }
147    }
148    async fn register(mut self, actor_system: &mut ActorSystem, kill_in_error: bool) {
149        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
150        let actor_system_tx = actor_system.handler_tx();
151        let _ = tokio::task::spawn_blocking(move || {
152            tokio::runtime::Handle::current().block_on(self.run_actor(
153                actor_system_tx,
154                kill_in_error,
155                tx,
156            ))
157        });
158        let _ = rx.recv().await;
159    }
160}
161
162#[derive(Clone)]
163pub struct ActorSystem {
164    handler_tx: tokio::sync::mpsc::UnboundedSender<ActorSystemCmd>,
165}
166
167impl Default for ActorSystem {
168    fn default() -> Self {
169        let (handler_tx, handler_rx) = tokio::sync::mpsc::unbounded_channel();
170        let mut me = Self { handler_tx };
171        me.run(handler_rx);
172        me
173    }
174}
175
176impl ActorSystem {
177    pub fn new() -> Self {
178        Self::default()
179    }
180
181    pub fn handler_tx(&self) -> tokio::sync::mpsc::UnboundedSender<ActorSystemCmd> {
182        self.handler_tx.clone()
183    }
184
185    fn run(
186        &mut self,
187        mut handler_rx: tokio::sync::mpsc::UnboundedReceiver<ActorSystemCmd>,
188    ) -> tokio::task::JoinHandle<()> {
189        tokio::task::spawn_blocking(move || {
190            let mut address_list = HashSet::<String>::new();
191            let mut map = HashMap::<
192                String,
193                (
194                    tokio::sync::mpsc::UnboundedSender<Message>,
195                    tokio::sync::mpsc::UnboundedSender<()>,
196                    tokio::sync::mpsc::UnboundedSender<()>,
197                    LifeCycle,
198                ),
199            >::new();
200            while let Some(msg) = tokio::runtime::Handle::current().block_on(handler_rx.recv()) {
201                match msg {
202                    ActorSystemCmd::Register(
203                        address,
204                        tx,
205                        restart_tx,
206                        kill_tx,
207                        life_cycle,
208                        result_tx,
209                    ) => {
210                        debug!("Register actor with address {}", address);
211                        map.insert(address.clone(), (tx, restart_tx, kill_tx, life_cycle));
212                        address_list.insert(address);
213                        let _ = result_tx.send(());
214                    }
215                    ActorSystemCmd::Restart(address_regex) => {
216                        debug!("Restart actor with address {}", address_regex);
217                        let addresses = match filter_address(&address_list, &address_regex) {
218                            Ok(addresses) => addresses,
219                            Err(e) => {
220                                error!("Filter address failed: {:?}", e);
221                                continue;
222                            }
223                        };
224                        for address in addresses {
225                            if let Some((_tx, restart_tx, _kill_tx, _life_cycle)) =
226                                map.get(&address)
227                            {
228                                let _ = restart_tx.send(());
229                            }
230                        }
231                    }
232                    ActorSystemCmd::Unregister(address_regex) => {
233                        debug!("Unregister actor with address {}", address_regex);
234                        let addresses = match filter_address(&address_list, &address_regex) {
235                            Ok(addresses) => addresses,
236                            Err(e) => {
237                                error!("Filter address failed: {:?}", e);
238                                continue;
239                            }
240                        };
241                        for address in addresses {
242                            match map.entry(address.to_string()) {
243                                std::collections::hash_map::Entry::Occupied(mut entry) => {
244                                    let _ = entry.get_mut().2.send(());
245                                    entry.remove_entry();
246                                    address_list.remove(&address);
247                                }
248                                std::collections::hash_map::Entry::Vacant(_) => {
249                                    continue;
250                                }
251                            }
252                        }
253                    }
254                    ActorSystemCmd::FilterAddress(address_regex, result_tx) => {
255                        debug!("FilterAddress with regex {}", address_regex);
256                        let addresses = match filter_address(&address_list, &address_regex) {
257                            Ok(addresses) => addresses,
258                            Err(e) => {
259                                error!("Filter address failed: {:?}", e);
260                                continue;
261                            }
262                        };
263                        let _ = result_tx.send(addresses);
264                    }
265                    ActorSystemCmd::FindActor(address, result_tx) => {
266                        debug!("FindActor with address {}", address);
267                        if let Some((tx, _restart_tx, _kill_tx, life_cycle)) = map.get(&address) {
268                            let _ = result_tx.send(Some((
269                                tx.clone(),
270                                match life_cycle {
271                                    LifeCycle::Receiving => true,
272                                    _ => false,
273                                },
274                            )));
275                        } else {
276                            let _ = result_tx.send(None);
277                        }
278                    }
279                    ActorSystemCmd::SetLifeCycle(address, life_cycle) => {
280                        debug!(
281                            "SetLifecycle with address {} into {:?}",
282                            address, life_cycle
283                        );
284                        if let Some(actor) = map.get_mut(&address) {
285                            actor.3 = life_cycle;
286                        };
287                    }
288                };
289            }
290        })
291    }
292
293    pub async fn register(
294        &mut self,
295        address: String,
296        tx: tokio::sync::mpsc::UnboundedSender<Message>,
297        restart_tx: tokio::sync::mpsc::UnboundedSender<()>,
298        kill_tx: tokio::sync::mpsc::UnboundedSender<()>,
299        life_cycle: LifeCycle,
300    ) {
301        let (result_tx, result_rx) = tokio::sync::oneshot::channel();
302        let _ = self.handler_tx.send(ActorSystemCmd::Register(
303            address, tx, restart_tx, kill_tx, life_cycle, result_tx,
304        ));
305        let _ = result_rx.await;
306    }
307
308    pub fn set_lifecycle(&mut self, address: &str, life_cycle: LifeCycle) {
309        let _ = self.handler_tx.send(ActorSystemCmd::SetLifeCycle(
310            address.to_string(),
311            life_cycle,
312        ));
313    }
314
315    pub fn restart(&mut self, address_regex: String) {
316        let _ = self.handler_tx.send(ActorSystemCmd::Restart(address_regex));
317    }
318
319    pub fn unregister(&mut self, address_regex: String) {
320        let _ = self
321            .handler_tx
322            .send(ActorSystemCmd::Unregister(address_regex));
323    }
324
325    pub async fn send<T>(&self, address: String, msg: T) -> Result<(), ActorError>
326    where
327        T: serde::Serialize + serde::de::DeserializeOwned,
328    {
329        let (tx, rx) = tokio::sync::oneshot::channel();
330        let _ = self
331            .handler_tx
332            .send(ActorSystemCmd::FindActor(address.clone(), tx));
333        if let Ok(Some((tx, ready))) = rx.await {
334            if ready {
335                let _ = tx.send(Message::new(rmp_serde::to_vec(&msg)?, None))?;
336                Ok(())
337            } else {
338                Err(ActorError::ActorNotReady(address))
339            }
340        } else {
341            Err(ActorError::AddressNotFound(address))
342        }
343    }
344    pub async fn send_broadcast<T>(
345        &self,
346        address_regex: String,
347        msg: T,
348    ) -> Vec<Result<(), ActorError>>
349    where
350        T: serde::Serialize + serde::de::DeserializeOwned,
351    {
352        let (tx, rx) = tokio::sync::oneshot::channel();
353        let _ = self
354            .handler_tx
355            .send(ActorSystemCmd::FilterAddress(address_regex, tx));
356        let addresses = match rx.await {
357            Ok(addresses) => addresses,
358            Err(e) => {
359                error!("Receive address list failed: {:?}", e);
360                return vec![Err(ActorError::from(e))];
361            }
362        };
363        let mut result = Vec::new();
364        for address in addresses {
365            let (tx, rx) = tokio::sync::oneshot::channel();
366            let _ = self
367                .handler_tx
368                .send(ActorSystemCmd::FindActor(address.clone(), tx));
369            if let Ok(Some((tx, ready))) = rx.await {
370                if ready {
371                    match rmp_serde::to_vec(&msg) {
372                        Ok(x) => {
373                            let message = Message::new(x, None);
374                            result.push(
375                                tx.send(message)
376                                    .map(|_| ())
377                                    .map_err(|e| ActorError::UnboundedChannelSend(e)),
378                            );
379                        }
380                        Err(e) => {
381                            result.push(Err(ActorError::from(e)));
382                        }
383                    }
384                } else {
385                    result.push(Err(ActorError::ActorNotReady(address)));
386                }
387            } else {
388                result.push(Err(ActorError::AddressNotFound(address)));
389            }
390        }
391        result
392    }
393
394    pub async fn send_and_recv<T, R>(&self, address: String, msg: T) -> Result<R, ActorError>
395    where
396        T: serde::Serialize + serde::de::DeserializeOwned,
397        R: serde::Serialize + serde::de::DeserializeOwned,
398    {
399        let (tx, rx) = tokio::sync::oneshot::channel();
400        let _ = self
401            .handler_tx
402            .send(ActorSystemCmd::FindActor(address.clone(), tx));
403        if let Ok(Some((tx, ready))) = rx.await {
404            if ready {
405                let (result_tx, result_rx) = tokio::sync::oneshot::channel();
406                let _ = tx.send(Message::new(rmp_serde::to_vec(&msg)?, Some(result_tx)))?;
407                Ok(rmp_serde::from_slice::<R>(&result_rx.await?)?)
408            } else {
409                Err(ActorError::ActorNotReady(address))
410            }
411        } else {
412            Err(ActorError::AddressNotFound(address))
413        }
414    }
415
416    pub async fn run_job<T, R>(
417        &self,
418        address: String,
419        subscript: bool,
420        job: JobSpec,
421        msg: T,
422    ) -> Result<
423        Option<tokio::sync::mpsc::UnboundedReceiver<Result<R, rmp_serde::decode::Error>>>,
424        ActorError,
425    >
426    where
427        T: serde::Serialize + Clone,
428        R: serde::de::DeserializeOwned + Send + 'static,
429    {
430        let (tx, rx) = tokio::sync::oneshot::channel();
431        let msg = match rmp_serde::to_vec(&msg) {
432            Ok(msg) => msg,
433            Err(e) => {
434                error!("Serialize message failed: {:?}", e);
435                return Err(ActorError::from(e));
436            }
437        };
438        let _ = self
439            .handler_tx
440            .send(ActorSystemCmd::FindActor(address.clone(), tx));
441        if let Ok(Some((tx, ready))) = rx.await {
442            if ready {
443                let tx = tx.clone();
444                if subscript {
445                    let (sub_tx, sub_rx) = tokio::sync::mpsc::unbounded_channel();
446                    let msg = msg.clone();
447                    tokio::spawn(async move {
448                        let mut i = 0;
449                        if let Some(interval) = job.interval() {
450                            loop {
451                                i += 1;
452                                if job.start_at() <= std::time::SystemTime::now() {
453                                    let (result_tx, result_rx) = tokio::sync::oneshot::channel();
454                                    let _ = tx.send(Message::new(msg.clone(), Some(result_tx)));
455                                    let result = match result_rx.await {
456                                        Ok(result) => result,
457                                        Err(e) => {
458                                            error!("Receive result failed: {:?}", e);
459                                            break;
460                                        }
461                                    };
462                                    let _ = sub_tx.send(rmp_serde::from_slice::<R>(&result));
463                                    tokio::time::sleep(interval).await;
464                                    if let Some(max_iter) = job.max_iter() {
465                                        if i >= max_iter {
466                                            break;
467                                        }
468                                    }
469                                }
470                            }
471                        } else {
472                            if job.start_at() <= std::time::SystemTime::now() {
473                                let (result_tx, result_rx) = tokio::sync::oneshot::channel();
474                                let msg = match rmp_serde::to_vec(&msg) {
475                                    Ok(msg) => msg,
476                                    Err(e) => {
477                                        error!("Serialize message failed: {:?}", e);
478                                        return;
479                                    }
480                                };
481                                let _ = tx.send(Message::new(msg, Some(result_tx)));
482                                let result =
483                                    match result_rx.await.map(|x| rmp_serde::from_slice::<R>(&x)) {
484                                        Ok(result) => result,
485                                        Err(e) => {
486                                            error!("Receive result failed: {:?}", e);
487                                            return;
488                                        }
489                                    };
490                                let _ = sub_tx.send(result);
491                            }
492                        }
493                    });
494                    Ok(Some(sub_rx))
495                } else {
496                    tokio::spawn(async move {
497                        let mut i = 0;
498                        if let Some(interval) = job.interval() {
499                            loop {
500                                i += 1;
501                                if job.start_at() <= std::time::SystemTime::now() {
502                                    let _ = tx.send(Message::new(msg.clone(), None));
503                                    tokio::time::sleep(interval).await;
504                                    if let Some(max_iter) = job.max_iter() {
505                                        if i >= max_iter {
506                                            break;
507                                        }
508                                    }
509                                }
510                            }
511                        } else {
512                            if job.start_at() <= std::time::SystemTime::now() {
513                                let _ = tx.send(Message::new(msg.clone(), None));
514                            }
515                        }
516                    });
517                    Ok(None)
518                }
519            } else {
520                Err(ActorError::ActorNotReady(address))
521            }
522        } else {
523            Err(ActorError::AddressNotFound(address))
524        }
525    }
526}
527
528fn filter_address(
529    address_list: &HashSet<String>,
530    regex: &str,
531) -> Result<Vec<String>, regex::Error> {
532    let regex = regex::Regex::new(&format!("^{}$", regex.replace("*", "(\\S+)"))).map_err(|e| {
533        error!("Regex error: {:?}", e);
534        e
535    })?;
536    Ok(address_list
537        .iter()
538        .filter(|x| regex.is_match(x))
539        .map(|x| x.to_string())
540        .collect())
541}