spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::{
8    error::GenServerError,
9    tasks::InitResult::{NoSuccess, Success},
10};
11
12const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
13
14#[derive(Debug)]
15pub struct GenServerHandle<G: GenServer + 'static> {
16    pub tx: mpsc::Sender<GenServerInMsg<G>>,
17    /// Cancellation token to stop the GenServer
18    cancellation_token: CancellationToken,
19}
20
21impl<G: GenServer> Clone for GenServerHandle<G> {
22    fn clone(&self) -> Self {
23        Self {
24            tx: self.tx.clone(),
25            cancellation_token: self.cancellation_token.clone(),
26        }
27    }
28}
29
30impl<G: GenServer> GenServerHandle<G> {
31    pub(crate) fn new(gen_server: G) -> Self {
32        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
33        let cancellation_token = CancellationToken::new();
34        let handle = GenServerHandle {
35            tx,
36            cancellation_token,
37        };
38        let handle_clone = handle.clone();
39        // Ignore the JoinHandle for now. Maybe we'll use it in the future
40        let _join_handle = rt::spawn(async move {
41            if gen_server.run(&handle, &mut rx).await.is_err() {
42                tracing::trace!("GenServer crashed")
43            };
44        });
45        handle_clone
46    }
47
48    pub(crate) fn new_blocking(gen_server: G) -> Self {
49        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50        let cancellation_token = CancellationToken::new();
51        let handle = GenServerHandle {
52            tx,
53            cancellation_token,
54        };
55        let handle_clone = handle.clone();
56        // Ignore the JoinHandle for now. Maybe we'll use it in the future
57        let _join_handle = rt::spawn_blocking(|| {
58            rt::block_on(async move {
59                if gen_server.run(&handle, &mut rx).await.is_err() {
60                    tracing::trace!("GenServer crashed")
61                };
62            })
63        });
64        handle_clone
65    }
66
67    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
68        self.tx.clone()
69    }
70
71    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
72        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
73    }
74
75    pub async fn call_with_timeout(
76        &mut self,
77        message: G::CallMsg,
78        duration: Duration,
79    ) -> Result<G::OutMsg, GenServerError> {
80        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
81        self.tx.send(GenServerInMsg::Call {
82            sender: oneshot_tx,
83            message,
84        })?;
85
86        match timeout(duration, oneshot_rx).await {
87            Ok(Ok(result)) => result,
88            Ok(Err(_)) => Err(GenServerError::Server),
89            Err(_) => Err(GenServerError::CallTimeout),
90        }
91    }
92
93    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
94        self.tx
95            .send(GenServerInMsg::Cast { message })
96            .map_err(|_error| GenServerError::Server)
97    }
98
99    pub fn cancellation_token(&self) -> CancellationToken {
100        self.cancellation_token.clone()
101    }
102}
103
104pub enum GenServerInMsg<G: GenServer> {
105    Call {
106        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
107        message: G::CallMsg,
108    },
109    Cast {
110        message: G::CastMsg,
111    },
112}
113
114pub enum CallResponse<G: GenServer> {
115    Reply(G, G::OutMsg),
116    Unused,
117    Stop(G::OutMsg),
118}
119
120pub enum CastResponse<G: GenServer> {
121    NoReply(G),
122    Unused,
123    Stop,
124}
125
126pub enum InitResult<G: GenServer> {
127    Success(G),
128    NoSuccess(G),
129}
130
131pub trait GenServer: Send + Sized + Clone {
132    type CallMsg: Clone + Send + Sized + Sync;
133    type CastMsg: Clone + Send + Sized + Sync;
134    type OutMsg: Send + Sized;
135    type Error: Debug + Send;
136
137    fn start(self) -> GenServerHandle<Self> {
138        GenServerHandle::new(self)
139    }
140
141    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
142    /// happen if the task is blocking the thread. As such, for sync compute task
143    /// or other blocking tasks need to be in their own separate thread, and the OS
144    /// will manage them through hardware interrupts.
145    /// Start blocking provides such thread.
146    fn start_blocking(self) -> GenServerHandle<Self> {
147        GenServerHandle::new_blocking(self)
148    }
149
150    fn run(
151        self,
152        handle: &GenServerHandle<Self>,
153        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
154    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
155        async {
156            let res = match self.init(handle).await {
157                Ok(Success(new_state)) => new_state.main_loop(handle, rx).await,
158                Ok(NoSuccess(intermediate_state)) => {
159                    // new_state is NoSuccess, this means the initialization failed, but the error was handled
160                    // in callback. No need to report the error.
161                    // Just skip main_loop and return the state to teardown the GenServer
162                    Ok(intermediate_state)
163                }
164                Err(err) => {
165                    tracing::error!("Initialization failed with unhandled error: {err:?}");
166                    Err(GenServerError::Initialization)
167                }
168            };
169
170            handle.cancellation_token().cancel();
171            if let Ok(final_state) = res {
172                if let Err(err) = final_state.teardown(handle).await {
173                    tracing::error!("Error during teardown: {err:?}");
174                }
175            }
176            Ok(())
177        }
178    }
179
180    /// Initialization function. It's called before main loop. It
181    /// can be overrided on implementations in case initial steps are
182    /// required.
183    fn init(
184        self,
185        _handle: &GenServerHandle<Self>,
186    ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
187        async { Ok(Success(self)) }
188    }
189
190    fn main_loop(
191        mut self,
192        handle: &GenServerHandle<Self>,
193        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
194    ) -> impl Future<Output = Result<Self, GenServerError>> + Send {
195        async {
196            loop {
197                let (new_state, cont) = self.receive(handle, rx).await?;
198                self = new_state;
199                if !cont {
200                    break;
201                }
202            }
203            tracing::trace!("Stopping GenServer");
204            Ok(self)
205        }
206    }
207
208    fn receive(
209        self,
210        handle: &GenServerHandle<Self>,
211        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
212    ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
213        async move {
214            let message = rx.recv().await;
215
216            // Save current state in case of a rollback
217            let state_clone = self.clone();
218
219            let (keep_running, new_state) = match message {
220                Some(GenServerInMsg::Call { sender, message }) => {
221                    let (keep_running, new_state, response) =
222                        match AssertUnwindSafe(self.handle_call(message, handle))
223                            .catch_unwind()
224                            .await
225                        {
226                            Ok(response) => match response {
227                                CallResponse::Reply(new_state, response) => {
228                                    (true, new_state, Ok(response))
229                                }
230                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
231                                CallResponse::Unused => {
232                                    tracing::error!("GenServer received unexpected CallMessage");
233                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
234                                }
235                            },
236                            Err(error) => {
237                                tracing::error!(
238                                    "Error in callback, reverting state - Error: '{error:?}'"
239                                );
240                                (true, state_clone, Err(GenServerError::Callback))
241                            }
242                        };
243                    // Send response back
244                    if sender.send(response).is_err() {
245                        tracing::error!(
246                            "GenServer failed to send response back, client must have died"
247                        )
248                    };
249                    (keep_running, new_state)
250                }
251                Some(GenServerInMsg::Cast { message }) => {
252                    match AssertUnwindSafe(self.handle_cast(message, handle))
253                        .catch_unwind()
254                        .await
255                    {
256                        Ok(response) => match response {
257                            CastResponse::NoReply(new_state) => (true, new_state),
258                            CastResponse::Stop => (false, state_clone),
259                            CastResponse::Unused => {
260                                tracing::error!("GenServer received unexpected CastMessage");
261                                (false, state_clone)
262                            }
263                        },
264                        Err(error) => {
265                            tracing::trace!(
266                                "Error in callback, reverting state - Error: '{error:?}'"
267                            );
268                            (true, state_clone)
269                        }
270                    }
271                }
272                None => {
273                    // Channel has been closed; won't receive further messages. Stop the server.
274                    (false, self)
275                }
276            };
277            Ok((new_state, keep_running))
278        }
279    }
280
281    fn handle_call(
282        self,
283        _message: Self::CallMsg,
284        _handle: &GenServerHandle<Self>,
285    ) -> impl Future<Output = CallResponse<Self>> + Send {
286        async { CallResponse::Unused }
287    }
288
289    fn handle_cast(
290        self,
291        _message: Self::CastMsg,
292        _handle: &GenServerHandle<Self>,
293    ) -> impl Future<Output = CastResponse<Self>> + Send {
294        async { CastResponse::Unused }
295    }
296
297    /// Teardown function. It's called after the stop message is received.
298    /// It can be overrided on implementations in case final steps are required,
299    /// like closing streams, stopping timers, etc.
300    fn teardown(
301        self,
302        _handle: &GenServerHandle<Self>,
303    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
304        async { Ok(()) }
305    }
306}
307
308#[cfg(test)]
309mod tests {
310
311    use super::*;
312    use crate::{messages::Unused, tasks::send_after};
313    use std::{
314        sync::{Arc, Mutex},
315        thread,
316        time::Duration,
317    };
318
319    #[derive(Clone)]
320    struct BadlyBehavedTask;
321
322    #[derive(Clone)]
323    pub enum InMessage {
324        GetCount,
325        Stop,
326    }
327    #[derive(Clone)]
328    pub enum OutMsg {
329        Count(u64),
330    }
331
332    impl GenServer for BadlyBehavedTask {
333        type CallMsg = InMessage;
334        type CastMsg = Unused;
335        type OutMsg = Unused;
336        type Error = Unused;
337
338        async fn handle_call(
339            self,
340            _: Self::CallMsg,
341            _: &GenServerHandle<Self>,
342        ) -> CallResponse<Self> {
343            CallResponse::Stop(Unused)
344        }
345
346        async fn handle_cast(
347            self,
348            _: Self::CastMsg,
349            _: &GenServerHandle<Self>,
350        ) -> CastResponse<Self> {
351            rt::sleep(Duration::from_millis(20)).await;
352            thread::sleep(Duration::from_secs(2));
353            CastResponse::Stop
354        }
355    }
356
357    #[derive(Clone)]
358    struct WellBehavedTask {
359        pub count: u64,
360    }
361
362    impl GenServer for WellBehavedTask {
363        type CallMsg = InMessage;
364        type CastMsg = Unused;
365        type OutMsg = OutMsg;
366        type Error = Unused;
367
368        async fn handle_call(
369            self,
370            message: Self::CallMsg,
371            _: &GenServerHandle<Self>,
372        ) -> CallResponse<Self> {
373            match message {
374                InMessage::GetCount => {
375                    let count = self.count;
376                    CallResponse::Reply(self, OutMsg::Count(count))
377                }
378                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
379            }
380        }
381
382        async fn handle_cast(
383            mut self,
384            _: Self::CastMsg,
385            handle: &GenServerHandle<Self>,
386        ) -> CastResponse<Self> {
387            self.count += 1;
388            println!("{:?}: good still alive", thread::current().id());
389            send_after(Duration::from_millis(100), handle.to_owned(), Unused);
390            CastResponse::NoReply(self)
391        }
392    }
393
394    #[test]
395    pub fn badly_behaved_thread_non_blocking() {
396        let runtime = rt::Runtime::new().unwrap();
397        runtime.block_on(async move {
398            let mut badboy = BadlyBehavedTask.start();
399            let _ = badboy.cast(Unused).await;
400            let mut goodboy = WellBehavedTask { count: 0 }.start();
401            let _ = goodboy.cast(Unused).await;
402            rt::sleep(Duration::from_secs(1)).await;
403            let count = goodboy.call(InMessage::GetCount).await.unwrap();
404
405            match count {
406                OutMsg::Count(num) => {
407                    assert_ne!(num, 10);
408                }
409            }
410            goodboy.call(InMessage::Stop).await.unwrap();
411        });
412    }
413
414    #[test]
415    pub fn badly_behaved_thread() {
416        let runtime = rt::Runtime::new().unwrap();
417        runtime.block_on(async move {
418            let mut badboy = BadlyBehavedTask.start_blocking();
419            let _ = badboy.cast(Unused).await;
420            let mut goodboy = WellBehavedTask { count: 0 }.start();
421            let _ = goodboy.cast(Unused).await;
422            rt::sleep(Duration::from_secs(1)).await;
423            let count = goodboy.call(InMessage::GetCount).await.unwrap();
424
425            match count {
426                OutMsg::Count(num) => {
427                    assert_eq!(num, 10);
428                }
429            }
430            goodboy.call(InMessage::Stop).await.unwrap();
431        });
432    }
433
434    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
435
436    #[derive(Debug, Default, Clone)]
437    struct SomeTask;
438
439    #[derive(Clone)]
440    enum SomeTaskCallMsg {
441        SlowOperation,
442        FastOperation,
443    }
444
445    impl GenServer for SomeTask {
446        type CallMsg = SomeTaskCallMsg;
447        type CastMsg = Unused;
448        type OutMsg = Unused;
449        type Error = Unused;
450
451        async fn handle_call(
452            self,
453            message: Self::CallMsg,
454            _handle: &GenServerHandle<Self>,
455        ) -> CallResponse<Self> {
456            match message {
457                SomeTaskCallMsg::SlowOperation => {
458                    // Simulate a slow operation that will not resolve in time
459                    rt::sleep(TIMEOUT_DURATION * 2).await;
460                    CallResponse::Reply(self, Unused)
461                }
462                SomeTaskCallMsg::FastOperation => {
463                    // Simulate a fast operation that resolves in time
464                    rt::sleep(TIMEOUT_DURATION / 2).await;
465                    CallResponse::Reply(self, Unused)
466                }
467            }
468        }
469    }
470
471    #[test]
472    pub fn unresolving_task_times_out() {
473        let runtime = rt::Runtime::new().unwrap();
474        runtime.block_on(async move {
475            let mut unresolving_task = SomeTask.start();
476
477            let result = unresolving_task
478                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
479                .await;
480            assert!(matches!(result, Ok(Unused)));
481
482            let result = unresolving_task
483                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
484                .await;
485            assert!(matches!(result, Err(GenServerError::CallTimeout)));
486        });
487    }
488
489    #[derive(Clone)]
490    struct SomeTaskThatFailsOnInit {
491        sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
492    }
493
494    impl SomeTaskThatFailsOnInit {
495        pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
496            Self { sender_channel }
497        }
498    }
499
500    impl GenServer for SomeTaskThatFailsOnInit {
501        type CallMsg = Unused;
502        type CastMsg = Unused;
503        type OutMsg = Unused;
504        type Error = Unused;
505
506        async fn init(
507            self,
508            _handle: &GenServerHandle<Self>,
509        ) -> Result<InitResult<Self>, Self::Error> {
510            // Simulate an initialization failure by returning NoSuccess
511            Ok(NoSuccess(self))
512        }
513
514        async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
515            self.sender_channel.lock().unwrap().close();
516            Ok(())
517        }
518    }
519
520    #[test]
521    pub fn task_fails_with_intermediate_state() {
522        let runtime = rt::Runtime::new().unwrap();
523        runtime.block_on(async move {
524            let (rx, tx) = mpsc::channel::<u8>();
525            let sender_channel = Arc::new(Mutex::new(tx));
526            let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
527
528            // Wait a while to ensure the task has time to run and fail
529            rt::sleep(Duration::from_secs(1)).await;
530
531            // We assure that the teardown function has ran by checking that the receiver channel is closed
532            assert!(rx.is_closed())
533        });
534    }
535}