spawned_concurrency/threads/
gen_server.rs1use spawned_rt::threads::{self as rt, mpsc, oneshot, CancellationToken};
4use std::{
5 fmt::Debug,
6 panic::{catch_unwind, AssertUnwindSafe},
7};
8
9use crate::error::GenServerError;
10
11#[derive(Debug)]
12pub struct GenServerHandle<G: GenServer + 'static> {
13 pub tx: mpsc::Sender<GenServerInMsg<G>>,
14 cancellation_token: CancellationToken,
15}
16
17impl<G: GenServer> Clone for GenServerHandle<G> {
18 fn clone(&self) -> Self {
19 Self {
20 tx: self.tx.clone(),
21 cancellation_token: self.cancellation_token.clone(),
22 }
23 }
24}
25
26impl<G: GenServer> GenServerHandle<G> {
27 pub(crate) fn new(gen_server: G) -> Self {
28 let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
29 let cancellation_token = CancellationToken::new();
30 let handle = GenServerHandle {
31 tx,
32 cancellation_token,
33 };
34 let handle_clone = handle.clone();
35 let _join_handle = rt::spawn(move || {
37 if gen_server.run(&handle, &mut rx).is_err() {
38 tracing::trace!("GenServer crashed")
39 };
40 });
41 handle_clone
42 }
43
44 pub fn sender(&self) -> mpsc::Sender<GenServerInMsg<G>> {
45 self.tx.clone()
46 }
47
48 pub fn call(&mut self, message: G::CallMsg) -> Result<G::OutMsg, GenServerError> {
49 let (oneshot_tx, oneshot_rx) = oneshot::channel::<Result<G::OutMsg, GenServerError>>();
50 self.tx.send(GenServerInMsg::Call {
51 sender: oneshot_tx,
52 message,
53 })?;
54 match oneshot_rx.recv() {
55 Ok(result) => result,
56 Err(_) => Err(GenServerError::Server),
57 }
58 }
59
60 pub fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
61 self.tx
62 .send(GenServerInMsg::Cast { message })
63 .map_err(|_error| GenServerError::Server)
64 }
65
66 pub fn cancellation_token(&self) -> CancellationToken {
67 self.cancellation_token.clone()
68 }
69}
70
71pub enum GenServerInMsg<G: GenServer> {
72 Call {
73 sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
74 message: G::CallMsg,
75 },
76 Cast {
77 message: G::CastMsg,
78 },
79}
80
81pub enum CallResponse<G: GenServer> {
82 Reply(G::OutMsg),
83 Unused,
84 Stop(G::OutMsg),
85}
86
87pub enum CastResponse {
88 NoReply,
89 Unused,
90 Stop,
91}
92
93pub trait GenServer: Send + Sized {
94 type CallMsg: Clone + Send + Sized;
95 type CastMsg: Clone + Send + Sized;
96 type OutMsg: Send + Sized;
97 type Error: Debug;
98
99 fn start(self) -> GenServerHandle<Self> {
100 GenServerHandle::new(self)
101 }
102
103 fn start_blocking(self) -> GenServerHandle<Self> {
106 GenServerHandle::new(self)
107 }
108
109 fn run(
110 self,
111 handle: &GenServerHandle<Self>,
112 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
113 ) -> Result<(), GenServerError> {
114 let mut cancellation_token = handle.cancellation_token.clone();
115 let res = match self.init(handle) {
116 Ok(new_state) => Ok(new_state.main_loop(handle, rx)?),
117 Err(err) => {
118 tracing::error!("Initialization failed: {err:?}");
119 Err(GenServerError::Initialization)
120 }
121 };
122 cancellation_token.cancel();
123 res
124 }
125
126 fn init(self, _handle: &GenServerHandle<Self>) -> Result<Self, Self::Error> {
130 Ok(self)
131 }
132
133 fn main_loop(
134 mut self,
135 handle: &GenServerHandle<Self>,
136 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
137 ) -> Result<(), GenServerError> {
138 loop {
139 if !self.receive(handle, rx)? {
140 break;
141 }
142 }
143 tracing::trace!("Stopping GenServer");
144 Ok(())
145 }
146
147 fn receive(
148 &mut self,
149 handle: &GenServerHandle<Self>,
150 rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
151 ) -> Result<bool, GenServerError> {
152 let message = rx.recv().ok();
153
154 let keep_running = match message {
155 Some(GenServerInMsg::Call { sender, message }) => {
156 let (keep_running, response) = match catch_unwind(AssertUnwindSafe(|| {
157 self.handle_call(message, handle)
158 })) {
159 Ok(response) => match response {
160 CallResponse::Reply(response) => (true, Ok(response)),
161 CallResponse::Stop(response) => (false, Ok(response)),
162 CallResponse::Unused => {
163 tracing::error!("GenServer received unexpected CallMessage");
164 (false, Err(GenServerError::CallMsgUnused))
165 }
166 },
167 Err(error) => {
168 tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
169 (true, Err(GenServerError::Callback))
170 }
171 };
172 if sender.send(response).is_err() {
174 tracing::trace!("GenServer failed to send response back, client must have died")
175 };
176 keep_running
177 }
178 Some(GenServerInMsg::Cast { message }) => {
179 match catch_unwind(AssertUnwindSafe(|| self.handle_cast(message, handle))) {
180 Ok(response) => match response {
181 CastResponse::NoReply => true,
182 CastResponse::Stop => false,
183 CastResponse::Unused => {
184 tracing::error!("GenServer received unexpected CastMessage");
185 false
186 }
187 },
188 Err(error) => {
189 tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
190 true
191 }
192 }
193 }
194 None => {
195 false
197 }
198 };
199 Ok(keep_running)
200 }
201
202 fn handle_call(
203 &mut self,
204 _message: Self::CallMsg,
205 _handle: &GenServerHandle<Self>,
206 ) -> CallResponse<Self> {
207 CallResponse::Unused
208 }
209
210 fn handle_cast(
211 &mut self,
212 _message: Self::CastMsg,
213 _handle: &GenServerHandle<Self>,
214 ) -> CastResponse {
215 CastResponse::Unused
216 }
217}