spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14 cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19 fn clone(&self) -> Self {
20 Self {
21 tx: self.tx.clone(),
22 cancellation_token: self.cancellation_token.clone(),
23 }
24 }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28 pub(crate) fn new(gen_server: G) -> Self {
29 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30 let cancellation_token = CancellationToken::new();
31 let handle = GenServerHandle {
32 tx,
33 cancellation_token,
34 };
35 let handle_clone = handle.clone();
36 let _join_handle = rt::spawn(async move {
38 if gen_server.run(&handle, &mut rx).await.is_err() {
39 tracing::trace!("GenServer crashed")
40 };
41 });
42 handle_clone
43 }
44
45 pub(crate) fn new_blocking(gen_server: G) -> Self {
46 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
47 let cancellation_token = CancellationToken::new();
48 let handle = GenServerHandle {
49 tx,
50 cancellation_token,
51 };
52 let handle_clone = handle.clone();
53 let _join_handle = rt::spawn_blocking(|| {
55 rt::block_on(async move {
56 if gen_server.run(&handle, &mut rx).await.is_err() {
57 tracing::trace!("GenServer crashed")
58 };
59 })
60 });
61 handle_clone
62 }
63
64 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
65 self.tx.clone()
66 }
67
68 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
69 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
70 }
71
72 pub async fn call_with_timeout(
73 &mut self,
74 message: G::CallMsg,
75 duration: Duration,
76 ) -> Result<G::OutMsg, GenServerError> {
77 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78 self.tx.send(GenServerInMsg::Call {
79 sender: oneshot_tx,
80 message,
81 })?;
82
83 match timeout(duration, oneshot_rx).await {
84 Ok(Ok(result)) => result,
85 Ok(Err(_)) => Err(GenServerError::Server),
86 Err(_) => Err(GenServerError::CallTimeout),
87 }
88 }
89
90 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
91 self.tx
92 .send(GenServerInMsg::Cast { message })
93 .map_err(|_error| GenServerError::Server)
94 }
95
96 pub fn cancellation_token(&self) -> CancellationToken {
97 self.cancellation_token.clone()
98 }
99}
100
101pub enum GenServerInMsg<G: GenServer> {
102 Call {
103 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
104 message: G::CallMsg,
105 },
106 Cast {
107 message: G::CastMsg,
108 },
109}
110
111pub enum CallResponse<G: GenServer> {
112 Reply(G, G::OutMsg),
113 Unused,
114 Stop(G::OutMsg),
115}
116
117pub enum CastResponse<G: GenServer> {
118 NoReply(G),
119 Unused,
120 Stop,
121}
122
123pub trait GenServer: Send + Sized + Clone {
124 type CallMsg: Clone + Send + Sized + Sync;
125 type CastMsg: Clone + Send + Sized + Sync;
126 type OutMsg: Send + Sized;
127 type Error: Debug + Send;
128
129 fn start(self) -> GenServerHandle<Self> {
130 GenServerHandle::new(self)
131 }
132
133 fn start_blocking(self) -> GenServerHandle<Self> {
139 GenServerHandle::new_blocking(self)
140 }
141
142 fn run(
143 self,
144 handle: &GenServerHandle<Self>,
145 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
146 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
147 async {
148 let init_result = self
149 .clone()
150 .init(handle)
151 .await
152 .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
153
154 let res = match init_result {
155 Ok(new_state) => new_state.main_loop(handle, rx).await,
156 Err(_) => Err(GenServerError::Initialization),
157 };
158
159 handle.cancellation_token().cancel();
160 if let Err(err) = self.teardown(handle).await {
161 tracing::error!("Error during teardown: {err:?}");
162 }
163 res
164 }
165 }
166
167 fn init(
171 self,
172 _handle: &GenServerHandle<Self>,
173 ) -> impl Future<Output = Result<Self, Self::Error>> + Send {
174 async { Ok(self) }
175 }
176
177 fn main_loop(
178 mut self,
179 handle: &GenServerHandle<Self>,
180 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
181 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
182 async {
183 loop {
184 let (new_state, cont) = self.receive(handle, rx).await?;
185 self = new_state;
186 if !cont {
187 break;
188 }
189 }
190 tracing::trace!("Stopping GenServer");
191 Ok(())
192 }
193 }
194
195 fn receive(
196 self,
197 handle: &GenServerHandle<Self>,
198 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
199 ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
200 async move {
201 let message = rx.recv().await;
202
203 let state_clone = self.clone();
205
206 let (keep_running, new_state) = match message {
207 Some(GenServerInMsg::Call { sender, message }) => {
208 let (keep_running, new_state, response) =
209 match AssertUnwindSafe(self.handle_call(message, handle))
210 .catch_unwind()
211 .await
212 {
213 Ok(response) => match response {
214 CallResponse::Reply(new_state, response) => {
215 (true, new_state, Ok(response))
216 }
217 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
218 CallResponse::Unused => {
219 tracing::error!("GenServer received unexpected CallMessage");
220 (false, state_clone, Err(GenServerError::CallMsgUnused))
221 }
222 },
223 Err(error) => {
224 tracing::error!(
225 "Error in callback, reverting state - Error: '{error:?}'"
226 );
227 (true, state_clone, Err(GenServerError::Callback))
228 }
229 };
230 if sender.send(response).is_err() {
232 tracing::error!(
233 "GenServer failed to send response back, client must have died"
234 )
235 };
236 (keep_running, new_state)
237 }
238 Some(GenServerInMsg::Cast { message }) => {
239 match AssertUnwindSafe(self.handle_cast(message, handle))
240 .catch_unwind()
241 .await
242 {
243 Ok(response) => match response {
244 CastResponse::NoReply(new_state) => (true, new_state),
245 CastResponse::Stop => (false, state_clone),
246 CastResponse::Unused => {
247 tracing::error!("GenServer received unexpected CastMessage");
248 (false, state_clone)
249 }
250 },
251 Err(error) => {
252 tracing::trace!(
253 "Error in callback, reverting state - Error: '{error:?}'"
254 );
255 (true, state_clone)
256 }
257 }
258 }
259 None => {
260 (false, self)
262 }
263 };
264 Ok((new_state, keep_running))
265 }
266 }
267
268 fn handle_call(
269 self,
270 _message: Self::CallMsg,
271 _handle: &GenServerHandle<Self>,
272 ) -> impl Future<Output = CallResponse<Self>> + Send {
273 async { CallResponse::Unused }
274 }
275
276 fn handle_cast(
277 self,
278 _message: Self::CastMsg,
279 _handle: &GenServerHandle<Self>,
280 ) -> impl Future<Output = CastResponse<Self>> + Send {
281 async { CastResponse::Unused }
282 }
283
284 fn teardown(
288 self,
289 _handle: &GenServerHandle<Self>,
290 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
291 async { Ok(()) }
292 }
293}
294
295#[cfg(test)]
296mod tests {
297
298 use super::*;
299 use crate::tasks::send_after;
300 use std::{thread, time::Duration};
301
302 #[derive(Clone)]
303 struct BadlyBehavedTask;
304
305 #[derive(Clone)]
306 pub enum InMessage {
307 GetCount,
308 Stop,
309 }
310 #[derive(Clone)]
311 pub enum OutMsg {
312 Count(u64),
313 }
314
315 impl GenServer for BadlyBehavedTask {
316 type CallMsg = InMessage;
317 type CastMsg = ();
318 type OutMsg = ();
319 type Error = ();
320
321 async fn handle_call(
322 self,
323 _: Self::CallMsg,
324 _: &GenServerHandle<Self>,
325 ) -> CallResponse<Self> {
326 CallResponse::Stop(())
327 }
328
329 async fn handle_cast(
330 self,
331 _: Self::CastMsg,
332 _: &GenServerHandle<Self>,
333 ) -> CastResponse<Self> {
334 rt::sleep(Duration::from_millis(20)).await;
335 thread::sleep(Duration::from_secs(2));
336 CastResponse::Stop
337 }
338 }
339
340 #[derive(Clone)]
341 struct WellBehavedTask {
342 pub count: u64,
343 }
344
345 impl GenServer for WellBehavedTask {
346 type CallMsg = InMessage;
347 type CastMsg = ();
348 type OutMsg = OutMsg;
349 type Error = ();
350
351 async fn handle_call(
352 self,
353 message: Self::CallMsg,
354 _: &GenServerHandle<Self>,
355 ) -> CallResponse<Self> {
356 match message {
357 InMessage::GetCount => {
358 let count = self.count;
359 CallResponse::Reply(self, OutMsg::Count(count))
360 }
361 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
362 }
363 }
364
365 async fn handle_cast(
366 mut self,
367 _: Self::CastMsg,
368 handle: &GenServerHandle<Self>,
369 ) -> CastResponse<Self> {
370 self.count += 1;
371 println!("{:?}: good still alive", thread::current().id());
372 send_after(Duration::from_millis(100), handle.to_owned(), ());
373 CastResponse::NoReply(self)
374 }
375 }
376
377 #[test]
378 pub fn badly_behaved_thread_non_blocking() {
379 let runtime = rt::Runtime::new().unwrap();
380 runtime.block_on(async move {
381 let mut badboy = BadlyBehavedTask.start();
382 let _ = badboy.cast(()).await;
383 let mut goodboy = WellBehavedTask { count: 0 }.start();
384 let _ = goodboy.cast(()).await;
385 rt::sleep(Duration::from_secs(1)).await;
386 let count = goodboy.call(InMessage::GetCount).await.unwrap();
387
388 match count {
389 OutMsg::Count(num) => {
390 assert_ne!(num, 10);
391 }
392 }
393 goodboy.call(InMessage::Stop).await.unwrap();
394 });
395 }
396
397 #[test]
398 pub fn badly_behaved_thread() {
399 let runtime = rt::Runtime::new().unwrap();
400 runtime.block_on(async move {
401 let mut badboy = BadlyBehavedTask.start_blocking();
402 let _ = badboy.cast(()).await;
403 let mut goodboy = WellBehavedTask { count: 0 }.start();
404 let _ = goodboy.cast(()).await;
405 rt::sleep(Duration::from_secs(1)).await;
406 let count = goodboy.call(InMessage::GetCount).await.unwrap();
407
408 match count {
409 OutMsg::Count(num) => {
410 assert_eq!(num, 10);
411 }
412 }
413 goodboy.call(InMessage::Stop).await.unwrap();
414 });
415 }
416
417 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
418
419 #[derive(Debug, Default, Clone)]
420 struct SomeTask;
421
422 #[derive(Clone)]
423 enum SomeTaskCallMsg {
424 SlowOperation,
425 FastOperation,
426 }
427
428 impl GenServer for SomeTask {
429 type CallMsg = SomeTaskCallMsg;
430 type CastMsg = ();
431 type OutMsg = ();
432 type Error = ();
433
434 async fn handle_call(
435 self,
436 message: Self::CallMsg,
437 _handle: &GenServerHandle<Self>,
438 ) -> CallResponse<Self> {
439 match message {
440 SomeTaskCallMsg::SlowOperation => {
441 rt::sleep(TIMEOUT_DURATION * 2).await;
443 CallResponse::Reply(self, ())
444 }
445 SomeTaskCallMsg::FastOperation => {
446 rt::sleep(TIMEOUT_DURATION / 2).await;
448 CallResponse::Reply(self, ())
449 }
450 }
451 }
452 }
453
454 #[test]
455 pub fn unresolving_task_times_out() {
456 let runtime = rt::Runtime::new().unwrap();
457 runtime.block_on(async move {
458 let mut unresolving_task = SomeTask.start();
459
460 let result = unresolving_task
461 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
462 .await;
463 assert!(matches!(result, Ok(())));
464
465 let result = unresolving_task
466 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
467 .await;
468 assert!(matches!(result, Err(GenServerError::CallTimeout)));
469 });
470 }
471}