spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9#[derive(Debug)]
10pub struct GenServerHandle<G: GenServer + 'static> {
11    pub tx: mpsc::Sender<GenServerInMsg<G>>,
12    /// Cancellation token to stop the GenServer
13    cancellation_token: CancellationToken,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17    fn clone(&self) -> Self {
18        Self {
19            tx: self.tx.clone(),
20            cancellation_token: self.cancellation_token.clone(),
21        }
22    }
23}
24
25impl<G: GenServer> GenServerHandle<G> {
26    pub(crate) fn new(initial_state: G::State) -> Self {
27        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
28        let cancellation_token = CancellationToken::new();
29        let handle = GenServerHandle {
30            tx,
31            cancellation_token,
32        };
33        let mut gen_server: G = GenServer::new();
34        let handle_clone = handle.clone();
35        // Ignore the JoinHandle for now. Maybe we'll use it in the future
36        let _join_handle = rt::spawn(async move {
37            if gen_server
38                .run(&handle, &mut rx, initial_state)
39                .await
40                .is_err()
41            {
42                tracing::trace!("GenServer crashed")
43            };
44        });
45        handle_clone
46    }
47
48    pub(crate) fn new_blocking(initial_state: G::State) -> Self {
49        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50        let cancellation_token = CancellationToken::new();
51        let handle = GenServerHandle {
52            tx,
53            cancellation_token,
54        };
55        let mut gen_server: G = GenServer::new();
56        let handle_clone = handle.clone();
57        // Ignore the JoinHandle for now. Maybe we'll use it in the future
58        let _join_handle = rt::spawn_blocking(|| {
59            rt::block_on(async move {
60                if gen_server
61                    .run(&handle, &mut rx, initial_state)
62                    .await
63                    .is_err()
64                {
65                    tracing::trace!("GenServer crashed")
66                };
67            })
68        });
69        handle_clone
70    }
71
72    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
73        self.tx.clone()
74    }
75
76    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
77        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78        self.tx.send(GenServerInMsg::Call {
79            sender: oneshot_tx,
80            message,
81        })?;
82        match oneshot_rx.await {
83            Ok(result) => result,
84            Err(_) => Err(GenServerError::Server),
85        }
86    }
87
88    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
89        self.tx
90            .send(GenServerInMsg::Cast { message })
91            .map_err(|_error| GenServerError::Server)
92    }
93
94    pub fn cancellation_token(&self) -> CancellationToken {
95        self.cancellation_token.clone()
96    }
97}
98
99pub enum GenServerInMsg<G: GenServer> {
100    Call {
101        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
102        message: G::CallMsg,
103    },
104    Cast {
105        message: G::CastMsg,
106    },
107}
108
109pub enum CallResponse<G: GenServer> {
110    Reply(G::State, G::OutMsg),
111    Unused,
112    Stop(G::OutMsg),
113}
114
115pub enum CastResponse<G: GenServer> {
116    NoReply(G::State),
117    Unused,
118    Stop,
119}
120
121pub trait GenServer
122where
123    Self: Send + Sized,
124{
125    type CallMsg: Clone + Send + Sized + Sync;
126    type CastMsg: Clone + Send + Sized + Sync;
127    type OutMsg: Send + Sized;
128    type State: Clone + Send;
129    type Error: Debug + Send;
130
131    fn new() -> Self;
132
133    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
134        GenServerHandle::new(initial_state)
135    }
136
137    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
138    /// happen if the task is blocking the thread. As such, for sync compute task
139    /// or other blocking tasks need to be in their own separate thread, and the OS
140    /// will manage them through hardware interrupts.
141    /// Start blocking provides such thread.
142    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
143        GenServerHandle::new_blocking(initial_state)
144    }
145
146    fn run(
147        &mut self,
148        handle: &GenServerHandle<Self>,
149        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
150        state: Self::State,
151    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
152        async {
153            match self.init(handle, state).await {
154                Ok(new_state) => {
155                    self.main_loop(handle, rx, new_state).await?;
156                    Ok(())
157                }
158                Err(err) => {
159                    tracing::error!("Initialization failed: {err:?}");
160                    Err(GenServerError::Initialization)
161                }
162            }
163        }
164    }
165
166    /// Initialization function. It's called before main loop. It
167    /// can be overrided on implementations in case initial steps are
168    /// required.
169    fn init(
170        &mut self,
171        _handle: &GenServerHandle<Self>,
172        state: Self::State,
173    ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
174        async { Ok(state) }
175    }
176
177    fn main_loop(
178        &mut self,
179        handle: &GenServerHandle<Self>,
180        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
181        mut state: Self::State,
182    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
183        async {
184            loop {
185                let (new_state, cont) = self.receive(handle, rx, state).await?;
186                state = new_state;
187                if !cont {
188                    break;
189                }
190            }
191            tracing::trace!("Stopping GenServer");
192            handle.cancellation_token().cancel();
193            if let Err(err) = self.teardown(handle, state).await {
194                tracing::error!("Error during teardown: {err:?}");
195            }
196            Ok(())
197        }
198    }
199
200    fn receive(
201        &mut self,
202        handle: &GenServerHandle<Self>,
203        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
204        state: Self::State,
205    ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
206        async move {
207            let message = rx.recv().await;
208
209            // Save current state in case of a rollback
210            let state_clone = state.clone();
211
212            let (keep_running, new_state) = match message {
213                Some(GenServerInMsg::Call { sender, message }) => {
214                    let (keep_running, new_state, response) =
215                        match AssertUnwindSafe(self.handle_call(message, handle, state))
216                            .catch_unwind()
217                            .await
218                        {
219                            Ok(response) => match response {
220                                CallResponse::Reply(new_state, response) => {
221                                    (true, new_state, Ok(response))
222                                }
223                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
224                                CallResponse::Unused => {
225                                    tracing::error!("GenServer received unexpected CallMessage");
226                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
227                                }
228                            },
229                            Err(error) => {
230                                tracing::error!(
231                                    "Error in callback, reverting state - Error: '{error:?}'"
232                                );
233                                (true, state_clone, Err(GenServerError::Callback))
234                            }
235                        };
236                    // Send response back
237                    if sender.send(response).is_err() {
238                        tracing::error!(
239                            "GenServer failed to send response back, client must have died"
240                        )
241                    };
242                    (keep_running, new_state)
243                }
244                Some(GenServerInMsg::Cast { message }) => {
245                    match AssertUnwindSafe(self.handle_cast(message, handle, state))
246                        .catch_unwind()
247                        .await
248                    {
249                        Ok(response) => match response {
250                            CastResponse::NoReply(new_state) => (true, new_state),
251                            CastResponse::Stop => (false, state_clone),
252                            CastResponse::Unused => {
253                                tracing::error!("GenServer received unexpected CastMessage");
254                                (false, state_clone)
255                            }
256                        },
257                        Err(error) => {
258                            tracing::trace!(
259                                "Error in callback, reverting state - Error: '{error:?}'"
260                            );
261                            (true, state_clone)
262                        }
263                    }
264                }
265                None => {
266                    // Channel has been closed; won't receive further messages. Stop the server.
267                    (false, state)
268                }
269            };
270            Ok((new_state, keep_running))
271        }
272    }
273
274    fn handle_call(
275        &mut self,
276        _message: Self::CallMsg,
277        _handle: &GenServerHandle<Self>,
278        _state: Self::State,
279    ) -> impl Future<Output = CallResponse<Self>> + Send {
280        async { CallResponse::Unused }
281    }
282
283    fn handle_cast(
284        &mut self,
285        _message: Self::CastMsg,
286        _handle: &GenServerHandle<Self>,
287        _state: Self::State,
288    ) -> impl Future<Output = CastResponse<Self>> + Send {
289        async { CastResponse::Unused }
290    }
291
292    /// Teardown function. It's called after the stop message is received.
293    /// It can be overrided on implementations in case final steps are required,
294    /// like closing streams, stopping timers, etc.
295    fn teardown(
296        &mut self,
297        _handle: &GenServerHandle<Self>,
298        _state: Self::State,
299    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
300        async { Ok(()) }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306
307    use super::*;
308    use crate::tasks::send_after;
309    use std::{thread, time::Duration};
310    struct BadlyBehavedTask;
311
312    #[derive(Clone)]
313    pub enum InMessage {
314        GetCount,
315        Stop,
316    }
317    #[derive(Clone)]
318    pub enum OutMsg {
319        Count(u64),
320    }
321
322    impl GenServer for BadlyBehavedTask {
323        type CallMsg = InMessage;
324        type CastMsg = ();
325        type OutMsg = ();
326        type State = ();
327        type Error = ();
328
329        fn new() -> Self {
330            Self {}
331        }
332
333        async fn handle_call(
334            &mut self,
335            _: Self::CallMsg,
336            _: &GenServerHandle<Self>,
337            _: Self::State,
338        ) -> CallResponse<Self> {
339            CallResponse::Stop(())
340        }
341
342        async fn handle_cast(
343            &mut self,
344            _: Self::CastMsg,
345            _: &GenServerHandle<Self>,
346            _: Self::State,
347        ) -> CastResponse<Self> {
348            rt::sleep(Duration::from_millis(20)).await;
349            thread::sleep(Duration::from_secs(2));
350            CastResponse::Stop
351        }
352    }
353
354    struct WellBehavedTask;
355
356    #[derive(Clone)]
357    struct CountState {
358        pub count: u64,
359    }
360
361    impl GenServer for WellBehavedTask {
362        type CallMsg = InMessage;
363        type CastMsg = ();
364        type OutMsg = OutMsg;
365        type State = CountState;
366        type Error = ();
367
368        fn new() -> Self {
369            Self {}
370        }
371
372        async fn handle_call(
373            &mut self,
374            message: Self::CallMsg,
375            _: &GenServerHandle<Self>,
376            state: Self::State,
377        ) -> CallResponse<Self> {
378            match message {
379                InMessage::GetCount => {
380                    let count = state.count;
381                    CallResponse::Reply(state, OutMsg::Count(count))
382                }
383                InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
384            }
385        }
386
387        async fn handle_cast(
388            &mut self,
389            _: Self::CastMsg,
390            handle: &GenServerHandle<Self>,
391            mut state: Self::State,
392        ) -> CastResponse<Self> {
393            state.count += 1;
394            println!("{:?}: good still alive", thread::current().id());
395            send_after(Duration::from_millis(100), handle.to_owned(), ());
396            CastResponse::NoReply(state)
397        }
398    }
399
400    #[test]
401    pub fn badly_behaved_thread_non_blocking() {
402        let runtime = rt::Runtime::new().unwrap();
403        runtime.block_on(async move {
404            let mut badboy = BadlyBehavedTask::start(());
405            let _ = badboy.cast(()).await;
406            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
407            let _ = goodboy.cast(()).await;
408            rt::sleep(Duration::from_secs(1)).await;
409            let count = goodboy.call(InMessage::GetCount).await.unwrap();
410
411            match count {
412                OutMsg::Count(num) => {
413                    assert_ne!(num, 10);
414                }
415            }
416            goodboy.call(InMessage::Stop).await.unwrap();
417        });
418    }
419
420    #[test]
421    pub fn badly_behaved_thread() {
422        let runtime = rt::Runtime::new().unwrap();
423        runtime.block_on(async move {
424            let mut badboy = BadlyBehavedTask::start_blocking(());
425            let _ = badboy.cast(()).await;
426            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
427            let _ = goodboy.cast(()).await;
428            rt::sleep(Duration::from_secs(1)).await;
429            let count = goodboy.call(InMessage::GetCount).await.unwrap();
430
431            match count {
432                OutMsg::Count(num) => {
433                    assert_eq!(num, 10);
434                }
435            }
436            goodboy.call(InMessage::Stop).await.unwrap();
437        });
438    }
439}