spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::{
8    error::GenServerError,
9    tasks::InitResult::{NoSuccess, Success},
10};
11
12const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
13
14#[derive(Debug)]
15pub struct GenServerHandle<G: GenServer + 'static> {
16    pub tx: mpsc::Sender<GenServerInMsg<G>>,
17    /// Cancellation token to stop the GenServer
18    cancellation_token: CancellationToken,
19}
20
21impl<G: GenServer> Clone for GenServerHandle<G> {
22    fn clone(&self) -> Self {
23        Self {
24            tx: self.tx.clone(),
25            cancellation_token: self.cancellation_token.clone(),
26        }
27    }
28}
29
30impl<G: GenServer> GenServerHandle<G> {
31    pub(crate) fn new(gen_server: G) -> Self {
32        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
33        let cancellation_token = CancellationToken::new();
34        let handle = GenServerHandle {
35            tx,
36            cancellation_token,
37        };
38        let handle_clone = handle.clone();
39        // Ignore the JoinHandle for now. Maybe we'll use it in the future
40        let _join_handle = rt::spawn(async move {
41            if gen_server.run(&handle, &mut rx).await.is_err() {
42                tracing::trace!("GenServer crashed")
43            };
44        });
45        handle_clone
46    }
47
48    pub(crate) fn new_blocking(gen_server: G) -> Self {
49        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50        let cancellation_token = CancellationToken::new();
51        let handle = GenServerHandle {
52            tx,
53            cancellation_token,
54        };
55        let handle_clone = handle.clone();
56        // Ignore the JoinHandle for now. Maybe we'll use it in the future
57        let _join_handle = rt::spawn_blocking(|| {
58            rt::block_on(async move {
59                if gen_server.run(&handle, &mut rx).await.is_err() {
60                    tracing::trace!("GenServer crashed")
61                };
62            })
63        });
64        handle_clone
65    }
66
67    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
68        self.tx.clone()
69    }
70
71    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
72        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
73    }
74
75    pub async fn call_with_timeout(
76        &mut self,
77        message: G::CallMsg,
78        duration: Duration,
79    ) -> Result<G::OutMsg, GenServerError> {
80        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
81        self.tx.send(GenServerInMsg::Call {
82            sender: oneshot_tx,
83            message,
84        })?;
85
86        match timeout(duration, oneshot_rx).await {
87            Ok(Ok(result)) => result,
88            Ok(Err(_)) => Err(GenServerError::Server),
89            Err(_) => Err(GenServerError::CallTimeout),
90        }
91    }
92
93    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
94        self.tx
95            .send(GenServerInMsg::Cast { message })
96            .map_err(|_error| GenServerError::Server)
97    }
98
99    pub fn cancellation_token(&self) -> CancellationToken {
100        self.cancellation_token.clone()
101    }
102}
103
104pub enum GenServerInMsg<G: GenServer> {
105    Call {
106        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
107        message: G::CallMsg,
108    },
109    Cast {
110        message: G::CastMsg,
111    },
112}
113
114pub enum CallResponse<G: GenServer> {
115    Reply(G::OutMsg),
116    Unused,
117    Stop(G::OutMsg),
118}
119
120pub enum CastResponse {
121    NoReply,
122    Unused,
123    Stop,
124}
125
126pub enum InitResult<G: GenServer> {
127    Success(G),
128    NoSuccess(G),
129}
130
131pub trait GenServer: Send + Sized {
132    type CallMsg: Clone + Send + Sized + Sync;
133    type CastMsg: Clone + Send + Sized + Sync;
134    type OutMsg: Send + Sized;
135    type Error: Debug + Send;
136
137    fn start(self) -> GenServerHandle<Self> {
138        GenServerHandle::new(self)
139    }
140
141    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
142    /// happen if the task is blocking the thread. As such, for sync compute task
143    /// or other blocking tasks need to be in their own separate thread, and the OS
144    /// will manage them through hardware interrupts.
145    /// Start blocking provides such thread.
146    fn start_blocking(self) -> GenServerHandle<Self> {
147        GenServerHandle::new_blocking(self)
148    }
149
150    fn run(
151        self,
152        handle: &GenServerHandle<Self>,
153        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
154    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
155        async {
156            let res = match self.init(handle).await {
157                Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
158                Ok(NoSuccess(intermediate_state)) => {
159                    // new_state is NoSuccess, this means the initialization failed, but the error was handled
160                    // in callback. No need to report the error.
161                    // Just skip main_loop and return the state to teardown the GenServer
162                    Ok(intermediate_state)
163                }
164                Err(err) => {
165                    tracing::error!("Initialization failed with unhandled error: {err:?}");
166                    Err(GenServerError::Initialization)
167                }
168            };
169
170            handle.cancellation_token().cancel();
171            if let Ok(final_state) = res {
172                if let Err(err) = final_state.teardown(handle).await {
173                    tracing::error!("Error during teardown: {err:?}");
174                }
175            }
176            Ok(())
177        }
178    }
179
180    /// Initialization function. It's called before main loop. It
181    /// can be overrided on implementations in case initial steps are
182    /// required.
183    fn init(
184        self,
185        _handle: &GenServerHandle<Self>,
186    ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
187        async { Ok(Success(self)) }
188    }
189
190    fn main_loop(
191        mut self,
192        handle: &GenServerHandle<Self>,
193        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
194    ) -> impl Future<Output = Self> + Send {
195        async {
196            loop {
197                if !self.receive(handle, rx).await {
198                    break;
199                }
200            }
201            tracing::trace!("Stopping GenServer");
202            self
203        }
204    }
205
206    fn receive(
207        &mut self,
208        handle: &GenServerHandle<Self>,
209        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
210    ) -> impl Future<Output = bool> + Send {
211        async move {
212            let message = rx.recv().await;
213
214            let keep_running = match message {
215                Some(GenServerInMsg::Call { sender, message }) => {
216                    let (keep_running, response) =
217                        match AssertUnwindSafe(self.handle_call(message, handle))
218                            .catch_unwind()
219                            .await
220                        {
221                            Ok(response) => match response {
222                                CallResponse::Reply(response) => (true, Ok(response)),
223                                CallResponse::Stop(response) => (false, Ok(response)),
224                                CallResponse::Unused => {
225                                    tracing::error!("GenServer received unexpected CallMessage");
226                                    (false, Err(GenServerError::CallMsgUnused))
227                                }
228                            },
229                            Err(error) => {
230                                tracing::error!("Error in callback: '{error:?}'");
231                                (false, Err(GenServerError::Callback))
232                            }
233                        };
234                    // Send response back
235                    if sender.send(response).is_err() {
236                        tracing::error!(
237                            "GenServer failed to send response back, client must have died"
238                        )
239                    };
240                    keep_running
241                }
242                Some(GenServerInMsg::Cast { message }) => {
243                    match AssertUnwindSafe(self.handle_cast(message, handle))
244                        .catch_unwind()
245                        .await
246                    {
247                        Ok(response) => match response {
248                            CastResponse::NoReply => true,
249                            CastResponse::Stop => false,
250                            CastResponse::Unused => {
251                                tracing::error!("GenServer received unexpected CastMessage");
252                                false
253                            }
254                        },
255                        Err(error) => {
256                            tracing::trace!("Error in callback: '{error:?}'");
257                            false
258                        }
259                    }
260                }
261                None => {
262                    // Channel has been closed; won't receive further messages. Stop the server.
263                    false
264                }
265            };
266            keep_running
267        }
268    }
269
270    fn handle_call(
271        &mut self,
272        _message: Self::CallMsg,
273        _handle: &GenServerHandle<Self>,
274    ) -> impl Future<Output = CallResponse<Self>> + Send {
275        async { CallResponse::Unused }
276    }
277
278    fn handle_cast(
279        &mut self,
280        _message: Self::CastMsg,
281        _handle: &GenServerHandle<Self>,
282    ) -> impl Future<Output = CastResponse> + Send {
283        async { CastResponse::Unused }
284    }
285
286    /// Teardown function. It's called after the stop message is received.
287    /// It can be overrided on implementations in case final steps are required,
288    /// like closing streams, stopping timers, etc.
289    fn teardown(
290        self,
291        _handle: &GenServerHandle<Self>,
292    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
293        async { Ok(()) }
294    }
295}
296
297#[cfg(test)]
298mod tests {
299
300    use super::*;
301    use crate::{messages::Unused, tasks::send_after};
302    use std::{
303        sync::{Arc, Mutex},
304        thread,
305        time::Duration,
306    };
307
308    struct BadlyBehavedTask;
309
310    #[derive(Clone)]
311    pub enum InMessage {
312        GetCount,
313        Stop,
314    }
315    #[derive(Clone)]
316    pub enum OutMsg {
317        Count(u64),
318    }
319
320    impl GenServer for BadlyBehavedTask {
321        type CallMsg = InMessage;
322        type CastMsg = Unused;
323        type OutMsg = Unused;
324        type Error = Unused;
325
326        async fn handle_call(
327            &mut self,
328            _: Self::CallMsg,
329            _: &GenServerHandle<Self>,
330        ) -> CallResponse<Self> {
331            CallResponse::Stop(Unused)
332        }
333
334        async fn handle_cast(
335            &mut self,
336            _: Self::CastMsg,
337            _: &GenServerHandle<Self>,
338        ) -> CastResponse {
339            rt::sleep(Duration::from_millis(20)).await;
340            thread::sleep(Duration::from_secs(2));
341            CastResponse::Stop
342        }
343    }
344
345    struct WellBehavedTask {
346        pub count: u64,
347    }
348
349    impl GenServer for WellBehavedTask {
350        type CallMsg = InMessage;
351        type CastMsg = Unused;
352        type OutMsg = OutMsg;
353        type Error = Unused;
354
355        async fn handle_call(
356            &mut self,
357            message: Self::CallMsg,
358            _: &GenServerHandle<Self>,
359        ) -> CallResponse<Self> {
360            match message {
361                InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
362                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
363            }
364        }
365
366        async fn handle_cast(
367            &mut self,
368            _: Self::CastMsg,
369            handle: &GenServerHandle<Self>,
370        ) -> CastResponse {
371            self.count += 1;
372            println!("{:?}: good still alive", thread::current().id());
373            send_after(Duration::from_millis(100), handle.to_owned(), Unused);
374            CastResponse::NoReply
375        }
376    }
377
378    #[test]
379    pub fn badly_behaved_thread_non_blocking() {
380        let runtime = rt::Runtime::new().unwrap();
381        runtime.block_on(async move {
382            let mut badboy = BadlyBehavedTask.start();
383            let _ = badboy.cast(Unused).await;
384            let mut goodboy = WellBehavedTask { count: 0 }.start();
385            let _ = goodboy.cast(Unused).await;
386            rt::sleep(Duration::from_secs(1)).await;
387            let count = goodboy.call(InMessage::GetCount).await.unwrap();
388
389            match count {
390                OutMsg::Count(num) => {
391                    assert_ne!(num, 10);
392                }
393            }
394            goodboy.call(InMessage::Stop).await.unwrap();
395        });
396    }
397
398    #[test]
399    pub fn badly_behaved_thread() {
400        let runtime = rt::Runtime::new().unwrap();
401        runtime.block_on(async move {
402            let mut badboy = BadlyBehavedTask.start_blocking();
403            let _ = badboy.cast(Unused).await;
404            let mut goodboy = WellBehavedTask { count: 0 }.start();
405            let _ = goodboy.cast(Unused).await;
406            rt::sleep(Duration::from_secs(1)).await;
407            let count = goodboy.call(InMessage::GetCount).await.unwrap();
408
409            match count {
410                OutMsg::Count(num) => {
411                    assert_eq!(num, 10);
412                }
413            }
414            goodboy.call(InMessage::Stop).await.unwrap();
415        });
416    }
417
418    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
419
420    #[derive(Debug, Default)]
421    struct SomeTask;
422
423    #[derive(Clone)]
424    enum SomeTaskCallMsg {
425        SlowOperation,
426        FastOperation,
427    }
428
429    impl GenServer for SomeTask {
430        type CallMsg = SomeTaskCallMsg;
431        type CastMsg = Unused;
432        type OutMsg = Unused;
433        type Error = Unused;
434
435        async fn handle_call(
436            &mut self,
437            message: Self::CallMsg,
438            _handle: &GenServerHandle<Self>,
439        ) -> CallResponse<Self> {
440            match message {
441                SomeTaskCallMsg::SlowOperation => {
442                    // Simulate a slow operation that will not resolve in time
443                    rt::sleep(TIMEOUT_DURATION * 2).await;
444                    CallResponse::Reply(Unused)
445                }
446                SomeTaskCallMsg::FastOperation => {
447                    // Simulate a fast operation that resolves in time
448                    rt::sleep(TIMEOUT_DURATION / 2).await;
449                    CallResponse::Reply(Unused)
450                }
451            }
452        }
453    }
454
455    #[test]
456    pub fn unresolving_task_times_out() {
457        let runtime = rt::Runtime::new().unwrap();
458        runtime.block_on(async move {
459            let mut unresolving_task = SomeTask.start();
460
461            let result = unresolving_task
462                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
463                .await;
464            assert!(matches!(result, Ok(Unused)));
465
466            let result = unresolving_task
467                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
468                .await;
469            assert!(matches!(result, Err(GenServerError::CallTimeout)));
470        });
471    }
472
473    struct SomeTaskThatFailsOnInit {
474        sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
475    }
476
477    impl SomeTaskThatFailsOnInit {
478        pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
479            Self { sender_channel }
480        }
481    }
482
483    impl GenServer for SomeTaskThatFailsOnInit {
484        type CallMsg = Unused;
485        type CastMsg = Unused;
486        type OutMsg = Unused;
487        type Error = Unused;
488
489        async fn init(
490            self,
491            _handle: &GenServerHandle<Self>,
492        ) -> Result<InitResult<Self>, Self::Error> {
493            // Simulate an initialization failure by returning NoSuccess
494            Ok(NoSuccess(self))
495        }
496
497        async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
498            self.sender_channel.lock().unwrap().close();
499            Ok(())
500        }
501    }
502
503    #[test]
504    pub fn task_fails_with_intermediate_state() {
505        let runtime = rt::Runtime::new().unwrap();
506        runtime.block_on(async move {
507            let (rx, tx) = mpsc::channel::<u8>();
508            let sender_channel = Arc::new(Mutex::new(tx));
509            let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
510
511            // Wait a while to ensure the task has time to run and fail
512            rt::sleep(Duration::from_secs(1)).await;
513
514            // We assure that the teardown function has ran by checking that the receiver channel is closed
515            assert!(rx.is_closed())
516        });
517    }
518}