spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot, CancellationToken};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9#[derive(Debug)]
10pub struct GenServerHandle<G: GenServer + 'static> {
11 pub tx: mpsc::Sender<GenServerInMsg<G>>,
12 cancellation_token: CancellationToken,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17 fn clone(&self) -> Self {
18 Self {
19 tx: self.tx.clone(),
20 cancellation_token: self.cancellation_token.clone(),
21 }
22 }
23}
24
25impl<G: GenServer> GenServerHandle<G> {
26 pub(crate) fn new(initial_state: G::State) -> Self {
27 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
28 let cancellation_token = CancellationToken::new();
29 let handle = GenServerHandle {
30 tx,
31 cancellation_token,
32 };
33 let mut gen_server: G = GenServer::new();
34 let handle_clone = handle.clone();
35 let _join_handle = rt::spawn(async move {
37 if gen_server
38 .run(&handle, &mut rx, initial_state)
39 .await
40 .is_err()
41 {
42 tracing::trace!("GenServer crashed")
43 };
44 });
45 handle_clone
46 }
47
48 pub(crate) fn new_blocking(initial_state: G::State) -> Self {
49 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
50 let cancellation_token = CancellationToken::new();
51 let handle = GenServerHandle {
52 tx,
53 cancellation_token,
54 };
55 let mut gen_server: G = GenServer::new();
56 let handle_clone = handle.clone();
57 let _join_handle = rt::spawn_blocking(|| {
59 rt::block_on(async move {
60 if gen_server
61 .run(&handle, &mut rx, initial_state)
62 .await
63 .is_err()
64 {
65 tracing::trace!("GenServer crashed")
66 };
67 })
68 });
69 handle_clone
70 }
71
72 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
73 self.tx.clone()
74 }
75
76 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
77 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
78 self.tx.send(GenServerInMsg::Call {
79 sender: oneshot_tx,
80 message,
81 })?;
82 match oneshot_rx.await {
83 Ok(result) => result,
84 Err(_) => Err(GenServerError::Server),
85 }
86 }
87
88 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
89 self.tx
90 .send(GenServerInMsg::Cast { message })
91 .map_err(|_error| GenServerError::Server)
92 }
93
94 pub fn cancellation_token(&self) -> CancellationToken {
95 self.cancellation_token.clone()
96 }
97}
98
99pub enum GenServerInMsg<G: GenServer> {
100 Call {
101 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
102 message: G::CallMsg,
103 },
104 Cast {
105 message: G::CastMsg,
106 },
107}
108
109pub enum CallResponse<G: GenServer> {
110 Reply(G::State, G::OutMsg),
111 Unused,
112 Stop(G::OutMsg),
113}
114
115pub enum CastResponse<G: GenServer> {
116 NoReply(G::State),
117 Unused,
118 Stop,
119}
120
121pub trait GenServer
122where
123 Self: Send + Sized,
124{
125 type CallMsg: Clone + Send + Sized + Sync;
126 type CastMsg: Clone + Send + Sized + Sync;
127 type OutMsg: Send + Sized;
128 type State: Clone + Send;
129 type Error: Debug + Send;
130
131 fn new() -> Self;
132
133 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
134 GenServerHandle::new(initial_state)
135 }
136
137 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
143 GenServerHandle::new_blocking(initial_state)
144 }
145
146 fn run(
147 &mut self,
148 handle: &GenServerHandle<Self>,
149 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
150 state: Self::State,
151 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
152 async {
153 match self.init(handle, state).await {
154 Ok(new_state) => {
155 self.main_loop(handle, rx, new_state).await?;
156 Ok(())
157 }
158 Err(err) => {
159 tracing::error!("Initialization failed: {err:?}");
160 Err(GenServerError::Initialization)
161 }
162 }
163 }
164 }
165
166 fn init(
170 &mut self,
171 _handle: &GenServerHandle<Self>,
172 state: Self::State,
173 ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
174 async { Ok(state) }
175 }
176
177 fn main_loop(
178 &mut self,
179 handle: &GenServerHandle<Self>,
180 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
181 mut state: Self::State,
182 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
183 async {
184 loop {
185 let (new_state, cont) = self.receive(handle, rx, state).await?;
186 state = new_state;
187 if !cont {
188 break;
189 }
190 }
191 tracing::trace!("Stopping GenServer");
192 handle.cancellation_token().cancel();
193 if let Err(err) = self.teardown(handle, state).await {
194 tracing::error!("Error during teardown: {err:?}");
195 }
196 Ok(())
197 }
198 }
199
200 fn receive(
201 &mut self,
202 handle: &GenServerHandle<Self>,
203 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
204 state: Self::State,
205 ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
206 async move {
207 let message = rx.recv().await;
208
209 let state_clone = state.clone();
211
212 let (keep_running, new_state) = match message {
213 Some(GenServerInMsg::Call { sender, message }) => {
214 let (keep_running, new_state, response) =
215 match AssertUnwindSafe(self.handle_call(message, handle, state))
216 .catch_unwind()
217 .await
218 {
219 Ok(response) => match response {
220 CallResponse::Reply(new_state, response) => {
221 (true, new_state, Ok(response))
222 }
223 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
224 CallResponse::Unused => {
225 tracing::error!("GenServer received unexpected CallMessage");
226 (false, state_clone, Err(GenServerError::CallMsgUnused))
227 }
228 },
229 Err(error) => {
230 tracing::error!(
231 "Error in callback, reverting state - Error: '{error:?}'"
232 );
233 (true, state_clone, Err(GenServerError::Callback))
234 }
235 };
236 if sender.send(response).is_err() {
238 tracing::error!(
239 "GenServer failed to send response back, client must have died"
240 )
241 };
242 (keep_running, new_state)
243 }
244 Some(GenServerInMsg::Cast { message }) => {
245 match AssertUnwindSafe(self.handle_cast(message, handle, state))
246 .catch_unwind()
247 .await
248 {
249 Ok(response) => match response {
250 CastResponse::NoReply(new_state) => (true, new_state),
251 CastResponse::Stop => (false, state_clone),
252 CastResponse::Unused => {
253 tracing::error!("GenServer received unexpected CastMessage");
254 (false, state_clone)
255 }
256 },
257 Err(error) => {
258 tracing::trace!(
259 "Error in callback, reverting state - Error: '{error:?}'"
260 );
261 (true, state_clone)
262 }
263 }
264 }
265 None => {
266 (false, state)
268 }
269 };
270 Ok((new_state, keep_running))
271 }
272 }
273
274 fn handle_call(
275 &mut self,
276 _message: Self::CallMsg,
277 _handle: &GenServerHandle<Self>,
278 _state: Self::State,
279 ) -> impl Future<Output = CallResponse<Self>> + Send {
280 async { CallResponse::Unused }
281 }
282
283 fn handle_cast(
284 &mut self,
285 _message: Self::CastMsg,
286 _handle: &GenServerHandle<Self>,
287 _state: Self::State,
288 ) -> impl Future<Output = CastResponse<Self>> + Send {
289 async { CastResponse::Unused }
290 }
291
292 fn teardown(
296 &mut self,
297 _handle: &GenServerHandle<Self>,
298 _state: Self::State,
299 ) -> impl Future<Output = Result<(), Self::Error>> + Send {
300 async { Ok(()) }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306
307 use super::*;
308 use crate::tasks::send_after;
309 use std::{thread, time::Duration};
310 struct BadlyBehavedTask;
311
312 #[derive(Clone)]
313 pub enum InMessage {
314 GetCount,
315 Stop,
316 }
317 #[derive(Clone)]
318 pub enum OutMsg {
319 Count(u64),
320 }
321
322 impl GenServer for BadlyBehavedTask {
323 type CallMsg = InMessage;
324 type CastMsg = ();
325 type OutMsg = ();
326 type State = ();
327 type Error = ();
328
329 fn new() -> Self {
330 Self {}
331 }
332
333 async fn handle_call(
334 &mut self,
335 _: Self::CallMsg,
336 _: &GenServerHandle<Self>,
337 _: Self::State,
338 ) -> CallResponse<Self> {
339 CallResponse::Stop(())
340 }
341
342 async fn handle_cast(
343 &mut self,
344 _: Self::CastMsg,
345 _: &GenServerHandle<Self>,
346 _: Self::State,
347 ) -> CastResponse<Self> {
348 rt::sleep(Duration::from_millis(20)).await;
349 thread::sleep(Duration::from_secs(2));
350 CastResponse::Stop
351 }
352 }
353
354 struct WellBehavedTask;
355
356 #[derive(Clone)]
357 struct CountState {
358 pub count: u64,
359 }
360
361 impl GenServer for WellBehavedTask {
362 type CallMsg = InMessage;
363 type CastMsg = ();
364 type OutMsg = OutMsg;
365 type State = CountState;
366 type Error = ();
367
368 fn new() -> Self {
369 Self {}
370 }
371
372 async fn handle_call(
373 &mut self,
374 message: Self::CallMsg,
375 _: &GenServerHandle<Self>,
376 state: Self::State,
377 ) -> CallResponse<Self> {
378 match message {
379 InMessage::GetCount => {
380 let count = state.count;
381 CallResponse::Reply(state, OutMsg::Count(count))
382 }
383 InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
384 }
385 }
386
387 async fn handle_cast(
388 &mut self,
389 _: Self::CastMsg,
390 handle: &GenServerHandle<Self>,
391 mut state: Self::State,
392 ) -> CastResponse<Self> {
393 state.count += 1;
394 println!("{:?}: good still alive", thread::current().id());
395 send_after(Duration::from_millis(100), handle.to_owned(), ());
396 CastResponse::NoReply(state)
397 }
398 }
399
400 #[test]
401 pub fn badly_behaved_thread_non_blocking() {
402 let runtime = rt::Runtime::new().unwrap();
403 runtime.block_on(async move {
404 let mut badboy = BadlyBehavedTask::start(());
405 let _ = badboy.cast(()).await;
406 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
407 let _ = goodboy.cast(()).await;
408 rt::sleep(Duration::from_secs(1)).await;
409 let count = goodboy.call(InMessage::GetCount).await.unwrap();
410
411 match count {
412 OutMsg::Count(num) => {
413 assert_ne!(num, 10);
414 }
415 }
416 goodboy.call(InMessage::Stop).await.unwrap();
417 });
418 }
419
420 #[test]
421 pub fn badly_behaved_thread() {
422 let runtime = rt::Runtime::new().unwrap();
423 runtime.block_on(async move {
424 let mut badboy = BadlyBehavedTask::start_blocking(());
425 let _ = badboy.cast(()).await;
426 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
427 let _ = goodboy.cast(()).await;
428 rt::sleep(Duration::from_secs(1)).await;
429 let count = goodboy.call(InMessage::GetCount).await.unwrap();
430
431 match count {
432 OutMsg::Count(num) => {
433 assert_eq!(num, 10);
434 }
435 }
436 goodboy.call(InMessage::Stop).await.unwrap();
437 });
438 }
439}