spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14    /// Cancellation token to stop the GenServer
15    cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19    fn clone(&self) -> Self {
20        Self {
21            tx: self.tx.clone(),
22            cancellation_token: self.cancellation_token.clone(),
23        }
24    }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28    pub(crate) fn new(gen_server: G) -> Self {
29        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30        let cancellation_token = CancellationToken::new();
31        let handle = GenServerHandle {
32            tx,
33            cancellation_token,
34        };
35        let handle_clone = handle.clone();
36        // Ignore the JoinHandle for now. Maybe we'll use it in the future
37        let _join_handle = rt::spawn(async move {
38            if gen_server.run(&handle, &mut rx).await.is_err() {
39                tracing::trace!("GenServer crashed")
40            };
41        });
42        handle_clone
43    }
44
45    pub(crate) fn new_blocking(gen_server: G) -> Self {
46        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
47        let cancellation_token = CancellationToken::new();
48        let handle = GenServerHandle {
49            tx,
50            cancellation_token,
51        };
52        let handle_clone = handle.clone();
53        // Ignore the JoinHandle for now. Maybe we'll use it in the future
54        let _join_handle = rt::spawn_blocking(|| {
55            rt::block_on(async move {
56                if gen_server.run(&handle, &mut rx).await.is_err() {
57                    tracing::trace!("GenServer crashed")
58                };
59            })
60        });
61        handle_clone
62    }
63
64    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
65        self.tx.clone()
66    }
67
68    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
69        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
70    }
71
72    pub async fn call_with_timeout(
73        &mut self,
74        message: G::CallMsg,
75        duration: Duration,
76    ) -> Result<G::OutMsg, GenServerError> {
77        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78        self.tx.send(GenServerInMsg::Call {
79            sender: oneshot_tx,
80            message,
81        })?;
82
83        match timeout(duration, oneshot_rx).await {
84            Ok(Ok(result)) => result,
85            Ok(Err(_)) => Err(GenServerError::Server),
86            Err(_) => Err(GenServerError::CallTimeout),
87        }
88    }
89
90    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
91        self.tx
92            .send(GenServerInMsg::Cast { message })
93            .map_err(|_error| GenServerError::Server)
94    }
95
96    pub fn cancellation_token(&self) -> CancellationToken {
97        self.cancellation_token.clone()
98    }
99}
100
101pub enum GenServerInMsg<G: GenServer> {
102    Call {
103        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
104        message: G::CallMsg,
105    },
106    Cast {
107        message: G::CastMsg,
108    },
109}
110
111pub enum CallResponse<G: GenServer> {
112    Reply(G, G::OutMsg),
113    Unused,
114    Stop(G::OutMsg),
115}
116
117pub enum CastResponse<G: GenServer> {
118    NoReply(G),
119    Unused,
120    Stop,
121}
122
123pub trait GenServer: Send + Sized + Clone {
124    type CallMsg: Clone + Send + Sized + Sync;
125    type CastMsg: Clone + Send + Sized + Sync;
126    type OutMsg: Send + Sized;
127    type Error: Debug + Send;
128
129    fn start(self) -> GenServerHandle<Self> {
130        GenServerHandle::new(self)
131    }
132
133    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
134    /// happen if the task is blocking the thread. As such, for sync compute task
135    /// or other blocking tasks need to be in their own separate thread, and the OS
136    /// will manage them through hardware interrupts.
137    /// Start blocking provides such thread.
138    fn start_blocking(self) -> GenServerHandle<Self> {
139        GenServerHandle::new_blocking(self)
140    }
141
142    fn run(
143        self,
144        handle: &GenServerHandle<Self>,
145        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
146    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
147        async {
148            let init_result = self
149                .clone()
150                .init(handle)
151                .await
152                .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
153
154            let res = match init_result {
155                Ok(new_state) => new_state.main_loop(handle, rx).await,
156                Err(_) => Err(GenServerError::Initialization),
157            };
158
159            handle.cancellation_token().cancel();
160            if let Err(err) = self.teardown(handle).await {
161                tracing::error!("Error during teardown: {err:?}");
162            }
163            res
164        }
165    }
166
167    /// Initialization function. It's called before main loop. It
168    /// can be overrided on implementations in case initial steps are
169    /// required.
170    fn init(
171        self,
172        _handle: &GenServerHandle<Self>,
173    ) -> impl Future<Output = Result<Self, Self::Error>> + Send {
174        async { Ok(self) }
175    }
176
177    fn main_loop(
178        mut self,
179        handle: &GenServerHandle<Self>,
180        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
181    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
182        async {
183            loop {
184                let (new_state, cont) = self.receive(handle, rx).await?;
185                self = new_state;
186                if !cont {
187                    break;
188                }
189            }
190            tracing::trace!("Stopping GenServer");
191            Ok(())
192        }
193    }
194
195    fn receive(
196        self,
197        handle: &GenServerHandle<Self>,
198        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
199    ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
200        async move {
201            let message = rx.recv().await;
202
203            // Save current state in case of a rollback
204            let state_clone = self.clone();
205
206            let (keep_running, new_state) = match message {
207                Some(GenServerInMsg::Call { sender, message }) => {
208                    let (keep_running, new_state, response) =
209                        match AssertUnwindSafe(self.handle_call(message, handle))
210                            .catch_unwind()
211                            .await
212                        {
213                            Ok(response) => match response {
214                                CallResponse::Reply(new_state, response) => {
215                                    (true, new_state, Ok(response))
216                                }
217                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
218                                CallResponse::Unused => {
219                                    tracing::error!("GenServer received unexpected CallMessage");
220                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
221                                }
222                            },
223                            Err(error) => {
224                                tracing::error!(
225                                    "Error in callback, reverting state - Error: '{error:?}'"
226                                );
227                                (true, state_clone, Err(GenServerError::Callback))
228                            }
229                        };
230                    // Send response back
231                    if sender.send(response).is_err() {
232                        tracing::error!(
233                            "GenServer failed to send response back, client must have died"
234                        )
235                    };
236                    (keep_running, new_state)
237                }
238                Some(GenServerInMsg::Cast { message }) => {
239                    match AssertUnwindSafe(self.handle_cast(message, handle))
240                        .catch_unwind()
241                        .await
242                    {
243                        Ok(response) => match response {
244                            CastResponse::NoReply(new_state) => (true, new_state),
245                            CastResponse::Stop => (false, state_clone),
246                            CastResponse::Unused => {
247                                tracing::error!("GenServer received unexpected CastMessage");
248                                (false, state_clone)
249                            }
250                        },
251                        Err(error) => {
252                            tracing::trace!(
253                                "Error in callback, reverting state - Error: '{error:?}'"
254                            );
255                            (true, state_clone)
256                        }
257                    }
258                }
259                None => {
260                    // Channel has been closed; won't receive further messages. Stop the server.
261                    (false, self)
262                }
263            };
264            Ok((new_state, keep_running))
265        }
266    }
267
268    fn handle_call(
269        self,
270        _message: Self::CallMsg,
271        _handle: &GenServerHandle<Self>,
272    ) -> impl Future<Output = CallResponse<Self>> + Send {
273        async { CallResponse::Unused }
274    }
275
276    fn handle_cast(
277        self,
278        _message: Self::CastMsg,
279        _handle: &GenServerHandle<Self>,
280    ) -> impl Future<Output = CastResponse<Self>> + Send {
281        async { CastResponse::Unused }
282    }
283
284    /// Teardown function. It's called after the stop message is received.
285    /// It can be overrided on implementations in case final steps are required,
286    /// like closing streams, stopping timers, etc.
287    fn teardown(
288        self,
289        _handle: &GenServerHandle<Self>,
290    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
291        async { Ok(()) }
292    }
293}
294
295#[cfg(test)]
296mod tests {
297
298    use super::*;
299    use crate::tasks::send_after;
300    use std::{thread, time::Duration};
301
302    #[derive(Clone)]
303    struct BadlyBehavedTask;
304
305    #[derive(Clone)]
306    pub enum InMessage {
307        GetCount,
308        Stop,
309    }
310    #[derive(Clone)]
311    pub enum OutMsg {
312        Count(u64),
313    }
314
315    impl GenServer for BadlyBehavedTask {
316        type CallMsg = InMessage;
317        type CastMsg = ();
318        type OutMsg = ();
319        type Error = ();
320
321        async fn handle_call(
322            self,
323            _: Self::CallMsg,
324            _: &GenServerHandle<Self>,
325        ) -> CallResponse<Self> {
326            CallResponse::Stop(())
327        }
328
329        async fn handle_cast(
330            self,
331            _: Self::CastMsg,
332            _: &GenServerHandle<Self>,
333        ) -> CastResponse<Self> {
334            rt::sleep(Duration::from_millis(20)).await;
335            thread::sleep(Duration::from_secs(2));
336            CastResponse::Stop
337        }
338    }
339
340    #[derive(Clone)]
341    struct WellBehavedTask {
342        pub count: u64,
343    }
344
345    impl GenServer for WellBehavedTask {
346        type CallMsg = InMessage;
347        type CastMsg = ();
348        type OutMsg = OutMsg;
349        type Error = ();
350
351        async fn handle_call(
352            self,
353            message: Self::CallMsg,
354            _: &GenServerHandle<Self>,
355        ) -> CallResponse<Self> {
356            match message {
357                InMessage::GetCount => {
358                    let count = self.count;
359                    CallResponse::Reply(self, OutMsg::Count(count))
360                }
361                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
362            }
363        }
364
365        async fn handle_cast(
366            mut self,
367            _: Self::CastMsg,
368            handle: &GenServerHandle<Self>,
369        ) -> CastResponse<Self> {
370            self.count += 1;
371            println!("{:?}: good still alive", thread::current().id());
372            send_after(Duration::from_millis(100), handle.to_owned(), ());
373            CastResponse::NoReply(self)
374        }
375    }
376
377    #[test]
378    pub fn badly_behaved_thread_non_blocking() {
379        let runtime = rt::Runtime::new().unwrap();
380        runtime.block_on(async move {
381            let mut badboy = BadlyBehavedTask.start();
382            let _ = badboy.cast(()).await;
383            let mut goodboy = WellBehavedTask { count: 0 }.start();
384            let _ = goodboy.cast(()).await;
385            rt::sleep(Duration::from_secs(1)).await;
386            let count = goodboy.call(InMessage::GetCount).await.unwrap();
387
388            match count {
389                OutMsg::Count(num) => {
390                    assert_ne!(num, 10);
391                }
392            }
393            goodboy.call(InMessage::Stop).await.unwrap();
394        });
395    }
396
397    #[test]
398    pub fn badly_behaved_thread() {
399        let runtime = rt::Runtime::new().unwrap();
400        runtime.block_on(async move {
401            let mut badboy = BadlyBehavedTask.start_blocking();
402            let _ = badboy.cast(()).await;
403            let mut goodboy = WellBehavedTask { count: 0 }.start();
404            let _ = goodboy.cast(()).await;
405            rt::sleep(Duration::from_secs(1)).await;
406            let count = goodboy.call(InMessage::GetCount).await.unwrap();
407
408            match count {
409                OutMsg::Count(num) => {
410                    assert_eq!(num, 10);
411                }
412            }
413            goodboy.call(InMessage::Stop).await.unwrap();
414        });
415    }
416
417    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
418
419    #[derive(Debug, Default, Clone)]
420    struct SomeTask;
421
422    #[derive(Clone)]
423    enum SomeTaskCallMsg {
424        SlowOperation,
425        FastOperation,
426    }
427
428    impl GenServer for SomeTask {
429        type CallMsg = SomeTaskCallMsg;
430        type CastMsg = ();
431        type OutMsg = ();
432        type Error = ();
433
434        async fn handle_call(
435            self,
436            message: Self::CallMsg,
437            _handle: &GenServerHandle<Self>,
438        ) -> CallResponse<Self> {
439            match message {
440                SomeTaskCallMsg::SlowOperation => {
441                    // Simulate a slow operation that will not resolve in time
442                    rt::sleep(TIMEOUT_DURATION * 2).await;
443                    CallResponse::Reply(self, ())
444                }
445                SomeTaskCallMsg::FastOperation => {
446                    // Simulate a fast operation that resolves in time
447                    rt::sleep(TIMEOUT_DURATION / 2).await;
448                    CallResponse::Reply(self, ())
449                }
450            }
451        }
452    }
453
454    #[test]
455    pub fn unresolving_task_times_out() {
456        let runtime = rt::Runtime::new().unwrap();
457        runtime.block_on(async move {
458            let mut unresolving_task = SomeTask.start();
459
460            let result = unresolving_task
461                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
462                .await;
463            assert!(matches!(result, Ok(())));
464
465            let result = unresolving_task
466                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
467                .await;
468            assert!(matches!(result, Err(GenServerError::CallTimeout)));
469        });
470    }
471}