1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, timeout, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe, time::Duration};
6
7use crate::error::GenServerError;
8
9const DEFAULT_CALL_TIMEOUT: Duration = Duration::from_secs(5);
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14 cancellation_token: CancellationToken,
16}
17
18impl<G: GenServer> Clone for GenServerHandle<G> {
19 fn clone(&self) -> Self {
20 Self {
21 tx: self.tx.clone(),
22 cancellation_token: self.cancellation_token.clone(),
23 }
24 }
25}
26
27impl<G: GenServer> GenServerHandle<G> {
28 pub(crate) fn new(initial_state: G::State) -> Self {
29 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
30 let cancellation_token = CancellationToken::new();
31 let handle = GenServerHandle {
32 tx,
33 cancellation_token,
34 };
35 let mut gen_server: G = GenServer::new();
36 let handle_clone = handle.clone();
37 let _join_handle = rt::spawn(async move {
39 if gen_server
40 .run(&handle, &mut rx, initial_state)
41 .await
42 .is_err()
43 {
44 tracing::trace!("GenServer crashed")
45 };
46 });
47 handle_clone
48 }
49
50 pub(crate) fn new_blocking(initial_state: G::State) -> Self {
51 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
52 let cancellation_token = CancellationToken::new();
53 let handle = GenServerHandle {
54 tx,
55 cancellation_token,
56 };
57 let mut gen_server: G = GenServer::new();
58 let handle_clone = handle.clone();
59 let _join_handle = rt::spawn_blocking(|| {
61 rt::block_on(async move {
62 if gen_server
63 .run(&handle, &mut rx, initial_state)
64 .await
65 .is_err()
66 {
67 tracing::trace!("GenServer crashed")
68 };
69 })
70 });
71 handle_clone
72 }
73
74 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
75 self.tx.clone()
76 }
77
78 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
79 self.call_with_timeout(message, DEFAULT_CALL_TIMEOUT).await
80 }
81
82 pub async fn call_with_timeout(
83 &mut self,
84 message: G::CallMsg,
85 duration: Duration,
86 ) -> Result<G::OutMsg, GenServerError> {
87 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
88 self.tx.send(GenServerInMsg::Call {
89 sender: oneshot_tx,
90 message,
91 })?;
92
93 match timeout(duration, oneshot_rx).await {
94 Ok(Ok(result)) => result,
95 Ok(Err(_)) => Err(GenServerError::Server),
96 Err(_) => Err(GenServerError::CallTimeout),
97 }
98 }
99
100 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
101 self.tx
102 .send(GenServerInMsg::Cast { message })
103 .map_err(|_error| GenServerError::Server)
104 }
105
106 pub fn cancellation_token(&self) -> CancellationToken {
107 self.cancellation_token.clone()
108 }
109}
110
111pub enum GenServerInMsg<G: GenServer> {
112 Call {
113 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
114 message: G::CallMsg,
115 },
116 Cast {
117 message: G::CastMsg,
118 },
119}
120
121pub enum CallResponse<G: GenServer> {
122 Reply(G::State, G::OutMsg),
123 Unused,
124 Stop(G::OutMsg),
125}
126
127pub enum CastResponse<G: GenServer> {
128 NoReply(G::State),
129 Unused,
130 Stop,
131}
132
133pub trait GenServer
134where
135 Self: Default + Send + Sized,
136{
137 type CallMsg: Clone + Send + Sized + Sync;
138 type CastMsg: Clone + Send + Sized + Sync;
139 type OutMsg: Send + Sized;
140 type State: Clone + Send;
141 type Error: Debug + Send;
142
143 fn new() -> Self {
144 Self::default()
145 }
146
147 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
148 GenServerHandle::new(initial_state)
149 }
150
151 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
157 GenServerHandle::new_blocking(initial_state)
158 }
159
160 fn run(
161 &mut self,
162 handle: &GenServerHandle<Self>,
163 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
164 state: Self::State,
165 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
166 async {
167 let init_result = self
168 .init(handle, state.clone())
169 .await
170 .inspect_err(|err| tracing::error!("Initialization failed: {err:?}"));
171
172 let res = match init_result {
173 Ok(new_state) => self.main_loop(handle, rx, new_state).await,
174 Err(_) => Err(GenServerError::Initialization),
175 };
176
177 handle.cancellation_token().cancel();
178 if let Err(err) = self.teardown(handle, state).await {
179 tracing::error!("Error during teardown: {err:?}");
180 }
181 res
182 }
183 }
184
185 fn init(
189 &mut self,
190 _handle: &GenServerHandle<Self>,
191 state: Self::State,
192 ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
193 async { Ok(state) }
194 }
195
196 fn main_loop(
197 &mut self,
198 handle: &GenServerHandle<Self>,
199 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
200 mut state: Self::State,
201 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
202 async {
203 loop {
204 let (new_state, cont) = self.receive(handle, rx, state).await?;
205 state = new_state;
206 if !cont {
207 break;
208 }
209 }
210 tracing::trace!("Stopping GenServer");
211 Ok(())
212 }
213 }
214
215 fn receive(
216 &mut self,
217 handle: &GenServerHandle<Self>,
218 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
219 state: Self::State,
220 ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
221 async move {
222 let message = rx.recv().await;
223
224 let state_clone = state.clone();
226
227 let (keep_running, new_state) = match message {
228 Some(GenServerInMsg::Call { sender, message }) => {
229 let (keep_running, new_state, response) =
230 match AssertUnwindSafe(self.handle_call(message, handle, state))
231 .catch_unwind()
232 .await
233 {
234 Ok(response) => match response {
235 CallResponse::Reply(new_state, response) => {
236 (true, new_state, Ok(response))
237 }
238 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
239 CallResponse::Unused => {
240 tracing::error!("GenServer received unexpected CallMessage");
241 (false, state_clone, Err(GenServerError::CallMsgUnused))
242 }
243 },
244 Err(error) => {
245 tracing::error!(
246 "Error in callback, reverting state - Error: '{error:?}'"
247 );
248 (true, state_clone, Err(GenServerError::Callback))
249 }
250 };
251 if sender.send(response).is_err() {
253 tracing::error!(
254 "GenServer failed to send response back, client must have died"
255 )
256 };
257 (keep_running, new_state)
258 }
259 Some(GenServerInMsg::Cast { message }) => {
260 match AssertUnwindSafe(self.handle_cast(message, handle, state))
261 .catch_unwind()
262 .await
263 {
264 Ok(response) => match response {
265 CastResponse::NoReply(new_state) => (true, new_state),
266 CastResponse::Stop => (false, state_clone),
267 CastResponse::Unused => {
268 tracing::error!("GenServer received unexpected CastMessage");
269 (false, state_clone)
270 }
271 },
272 Err(error) => {
273 tracing::trace!(
274 "Error in callback, reverting state - Error: '{error:?}'"
275 );
276 (true, state_clone)
277 }
278 }
279 }
280 None => {
281 (false, state)
283 }
284 };
285 Ok((new_state, keep_running))
286 }
287 }
288
289 fn handle_call(
290 &mut self,
291 _message: Self::CallMsg,
292 _handle: &GenServerHandle<Self>,
293 _state: Self::State,
294 ) -> impl Future<Output = CallResponse<Self>> + Send {
295 async { CallResponse::Unused }
296 }
297
298 fn handle_cast(
299 &mut self,
300 _message: Self::CastMsg,
301 _handle: &GenServerHandle<Self>,
302 _state: Self::State,
303 ) -> impl Future<Output = CastResponse<Self>> + Send {
304 async { CastResponse::Unused }
305 }
306
307 fn teardown(
311 &mut self,
312 _handle: &GenServerHandle<Self>,
313 _state: Self::State,
314 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
315 async { Ok(()) }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321
322 use super::*;
323 use crate::tasks::send_after;
324 use std::{thread, time::Duration};
325
326 #[derive(Default)]
327 struct BadlyBehavedTask;
328
329 #[derive(Clone)]
330 pub enum InMessage {
331 GetCount,
332 Stop,
333 }
334 #[derive(Clone)]
335 pub enum OutMsg {
336 Count(u64),
337 }
338
339 impl GenServer for BadlyBehavedTask {
340 type CallMsg = InMessage;
341 type CastMsg = ();
342 type OutMsg = ();
343 type State = ();
344 type Error = ();
345
346 async fn handle_call(
347 &mut self,
348 _: Self::CallMsg,
349 _: &GenServerHandle<Self>,
350 _: Self::State,
351 ) -> CallResponse<Self> {
352 CallResponse::Stop(())
353 }
354
355 async fn handle_cast(
356 &mut self,
357 _: Self::CastMsg,
358 _: &GenServerHandle<Self>,
359 _: Self::State,
360 ) -> CastResponse<Self> {
361 rt::sleep(Duration::from_millis(20)).await;
362 thread::sleep(Duration::from_secs(2));
363 CastResponse::Stop
364 }
365 }
366
367 #[derive(Default)]
368 struct WellBehavedTask;
369
370 #[derive(Clone)]
371 struct CountState {
372 pub count: u64,
373 }
374
375 impl GenServer for WellBehavedTask {
376 type CallMsg = InMessage;
377 type CastMsg = ();
378 type OutMsg = OutMsg;
379 type State = CountState;
380 type Error = ();
381
382 async fn handle_call(
383 &mut self,
384 message: Self::CallMsg,
385 _: &GenServerHandle<Self>,
386 state: Self::State,
387 ) -> CallResponse<Self> {
388 match message {
389 InMessage::GetCount => {
390 let count = state.count;
391 CallResponse::Reply(state, OutMsg::Count(count))
392 }
393 InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
394 }
395 }
396
397 async fn handle_cast(
398 &mut self,
399 _: Self::CastMsg,
400 handle: &GenServerHandle<Self>,
401 mut state: Self::State,
402 ) -> CastResponse<Self> {
403 state.count += 1;
404 println!("{:?}: good still alive", thread::current().id());
405 send_after(Duration::from_millis(100), handle.to_owned(), ());
406 CastResponse::NoReply(state)
407 }
408 }
409
410 #[test]
411 pub fn badly_behaved_thread_non_blocking() {
412 let runtime = rt::Runtime::new().unwrap();
413 runtime.block_on(async move {
414 let mut badboy = BadlyBehavedTask::start(());
415 let _ = badboy.cast(()).await;
416 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
417 let _ = goodboy.cast(()).await;
418 rt::sleep(Duration::from_secs(1)).await;
419 let count = goodboy.call(InMessage::GetCount).await.unwrap();
420
421 match count {
422 OutMsg::Count(num) => {
423 assert_ne!(num, 10);
424 }
425 }
426 goodboy.call(InMessage::Stop).await.unwrap();
427 });
428 }
429
430 #[test]
431 pub fn badly_behaved_thread() {
432 let runtime = rt::Runtime::new().unwrap();
433 runtime.block_on(async move {
434 let mut badboy = BadlyBehavedTask::start_blocking(());
435 let _ = badboy.cast(()).await;
436 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
437 let _ = goodboy.cast(()).await;
438 rt::sleep(Duration::from_secs(1)).await;
439 let count = goodboy.call(InMessage::GetCount).await.unwrap();
440
441 match count {
442 OutMsg::Count(num) => {
443 assert_eq!(num, 10);
444 }
445 }
446 goodboy.call(InMessage::Stop).await.unwrap();
447 });
448 }
449
450 const TIMEOUT_DURATION: Duration = Duration::from_millis(100);
451
452 #[derive(Default)]
453 struct SomeTask;
454
455 #[derive(Clone)]
456 enum SomeTaskCallMsg {
457 SlowOperation,
458 FastOperation,
459 }
460
461 impl GenServer for SomeTask {
462 type CallMsg = SomeTaskCallMsg;
463 type CastMsg = ();
464 type OutMsg = ();
465 type State = ();
466 type Error = ();
467
468 async fn handle_call(
469 &mut self,
470 message: Self::CallMsg,
471 _handle: &GenServerHandle<Self>,
472 _state: Self::State,
473 ) -> CallResponse<Self> {
474 match message {
475 SomeTaskCallMsg::SlowOperation => {
476 rt::sleep(TIMEOUT_DURATION * 2).await;
478 CallResponse::Reply((), ())
479 }
480 SomeTaskCallMsg::FastOperation => {
481 rt::sleep(TIMEOUT_DURATION / 2).await;
483 CallResponse::Reply((), ())
484 }
485 }
486 }
487 }
488
489 #[test]
490 pub fn unresolving_task_times_out() {
491 let runtime = rt::Runtime::new().unwrap();
492 runtime.block_on(async move {
493 let mut unresolving_task = SomeTask::start(());
494
495 let result = unresolving_task
496 .call_with_timeout(SomeTaskCallMsg::FastOperation, TIMEOUT_DURATION)
497 .await;
498 assert!(matches!(result, Ok(())));
499
500 let result = unresolving_task
501 .call_with_timeout(SomeTaskCallMsg::SlowOperation, TIMEOUT_DURATION)
502 .await;
503 assert!(matches!(result, Err(GenServerError::CallTimeout)));
504 });
505 }
506}