spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14    /// Cancellation token to stop the GenServer
15    cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19    fn clone(&self) -> Self {
20        Self {
21            tx: self.tx.clone(),
22            cancellation_token: self.cancellation_token.clone(),
23        }
24    }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28    pub(crate) fn new(gen_server: G) -> Self {
29        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30        let cancellation_token = CancellationToken::new();
31        let handle = GenServerHandle {
32            tx,
33            cancellation_token,
34        };
35        let handle_clone = handle.clone();
36        // Ignore the JoinHandle for now. Maybe we'll use it in the future
37        let _join_handle = rt::spawn(async move {
38            if gen_server.run(&handle, &mut rx).await.is_err() {
39                tracing::trace!("GenServer crashed")
40            };
41        });
42        handle_clone
43    }
44
45    pub(crate) fn new_blocking(gen_server: G) -> Self {
46        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
47        let cancellation_token = CancellationToken::new();
48        let handle = GenServerHandle {
49            tx,
50            cancellation_token,
51        };
52        let handle_clone = handle.clone();
53        // Ignore the JoinHandle for now. Maybe we'll use it in the future
54        let _join_handle = rt::spawn_blocking(|| {
55            rt::block_on(async move {
56                if gen_server.run(&handle, &mut rx).await.is_err() {
57                    tracing::trace!("GenServer crashed")
58                };
59            })
60        });
61        handle_clone
62    }
63
64    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
65        self.tx.clone()
66    }
67
68    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
69        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
70    }
71
72    pub async fn call_with_timeout(
73        &mut self,
74        message: G::CallMsg,
75        duration: Duration,
76    ) -> Result<G::OutMsg, GenServerError> {
77        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78        self.tx.send(GenServerInMsg::Call {
79            sender: oneshot_tx,
80            message,
81        })?;
82
83        match timeout(duration, oneshot_rx).await {
84            Ok(Ok(result)) => result,
85            Ok(Err(_)) => Err(GenServerError::Server),
86            Err(_) => Err(GenServerError::CallTimeout),
87        }
88    }
89
90    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
91        self.tx
92            .send(GenServerInMsg::Cast { message })
93            .map_err(|_error| GenServerError::Server)
94    }
95
96    pub fn cancellation_token(&self) -> CancellationToken {
97        self.cancellation_token.clone()
98    }
99}
100
101pub enum GenServerInMsg<G: GenServer> {
102    Call {
103        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
104        message: G::CallMsg,
105    },
106    Cast {
107        message: G::CastMsg,
108    },
109}
110
111pub enum CallResponse<G: GenServer> {
112    Reply(G, G::OutMsg),
113    Unused,
114    Stop(G::OutMsg),
115}
116
117pub enum CastResponse<G: GenServer> {
118    NoReply(G),
119    Unused,
120    Stop,
121}
122
123pub trait GenServer: Send + Sized + Clone {
124    type CallMsg: Clone + Send + Sized + Sync;
125    type CastMsg: Clone + Send + Sized + Sync;
126    type OutMsg: Send + Sized;
127    type Error: Debug + Send;
128
129    fn start(self) -> GenServerHandle<Self> {
130        GenServerHandle::new(self)
131    }
132
133    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
134    /// happen if the task is blocking the thread. As such, for sync compute task
135    /// or other blocking tasks need to be in their own separate thread, and the OS
136    /// will manage them through hardware interrupts.
137    /// Start blocking provides such thread.
138    fn start_blocking(self) -> GenServerHandle<Self> {
139        GenServerHandle::new_blocking(self)
140    }
141
142    fn run(
143        self,
144        handle: &GenServerHandle<Self>,
145        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
146    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
147        async {
148            let init_result = self
149                .init(handle)
150                .await
151                .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
152
153            let res = match init_result {
154                Ok(new_state) => new_state.main_loop(handle, rx).await,
155                Err(_) => Err(GenServerError::Initialization),
156            };
157
158            handle.cancellation_token().cancel();
159            if let Ok(final_state) = res {
160                if let Err(err) = final_state.teardown(handle).await {
161                    tracing::error!("Error during teardown: {err:?}");
162                }
163            }
164            Ok(())
165        }
166    }
167
168    /// Initialization function. It's called before main loop. It
169    /// can be overrided on implementations in case initial steps are
170    /// required.
171    fn init(
172        self,
173        _handle: &GenServerHandle<Self>,
174    ) -> impl Future<Output = Result<Self, Self::Error>> + Send {
175        async { Ok(self) }
176    }
177
178    fn main_loop(
179        mut self,
180        handle: &GenServerHandle<Self>,
181        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
182    ) -> impl Future<Output = Result<Self, GenServerError>> + Send {
183        async {
184            loop {
185                let (new_state, cont) = self.receive(handle, rx).await?;
186                self = new_state;
187                if !cont {
188                    break;
189                }
190            }
191            tracing::trace!("Stopping GenServer");
192            Ok(self)
193        }
194    }
195
196    fn receive(
197        self,
198        handle: &GenServerHandle<Self>,
199        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200    ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
201        async move {
202            let message = rx.recv().await;
203
204            // Save current state in case of a rollback
205            let state_clone = self.clone();
206
207            let (keep_running, new_state) = match message {
208                Some(GenServerInMsg::Call { sender, message }) => {
209                    let (keep_running, new_state, response) =
210                        match AssertUnwindSafe(self.handle_call(message, handle))
211                            .catch_unwind()
212                            .await
213                        {
214                            Ok(response) => match response {
215                                CallResponse::Reply(new_state, response) => {
216                                    (true, new_state, Ok(response))
217                                }
218                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
219                                CallResponse::Unused => {
220                                    tracing::error!("GenServer received unexpected CallMessage");
221                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
222                                }
223                            },
224                            Err(error) => {
225                                tracing::error!(
226                                    "Error in callback, reverting state - Error: '{error:?}'"
227                                );
228                                (true, state_clone, Err(GenServerError::Callback))
229                            }
230                        };
231                    // Send response back
232                    if sender.send(response).is_err() {
233                        tracing::error!(
234                            "GenServer failed to send response back, client must have died"
235                        )
236                    };
237                    (keep_running, new_state)
238                }
239                Some(GenServerInMsg::Cast { message }) => {
240                    match AssertUnwindSafe(self.handle_cast(message, handle))
241                        .catch_unwind()
242                        .await
243                    {
244                        Ok(response) => match response {
245                            CastResponse::NoReply(new_state) => (true, new_state),
246                            CastResponse::Stop => (false, state_clone),
247                            CastResponse::Unused => {
248                                tracing::error!("GenServer received unexpected CastMessage");
249                                (false, state_clone)
250                            }
251                        },
252                        Err(error) => {
253                            tracing::trace!(
254                                "Error in callback, reverting state - Error: '{error:?}'"
255                            );
256                            (true, state_clone)
257                        }
258                    }
259                }
260                None => {
261                    // Channel has been closed; won't receive further messages. Stop the server.
262                    (false, self)
263                }
264            };
265            Ok((new_state, keep_running))
266        }
267    }
268
269    fn handle_call(
270        self,
271        _message: Self::CallMsg,
272        _handle: &GenServerHandle<Self>,
273    ) -> impl Future<Output = CallResponse<Self>> + Send {
274        async { CallResponse::Unused }
275    }
276
277    fn handle_cast(
278        self,
279        _message: Self::CastMsg,
280        _handle: &GenServerHandle<Self>,
281    ) -> impl Future<Output = CastResponse<Self>> + Send {
282        async { CastResponse::Unused }
283    }
284
285    /// Teardown function. It's called after the stop message is received.
286    /// It can be overrided on implementations in case final steps are required,
287    /// like closing streams, stopping timers, etc.
288    fn teardown(
289        self,
290        _handle: &GenServerHandle<Self>,
291    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
292        async { Ok(()) }
293    }
294}
295
296#[cfg(test)]
297mod tests {
298
299    use super::*;
300    use crate::tasks::send_after;
301    use std::{thread, time::Duration};
302
303    #[derive(Clone)]
304    struct BadlyBehavedTask;
305
306    #[derive(Clone)]
307    pub enum InMessage {
308        GetCount,
309        Stop,
310    }
311    #[derive(Clone)]
312    pub enum OutMsg {
313        Count(u64),
314    }
315
316    impl GenServer for BadlyBehavedTask {
317        type CallMsg = InMessage;
318        type CastMsg = ();
319        type OutMsg = ();
320        type Error = ();
321
322        async fn handle_call(
323            self,
324            _: Self::CallMsg,
325            _: &GenServerHandle<Self>,
326        ) -> CallResponse<Self> {
327            CallResponse::Stop(())
328        }
329
330        async fn handle_cast(
331            self,
332            _: Self::CastMsg,
333            _: &GenServerHandle<Self>,
334        ) -> CastResponse<Self> {
335            rt::sleep(Duration::from_millis(20)).await;
336            thread::sleep(Duration::from_secs(2));
337            CastResponse::Stop
338        }
339    }
340
341    #[derive(Clone)]
342    struct WellBehavedTask {
343        pub count: u64,
344    }
345
346    impl GenServer for WellBehavedTask {
347        type CallMsg = InMessage;
348        type CastMsg = ();
349        type OutMsg = OutMsg;
350        type Error = ();
351
352        async fn handle_call(
353            self,
354            message: Self::CallMsg,
355            _: &GenServerHandle<Self>,
356        ) -> CallResponse<Self> {
357            match message {
358                InMessage::GetCount => {
359                    let count = self.count;
360                    CallResponse::Reply(self, OutMsg::Count(count))
361                }
362                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
363            }
364        }
365
366        async fn handle_cast(
367            mut self,
368            _: Self::CastMsg,
369            handle: &GenServerHandle<Self>,
370        ) -> CastResponse<Self> {
371            self.count += 1;
372            println!("{:?}: good still alive", thread::current().id());
373            send_after(Duration::from_millis(100), handle.to_owned(), ());
374            CastResponse::NoReply(self)
375        }
376    }
377
378    #[test]
379    pub fn badly_behaved_thread_non_blocking() {
380        let runtime = rt::Runtime::new().unwrap();
381        runtime.block_on(async move {
382            let mut badboy = BadlyBehavedTask.start();
383            let _ = badboy.cast(()).await;
384            let mut goodboy = WellBehavedTask { count: 0 }.start();
385            let _ = goodboy.cast(()).await;
386            rt::sleep(Duration::from_secs(1)).await;
387            let count = goodboy.call(InMessage::GetCount).await.unwrap();
388
389            match count {
390                OutMsg::Count(num) => {
391                    assert_ne!(num, 10);
392                }
393            }
394            goodboy.call(InMessage::Stop).await.unwrap();
395        });
396    }
397
398    #[test]
399    pub fn badly_behaved_thread() {
400        let runtime = rt::Runtime::new().unwrap();
401        runtime.block_on(async move {
402            let mut badboy = BadlyBehavedTask.start_blocking();
403            let _ = badboy.cast(()).await;
404            let mut goodboy = WellBehavedTask { count: 0 }.start();
405            let _ = goodboy.cast(()).await;
406            rt::sleep(Duration::from_secs(1)).await;
407            let count = goodboy.call(InMessage::GetCount).await.unwrap();
408
409            match count {
410                OutMsg::Count(num) => {
411                    assert_eq!(num, 10);
412                }
413            }
414            goodboy.call(InMessage::Stop).await.unwrap();
415        });
416    }
417
418    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
419
420    #[derive(Debug, Default, Clone)]
421    struct SomeTask;
422
423    #[derive(Clone)]
424    enum SomeTaskCallMsg {
425        SlowOperation,
426        FastOperation,
427    }
428
429    impl GenServer for SomeTask {
430        type CallMsg = SomeTaskCallMsg;
431        type CastMsg = ();
432        type OutMsg = ();
433        type Error = ();
434
435        async fn handle_call(
436            self,
437            message: Self::CallMsg,
438            _handle: &GenServerHandle<Self>,
439        ) -> CallResponse<Self> {
440            match message {
441                SomeTaskCallMsg::SlowOperation => {
442                    // Simulate a slow operation that will not resolve in time
443                    rt::sleep(TIMEOUT_DURATION * 2).await;
444                    CallResponse::Reply(self, ())
445                }
446                SomeTaskCallMsg::FastOperation => {
447                    // Simulate a fast operation that resolves in time
448                    rt::sleep(TIMEOUT_DURATION / 2).await;
449                    CallResponse::Reply(self, ())
450                }
451            }
452        }
453    }
454
455    #[test]
456    pub fn unresolving_task_times_out() {
457        let runtime = rt::Runtime::new().unwrap();
458        runtime.block_on(async move {
459            let mut unresolving_task = SomeTask.start();
460
461            let result = unresolving_task
462                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
463                .await;
464            assert!(matches!(result, Ok(())));
465
466            let result = unresolving_task
467                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
468                .await;
469            assert!(matches!(result, Err(GenServerError::CallTimeout)));
470        });
471    }
472}