spawned_concurrency/tasks/
gen_server.rs1use futures::future::FutureExt as _;
4use spawned_rt::tasks::{self as rt, mpsc, oneshot};
5use std::{fmt::Debug, future::Future, panic::AssertUnwindSafe};
6
7use crate::error::GenServerError;
8
9#[derive(Debug)]
10pub struct GenServerHandle<G: GenServer + 'static> {
11 pub tx: mpsc::Sender<GenServerInMsg<G>>,
12}
13
14impl<G: GenServer> Clone for GenServerHandle<G> {
15 fn clone(&self) -> Self {
16 Self {
17 tx: self.tx.clone(),
18 }
19 }
20}
21
22impl<G: GenServer> GenServerHandle<G> {
23 pub(crate) fn new(initial_state: G::State) -> Self {
24 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
25 let handle = GenServerHandle { tx };
26 let mut gen_server: G = GenServer::new();
27 let handle_clone = handle.clone();
28 let _join_handle = rt::spawn(async move {
30 if gen_server
31 .run(&handle, &mut rx, initial_state)
32 .await
33 .is_err()
34 {
35 tracing::trace!("GenServer crashed")
36 };
37 });
38 handle_clone
39 }
40
41 pub(crate) fn new_blocking(initial_state: G::State) -> Self {
42 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
43 let handle = GenServerHandle { tx };
44 let mut gen_server: G = GenServer::new();
45 let handle_clone = handle.clone();
46 let _join_handle = rt::spawn_blocking(|| {
48 rt::block_on(async move {
49 if gen_server
50 .run(&handle, &mut rx, initial_state)
51 .await
52 .is_err()
53 {
54 tracing::trace!("GenServer crashed")
55 };
56 })
57 });
58 handle_clone
59 }
60
61 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
62 self.tx.clone()
63 }
64
65 pub async fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
66 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
67 self.tx.send(GenServerInMsg::Call {
68 sender: oneshot_tx,
69 message,
70 })?;
71 match oneshot_rx.await {
72 Ok(result) => result,
73 Err(_) => Err(GenServerError::Server),
74 }
75 }
76
77 pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
78 self.tx
79 .send(GenServerInMsg::Cast { message })
80 .map_err(|_error| GenServerError::Server)
81 }
82}
83
84pub enum GenServerInMsg<G: GenServer> {
85 Call {
86 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
87 message: G::CallMsg,
88 },
89 Cast {
90 message: G::CastMsg,
91 },
92}
93
94pub enum CallResponse<G: GenServer> {
95 Reply(G::State, G::OutMsg),
96 Unused,
97 Stop(G::OutMsg),
98}
99
100pub enum CastResponse<G: GenServer> {
101 NoReply(G::State),
102 Unused,
103 Stop,
104}
105
106pub trait GenServer
107where
108 Self: Send + Sized,
109{
110 type CallMsg: Clone + Send + Sized + Sync;
111 type CastMsg: Clone + Send + Sized + Sync;
112 type OutMsg: Send + Sized;
113 type State: Clone + Send;
114 type Error: Debug + Send;
115
116 fn new() -> Self;
117
118 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
119 GenServerHandle::new(initial_state)
120 }
121
122 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
128 GenServerHandle::new_blocking(initial_state)
129 }
130
131 fn run(
132 &mut self,
133 handle: &GenServerHandle<Self>,
134 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
135 state: Self::State,
136 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
137 async {
138 match self.init(handle, state).await {
139 Ok(new_state) => {
140 self.main_loop(handle, rx, new_state).await?;
141 Ok(())
142 }
143 Err(err) => {
144 tracing::error!("Initialization failed: {err:?}");
145 Err(GenServerError::Initialization)
146 }
147 }
148 }
149 }
150
151 fn init(
155 &mut self,
156 _handle: &GenServerHandle<Self>,
157 state: Self::State,
158 ) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
159 async { Ok(state) }
160 }
161
162 fn main_loop(
163 &mut self,
164 handle: &GenServerHandle<Self>,
165 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
166 mut state: Self::State,
167 ) -> impl Future<Output = Result<(), GenServerError>> + Send {
168 async {
169 loop {
170 let (new_state, cont) = self.receive(handle, rx, state).await?;
171 if !cont {
172 break;
173 }
174 state = new_state;
175 }
176 tracing::trace!("Stopping GenServer");
177 Ok(())
178 }
179 }
180
181 fn receive(
182 &mut self,
183 handle: &GenServerHandle<Self>,
184 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
185 state: Self::State,
186 ) -> impl Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
187 async move {
188 let message = rx.recv().await;
189
190 let state_clone = state.clone();
192
193 let (keep_running, new_state) = match message {
194 Some(GenServerInMsg::Call { sender, message }) => {
195 let (keep_running, new_state, response) =
196 match AssertUnwindSafe(self.handle_call(message, handle, state))
197 .catch_unwind()
198 .await
199 {
200 Ok(response) => match response {
201 CallResponse::Reply(new_state, response) => {
202 (true, new_state, Ok(response))
203 }
204 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
205 CallResponse::Unused => {
206 tracing::error!("GenServer received unexpected CallMessage");
207 (false, state_clone, Err(GenServerError::CallMsgUnused))
208 }
209 },
210 Err(error) => {
211 tracing::error!(
212 "Error in callback, reverting state - Error: '{error:?}'"
213 );
214 (true, state_clone, Err(GenServerError::Callback))
215 }
216 };
217 if sender.send(response).is_err() {
219 tracing::error!(
220 "GenServer failed to send response back, client must have died"
221 )
222 };
223 (keep_running, new_state)
224 }
225 Some(GenServerInMsg::Cast { message }) => {
226 match AssertUnwindSafe(self.handle_cast(message, handle, state))
227 .catch_unwind()
228 .await
229 {
230 Ok(response) => match response {
231 CastResponse::NoReply(new_state) => (true, new_state),
232 CastResponse::Stop => (false, state_clone),
233 CastResponse::Unused => {
234 tracing::error!("GenServer received unexpected CastMessage");
235 (false, state_clone)
236 }
237 },
238 Err(error) => {
239 tracing::trace!(
240 "Error in callback, reverting state - Error: '{error:?}'"
241 );
242 (true, state_clone)
243 }
244 }
245 }
246 None => {
247 (false, state)
249 }
250 };
251 Ok((new_state, keep_running))
252 }
253 }
254
255 fn handle_call(
256 &mut self,
257 _message: Self::CallMsg,
258 _handle: &GenServerHandle<Self>,
259 _state: Self::State,
260 ) -> impl Future<Output = CallResponse<Self>> + Send {
261 async { CallResponse::Unused }
262 }
263
264 fn handle_cast(
265 &mut self,
266 _message: Self::CastMsg,
267 _handle: &GenServerHandle<Self>,
268 _state: Self::State,
269 ) -> impl Future<Output = CastResponse<Self>> + Send {
270 async { CastResponse::Unused }
271 }
272}
273
274#[cfg(test)]
275mod tests {
276
277 use super::*;
278 use crate::tasks::send_after;
279 use std::{thread, time::Duration};
280 struct BadlyBehavedTask;
281
282 #[derive(Clone)]
283 pub enum InMessage {
284 GetCount,
285 Stop,
286 }
287 #[derive(Clone)]
288 pub enum OutMsg {
289 Count(u64),
290 }
291
292 impl GenServer for BadlyBehavedTask {
293 type CallMsg = InMessage;
294 type CastMsg = ();
295 type OutMsg = ();
296 type State = ();
297 type Error = ();
298
299 fn new() -> Self {
300 Self {}
301 }
302
303 async fn handle_call(
304 &mut self,
305 _: Self::CallMsg,
306 _: &GenServerHandle<Self>,
307 _: Self::State,
308 ) -> CallResponse<Self> {
309 CallResponse::Stop(())
310 }
311
312 async fn handle_cast(
313 &mut self,
314 _: Self::CastMsg,
315 _: &GenServerHandle<Self>,
316 _: Self::State,
317 ) -> CastResponse<Self> {
318 rt::sleep(Duration::from_millis(20)).await;
319 thread::sleep(Duration::from_secs(2));
320 CastResponse::Stop
321 }
322 }
323
324 struct WellBehavedTask;
325
326 #[derive(Clone)]
327 struct CountState {
328 pub count: u64,
329 }
330
331 impl GenServer for WellBehavedTask {
332 type CallMsg = InMessage;
333 type CastMsg = ();
334 type OutMsg = OutMsg;
335 type State = CountState;
336 type Error = ();
337
338 fn new() -> Self {
339 Self {}
340 }
341
342 async fn handle_call(
343 &mut self,
344 message: Self::CallMsg,
345 _: &GenServerHandle<Self>,
346 state: Self::State,
347 ) -> CallResponse<Self> {
348 match message {
349 InMessage::GetCount => {
350 let count = state.count;
351 CallResponse::Reply(state, OutMsg::Count(count))
352 }
353 InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
354 }
355 }
356
357 async fn handle_cast(
358 &mut self,
359 _: Self::CastMsg,
360 handle: &GenServerHandle<Self>,
361 mut state: Self::State,
362 ) -> CastResponse<Self> {
363 state.count += 1;
364 println!("{:?}: good still alive", thread::current().id());
365 send_after(Duration::from_millis(100), handle.to_owned(), ());
366 CastResponse::NoReply(state)
367 }
368 }
369
370 #[test]
371 pub fn badly_behaved_thread_non_blocking() {
372 let runtime = rt::Runtime::new().unwrap();
373 runtime.block_on(async move {
374 let mut badboy = BadlyBehavedTask::start(());
375 let _ = badboy.cast(()).await;
376 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
377 let _ = goodboy.cast(()).await;
378 rt::sleep(Duration::from_secs(1)).await;
379 let count = goodboy.call(InMessage::GetCount).await.unwrap();
380
381 match count {
382 OutMsg::Count(num) => {
383 assert_ne!(num, 10);
384 }
385 }
386 goodboy.call(InMessage::Stop).await.unwrap();
387 });
388 }
389
390 #[test]
391 pub fn badly_behaved_thread() {
392 let runtime = rt::Runtime::new().unwrap();
393 runtime.block_on(async move {
394 let mut badboy = BadlyBehavedTask::start_blocking(());
395 let _ = badboy.cast(()).await;
396 let mut goodboy = WellBehavedTask::start(CountState { count: 0 });
397 let _ = goodboy.cast(()).await;
398 rt::sleep(Duration::from_secs(1)).await;
399 let count = goodboy.call(InMessage::GetCount).await.unwrap();
400
401 match count {
402 OutMsg::Count(num) => {
403 assert_eq!(num, 10);
404 }
405 }
406 goodboy.call(InMessage::Stop).await.unwrap();
407 });
408 }
409}