spawned_concurrency/tasks/
gen_server.rs1use crate::{
4 error::GenServerError,
5 tasks::InitResult::{NoSuccess, Success},
6};
7use core::pin::pin;
8use futures::future::{self, FutureExt as _};
9use spawned_rt::{
10 tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken, JoinHandle},
11 threads,
12};
13use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
14
15const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
16
17#[derive(Debug)]
18pub struct GenServerHandle<G: GenServer + 'static> {
19 pub tx: mpsc::Sender<GenServerInMsg<G>>,
20 cancellation_token: CancellationToken,
22}
23
24impl<G: GenServer> Clone for GenServerHandle<G> {
25 fn clone(&self) -> Self {
26 Self {
27 tx: self.tx.clone(),
28 cancellation_token: self.cancellation_token.clone(),
29 }
30 }
31}
32
33impl<G: GenServer> GenServerHandle<G> {
34 fn new(gen_server: G) -> Self {
35 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
36 let cancellation_token = CancellationToken::new();
37 let handle = GenServerHandle {
38 tx,
39 cancellation_token,
40 };
41 let handle_clone = handle.clone();
42 let inner_future = async move {
43 if let Err(error) = gen_server.run(&handle, &mut rx).await {
44 tracing::trace!(%error, "GenServer crashed")
45 }
46 };
47
48 #[cfg(debug_assertions)]
49 let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
51
52 let _join_handle = rt::spawn(inner_future);
54
55 handle_clone
56 }
57
58 fn new_blocking(gen_server: G) -> Self {
59 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
60 let cancellation_token = CancellationToken::new();
61 let handle = GenServerHandle {
62 tx,
63 cancellation_token,
64 };
65 let handle_clone = handle.clone();
66 let _join_handle = rt::spawn_blocking(|| {
68 rt::block_on(async move {
69 if let Err(error) = gen_server.run(&handle, &mut rx).await {
70 tracing::trace!(%error, "GenServer crashed")
71 };
72 })
73 });
74 handle_clone
75 }
76
77 fn new_on_thread(gen_server: G) -> Self {
78 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
79 let cancellation_token = CancellationToken::new();
80 let handle = GenServerHandle {
81 tx,
82 cancellation_token,
83 };
84 let handle_clone = handle.clone();
85 let _join_handle = threads::spawn(|| {
87 threads::block_on(async move {
88 if let Err(error) = gen_server.run(&handle, &mut rx).await {
89 tracing::trace!(%error, "GenServer crashed")
90 };
91 })
92 });
93 handle_clone
94 }
95
96 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
97 self.tx.clone()
98 }
99
100 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
101 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
102 }
103
104 pub async fn call_with_timeout(
105 &mut self,
106 message: G::CallMsg,
107 duration: Duration,
108 ) -> Result<G::OutMsg, GenServerError> {
109 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
110 self.tx.send(GenServerInMsg::Call {
111 sender: oneshot_tx,
112 message,
113 })?;
114
115 match timeout(duration, oneshot_rx).await {
116 Ok(Ok(result)) => result,
117 Ok(Err(_)) => Err(GenServerError::Server),
118 Err(_) => Err(GenServerError::CallTimeout),
119 }
120 }
121
122 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
123 self.tx
124 .send(GenServerInMsg::Cast { message })
125 .map_err(|_error| GenServerError::Server)
126 }
127
128 pub fn cancellation_token(&self) -> CancellationToken {
129 self.cancellation_token.clone()
130 }
131}
132
133pub enum GenServerInMsg<G: GenServer> {
134 Call {
135 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
136 message: G::CallMsg,
137 },
138 Cast {
139 message: G::CastMsg,
140 },
141}
142
143pub enum CallResponse<G: GenServer> {
144 Reply(G::OutMsg),
145 Unused,
146 Stop(G::OutMsg),
147}
148
149pub enum CastResponse {
150 NoReply,
151 Unused,
152 Stop,
153}
154
155pub enum InitResult<G: GenServer> {
156 Success(G),
157 NoSuccess(G),
158}
159
160pub trait GenServer: Send + Sized {
161 type CallMsg: Clone + Send + Sized + Sync;
162 type CastMsg: Clone + Send + Sized + Sync;
163 type OutMsg: Send + Sized;
164 type Error: Debug + Send;
165
166 fn start(self) -> GenServerHandle<Self> {
167 GenServerHandle::new(self)
168 }
169
170 fn start_blocking(self) -> GenServerHandle<Self> {
176 GenServerHandle::new_blocking(self)
177 }
178
179 fn start_on_thread(self) -> GenServerHandle<Self> {
185 GenServerHandle::new_on_thread(self)
186 }
187
188 fn run(
189 self,
190 handle: &GenServerHandle<Self>,
191 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
192 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
193 async {
194 let res = match self.init(handle).await {
195 Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
196 Ok(NoSuccess(intermediate_state)) => {
197 Ok(intermediate_state)
201 }
202 Err(err) => {
203 tracing::error!("Initialization failed with unhandled error: {err:?}");
204 Err(GenServerError::Initialization)
205 }
206 };
207
208 handle.cancellation_token().cancel();
209 if let Ok(final_state) = res {
210 if let Err(err) = final_state.teardown(handle).await {
211 tracing::error!("Error during teardown: {err:?}");
212 }
213 }
214 Ok(())
215 }
216 }
217
218 fn init(
222 self,
223 _handle: &GenServerHandle<Self>,
224 ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
225 async { Ok(Success(self)) }
226 }
227
228 fn main_loop(
229 mut self,
230 handle: &GenServerHandle<Self>,
231 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
232 ) -> impl Future<Output = Self> + Send {
233 async {
234 loop {
235 if !self.receive(handle, rx).await {
236 break;
237 }
238 }
239 tracing::trace!("Stopping GenServer");
240 self
241 }
242 }
243
244 fn receive(
245 &mut self,
246 handle: &GenServerHandle<Self>,
247 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
248 ) -> impl Future<Output = bool> + Send {
249 async move {
250 let message = rx.recv().await;
251
252 let keep_running = match message {
253 Some(GenServerInMsg::Call { sender, message }) => {
254 let (keep_running, response) =
255 match AssertUnwindSafe(self.handle_call(message, handle))
256 .catch_unwind()
257 .await
258 {
259 Ok(response) => match response {
260 CallResponse::Reply(response) => (true, Ok(response)),
261 CallResponse::Stop(response) => (false, Ok(response)),
262 CallResponse::Unused => {
263 tracing::error!("GenServer received unexpected CallMessage");
264 (false, Err(GenServerError::CallMsgUnused))
265 }
266 },
267 Err(error) => {
268 tracing::error!("Error in callback: '{error:?}'");
269 (false, Err(GenServerError::Callback))
270 }
271 };
272 if sender.send(response).is_err() {
274 tracing::error!(
275 "GenServer failed to send response back, client must have died"
276 )
277 };
278 keep_running
279 }
280 Some(GenServerInMsg::Cast { message }) => {
281 match AssertUnwindSafe(self.handle_cast(message, handle))
282 .catch_unwind()
283 .await
284 {
285 Ok(response) => match response {
286 CastResponse::NoReply => true,
287 CastResponse::Stop => false,
288 CastResponse::Unused => {
289 tracing::error!("GenServer received unexpected CastMessage");
290 false
291 }
292 },
293 Err(error) => {
294 tracing::trace!("Error in callback: '{error:?}'");
295 false
296 }
297 }
298 }
299 None => {
300 false
302 }
303 };
304 keep_running
305 }
306 }
307
308 fn handle_call(
309 &mut self,
310 _message: Self::CallMsg,
311 _handle: &GenServerHandle<Self>,
312 ) -> impl Future<Output = CallResponse<Self>> + Send {
313 async { CallResponse::Unused }
314 }
315
316 fn handle_cast(
317 &mut self,
318 _message: Self::CastMsg,
319 _handle: &GenServerHandle<Self>,
320 ) -> impl Future<Output = CastResponse> + Send {
321 async { CastResponse::Unused }
322 }
323
324 fn teardown(
328 self,
329 _handle: &GenServerHandle<Self>,
330 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
331 async { Ok(()) }
332 }
333}
334
335pub fn send_message_on<T, U>(
339 handle: GenServerHandle<T>,
340 future: U,
341 message: T::CastMsg,
342) -> JoinHandle<()>
343where
344 T: GenServer,
345 U: Future + Send + 'static,
346 <U as Future>::Output: Send,
347{
348 let cancelation_token = handle.cancellation_token();
349 let mut handle_clone = handle.clone();
350 let join_handle = rt::spawn(async move {
351 let is_cancelled = pin!(cancelation_token.cancelled());
352 let signal = pin!(future);
353 match future::select(is_cancelled, signal).await {
354 future::Either::Left(_) => tracing::debug!("GenServer stopped"),
355 future::Either::Right(_) => {
356 if let Err(e) = handle_clone.cast(message).await {
357 tracing::error!("Failed to send message: {e:?}")
358 }
359 }
360 }
361 });
362 join_handle
363}
364
365#[cfg(debug_assertions)]
366mod warn_on_block {
367 use super::*;
368
369 use std::time::Instant;
370 use tracing::warn;
371
372 pin_project_lite::pin_project! {
373 pub struct WarnOnBlocking<F: Future>{
374 #[pin]
375 inner: F
376 }
377 }
378
379 impl<F: Future> WarnOnBlocking<F> {
380 pub fn new(inner: F) -> Self {
381 Self { inner }
382 }
383 }
384
385 impl<F: Future> Future for WarnOnBlocking<F> {
386 type Output = F::Output;
387
388 fn poll(
389 self: std::pin::Pin<&mut Self>,
390 cx: &mut std::task::Context<'_>,
391 ) -> std::task::Poll<Self::Output> {
392 let type_id = std::any::type_name::<F>();
393 let task_id = rt::task_id();
394 let this = self.project();
395 let now = Instant::now();
396 let res = this.inner.poll(cx);
397 let elapsed = now.elapsed();
398 if elapsed > Duration::from_millis(10) {
399 warn!(task = ?task_id, future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
400 }
401 res
402 }
403 }
404}
405
406#[cfg(test)]
407mod tests {
408
409 use super::*;
410 use crate::{messages::Unused, tasks::send_after};
411 use std::{
412 sync::{Arc, Mutex},
413 thread,
414 time::Duration,
415 };
416
417 struct BadlyBehavedTask;
418
419 #[derive(Clone)]
420 pub enum InMessage {
421 GetCount,
422 Stop,
423 }
424 #[derive(Clone)]
425 pub enum OutMsg {
426 Count(u64),
427 }
428
429 impl GenServer for BadlyBehavedTask {
430 type CallMsg = InMessage;
431 type CastMsg = Unused;
432 type OutMsg = Unused;
433 type Error = Unused;
434
435 async fn handle_call(
436 &mut self,
437 _: Self::CallMsg,
438 _: &GenServerHandle<Self>,
439 ) -> CallResponse<Self> {
440 CallResponse::Stop(Unused)
441 }
442
443 async fn handle_cast(
444 &mut self,
445 _: Self::CastMsg,
446 _: &GenServerHandle<Self>,
447 ) -> CastResponse {
448 rt::sleep(Duration::from_millis(20)).await;
449 thread::sleep(Duration::from_secs(2));
450 CastResponse::Stop
451 }
452 }
453
454 struct WellBehavedTask {
455 pub count: u64,
456 }
457
458 impl GenServer for WellBehavedTask {
459 type CallMsg = InMessage;
460 type CastMsg = Unused;
461 type OutMsg = OutMsg;
462 type Error = Unused;
463
464 async fn handle_call(
465 &mut self,
466 message: Self::CallMsg,
467 _: &GenServerHandle<Self>,
468 ) -> CallResponse<Self> {
469 match message {
470 InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
471 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
472 }
473 }
474
475 async fn handle_cast(
476 &mut self,
477 _: Self::CastMsg,
478 handle: &GenServerHandle<Self>,
479 ) -> CastResponse {
480 self.count += 1;
481 println!("{:?}: good still alive", thread::current().id());
482 send_after(Duration::from_millis(100), handle.to_owned(), Unused);
483 CastResponse::NoReply
484 }
485 }
486
487 #[test]
488 pub fn badly_behaved_thread_non_blocking() {
489 let runtime = rt::Runtime::new().unwrap();
490 runtime.block_on(async move {
491 let mut badboy = BadlyBehavedTask.start();
492 let _ = badboy.cast(Unused).await;
493 let mut goodboy = WellBehavedTask { count: 0 }.start();
494 let _ = goodboy.cast(Unused).await;
495 rt::sleep(Duration::from_secs(1)).await;
496 let count = goodboy.call(InMessage::GetCount).await.unwrap();
497
498 match count {
499 OutMsg::Count(num) => {
500 assert_ne!(num, 10);
501 }
502 }
503 goodboy.call(InMessage::Stop).await.unwrap();
504 });
505 }
506
507 #[test]
508 pub fn badly_behaved_thread() {
509 let runtime = rt::Runtime::new().unwrap();
510 runtime.block_on(async move {
511 let mut badboy = BadlyBehavedTask.start_blocking();
512 let _ = badboy.cast(Unused).await;
513 let mut goodboy = WellBehavedTask { count: 0 }.start();
514 let _ = goodboy.cast(Unused).await;
515 rt::sleep(Duration::from_secs(1)).await;
516 let count = goodboy.call(InMessage::GetCount).await.unwrap();
517
518 match count {
519 OutMsg::Count(num) => {
520 assert_eq!(num, 10);
521 }
522 }
523 goodboy.call(InMessage::Stop).await.unwrap();
524 });
525 }
526
527 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
528
529 #[derive(Debug, Default)]
530 struct SomeTask;
531
532 #[derive(Clone)]
533 enum SomeTaskCallMsg {
534 SlowOperation,
535 FastOperation,
536 }
537
538 impl GenServer for SomeTask {
539 type CallMsg = SomeTaskCallMsg;
540 type CastMsg = Unused;
541 type OutMsg = Unused;
542 type Error = Unused;
543
544 async fn handle_call(
545 &mut self,
546 message: Self::CallMsg,
547 _handle: &GenServerHandle<Self>,
548 ) -> CallResponse<Self> {
549 match message {
550 SomeTaskCallMsg::SlowOperation => {
551 rt::sleep(TIMEOUT_DURATION * 2).await;
553 CallResponse::Reply(Unused)
554 }
555 SomeTaskCallMsg::FastOperation => {
556 rt::sleep(TIMEOUT_DURATION / 2).await;
558 CallResponse::Reply(Unused)
559 }
560 }
561 }
562 }
563
564 #[test]
565 pub fn unresolving_task_times_out() {
566 let runtime = rt::Runtime::new().unwrap();
567 runtime.block_on(async move {
568 let mut unresolving_task = SomeTask.start();
569
570 let result = unresolving_task
571 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
572 .await;
573 assert!(matches!(result, Ok(Unused)));
574
575 let result = unresolving_task
576 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
577 .await;
578 assert!(matches!(result, Err(GenServerError::CallTimeout)));
579 });
580 }
581
582 struct SomeTaskThatFailsOnInit {
583 sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
584 }
585
586 impl SomeTaskThatFailsOnInit {
587 pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
588 Self { sender_channel }
589 }
590 }
591
592 impl GenServer for SomeTaskThatFailsOnInit {
593 type CallMsg = Unused;
594 type CastMsg = Unused;
595 type OutMsg = Unused;
596 type Error = Unused;
597
598 async fn init(
599 self,
600 _handle: &GenServerHandle<Self>,
601 ) -> Result<InitResult<Self>, Self::Error> {
602 Ok(NoSuccess(self))
604 }
605
606 async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
607 self.sender_channel.lock().unwrap().close();
608 Ok(())
609 }
610 }
611
612 #[test]
613 pub fn task_fails_with_intermediate_state() {
614 let runtime = rt::Runtime::new().unwrap();
615 runtime.block_on(async move {
616 let (rx, tx) = mpsc::channel::<u8>();
617 let sender_channel = Arc::new(Mutex::new(tx));
618 let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
619
620 rt::sleep(Duration::from_secs(1)).await;
622
623 assert!(rx.is_closed())
625 });
626 }
627}