spark_channel/
listener.rs

1//! # Spark Listener
2//!
3//! This crate is used to listen for messages between modules. This crate is not specific to any of the traits but a generic implementation to implement them in an event-driven architecture.
4
5use core::any::Any;
6use std::{fmt::Debug, marker::PhantomData};
7
8use async_trait::async_trait;
9use tokio::sync::{mpsc, oneshot};
10
11use crate::callback::CallbackWrapper;
12
13/// A trait for handling module requests and commands.
14#[async_trait]
15pub trait SparkGenericModuleHandler<Message, Response, Error = eyre::Error> {
16    /// Handle a request, potentially returning an error.
17    async fn handle_request(&mut self, request: Message) -> Result<Response, Error>;
18
19    /// Handle a command, potentially returning an error.
20    async fn handle_command(&mut self, command: Message) -> Result<(), Error>;
21}
22
23/// Trait for cancellation of execution.
24pub trait SparkChannelCancellationTrait {
25    /// Cancels the execution of a running task.
26    fn cancel_execution(&self);
27}
28
29/// The message type for the module dispatcher.
30pub enum SparkGenericModuleMessage<
31    Message,
32    Response,
33    CancellationMessage: SparkChannelCancellationTrait,
34    Error,
35> {
36    /// A request message with an error type.
37    Request(CallbackWrapper<Message, Result<Response, Error>>),
38
39    /// A command message.
40    Command(Message),
41
42    /// A shutdown message.
43    Shutdown(CancellationMessage),
44}
45
46/// Trait for converting between result types.
47pub trait IntoResult<Success, Error> {
48    /// The output type.
49    type Output;
50
51    /// Convert a result to the output type.
52    fn into_result(result: Result<Success, Error>) -> Self::Output;
53}
54
55// Implementation for eyre::Result
56impl<T, E> IntoResult<T, E> for eyre::Result<T>
57where
58    E: std::fmt::Display + Send + Sync + 'static,
59{
60    type Output = eyre::Result<T>;
61
62    fn into_result(result: Result<T, E>) -> Self::Output {
63        result.map_err(|e| eyre::eyre!("{}", e))
64    }
65}
66
67/// Implementation specific to SparkChannelError. This can be overridden by the user to use their own error type.
68#[derive(Debug, Clone)]
69pub struct SparkChannelError(pub String);
70
71impl std::fmt::Display for SparkChannelError {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        write!(f, "{}", self.0)
74    }
75}
76
77impl std::error::Error for SparkChannelError {}
78
79impl From<&'static str> for SparkChannelError {
80    fn from(s: &'static str) -> Self {
81        SparkChannelError(s.to_string())
82    }
83}
84
85impl From<String> for SparkChannelError {
86    fn from(s: String) -> Self {
87        SparkChannelError(s)
88    }
89}
90
91impl<T> IntoResult<T, SparkChannelError> for Result<T, SparkChannelError> {
92    type Output = Result<T, SparkChannelError>;
93
94    fn into_result(result: Result<T, SparkChannelError>) -> Self::Output {
95        result
96    }
97}
98
99/// A dispatcher for module types.
100#[derive(Clone)]
101pub struct SparkGenericModuleDispatcher<
102    Message,
103    Response,
104    CancellationMessage: SparkChannelCancellationTrait,
105    Error,
106> {
107    /// The sender for the module dispatcher.
108    pub sender:
109        mpsc::Sender<SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>>,
110    /// Phantom data to mark the Error type
111    _error_marker: PhantomData<Error>,
112}
113
114impl<Message, Response, CancellationMessage, Error>
115    SparkGenericModuleDispatcher<Message, Response, CancellationMessage, Error>
116where
117    Error: std::error::Error + Send + Sync + 'static,
118    CancellationMessage: SparkChannelCancellationTrait,
119{
120    /// Creates a new module dispatcher.
121    #[must_use]
122    pub const fn new(
123        sender: mpsc::Sender<
124            SparkGenericModuleMessage<Message, Response, CancellationMessage, Error>,
125        >,
126    ) -> Self {
127        Self {
128            sender,
129            _error_marker: PhantomData,
130        }
131    }
132
133    /// Send a request and get typed response using your callback pattern.
134    async fn internal_request<Req, Resp, R>(&self, request: Req) -> R::Output
135    where
136        Req: Into<Message> + Send + 'static,
137        Resp: 'static + Send,
138        Response: 'static + Send + AsRef<dyn Any + Send>,
139        R: IntoResult<Resp, Error>,
140        Error: Debug + Send + Sync + From<String> + From<&'static str>,
141    {
142        let (callback_tx, callback_rx) = oneshot::channel();
143
144        let wrapper = CallbackWrapper {
145            message: request.into(),
146            sender: callback_tx,
147        };
148
149        // Handle channel send error
150        let send_result = self
151            .sender
152            .send(SparkGenericModuleMessage::Request(wrapper))
153            .await;
154
155        if let Err(_) = send_result {
156            return R::into_result(Err(Error::from("Failed to send request")));
157        }
158
159        // Handle channel receive error
160        let receive_result = callback_rx.await;
161        let result = match receive_result {
162            Ok(result) => result,
163            Err(_) => return R::into_result(Err(Error::from("Failed to receive response"))),
164        };
165
166        // Handle application result (success or error)
167        match result {
168            Ok(response) => {
169                // Try to downcast the response by moving it into a Box
170                // and then attempting to downcast the Box<dyn Any + Send> to Box<Resp>.
171                let boxed_response: Box<dyn Any + Send> = Box::new(response);
172                match boxed_response.downcast::<Resp>() {
173                    Ok(boxed_resp) => R::into_result(Ok(*boxed_resp)), // Unbox to get owned Resp
174                    Err(_) => R::into_result(Err(Error::from("Invalid response type"))),
175                }
176            },
177            Err(err) => {
178                // Use a static string with From trait to avoid lifetime issues
179                let error_message = format!("Error handling request: {:?}", err);
180                tracing::debug!("Received error on handler callback: {}", error_message);
181                R::into_result(Err(Error::from(error_message)))
182            },
183        }
184    }
185
186    /// Sends a request and get typed response using your callback pattern.
187    pub async fn request<Req>(&self, request: Req) -> Result<Response, Error>
188    where
189        Req: Into<Message> + Send + 'static,
190        Response: 'static + Send,
191        Response: AsRef<dyn Any + Send>,
192        Result<Response, Error>: IntoResult<Response, Error, Output = Result<Response, Error>>,
193        Error: Debug + Send + Sync + From<String> + From<&'static str>,
194    {
195        // Use the existing send_request method but with the known types
196        self.internal_request::<_, Response, Result<_, Error>>(request)
197            .await
198    }
199
200    /// Send a command (fire and forget).
201    pub async fn send_command<C, R>(&self, command: C) -> R::Output
202    where
203        C: Into<Message> + Send + 'static,
204        R: IntoResult<(), Error>,
205        Error: From<&'static str>,
206    {
207        let send_result = self
208            .sender
209            .send(SparkGenericModuleMessage::Command(command.into()))
210            .await;
211
212        match send_result {
213            Ok(_) => R::into_result(Ok(())),
214            Err(_) => R::into_result(Err(Error::from("Failed to send command"))),
215        }
216    }
217}
218
219/// Runs the module server.
220pub async fn run_module_server<Message, Response, CancellationToken, Error, H>(
221    mut handler: H,
222    mut receiver: mpsc::Receiver<
223        SparkGenericModuleMessage<Message, Response, CancellationToken, Error>,
224    >,
225) where
226    Message: Send + 'static,
227    Response: Debug + Send + 'static,
228    CancellationToken: SparkChannelCancellationTrait + Send + 'static,
229    Error: std::error::Error + Send + Sync + 'static,
230    H: SparkGenericModuleHandler<Message, Response, Error> + Send,
231{
232    while let Some(message) = receiver.recv().await {
233        match message {
234            SparkGenericModuleMessage::Request(wrapper) => {
235                let (request, callback) = wrapper.inner_owned();
236
237                // Call the handler and get the Result
238                let result = handler.handle_request(request).await;
239                if let Err(err) = &result {
240                    tracing::error!("Error handling request: {:?}", err);
241                }
242
243                if let Err(err) = callback.send(result) {
244                    tracing::error!("Failed to send result to callback: {:?}", err);
245                }
246            },
247            SparkGenericModuleMessage::Command(command) => {
248                let result = handler.handle_command(command).await;
249                if let Err(err) = &result {
250                    tracing::error!("Error handling command: {:?}", err);
251                }
252            },
253            SparkGenericModuleMessage::Shutdown(cancellation_token) => {
254                cancellation_token.cancel_execution();
255                break;
256            },
257        }
258    }
259}