spawned_concurrency/threads/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use spawned_rt::threads::{self as rt, mpsc, oneshot};
4use std::{
5    fmt::Debug,
6    panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17    fn clone(&self) -> Self {
18        Self {
19            tx: self.tx.clone(),
20        }
21    }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25    pub(crate) fn new(gen_server: G) -> Self {
26        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27        let handle = GenServerHandle { tx };
28        let handle_clone = handle.clone();
29        // Ignore the JoinHandle for now. Maybe we'll use it in the future
30        let _join_handle = rt::spawn(move || {
31            if gen_server.run(&handle, &mut rx).is_err() {
32                tracing::trace!("GenServer crashed")
33            };
34        });
35        handle_clone
36    }
37
38    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
39        self.tx.clone()
40    }
41
42    pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
43        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
44        self.tx.send(GenServerInMsg::Call {
45            sender: oneshot_tx,
46            message,
47        })?;
48        match oneshot_rx.recv() {
49            Ok(result) => result,
50            Err(_) => Err(GenServerError::Server),
51        }
52    }
53
54    pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
55        self.tx
56            .send(GenServerInMsg::Cast { message })
57            .map_err(|_error| GenServerError::Server)
58    }
59}
60
61pub enum GenServerInMsg<G: GenServer> {
62    Call {
63        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
64        message: G::CallMsg,
65    },
66    Cast {
67        message: G::CastMsg,
68    },
69}
70
71pub enum CallResponse<G: GenServer> {
72    Reply(G, G::OutMsg),
73    Unused,
74    Stop(G::OutMsg),
75}
76
77pub enum CastResponse<G: GenServer> {
78    NoReply(G),
79    Unused,
80    Stop,
81}
82
83pub trait GenServer: Send + Sized + Clone {
84    type CallMsg: Clone + Send + Sized;
85    type CastMsg: Clone + Send + Sized;
86    type OutMsg: Send + Sized;
87    type Error: Debug;
88
89    fn start(self) -> GenServerHandle<Self> {
90        GenServerHandle::new(self)
91    }
92
93    /// We copy the same interface as tasks, but all threads can work
94    /// while blocking by default
95    fn start_blocking(self) -> GenServerHandle<Self> {
96        GenServerHandle::new(self)
97    }
98
99    fn run(
100        self,
101        handle: &GenServerHandle<Self>,
102        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
103    ) -> Result<(), GenServerError> {
104        match self.init(handle) {
105            Ok(new_state) => {
106                new_state.main_loop(handle, rx)?;
107                Ok(())
108            }
109            Err(err) => {
110                tracing::error!("Initialization failed: {err:?}");
111                Err(GenServerError::Initialization)
112            }
113        }
114    }
115
116    /// Initialization function. It's called before main loop. It
117    /// can be overrided on implementations in case initial steps are
118    /// required.
119    fn init(self, _handle: &GenServerHandle<Self>) -> Result<Self, Self::Error> {
120        Ok(self)
121    }
122
123    fn main_loop(
124        mut self,
125        handle: &GenServerHandle<Self>,
126        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
127    ) -> Result<(), GenServerError> {
128        loop {
129            let (new_state, cont) = self.receive(handle, rx)?;
130            if !cont {
131                break;
132            }
133            self = new_state;
134        }
135        tracing::trace!("Stopping GenServer");
136        Ok(())
137    }
138
139    fn receive(
140        self,
141        handle: &GenServerHandle<Self>,
142        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
143    ) -> Result<(Self, bool), GenServerError> {
144        let message = rx.recv().ok();
145
146        // Save current state in case of a rollback
147        let state_clone = self.clone();
148
149        let (keep_running, new_state) = match message {
150            Some(GenServerInMsg::Call { sender, message }) => {
151                let (keep_running, new_state, response) =
152                    match catch_unwind(AssertUnwindSafe(|| self.handle_call(message, handle))) {
153                        Ok(response) => match response {
154                            CallResponse::Reply(new_state, response) => {
155                                (true, new_state, Ok(response))
156                            }
157                            CallResponse::Stop(response) => (false, state_clone, Ok(response)),
158                            CallResponse::Unused => {
159                                tracing::error!("GenServer received unexpected CallMessage");
160                                (false, state_clone, Err(GenServerError::CallMsgUnused))
161                            }
162                        },
163                        Err(error) => {
164                            tracing::trace!(
165                                "Error in callback, reverting state - Error: '{error:?}'"
166                            );
167                            (true, state_clone, Err(GenServerError::Callback))
168                        }
169                    };
170                // Send response back
171                if sender.send(response).is_err() {
172                    tracing::trace!("GenServer failed to send response back, client must have died")
173                };
174                (keep_running, new_state)
175            }
176            Some(GenServerInMsg::Cast { message }) => {
177                match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) {
178                    Ok(response) => match response {
179                        CastResponse::NoReply(new_state) => (true, new_state),
180                        CastResponse::Stop => (false, state_clone),
181                        CastResponse::Unused => {
182                            tracing::error!("GenServer received unexpected CastMessage");
183                            (false, state_clone)
184                        }
185                    },
186                    Err(error) => {
187                        tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
188                        (true, state_clone)
189                    }
190                }
191            }
192            None => {
193                // Channel has been closed; won't receive further messages. Stop the server.
194                (false, self)
195            }
196        };
197        Ok((new_state, keep_running))
198    }
199
200    fn handle_call(
201        self,
202        _message: Self::CallMsg,
203        _handle: &GenServerHandle<Self>,
204    ) -> CallResponse<Self> {
205        CallResponse::Unused
206    }
207
208    fn handle_cast(
209        self,
210        _message: Self::CastMsg,
211        _handle: &GenServerHandle<Self>,
212    ) -> CastResponse<Self> {
213        CastResponse::Unused
214    }
215}