spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14 cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19 fn clone(&self) -> Self {
20 Self {
21 tx: self.tx.clone(),
22 cancellation_token: self.cancellation_token.clone(),
23 }
24 }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28 pub(crate) fn new(gen_server: G) -> Self {
29 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30 let cancellation_token = CancellationToken::new();
31 let handle = GenServerHandle {
32 tx,
33 cancellation_token,
34 };
35 let handle_clone = handle.clone();
36 let _join_handle = rt::spawn(async move {
38 if gen_server.run(&handle, &mut rx).await.is_err() {
39 tracing::trace!("GenServer crashed")
40 };
41 });
42 handle_clone
43 }
44
45 pub(crate) fn new_blocking(gen_server: G) -> Self {
46 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
47 let cancellation_token = CancellationToken::new();
48 let handle = GenServerHandle {
49 tx,
50 cancellation_token,
51 };
52 let handle_clone = handle.clone();
53 let _join_handle = rt::spawn_blocking(|| {
55 rt::block_on(async move {
56 if gen_server.run(&handle, &mut rx).await.is_err() {
57 tracing::trace!("GenServer crashed")
58 };
59 })
60 });
61 handle_clone
62 }
63
64 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
65 self.tx.clone()
66 }
67
68 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
69 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
70 }
71
72 pub async fn call_with_timeout(
73 &mut self,
74 message: G::CallMsg,
75 duration: Duration,
76 ) -> Result<G::OutMsg, GenServerError> {
77 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78 self.tx.send(GenServerInMsg::Call {
79 sender: oneshot_tx,
80 message,
81 })?;
82
83 match timeout(duration, oneshot_rx).await {
84 Ok(Ok(result)) => result,
85 Ok(Err(_)) => Err(GenServerError::Server),
86 Err(_) => Err(GenServerError::CallTimeout),
87 }
88 }
89
90 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
91 self.tx
92 .send(GenServerInMsg::Cast { message })
93 .map_err(|_error| GenServerError::Server)
94 }
95
96 pub fn cancellation_token(&self) -> CancellationToken {
97 self.cancellation_token.clone()
98 }
99}
100
101pub enum GenServerInMsg<G: GenServer> {
102 Call {
103 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
104 message: G::CallMsg,
105 },
106 Cast {
107 message: G::CastMsg,
108 },
109}
110
111pub enum CallResponse<G: GenServer> {
112 Reply(G, G::OutMsg),
113 Unused,
114 Stop(G::OutMsg),
115}
116
117pub enum CastResponse<G: GenServer> {
118 NoReply(G),
119 Unused,
120 Stop,
121}
122
123pub trait GenServer: Send + Sized + Clone {
124 type CallMsg: Clone + Send + Sized + Sync;
125 type CastMsg: Clone + Send + Sized + Sync;
126 type OutMsg: Send + Sized;
127 type Error: Debug + Send;
128
129 fn start(self) -> GenServerHandle<Self> {
130 GenServerHandle::new(self)
131 }
132
133 fn start_blocking(self) -> GenServerHandle<Self> {
139 GenServerHandle::new_blocking(self)
140 }
141
142 fn run(
143 self,
144 handle: &GenServerHandle<Self>,
145 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
146 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
147 async {
148 let init_result = self
149 .init(handle)
150 .await
151 .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
152
153 let res = match init_result {
154 Ok(new_state) => new_state.main_loop(handle, rx).await,
155 Err(_) => Err(GenServerError::Initialization),
156 };
157
158 handle.cancellation_token().cancel();
159 if let Ok(final_state) = res {
160 if let Err(err) = final_state.teardown(handle).await {
161 tracing::error!("Error during teardown: {err:?}");
162 }
163 }
164 Ok(())
165 }
166 }
167
168 fn init(
172 self,
173 _handle: &GenServerHandle<Self>,
174 ) -> impl Future<Output = Result<Self, Self::Error>> + Send {
175 async { Ok(self) }
176 }
177
178 fn main_loop(
179 mut self,
180 handle: &GenServerHandle<Self>,
181 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
182 ) -> impl Future<Output = Result<Self, GenServerError>> + Send {
183 async {
184 loop {
185 let (new_state, cont) = self.receive(handle, rx).await?;
186 self = new_state;
187 if !cont {
188 break;
189 }
190 }
191 tracing::trace!("Stopping GenServer");
192 Ok(self)
193 }
194 }
195
196 fn receive(
197 self,
198 handle: &GenServerHandle<Self>,
199 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200 ) -> impl Future<Output = Result<(Self, bool), GenServerError>> + Send {
201 async move {
202 let message = rx.recv().await;
203
204 let state_clone = self.clone();
206
207 let (keep_running, new_state) = match message {
208 Some(GenServerInMsg::Call { sender, message }) => {
209 let (keep_running, new_state, response) =
210 match AssertUnwindSafe(self.handle_call(message, handle))
211 .catch_unwind()
212 .await
213 {
214 Ok(response) => match response {
215 CallResponse::Reply(new_state, response) => {
216 (true, new_state, Ok(response))
217 }
218 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
219 CallResponse::Unused => {
220 tracing::error!("GenServer received unexpected CallMessage");
221 (false, state_clone, Err(GenServerError::CallMsgUnused))
222 }
223 },
224 Err(error) => {
225 tracing::error!(
226 "Error in callback, reverting state - Error: '{error:?}'"
227 );
228 (true, state_clone, Err(GenServerError::Callback))
229 }
230 };
231 if sender.send(response).is_err() {
233 tracing::error!(
234 "GenServer failed to send response back, client must have died"
235 )
236 };
237 (keep_running, new_state)
238 }
239 Some(GenServerInMsg::Cast { message }) => {
240 match AssertUnwindSafe(self.handle_cast(message, handle))
241 .catch_unwind()
242 .await
243 {
244 Ok(response) => match response {
245 CastResponse::NoReply(new_state) => (true, new_state),
246 CastResponse::Stop => (false, state_clone),
247 CastResponse::Unused => {
248 tracing::error!("GenServer received unexpected CastMessage");
249 (false, state_clone)
250 }
251 },
252 Err(error) => {
253 tracing::trace!(
254 "Error in callback, reverting state - Error: '{error:?}'"
255 );
256 (true, state_clone)
257 }
258 }
259 }
260 None => {
261 (false, self)
263 }
264 };
265 Ok((new_state, keep_running))
266 }
267 }
268
269 fn handle_call(
270 self,
271 _message: Self::CallMsg,
272 _handle: &GenServerHandle<Self>,
273 ) -> impl Future<Output = CallResponse<Self>> + Send {
274 async { CallResponse::Unused }
275 }
276
277 fn handle_cast(
278 self,
279 _message: Self::CastMsg,
280 _handle: &GenServerHandle<Self>,
281 ) -> impl Future<Output = CastResponse<Self>> + Send {
282 async { CastResponse::Unused }
283 }
284
285 fn teardown(
289 self,
290 _handle: &GenServerHandle<Self>,
291 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
292 async { Ok(()) }
293 }
294}
295
296#[cfg(test)]
297mod tests {
298
299 use super::*;
300 use crate::tasks::send_after;
301 use std::{thread, time::Duration};
302
303 #[derive(Clone)]
304 struct BadlyBehavedTask;
305
306 #[derive(Clone)]
307 pub enum InMessage {
308 GetCount,
309 Stop,
310 }
311 #[derive(Clone)]
312 pub enum OutMsg {
313 Count(u64),
314 }
315
316 impl GenServer for BadlyBehavedTask {
317 type CallMsg = InMessage;
318 type CastMsg = ();
319 type OutMsg = ();
320 type Error = ();
321
322 async fn handle_call(
323 self,
324 _: Self::CallMsg,
325 _: &GenServerHandle<Self>,
326 ) -> CallResponse<Self> {
327 CallResponse::Stop(())
328 }
329
330 async fn handle_cast(
331 self,
332 _: Self::CastMsg,
333 _: &GenServerHandle<Self>,
334 ) -> CastResponse<Self> {
335 rt::sleep(Duration::from_millis(20)).await;
336 thread::sleep(Duration::from_secs(2));
337 CastResponse::Stop
338 }
339 }
340
341 #[derive(Clone)]
342 struct WellBehavedTask {
343 pub count: u64,
344 }
345
346 impl GenServer for WellBehavedTask {
347 type CallMsg = InMessage;
348 type CastMsg = ();
349 type OutMsg = OutMsg;
350 type Error = ();
351
352 async fn handle_call(
353 self,
354 message: Self::CallMsg,
355 _: &GenServerHandle<Self>,
356 ) -> CallResponse<Self> {
357 match message {
358 InMessage::GetCount => {
359 let count = self.count;
360 CallResponse::Reply(self, OutMsg::Count(count))
361 }
362 InMessage::Stop => CallResponse::Stop(OutMsg::Count(self.count)),
363 }
364 }
365
366 async fn handle_cast(
367 mut self,
368 _: Self::CastMsg,
369 handle: &GenServerHandle<Self>,
370 ) -> CastResponse<Self> {
371 self.count += 1;
372 println!("{:?}: good still alive", thread::current().id());
373 send_after(Duration::from_millis(100), handle.to_owned(), ());
374 CastResponse::NoReply(self)
375 }
376 }
377
378 #[test]
379 pub fn badly_behaved_thread_non_blocking() {
380 let runtime = rt::Runtime::new().unwrap();
381 runtime.block_on(async move {
382 let mut badboy = BadlyBehavedTask.start();
383 let _ = badboy.cast(()).await;
384 let mut goodboy = WellBehavedTask { count: 0 }.start();
385 let _ = goodboy.cast(()).await;
386 rt::sleep(Duration::from_secs(1)).await;
387 let count = goodboy.call(InMessage::GetCount).await.unwrap();
388
389 match count {
390 OutMsg::Count(num) => {
391 assert_ne!(num, 10);
392 }
393 }
394 goodboy.call(InMessage::Stop).await.unwrap();
395 });
396 }
397
398 #[test]
399 pub fn badly_behaved_thread() {
400 let runtime = rt::Runtime::new().unwrap();
401 runtime.block_on(async move {
402 let mut badboy = BadlyBehavedTask.start_blocking();
403 let _ = badboy.cast(()).await;
404 let mut goodboy = WellBehavedTask { count: 0 }.start();
405 let _ = goodboy.cast(()).await;
406 rt::sleep(Duration::from_secs(1)).await;
407 let count = goodboy.call(InMessage::GetCount).await.unwrap();
408
409 match count {
410 OutMsg::Count(num) => {
411 assert_eq!(num, 10);
412 }
413 }
414 goodboy.call(InMessage::Stop).await.unwrap();
415 });
416 }
417
418 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
419
420 #[derive(Debug, Default, Clone)]
421 struct SomeTask;
422
423 #[derive(Clone)]
424 enum SomeTaskCallMsg {
425 SlowOperation,
426 FastOperation,
427 }
428
429 impl GenServer for SomeTask {
430 type CallMsg = SomeTaskCallMsg;
431 type CastMsg = ();
432 type OutMsg = ();
433 type Error = ();
434
435 async fn handle_call(
436 self,
437 message: Self::CallMsg,
438 _handle: &GenServerHandle<Self>,
439 ) -> CallResponse<Self> {
440 match message {
441 SomeTaskCallMsg::SlowOperation => {
442 rt::sleep(TIMEOUT_DURATION * 2).await;
444 CallResponse::Reply(self, ())
445 }
446 SomeTaskCallMsg::FastOperation => {
447 rt::sleep(TIMEOUT_DURATION / 2).await;
449 CallResponse::Reply(self, ())
450 }
451 }
452 }
453 }
454
455 #[test]
456 pub fn unresolving_task_times_out() {
457 let runtime = rt::Runtime::new().unwrap();
458 runtime.block_on(async move {
459 let mut unresolving_task = SomeTask.start();
460
461 let result = unresolving_task
462 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
463 .await;
464 assert!(matches!(result, Ok(())));
465
466 let result = unresolving_task
467 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
468 .await;
469 assert!(matches!(result, Err(GenServerError::CallTimeout)));
470 });
471 }
472}