spawned_concurrency/tasks/
gen_server.rs1use crate::{
4    error::GenServerError,
5    tasks::InitResult::{NoSuccess, Success},
6};
7use futures::future::FutureExt as _;
8use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
9use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
10
11const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
12
13#[derive(Debug)]
14pub struct GenServerHandle<G: GenServer + 'static> {
15    pub tx: mpsc::Sender<GenServerInMsg<G>>,
16    cancellation_token: CancellationToken,
18}
19
20impl<G: GenServer> Clone for GenServerHandle<G> {
21    fn clone(&self) -> Self {
22        Self {
23            tx: self.tx.clone(),
24            cancellation_token: self.cancellation_token.clone(),
25        }
26    }
27}
28
29impl<G: GenServer> GenServerHandle<G> {
30    pub(crate) fn new(gen_server: G) -> Self {
31        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
32        let cancellation_token = CancellationToken::new();
33        let handle = GenServerHandle {
34            tx,
35            cancellation_token,
36        };
37        let handle_clone = handle.clone();
38        let inner_future = async move {
39            if let Err(error) = gen_server.run(&handle, &mut rx).await {
40                tracing::trace!(%error, "GenServer crashed")
41            }
42        };
43
44        #[cfg(debug_assertions)]
45        let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
47
48        let _join_handle = rt::spawn(inner_future);
50
51        handle_clone
52    }
53
54    pub(crate) fn new_blocking(gen_server: G) -> Self {
55        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
56        let cancellation_token = CancellationToken::new();
57        let handle = GenServerHandle {
58            tx,
59            cancellation_token,
60        };
61        let handle_clone = handle.clone();
62        let _join_handle = rt::spawn_blocking(|| {
64            rt::block_on(async move {
65                if let Err(error) = gen_server.run(&handle, &mut rx).await {
66                    tracing::trace!(%error, "GenServer crashed")
67                };
68            })
69        });
70        handle_clone
71    }
72
73    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
74        self.tx.clone()
75    }
76
77    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
78        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
79    }
80
81    pub async fn call_with_timeout(
82        &mut self,
83        message: G::CallMsg,
84        duration: Duration,
85    ) -> Result<G::OutMsg, GenServerError> {
86        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
87        self.tx.send(GenServerInMsg::Call {
88            sender: oneshot_tx,
89            message,
90        })?;
91
92        match timeout(duration, oneshot_rx).await {
93            Ok(Ok(result)) => result,
94            Ok(Err(_)) => Err(GenServerError::Server),
95            Err(_) => Err(GenServerError::CallTimeout),
96        }
97    }
98
99    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
100        self.tx
101            .send(GenServerInMsg::Cast { message })
102            .map_err(|_error| GenServerError::Server)
103    }
104
105    pub fn cancellation_token(&self) -> CancellationToken {
106        self.cancellation_token.clone()
107    }
108}
109
110pub enum GenServerInMsg<G: GenServer> {
111    Call {
112        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
113        message: G::CallMsg,
114    },
115    Cast {
116        message: G::CastMsg,
117    },
118}
119
120pub enum CallResponse<G: GenServer> {
121    Reply(G::OutMsg),
122    Unused,
123    Stop(G::OutMsg),
124}
125
126pub enum CastResponse {
127    NoReply,
128    Unused,
129    Stop,
130}
131
132pub enum InitResult<G: GenServer> {
133    Success(G),
134    NoSuccess(G),
135}
136
137pub trait GenServer: Send + Sized {
138    type CallMsg: Clone + Send + Sized + Sync;
139    type CastMsg: Clone + Send + Sized + Sync;
140    type OutMsg: Send + Sized;
141    type Error: Debug + Send;
142
143    fn start(self) -> GenServerHandle<Self> {
144        GenServerHandle::new(self)
145    }
146
147    fn start_blocking(self) -> GenServerHandle<Self> {
153        GenServerHandle::new_blocking(self)
154    }
155
156    fn run(
157        self,
158        handle: &GenServerHandle<Self>,
159        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
160    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
161        async {
162            let res = match self.init(handle).await {
163                Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
164                Ok(NoSuccess(intermediate_state)) => {
165                    Ok(intermediate_state)
169                }
170                Err(err) => {
171                    tracing::error!("Initialization failed with unhandled error: {err:?}");
172                    Err(GenServerError::Initialization)
173                }
174            };
175
176            handle.cancellation_token().cancel();
177            if let Ok(final_state) = res {
178                if let Err(err) = final_state.teardown(handle).await {
179                    tracing::error!("Error during teardown: {err:?}");
180                }
181            }
182            Ok(())
183        }
184    }
185
186    fn init(
190        self,
191        _handle: &GenServerHandle<Self>,
192    ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
193        async { Ok(Success(self)) }
194    }
195
196    fn main_loop(
197        mut self,
198        handle: &GenServerHandle<Self>,
199        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200    ) -> impl Future<Output = Self> + Send {
201        async {
202            loop {
203                if !self.receive(handle, rx).await {
204                    break;
205                }
206            }
207            tracing::trace!("Stopping GenServer");
208            self
209        }
210    }
211
212    fn receive(
213        &mut self,
214        handle: &GenServerHandle<Self>,
215        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
216    ) -> impl Future<Output = bool> + Send {
217        async move {
218            let message = rx.recv().await;
219
220            let keep_running = match message {
221                Some(GenServerInMsg::Call { sender, message }) => {
222                    let (keep_running, response) =
223                        match AssertUnwindSafe(self.handle_call(message, handle))
224                            .catch_unwind()
225                            .await
226                        {
227                            Ok(response) => match response {
228                                CallResponse::Reply(response) => (true, Ok(response)),
229                                CallResponse::Stop(response) => (false, Ok(response)),
230                                CallResponse::Unused => {
231                                    tracing::error!("GenServer received unexpected CallMessage");
232                                    (false, Err(GenServerError::CallMsgUnused))
233                                }
234                            },
235                            Err(error) => {
236                                tracing::error!("Error in callback: '{error:?}'");
237                                (false, Err(GenServerError::Callback))
238                            }
239                        };
240                    if sender.send(response).is_err() {
242                        tracing::error!(
243                            "GenServer failed to send response back, client must have died"
244                        )
245                    };
246                    keep_running
247                }
248                Some(GenServerInMsg::Cast { message }) => {
249                    match AssertUnwindSafe(self.handle_cast(message, handle))
250                        .catch_unwind()
251                        .await
252                    {
253                        Ok(response) => match response {
254                            CastResponse::NoReply => true,
255                            CastResponse::Stop => false,
256                            CastResponse::Unused => {
257                                tracing::error!("GenServer received unexpected CastMessage");
258                                false
259                            }
260                        },
261                        Err(error) => {
262                            tracing::trace!("Error in callback: '{error:?}'");
263                            false
264                        }
265                    }
266                }
267                None => {
268                    false
270                }
271            };
272            keep_running
273        }
274    }
275
276    fn handle_call(
277        &mut self,
278        _message: Self::CallMsg,
279        _handle: &GenServerHandle<Self>,
280    ) -> impl Future<Output = CallResponse<Self>> + Send {
281        async { CallResponse::Unused }
282    }
283
284    fn handle_cast(
285        &mut self,
286        _message: Self::CastMsg,
287        _handle: &GenServerHandle<Self>,
288    ) -> impl Future<Output = CastResponse> + Send {
289        async { CastResponse::Unused }
290    }
291
292    fn teardown(
296        self,
297        _handle: &GenServerHandle<Self>,
298    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299        async { Ok(()) }
300    }
301}
302
303#[cfg(debug_assertions)]
304mod warn_on_block {
305    use super::*;
306
307    use std::time::Instant;
308    use tracing::warn;
309
310    pin_project_lite::pin_project! {
311        pub struct WarnOnBlocking<F: Future>{
312            #[pin]
313            inner: F
314        }
315    }
316
317    impl<F: Future> WarnOnBlocking<F> {
318        pub fn new(inner: F) -> Self {
319            Self { inner }
320        }
321    }
322
323    impl<F: Future> Future for WarnOnBlocking<F> {
324        type Output = F::Output;
325
326        fn poll(
327            self: std::pin::Pin<&mut Self>,
328            cx: &mut std::task::Context<'_>,
329        ) -> std::task::Poll<Self::Output> {
330            let type_id = std::any::type_name::<F>();
331            let task_id = rt::task_id();
332            let this = self.project();
333            let now = Instant::now();
334            let res = this.inner.poll(cx);
335            let elapsed = now.elapsed();
336            if elapsed > Duration::from_millis(10) {
337                warn!(task = ?task_id, future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
338            }
339            res
340        }
341    }
342}
343
344#[cfg(test)]
345mod tests {
346
347    use super::*;
348    use crate::{messages::Unused, tasks::send_after};
349    use std::{
350        sync::{Arc, Mutex},
351        thread,
352        time::Duration,
353    };
354
355    struct BadlyBehavedTask;
356
357    #[derive(Clone)]
358    pub enum InMessage {
359        GetCount,
360        Stop,
361    }
362    #[derive(Clone)]
363    pub enum OutMsg {
364        Count(u64),
365    }
366
367    impl GenServer for BadlyBehavedTask {
368        type CallMsg = InMessage;
369        type CastMsg = Unused;
370        type OutMsg = Unused;
371        type Error = Unused;
372
373        async fn handle_call(
374            &mut self,
375            _: Self::CallMsg,
376            _: &GenServerHandle<Self>,
377        ) -> CallResponse<Self> {
378            CallResponse::Stop(Unused)
379        }
380
381        async fn handle_cast(
382            &mut self,
383            _: Self::CastMsg,
384            _: &GenServerHandle<Self>,
385        ) -> CastResponse {
386            rt::sleep(Duration::from_millis(20)).await;
387            thread::sleep(Duration::from_secs(2));
388            CastResponse::Stop
389        }
390    }
391
392    struct WellBehavedTask {
393        pub count: u64,
394    }
395
396    impl GenServer for WellBehavedTask {
397        type CallMsg = InMessage;
398        type CastMsg = Unused;
399        type OutMsg = OutMsg;
400        type Error = Unused;
401
402        async fn handle_call(
403            &mut self,
404            message: Self::CallMsg,
405            _: &GenServerHandle<Self>,
406        ) -> CallResponse<Self> {
407            match message {
408                InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
409                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
410            }
411        }
412
413        async fn handle_cast(
414            &mut self,
415            _: Self::CastMsg,
416            handle: &GenServerHandle<Self>,
417        ) -> CastResponse {
418            self.count += 1;
419            println!("{:?}: good still alive", thread::current().id());
420            send_after(Duration::from_millis(100), handle.to_owned(), Unused);
421            CastResponse::NoReply
422        }
423    }
424
425    #[test]
426    pub fn badly_behaved_thread_non_blocking() {
427        let runtime = rt::Runtime::new().unwrap();
428        runtime.block_on(async move {
429            let mut badboy = BadlyBehavedTask.start();
430            let _ = badboy.cast(Unused).await;
431            let mut goodboy = WellBehavedTask { count: 0 }.start();
432            let _ = goodboy.cast(Unused).await;
433            rt::sleep(Duration::from_secs(1)).await;
434            let count = goodboy.call(InMessage::GetCount).await.unwrap();
435
436            match count {
437                OutMsg::Count(num) => {
438                    assert_ne!(num, 10);
439                }
440            }
441            goodboy.call(InMessage::Stop).await.unwrap();
442        });
443    }
444
445    #[test]
446    pub fn badly_behaved_thread() {
447        let runtime = rt::Runtime::new().unwrap();
448        runtime.block_on(async move {
449            let mut badboy = BadlyBehavedTask.start_blocking();
450            let _ = badboy.cast(Unused).await;
451            let mut goodboy = WellBehavedTask { count: 0 }.start();
452            let _ = goodboy.cast(Unused).await;
453            rt::sleep(Duration::from_secs(1)).await;
454            let count = goodboy.call(InMessage::GetCount).await.unwrap();
455
456            match count {
457                OutMsg::Count(num) => {
458                    assert_eq!(num, 10);
459                }
460            }
461            goodboy.call(InMessage::Stop).await.unwrap();
462        });
463    }
464
465    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
466
467    #[derive(Debug, Default)]
468    struct SomeTask;
469
470    #[derive(Clone)]
471    enum SomeTaskCallMsg {
472        SlowOperation,
473        FastOperation,
474    }
475
476    impl GenServer for SomeTask {
477        type CallMsg = SomeTaskCallMsg;
478        type CastMsg = Unused;
479        type OutMsg = Unused;
480        type Error = Unused;
481
482        async fn handle_call(
483            &mut self,
484            message: Self::CallMsg,
485            _handle: &GenServerHandle<Self>,
486        ) -> CallResponse<Self> {
487            match message {
488                SomeTaskCallMsg::SlowOperation => {
489                    rt::sleep(TIMEOUT_DURATION * 2).await;
491                    CallResponse::Reply(Unused)
492                }
493                SomeTaskCallMsg::FastOperation => {
494                    rt::sleep(TIMEOUT_DURATION / 2).await;
496                    CallResponse::Reply(Unused)
497                }
498            }
499        }
500    }
501
502    #[test]
503    pub fn unresolving_task_times_out() {
504        let runtime = rt::Runtime::new().unwrap();
505        runtime.block_on(async move {
506            let mut unresolving_task = SomeTask.start();
507
508            let result = unresolving_task
509                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
510                .await;
511            assert!(matches!(result, Ok(Unused)));
512
513            let result = unresolving_task
514                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
515                .await;
516            assert!(matches!(result, Err(GenServerError::CallTimeout)));
517        });
518    }
519
520    struct SomeTaskThatFailsOnInit {
521        sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
522    }
523
524    impl SomeTaskThatFailsOnInit {
525        pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
526            Self { sender_channel }
527        }
528    }
529
530    impl GenServer for SomeTaskThatFailsOnInit {
531        type CallMsg = Unused;
532        type CastMsg = Unused;
533        type OutMsg = Unused;
534        type Error = Unused;
535
536        async fn init(
537            self,
538            _handle: &GenServerHandle<Self>,
539        ) -> Result<InitResult<Self>, Self::Error> {
540            Ok(NoSuccess(self))
542        }
543
544        async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
545            self.sender_channel.lock().unwrap().close();
546            Ok(())
547        }
548    }
549
550    #[test]
551    pub fn task_fails_with_intermediate_state() {
552        let runtime = rt::Runtime::new().unwrap();
553        runtime.block_on(async move {
554            let (rx, tx) = mpsc::channel::<u8>();
555            let sender_channel = Arc::new(Mutex::new(tx));
556            let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
557
558            rt::sleep(Duration::from_secs(1)).await;
560
561            assert!(rx.is_closed())
563        });
564    }
565}