turbomcp_client/client/
dispatcher.rs

1//! Message dispatcher for routing JSON-RPC messages
2//!
3//! This module implements the message routing layer that solves the bidirectional
4//! communication problem. It runs a background task that reads ALL messages from
5//! the transport and routes them appropriately:
6//!
7//! - **Responses** → Routed to waiting `request()` calls via oneshot channels
8//! - **Requests** → Routed to registered request handler (for elicitation, sampling, etc.)
9//! - **Notifications** → Routed to registered notification handler
10//!
11//! ## Architecture
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────┐
15//! │          MessageDispatcher                   │
16//! │                                              │
17//! │  Background Task (tokio::spawn):             │
18//! │  loop {                                      │
19//! │    msg = transport.receive().await           │
20//! │    match parse(msg) {                        │
21//! │      Response => send to oneshot channel     │
22//! │      Request => call request_handler         │
23//! │      Notification => call notif_handler      │
24//! │    }                                         │
25//! │  }                                           │
26//! └──────────────────────────────────────────────┘
27//! ```
28//!
29//! This ensures that there's only ONE consumer of `transport.receive()`,
30//! eliminating race conditions by centralizing all message routing.
31
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex}; // Use std::sync::Mutex for simpler synchronous access
34
35use tokio::sync::{Notify, oneshot};
36use turbomcp_protocol::jsonrpc::{
37    JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
38};
39use turbomcp_protocol::{Error, MessageId, Result};
40use turbomcp_transport::{Transport, TransportMessage};
41
42/// Type alias for request handler functions
43///
44/// The handler receives a request and processes it asynchronously.
45/// It's responsible for sending responses back via the transport.
46type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> Result<()> + Send + Sync>;
47
48/// Type alias for notification handler functions
49///
50/// The handler receives a notification and processes it asynchronously.
51type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> Result<()> + Send + Sync>;
52
53/// Message dispatcher that routes incoming JSON-RPC messages
54///
55/// The dispatcher solves the bidirectional communication problem by being the
56/// SINGLE consumer of `transport.receive()`. It runs a background task that
57/// continuously reads messages and routes them to the appropriate handlers.
58///
59/// # Design Principles
60///
61/// 1. **Single Responsibility**: Only handles message routing, not processing
62/// 2. **Thread-Safe**: All state protected by Arc<Mutex<...>>
63/// 3. **Graceful Shutdown**: Supports clean shutdown via Notify signal
64/// 4. **Error Resilient**: Continues running even if individual messages fail
65/// 5. **Production-Ready**: Comprehensive logging and error handling
66///
67/// # Known Limitations
68///
69/// **Response Waiter Cleanup**: If a request is made but the response never arrives
70/// (e.g., server crash, network partition), the oneshot sender remains in the
71/// `response_waiters` HashMap indefinitely. This has minimal impact because:
72/// - Oneshot senders have a small memory footprint (~24 bytes)
73/// - In practice, responses arrive or clients timeout and drop the receiver
74/// - When a receiver is dropped, the send fails gracefully (error is ignored)
75///
76/// Future enhancement: Add a background cleanup task or request timeout mechanism
77/// to remove stale entries after a configurable duration.
78///
79/// # Example
80///
81/// ```rust,ignore
82/// let dispatcher = MessageDispatcher::new(Arc::new(transport));
83///
84/// // Register handlers
85/// dispatcher.set_request_handler(Arc::new(|req| {
86///     // Handle server-initiated requests (elicitation, sampling)
87///     Ok(())
88/// })).await;
89///
90/// // Wait for a response to a specific request
91/// let id = MessageId::from("req-123");
92/// let receiver = dispatcher.wait_for_response(id.clone()).await;
93///
94/// // The background task routes the response when it arrives
95/// let response = receiver.await?;
96/// ```
97pub(super) struct MessageDispatcher {
98    /// Map of request IDs to oneshot senders for response routing
99    ///
100    /// When `ProtocolClient::request()` sends a request, it registers a oneshot
101    /// channel here. When the dispatcher receives the corresponding response,
102    /// it sends it through the channel.
103    response_waiters: Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
104
105    /// Optional handler for server-initiated requests (elicitation, sampling)
106    ///
107    /// This is set by the Client to handle incoming requests from the server.
108    /// The handler is responsible for processing the request and sending a response.
109    request_handler: Arc<Mutex<Option<RequestHandler>>>,
110
111    /// Optional handler for server-initiated notifications
112    ///
113    /// This is set by the Client to handle incoming notifications from the server.
114    notification_handler: Arc<Mutex<Option<NotificationHandler>>>,
115
116    /// Shutdown signal for graceful termination
117    ///
118    /// When `shutdown()` is called, this notify wakes up the background task
119    /// which then exits cleanly.
120    shutdown: Arc<Notify>,
121}
122
123impl MessageDispatcher {
124    /// Create a new message dispatcher and start the background routing task
125    ///
126    /// The dispatcher immediately spawns a background task that continuously
127    /// reads messages from the transport and routes them appropriately.
128    ///
129    /// # Arguments
130    ///
131    /// * `transport` - The transport to read messages from
132    ///
133    /// # Returns
134    ///
135    /// Returns a new `MessageDispatcher` with the routing task running.
136    pub fn new<T: Transport + 'static>(transport: Arc<T>) -> Arc<Self> {
137        let dispatcher = Arc::new(Self {
138            response_waiters: Arc::new(Mutex::new(HashMap::new())),
139            request_handler: Arc::new(Mutex::new(None)),
140            notification_handler: Arc::new(Mutex::new(None)),
141            shutdown: Arc::new(Notify::new()),
142        });
143
144        // Start background routing task
145        Self::spawn_routing_task(dispatcher.clone(), transport);
146
147        dispatcher
148    }
149
150    /// Register a request handler for server-initiated requests
151    ///
152    /// This handler will be called when the server sends a request (like
153    /// elicitation/create or sampling/createMessage). The handler is responsible
154    /// for processing the request and sending a response back.
155    ///
156    /// # Arguments
157    ///
158    /// * `handler` - Function to handle incoming requests
159    pub fn set_request_handler(&self, handler: RequestHandler) {
160        *self.request_handler.lock().expect("handler mutex poisoned") = Some(handler);
161        tracing::debug!("Request handler registered with dispatcher");
162    }
163
164    /// Register a notification handler for server-initiated notifications
165    ///
166    /// This handler will be called when the server sends a notification.
167    ///
168    /// # Arguments
169    ///
170    /// * `handler` - Function to handle incoming notifications
171    pub fn set_notification_handler(&self, handler: NotificationHandler) {
172        *self
173            .notification_handler
174            .lock()
175            .expect("handler mutex poisoned") = Some(handler);
176        tracing::debug!("Notification handler registered with dispatcher");
177    }
178
179    /// Wait for a response to a specific request ID
180    ///
181    /// This method is called by `ProtocolClient::request()` before sending a request.
182    /// It registers a oneshot channel that will receive the response when it arrives.
183    ///
184    /// # Arguments
185    ///
186    /// * `id` - The request ID to wait for
187    ///
188    /// # Returns
189    ///
190    /// Returns a oneshot receiver that will be sent the response when it arrives.
191    ///
192    /// # Example
193    ///
194    /// ```rust,ignore
195    /// // Register waiter before sending request
196    /// let id = MessageId::from("req-123");
197    /// let receiver = dispatcher.wait_for_response(id.clone()).await;
198    ///
199    /// // Send request...
200    ///
201    /// // Wait for response
202    /// let response = receiver.await?;
203    /// ```
204    pub fn wait_for_response(&self, id: MessageId) -> oneshot::Receiver<JsonRpcResponse> {
205        let (tx, rx) = oneshot::channel();
206        self.response_waiters
207            .lock()
208            .expect("response_waiters mutex poisoned")
209            .insert(id.clone(), tx);
210        tracing::trace!("Registered response waiter for request ID: {:?}", id);
211        rx
212    }
213
214    /// Signal the dispatcher to shutdown gracefully
215    ///
216    /// This notifies the background routing task to exit cleanly.
217    /// The task will finish processing the current message and then terminate.
218    ///
219    /// This method is called automatically when the Client is dropped,
220    /// ensuring proper cleanup of background resources.
221    pub fn shutdown(&self) {
222        self.shutdown.notify_one();
223        tracing::info!("Message dispatcher shutdown initiated");
224    }
225
226    /// Spawn the background routing task
227    ///
228    /// This task continuously reads messages from the transport and routes them
229    /// to the appropriate handlers. It runs until `shutdown()` is called or
230    /// the transport is closed.
231    ///
232    /// # Arguments
233    ///
234    /// * `dispatcher` - Arc reference to the dispatcher
235    /// * `transport` - Arc reference to the transport
236    fn spawn_routing_task<T: Transport + 'static>(dispatcher: Arc<Self>, transport: Arc<T>) {
237        let response_waiters = dispatcher.response_waiters.clone();
238        let request_handler = dispatcher.request_handler.clone();
239        let notification_handler = dispatcher.notification_handler.clone();
240        let shutdown = dispatcher.shutdown.clone();
241
242        tokio::spawn(async move {
243            tracing::info!("Message dispatcher routing task started");
244
245            loop {
246                tokio::select! {
247                    // Graceful shutdown
248                    _ = shutdown.notified() => {
249                        tracing::info!("Message dispatcher routing task shutting down");
250                        break;
251                    }
252
253                    // Read and route messages
254                    result = transport.receive() => {
255                        match result {
256                            Ok(Some(msg)) => {
257                                // Route the message
258                                if let Err(e) = Self::route_message(
259                                    msg,
260                                    &response_waiters,
261                                    &request_handler,
262                                    &notification_handler,
263                                ).await {
264                                    tracing::error!("Error routing message: {}", e);
265                                }
266                            }
267                            Ok(None) => {
268                                // No message available - transport returned None
269                                // Brief sleep to avoid busy-waiting
270                                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
271                            }
272                            Err(e) => {
273                                tracing::error!("Transport receive error: {}", e);
274                                // Brief delay before retry to avoid tight error loop
275                                tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
276                            }
277                        }
278                    }
279                }
280            }
281
282            tracing::info!("Message dispatcher routing task terminated");
283        });
284    }
285
286    /// Route an incoming message to the appropriate handler
287    ///
288    /// This is the core routing logic. It parses the raw transport message as
289    /// a JSON-RPC message and routes it based on type:
290    ///
291    /// - **Response**: Look up the waiting oneshot channel and send the response
292    /// - **Request**: Call the registered request handler
293    /// - **Notification**: Call the registered notification handler
294    ///
295    /// # Arguments
296    ///
297    /// * `msg` - The raw transport message to route
298    /// * `response_waiters` - Map of request IDs to oneshot senders
299    /// * `request_handler` - Optional request handler
300    /// * `notification_handler` - Optional notification handler
301    ///
302    /// # Errors
303    ///
304    /// Returns an error if the message cannot be parsed as valid JSON-RPC.
305    /// Handler errors are logged but do not propagate.
306    async fn route_message(
307        msg: TransportMessage,
308        response_waiters: &Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
309        request_handler: &Arc<Mutex<Option<RequestHandler>>>,
310        notification_handler: &Arc<Mutex<Option<NotificationHandler>>>,
311    ) -> Result<()> {
312        // Parse as JSON-RPC message
313        let json_msg: JsonRpcMessage = serde_json::from_slice(&msg.payload)
314            .map_err(|e| Error::protocol(format!("Invalid JSON-RPC message: {}", e)))?;
315
316        match json_msg {
317            JsonRpcMessage::Response(response) => {
318                // Route to waiting request() call
319                // ResponseId is Option<RequestId> where RequestId = MessageId
320                if let Some(request_id) = &response.id.0 {
321                    if let Some(tx) = response_waiters
322                        .lock()
323                        .expect("response_waiters mutex poisoned")
324                        .remove(request_id)
325                    {
326                        tracing::trace!("Routing response to request ID: {:?}", request_id);
327                        // Send response through oneshot channel
328                        // Ignore error if receiver was dropped (request timed out)
329                        let _ = tx.send(response);
330                    } else {
331                        tracing::warn!(
332                            "Received response for unknown/expired request ID: {:?}",
333                            request_id
334                        );
335                    }
336                } else {
337                    tracing::warn!("Received response with null ID (parse error)");
338                }
339            }
340
341            JsonRpcMessage::Request(request) => {
342                // Route to request handler (elicitation, sampling, etc.)
343                tracing::debug!(
344                    "Routing server-initiated request: method={}, id={:?}",
345                    request.method,
346                    request.id
347                );
348
349                if let Some(handler) = request_handler
350                    .lock()
351                    .expect("request_handler mutex poisoned")
352                    .as_ref()
353                {
354                    // Call handler (handler is responsible for sending response)
355                    if let Err(e) = handler(request) {
356                        tracing::error!("Request handler error: {}", e);
357                    }
358                } else {
359                    tracing::warn!(
360                        "Received server request but no handler registered: method={}",
361                        request.method
362                    );
363                }
364            }
365
366            JsonRpcMessage::Notification(notification) => {
367                // Route to notification handler
368                tracing::debug!(
369                    "Routing server notification: method={}",
370                    notification.method
371                );
372
373                if let Some(handler) = notification_handler
374                    .lock()
375                    .expect("notification_handler mutex poisoned")
376                    .as_ref()
377                {
378                    if let Err(e) = handler(notification) {
379                        tracing::error!("Notification handler error: {}", e);
380                    }
381                } else {
382                    tracing::debug!(
383                        "Received notification but no handler registered: method={}",
384                        notification.method
385                    );
386                }
387            }
388
389            JsonRpcMessage::RequestBatch(_)
390            | JsonRpcMessage::ResponseBatch(_)
391            | JsonRpcMessage::MessageBatch(_) => {
392                // Batch operations not yet supported
393                tracing::debug!("Received batch message (not yet supported)");
394            }
395        }
396
397        Ok(())
398    }
399}
400
401impl std::fmt::Debug for MessageDispatcher {
402    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403        f.debug_struct("MessageDispatcher")
404            .field("response_waiters", &"<Arc<Mutex<HashMap>>>")
405            .field("request_handler", &"<Arc<Mutex<Option<Handler>>>>")
406            .field("notification_handler", &"<Arc<Mutex<Option<Handler>>>>")
407            .field("shutdown", &"<Arc<Notify>>")
408            .finish()
409    }
410}
411
412#[cfg(test)]
413mod tests {
414
415    // Note: Full integration tests with mock transport will be added
416    // in tests/bidirectional_integration.rs
417
418    #[test]
419    fn test_dispatcher_creation() {
420        // Smoke test to ensure the module compiles and basic structures work
421        // Full testing requires a mock transport
422    }
423}