turbomcp_client/client/dispatcher.rs
1//! Message dispatcher for routing JSON-RPC messages
2//!
3//! This module implements the message routing layer that solves the bidirectional
4//! communication problem. It runs a background task that reads ALL messages from
5//! the transport and routes them appropriately:
6//!
7//! - **Responses** → Routed to waiting `request()` calls via oneshot channels
8//! - **Requests** → Routed to registered request handler (for elicitation, sampling, etc.)
9//! - **Notifications** → Routed to registered notification handler
10//!
11//! ## Architecture
12//!
13//! ```text
14//! ┌──────────────────────────────────────────────┐
15//! │ MessageDispatcher │
16//! │ │
17//! │ Background Task (tokio::spawn): │
18//! │ loop { │
19//! │ msg = transport.receive().await │
20//! │ match parse(msg) { │
21//! │ Response => send to oneshot channel │
22//! │ Request => call request_handler │
23//! │ Notification => call notif_handler │
24//! │ } │
25//! │ } │
26//! └──────────────────────────────────────────────┘
27//! ```
28//!
29//! This ensures that there's only ONE consumer of `transport.receive()`,
30//! eliminating race conditions by centralizing all message routing.
31
32use std::collections::HashMap;
33use std::sync::{Arc, Mutex}; // Use std::sync::Mutex for simpler synchronous access
34
35use tokio::sync::{Notify, oneshot};
36use turbomcp_protocol::jsonrpc::{
37 JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse,
38};
39use turbomcp_protocol::{Error, MessageId, Result};
40use turbomcp_transport::{Transport, TransportMessage};
41
42/// Type alias for request handler functions
43///
44/// The handler receives a request and processes it asynchronously.
45/// It's responsible for sending responses back via the transport.
46type RequestHandler = Arc<dyn Fn(JsonRpcRequest) -> Result<()> + Send + Sync>;
47
48/// Type alias for notification handler functions
49///
50/// The handler receives a notification and processes it asynchronously.
51type NotificationHandler = Arc<dyn Fn(JsonRpcNotification) -> Result<()> + Send + Sync>;
52
53/// Message dispatcher that routes incoming JSON-RPC messages
54///
55/// The dispatcher solves the bidirectional communication problem by being the
56/// SINGLE consumer of `transport.receive()`. It runs a background task that
57/// continuously reads messages and routes them to the appropriate handlers.
58///
59/// # Design Principles
60///
61/// 1. **Single Responsibility**: Only handles message routing, not processing
62/// 2. **Thread-Safe**: All state protected by Arc<Mutex<...>>
63/// 3. **Graceful Shutdown**: Supports clean shutdown via Notify signal
64/// 4. **Error Resilient**: Continues running even if individual messages fail
65/// 5. **Production-Ready**: Comprehensive logging and error handling
66///
67/// # Known Limitations
68///
69/// **Response Waiter Cleanup**: If a request is made but the response never arrives
70/// (e.g., server crash, network partition), the oneshot sender remains in the
71/// `response_waiters` HashMap indefinitely. This has minimal impact because:
72/// - Oneshot senders have a small memory footprint (~24 bytes)
73/// - In practice, responses arrive or clients timeout and drop the receiver
74/// - When a receiver is dropped, the send fails gracefully (error is ignored)
75///
76/// Future enhancement: Add a background cleanup task or request timeout mechanism
77/// to remove stale entries after a configurable duration.
78///
79/// # Example
80///
81/// ```rust,ignore
82/// let dispatcher = MessageDispatcher::new(Arc::new(transport));
83///
84/// // Register handlers
85/// dispatcher.set_request_handler(Arc::new(|req| {
86/// // Handle server-initiated requests (elicitation, sampling)
87/// Ok(())
88/// })).await;
89///
90/// // Wait for a response to a specific request
91/// let id = MessageId::from("req-123");
92/// let receiver = dispatcher.wait_for_response(id.clone()).await;
93///
94/// // The background task routes the response when it arrives
95/// let response = receiver.await?;
96/// ```
97pub(super) struct MessageDispatcher {
98 /// Map of request IDs to oneshot senders for response routing
99 ///
100 /// When `ProtocolClient::request()` sends a request, it registers a oneshot
101 /// channel here. When the dispatcher receives the corresponding response,
102 /// it sends it through the channel.
103 response_waiters: Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
104
105 /// Optional handler for server-initiated requests (elicitation, sampling)
106 ///
107 /// This is set by the Client to handle incoming requests from the server.
108 /// The handler is responsible for processing the request and sending a response.
109 request_handler: Arc<Mutex<Option<RequestHandler>>>,
110
111 /// Optional handler for server-initiated notifications
112 ///
113 /// This is set by the Client to handle incoming notifications from the server.
114 notification_handler: Arc<Mutex<Option<NotificationHandler>>>,
115
116 /// Shutdown signal for graceful termination
117 ///
118 /// When `shutdown()` is called, this notify wakes up the background task
119 /// which then exits cleanly.
120 shutdown: Arc<Notify>,
121}
122
123impl MessageDispatcher {
124 /// Create a new message dispatcher and start the background routing task
125 ///
126 /// The dispatcher immediately spawns a background task that continuously
127 /// reads messages from the transport and routes them appropriately.
128 ///
129 /// # Arguments
130 ///
131 /// * `transport` - The transport to read messages from
132 ///
133 /// # Returns
134 ///
135 /// Returns a new `MessageDispatcher` with the routing task running.
136 pub fn new<T: Transport + 'static>(transport: Arc<T>) -> Arc<Self> {
137 let dispatcher = Arc::new(Self {
138 response_waiters: Arc::new(Mutex::new(HashMap::new())),
139 request_handler: Arc::new(Mutex::new(None)),
140 notification_handler: Arc::new(Mutex::new(None)),
141 shutdown: Arc::new(Notify::new()),
142 });
143
144 // Start background routing task
145 Self::spawn_routing_task(dispatcher.clone(), transport);
146
147 dispatcher
148 }
149
150 /// Register a request handler for server-initiated requests
151 ///
152 /// This handler will be called when the server sends a request (like
153 /// elicitation/create or sampling/createMessage). The handler is responsible
154 /// for processing the request and sending a response back.
155 ///
156 /// # Arguments
157 ///
158 /// * `handler` - Function to handle incoming requests
159 pub fn set_request_handler(&self, handler: RequestHandler) {
160 *self.request_handler.lock().expect("handler mutex poisoned") = Some(handler);
161 tracing::debug!("Request handler registered with dispatcher");
162 }
163
164 /// Register a notification handler for server-initiated notifications
165 ///
166 /// This handler will be called when the server sends a notification.
167 ///
168 /// # Arguments
169 ///
170 /// * `handler` - Function to handle incoming notifications
171 pub fn set_notification_handler(&self, handler: NotificationHandler) {
172 *self
173 .notification_handler
174 .lock()
175 .expect("handler mutex poisoned") = Some(handler);
176 tracing::debug!("Notification handler registered with dispatcher");
177 }
178
179 /// Wait for a response to a specific request ID
180 ///
181 /// This method is called by `ProtocolClient::request()` before sending a request.
182 /// It registers a oneshot channel that will receive the response when it arrives.
183 ///
184 /// # Arguments
185 ///
186 /// * `id` - The request ID to wait for
187 ///
188 /// # Returns
189 ///
190 /// Returns a oneshot receiver that will be sent the response when it arrives.
191 ///
192 /// # Example
193 ///
194 /// ```rust,ignore
195 /// // Register waiter before sending request
196 /// let id = MessageId::from("req-123");
197 /// let receiver = dispatcher.wait_for_response(id.clone()).await;
198 ///
199 /// // Send request...
200 ///
201 /// // Wait for response
202 /// let response = receiver.await?;
203 /// ```
204 pub fn wait_for_response(&self, id: MessageId) -> oneshot::Receiver<JsonRpcResponse> {
205 let (tx, rx) = oneshot::channel();
206 self.response_waiters
207 .lock()
208 .expect("response_waiters mutex poisoned")
209 .insert(id.clone(), tx);
210 tracing::trace!("Registered response waiter for request ID: {:?}", id);
211 rx
212 }
213
214 /// Signal the dispatcher to shutdown gracefully
215 ///
216 /// This notifies the background routing task to exit cleanly.
217 /// The task will finish processing the current message and then terminate.
218 ///
219 /// This method is called automatically when the Client is dropped,
220 /// ensuring proper cleanup of background resources.
221 pub fn shutdown(&self) {
222 self.shutdown.notify_one();
223 tracing::info!("Message dispatcher shutdown initiated");
224 }
225
226 /// Spawn the background routing task
227 ///
228 /// This task continuously reads messages from the transport and routes them
229 /// to the appropriate handlers. It runs until `shutdown()` is called or
230 /// the transport is closed.
231 ///
232 /// # Arguments
233 ///
234 /// * `dispatcher` - Arc reference to the dispatcher
235 /// * `transport` - Arc reference to the transport
236 fn spawn_routing_task<T: Transport + 'static>(dispatcher: Arc<Self>, transport: Arc<T>) {
237 let response_waiters = dispatcher.response_waiters.clone();
238 let request_handler = dispatcher.request_handler.clone();
239 let notification_handler = dispatcher.notification_handler.clone();
240 let shutdown = dispatcher.shutdown.clone();
241
242 tokio::spawn(async move {
243 tracing::info!("Message dispatcher routing task started");
244
245 loop {
246 tokio::select! {
247 // Graceful shutdown
248 _ = shutdown.notified() => {
249 tracing::info!("Message dispatcher routing task shutting down");
250 break;
251 }
252
253 // Read and route messages
254 result = transport.receive() => {
255 match result {
256 Ok(Some(msg)) => {
257 // Route the message
258 if let Err(e) = Self::route_message(
259 msg,
260 &response_waiters,
261 &request_handler,
262 ¬ification_handler,
263 ).await {
264 tracing::error!("Error routing message: {}", e);
265 }
266 }
267 Ok(None) => {
268 // No message available - transport returned None
269 // Brief sleep to avoid busy-waiting
270 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
271 }
272 Err(e) => {
273 tracing::error!("Transport receive error: {}", e);
274 // Brief delay before retry to avoid tight error loop
275 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
276 }
277 }
278 }
279 }
280 }
281
282 tracing::info!("Message dispatcher routing task terminated");
283 });
284 }
285
286 /// Route an incoming message to the appropriate handler
287 ///
288 /// This is the core routing logic. It parses the raw transport message as
289 /// a JSON-RPC message and routes it based on type:
290 ///
291 /// - **Response**: Look up the waiting oneshot channel and send the response
292 /// - **Request**: Call the registered request handler
293 /// - **Notification**: Call the registered notification handler
294 ///
295 /// # Arguments
296 ///
297 /// * `msg` - The raw transport message to route
298 /// * `response_waiters` - Map of request IDs to oneshot senders
299 /// * `request_handler` - Optional request handler
300 /// * `notification_handler` - Optional notification handler
301 ///
302 /// # Errors
303 ///
304 /// Returns an error if the message cannot be parsed as valid JSON-RPC.
305 /// Handler errors are logged but do not propagate.
306 async fn route_message(
307 msg: TransportMessage,
308 response_waiters: &Arc<Mutex<HashMap<MessageId, oneshot::Sender<JsonRpcResponse>>>>,
309 request_handler: &Arc<Mutex<Option<RequestHandler>>>,
310 notification_handler: &Arc<Mutex<Option<NotificationHandler>>>,
311 ) -> Result<()> {
312 // Parse as JSON-RPC message
313 let json_msg: JsonRpcMessage = serde_json::from_slice(&msg.payload)
314 .map_err(|e| Error::protocol(format!("Invalid JSON-RPC message: {}", e)))?;
315
316 match json_msg {
317 JsonRpcMessage::Response(response) => {
318 // Route to waiting request() call
319 // ResponseId is Option<RequestId> where RequestId = MessageId
320 if let Some(request_id) = &response.id.0 {
321 if let Some(tx) = response_waiters
322 .lock()
323 .expect("response_waiters mutex poisoned")
324 .remove(request_id)
325 {
326 tracing::trace!("Routing response to request ID: {:?}", request_id);
327 // Send response through oneshot channel
328 // Ignore error if receiver was dropped (request timed out)
329 let _ = tx.send(response);
330 } else {
331 tracing::warn!(
332 "Received response for unknown/expired request ID: {:?}",
333 request_id
334 );
335 }
336 } else {
337 tracing::warn!("Received response with null ID (parse error)");
338 }
339 }
340
341 JsonRpcMessage::Request(request) => {
342 // Route to request handler (elicitation, sampling, etc.)
343 tracing::debug!(
344 "Routing server-initiated request: method={}, id={:?}",
345 request.method,
346 request.id
347 );
348
349 if let Some(handler) = request_handler
350 .lock()
351 .expect("request_handler mutex poisoned")
352 .as_ref()
353 {
354 // Call handler (handler is responsible for sending response)
355 if let Err(e) = handler(request) {
356 tracing::error!("Request handler error: {}", e);
357 }
358 } else {
359 tracing::warn!(
360 "Received server request but no handler registered: method={}",
361 request.method
362 );
363 }
364 }
365
366 JsonRpcMessage::Notification(notification) => {
367 // Route to notification handler
368 tracing::debug!(
369 "Routing server notification: method={}",
370 notification.method
371 );
372
373 if let Some(handler) = notification_handler
374 .lock()
375 .expect("notification_handler mutex poisoned")
376 .as_ref()
377 {
378 if let Err(e) = handler(notification) {
379 tracing::error!("Notification handler error: {}", e);
380 }
381 } else {
382 tracing::debug!(
383 "Received notification but no handler registered: method={}",
384 notification.method
385 );
386 }
387 }
388
389 JsonRpcMessage::RequestBatch(_)
390 | JsonRpcMessage::ResponseBatch(_)
391 | JsonRpcMessage::MessageBatch(_) => {
392 // Batch operations not yet supported
393 tracing::debug!("Received batch message (not yet supported)");
394 }
395 }
396
397 Ok(())
398 }
399}
400
401impl std::fmt::Debug for MessageDispatcher {
402 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
403 f.debug_struct("MessageDispatcher")
404 .field("response_waiters", &"<Arc<Mutex<HashMap>>>")
405 .field("request_handler", &"<Arc<Mutex<Option<Handler>>>>")
406 .field("notification_handler", &"<Arc<Mutex<Option<Handler>>>>")
407 .field("shutdown", &"<Arc<Notify>>")
408 .finish()
409 }
410}
411
412#[cfg(test)]
413mod tests {
414
415 // Note: Full integration tests with mock transport will be added
416 // in tests/bidirectional_integration.rs
417
418 #[test]
419 fn test_dispatcher_creation() {
420 // Smoke test to ensure the module compiles and basic structures work
421 // Full testing requires a mock transport
422 }
423}