spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14    /// Cancellation token to stop the GenServer
15    cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19    fn clone(&self) -> Self {
20        Self {
21            tx: self.tx.clone(),
22            cancellation_token: self.cancellation_token.clone(),
23        }
24    }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28    pub(crate) fn new(initial_state: G::State) -> Self {
29        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30        let cancellation_token = CancellationToken::new();
31        let handle = GenServerHandle {
32            tx,
33            cancellation_token,
34        };
35        let mut gen_server: G = GenServer::new();
36        let handle_clone = handle.clone();
37        // Ignore the JoinHandle for now. Maybe we'll use it in the future
38        let _join_handle = rt::spawn(async move {
39            if gen_server
40                .run(&handle, &mut rx, initial_state)
41                .await
42                .is_err()
43            {
44                tracing::trace!("GenServer crashed")
45            };
46        });
47        handle_clone
48    }
49
50    pub(crate) fn new_blocking(initial_state: G::State) -> Self {
51        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
52        let cancellation_token = CancellationToken::new();
53        let handle = GenServerHandle {
54            tx,
55            cancellation_token,
56        };
57        let mut gen_server: G = GenServer::new();
58        let handle_clone = handle.clone();
59        // Ignore the JoinHandle for now. Maybe we'll use it in the future
60        let _join_handle = rt::spawn_blocking(|| {
61            rt::block_on(async move {
62                if gen_server
63                    .run(&handle, &mut rx, initial_state)
64                    .await
65                    .is_err()
66                {
67                    tracing::trace!("GenServer crashed")
68                };
69            })
70        });
71        handle_clone
72    }
73
74    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
75        self.tx.clone()
76    }
77
78    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
79        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
80    }
81
82    pub async fn call_with_timeout(
83        &mut self,
84        message: G::CallMsg,
85        duration: Duration,
86    ) -> Result<G::OutMsg, GenServerError> {
87        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
88        self.tx.send(GenServerInMsg::Call {
89            sender: oneshot_tx,
90            message,
91        })?;
92
93        match timeout(duration, oneshot_rx).await {
94            Ok(Ok(result)) => result,
95            Ok(Err(_)) => Err(GenServerError::Server),
96            Err(_) => Err(GenServerError::CallTimeout),
97        }
98    }
99
100    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
101        self.tx
102            .send(GenServerInMsg::Cast { message })
103            .map_err(|_error| GenServerError::Server)
104    }
105
106    pub fn cancellation_token(&self) -> CancellationToken {
107        self.cancellation_token.clone()
108    }
109}
110
111pub enum GenServerInMsg<G: GenServer> {
112    Call {
113        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
114        message: G::CallMsg,
115    },
116    Cast {
117        message: G::CastMsg,
118    },
119}
120
121pub enum CallResponse<G: GenServer> {
122    Reply(G::State, G::OutMsg),
123    Unused,
124    Stop(G::OutMsg),
125}
126
127pub enum CastResponse<G: GenServer> {
128    NoReply(G::State),
129    Unused,
130    Stop,
131}
132
133pub trait GenServer
134where
135    Self: Default + Send + Sized,
136{
137    type CallMsg: Clone + Send + Sized + Sync;
138    type CastMsg: Clone + Send + Sized + Sync;
139    type OutMsg: Send + Sized;
140    type State: Clone + Send;
141    type Error: Debug + Send;
142
143    fn new() -> Self {
144        Self::default()
145    }
146
147    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
148        GenServerHandle::new(initial_state)
149    }
150
151    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
152    /// happen if the task is blocking the thread. As such, for sync compute task
153    /// or other blocking tasks need to be in their own separate thread, and the OS
154    /// will manage them through hardware interrupts.
155    /// Start blocking provides such thread.
156    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
157        GenServerHandle::new_blocking(initial_state)
158    }
159
160    fn run(
161        &mut self,
162        handle: &GenServerHandle<Self>,
163        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
164        state: Self::State,
165    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
166        async {
167            let init_result = self
168                .init(handle, state.clone())
169                .await
170                .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
171
172            let res = match init_result {
173                Ok(new_state) => self.main_loop(handle, rx, new_state).await,
174                Err(_) => Err(GenServerError::Initialization),
175            };
176
177            handle.cancellation_token().cancel();
178            if let Err(err) = self.teardown(handle, state).await {
179                tracing::error!("Error during teardown: {err:?}");
180            }
181            res
182        }
183    }
184
185    /// Initialization function. It's called before main loop. It
186    /// can be overrided on implementations in case initial steps are
187    /// required.
188    fn init(
189        &mut self,
190        _handle: &GenServerHandle<Self>,
191        state: Self::State,
192    ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
193        async { Ok(state) }
194    }
195
196    fn main_loop(
197        &mut self,
198        handle: &GenServerHandle<Self>,
199        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200        mut state: Self::State,
201    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
202        async {
203            loop {
204                let (new_state, cont) = self.receive(handle, rx, state).await?;
205                state = new_state;
206                if !cont {
207                    break;
208                }
209            }
210            tracing::trace!("Stopping GenServer");
211            Ok(())
212        }
213    }
214
215    fn receive(
216        &mut self,
217        handle: &GenServerHandle<Self>,
218        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
219        state: Self::State,
220    ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
221        async move {
222            let message = rx.recv().await;
223
224            // Save current state in case of a rollback
225            let state_clone = state.clone();
226
227            let (keep_running, new_state) = match message {
228                Some(GenServerInMsg::Call { sender, message }) => {
229                    let (keep_running, new_state, response) =
230                        match AssertUnwindSafe(self.handle_call(message, handle, state))
231                            .catch_unwind()
232                            .await
233                        {
234                            Ok(response) => match response {
235                                CallResponse::Reply(new_state, response) => {
236                                    (true, new_state, Ok(response))
237                                }
238                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
239                                CallResponse::Unused => {
240                                    tracing::error!("GenServer received unexpected CallMessage");
241                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
242                                }
243                            },
244                            Err(error) => {
245                                tracing::error!(
246                                    "Error in callback, reverting state - Error: '{error:?}'"
247                                );
248                                (true, state_clone, Err(GenServerError::Callback))
249                            }
250                        };
251                    // Send response back
252                    if sender.send(response).is_err() {
253                        tracing::error!(
254                            "GenServer failed to send response back, client must have died"
255                        )
256                    };
257                    (keep_running, new_state)
258                }
259                Some(GenServerInMsg::Cast { message }) => {
260                    match AssertUnwindSafe(self.handle_cast(message, handle, state))
261                        .catch_unwind()
262                        .await
263                    {
264                        Ok(response) => match response {
265                            CastResponse::NoReply(new_state) => (true, new_state),
266                            CastResponse::Stop => (false, state_clone),
267                            CastResponse::Unused => {
268                                tracing::error!("GenServer received unexpected CastMessage");
269                                (false, state_clone)
270                            }
271                        },
272                        Err(error) => {
273                            tracing::trace!(
274                                "Error in callback, reverting state - Error: '{error:?}'"
275                            );
276                            (true, state_clone)
277                        }
278                    }
279                }
280                None => {
281                    // Channel has been closed; won't receive further messages. Stop the server.
282                    (false, state)
283                }
284            };
285            Ok((new_state, keep_running))
286        }
287    }
288
289    fn handle_call(
290        &mut self,
291        _message: Self::CallMsg,
292        _handle: &GenServerHandle<Self>,
293        _state: Self::State,
294    ) -> impl Future<Output = CallResponse<Self>> + Send {
295        async { CallResponse::Unused }
296    }
297
298    fn handle_cast(
299        &mut self,
300        _message: Self::CastMsg,
301        _handle: &GenServerHandle<Self>,
302        _state: Self::State,
303    ) -> impl Future<Output = CastResponse<Self>> + Send {
304        async { CastResponse::Unused }
305    }
306
307    /// Teardown function. It's called after the stop message is received.
308    /// It can be overrided on implementations in case final steps are required,
309    /// like closing streams, stopping timers, etc.
310    fn teardown(
311        &mut self,
312        _handle: &GenServerHandle<Self>,
313        _state: Self::State,
314    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
315        async { Ok(()) }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321
322    use super::*;
323    use crate::tasks::send_after;
324    use std::{thread, time::Duration};
325
326    #[derive(Default)]
327    struct BadlyBehavedTask;
328
329    #[derive(Clone)]
330    pub enum InMessage {
331        GetCount,
332        Stop,
333    }
334    #[derive(Clone)]
335    pub enum OutMsg {
336        Count(u64),
337    }
338
339    impl GenServer for BadlyBehavedTask {
340        type CallMsg = InMessage;
341        type CastMsg = ();
342        type OutMsg = ();
343        type State = ();
344        type Error = ();
345
346        async fn handle_call(
347            &mut self,
348            _: Self::CallMsg,
349            _: &GenServerHandle<Self>,
350            _: Self::State,
351        ) -> CallResponse<Self> {
352            CallResponse::Stop(())
353        }
354
355        async fn handle_cast(
356            &mut self,
357            _: Self::CastMsg,
358            _: &GenServerHandle<Self>,
359            _: Self::State,
360        ) -> CastResponse<Self> {
361            rt::sleep(Duration::from_millis(20)).await;
362            thread::sleep(Duration::from_secs(2));
363            CastResponse::Stop
364        }
365    }
366
367    #[derive(Default)]
368    struct WellBehavedTask;
369
370    #[derive(Clone)]
371    struct CountState {
372        pub count: u64,
373    }
374
375    impl GenServer for WellBehavedTask {
376        type CallMsg = InMessage;
377        type CastMsg = ();
378        type OutMsg = OutMsg;
379        type State = CountState;
380        type Error = ();
381
382        async fn handle_call(
383            &mut self,
384            message: Self::CallMsg,
385            _: &GenServerHandle<Self>,
386            state: Self::State,
387        ) -> CallResponse<Self> {
388            match message {
389                InMessage::GetCount => {
390                    let count = state.count;
391                    CallResponse::Reply(state, OutMsg::Count(count))
392                }
393                InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
394            }
395        }
396
397        async fn handle_cast(
398            &mut self,
399            _: Self::CastMsg,
400            handle: &GenServerHandle<Self>,
401            mut state: Self::State,
402        ) -> CastResponse<Self> {
403            state.count += 1;
404            println!("{:?}: good still alive", thread::current().id());
405            send_after(Duration::from_millis(100), handle.to_owned(), ());
406            CastResponse::NoReply(state)
407        }
408    }
409
410    #[test]
411    pub fn badly_behaved_thread_non_blocking() {
412        let runtime = rt::Runtime::new().unwrap();
413        runtime.block_on(async move {
414            let mut badboy = BadlyBehavedTask::start(());
415            let _ = badboy.cast(()).await;
416            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
417            let _ = goodboy.cast(()).await;
418            rt::sleep(Duration::from_secs(1)).await;
419            let count = goodboy.call(InMessage::GetCount).await.unwrap();
420
421            match count {
422                OutMsg::Count(num) => {
423                    assert_ne!(num, 10);
424                }
425            }
426            goodboy.call(InMessage::Stop).await.unwrap();
427        });
428    }
429
430    #[test]
431    pub fn badly_behaved_thread() {
432        let runtime = rt::Runtime::new().unwrap();
433        runtime.block_on(async move {
434            let mut badboy = BadlyBehavedTask::start_blocking(());
435            let _ = badboy.cast(()).await;
436            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
437            let _ = goodboy.cast(()).await;
438            rt::sleep(Duration::from_secs(1)).await;
439            let count = goodboy.call(InMessage::GetCount).await.unwrap();
440
441            match count {
442                OutMsg::Count(num) => {
443                    assert_eq!(num, 10);
444                }
445            }
446            goodboy.call(InMessage::Stop).await.unwrap();
447        });
448    }
449
450    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
451
452    #[derive(Default)]
453    struct SomeTask;
454
455    #[derive(Clone)]
456    enum SomeTaskCallMsg {
457        SlowOperation,
458        FastOperation,
459    }
460
461    impl GenServer for SomeTask {
462        type CallMsg = SomeTaskCallMsg;
463        type CastMsg = ();
464        type OutMsg = ();
465        type State = ();
466        type Error = ();
467
468        async fn handle_call(
469            &mut self,
470            message: Self::CallMsg,
471            _handle: &GenServerHandle<Self>,
472            _state: Self::State,
473        ) -> CallResponse<Self> {
474            match message {
475                SomeTaskCallMsg::SlowOperation => {
476                    // Simulate a slow operation that will not resolve in time
477                    rt::sleep(TIMEOUT_DURATION * 2).await;
478                    CallResponse::Reply((), ())
479                }
480                SomeTaskCallMsg::FastOperation => {
481                    // Simulate a fast operation that resolves in time
482                    rt::sleep(TIMEOUT_DURATION / 2).await;
483                    CallResponse::Reply((), ())
484                }
485            }
486        }
487    }
488
489    #[test]
490    pub fn unresolving_task_times_out() {
491        let runtime = rt::Runtime::new().unwrap();
492        runtime.block_on(async move {
493            let mut unresolving_task = SomeTask::start(());
494
495            let result = unresolving_task
496                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
497                .await;
498            assert!(matches!(result, Ok(())));
499
500            let result = unresolving_task
501                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
502                .await;
503            assert!(matches!(result, Err(GenServerError::CallTimeout)));
504        });
505    }
506}