spawned_concurrency/threads/
gen_server.rs

1//! GenServer trait and structs to create an abstraction similar to Erlang gen_server.
2//! See examples/name_server for a usage example.
3use spawned_rt::threads::{self as rt, mpsc, oneshot, CancellationToken};
4use std::{
5    fmt::Debug,
6    panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13    pub tx: mpsc::Sender<GenServerInMsg<G>>,
14    cancellation_token: CancellationToken,
15}
16
17impl<G: GenServer> Clone for GenServerHandle<G> {
18    fn clone(&self) -> Self {
19        Self {
20            tx: self.tx.clone(),
21            cancellation_token: self.cancellation_token.clone(),
22        }
23    }
24}
25
26impl<G: GenServer> GenServerHandle<G> {
27    pub(crate) fn new(gen_server: G) -> Self {
28        let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
29        let cancellation_token = CancellationToken::new();
30        let handle = GenServerHandle {
31            tx,
32            cancellation_token,
33        };
34        let handle_clone = handle.clone();
35        // Ignore the JoinHandle for now. Maybe we'll use it in the future
36        let _join_handle = rt::spawn(move || {
37            if gen_server.run(&handle, &mut rx).is_err() {
38                tracing::trace!("GenServer crashed")
39            };
40        });
41        handle_clone
42    }
43
44    pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
45        self.tx.clone()
46    }
47
48    pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
49        let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
50        self.tx.send(GenServerInMsg::Call {
51            sender: oneshot_tx,
52            message,
53        })?;
54        match oneshot_rx.recv() {
55            Ok(result) => result,
56            Err(_) => Err(GenServerError::Server),
57        }
58    }
59
60    pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
61        self.tx
62            .send(GenServerInMsg::Cast { message })
63            .map_err(|_error| GenServerError::Server)
64    }
65
66    pub fn cancellation_token(&self) -> CancellationToken {
67        self.cancellation_token.clone()
68    }
69}
70
71pub enum GenServerInMsg<G: GenServer> {
72    Call {
73        sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
74        message: G::CallMsg,
75    },
76    Cast {
77        message: G::CastMsg,
78    },
79}
80
81pub enum CallResponse<G: GenServer> {
82    Reply(G::OutMsg),
83    Unused,
84    Stop(G::OutMsg),
85}
86
87pub enum CastResponse {
88    NoReply,
89    Unused,
90    Stop,
91}
92
93pub trait GenServer: Send + Sized {
94    type CallMsg: Clone + Send + Sized;
95    type CastMsg: Clone + Send + Sized;
96    type OutMsg: Send + Sized;
97    type Error: Debug;
98
99    fn start(self) -> GenServerHandle<Self> {
100        GenServerHandle::new(self)
101    }
102
103    /// We copy the same interface as tasks, but all threads can work
104    /// while blocking by default
105    fn start_blocking(self) -> GenServerHandle<Self> {
106        GenServerHandle::new(self)
107    }
108
109    fn run(
110        self,
111        handle: &GenServerHandle<Self>,
112        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
113    ) -> Result<(), GenServerError> {
114        let mut cancellation_token = handle.cancellation_token.clone();
115        let res = match self.init(handle) {
116            Ok(new_state) => Ok(new_state.main_loop(handle, rx)?),
117            Err(err) => {
118                tracing::error!("Initialization failed: {err:?}");
119                Err(GenServerError::Initialization)
120            }
121        };
122        cancellation_token.cancel();
123        res
124    }
125
126    /// Initialization function. It's called before main loop. It
127    /// can be overrided on implementations in case initial steps are
128    /// required.
129    fn init(self, _handle: &GenServerHandle<Self>) -> Result<Self, Self::Error> {
130        Ok(self)
131    }
132
133    fn main_loop(
134        mut self,
135        handle: &GenServerHandle<Self>,
136        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
137    ) -> Result<(), GenServerError> {
138        loop {
139            if !self.receive(handle, rx)? {
140                break;
141            }
142        }
143        tracing::trace!("Stopping GenServer");
144        Ok(())
145    }
146
147    fn receive(
148        &mut self,
149        handle: &GenServerHandle<Self>,
150        rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
151    ) -> Result<bool, GenServerError> {
152        let message = rx.recv().ok();
153
154        let keep_running = match message {
155            Some(GenServerInMsg::Call { sender, message }) => {
156                let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| {
157                    self.handle_call(message, handle)
158                })) {
159                    Ok(response) => match response {
160                        CallResponse::Reply(response) => (true, Ok(response)),
161                        CallResponse::Stop(response) => (false, Ok(response)),
162                        CallResponse::Unused => {
163                            tracing::error!("GenServer received unexpected CallMessage");
164                            (false, Err(GenServerError::CallMsgUnused))
165                        }
166                    },
167                    Err(error) => {
168                        tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
169                        (true, Err(GenServerError::Callback))
170                    }
171                };
172                // Send response back
173                if sender.send(response).is_err() {
174                    tracing::trace!("GenServer failed to send response back, client must have died")
175                };
176                keep_running
177            }
178            Some(GenServerInMsg::Cast { message }) => {
179                match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) {
180                    Ok(response) => match response {
181                        CastResponse::NoReply => true,
182                        CastResponse::Stop => false,
183                        CastResponse::Unused => {
184                            tracing::error!("GenServer received unexpected CastMessage");
185                            false
186                        }
187                    },
188                    Err(error) => {
189                        tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
190                        true
191                    }
192                }
193            }
194            None => {
195                // Channel has been closed; won't receive further messages. Stop the server.
196                false
197            }
198        };
199        Ok(keep_running)
200    }
201
202    fn handle_call(
203        &mut self,
204        _message: Self::CallMsg,
205        _handle: &GenServerHandle<Self>,
206    ) -> CallResponse<Self> {
207        CallResponse::Unused
208    }
209
210    fn handle_cast(
211        &mut self,
212        _message: Self::CastMsg,
213        _handle: &GenServerHandle<Self>,
214    ) -> CastResponse {
215        CastResponse::Unused
216    }
217}