spawned_concurrency/tasks/
gen_server.rs1use crate::{
4    error::GenServerError,
5    tasks::InitResult::{NoSuccess, Success},
6};
7use futures::future::FutureExt as _;
8use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
9use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
10
11const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
12
13#[derive(Debug)]
14pub struct GenServerHandle<G: GenServer + 'static> {
15    pub tx: mpsc::Sender<GenServerInMsg<G>>,
16    cancellation_token: CancellationToken,
18}
19
20impl<G: GenServer> Clone for GenServerHandle<G> {
21    fn clone(&self) -> Self {
22        Self {
23            tx: self.tx.clone(),
24            cancellation_token: self.cancellation_token.clone(),
25        }
26    }
27}
28
29impl<G: GenServer> GenServerHandle<G> {
30    pub(crate) fn new(gen_server: G) -> Self {
31        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
32        let cancellation_token = CancellationToken::new();
33        let handle = GenServerHandle {
34            tx,
35            cancellation_token,
36        };
37        let handle_clone = handle.clone();
38        let inner_future = async move {
39            if gen_server.run(&handle, &mut rx).await.is_err() {
40                tracing::trace!("GenServer crashed")
41            }
42        };
43
44        #[cfg(feature = "warn-on-block")]
45        let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
47
48        let _join_handle = rt::spawn(inner_future);
50
51        handle_clone
52    }
53
54    pub(crate) fn new_blocking(gen_server: G) -> Self {
55        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
56        let cancellation_token = CancellationToken::new();
57        let handle = GenServerHandle {
58            tx,
59            cancellation_token,
60        };
61        let handle_clone = handle.clone();
62        let _join_handle = rt::spawn_blocking(|| {
64            rt::block_on(async move {
65                if gen_server.run(&handle, &mut rx).await.is_err() {
66                    tracing::trace!("GenServer crashed")
67                };
68            })
69        });
70        handle_clone
71    }
72
73    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
74        self.tx.clone()
75    }
76
77    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
78        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
79    }
80
81    pub async fn call_with_timeout(
82        &mut self,
83        message: G::CallMsg,
84        duration: Duration,
85    ) -> Result<G::OutMsg, GenServerError> {
86        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
87        self.tx.send(GenServerInMsg::Call {
88            sender: oneshot_tx,
89            message,
90        })?;
91
92        match timeout(duration, oneshot_rx).await {
93            Ok(Ok(result)) => result,
94            Ok(Err(_)) => Err(GenServerError::Server),
95            Err(_) => Err(GenServerError::CallTimeout),
96        }
97    }
98
99    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
100        self.tx
101            .send(GenServerInMsg::Cast { message })
102            .map_err(|_error| GenServerError::Server)
103    }
104
105    pub fn cancellation_token(&self) -> CancellationToken {
106        self.cancellation_token.clone()
107    }
108}
109
110pub enum GenServerInMsg<G: GenServer> {
111    Call {
112        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
113        message: G::CallMsg,
114    },
115    Cast {
116        message: G::CastMsg,
117    },
118}
119
120pub enum CallResponse<G: GenServer> {
121    Reply(G::OutMsg),
122    Unused,
123    Stop(G::OutMsg),
124}
125
126pub enum CastResponse {
127    NoReply,
128    Unused,
129    Stop,
130}
131
132pub enum InitResult<G: GenServer> {
133    Success(G),
134    NoSuccess(G),
135}
136
137pub trait GenServer: Send + Sized {
138    type CallMsg: Clone + Send + Sized + Sync;
139    type CastMsg: Clone + Send + Sized + Sync;
140    type OutMsg: Send + Sized;
141    type Error: Debug + Send;
142
143    fn start(self) -> GenServerHandle<Self> {
144        GenServerHandle::new(self)
145    }
146
147    fn start_blocking(self) -> GenServerHandle<Self> {
153        GenServerHandle::new_blocking(self)
154    }
155
156    fn run(
157        self,
158        handle: &GenServerHandle<Self>,
159        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
160    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
161        async {
162            let res = match self.init(handle).await {
163                Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
164                Ok(NoSuccess(intermediate_state)) => {
165                    Ok(intermediate_state)
169                }
170                Err(err) => {
171                    tracing::error!("Initialization failed with unhandled error: {err:?}");
172                    Err(GenServerError::Initialization)
173                }
174            };
175
176            handle.cancellation_token().cancel();
177            if let Ok(final_state) = res {
178                if let Err(err) = final_state.teardown(handle).await {
179                    tracing::error!("Error during teardown: {err:?}");
180                }
181            }
182            Ok(())
183        }
184    }
185
186    fn init(
190        self,
191        _handle: &GenServerHandle<Self>,
192    ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
193        async { Ok(Success(self)) }
194    }
195
196    fn main_loop(
197        mut self,
198        handle: &GenServerHandle<Self>,
199        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200    ) -> impl Future<Output = Self> + Send {
201        async {
202            loop {
203                if !self.receive(handle, rx).await {
204                    break;
205                }
206            }
207            tracing::trace!("Stopping GenServer");
208            self
209        }
210    }
211
212    fn receive(
213        &mut self,
214        handle: &GenServerHandle<Self>,
215        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
216    ) -> impl Future<Output = bool> + Send {
217        async move {
218            let message = rx.recv().await;
219
220            let keep_running = match message {
221                Some(GenServerInMsg::Call { sender, message }) => {
222                    let (keep_running, response) =
223                        match AssertUnwindSafe(self.handle_call(message, handle))
224                            .catch_unwind()
225                            .await
226                        {
227                            Ok(response) => match response {
228                                CallResponse::Reply(response) => (true, Ok(response)),
229                                CallResponse::Stop(response) => (false, Ok(response)),
230                                CallResponse::Unused => {
231                                    tracing::error!("GenServer received unexpected CallMessage");
232                                    (false, Err(GenServerError::CallMsgUnused))
233                                }
234                            },
235                            Err(error) => {
236                                tracing::error!("Error in callback: '{error:?}'");
237                                (false, Err(GenServerError::Callback))
238                            }
239                        };
240                    if sender.send(response).is_err() {
242                        tracing::error!(
243                            "GenServer failed to send response back, client must have died"
244                        )
245                    };
246                    keep_running
247                }
248                Some(GenServerInMsg::Cast { message }) => {
249                    match AssertUnwindSafe(self.handle_cast(message, handle))
250                        .catch_unwind()
251                        .await
252                    {
253                        Ok(response) => match response {
254                            CastResponse::NoReply => true,
255                            CastResponse::Stop => false,
256                            CastResponse::Unused => {
257                                tracing::error!("GenServer received unexpected CastMessage");
258                                false
259                            }
260                        },
261                        Err(error) => {
262                            tracing::trace!("Error in callback: '{error:?}'");
263                            false
264                        }
265                    }
266                }
267                None => {
268                    false
270                }
271            };
272            keep_running
273        }
274    }
275
276    fn handle_call(
277        &mut self,
278        _message: Self::CallMsg,
279        _handle: &GenServerHandle<Self>,
280    ) -> impl Future<Output = CallResponse<Self>> + Send {
281        async { CallResponse::Unused }
282    }
283
284    fn handle_cast(
285        &mut self,
286        _message: Self::CastMsg,
287        _handle: &GenServerHandle<Self>,
288    ) -> impl Future<Output = CastResponse> + Send {
289        async { CastResponse::Unused }
290    }
291
292    fn teardown(
296        self,
297        _handle: &GenServerHandle<Self>,
298    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299        async { Ok(()) }
300    }
301}
302
303#[cfg(feature = "warn-on-block")]
304mod warn_on_block {
305    use super::*;
306
307    use std::time::Instant;
308    use tracing::warn;
309
310    pin_project_lite::pin_project! {
311        pub struct WarnOnBlocking<F: Future>{
312            #[pin]
313            inner: F
314        }
315    }
316
317    impl<F: Future> WarnOnBlocking<F> {
318        pub fn new(inner: F) -> Self {
319            Self { inner }
320        }
321    }
322
323    impl<F: Future> Future for WarnOnBlocking<F> {
324        type Output = F::Output;
325
326        fn poll(
327            self: std::pin::Pin<&mut Self>,
328            cx: &mut std::task::Context<'_>,
329        ) -> std::task::Poll<Self::Output> {
330            let type_id = std::any::type_name::<F>();
331            let this = self.project();
332            let now = Instant::now();
333            let res = this.inner.poll(cx);
334            let elapsed = now.elapsed();
335            if elapsed > Duration::from_millis(10) {
336                warn!(future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
337            }
338            res
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345
346    use super::*;
347    use crate::{messages::Unused, tasks::send_after};
348    use std::{
349        sync::{Arc, Mutex},
350        thread,
351        time::Duration,
352    };
353
354    struct BadlyBehavedTask;
355
356    #[derive(Clone)]
357    pub enum InMessage {
358        GetCount,
359        Stop,
360    }
361    #[derive(Clone)]
362    pub enum OutMsg {
363        Count(u64),
364    }
365
366    impl GenServer for BadlyBehavedTask {
367        type CallMsg = InMessage;
368        type CastMsg = Unused;
369        type OutMsg = Unused;
370        type Error = Unused;
371
372        async fn handle_call(
373            &mut self,
374            _: Self::CallMsg,
375            _: &GenServerHandle<Self>,
376        ) -> CallResponse<Self> {
377            CallResponse::Stop(Unused)
378        }
379
380        async fn handle_cast(
381            &mut self,
382            _: Self::CastMsg,
383            _: &GenServerHandle<Self>,
384        ) -> CastResponse {
385            rt::sleep(Duration::from_millis(20)).await;
386            thread::sleep(Duration::from_secs(2));
387            CastResponse::Stop
388        }
389    }
390
391    struct WellBehavedTask {
392        pub count: u64,
393    }
394
395    impl GenServer for WellBehavedTask {
396        type CallMsg = InMessage;
397        type CastMsg = Unused;
398        type OutMsg = OutMsg;
399        type Error = Unused;
400
401        async fn handle_call(
402            &mut self,
403            message: Self::CallMsg,
404            _: &GenServerHandle<Self>,
405        ) -> CallResponse<Self> {
406            match message {
407                InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
408                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
409            }
410        }
411
412        async fn handle_cast(
413            &mut self,
414            _: Self::CastMsg,
415            handle: &GenServerHandle<Self>,
416        ) -> CastResponse {
417            self.count += 1;
418            println!("{:?}: good still alive", thread::current().id());
419            send_after(Duration::from_millis(100), handle.to_owned(), Unused);
420            CastResponse::NoReply
421        }
422    }
423
424    #[test]
425    pub fn badly_behaved_thread_non_blocking() {
426        let runtime = rt::Runtime::new().unwrap();
427        runtime.block_on(async move {
428            let mut badboy = BadlyBehavedTask.start();
429            let _ = badboy.cast(Unused).await;
430            let mut goodboy = WellBehavedTask { count: 0 }.start();
431            let _ = goodboy.cast(Unused).await;
432            rt::sleep(Duration::from_secs(1)).await;
433            let count = goodboy.call(InMessage::GetCount).await.unwrap();
434
435            match count {
436                OutMsg::Count(num) => {
437                    assert_ne!(num, 10);
438                }
439            }
440            goodboy.call(InMessage::Stop).await.unwrap();
441        });
442    }
443
444    #[test]
445    pub fn badly_behaved_thread() {
446        let runtime = rt::Runtime::new().unwrap();
447        runtime.block_on(async move {
448            let mut badboy = BadlyBehavedTask.start_blocking();
449            let _ = badboy.cast(Unused).await;
450            let mut goodboy = WellBehavedTask { count: 0 }.start();
451            let _ = goodboy.cast(Unused).await;
452            rt::sleep(Duration::from_secs(1)).await;
453            let count = goodboy.call(InMessage::GetCount).await.unwrap();
454
455            match count {
456                OutMsg::Count(num) => {
457                    assert_eq!(num, 10);
458                }
459            }
460            goodboy.call(InMessage::Stop).await.unwrap();
461        });
462    }
463
464    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
465
466    #[derive(Debug, Default)]
467    struct SomeTask;
468
469    #[derive(Clone)]
470    enum SomeTaskCallMsg {
471        SlowOperation,
472        FastOperation,
473    }
474
475    impl GenServer for SomeTask {
476        type CallMsg = SomeTaskCallMsg;
477        type CastMsg = Unused;
478        type OutMsg = Unused;
479        type Error = Unused;
480
481        async fn handle_call(
482            &mut self,
483            message: Self::CallMsg,
484            _handle: &GenServerHandle<Self>,
485        ) -> CallResponse<Self> {
486            match message {
487                SomeTaskCallMsg::SlowOperation => {
488                    rt::sleep(TIMEOUT_DURATION * 2).await;
490                    CallResponse::Reply(Unused)
491                }
492                SomeTaskCallMsg::FastOperation => {
493                    rt::sleep(TIMEOUT_DURATION / 2).await;
495                    CallResponse::Reply(Unused)
496                }
497            }
498        }
499    }
500
501    #[test]
502    pub fn unresolving_task_times_out() {
503        let runtime = rt::Runtime::new().unwrap();
504        runtime.block_on(async move {
505            let mut unresolving_task = SomeTask.start();
506
507            let result = unresolving_task
508                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
509                .await;
510            assert!(matches!(result, Ok(Unused)));
511
512            let result = unresolving_task
513                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
514                .await;
515            assert!(matches!(result, Err(GenServerError::CallTimeout)));
516        });
517    }
518
519    struct SomeTaskThatFailsOnInit {
520        sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
521    }
522
523    impl SomeTaskThatFailsOnInit {
524        pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
525            Self { sender_channel }
526        }
527    }
528
529    impl GenServer for SomeTaskThatFailsOnInit {
530        type CallMsg = Unused;
531        type CastMsg = Unused;
532        type OutMsg = Unused;
533        type Error = Unused;
534
535        async fn init(
536            self,
537            _handle: &GenServerHandle<Self>,
538        ) -> Result<InitResult<Self>, Self::Error> {
539            Ok(NoSuccess(self))
541        }
542
543        async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
544            self.sender_channel.lock().unwrap().close();
545            Ok(())
546        }
547    }
548
549    #[test]
550    pub fn task_fails_with_intermediate_state() {
551        let runtime = rt::Runtime::new().unwrap();
552        runtime.block_on(async move {
553            let (rx, tx) = mpsc::channel::<u8>();
554            let sender_channel = Arc::new(Mutex::new(tx));
555            let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
556
557            rt::sleep(Duration::from_secs(1)).await;
559
560            assert!(rx.is_closed())
562        });
563    }
564}