spawned_concurrency/tasks/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use crate::{
4    error::GenServerError,
5    tasks::InitResult::{NoSuccess, Success},
6};
7use core::pin::pin;
8use futures::future::{self, FutureExt as _};
9use spawned_rt::{
10    tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken, JoinHandle},
11    threads,
12};
13use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
14
15const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
16
17#[derive(Debug)]
18pub struct GenServerHandle<G: GenServer + 'static> {
19    pub tx: mpsc::Sender<GenServerInMsg<G>>,
20    /// Cancellation token to stop the GenServer
21    cancellation_token: CancellationToken,
22}
23
24impl<G: GenServer> Clone for GenServerHandle<G> {
25    fn clone(&self) -> Self {
26        Self {
27            tx: self.tx.clone(),
28            cancellation_token: self.cancellation_token.clone(),
29        }
30    }
31}
32
33impl<G: GenServer> GenServerHandle<G> {
34    fn new(gen_server: G) -> Self {
35        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
36        let cancellation_token = CancellationToken::new();
37        let handle = GenServerHandle {
38            tx,
39            cancellation_token,
40        };
41        let handle_clone = handle.clone();
42        let inner_future = async move {
43            if let Err(error) = gen_server.run(&handle, &mut rx).await {
44                tracing::trace!(%error, "GenServer crashed")
45            }
46        };
47
48        #[cfg(debug_assertions)]
49        // Optionally warn if the GenServer future blocks for too much time
50        let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
51
52        // Ignore the JoinHandle for now. Maybe we'll use it in the future
53        let _join_handle = rt::spawn(inner_future);
54
55        handle_clone
56    }
57
58    fn new_blocking(gen_server: G) -> Self {
59        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
60        let cancellation_token = CancellationToken::new();
61        let handle = GenServerHandle {
62            tx,
63            cancellation_token,
64        };
65        let handle_clone = handle.clone();
66        // Ignore the JoinHandle for now. Maybe we'll use it in the future
67        let _join_handle = rt::spawn_blocking(|| {
68            rt::block_on(async move {
69                if let Err(error) = gen_server.run(&handle, &mut rx).await {
70                    tracing::trace!(%error, "GenServer crashed")
71                };
72            })
73        });
74        handle_clone
75    }
76
77    fn new_on_thread(gen_server: G) -> Self {
78        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
79        let cancellation_token = CancellationToken::new();
80        let handle = GenServerHandle {
81            tx,
82            cancellation_token,
83        };
84        let handle_clone = handle.clone();
85        // Ignore the JoinHandle for now. Maybe we'll use it in the future
86        let _join_handle = threads::spawn(|| {
87            threads::block_on(async move {
88                if let Err(error) = gen_server.run(&handle, &mut rx).await {
89                    tracing::trace!(%error, "GenServer crashed")
90                };
91            })
92        });
93        handle_clone
94    }
95
96    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
97        self.tx.clone()
98    }
99
100    pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
101        self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
102    }
103
104    pub async fn call_with_timeout(
105        &mut self,
106        message: G::CallMsg,
107        duration: Duration,
108    ) -> Result<G::OutMsg, GenServerError> {
109        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
110        self.tx.send(GenServerInMsg::Call {
111            sender: oneshot_tx,
112            message,
113        })?;
114
115        match timeout(duration, oneshot_rx).await {
116            Ok(Ok(result)) => result,
117            Ok(Err(_)) => Err(GenServerError::Server),
118            Err(_) => Err(GenServerError::CallTimeout),
119        }
120    }
121
122    pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
123        self.tx
124            .send(GenServerInMsg::Cast { message })
125            .map_err(|_error| GenServerError::Server)
126    }
127
128    pub fn cancellation_token(&self) -> CancellationToken {
129        self.cancellation_token.clone()
130    }
131}
132
133pub enum GenServerInMsg<G: GenServer> {
134    Call {
135        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
136        message: G::CallMsg,
137    },
138    Cast {
139        message: G::CastMsg,
140    },
141}
142
143pub enum CallResponse<G: GenServer> {
144    Reply(G::OutMsg),
145    Unused,
146    Stop(G::OutMsg),
147}
148
149pub enum CastResponse {
150    NoReply,
151    Unused,
152    Stop,
153}
154
155pub enum InitResult<G: GenServer> {
156    Success(G),
157    NoSuccess(G),
158}
159
160pub trait GenServer: Send + Sized {
161    type CallMsg: Clone + Send + Sized + Sync;
162    type CastMsg: Clone + Send + Sized + Sync;
163    type OutMsg: Send + Sized;
164    type Error: Debug + Send;
165
166    fn start(self) -> GenServerHandle<Self> {
167        GenServerHandle::new(self)
168    }
169
170    /// Tokio tasks depend on a coolaborative multitasking model. "work stealing" can't
171    /// happen if the task is blocking the thread. As such, for sync compute task
172    /// or other blocking tasks need to be in their own separate thread, and the OS
173    /// will manage them through hardware interrupts.
174    /// Start blocking provides such thread.
175    fn start_blocking(self) -> GenServerHandle<Self> {
176        GenServerHandle::new_blocking(self)
177    }
178
179    /// For some "singleton" GenServers that run througout the whole execution of the
180    /// program, it makes sense to run in their own dedicated thread to avoid interference
181    /// with the rest of the tasks' runtime.
182    /// The use of tokio::task::spawm_blocking is not recommended for these scenarios
183    /// as it is a limited thread pool better suited for blocking IO tasks that eventually end
184    fn start_on_thread(self) -> GenServerHandle<Self> {
185        GenServerHandle::new_on_thread(self)
186    }
187
188    fn run(
189        self,
190        handle: &GenServerHandle<Self>,
191        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
192    ) -> impl Future<Output = Result<(), GenServerError>> + Send {
193        async {
194            let res = match self.init(handle).await {
195                Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
196                Ok(NoSuccess(intermediate_state)) => {
197                    // new_state is NoSuccess, this means the initialization failed, but the error was handled
198                    // in callback. No need to report the error.
199                    // Just skip main_loop and return the state to teardown the GenServer
200                    Ok(intermediate_state)
201                }
202                Err(err) => {
203                    tracing::error!("Initialization failed with unhandled error: {err:?}");
204                    Err(GenServerError::Initialization)
205                }
206            };
207
208            handle.cancellation_token().cancel();
209            if let Ok(final_state) = res {
210                if let Err(err) = final_state.teardown(handle).await {
211                    tracing::error!("Error during teardown: {err:?}");
212                }
213            }
214            Ok(())
215        }
216    }
217
218    /// Initialization function. It's called before main loop. It
219    /// can be overrided on implementations in case initial steps are
220    /// required.
221    fn init(
222        self,
223        _handle: &GenServerHandle<Self>,
224    ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
225        async { Ok(Success(self)) }
226    }
227
228    fn main_loop(
229        mut self,
230        handle: &GenServerHandle<Self>,
231        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
232    ) -> impl Future<Output = Self> + Send {
233        async {
234            loop {
235                if !self.receive(handle, rx).await {
236                    break;
237                }
238            }
239            tracing::trace!("Stopping GenServer");
240            self
241        }
242    }
243
244    fn receive(
245        &mut self,
246        handle: &GenServerHandle<Self>,
247        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
248    ) -> impl Future<Output = bool> + Send {
249        async move {
250            let message = rx.recv().await;
251
252            let keep_running = match message {
253                Some(GenServerInMsg::Call { sender, message }) => {
254                    let (keep_running, response) =
255                        match AssertUnwindSafe(self.handle_call(message, handle))
256                            .catch_unwind()
257                            .await
258                        {
259                            Ok(response) => match response {
260                                CallResponse::Reply(response) => (true, Ok(response)),
261                                CallResponse::Stop(response) => (false, Ok(response)),
262                                CallResponse::Unused => {
263                                    tracing::error!("GenServer received unexpected CallMessage");
264                                    (false, Err(GenServerError::CallMsgUnused))
265                                }
266                            },
267                            Err(error) => {
268                                tracing::error!("Error in callback: '{error:?}'");
269                                (false, Err(GenServerError::Callback))
270                            }
271                        };
272                    // Send response back
273                    if sender.send(response).is_err() {
274                        tracing::error!(
275                            "GenServer failed to send response back, client must have died"
276                        )
277                    };
278                    keep_running
279                }
280                Some(GenServerInMsg::Cast { message }) => {
281                    match AssertUnwindSafe(self.handle_cast(message, handle))
282                        .catch_unwind()
283                        .await
284                    {
285                        Ok(response) => match response {
286                            CastResponse::NoReply => true,
287                            CastResponse::Stop => false,
288                            CastResponse::Unused => {
289                                tracing::error!("GenServer received unexpected CastMessage");
290                                false
291                            }
292                        },
293                        Err(error) => {
294                            tracing::trace!("Error in callback: '{error:?}'");
295                            false
296                        }
297                    }
298                }
299                None => {
300                    // Channel has been closed; won't receive further messages. Stop the server.
301                    false
302                }
303            };
304            keep_running
305        }
306    }
307
308    fn handle_call(
309        &mut self,
310        _message: Self::CallMsg,
311        _handle: &GenServerHandle<Self>,
312    ) -> impl Future<Output = CallResponse<Self>> + Send {
313        async { CallResponse::Unused }
314    }
315
316    fn handle_cast(
317        &mut self,
318        _message: Self::CastMsg,
319        _handle: &GenServerHandle<Self>,
320    ) -> impl Future<Output = CastResponse> + Send {
321        async { CastResponse::Unused }
322    }
323
324    /// Teardown function. It's called after the stop message is received.
325    /// It can be overrided on implementations in case final steps are required,
326    /// like closing streams, stopping timers, etc.
327    fn teardown(
328        self,
329        _handle: &GenServerHandle<Self>,
330    ) -> impl Future<Output = Result<(), Self::Error>> + Send {
331        async { Ok(()) }
332    }
333}
334
335/// Spawns a task that awaits on a future and sends a message to a GenServer
336/// on completion.
337/// This function returns a handle to the spawned task.
338pub fn send_message_on<T, U>(
339    handle: GenServerHandle<T>,
340    future: U,
341    message: T::CastMsg,
342) -> JoinHandle<()>
343where
344    T: GenServer,
345    U: Future + Send + 'static,
346    <U as Future>::Output: Send,
347{
348    let cancelation_token = handle.cancellation_token();
349    let mut handle_clone = handle.clone();
350    let join_handle = rt::spawn(async move {
351        let is_cancelled = pin!(cancelation_token.cancelled());
352        let signal = pin!(future);
353        match future::select(is_cancelled, signal).await {
354            future::Either::Left(_) => tracing::debug!("GenServer stopped"),
355            future::Either::Right(_) => {
356                if let Err(e) = handle_clone.cast(message).await {
357                    tracing::error!("Failed to send message: {e:?}")
358                }
359            }
360        }
361    });
362    join_handle
363}
364
365#[cfg(debug_assertions)]
366mod warn_on_block {
367    use super::*;
368
369    use std::time::Instant;
370    use tracing::warn;
371
372    pin_project_lite::pin_project! {
373        pub struct WarnOnBlocking<F: Future>{
374            #[pin]
375            inner: F
376        }
377    }
378
379    impl<F: Future> WarnOnBlocking<F> {
380        pub fn new(inner: F) -> Self {
381            Self { inner }
382        }
383    }
384
385    impl<F: Future> Future for WarnOnBlocking<F> {
386        type Output = F::Output;
387
388        fn poll(
389            self: std::pin::Pin<&mut Self>,
390            cx: &mut std::task::Context<'_>,
391        ) -> std::task::Poll<Self::Output> {
392            let type_id = std::any::type_name::<F>();
393            let task_id = rt::task_id();
394            let this = self.project();
395            let now = Instant::now();
396            let res = this.inner.poll(cx);
397            let elapsed = now.elapsed();
398            if elapsed > Duration::from_millis(10) {
399                warn!(task = ?task_id, future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
400            }
401            res
402        }
403    }
404}
405
406#[cfg(test)]
407mod tests {
408
409    use super::*;
410    use crate::{messages::Unused, tasks::send_after};
411    use std::{
412        sync::{Arc, Mutex},
413        thread,
414        time::Duration,
415    };
416
417    struct BadlyBehavedTask;
418
419    #[derive(Clone)]
420    pub enum InMessage {
421        GetCount,
422        Stop,
423    }
424    #[derive(Clone)]
425    pub enum OutMsg {
426        Count(u64),
427    }
428
429    impl GenServer for BadlyBehavedTask {
430        type CallMsg = InMessage;
431        type CastMsg = Unused;
432        type OutMsg = Unused;
433        type Error = Unused;
434
435        async fn handle_call(
436            &mut self,
437            _: Self::CallMsg,
438            _: &GenServerHandle<Self>,
439        ) -> CallResponse<Self> {
440            CallResponse::Stop(Unused)
441        }
442
443        async fn handle_cast(
444            &mut self,
445            _: Self::CastMsg,
446            _: &GenServerHandle<Self>,
447        ) -> CastResponse {
448            rt::sleep(Duration::from_millis(20)).await;
449            thread::sleep(Duration::from_secs(2));
450            CastResponse::Stop
451        }
452    }
453
454    struct WellBehavedTask {
455        pub count: u64,
456    }
457
458    impl GenServer for WellBehavedTask {
459        type CallMsg = InMessage;
460        type CastMsg = Unused;
461        type OutMsg = OutMsg;
462        type Error = Unused;
463
464        async fn handle_call(
465            &mut self,
466            message: Self::CallMsg,
467            _: &GenServerHandle<Self>,
468        ) -> CallResponse<Self> {
469            match message {
470                InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
471                InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
472            }
473        }
474
475        async fn handle_cast(
476            &mut self,
477            _: Self::CastMsg,
478            handle: &GenServerHandle<Self>,
479        ) -> CastResponse {
480            self.count += 1;
481            println!("{:?}: good still alive", thread::current().id());
482            send_after(Duration::from_millis(100), handle.to_owned(), Unused);
483            CastResponse::NoReply
484        }
485    }
486
487    #[test]
488    pub fn badly_behaved_thread_non_blocking() {
489        let runtime = rt::Runtime::new().unwrap();
490        runtime.block_on(async move {
491            let mut badboy = BadlyBehavedTask.start();
492            let _ = badboy.cast(Unused).await;
493            let mut goodboy = WellBehavedTask { count: 0 }.start();
494            let _ = goodboy.cast(Unused).await;
495            rt::sleep(Duration::from_secs(1)).await;
496            let count = goodboy.call(InMessage::GetCount).await.unwrap();
497
498            match count {
499                OutMsg::Count(num) => {
500                    assert_ne!(num, 10);
501                }
502            }
503            goodboy.call(InMessage::Stop).await.unwrap();
504        });
505    }
506
507    #[test]
508    pub fn badly_behaved_thread() {
509        let runtime = rt::Runtime::new().unwrap();
510        runtime.block_on(async move {
511            let mut badboy = BadlyBehavedTask.start_blocking();
512            let _ = badboy.cast(Unused).await;
513            let mut goodboy = WellBehavedTask { count: 0 }.start();
514            let _ = goodboy.cast(Unused).await;
515            rt::sleep(Duration::from_secs(1)).await;
516            let count = goodboy.call(InMessage::GetCount).await.unwrap();
517
518            match count {
519                OutMsg::Count(num) => {
520                    assert_eq!(num, 10);
521                }
522            }
523            goodboy.call(InMessage::Stop).await.unwrap();
524        });
525    }
526
527    const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
528
529    #[derive(Debug, Default)]
530    struct SomeTask;
531
532    #[derive(Clone)]
533    enum SomeTaskCallMsg {
534        SlowOperation,
535        FastOperation,
536    }
537
538    impl GenServer for SomeTask {
539        type CallMsg = SomeTaskCallMsg;
540        type CastMsg = Unused;
541        type OutMsg = Unused;
542        type Error = Unused;
543
544        async fn handle_call(
545            &mut self,
546            message: Self::CallMsg,
547            _handle: &GenServerHandle<Self>,
548        ) -> CallResponse<Self> {
549            match message {
550                SomeTaskCallMsg::SlowOperation => {
551                    // Simulate a slow operation that will not resolve in time
552                    rt::sleep(TIMEOUT_DURATION * 2).await;
553                    CallResponse::Reply(Unused)
554                }
555                SomeTaskCallMsg::FastOperation => {
556                    // Simulate a fast operation that resolves in time
557                    rt::sleep(TIMEOUT_DURATION / 2).await;
558                    CallResponse::Reply(Unused)
559                }
560            }
561        }
562    }
563
564    #[test]
565    pub fn unresolving_task_times_out() {
566        let runtime = rt::Runtime::new().unwrap();
567        runtime.block_on(async move {
568            let mut unresolving_task = SomeTask.start();
569
570            let result = unresolving_task
571                .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
572                .await;
573            assert!(matches!(result, Ok(Unused)));
574
575            let result = unresolving_task
576                .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
577                .await;
578            assert!(matches!(result, Err(GenServerError::CallTimeout)));
579        });
580    }
581
582    struct SomeTaskThatFailsOnInit {
583        sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
584    }
585
586    impl SomeTaskThatFailsOnInit {
587        pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
588            Self { sender_channel }
589        }
590    }
591
592    impl GenServer for SomeTaskThatFailsOnInit {
593        type CallMsg = Unused;
594        type CastMsg = Unused;
595        type OutMsg = Unused;
596        type Error = Unused;
597
598        async fn init(
599            self,
600            _handle: &GenServerHandle<Self>,
601        ) -> Result<InitResult<Self>, Self::Error> {
602            // Simulate an initialization failure by returning NoSuccess
603            Ok(NoSuccess(self))
604        }
605
606        async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
607            self.sender_channel.lock().unwrap().close();
608            Ok(())
609        }
610    }
611
612    #[test]
613    pub fn task_fails_with_intermediate_state() {
614        let runtime = rt::Runtime::new().unwrap();
615        runtime.block_on(async move {
616            let (rx, tx) = mpsc::channel::<u8>();
617            let sender_channel = Arc::new(Mutex::new(tx));
618            let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
619
620            // Wait a while to ensure the task has time to run and fail
621            rt::sleep(Duration::from_secs(1)).await;
622
623            // We assure that the teardown function has ran by checking that the receiver channel is closed
624            assert!(rx.is_closed())
625        });
626    }
627}