spawned_concurrency/threads/
gen_server.rs1use spawned_rt::threads::{self as rt, mpsc, oneshot};
4use std::{
5 fmt::Debug,
6 panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17 fn clone(&self) -> Self {
18 Self {
19 tx: self.tx.clone(),
20 }
21 }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25 pub(crate) fn new(initial_state: G::State) -> Self {
26 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27 let handle = GenServerHandle { tx };
28 let mut gen_server: G = GenServer::new();
29 let handle_clone = handle.clone();
30 let _join_handle = rt::spawn(move || {
32 if gen_server.run(&handle, &mut rx, initial_state).is_err() {
33 tracing::trace!("GenServer crashed")
34 };
35 });
36 handle_clone
37 }
38
39 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
40 self.tx.clone()
41 }
42
43 pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
44 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
45 self.tx.send(GenServerInMsg::Call {
46 sender: oneshot_tx,
47 message,
48 })?;
49 match oneshot_rx.recv() {
50 Ok(result) => result,
51 Err(_) => Err(GenServerError::Server),
52 }
53 }
54
55 pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
56 self.tx
57 .send(GenServerInMsg::Cast { message })
58 .map_err(|_error| GenServerError::Server)
59 }
60}
61
62pub enum GenServerInMsg<G: GenServer> {
63 Call {
64 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
65 message: G::CallMsg,
66 },
67 Cast {
68 message: G::CastMsg,
69 },
70}
71
72pub enum CallResponse<G: GenServer> {
73 Reply(G::State, G::OutMsg),
74 Unused,
75 Stop(G::OutMsg),
76}
77
78pub enum CastResponse<G: GenServer> {
79 NoReply(G::State),
80 Unused,
81 Stop,
82}
83
84pub trait GenServer
85where
86 Self: Send + Sized,
87{
88 type CallMsg: Clone + Send + Sized;
89 type CastMsg: Clone + Send + Sized;
90 type OutMsg: Send + Sized;
91 type State: Clone + Send;
92 type Error: Debug;
93
94 fn new() -> Self;
95
96 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
97 GenServerHandle::new(initial_state)
98 }
99
100 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
103 GenServerHandle::new(initial_state)
104 }
105
106 fn run(
107 &mut self,
108 handle: &GenServerHandle<Self>,
109 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
110 state: Self::State,
111 ) -> Result<(), GenServerError> {
112 match self.init(handle, state) {
113 Ok(new_state) => {
114 self.main_loop(handle, rx, new_state)?;
115 Ok(())
116 }
117 Err(err) => {
118 tracing::error!("Initialization failed: {err:?}");
119 Err(GenServerError::Initialization)
120 }
121 }
122 }
123
124 fn init(
128 &mut self,
129 _handle: &GenServerHandle<Self>,
130 state: Self::State,
131 ) -> Result<Self::State, Self::Error> {
132 Ok(state)
133 }
134
135 fn main_loop(
136 &mut self,
137 handle: &GenServerHandle<Self>,
138 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
139 mut state: Self::State,
140 ) -> Result<(), GenServerError> {
141 loop {
142 let (new_state, cont) = self.receive(handle, rx, state)?;
143 if !cont {
144 break;
145 }
146 state = new_state;
147 }
148 tracing::trace!("Stopping GenServer");
149 Ok(())
150 }
151
152 fn receive(
153 &mut self,
154 handle: &GenServerHandle<Self>,
155 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
156 state: Self::State,
157 ) -> Result<(Self::State, bool), GenServerError> {
158 let message = rx.recv().ok();
159
160 let state_clone = state.clone();
162
163 let (keep_running, new_state) = match message {
164 Some(GenServerInMsg::Call { sender, message }) => {
165 let (keep_running, new_state, response) =
166 match catch_unwind(AssertUnwindSafe(|| {
167 self.handle_call(message, handle, state)
168 })) {
169 Ok(response) => match response {
170 CallResponse::Reply(new_state, response) => {
171 (true, new_state, Ok(response))
172 }
173 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
174 CallResponse::Unused => {
175 tracing::error!("GenServer received unexpected CallMessage");
176 (false, state_clone, Err(GenServerError::CallMsgUnused))
177 }
178 },
179 Err(error) => {
180 tracing::trace!(
181 "Error in callback, reverting state - Error: '{error:?}'"
182 );
183 (true, state_clone, Err(GenServerError::Callback))
184 }
185 };
186 if sender.send(response).is_err() {
188 tracing::trace!("GenServer failed to send response back, client must have died")
189 };
190 (keep_running, new_state)
191 }
192 Some(GenServerInMsg::Cast { message }) => {
193 match catch_unwind(AssertUnwindSafe(|| {
194 self.handle_cast(message, handle, state)
195 })) {
196 Ok(response) => match response {
197 CastResponse::NoReply(new_state) => (true, new_state),
198 CastResponse::Stop => (false, state_clone),
199 CastResponse::Unused => {
200 tracing::error!("GenServer received unexpected CastMessage");
201 (false, state_clone)
202 }
203 },
204 Err(error) => {
205 tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
206 (true, state_clone)
207 }
208 }
209 }
210 None => {
211 (false, state)
213 }
214 };
215 Ok((new_state, keep_running))
216 }
217
218 fn handle_call(
219 &mut self,
220 _message: Self::CallMsg,
221 _handle: &GenServerHandle<Self>,
222 _state: Self::State,
223 ) -> CallResponse<Self> {
224 CallResponse::Unused
225 }
226
227 fn handle_cast(
228 &mut self,
229 _message: Self::CastMsg,
230 _handle: &GenServerHandle<Self>,
231 _state: Self::State,
232 ) -> CastResponse<Self> {
233 CastResponse::Unused
234 }
235}