spawned_concurrency/tasks/
gen_server.rs1use crate::{
4 error::GenServerError,
5 tasks::InitResult::{NoSuccess, Success},
6};
7use futures::future::FutureExt as _;
8use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
9use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
10
11const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
12
13#[derive(Debug)]
14pub struct GenServerHandle<G: GenServer + 'static> {
15 pub tx: mpsc::Sender<GenServerInMsg<G>>,
16 cancellation_token: CancellationToken,
18}
19
20impl<G: GenServer> Clone for GenServerHandle<G> {
21 fn clone(&self) -> Self {
22 Self {
23 tx: self.tx.clone(),
24 cancellation_token: self.cancellation_token.clone(),
25 }
26 }
27}
28
29impl<G: GenServer> GenServerHandle<G> {
30 pub(crate) fn new(gen_server: G) -> Self {
31 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
32 let cancellation_token = CancellationToken::new();
33 let handle = GenServerHandle {
34 tx,
35 cancellation_token,
36 };
37 let handle_clone = handle.clone();
38 let inner_future = async move {
39 if gen_server.run(&handle, &mut rx).await.is_err() {
40 tracing::trace!("GenServer crashed")
41 }
42 };
43
44 #[cfg(feature = "warn-on-block")]
45 let inner_future = warn_on_block::WarnOnBlocking::new(inner_future);
47
48 let _join_handle = rt::spawn(inner_future);
50
51 handle_clone
52 }
53
54 pub(crate) fn new_blocking(gen_server: G) -> Self {
55 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
56 let cancellation_token = CancellationToken::new();
57 let handle = GenServerHandle {
58 tx,
59 cancellation_token,
60 };
61 let handle_clone = handle.clone();
62 let _join_handle = rt::spawn_blocking(|| {
64 rt::block_on(async move {
65 if gen_server.run(&handle, &mut rx).await.is_err() {
66 tracing::trace!("GenServer crashed")
67 };
68 })
69 });
70 handle_clone
71 }
72
73 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
74 self.tx.clone()
75 }
76
77 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
78 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
79 }
80
81 pub async fn call_with_timeout(
82 &mut self,
83 message: G::CallMsg,
84 duration: Duration,
85 ) -> Result<G::OutMsg, GenServerError> {
86 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
87 self.tx.send(GenServerInMsg::Call {
88 sender: oneshot_tx,
89 message,
90 })?;
91
92 match timeout(duration, oneshot_rx).await {
93 Ok(Ok(result)) => result,
94 Ok(Err(_)) => Err(GenServerError::Server),
95 Err(_) => Err(GenServerError::CallTimeout),
96 }
97 }
98
99 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
100 self.tx
101 .send(GenServerInMsg::Cast { message })
102 .map_err(|_error| GenServerError::Server)
103 }
104
105 pub fn cancellation_token(&self) -> CancellationToken {
106 self.cancellation_token.clone()
107 }
108}
109
110pub enum GenServerInMsg<G: GenServer> {
111 Call {
112 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
113 message: G::CallMsg,
114 },
115 Cast {
116 message: G::CastMsg,
117 },
118}
119
120pub enum CallResponse<G: GenServer> {
121 Reply(G::OutMsg),
122 Unused,
123 Stop(G::OutMsg),
124}
125
126pub enum CastResponse {
127 NoReply,
128 Unused,
129 Stop,
130}
131
132pub enum InitResult<G: GenServer> {
133 Success(G),
134 NoSuccess(G),
135}
136
137pub trait GenServer: Send + Sized {
138 type CallMsg: Clone + Send + Sized + Sync;
139 type CastMsg: Clone + Send + Sized + Sync;
140 type OutMsg: Send + Sized;
141 type Error: Debug + Send;
142
143 fn start(self) -> GenServerHandle<Self> {
144 GenServerHandle::new(self)
145 }
146
147 fn start_blocking(self) -> GenServerHandle<Self> {
153 GenServerHandle::new_blocking(self)
154 }
155
156 fn run(
157 self,
158 handle: &GenServerHandle<Self>,
159 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
160 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
161 async {
162 let res = match self.init(handle).await {
163 Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
164 Ok(NoSuccess(intermediate_state)) => {
165 Ok(intermediate_state)
169 }
170 Err(err) => {
171 tracing::error!("Initialization failed with unhandled error: {err:?}");
172 Err(GenServerError::Initialization)
173 }
174 };
175
176 handle.cancellation_token().cancel();
177 if let Ok(final_state) = res {
178 if let Err(err) = final_state.teardown(handle).await {
179 tracing::error!("Error during teardown: {err:?}");
180 }
181 }
182 Ok(())
183 }
184 }
185
186 fn init(
190 self,
191 _handle: &GenServerHandle<Self>,
192 ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
193 async { Ok(Success(self)) }
194 }
195
196 fn main_loop(
197 mut self,
198 handle: &GenServerHandle<Self>,
199 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200 ) -> impl Future<Output = Self> + Send {
201 async {
202 loop {
203 if !self.receive(handle, rx).await {
204 break;
205 }
206 }
207 tracing::trace!("Stopping GenServer");
208 self
209 }
210 }
211
212 fn receive(
213 &mut self,
214 handle: &GenServerHandle<Self>,
215 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
216 ) -> impl Future<Output = bool> + Send {
217 async move {
218 let message = rx.recv().await;
219
220 let keep_running = match message {
221 Some(GenServerInMsg::Call { sender, message }) => {
222 let (keep_running, response) =
223 match AssertUnwindSafe(self.handle_call(message, handle))
224 .catch_unwind()
225 .await
226 {
227 Ok(response) => match response {
228 CallResponse::Reply(response) => (true, Ok(response)),
229 CallResponse::Stop(response) => (false, Ok(response)),
230 CallResponse::Unused => {
231 tracing::error!("GenServer received unexpected CallMessage");
232 (false, Err(GenServerError::CallMsgUnused))
233 }
234 },
235 Err(error) => {
236 tracing::error!("Error in callback: '{error:?}'");
237 (false, Err(GenServerError::Callback))
238 }
239 };
240 if sender.send(response).is_err() {
242 tracing::error!(
243 "GenServer failed to send response back, client must have died"
244 )
245 };
246 keep_running
247 }
248 Some(GenServerInMsg::Cast { message }) => {
249 match AssertUnwindSafe(self.handle_cast(message, handle))
250 .catch_unwind()
251 .await
252 {
253 Ok(response) => match response {
254 CastResponse::NoReply => true,
255 CastResponse::Stop => false,
256 CastResponse::Unused => {
257 tracing::error!("GenServer received unexpected CastMessage");
258 false
259 }
260 },
261 Err(error) => {
262 tracing::trace!("Error in callback: '{error:?}'");
263 false
264 }
265 }
266 }
267 None => {
268 false
270 }
271 };
272 keep_running
273 }
274 }
275
276 fn handle_call(
277 &mut self,
278 _message: Self::CallMsg,
279 _handle: &GenServerHandle<Self>,
280 ) -> impl Future<Output = CallResponse<Self>> + Send {
281 async { CallResponse::Unused }
282 }
283
284 fn handle_cast(
285 &mut self,
286 _message: Self::CastMsg,
287 _handle: &GenServerHandle<Self>,
288 ) -> impl Future<Output = CastResponse> + Send {
289 async { CastResponse::Unused }
290 }
291
292 fn teardown(
296 self,
297 _handle: &GenServerHandle<Self>,
298 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299 async { Ok(()) }
300 }
301}
302
303#[cfg(feature = "warn-on-block")]
304mod warn_on_block {
305 use super::*;
306
307 use std::time::Instant;
308 use tracing::warn;
309
310 pin_project_lite::pin_project! {
311 pub struct WarnOnBlocking<F: Future>{
312 #[pin]
313 inner: F
314 }
315 }
316
317 impl<F: Future> WarnOnBlocking<F> {
318 pub fn new(inner: F) -> Self {
319 Self { inner }
320 }
321 }
322
323 impl<F: Future> Future for WarnOnBlocking<F> {
324 type Output = F::Output;
325
326 fn poll(
327 self: std::pin::Pin<&mut Self>,
328 cx: &mut std::task::Context<'_>,
329 ) -> std::task::Poll<Self::Output> {
330 let type_id = std::any::type_name::<F>();
331 let this = self.project();
332 let now = Instant::now();
333 let res = this.inner.poll(cx);
334 let elapsed = now.elapsed();
335 if elapsed > Duration::from_millis(10) {
336 warn!(future = ?type_id, elapsed = ?elapsed, "Blocking operation detected");
337 }
338 res
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345
346 use super::*;
347 use crate::{messages::Unused, tasks::send_after};
348 use std::{
349 sync::{Arc, Mutex},
350 thread,
351 time::Duration,
352 };
353
354 struct BadlyBehavedTask;
355
356 #[derive(Clone)]
357 pub enum InMessage {
358 GetCount,
359 Stop,
360 }
361 #[derive(Clone)]
362 pub enum OutMsg {
363 Count(u64),
364 }
365
366 impl GenServer for BadlyBehavedTask {
367 type CallMsg = InMessage;
368 type CastMsg = Unused;
369 type OutMsg = Unused;
370 type Error = Unused;
371
372 async fn handle_call(
373 &mut self,
374 _: Self::CallMsg,
375 _: &GenServerHandle<Self>,
376 ) -> CallResponse<Self> {
377 CallResponse::Stop(Unused)
378 }
379
380 async fn handle_cast(
381 &mut self,
382 _: Self::CastMsg,
383 _: &GenServerHandle<Self>,
384 ) -> CastResponse {
385 rt::sleep(Duration::from_millis(20)).await;
386 thread::sleep(Duration::from_secs(2));
387 CastResponse::Stop
388 }
389 }
390
391 struct WellBehavedTask {
392 pub count: u64,
393 }
394
395 impl GenServer for WellBehavedTask {
396 type CallMsg = InMessage;
397 type CastMsg = Unused;
398 type OutMsg = OutMsg;
399 type Error = Unused;
400
401 async fn handle_call(
402 &mut self,
403 message: Self::CallMsg,
404 _: &GenServerHandle<Self>,
405 ) -> CallResponse<Self> {
406 match message {
407 InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
408 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
409 }
410 }
411
412 async fn handle_cast(
413 &mut self,
414 _: Self::CastMsg,
415 handle: &GenServerHandle<Self>,
416 ) -> CastResponse {
417 self.count += 1;
418 println!("{:?}: good still alive", thread::current().id());
419 send_after(Duration::from_millis(100), handle.to_owned(), Unused);
420 CastResponse::NoReply
421 }
422 }
423
424 #[test]
425 pub fn badly_behaved_thread_non_blocking() {
426 let runtime = rt::Runtime::new().unwrap();
427 runtime.block_on(async move {
428 let mut badboy = BadlyBehavedTask.start();
429 let _ = badboy.cast(Unused).await;
430 let mut goodboy = WellBehavedTask { count: 0 }.start();
431 let _ = goodboy.cast(Unused).await;
432 rt::sleep(Duration::from_secs(1)).await;
433 let count = goodboy.call(InMessage::GetCount).await.unwrap();
434
435 match count {
436 OutMsg::Count(num) => {
437 assert_ne!(num, 10);
438 }
439 }
440 goodboy.call(InMessage::Stop).await.unwrap();
441 });
442 }
443
444 #[test]
445 pub fn badly_behaved_thread() {
446 let runtime = rt::Runtime::new().unwrap();
447 runtime.block_on(async move {
448 let mut badboy = BadlyBehavedTask.start_blocking();
449 let _ = badboy.cast(Unused).await;
450 let mut goodboy = WellBehavedTask { count: 0 }.start();
451 let _ = goodboy.cast(Unused).await;
452 rt::sleep(Duration::from_secs(1)).await;
453 let count = goodboy.call(InMessage::GetCount).await.unwrap();
454
455 match count {
456 OutMsg::Count(num) => {
457 assert_eq!(num, 10);
458 }
459 }
460 goodboy.call(InMessage::Stop).await.unwrap();
461 });
462 }
463
464 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
465
466 #[derive(Debug, Default)]
467 struct SomeTask;
468
469 #[derive(Clone)]
470 enum SomeTaskCallMsg {
471 SlowOperation,
472 FastOperation,
473 }
474
475 impl GenServer for SomeTask {
476 type CallMsg = SomeTaskCallMsg;
477 type CastMsg = Unused;
478 type OutMsg = Unused;
479 type Error = Unused;
480
481 async fn handle_call(
482 &mut self,
483 message: Self::CallMsg,
484 _handle: &GenServerHandle<Self>,
485 ) -> CallResponse<Self> {
486 match message {
487 SomeTaskCallMsg::SlowOperation => {
488 rt::sleep(TIMEOUT_DURATION * 2).await;
490 CallResponse::Reply(Unused)
491 }
492 SomeTaskCallMsg::FastOperation => {
493 rt::sleep(TIMEOUT_DURATION / 2).await;
495 CallResponse::Reply(Unused)
496 }
497 }
498 }
499 }
500
501 #[test]
502 pub fn unresolving_task_times_out() {
503 let runtime = rt::Runtime::new().unwrap();
504 runtime.block_on(async move {
505 let mut unresolving_task = SomeTask.start();
506
507 let result = unresolving_task
508 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
509 .await;
510 assert!(matches!(result, Ok(Unused)));
511
512 let result = unresolving_task
513 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
514 .await;
515 assert!(matches!(result, Err(GenServerError::CallTimeout)));
516 });
517 }
518
519 struct SomeTaskThatFailsOnInit {
520 sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
521 }
522
523 impl SomeTaskThatFailsOnInit {
524 pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
525 Self { sender_channel }
526 }
527 }
528
529 impl GenServer for SomeTaskThatFailsOnInit {
530 type CallMsg = Unused;
531 type CastMsg = Unused;
532 type OutMsg = Unused;
533 type Error = Unused;
534
535 async fn init(
536 self,
537 _handle: &GenServerHandle<Self>,
538 ) -> Result<InitResult<Self>, Self::Error> {
539 Ok(NoSuccess(self))
541 }
542
543 async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
544 self.sender_channel.lock().unwrap().close();
545 Ok(())
546 }
547 }
548
549 #[test]
550 pub fn task_fails_with_intermediate_state() {
551 let runtime = rt::Runtime::new().unwrap();
552 runtime.block_on(async move {
553 let (rx, tx) = mpsc::channel::<u8>();
554 let sender_channel = Arc::new(Mutex::new(tx));
555 let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
556
557 rt::sleep(Duration::from_secs(1)).await;
559
560 assert!(rx.is_closed())
562 });
563 }
564}