1use core::any::Any;
6use std::{fmt::Debug, marker::PhantomData};
7
8use async_trait::async_trait;
9use tokio::sync::{mpsc, oneshot};
10
11use crate::callback::CallbackWrapper;
12
13#[async_trait]
15pub trait SparkGenericModuleHandler<Message, Response, Error = eyre::Error> {
16 async fn handle_request(&mut self, request: Message) -> Result<Response, Error>;
18
19 async fn handle_command(&mut self, command: Message) -> Result<(), Error>;
21}
22
23pub trait SparkChannelCancellationTrait {
25 fn cancel_execution(&self);
27}
28
29pub enum SparkGenericModuleMessage<
31 Message,
32 Response,
33 CancellationMessage: SparkChannelCancellationTrait,
34 Error,
35> {
36 Request(CallbackWrapper<Message, Result<Response, Error>>),
38
39 Command(Message),
41
42 Shutdown(CancellationMessage),
44}
45
46pub trait IntoResult<Success, Error> {
48 type Output;
50
51 fn into_result(result: Result<Success, Error>) -> Self::Output;
53}
54
55impl<T, E> IntoResult<T, E> for eyre::Result<T>
57where
58 E: std::fmt::Display + Send + Sync + 'static,
59{
60 type Output = eyre::Result<T>;
61
62 fn into_result(result: Result<T, E>) -> Self::Output {
63 result.map_err(|e| eyre::eyre!("{}", e))
64 }
65}
66
67#[derive(Debug, Clone)]
69pub struct SparkChannelError(pub String);
70
71impl std::fmt::Display for SparkChannelError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 write!(f, "{}", self.0)
74 }
75}
76
77impl std::error::Error for SparkChannelError {}
78
79impl From<&'static str> for SparkChannelError {
80 fn from(s: &'static str) -> Self {
81 SparkChannelError(s.to_string())
82 }
83}
84
85impl From<String> for SparkChannelError {
86 fn from(s: String) -> Self {
87 SparkChannelError(s)
88 }
89}
90
91impl<T> IntoResult<T, SparkChannelError> for Result<T, SparkChannelError> {
92 type Output = Result<T, SparkChannelError>;
93
94 fn into_result(result: Result<T, SparkChannelError>) -> Self::Output {
95 result
96 }
97}
98
99#[derive(Clone)]
101pub struct SparkGenericModuleDispatcher<
102 Message,
103 Response,
104 CancellationMessage: SparkChannelCancellationTrait,
105 Error,
106> {
107 pub sender:
109 mpsc::Sender<SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>>,
110 _error_marker: PhantomData<Error>,
112}
113
114impl<Message, Response, CancellationMessage, Error>
115 SparkGenericModuleDispatcher<Message, Response, CancellationMessage, Error>
116where
117 Error: std::error::Error + Send + Sync + 'static,
118 CancellationMessage: SparkChannelCancellationTrait,
119{
120 #[must_use]
122 pub const fn new(
123 sender: mpsc::Sender<
124 SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>,
125 >,
126 ) -> Self {
127 Self {
128 sender,
129 _error_marker: PhantomData,
130 }
131 }
132
133 async fn internal_request<Req, Resp, R>(&self, request: Req) -> R::Output
135 where
136 Req: Into<Message> + Send + 'static,
137 Resp: 'static + Send,
138 Response: 'static + Send + AsRef<dyn Any + Send>,
139 R: IntoResult<Resp, Error>,
140 Error: Debug + Send + Sync + From<String> + From<&'static str>,
141 {
142 let (callback_tx, callback_rx) = oneshot::channel();
143
144 let wrapper = CallbackWrapper {
145 message: request.into(),
146 sender: callback_tx,
147 };
148
149 let send_result = self
151 .sender
152 .send(SparkGenericModuleMessage::Request(wrapper))
153 .await;
154
155 if let Err(_) = send_result {
156 return R::into_result(Err(Error::from("Failed to send request")));
157 }
158
159 let receive_result = callback_rx.await;
161 let result = match receive_result {
162 Ok(result) => result,
163 Err(_) => return R::into_result(Err(Error::from("Failed to receive response"))),
164 };
165
166 match result {
168 Ok(response) => {
169 let boxed_response: Box<dyn Any + Send> = Box::new(response);
172 match boxed_response.downcast::<Resp>() {
173 Ok(boxed_resp) => R::into_result(Ok(*boxed_resp)), Err(_) => R::into_result(Err(Error::from("Invalid response type"))),
175 }
176 },
177 Err(err) => {
178 let error_message = format!("Error handling request: {:?}", err);
180 tracing::debug!("Received error on handler callback: {}", error_message);
181 R::into_result(Err(Error::from(error_message)))
182 },
183 }
184 }
185
186 pub async fn request<Req>(&self, request: Req) -> Result<Response, Error>
188 where
189 Req: Into<Message> + Send + 'static,
190 Response: 'static + Send,
191 Response: AsRef<dyn Any + Send>,
192 Result<Response, Error>: IntoResult<Response, Error, Output = Result<Response, Error>>,
193 Error: Debug + Send + Sync + From<String> + From<&'static str>,
194 {
195 self.internal_request::<_, Response, Result<_, Error>>(request)
197 .await
198 }
199
200 pub async fn send_command<C, R>(&self, command: C) -> R::Output
202 where
203 C: Into<Message> + Send + 'static,
204 R: IntoResult<(), Error>,
205 Error: From<&'static str>,
206 {
207 let send_result = self
208 .sender
209 .send(SparkGenericModuleMessage::Command(command.into()))
210 .await;
211
212 match send_result {
213 Ok(_) => R::into_result(Ok(())),
214 Err(_) => R::into_result(Err(Error::from("Failed to send command"))),
215 }
216 }
217}
218
219pub async fn run_module_server<Message, Response, CancellationToken, Error, H>(
221 mut handler: H,
222 mut receiver: mpsc::Receiver<
223 SparkGenericModuleMessage<Message, Response, CancellationToken, Error>,
224 >,
225) where
226 Message: Send + 'static,
227 Response: Debug + Send + 'static,
228 CancellationToken: SparkChannelCancellationTrait + Send + 'static,
229 Error: std::error::Error + Send + Sync + 'static,
230 H: SparkGenericModuleHandler<Message, Response, Error> + Send,
231{
232 while let Some(message) = receiver.recv().await {
233 match message {
234 SparkGenericModuleMessage::Request(wrapper) => {
235 let (request, callback) = wrapper.inner_owned();
236
237 let result = handler.handle_request(request).await;
239 if let Err(err) = &result {
240 tracing::error!("Error handling request: {:?}", err);
241 }
242
243 if let Err(err) = callback.send(result) {
244 tracing::error!("Failed to send result to callback: {:?}", err);
245 }
246 },
247 SparkGenericModuleMessage::Command(command) => {
248 let result = handler.handle_command(command).await;
249 if let Err(err) = &result {
250 tracing::error!("Error handling command: {:?}", err);
251 }
252 },
253 SparkGenericModuleMessage::Shutdown(cancellation_token) => {
254 cancellation_token.cancel_execution();
255 break;
256 },
257 }
258 }
259}