spawned_concurrency/tasks/
gen_server.rs1use crate::{
4 error::GenServerError,
5 tasks::InitResult::{NoSuccess, Success},
6};
7use futures::future::FutureExt as _;
8use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
9use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
10
11const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
12
13#[derive(Debug)]
14pub struct GenServerHandle<G: GenServer + 'static> {
15 pub tx: mpsc::Sender<GenServerInMsg<G>>,
16 cancellation_token: CancellationToken,
18}
19
20impl<G: GenServer> Clone for GenServerHandle<G> {
21 fn clone(&self) -> Self {
22 Self {
23 tx: self.tx.clone(),
24 cancellation_token: self.cancellation_token.clone(),
25 }
26 }
27}
28
29impl<G: GenServer> GenServerHandle<G> {
30 pub(crate) fn new(gen_server: G) -> Self {
31 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
32 let cancellation_token = CancellationToken::new();
33 let handle = GenServerHandle {
34 tx,
35 cancellation_token,
36 };
37 let handle_clone = handle.clone();
38 let inner_future = async move {
39 if let Err(error) = gen_server.run(&handle, &mut rx).await {
40 tracing::trace!(%error, "GenServer crashed")
41 }
42 };
43
44 #[cfg(debug_assertions)]
45 let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
47
48 let _join_handle = rt::spawn(inner_future);
50
51 handle_clone
52 }
53
54 pub(crate) fn new_blocking(gen_server: G) -> Self {
55 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
56 let cancellation_token = CancellationToken::new();
57 let handle = GenServerHandle {
58 tx,
59 cancellation_token,
60 };
61 let handle_clone = handle.clone();
62 let _join_handle = rt::spawn_blocking(|| {
64 rt::block_on(async move {
65 if let Err(error) = gen_server.run(&handle, &mut rx).await {
66 tracing::trace!(%error, "GenServer crashed")
67 };
68 })
69 });
70 handle_clone
71 }
72
73 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
74 self.tx.clone()
75 }
76
77 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
78 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
79 }
80
81 pub async fn call_with_timeout(
82 &mut self,
83 message: G::CallMsg,
84 duration: Duration,
85 ) -> Result<G::OutMsg, GenServerError> {
86 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
87 self.tx.send(GenServerInMsg::Call {
88 sender: oneshot_tx,
89 message,
90 })?;
91
92 match timeout(duration, oneshot_rx).await {
93 Ok(Ok(result)) => result,
94 Ok(Err(_)) => Err(GenServerError::Server),
95 Err(_) => Err(GenServerError::CallTimeout),
96 }
97 }
98
99 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
100 self.tx
101 .send(GenServerInMsg::Cast { message })
102 .map_err(|_error| GenServerError::Server)
103 }
104
105 pub fn cancellation_token(&self) -> CancellationToken {
106 self.cancellation_token.clone()
107 }
108}
109
110pub enum GenServerInMsg<G: GenServer> {
111 Call {
112 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
113 message: G::CallMsg,
114 },
115 Cast {
116 message: G::CastMsg,
117 },
118}
119
120pub enum CallResponse<G: GenServer> {
121 Reply(G::OutMsg),
122 Unused,
123 Stop(G::OutMsg),
124}
125
126pub enum CastResponse {
127 NoReply,
128 Unused,
129 Stop,
130}
131
132pub enum InitResult<G: GenServer> {
133 Success(G),
134 NoSuccess(G),
135}
136
137pub trait GenServer: Send + Sized {
138 type CallMsg: Clone + Send + Sized + Sync;
139 type CastMsg: Clone + Send + Sized + Sync;
140 type OutMsg: Send + Sized;
141 type Error: Debug + Send;
142
143 fn start(self) -> GenServerHandle<Self> {
144 GenServerHandle::new(self)
145 }
146
147 fn start_blocking(self) -> GenServerHandle<Self> {
153 GenServerHandle::new_blocking(self)
154 }
155
156 fn run(
157 self,
158 handle: &GenServerHandle<Self>,
159 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
160 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
161 async {
162 let res = match self.init(handle).await {
163 Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
164 Ok(NoSuccess(intermediate_state)) => {
165 Ok(intermediate_state)
169 }
170 Err(err) => {
171 tracing::error!("Initialization failed with unhandled error: {err:?}");
172 Err(GenServerError::Initialization)
173 }
174 };
175
176 handle.cancellation_token().cancel();
177 if let Ok(final_state) = res {
178 if let Err(err) = final_state.teardown(handle).await {
179 tracing::error!("Error during teardown: {err:?}");
180 }
181 }
182 Ok(())
183 }
184 }
185
186 fn init(
190 self,
191 _handle: &GenServerHandle<Self>,
192 ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
193 async { Ok(Success(self)) }
194 }
195
196 fn main_loop(
197 mut self,
198 handle: &GenServerHandle<Self>,
199 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200 ) -> impl Future<Output = Self> + Send {
201 async {
202 loop {
203 if !self.receive(handle, rx).await {
204 break;
205 }
206 }
207 tracing::trace!("Stopping GenServer");
208 self
209 }
210 }
211
212 fn receive(
213 &mut self,
214 handle: &GenServerHandle<Self>,
215 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
216 ) -> impl Future<Output = bool> + Send {
217 async move {
218 let message = rx.recv().await;
219
220 let keep_running = match message {
221 Some(GenServerInMsg::Call { sender, message }) => {
222 let (keep_running, response) =
223 match AssertUnwindSafe(self.handle_call(message, handle))
224 .catch_unwind()
225 .await
226 {
227 Ok(response) => match response {
228 CallResponse::Reply(response) => (true, Ok(response)),
229 CallResponse::Stop(response) => (false, Ok(response)),
230 CallResponse::Unused => {
231 tracing::error!("GenServer received unexpected CallMessage");
232 (false, Err(GenServerError::CallMsgUnused))
233 }
234 },
235 Err(error) => {
236 tracing::error!("Error in callback: '{error:?}'");
237 (false, Err(GenServerError::Callback))
238 }
239 };
240 if sender.send(response).is_err() {
242 tracing::error!(
243 "GenServer failed to send response back, client must have died"
244 )
245 };
246 keep_running
247 }
248 Some(GenServerInMsg::Cast { message }) => {
249 match AssertUnwindSafe(self.handle_cast(message, handle))
250 .catch_unwind()
251 .await
252 {
253 Ok(response) => match response {
254 CastResponse::NoReply => true,
255 CastResponse::Stop => false,
256 CastResponse::Unused => {
257 tracing::error!("GenServer received unexpected CastMessage");
258 false
259 }
260 },
261 Err(error) => {
262 tracing::trace!("Error in callback: '{error:?}'");
263 false
264 }
265 }
266 }
267 None => {
268 false
270 }
271 };
272 keep_running
273 }
274 }
275
276 fn handle_call(
277 &mut self,
278 _message: Self::CallMsg,
279 _handle: &GenServerHandle<Self>,
280 ) -> impl Future<Output = CallResponse<Self>> + Send {
281 async { CallResponse::Unused }
282 }
283
284 fn handle_cast(
285 &mut self,
286 _message: Self::CastMsg,
287 _handle: &GenServerHandle<Self>,
288 ) -> impl Future<Output = CastResponse> + Send {
289 async { CastResponse::Unused }
290 }
291
292 fn teardown(
296 self,
297 _handle: &GenServerHandle<Self>,
298 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299 async { Ok(()) }
300 }
301}
302
303#[cfg(debug_assertions)]
304mod warn_on_block {
305 use super::*;
306
307 use std::time::Instant;
308 use tracing::warn;
309
310 pin_project_lite::pin_project! {
311 pub struct WarnOnBlocking<F: Future>{
312 #[pin]
313 inner: F
314 }
315 }
316
317 impl<F: Future> WarnOnBlocking<F> {
318 pub fn new(inner: F) -> Self {
319 Self { inner }
320 }
321 }
322
323 impl<F: Future> Future for WarnOnBlocking<F> {
324 type Output = F::Output;
325
326 fn poll(
327 self: std::pin::Pin<&mut Self>,
328 cx: &mut std::task::Context<'_>,
329 ) -> std::task::Poll<Self::Output> {
330 let type_id = std::any::type_name::<F>();
331 let task_id = rt::task_id();
332 let this = self.project();
333 let now = Instant::now();
334 let res = this.inner.poll(cx);
335 let elapsed = now.elapsed();
336 if elapsed > Duration::from_millis(10) {
337 warn!(task = ?task_id, future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
338 }
339 res
340 }
341 }
342}
343
344#[cfg(test)]
345mod tests {
346
347 use super::*;
348 use crate::{messages::Unused, tasks::send_after};
349 use std::{
350 sync::{Arc, Mutex},
351 thread,
352 time::Duration,
353 };
354
355 struct BadlyBehavedTask;
356
357 #[derive(Clone)]
358 pub enum InMessage {
359 GetCount,
360 Stop,
361 }
362 #[derive(Clone)]
363 pub enum OutMsg {
364 Count(u64),
365 }
366
367 impl GenServer for BadlyBehavedTask {
368 type CallMsg = InMessage;
369 type CastMsg = Unused;
370 type OutMsg = Unused;
371 type Error = Unused;
372
373 async fn handle_call(
374 &mut self,
375 _: Self::CallMsg,
376 _: &GenServerHandle<Self>,
377 ) -> CallResponse<Self> {
378 CallResponse::Stop(Unused)
379 }
380
381 async fn handle_cast(
382 &mut self,
383 _: Self::CastMsg,
384 _: &GenServerHandle<Self>,
385 ) -> CastResponse {
386 rt::sleep(Duration::from_millis(20)).await;
387 thread::sleep(Duration::from_secs(2));
388 CastResponse::Stop
389 }
390 }
391
392 struct WellBehavedTask {
393 pub count: u64,
394 }
395
396 impl GenServer for WellBehavedTask {
397 type CallMsg = InMessage;
398 type CastMsg = Unused;
399 type OutMsg = OutMsg;
400 type Error = Unused;
401
402 async fn handle_call(
403 &mut self,
404 message: Self::CallMsg,
405 _: &GenServerHandle<Self>,
406 ) -> CallResponse<Self> {
407 match message {
408 InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
409 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
410 }
411 }
412
413 async fn handle_cast(
414 &mut self,
415 _: Self::CastMsg,
416 handle: &GenServerHandle<Self>,
417 ) -> CastResponse {
418 self.count += 1;
419 println!("{:?}: good still alive", thread::current().id());
420 send_after(Duration::from_millis(100), handle.to_owned(), Unused);
421 CastResponse::NoReply
422 }
423 }
424
425 #[test]
426 pub fn badly_behaved_thread_non_blocking() {
427 let runtime = rt::Runtime::new().unwrap();
428 runtime.block_on(async move {
429 let mut badboy = BadlyBehavedTask.start();
430 let _ = badboy.cast(Unused).await;
431 let mut goodboy = WellBehavedTask { count: 0 }.start();
432 let _ = goodboy.cast(Unused).await;
433 rt::sleep(Duration::from_secs(1)).await;
434 let count = goodboy.call(InMessage::GetCount).await.unwrap();
435
436 match count {
437 OutMsg::Count(num) => {
438 assert_ne!(num, 10);
439 }
440 }
441 goodboy.call(InMessage::Stop).await.unwrap();
442 });
443 }
444
445 #[test]
446 pub fn badly_behaved_thread() {
447 let runtime = rt::Runtime::new().unwrap();
448 runtime.block_on(async move {
449 let mut badboy = BadlyBehavedTask.start_blocking();
450 let _ = badboy.cast(Unused).await;
451 let mut goodboy = WellBehavedTask { count: 0 }.start();
452 let _ = goodboy.cast(Unused).await;
453 rt::sleep(Duration::from_secs(1)).await;
454 let count = goodboy.call(InMessage::GetCount).await.unwrap();
455
456 match count {
457 OutMsg::Count(num) => {
458 assert_eq!(num, 10);
459 }
460 }
461 goodboy.call(InMessage::Stop).await.unwrap();
462 });
463 }
464
465 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
466
467 #[derive(Debug, Default)]
468 struct SomeTask;
469
470 #[derive(Clone)]
471 enum SomeTaskCallMsg {
472 SlowOperation,
473 FastOperation,
474 }
475
476 impl GenServer for SomeTask {
477 type CallMsg = SomeTaskCallMsg;
478 type CastMsg = Unused;
479 type OutMsg = Unused;
480 type Error = Unused;
481
482 async fn handle_call(
483 &mut self,
484 message: Self::CallMsg,
485 _handle: &GenServerHandle<Self>,
486 ) -> CallResponse<Self> {
487 match message {
488 SomeTaskCallMsg::SlowOperation => {
489 rt::sleep(TIMEOUT_DURATION * 2).await;
491 CallResponse::Reply(Unused)
492 }
493 SomeTaskCallMsg::FastOperation => {
494 rt::sleep(TIMEOUT_DURATION / 2).await;
496 CallResponse::Reply(Unused)
497 }
498 }
499 }
500 }
501
502 #[test]
503 pub fn unresolving_task_times_out() {
504 let runtime = rt::Runtime::new().unwrap();
505 runtime.block_on(async move {
506 let mut unresolving_task = SomeTask.start();
507
508 let result = unresolving_task
509 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
510 .await;
511 assert!(matches!(result, Ok(Unused)));
512
513 let result = unresolving_task
514 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
515 .await;
516 assert!(matches!(result, Err(GenServerError::CallTimeout)));
517 });
518 }
519
520 struct SomeTaskThatFailsOnInit {
521 sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
522 }
523
524 impl SomeTaskThatFailsOnInit {
525 pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
526 Self { sender_channel }
527 }
528 }
529
530 impl GenServer for SomeTaskThatFailsOnInit {
531 type CallMsg = Unused;
532 type CastMsg = Unused;
533 type OutMsg = Unused;
534 type Error = Unused;
535
536 async fn init(
537 self,
538 _handle: &GenServerHandle<Self>,
539 ) -> Result<InitResult<Self>, Self::Error> {
540 Ok(NoSuccess(self))
542 }
543
544 async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
545 self.sender_channel.lock().unwrap().close();
546 Ok(())
547 }
548 }
549
550 #[test]
551 pub fn task_fails_with_intermediate_state() {
552 let runtime = rt::Runtime::new().unwrap();
553 runtime.block_on(async move {
554 let (rx, tx) = mpsc::channel::<u8>();
555 let sender_channel = Arc::new(Mutex::new(tx));
556 let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
557
558 rt::sleep(Duration::from_secs(1)).await;
560
561 assert!(rx.is_closed())
563 });
564 }
565}