spawned_concurrency/threads/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use spawned_rt::threads::{self as rt, mpsc, oneshot};
4use std::{
5    fmt::Debug,
6    panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17    fn clone(&self) -> Self {
18        Self {
19            tx: self.tx.clone(),
20        }
21    }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25    pub(crate) fn new(initial_state: G::State) -> Self {
26        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27        let handle = GenServerHandle { tx };
28        let mut gen_server: G = GenServer::new();
29        let handle_clone = handle.clone();
30        // Ignore the JoinHandle for now. Maybe we'll use it in the future
31        let _join_handle = rt::spawn(move || {
32            if gen_server.run(&handle, &mut rx, initial_state).is_err() {
33                tracing::trace!("GenServer crashed")
34            };
35        });
36        handle_clone
37    }
38
39    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
40        self.tx.clone()
41    }
42
43    pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
44        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
45        self.tx.send(GenServerInMsg::Call {
46            sender: oneshot_tx,
47            message,
48        })?;
49        match oneshot_rx.recv() {
50            Ok(result) => result,
51            Err(_) => Err(GenServerError::Server),
52        }
53    }
54
55    pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
56        self.tx
57            .send(GenServerInMsg::Cast { message })
58            .map_err(|_error| GenServerError::Server)
59    }
60}
61
62pub enum GenServerInMsg<G: GenServer> {
63    Call {
64        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
65        message: G::CallMsg,
66    },
67    Cast {
68        message: G::CastMsg,
69    },
70}
71
72pub enum CallResponse<G: GenServer> {
73    Reply(G::State, G::OutMsg),
74    Unused,
75    Stop(G::OutMsg),
76}
77
78pub enum CastResponse<G: GenServer> {
79    NoReply(G::State),
80    Unused,
81    Stop,
82}
83
84pub trait GenServer
85where
86    Self: Default + Send + Sized,
87{
88    type CallMsg: Clone + Send + Sized;
89    type CastMsg: Clone + Send + Sized;
90    type OutMsg: Send + Sized;
91    type State: Clone + Send;
92    type Error: Debug;
93
94    fn new() -> Self {
95        Self::default()
96    }
97
98    fn start(initial_state: Self::State) -> GenServerHandle<Self> {
99        GenServerHandle::new(initial_state)
100    }
101
102    /// We copy the same interface as tasks, but all threads can work
103    /// while blocking by default
104    fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
105        GenServerHandle::new(initial_state)
106    }
107
108    fn run(
109        &mut self,
110        handle: &GenServerHandle<Self>,
111        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
112        state: Self::State,
113    ) -> Result<(), GenServerError> {
114        match self.init(handle, state) {
115            Ok(new_state) => {
116                self.main_loop(handle, rx, new_state)?;
117                Ok(())
118            }
119            Err(err) => {
120                tracing::error!("Initialization failed: {err:?}");
121                Err(GenServerError::Initialization)
122            }
123        }
124    }
125
126    /// Initialization function. It's called before main loop. It
127    /// can be overrided on implementations in case initial steps are
128    /// required.
129    fn init(
130        &mut self,
131        _handle: &GenServerHandle<Self>,
132        state: Self::State,
133    ) -> Result<Self::State, Self::Error> {
134        Ok(state)
135    }
136
137    fn main_loop(
138        &mut self,
139        handle: &GenServerHandle<Self>,
140        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
141        mut state: Self::State,
142    ) -> Result<(), GenServerError> {
143        loop {
144            let (new_state, cont) = self.receive(handle, rx, state)?;
145            if !cont {
146                break;
147            }
148            state = new_state;
149        }
150        tracing::trace!("Stopping GenServer");
151        Ok(())
152    }
153
154    fn receive(
155        &mut self,
156        handle: &GenServerHandle<Self>,
157        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
158        state: Self::State,
159    ) -> Result<(Self::State, bool), GenServerError> {
160        let message = rx.recv().ok();
161
162        // Save current state in case of a rollback
163        let state_clone = state.clone();
164
165        let (keep_running, new_state) = match message {
166            Some(GenServerInMsg::Call { sender, message }) => {
167                let (keep_running, new_state, response) =
168                    match catch_unwind(AssertUnwindSafe(|| {
169                        self.handle_call(message, handle, state)
170                    })) {
171                        Ok(response) => match response {
172                            CallResponse::Reply(new_state, response) => {
173                                (true, new_state, Ok(response))
174                            }
175                            CallResponse::Stop(response) => (false, state_clone, Ok(response)),
176                            CallResponse::Unused => {
177                                tracing::error!("GenServer received unexpected CallMessage");
178                                (false, state_clone, Err(GenServerError::CallMsgUnused))
179                            }
180                        },
181                        Err(error) => {
182                            tracing::trace!(
183                                "Error in callback, reverting state - Error: '{error:?}'"
184                            );
185                            (true, state_clone, Err(GenServerError::Callback))
186                        }
187                    };
188                // Send response back
189                if sender.send(response).is_err() {
190                    tracing::trace!("GenServer failed to send response back, client must have died")
191                };
192                (keep_running, new_state)
193            }
194            Some(GenServerInMsg::Cast { message }) => {
195                match catch_unwind(AssertUnwindSafe(|| {
196                    self.handle_cast(message, handle, state)
197                })) {
198                    Ok(response) => match response {
199                        CastResponse::NoReply(new_state) => (true, new_state),
200                        CastResponse::Stop => (false, state_clone),
201                        CastResponse::Unused => {
202                            tracing::error!("GenServer received unexpected CastMessage");
203                            (false, state_clone)
204                        }
205                    },
206                    Err(error) => {
207                        tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
208                        (true, state_clone)
209                    }
210                }
211            }
212            None => {
213                // Channel has been closed; won't receive further messages. Stop the server.
214                (false, state)
215            }
216        };
217        Ok((new_state, keep_running))
218    }
219
220    fn handle_call(
221        &mut self,
222        _message: Self::CallMsg,
223        _handle: &GenServerHandle<Self>,
224        _state: Self::State,
225    ) -> CallResponse<Self> {
226        CallResponse::Unused
227    }
228
229    fn handle_cast(
230        &mut self,
231        _message: Self::CastMsg,
232        _handle: &GenServerHandle<Self>,
233        _state: Self::State,
234    ) -> CastResponse<Self> {
235        CastResponse::Unused
236    }
237}