spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9pub struct GenServerHandle<G: GenServer + 'static> {
10 pub tx: mpsc::Sender<GenServerInMsg<G>>,
11 cancellation_token: CancellationToken,
13}
14
15impl<G: GenServer> Clone for GenServerHandle<G> {
16 fn clone(&self) -> Self {
17 Self {
18 tx: self.tx.clone(),
19 cancellation_token: self.cancellation_token.clone(),
20 }
21 }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25 pub(crate) fn new(initial_state: G::State) -> Self {
26 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27 let cancellation_token = CancellationToken::new();
28 let handle = GenServerHandle {
29 tx,
30 cancellation_token,
31 };
32 let mut gen_server: G = GenServer::new();
33 let handle_clone = handle.clone();
34 let _join_handle = rt::spawn(async move {
36 if gen_server
37 .run(&handle, &mut rx, initial_state)
38 .await
39 .is_err()
40 {
41 tracing::trace!("GenServer crashed")
42 };
43 });
44 handle_clone
45 }
46
47 pub(crate) fn new_blocking(initial_state: G::State) -> Self {
48 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
49 let cancellation_token = CancellationToken::new();
50 let handle = GenServerHandle {
51 tx,
52 cancellation_token,
53 };
54 let mut gen_server: G = GenServer::new();
55 let handle_clone = handle.clone();
56 let _join_handle = rt::spawn_blocking(|| {
58 rt::block_on(async move {
59 if gen_server
60 .run(&handle, &mut rx, initial_state)
61 .await
62 .is_err()
63 {
64 tracing::trace!("GenServer crashed")
65 };
66 })
67 });
68 handle_clone
69 }
70
71 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
72 self.tx.clone()
73 }
74
75 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
76 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
77 self.tx.send(GenServerInMsg::Call {
78 sender: oneshot_tx,
79 message,
80 })?;
81 match oneshot_rx.await {
82 Ok(result) => result,
83 Err(_) => Err(GenServerError::Server),
84 }
85 }
86
87 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
88 self.tx
89 .send(GenServerInMsg::Cast { message })
90 .map_err(|_error| GenServerError::Server)
91 }
92
93 pub fn cancellation_token(&self) -> CancellationToken {
94 self.cancellation_token.clone()
95 }
96}
97
98pub enum GenServerInMsg<G: GenServer> {
99 Call {
100 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
101 message: G::CallMsg,
102 },
103 Cast {
104 message: G::CastMsg,
105 },
106}
107
108pub enum CallResponse<G: GenServer> {
109 Reply(G::State, G::OutMsg),
110 Unused,
111 Stop(G::OutMsg),
112}
113
114pub enum CastResponse<G: GenServer> {
115 NoReply(G::State),
116 Unused,
117 Stop,
118}
119
120pub trait GenServer
121where
122 Self: Send + Sized,
123{
124 type CallMsg: Clone + Send + Sized + Sync;
125 type CastMsg: Clone + Send + Sized + Sync;
126 type OutMsg: Send + Sized;
127 type State: Clone + Send;
128 type Error: Debug + Send;
129
130 fn new() -> Self;
131
132 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
133 GenServerHandle::new(initial_state)
134 }
135
136 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
142 GenServerHandle::new_blocking(initial_state)
143 }
144
145 fn run(
146 &mut self,
147 handle: &GenServerHandle<Self>,
148 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
149 state: Self::State,
150 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
151 async {
152 match self.init(handle, state).await {
153 Ok(new_state) => {
154 self.main_loop(handle, rx, new_state).await?;
155 Ok(())
156 }
157 Err(err) => {
158 tracing::error!("Initialization failed: {err:?}");
159 Err(GenServerError::Initialization)
160 }
161 }
162 }
163 }
164
165 fn init(
169 &mut self,
170 _handle: &GenServerHandle<Self>,
171 state: Self::State,
172 ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
173 async { Ok(state) }
174 }
175
176 fn main_loop(
177 &mut self,
178 handle: &GenServerHandle<Self>,
179 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
180 mut state: Self::State,
181 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
182 async {
183 loop {
184 let (new_state, cont) = self.receive(handle, rx, state).await?;
185 state = new_state;
186 if !cont {
187 break;
188 }
189 }
190 tracing::trace!("Stopping GenServer");
191 handle.cancellation_token().cancel();
192 if let Err(err) = self.teardown(handle, state).await {
193 tracing::error!("Error during teardown: {err:?}");
194 }
195 Ok(())
196 }
197 }
198
199 fn receive(
200 &mut self,
201 handle: &GenServerHandle<Self>,
202 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
203 state: Self::State,
204 ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
205 async move {
206 let message = rx.recv().await;
207
208 let state_clone = state.clone();
210
211 let (keep_running, new_state) = match message {
212 Some(GenServerInMsg::Call { sender, message }) => {
213 let (keep_running, new_state, response) =
214 match AssertUnwindSafe(self.handle_call(message, handle, state))
215 .catch_unwind()
216 .await
217 {
218 Ok(response) => match response {
219 CallResponse::Reply(new_state, response) => {
220 (true, new_state, Ok(response))
221 }
222 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
223 CallResponse::Unused => {
224 tracing::error!("GenServer received unexpected CallMessage");
225 (false, state_clone, Err(GenServerError::CallMsgUnused))
226 }
227 },
228 Err(error) => {
229 tracing::error!(
230 "Error in callback, reverting state - Error: '{error:?}'"
231 );
232 (true, state_clone, Err(GenServerError::Callback))
233 }
234 };
235 if sender.send(response).is_err() {
237 tracing::error!(
238 "GenServer failed to send response back, client must have died"
239 )
240 };
241 (keep_running, new_state)
242 }
243 Some(GenServerInMsg::Cast { message }) => {
244 match AssertUnwindSafe(self.handle_cast(message, handle, state))
245 .catch_unwind()
246 .await
247 {
248 Ok(response) => match response {
249 CastResponse::NoReply(new_state) => (true, new_state),
250 CastResponse::Stop => (false, state_clone),
251 CastResponse::Unused => {
252 tracing::error!("GenServer received unexpected CastMessage");
253 (false, state_clone)
254 }
255 },
256 Err(error) => {
257 tracing::trace!(
258 "Error in callback, reverting state - Error: '{error:?}'"
259 );
260 (true, state_clone)
261 }
262 }
263 }
264 None => {
265 (false, state)
267 }
268 };
269 Ok((new_state, keep_running))
270 }
271 }
272
273 fn handle_call(
274 &mut self,
275 _message: Self::CallMsg,
276 _handle: &GenServerHandle<Self>,
277 _state: Self::State,
278 ) -> impl Future<Output = CallResponse<Self>> + Send {
279 async { CallResponse::Unused }
280 }
281
282 fn handle_cast(
283 &mut self,
284 _message: Self::CastMsg,
285 _handle: &GenServerHandle<Self>,
286 _state: Self::State,
287 ) -> impl Future<Output = CastResponse<Self>> + Send {
288 async { CastResponse::Unused }
289 }
290
291 fn teardown(
295 &mut self,
296 _handle: &GenServerHandle<Self>,
297 _state: Self::State,
298 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
299 async { Ok(()) }
300 }
301}
302
303#[cfg(test)]
304mod tests {
305
306 use super::*;
307 use crate::tasks::send_after;
308 use std::{thread, time::Duration};
309 struct BadlyBehavedTask;
310
311 #[derive(Clone)]
312 pub enum InMessage {
313 GetCount,
314 Stop,
315 }
316 #[derive(Clone)]
317 pub enum OutMsg {
318 Count(u64),
319 }
320
321 impl GenServer for BadlyBehavedTask {
322 type CallMsg = InMessage;
323 type CastMsg = ();
324 type OutMsg = ();
325 type State = ();
326 type Error = ();
327
328 fn new() -> Self {
329 Self {}
330 }
331
332 async fn handle_call(
333 &mut self,
334 _: Self::CallMsg,
335 _: &GenServerHandle<Self>,
336 _: Self::State,
337 ) -> CallResponse<Self> {
338 CallResponse::Stop(())
339 }
340
341 async fn handle_cast(
342 &mut self,
343 _: Self::CastMsg,
344 _: &GenServerHandle<Self>,
345 _: Self::State,
346 ) -> CastResponse<Self> {
347 rt::sleep(Duration::from_millis(20)).await;
348 thread::sleep(Duration::from_secs(2));
349 CastResponse::Stop
350 }
351 }
352
353 struct WellBehavedTask;
354
355 #[derive(Clone)]
356 struct CountState {
357 pub count: u64,
358 }
359
360 impl GenServer for WellBehavedTask {
361 type CallMsg = InMessage;
362 type CastMsg = ();
363 type OutMsg = OutMsg;
364 type State = CountState;
365 type Error = ();
366
367 fn new() -> Self {
368 Self {}
369 }
370
371 async fn handle_call(
372 &mut self,
373 message: Self::CallMsg,
374 _: &GenServerHandle<Self>,
375 state: Self::State,
376 ) -> CallResponse<Self> {
377 match message {
378 InMessage::GetCount => {
379 let count = state.count;
380 CallResponse::Reply(state, OutMsg::Count(count))
381 }
382 InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
383 }
384 }
385
386 async fn handle_cast(
387 &mut self,
388 _: Self::CastMsg,
389 handle: &GenServerHandle<Self>,
390 mut state: Self::State,
391 ) -> CastResponse<Self> {
392 state.count += 1;
393 println!("{:?}: good still alive", thread::current().id());
394 send_after(Duration::from_millis(100), handle.to_owned(), ());
395 CastResponse::NoReply(state)
396 }
397 }
398
399 #[test]
400 pub fn badly_behaved_thread_non_blocking() {
401 let runtime = rt::Runtime::new().unwrap();
402 runtime.block_on(async move {
403 let mut badboy = BadlyBehavedTask::start(());
404 let _ = badboy.cast(()).await;
405 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
406 let _ = goodboy.cast(()).await;
407 rt::sleep(Duration::from_secs(1)).await;
408 let count = goodboy.call(InMessage::GetCount).await.unwrap();
409
410 match count {
411 OutMsg::Count(num) => {
412 assert_ne!(num, 10);
413 }
414 }
415 goodboy.call(InMessage::Stop).await.unwrap();
416 });
417 }
418
419 #[test]
420 pub fn badly_behaved_thread() {
421 let runtime = rt::Runtime::new().unwrap();
422 runtime.block_on(async move {
423 let mut badboy = BadlyBehavedTask::start_blocking(());
424 let _ = badboy.cast(()).await;
425 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
426 let _ = goodboy.cast(()).await;
427 rt::sleep(Duration::from_secs(1)).await;
428 let count = goodboy.call(InMessage::GetCount).await.unwrap();
429
430 match count {
431 OutMsg::Count(num) => {
432 assert_eq!(num, 10);
433 }
434 }
435 goodboy.call(InMessage::Stop).await.unwrap();
436 });
437 }
438}