spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::{
8 error::GenServerError,
9 tasks::InitResult::{NoSuccess, Success},
10};
11
12const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
13
14#[derive(Debug)]
15pub struct GenServerHandle<G: GenServer + 'static> {
16 pub tx: mpsc::Sender<GenServerInMsg<G>>,
17 cancellation_token: CancellationToken,
19}
20
21impl<G: GenServer> Clone for GenServerHandle<G> {
22 fn clone(&self) -> Self {
23 Self {
24 tx: self.tx.clone(),
25 cancellation_token: self.cancellation_token.clone(),
26 }
27 }
28}
29
30impl<G: GenServer> GenServerHandle<G> {
31 pub(crate) fn new(gen_server: G) -> Self {
32 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
33 let cancellation_token = CancellationToken::new();
34 let handle = GenServerHandle {
35 tx,
36 cancellation_token,
37 };
38 let handle_clone = handle.clone();
39 let _join_handle = rt::spawn(async move {
41 if gen_server.run(&handle, &mut rx).await.is_err() {
42 tracing::trace!("GenServer crashed")
43 };
44 });
45 handle_clone
46 }
47
48 pub(crate) fn new_blocking(gen_server: G) -> Self {
49 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50 let cancellation_token = CancellationToken::new();
51 let handle = GenServerHandle {
52 tx,
53 cancellation_token,
54 };
55 let handle_clone = handle.clone();
56 let _join_handle = rt::spawn_blocking(|| {
58 rt::block_on(async move {
59 if gen_server.run(&handle, &mut rx).await.is_err() {
60 tracing::trace!("GenServer crashed")
61 };
62 })
63 });
64 handle_clone
65 }
66
67 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
68 self.tx.clone()
69 }
70
71 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
72 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
73 }
74
75 pub async fn call_with_timeout(
76 &mut self,
77 message: G::CallMsg,
78 duration: Duration,
79 ) -> Result<G::OutMsg, GenServerError> {
80 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
81 self.tx.send(GenServerInMsg::Call {
82 sender: oneshot_tx,
83 message,
84 })?;
85
86 match timeout(duration, oneshot_rx).await {
87 Ok(Ok(result)) => result,
88 Ok(Err(_)) => Err(GenServerError::Server),
89 Err(_) => Err(GenServerError::CallTimeout),
90 }
91 }
92
93 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
94 self.tx
95 .send(GenServerInMsg::Cast { message })
96 .map_err(|_error| GenServerError::Server)
97 }
98
99 pub fn cancellation_token(&self) -> CancellationToken {
100 self.cancellation_token.clone()
101 }
102}
103
104pub enum GenServerInMsg<G: GenServer> {
105 Call {
106 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
107 message: G::CallMsg,
108 },
109 Cast {
110 message: G::CastMsg,
111 },
112}
113
114pub enum CallResponse<G: GenServer> {
115 Reply(G, G::OutMsg),
116 Unused,
117 Stop(G::OutMsg),
118}
119
120pub enum CastResponse<G: GenServer> {
121 NoReply(G),
122 Unused,
123 Stop,
124}
125
126pub enum InitResult<G: GenServer> {
127 Success(G),
128 NoSuccess(G),
129}
130
131pub trait GenServer: Send + Sized + Clone {
132 type CallMsg: Clone + Send + Sized + Sync;
133 type CastMsg: Clone + Send + Sized + Sync;
134 type OutMsg: Send + Sized;
135 type Error: Debug + Send;
136
137 fn start(self) -> GenServerHandle<Self> {
138 GenServerHandle::new(self)
139 }
140
141 fn start_blocking(self) -> GenServerHandle<Self> {
147 GenServerHandle::new_blocking(self)
148 }
149
150 fn run(
151 self,
152 handle: &GenServerHandle<Self>,
153 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
154 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
155 async {
156 let res = match self.init(handle).await {
157 Ok(Success(new_state)) => new_state.main_loop(handle, rx).await,
158 Ok(NoSuccess(intermediate_state)) => {
159 Ok(intermediate_state)
163 }
164 Err(err) => {
165 tracing::error!("Initialization failed with unhandled error: {err:?}");
166 Err(GenServerError::Initialization)
167 }
168 };
169
170 handle.cancellation_token().cancel();
171 if let Ok(final_state) = res {
172 if let Err(err) = final_state.teardown(handle).await {
173 tracing::error!("Error during teardown: {err:?}");
174 }
175 }
176 Ok(())
177 }
178 }
179
180 fn init(
184 self,
185 _handle: &GenServerHandle<Self>,
186 ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
187 async { Ok(Success(self)) }
188 }
189
190 fn main_loop(
191 mut self,
192 handle: &GenServerHandle<Self>,
193 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
194 ) -> impl Future<Output = Result<Self, GenServerError>> + Send {
195 async {
196 loop {
197 let (new_state, cont) = self.receive(handle, rx).await?;
198 self = new_state;
199 if !cont {
200 break;
201 }
202 }
203 tracing::trace!("Stopping GenServer");
204 Ok(self)
205 }
206 }
207
208 fn receive(
209 self,
210 handle: &GenServerHandle<Self>,
211 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
212 ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
213 async move {
214 let message = rx.recv().await;
215
216 let state_clone = self.clone();
218
219 let (keep_running, new_state) = match message {
220 Some(GenServerInMsg::Call { sender, message }) => {
221 let (keep_running, new_state, response) =
222 match AssertUnwindSafe(self.handle_call(message, handle))
223 .catch_unwind()
224 .await
225 {
226 Ok(response) => match response {
227 CallResponse::Reply(new_state, response) => {
228 (true, new_state, Ok(response))
229 }
230 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
231 CallResponse::Unused => {
232 tracing::error!("GenServer received unexpected CallMessage");
233 (false, state_clone, Err(GenServerError::CallMsgUnused))
234 }
235 },
236 Err(error) => {
237 tracing::error!(
238 "Error in callback, reverting state - Error: '{error:?}'"
239 );
240 (true, state_clone, Err(GenServerError::Callback))
241 }
242 };
243 if sender.send(response).is_err() {
245 tracing::error!(
246 "GenServer failed to send response back, client must have died"
247 )
248 };
249 (keep_running, new_state)
250 }
251 Some(GenServerInMsg::Cast { message }) => {
252 match AssertUnwindSafe(self.handle_cast(message, handle))
253 .catch_unwind()
254 .await
255 {
256 Ok(response) => match response {
257 CastResponse::NoReply(new_state) => (true, new_state),
258 CastResponse::Stop => (false, state_clone),
259 CastResponse::Unused => {
260 tracing::error!("GenServer received unexpected CastMessage");
261 (false, state_clone)
262 }
263 },
264 Err(error) => {
265 tracing::trace!(
266 "Error in callback, reverting state - Error: '{error:?}'"
267 );
268 (true, state_clone)
269 }
270 }
271 }
272 None => {
273 (false, self)
275 }
276 };
277 Ok((new_state, keep_running))
278 }
279 }
280
281 fn handle_call(
282 self,
283 _message: Self::CallMsg,
284 _handle: &GenServerHandle<Self>,
285 ) -> impl Future<Output = CallResponse<Self>> + Send {
286 async { CallResponse::Unused }
287 }
288
289 fn handle_cast(
290 self,
291 _message: Self::CastMsg,
292 _handle: &GenServerHandle<Self>,
293 ) -> impl Future<Output = CastResponse<Self>> + Send {
294 async { CastResponse::Unused }
295 }
296
297 fn teardown(
301 self,
302 _handle: &GenServerHandle<Self>,
303 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
304 async { Ok(()) }
305 }
306}
307
308#[cfg(test)]
309mod tests {
310
311 use super::*;
312 use crate::{messages::Unused, tasks::send_after};
313 use std::{
314 sync::{Arc, Mutex},
315 thread,
316 time::Duration,
317 };
318
319 #[derive(Clone)]
320 struct BadlyBehavedTask;
321
322 #[derive(Clone)]
323 pub enum InMessage {
324 GetCount,
325 Stop,
326 }
327 #[derive(Clone)]
328 pub enum OutMsg {
329 Count(u64),
330 }
331
332 impl GenServer for BadlyBehavedTask {
333 type CallMsg = InMessage;
334 type CastMsg = Unused;
335 type OutMsg = Unused;
336 type Error = Unused;
337
338 async fn handle_call(
339 self,
340 _: Self::CallMsg,
341 _: &GenServerHandle<Self>,
342 ) -> CallResponse<Self> {
343 CallResponse::Stop(Unused)
344 }
345
346 async fn handle_cast(
347 self,
348 _: Self::CastMsg,
349 _: &GenServerHandle<Self>,
350 ) -> CastResponse<Self> {
351 rt::sleep(Duration::from_millis(20)).await;
352 thread::sleep(Duration::from_secs(2));
353 CastResponse::Stop
354 }
355 }
356
357 #[derive(Clone)]
358 struct WellBehavedTask {
359 pub count: u64,
360 }
361
362 impl GenServer for WellBehavedTask {
363 type CallMsg = InMessage;
364 type CastMsg = Unused;
365 type OutMsg = OutMsg;
366 type Error = Unused;
367
368 async fn handle_call(
369 self,
370 message: Self::CallMsg,
371 _: &GenServerHandle<Self>,
372 ) -> CallResponse<Self> {
373 match message {
374 InMessage::GetCount => {
375 let count = self.count;
376 CallResponse::Reply(self, OutMsg::Count(count))
377 }
378 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
379 }
380 }
381
382 async fn handle_cast(
383 mut self,
384 _: Self::CastMsg,
385 handle: &GenServerHandle<Self>,
386 ) -> CastResponse<Self> {
387 self.count += 1;
388 println!("{:?}: good still alive", thread::current().id());
389 send_after(Duration::from_millis(100), handle.to_owned(), Unused);
390 CastResponse::NoReply(self)
391 }
392 }
393
394 #[test]
395 pub fn badly_behaved_thread_non_blocking() {
396 let runtime = rt::Runtime::new().unwrap();
397 runtime.block_on(async move {
398 let mut badboy = BadlyBehavedTask.start();
399 let _ = badboy.cast(Unused).await;
400 let mut goodboy = WellBehavedTask { count: 0 }.start();
401 let _ = goodboy.cast(Unused).await;
402 rt::sleep(Duration::from_secs(1)).await;
403 let count = goodboy.call(InMessage::GetCount).await.unwrap();
404
405 match count {
406 OutMsg::Count(num) => {
407 assert_ne!(num, 10);
408 }
409 }
410 goodboy.call(InMessage::Stop).await.unwrap();
411 });
412 }
413
414 #[test]
415 pub fn badly_behaved_thread() {
416 let runtime = rt::Runtime::new().unwrap();
417 runtime.block_on(async move {
418 let mut badboy = BadlyBehavedTask.start_blocking();
419 let _ = badboy.cast(Unused).await;
420 let mut goodboy = WellBehavedTask { count: 0 }.start();
421 let _ = goodboy.cast(Unused).await;
422 rt::sleep(Duration::from_secs(1)).await;
423 let count = goodboy.call(InMessage::GetCount).await.unwrap();
424
425 match count {
426 OutMsg::Count(num) => {
427 assert_eq!(num, 10);
428 }
429 }
430 goodboy.call(InMessage::Stop).await.unwrap();
431 });
432 }
433
434 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
435
436 #[derive(Debug, Default, Clone)]
437 struct SomeTask;
438
439 #[derive(Clone)]
440 enum SomeTaskCallMsg {
441 SlowOperation,
442 FastOperation,
443 }
444
445 impl GenServer for SomeTask {
446 type CallMsg = SomeTaskCallMsg;
447 type CastMsg = Unused;
448 type OutMsg = Unused;
449 type Error = Unused;
450
451 async fn handle_call(
452 self,
453 message: Self::CallMsg,
454 _handle: &GenServerHandle<Self>,
455 ) -> CallResponse<Self> {
456 match message {
457 SomeTaskCallMsg::SlowOperation => {
458 rt::sleep(TIMEOUT_DURATION * 2).await;
460 CallResponse::Reply(self, Unused)
461 }
462 SomeTaskCallMsg::FastOperation => {
463 rt::sleep(TIMEOUT_DURATION / 2).await;
465 CallResponse::Reply(self, Unused)
466 }
467 }
468 }
469 }
470
471 #[test]
472 pub fn unresolving_task_times_out() {
473 let runtime = rt::Runtime::new().unwrap();
474 runtime.block_on(async move {
475 let mut unresolving_task = SomeTask.start();
476
477 let result = unresolving_task
478 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
479 .await;
480 assert!(matches!(result, Ok(Unused)));
481
482 let result = unresolving_task
483 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
484 .await;
485 assert!(matches!(result, Err(GenServerError::CallTimeout)));
486 });
487 }
488
489 #[derive(Clone)]
490 struct SomeTaskThatFailsOnInit {
491 sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
492 }
493
494 impl SomeTaskThatFailsOnInit {
495 pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
496 Self { sender_channel }
497 }
498 }
499
500 impl GenServer for SomeTaskThatFailsOnInit {
501 type CallMsg = Unused;
502 type CastMsg = Unused;
503 type OutMsg = Unused;
504 type Error = Unused;
505
506 async fn init(
507 self,
508 _handle: &GenServerHandle<Self>,
509 ) -> Result<InitResult<Self>, Self::Error> {
510 Ok(NoSuccess(self))
512 }
513
514 async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
515 self.sender_channel.lock().unwrap().close();
516 Ok(())
517 }
518 }
519
520 #[test]
521 pub fn task_fails_with_intermediate_state() {
522 let runtime = rt::Runtime::new().unwrap();
523 runtime.block_on(async move {
524 let (rx, tx) = mpsc::channel::<u8>();
525 let sender_channel = Arc::new(Mutex::new(tx));
526 let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
527
528 rt::sleep(Duration::from_secs(1)).await;
530
531 assert!(rx.is_closed())
533 });
534 }
535}