spawned_concurrency/threads/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use spawned_rt::threads::{self as rt, mpsc, oneshot};
4use std::{
5    fmt::Debug,
6    panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17    fn clone(&self) -> Self {
18        Self {
19            tx: self.tx.clone(),
20        }
21    }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25    pub(crate) fn new(initial_state: G::State) -> Self {
26        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27        let handle = GenServerHandle { tx };
28        let mut gen_server: G = GenServer::new();
29        let handle_clone = handle.clone();
30        // Ignore the JoinHandle for now. Maybe we'll use it in the future
31        let _join_handle = rt::spawn(move || {
32            if gen_server.run(&handle, &mut rx, initial_state).is_err() {
33                tracing::trace!("GenServer crashed")
34            };
35        });
36        handle_clone
37    }
38
39    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
40        self.tx.clone()
41    }
42
43    pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
44        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
45        self.tx.send(GenServerInMsg::Call {
46            sender: oneshot_tx,
47            message,
48        })?;
49        match oneshot_rx.recv() {
50            Ok(result) => result,
51            Err(_) => Err(GenServerError::Server),
52        }
53    }
54
55    pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
56        self.tx
57            .send(GenServerInMsg::Cast { message })
58            .map_err(|_error| GenServerError::Server)
59    }
60}
61
62pub enum GenServerInMsg<G: GenServer> {
63    Call {
64        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
65        message: G::CallMsg,
66    },
67    Cast {
68        message: G::CastMsg,
69    },
70}
71
72pub enum CallResponse<G: GenServer> {
73    Reply(G::State, G::OutMsg),
74    Unused,
75    Stop(G::OutMsg),
76}
77
78pub enum CastResponse<G: GenServer> {
79    NoReply(G::State),
80    Unused,
81    Stop,
82}
83
84pub trait GenServer
85where
86    Self: Send + Sized,
87{
88    type CallMsg: Clone + Send + Sized;
89    type CastMsg: Clone + Send + Sized;
90    type OutMsg: Send + Sized;
91    type State: Clone + Send;
92    type Error: Debug;
93
94    fn new() -> Self;
95
96    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
97        GenServerHandle::new(initial_state)
98    }
99
100    /// We copy the same interface as tasks, but all threads can work
101    /// while blocking by default
102    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
103        GenServerHandle::new(initial_state)
104    }
105
106    fn run(
107        &mut self,
108        handle: &GenServerHandle<Self>,
109        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
110        state: Self::State,
111    ) -> Result<(), GenServerError> {
112        match self.init(handle, state) {
113            Ok(new_state) => {
114                self.main_loop(handle, rx, new_state)?;
115                Ok(())
116            }
117            Err(err) => {
118                tracing::error!("Initialization failed: {err:?}");
119                Err(GenServerError::Initialization)
120            }
121        }
122    }
123
124    /// Initialization function. It's called before main loop. It
125    /// can be overrided on implementations in case initial steps are
126    /// required.
127    fn init(
128        &mut self,
129        _handle: &GenServerHandle<Self>,
130        state: Self::State,
131    ) -> Result<Self::State, Self::Error> {
132        Ok(state)
133    }
134
135    fn main_loop(
136        &mut self,
137        handle: &GenServerHandle<Self>,
138        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
139        mut state: Self::State,
140    ) -> Result<(), GenServerError> {
141        loop {
142            let (new_state, cont) = self.receive(handle, rx, state)?;
143            if !cont {
144                break;
145            }
146            state = new_state;
147        }
148        tracing::trace!("Stopping GenServer");
149        Ok(())
150    }
151
152    fn receive(
153        &mut self,
154        handle: &GenServerHandle<Self>,
155        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
156        state: Self::State,
157    ) -> Result<(Self::State, bool), GenServerError> {
158        let message = rx.recv().ok();
159
160        // Save current state in case of a rollback
161        let state_clone = state.clone();
162
163        let (keep_running, new_state) = match message {
164            Some(GenServerInMsg::Call { sender, message }) => {
165                let (keep_running, new_state, response) =
166                    match catch_unwind(AssertUnwindSafe(|| {
167                        self.handle_call(message, handle, state)
168                    })) {
169                        Ok(response) => match response {
170                            CallResponse::Reply(new_state, response) => {
171                                (true, new_state, Ok(response))
172                            }
173                            CallResponse::Stop(response) => (false, state_clone, Ok(response)),
174                            CallResponse::Unused => {
175                                tracing::error!("GenServer received unexpected CallMessage");
176                                (false, state_clone, Err(GenServerError::CallMsgUnused))
177                            }
178                        },
179                        Err(error) => {
180                            tracing::trace!(
181                                "Error in callback, reverting state - Error: '{error:?}'"
182                            );
183                            (true, state_clone, Err(GenServerError::Callback))
184                        }
185                    };
186                // Send response back
187                if sender.send(response).is_err() {
188                    tracing::trace!("GenServer failed to send response back, client must have died")
189                };
190                (keep_running, new_state)
191            }
192            Some(GenServerInMsg::Cast { message }) => {
193                match catch_unwind(AssertUnwindSafe(|| {
194                    self.handle_cast(message, handle, state)
195                })) {
196                    Ok(response) => match response {
197                        CastResponse::NoReply(new_state) => (true, new_state),
198                        CastResponse::Stop => (false, state_clone),
199                        CastResponse::Unused => {
200                            tracing::error!("GenServer received unexpected CastMessage");
201                            (false, state_clone)
202                        }
203                    },
204                    Err(error) => {
205                        tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
206                        (true, state_clone)
207                    }
208                }
209            }
210            None => {
211                // Channel has been closed; won't receive further messages. Stop the server.
212                (false, state)
213            }
214        };
215        Ok((new_state, keep_running))
216    }
217
218    fn handle_call(
219        &mut self,
220        _message: Self::CallMsg,
221        _handle: &GenServerHandle<Self>,
222        _state: Self::State,
223    ) -> CallResponse<Self> {
224        CallResponse::Unused
225    }
226
227    fn handle_cast(
228        &mut self,
229        _message: Self::CastMsg,
230        _handle: &GenServerHandle<Self>,
231        _state: Self::State,
232    ) -> CastResponse<Self> {
233        CastResponse::Unused
234    }
235}