spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9#[derive(Debug)]
10pub struct GenServerHandle<G: GenServer + 'static> {
11    pub tx: mpsc::Sender<GenServerInMsg<G>>,
12}
13
14impl<G: GenServer> Clone for GenServerHandle<G> {
15    fn clone(&self) -> Self {
16        Self {
17            tx: self.tx.clone(),
18        }
19    }
20}
21
22impl<G: GenServer> GenServerHandle<G> {
23    pub(crate) fn new(initial_state: G::State) -> Self {
24        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
25        let handle = GenServerHandle { tx };
26        let mut gen_server: G = GenServer::new();
27        let handle_clone = handle.clone();
28        // Ignore the JoinHandle for now. Maybe we'll use it in the future
29        let _join_handle = rt::spawn(async move {
30            if gen_server
31                .run(&handle, &mut rx, initial_state)
32                .await
33                .is_err()
34            {
35                tracing::trace!("GenServer crashed")
36            };
37        });
38        handle_clone
39    }
40
41    pub(crate) fn new_blocking(initial_state: G::State) -> Self {
42        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
43        let handle = GenServerHandle { tx };
44        let mut gen_server: G = GenServer::new();
45        let handle_clone = handle.clone();
46        // Ignore the JoinHandle for now. Maybe we'll use it in the future
47        let _join_handle = rt::spawn_blocking(|| {
48            rt::block_on(async move {
49                if gen_server
50                    .run(&handle, &mut rx, initial_state)
51                    .await
52                    .is_err()
53                {
54                    tracing::trace!("GenServer crashed")
55                };
56            })
57        });
58        handle_clone
59    }
60
61    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
62        self.tx.clone()
63    }
64
65    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
66        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
67        self.tx.send(GenServerInMsg::Call {
68            sender: oneshot_tx,
69            message,
70        })?;
71        match oneshot_rx.await {
72            Ok(result) => result,
73            Err(_) => Err(GenServerError::Server),
74        }
75    }
76
77    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
78        self.tx
79            .send(GenServerInMsg::Cast { message })
80            .map_err(|_error| GenServerError::Server)
81    }
82}
83
84pub enum GenServerInMsg<G: GenServer> {
85    Call {
86        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
87        message: G::CallMsg,
88    },
89    Cast {
90        message: G::CastMsg,
91    },
92}
93
94pub enum CallResponse<G: GenServer> {
95    Reply(G::State, G::OutMsg),
96    Unused,
97    Stop(G::OutMsg),
98}
99
100pub enum CastResponse<G: GenServer> {
101    NoReply(G::State),
102    Unused,
103    Stop,
104}
105
106pub trait GenServer
107where
108    Self: Send + Sized,
109{
110    type CallMsg: Clone + Send + Sized + Sync;
111    type CastMsg: Clone + Send + Sized + Sync;
112    type OutMsg: Send + Sized;
113    type State: Clone + Send;
114    type Error: Debug + Send;
115
116    fn new() -> Self;
117
118    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
119        GenServerHandle::new(initial_state)
120    }
121
122    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
123    /// happen if the task is blocking the thread. As such, for sync compute task
124    /// or other blocking tasks need to be in their own separate thread, and the OS
125    /// will manage them through hardware interrupts.
126    /// Start blocking provides such thread.
127    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
128        GenServerHandle::new_blocking(initial_state)
129    }
130
131    fn run(
132        &mut self,
133        handle: &GenServerHandle<Self>,
134        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
135        state: Self::State,
136    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
137        async {
138            match self.init(handle, state).await {
139                Ok(new_state) => {
140                    self.main_loop(handle, rx, new_state).await?;
141                    Ok(())
142                }
143                Err(err) => {
144                    tracing::error!("Initialization failed: {err:?}");
145                    Err(GenServerError::Initialization)
146                }
147            }
148        }
149    }
150
151    /// Initialization function. It's called before main loop. It
152    /// can be overrided on implementations in case initial steps are
153    /// required.
154    fn init(
155        &mut self,
156        _handle: &GenServerHandle<Self>,
157        state: Self::State,
158    ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
159        async { Ok(state) }
160    }
161
162    fn main_loop(
163        &mut self,
164        handle: &GenServerHandle<Self>,
165        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
166        mut state: Self::State,
167    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
168        async {
169            loop {
170                let (new_state, cont) = self.receive(handle, rx, state).await?;
171                if !cont {
172                    break;
173                }
174                state = new_state;
175            }
176            tracing::trace!("Stopping GenServer");
177            Ok(())
178        }
179    }
180
181    fn receive(
182        &mut self,
183        handle: &GenServerHandle<Self>,
184        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
185        state: Self::State,
186    ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
187        async move {
188            let message = rx.recv().await;
189
190            // Save current state in case of a rollback
191            let state_clone = state.clone();
192
193            let (keep_running, new_state) = match message {
194                Some(GenServerInMsg::Call { sender, message }) => {
195                    let (keep_running, new_state, response) =
196                        match AssertUnwindSafe(self.handle_call(message, handle, state))
197                            .catch_unwind()
198                            .await
199                        {
200                            Ok(response) => match response {
201                                CallResponse::Reply(new_state, response) => {
202                                    (true, new_state, Ok(response))
203                                }
204                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
205                                CallResponse::Unused => {
206                                    tracing::error!("GenServer received unexpected CallMessage");
207                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
208                                }
209                            },
210                            Err(error) => {
211                                tracing::error!(
212                                    "Error in callback, reverting state - Error: '{error:?}'"
213                                );
214                                (true, state_clone, Err(GenServerError::Callback))
215                            }
216                        };
217                    // Send response back
218                    if sender.send(response).is_err() {
219                        tracing::error!(
220                            "GenServer failed to send response back, client must have died"
221                        )
222                    };
223                    (keep_running, new_state)
224                }
225                Some(GenServerInMsg::Cast { message }) => {
226                    match AssertUnwindSafe(self.handle_cast(message, handle, state))
227                        .catch_unwind()
228                        .await
229                    {
230                        Ok(response) => match response {
231                            CastResponse::NoReply(new_state) => (true, new_state),
232                            CastResponse::Stop => (false, state_clone),
233                            CastResponse::Unused => {
234                                tracing::error!("GenServer received unexpected CastMessage");
235                                (false, state_clone)
236                            }
237                        },
238                        Err(error) => {
239                            tracing::trace!(
240                                "Error in callback, reverting state - Error: '{error:?}'"
241                            );
242                            (true, state_clone)
243                        }
244                    }
245                }
246                None => {
247                    // Channel has been closed; won't receive further messages. Stop the server.
248                    (false, state)
249                }
250            };
251            Ok((new_state, keep_running))
252        }
253    }
254
255    fn handle_call(
256        &mut self,
257        _message: Self::CallMsg,
258        _handle: &GenServerHandle<Self>,
259        _state: Self::State,
260    ) -> impl Future<Output = CallResponse<Self>> + Send {
261        async { CallResponse::Unused }
262    }
263
264    fn handle_cast(
265        &mut self,
266        _message: Self::CastMsg,
267        _handle: &GenServerHandle<Self>,
268        _state: Self::State,
269    ) -> impl Future<Output = CastResponse<Self>> + Send {
270        async { CastResponse::Unused }
271    }
272}
273
274#[cfg(test)]
275mod tests {
276
277    use super::*;
278    use crate::tasks::send_after;
279    use std::{thread, time::Duration};
280    struct BadlyBehavedTask;
281
282    #[derive(Clone)]
283    pub enum InMessage {
284        GetCount,
285        Stop,
286    }
287    #[derive(Clone)]
288    pub enum OutMsg {
289        Count(u64),
290    }
291
292    impl GenServer for BadlyBehavedTask {
293        type CallMsg = InMessage;
294        type CastMsg = ();
295        type OutMsg = ();
296        type State = ();
297        type Error = ();
298
299        fn new() -> Self {
300            Self {}
301        }
302
303        async fn handle_call(
304            &mut self,
305            _: Self::CallMsg,
306            _: &GenServerHandle<Self>,
307            _: Self::State,
308        ) -> CallResponse<Self> {
309            CallResponse::Stop(())
310        }
311
312        async fn handle_cast(
313            &mut self,
314            _: Self::CastMsg,
315            _: &GenServerHandle<Self>,
316            _: Self::State,
317        ) -> CastResponse<Self> {
318            rt::sleep(Duration::from_millis(20)).await;
319            thread::sleep(Duration::from_secs(2));
320            CastResponse::Stop
321        }
322    }
323
324    struct WellBehavedTask;
325
326    #[derive(Clone)]
327    struct CountState {
328        pub count: u64,
329    }
330
331    impl GenServer for WellBehavedTask {
332        type CallMsg = InMessage;
333        type CastMsg = ();
334        type OutMsg = OutMsg;
335        type State = CountState;
336        type Error = ();
337
338        fn new() -> Self {
339            Self {}
340        }
341
342        async fn handle_call(
343            &mut self,
344            message: Self::CallMsg,
345            _: &GenServerHandle<Self>,
346            state: Self::State,
347        ) -> CallResponse<Self> {
348            match message {
349                InMessage::GetCount => {
350                    let count = state.count;
351                    CallResponse::Reply(state, OutMsg::Count(count))
352                }
353                InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
354            }
355        }
356
357        async fn handle_cast(
358            &mut self,
359            _: Self::CastMsg,
360            handle: &GenServerHandle<Self>,
361            mut state: Self::State,
362        ) -> CastResponse<Self> {
363            state.count += 1;
364            println!("{:?}: good still alive", thread::current().id());
365            send_after(Duration::from_millis(100), handle.to_owned(), ());
366            CastResponse::NoReply(state)
367        }
368    }
369
370    #[test]
371    pub fn badly_behaved_thread_non_blocking() {
372        let runtime = rt::Runtime::new().unwrap();
373        runtime.block_on(async move {
374            let mut badboy = BadlyBehavedTask::start(());
375            let _ = badboy.cast(()).await;
376            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
377            let _ = goodboy.cast(()).await;
378            rt::sleep(Duration::from_secs(1)).await;
379            let count = goodboy.call(InMessage::GetCount).await.unwrap();
380
381            match count {
382                OutMsg::Count(num) => {
383                    assert_ne!(num, 10);
384                }
385            }
386            goodboy.call(InMessage::Stop).await.unwrap();
387        });
388    }
389
390    #[test]
391    pub fn badly_behaved_thread() {
392        let runtime = rt::Runtime::new().unwrap();
393        runtime.block_on(async move {
394            let mut badboy = BadlyBehavedTask::start_blocking(());
395            let _ = badboy.cast(()).await;
396            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
397            let _ = goodboy.cast(()).await;
398            rt::sleep(Duration::from_secs(1)).await;
399            let count = goodboy.call(InMessage::GetCount).await.unwrap();
400
401            match count {
402                OutMsg::Count(num) => {
403                    assert_eq!(num, 10);
404                }
405            }
406            goodboy.call(InMessage::Stop).await.unwrap();
407        });
408    }
409}