spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9pub struct GenServerHandle<G: GenServer + 'static> {
10    pub tx: mpsc::Sender<GenServerInMsg<G>>,
11    cancellation_token: CancellationToken,
13}
14
15impl<G: GenServer> Clone for GenServerHandle<G> {
16    fn clone(&self) -> Self {
17        Self {
18            tx: self.tx.clone(),
19            cancellation_token: self.cancellation_token.clone(),
20        }
21    }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25    pub(crate) fn new(initial_state: G::State) -> Self {
26        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27        let cancellation_token = CancellationToken::new();
28        let handle = GenServerHandle {
29            tx,
30            cancellation_token,
31        };
32        let mut gen_server: G = GenServer::new();
33        let handle_clone = handle.clone();
34        let _join_handle = rt::spawn(async move {
36            if gen_server
37                .run(&handle, &mut rx, initial_state)
38                .await
39                .is_err()
40            {
41                tracing::trace!("GenServer crashed")
42            };
43        });
44        handle_clone
45    }
46
47    pub(crate) fn new_blocking(initial_state: G::State) -> Self {
48        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
49        let cancellation_token = CancellationToken::new();
50        let handle = GenServerHandle {
51            tx,
52            cancellation_token,
53        };
54        let mut gen_server: G = GenServer::new();
55        let handle_clone = handle.clone();
56        let _join_handle = rt::spawn_blocking(|| {
58            rt::block_on(async move {
59                if gen_server
60                    .run(&handle, &mut rx, initial_state)
61                    .await
62                    .is_err()
63                {
64                    tracing::trace!("GenServer crashed")
65                };
66            })
67        });
68        handle_clone
69    }
70
71    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
72        self.tx.clone()
73    }
74
75    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
76        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
77        self.tx.send(GenServerInMsg::Call {
78            sender: oneshot_tx,
79            message,
80        })?;
81        match oneshot_rx.await {
82            Ok(result) => result,
83            Err(_) => Err(GenServerError::Server),
84        }
85    }
86
87    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
88        self.tx
89            .send(GenServerInMsg::Cast { message })
90            .map_err(|_error| GenServerError::Server)
91    }
92
93    pub fn cancellation_token(&self) -> CancellationToken {
94        self.cancellation_token.clone()
95    }
96}
97
98pub enum GenServerInMsg<G: GenServer> {
99    Call {
100        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
101        message: G::CallMsg,
102    },
103    Cast {
104        message: G::CastMsg,
105    },
106}
107
108pub enum CallResponse<G: GenServer> {
109    Reply(G::State, G::OutMsg),
110    Unused,
111    Stop(G::OutMsg),
112}
113
114pub enum CastResponse<G: GenServer> {
115    NoReply(G::State),
116    Unused,
117    Stop,
118}
119
120pub trait GenServer
121where
122    Self: Send + Sized,
123{
124    type CallMsg: Clone + Send + Sized + Sync;
125    type CastMsg: Clone + Send + Sized + Sync;
126    type OutMsg: Send + Sized;
127    type State: Clone + Send;
128    type Error: Debug + Send;
129
130    fn new() -> Self;
131
132    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
133        GenServerHandle::new(initial_state)
134    }
135
136    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
142        GenServerHandle::new_blocking(initial_state)
143    }
144
145    fn run(
146        &mut self,
147        handle: &GenServerHandle<Self>,
148        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
149        state: Self::State,
150    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
151        async {
152            match self.init(handle, state).await {
153                Ok(new_state) => {
154                    self.main_loop(handle, rx, new_state).await?;
155                    Ok(())
156                }
157                Err(err) => {
158                    tracing::error!("Initialization failed: {err:?}");
159                    Err(GenServerError::Initialization)
160                }
161            }
162        }
163    }
164
165    fn init(
169        &mut self,
170        _handle: &GenServerHandle<Self>,
171        state: Self::State,
172    ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
173        async { Ok(state) }
174    }
175
176    fn main_loop(
177        &mut self,
178        handle: &GenServerHandle<Self>,
179        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
180        mut state: Self::State,
181    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
182        async {
183            loop {
184                let (new_state, cont) = self.receive(handle, rx, state).await?;
185                state = new_state;
186                if !cont {
187                    break;
188                }
189            }
190            tracing::trace!("Stopping GenServer");
191            handle.cancellation_token().cancel();
192            if let Err(err) = self.teardown(handle, state).await {
193                tracing::error!("Error during teardown: {err:?}");
194            }
195            Ok(())
196        }
197    }
198
199    fn receive(
200        &mut self,
201        handle: &GenServerHandle<Self>,
202        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
203        state: Self::State,
204    ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
205        async move {
206            let message = rx.recv().await;
207
208            let state_clone = state.clone();
210
211            let (keep_running, new_state) = match message {
212                Some(GenServerInMsg::Call { sender, message }) => {
213                    let (keep_running, new_state, response) =
214                        match AssertUnwindSafe(self.handle_call(message, handle, state))
215                            .catch_unwind()
216                            .await
217                        {
218                            Ok(response) => match response {
219                                CallResponse::Reply(new_state, response) => {
220                                    (true, new_state, Ok(response))
221                                }
222                                CallResponse::Stop(response) => (false, state_clone, Ok(response)),
223                                CallResponse::Unused => {
224                                    tracing::error!("GenServer received unexpected CallMessage");
225                                    (false, state_clone, Err(GenServerError::CallMsgUnused))
226                                }
227                            },
228                            Err(error) => {
229                                tracing::error!(
230                                    "Error in callback, reverting state - Error: '{error:?}'"
231                                );
232                                (true, state_clone, Err(GenServerError::Callback))
233                            }
234                        };
235                    if sender.send(response).is_err() {
237                        tracing::error!(
238                            "GenServer failed to send response back, client must have died"
239                        )
240                    };
241                    (keep_running, new_state)
242                }
243                Some(GenServerInMsg::Cast { message }) => {
244                    match AssertUnwindSafe(self.handle_cast(message, handle, state))
245                        .catch_unwind()
246                        .await
247                    {
248                        Ok(response) => match response {
249                            CastResponse::NoReply(new_state) => (true, new_state),
250                            CastResponse::Stop => (false, state_clone),
251                            CastResponse::Unused => {
252                                tracing::error!("GenServer received unexpected CastMessage");
253                                (false, state_clone)
254                            }
255                        },
256                        Err(error) => {
257                            tracing::trace!(
258                                "Error in callback, reverting state - Error: '{error:?}'"
259                            );
260                            (true, state_clone)
261                        }
262                    }
263                }
264                None => {
265                    (false, state)
267                }
268            };
269            Ok((new_state, keep_running))
270        }
271    }
272
273    fn handle_call(
274        &mut self,
275        _message: Self::CallMsg,
276        _handle: &GenServerHandle<Self>,
277        _state: Self::State,
278    ) -> impl Future<Output = CallResponse<Self>> + Send {
279        async { CallResponse::Unused }
280    }
281
282    fn handle_cast(
283        &mut self,
284        _message: Self::CastMsg,
285        _handle: &GenServerHandle<Self>,
286        _state: Self::State,
287    ) -> impl Future<Output = CastResponse<Self>> + Send {
288        async { CastResponse::Unused }
289    }
290
291    fn teardown(
295        &mut self,
296        _handle: &GenServerHandle<Self>,
297        _state: Self::State,
298    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299        async { Ok(()) }
300    }
301}
302
303#[cfg(test)]
304mod tests {
305
306    use super::*;
307    use crate::tasks::send_after;
308    use std::{thread, time::Duration};
309    struct BadlyBehavedTask;
310
311    #[derive(Clone)]
312    pub enum InMessage {
313        GetCount,
314        Stop,
315    }
316    #[derive(Clone)]
317    pub enum OutMsg {
318        Count(u64),
319    }
320
321    impl GenServer for BadlyBehavedTask {
322        type CallMsg = InMessage;
323        type CastMsg = ();
324        type OutMsg = ();
325        type State = ();
326        type Error = ();
327
328        fn new() -> Self {
329            Self {}
330        }
331
332        async fn handle_call(
333            &mut self,
334            _: Self::CallMsg,
335            _: &GenServerHandle<Self>,
336            _: Self::State,
337        ) -> CallResponse<Self> {
338            CallResponse::Stop(())
339        }
340
341        async fn handle_cast(
342            &mut self,
343            _: Self::CastMsg,
344            _: &GenServerHandle<Self>,
345            _: Self::State,
346        ) -> CastResponse<Self> {
347            rt::sleep(Duration::from_millis(20)).await;
348            thread::sleep(Duration::from_secs(2));
349            CastResponse::Stop
350        }
351    }
352
353    struct WellBehavedTask;
354
355    #[derive(Clone)]
356    struct CountState {
357        pub count: u64,
358    }
359
360    impl GenServer for WellBehavedTask {
361        type CallMsg = InMessage;
362        type CastMsg = ();
363        type OutMsg = OutMsg;
364        type State = CountState;
365        type Error = ();
366
367        fn new() -> Self {
368            Self {}
369        }
370
371        async fn handle_call(
372            &mut self,
373            message: Self::CallMsg,
374            _: &GenServerHandle<Self>,
375            state: Self::State,
376        ) -> CallResponse<Self> {
377            match message {
378                InMessage::GetCount => {
379                    let count = state.count;
380                    CallResponse::Reply(state, OutMsg::Count(count))
381                }
382                InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
383            }
384        }
385
386        async fn handle_cast(
387            &mut self,
388            _: Self::CastMsg,
389            handle: &GenServerHandle<Self>,
390            mut state: Self::State,
391        ) -> CastResponse<Self> {
392            state.count += 1;
393            println!("{:?}: good still alive", thread::current().id());
394            send_after(Duration::from_millis(100), handle.to_owned(), ());
395            CastResponse::NoReply(state)
396        }
397    }
398
399    #[test]
400    pub fn badly_behaved_thread_non_blocking() {
401        let runtime = rt::Runtime::new().unwrap();
402        runtime.block_on(async move {
403            let mut badboy = BadlyBehavedTask::start(());
404            let _ = badboy.cast(()).await;
405            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
406            let _ = goodboy.cast(()).await;
407            rt::sleep(Duration::from_secs(1)).await;
408            let count = goodboy.call(InMessage::GetCount).await.unwrap();
409
410            match count {
411                OutMsg::Count(num) => {
412                    assert_ne!(num, 10);
413                }
414            }
415            goodboy.call(InMessage::Stop).await.unwrap();
416        });
417    }
418
419    #[test]
420    pub fn badly_behaved_thread() {
421        let runtime = rt::Runtime::new().unwrap();
422        runtime.block_on(async move {
423            let mut badboy = BadlyBehavedTask::start_blocking(());
424            let _ = badboy.cast(()).await;
425            let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
426            let _ = goodboy.cast(()).await;
427            rt::sleep(Duration::from_secs(1)).await;
428            let count = goodboy.call(InMessage::GetCount).await.unwrap();
429
430            match count {
431                OutMsg::Count(num) => {
432                    assert_eq!(num, 10);
433                }
434            }
435            goodboy.call(InMessage::Stop).await.unwrap();
436        });
437    }
438}