spawned_concurrency/threads/
gen_server.rs1use spawned_rt::threads::{self as rt, mpsc, oneshot};
4use std::{
5 fmt::Debug,
6 panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14}
15
16impl<G: GenServer> Clone for GenServerHandle<G> {
17 fn clone(&self) -> Self {
18 Self {
19 tx: self.tx.clone(),
20 }
21 }
22}
23
24impl<G: GenServer> GenServerHandle<G> {
25 pub(crate) fn new(initial_state: G::State) -> Self {
26 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
27 let handle = GenServerHandle { tx };
28 let mut gen_server: G = GenServer::new();
29 let handle_clone = handle.clone();
30 let _join_handle = rt::spawn(move || {
32 if gen_server.run(&handle, &mut rx, initial_state).is_err() {
33 tracing::trace!("GenServer crashed")
34 };
35 });
36 handle_clone
37 }
38
39 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
40 self.tx.clone()
41 }
42
43 pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
44 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
45 self.tx.send(GenServerInMsg::Call {
46 sender: oneshot_tx,
47 message,
48 })?;
49 match oneshot_rx.recv() {
50 Ok(result) => result,
51 Err(_) => Err(GenServerError::Server),
52 }
53 }
54
55 pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
56 self.tx
57 .send(GenServerInMsg::Cast { message })
58 .map_err(|_error| GenServerError::Server)
59 }
60}
61
62pub enum GenServerInMsg<G: GenServer> {
63 Call {
64 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
65 message: G::CallMsg,
66 },
67 Cast {
68 message: G::CastMsg,
69 },
70}
71
72pub enum CallResponse<G: GenServer> {
73 Reply(G::State, G::OutMsg),
74 Unused,
75 Stop(G::OutMsg),
76}
77
78pub enum CastResponse<G: GenServer> {
79 NoReply(G::State),
80 Unused,
81 Stop,
82}
83
84pub trait GenServer
85where
86 Self: Default + Send + Sized,
87{
88 type CallMsg: Clone + Send + Sized;
89 type CastMsg: Clone + Send + Sized;
90 type OutMsg: Send + Sized;
91 type State: Clone + Send;
92 type Error: Debug;
93
94 fn new() -> Self {
95 Self::default()
96 }
97
98 fn start(initial_state: Self::State) -> GenServerHandle<Self> {
99 GenServerHandle::new(initial_state)
100 }
101
102 fn start_blocking(initial_state: Self::State) -> GenServerHandle<Self> {
105 GenServerHandle::new(initial_state)
106 }
107
108 fn run(
109 &mut self,
110 handle: &GenServerHandle<Self>,
111 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
112 state: Self::State,
113 ) -> Result<(), GenServerError> {
114 match self.init(handle, state) {
115 Ok(new_state) => {
116 self.main_loop(handle, rx, new_state)?;
117 Ok(())
118 }
119 Err(err) => {
120 tracing::error!("Initialization failed: {err:?}");
121 Err(GenServerError::Initialization)
122 }
123 }
124 }
125
126 fn init(
130 &mut self,
131 _handle: &GenServerHandle<Self>,
132 state: Self::State,
133 ) -> Result<Self::State, Self::Error> {
134 Ok(state)
135 }
136
137 fn main_loop(
138 &mut self,
139 handle: &GenServerHandle<Self>,
140 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
141 mut state: Self::State,
142 ) -> Result<(), GenServerError> {
143 loop {
144 let (new_state, cont) = self.receive(handle, rx, state)?;
145 if !cont {
146 break;
147 }
148 state = new_state;
149 }
150 tracing::trace!("Stopping GenServer");
151 Ok(())
152 }
153
154 fn receive(
155 &mut self,
156 handle: &GenServerHandle<Self>,
157 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
158 state: Self::State,
159 ) -> Result<(Self::State, bool), GenServerError> {
160 let message = rx.recv().ok();
161
162 let state_clone = state.clone();
164
165 let (keep_running, new_state) = match message {
166 Some(GenServerInMsg::Call { sender, message }) => {
167 let (keep_running, new_state, response) =
168 match catch_unwind(AssertUnwindSafe(|| {
169 self.handle_call(message, handle, state)
170 })) {
171 Ok(response) => match response {
172 CallResponse::Reply(new_state, response) => {
173 (true, new_state, Ok(response))
174 }
175 CallResponse::Stop(response) => (false, state_clone, Ok(response)),
176 CallResponse::Unused => {
177 tracing::error!("GenServer received unexpected CallMessage");
178 (false, state_clone, Err(GenServerError::CallMsgUnused))
179 }
180 },
181 Err(error) => {
182 tracing::trace!(
183 "Error in callback, reverting state - Error: '{error:?}'"
184 );
185 (true, state_clone, Err(GenServerError::Callback))
186 }
187 };
188 if sender.send(response).is_err() {
190 tracing::trace!("GenServer failed to send response back, client must have died")
191 };
192 (keep_running, new_state)
193 }
194 Some(GenServerInMsg::Cast { message }) => {
195 match catch_unwind(AssertUnwindSafe(|| {
196 self.handle_cast(message, handle, state)
197 })) {
198 Ok(response) => match response {
199 CastResponse::NoReply(new_state) => (true, new_state),
200 CastResponse::Stop => (false, state_clone),
201 CastResponse::Unused => {
202 tracing::error!("GenServer received unexpected CastMessage");
203 (false, state_clone)
204 }
205 },
206 Err(error) => {
207 tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
208 (true, state_clone)
209 }
210 }
211 }
212 None => {
213 (false, state)
215 }
216 };
217 Ok((new_state, keep_running))
218 }
219
220 fn handle_call(
221 &mut self,
222 _message: Self::CallMsg,
223 _handle: &GenServerHandle<Self>,
224 _state: Self::State,
225 ) -> CallResponse<Self> {
226 CallResponse::Unused
227 }
228
229 fn handle_cast(
230 &mut self,
231 _message: Self::CastMsg,
232 _handle: &GenServerHandle<Self>,
233 _state: Self::State,
234 ) -> CastResponse<Self> {
235 CastResponse::Unused
236 }
237}