spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::{
8 error::GenServerError,
9 tasks::InitResult::{NoSuccess, Success},
10};
11
12const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
13
14#[derive(Debug)]
15pub struct GenServerHandle<G: GenServer + 'static> {
16 pub tx: mpsc::Sender<GenServerInMsg<G>>,
17 cancellation_token: CancellationToken,
19}
20
21impl<G: GenServer> Clone for GenServerHandle<G> {
22 fn clone(&self) -> Self {
23 Self {
24 tx: self.tx.clone(),
25 cancellation_token: self.cancellation_token.clone(),
26 }
27 }
28}
29
30impl<G: GenServer> GenServerHandle<G> {
31 pub(crate) fn new(gen_server: G) -> Self {
32 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
33 let cancellation_token = CancellationToken::new();
34 let handle = GenServerHandle {
35 tx,
36 cancellation_token,
37 };
38 let handle_clone = handle.clone();
39 let _join_handle = rt::spawn(async move {
41 if gen_server.run(&handle, &mut rx).await.is_err() {
42 tracing::trace!("GenServer crashed")
43 };
44 });
45 handle_clone
46 }
47
48 pub(crate) fn new_blocking(gen_server: G) -> Self {
49 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50 let cancellation_token = CancellationToken::new();
51 let handle = GenServerHandle {
52 tx,
53 cancellation_token,
54 };
55 let handle_clone = handle.clone();
56 let _join_handle = rt::spawn_blocking(|| {
58 rt::block_on(async move {
59 if gen_server.run(&handle, &mut rx).await.is_err() {
60 tracing::trace!("GenServer crashed")
61 };
62 })
63 });
64 handle_clone
65 }
66
67 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
68 self.tx.clone()
69 }
70
71 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
72 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
73 }
74
75 pub async fn call_with_timeout(
76 &mut self,
77 message: G::CallMsg,
78 duration: Duration,
79 ) -> Result<G::OutMsg, GenServerError> {
80 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
81 self.tx.send(GenServerInMsg::Call {
82 sender: oneshot_tx,
83 message,
84 })?;
85
86 match timeout(duration, oneshot_rx).await {
87 Ok(Ok(result)) => result,
88 Ok(Err(_)) => Err(GenServerError::Server),
89 Err(_) => Err(GenServerError::CallTimeout),
90 }
91 }
92
93 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
94 self.tx
95 .send(GenServerInMsg::Cast { message })
96 .map_err(|_error| GenServerError::Server)
97 }
98
99 pub fn cancellation_token(&self) -> CancellationToken {
100 self.cancellation_token.clone()
101 }
102}
103
104pub enum GenServerInMsg<G: GenServer> {
105 Call {
106 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
107 message: G::CallMsg,
108 },
109 Cast {
110 message: G::CastMsg,
111 },
112}
113
114pub enum CallResponse<G: GenServer> {
115 Reply(G::OutMsg),
116 Unused,
117 Stop(G::OutMsg),
118}
119
120pub enum CastResponse {
121 NoReply,
122 Unused,
123 Stop,
124}
125
126pub enum InitResult<G: GenServer> {
127 Success(G),
128 NoSuccess(G),
129}
130
131pub trait GenServer: Send + Sized {
132 type CallMsg: Clone + Send + Sized + Sync;
133 type CastMsg: Clone + Send + Sized + Sync;
134 type OutMsg: Send + Sized;
135 type Error: Debug + Send;
136
137 fn start(self) -> GenServerHandle<Self> {
138 GenServerHandle::new(self)
139 }
140
141 fn start_blocking(self) -> GenServerHandle<Self> {
147 GenServerHandle::new_blocking(self)
148 }
149
150 fn run(
151 self,
152 handle: &GenServerHandle<Self>,
153 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
154 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
155 async {
156 let res = match self.init(handle).await {
157 Ok(Success(new_state)) => Ok(new_state.main_loop(handle, rx).await),
158 Ok(NoSuccess(intermediate_state)) => {
159 Ok(intermediate_state)
163 }
164 Err(err) => {
165 tracing::error!("Initialization failed with unhandled error: {err:?}");
166 Err(GenServerError::Initialization)
167 }
168 };
169
170 handle.cancellation_token().cancel();
171 if let Ok(final_state) = res {
172 if let Err(err) = final_state.teardown(handle).await {
173 tracing::error!("Error during teardown: {err:?}");
174 }
175 }
176 Ok(())
177 }
178 }
179
180 fn init(
184 self,
185 _handle: &GenServerHandle<Self>,
186 ) -> impl Future<Output = Result<InitResult<Self>, Self::Error>> + Send {
187 async { Ok(Success(self)) }
188 }
189
190 fn main_loop(
191 mut self,
192 handle: &GenServerHandle<Self>,
193 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
194 ) -> impl Future<Output = Self> + Send {
195 async {
196 loop {
197 if !self.receive(handle, rx).await {
198 break;
199 }
200 }
201 tracing::trace!("Stopping GenServer");
202 self
203 }
204 }
205
206 fn receive(
207 &mut self,
208 handle: &GenServerHandle<Self>,
209 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
210 ) -> impl Future<Output = bool> + Send {
211 async move {
212 let message = rx.recv().await;
213
214 let keep_running = match message {
215 Some(GenServerInMsg::Call { sender, message }) => {
216 let (keep_running, response) =
217 match AssertUnwindSafe(self.handle_call(message, handle))
218 .catch_unwind()
219 .await
220 {
221 Ok(response) => match response {
222 CallResponse::Reply(response) => (true, Ok(response)),
223 CallResponse::Stop(response) => (false, Ok(response)),
224 CallResponse::Unused => {
225 tracing::error!("GenServer received unexpected CallMessage");
226 (false, Err(GenServerError::CallMsgUnused))
227 }
228 },
229 Err(error) => {
230 tracing::error!("Error in callback: '{error:?}'");
231 (false, Err(GenServerError::Callback))
232 }
233 };
234 if sender.send(response).is_err() {
236 tracing::error!(
237 "GenServer failed to send response back, client must have died"
238 )
239 };
240 keep_running
241 }
242 Some(GenServerInMsg::Cast { message }) => {
243 match AssertUnwindSafe(self.handle_cast(message, handle))
244 .catch_unwind()
245 .await
246 {
247 Ok(response) => match response {
248 CastResponse::NoReply => true,
249 CastResponse::Stop => false,
250 CastResponse::Unused => {
251 tracing::error!("GenServer received unexpected CastMessage");
252 false
253 }
254 },
255 Err(error) => {
256 tracing::trace!("Error in callback: '{error:?}'");
257 false
258 }
259 }
260 }
261 None => {
262 false
264 }
265 };
266 keep_running
267 }
268 }
269
270 fn handle_call(
271 &mut self,
272 _message: Self::CallMsg,
273 _handle: &GenServerHandle<Self>,
274 ) -> impl Future<Output = CallResponse<Self>> + Send {
275 async { CallResponse::Unused }
276 }
277
278 fn handle_cast(
279 &mut self,
280 _message: Self::CastMsg,
281 _handle: &GenServerHandle<Self>,
282 ) -> impl Future<Output = CastResponse> + Send {
283 async { CastResponse::Unused }
284 }
285
286 fn teardown(
290 self,
291 _handle: &GenServerHandle<Self>,
292 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
293 async { Ok(()) }
294 }
295}
296
297#[cfg(test)]
298mod tests {
299
300 use super::*;
301 use crate::{messages::Unused, tasks::send_after};
302 use std::{
303 sync::{Arc, Mutex},
304 thread,
305 time::Duration,
306 };
307
308 struct BadlyBehavedTask;
309
310 #[derive(Clone)]
311 pub enum InMessage {
312 GetCount,
313 Stop,
314 }
315 #[derive(Clone)]
316 pub enum OutMsg {
317 Count(u64),
318 }
319
320 impl GenServer for BadlyBehavedTask {
321 type CallMsg = InMessage;
322 type CastMsg = Unused;
323 type OutMsg = Unused;
324 type Error = Unused;
325
326 async fn handle_call(
327 &mut self,
328 _: Self::CallMsg,
329 _: &GenServerHandle<Self>,
330 ) -> CallResponse<Self> {
331 CallResponse::Stop(Unused)
332 }
333
334 async fn handle_cast(
335 &mut self,
336 _: Self::CastMsg,
337 _: &GenServerHandle<Self>,
338 ) -> CastResponse {
339 rt::sleep(Duration::from_millis(20)).await;
340 thread::sleep(Duration::from_secs(2));
341 CastResponse::Stop
342 }
343 }
344
345 struct WellBehavedTask {
346 pub count: u64,
347 }
348
349 impl GenServer for WellBehavedTask {
350 type CallMsg = InMessage;
351 type CastMsg = Unused;
352 type OutMsg = OutMsg;
353 type Error = Unused;
354
355 async fn handle_call(
356 &mut self,
357 message: Self::CallMsg,
358 _: &GenServerHandle<Self>,
359 ) -> CallResponse<Self> {
360 match message {
361 InMessage::GetCount => CallResponse::Reply(OutMsg::Count(self.count)),
362 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
363 }
364 }
365
366 async fn handle_cast(
367 &mut self,
368 _: Self::CastMsg,
369 handle: &GenServerHandle<Self>,
370 ) -> CastResponse {
371 self.count += 1;
372 println!("{:?}: good still alive", thread::current().id());
373 send_after(Duration::from_millis(100), handle.to_owned(), Unused);
374 CastResponse::NoReply
375 }
376 }
377
378 #[test]
379 pub fn badly_behaved_thread_non_blocking() {
380 let runtime = rt::Runtime::new().unwrap();
381 runtime.block_on(async move {
382 let mut badboy = BadlyBehavedTask.start();
383 let _ = badboy.cast(Unused).await;
384 let mut goodboy = WellBehavedTask { count: 0 }.start();
385 let _ = goodboy.cast(Unused).await;
386 rt::sleep(Duration::from_secs(1)).await;
387 let count = goodboy.call(InMessage::GetCount).await.unwrap();
388
389 match count {
390 OutMsg::Count(num) => {
391 assert_ne!(num, 10);
392 }
393 }
394 goodboy.call(InMessage::Stop).await.unwrap();
395 });
396 }
397
398 #[test]
399 pub fn badly_behaved_thread() {
400 let runtime = rt::Runtime::new().unwrap();
401 runtime.block_on(async move {
402 let mut badboy = BadlyBehavedTask.start_blocking();
403 let _ = badboy.cast(Unused).await;
404 let mut goodboy = WellBehavedTask { count: 0 }.start();
405 let _ = goodboy.cast(Unused).await;
406 rt::sleep(Duration::from_secs(1)).await;
407 let count = goodboy.call(InMessage::GetCount).await.unwrap();
408
409 match count {
410 OutMsg::Count(num) => {
411 assert_eq!(num, 10);
412 }
413 }
414 goodboy.call(InMessage::Stop).await.unwrap();
415 });
416 }
417
418 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
419
420 #[derive(Debug, Default)]
421 struct SomeTask;
422
423 #[derive(Clone)]
424 enum SomeTaskCallMsg {
425 SlowOperation,
426 FastOperation,
427 }
428
429 impl GenServer for SomeTask {
430 type CallMsg = SomeTaskCallMsg;
431 type CastMsg = Unused;
432 type OutMsg = Unused;
433 type Error = Unused;
434
435 async fn handle_call(
436 &mut self,
437 message: Self::CallMsg,
438 _handle: &GenServerHandle<Self>,
439 ) -> CallResponse<Self> {
440 match message {
441 SomeTaskCallMsg::SlowOperation => {
442 rt::sleep(TIMEOUT_DURATION * 2).await;
444 CallResponse::Reply(Unused)
445 }
446 SomeTaskCallMsg::FastOperation => {
447 rt::sleep(TIMEOUT_DURATION / 2).await;
449 CallResponse::Reply(Unused)
450 }
451 }
452 }
453 }
454
455 #[test]
456 pub fn unresolving_task_times_out() {
457 let runtime = rt::Runtime::new().unwrap();
458 runtime.block_on(async move {
459 let mut unresolving_task = SomeTask.start();
460
461 let result = unresolving_task
462 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
463 .await;
464 assert!(matches!(result, Ok(Unused)));
465
466 let result = unresolving_task
467 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
468 .await;
469 assert!(matches!(result, Err(GenServerError::CallTimeout)));
470 });
471 }
472
473 struct SomeTaskThatFailsOnInit {
474 sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>,
475 }
476
477 impl SomeTaskThatFailsOnInit {
478 pub fn new(sender_channel: Arc<Mutex<mpsc::Receiver<u8>>>) -> Self {
479 Self { sender_channel }
480 }
481 }
482
483 impl GenServer for SomeTaskThatFailsOnInit {
484 type CallMsg = Unused;
485 type CastMsg = Unused;
486 type OutMsg = Unused;
487 type Error = Unused;
488
489 async fn init(
490 self,
491 _handle: &GenServerHandle<Self>,
492 ) -> Result<InitResult<Self>, Self::Error> {
493 Ok(NoSuccess(self))
495 }
496
497 async fn teardown(self, _handle: &GenServerHandle<Self>) -> Result<(), Self::Error> {
498 self.sender_channel.lock().unwrap().close();
499 Ok(())
500 }
501 }
502
503 #[test]
504 pub fn task_fails_with_intermediate_state() {
505 let runtime = rt::Runtime::new().unwrap();
506 runtime.block_on(async move {
507 let (rx, tx) = mpsc::channel::<u8>();
508 let sender_channel = Arc::new(Mutex::new(tx));
509 let _task = SomeTaskThatFailsOnInit::new(sender_channel).start();
510
511 rt::sleep(Duration::from_secs(1)).await;
513
514 assert!(rx.is_closed())
516 });
517 }
518}