turbomcp_server/
runtime.rs

1//! Runtime components for bidirectional transport
2//!
3//! This module provides unified bidirectional communication support for all
4//! duplex transports (STDIO, TCP, Unix Socket, HTTP, WebSocket) with full MCP 2025-06-18 compliance.
5//!
6//! ## Architecture
7//!
8//! **Generic Abstraction**: `TransportDispatcher<T>` works with any `Transport`
9//! - Sends server-initiated requests via transport
10//! - Correlates responses with pending requests
11//! - Implements `ServerRequestDispatcher` trait
12//!
13//! **Specialized Implementations**:
14//! - `StdioDispatcher`: Optimized for stdin/stdout (line-delimited JSON)
15//! - `TransportDispatcher<TcpTransport>`: For TCP sockets
16//! - `TransportDispatcher<UnixTransport>`: For Unix domain sockets
17//! - `http::HttpDispatcher`: For HTTP + SSE sessions (feature-gated)
18//! - `websocket::WebSocketDispatcher`: For WebSocket connections (feature-gated)
19//!
20//! All share the same request correlation and error handling logic.
21
22// HTTP bidirectional support (feature-gated)
23#[cfg(feature = "http")]
24pub mod http;
25
26// WebSocket bidirectional support (feature-gated)
27#[cfg(feature = "websocket")]
28pub mod websocket;
29
30use std::collections::HashMap;
31use std::sync::Arc;
32
33use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
34use tokio::sync::{Mutex, mpsc, oneshot};
35use tokio::task::JoinSet;
36
37use turbomcp_protocol::RequestContext;
38use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
39use turbomcp_protocol::types::{
40    CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
41    ListRootsResult, PingRequest, PingResult,
42};
43
44use crate::routing::{RequestRouter, ServerRequestDispatcher};
45use crate::{ServerError, ServerResult};
46
47type MessageId = turbomcp_protocol::MessageId;
48
49/// STDIO dispatcher for server-initiated requests
50///
51/// This dispatcher implements the MCP 2025-06-18 specification for stdio transport,
52/// allowing servers to make requests to clients (server→client capability).
53#[derive(Clone)]
54pub struct StdioDispatcher {
55    /// Channel for sending messages to stdout writer
56    request_tx: mpsc::UnboundedSender<StdioMessage>,
57    /// Pending server-initiated requests awaiting responses
58    pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
59}
60
61impl std::fmt::Debug for StdioDispatcher {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        f.debug_struct("StdioDispatcher")
64            .field("has_request_tx", &true)
65            .field("pending_count", &"<mutex>")
66            .finish()
67    }
68}
69
70/// Internal message type for STDIO transport
71#[derive(Debug)]
72pub enum StdioMessage {
73    /// Server request to be sent to client
74    ServerRequest {
75        /// The JSON-RPC request
76        request: JsonRpcRequest,
77    },
78    /// Shutdown signal
79    Shutdown,
80}
81
82impl StdioDispatcher {
83    /// Create a new STDIO dispatcher
84    pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
85        Self {
86            request_tx,
87            pending_requests: Arc::new(Mutex::new(HashMap::new())),
88        }
89    }
90
91    /// Send a JSON-RPC request and wait for response
92    async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
93        let (response_tx, response_rx) = oneshot::channel();
94
95        // Extract request ID for correlation
96        let request_id = match &request.id {
97            MessageId::String(s) => s.clone(),
98            MessageId::Number(n) => n.to_string(),
99            MessageId::Uuid(u) => u.to_string(),
100        };
101
102        // Register pending request
103        self.pending_requests
104            .lock()
105            .await
106            .insert(request_id.clone(), response_tx);
107
108        // Send to stdout writer
109        self.request_tx
110            .send(StdioMessage::ServerRequest { request })
111            .map_err(|e| ServerError::Handler {
112                message: format!("Failed to send request to stdout: {}", e),
113                context: Some("stdio_dispatcher".to_string()),
114            })?;
115
116        // Wait for response with timeout (60 seconds per MCP recommendation)
117        match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
118            Ok(Ok(response)) => Ok(response),
119            Ok(Err(_)) => Err(ServerError::Handler {
120                message: "Response channel closed".to_string(),
121                context: Some("stdio_dispatcher".to_string()),
122            }),
123            Err(_) => {
124                // Timeout - remove from pending
125                self.pending_requests.lock().await.remove(&request_id);
126                Err(ServerError::Handler {
127                    message: "Request timeout (60s)".to_string(),
128                    context: Some("stdio_dispatcher".to_string()),
129                })
130            }
131        }
132    }
133
134    /// Generate a unique request ID
135    fn generate_request_id() -> MessageId {
136        MessageId::String(uuid::Uuid::new_v4().to_string())
137    }
138}
139
140#[async_trait::async_trait]
141impl ServerRequestDispatcher for StdioDispatcher {
142    async fn send_elicitation(
143        &self,
144        request: ElicitRequest,
145        _ctx: RequestContext,
146    ) -> ServerResult<ElicitResult> {
147        let json_rpc_request = JsonRpcRequest {
148            jsonrpc: JsonRpcVersion,
149            method: "elicitation/create".to_string(),
150            params: Some(
151                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
152                    message: format!("Failed to serialize elicitation request: {}", e),
153                    context: Some("MCP compliance".to_string()),
154                })?,
155            ),
156            id: Self::generate_request_id(),
157        };
158
159        let response = self.send_request(json_rpc_request).await?;
160
161        if let Some(result) = response.result() {
162            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
163                message: format!("Invalid elicitation response format: {}", e),
164                context: Some("MCP compliance".to_string()),
165            })
166        } else if let Some(error) = response.error() {
167            // Preserve client error code by wrapping as Protocol error
168            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
169                error.code,
170                &error.message,
171            )))
172        } else {
173            Err(ServerError::Handler {
174                message: "Invalid elicitation response: missing result and error".to_string(),
175                context: Some("MCP compliance".to_string()),
176            })
177        }
178    }
179
180    async fn send_ping(
181        &self,
182        _request: PingRequest,
183        _ctx: RequestContext,
184    ) -> ServerResult<PingResult> {
185        let json_rpc_request = JsonRpcRequest {
186            jsonrpc: JsonRpcVersion,
187            method: "ping".to_string(),
188            params: None,
189            id: Self::generate_request_id(),
190        };
191
192        let response = self.send_request(json_rpc_request).await?;
193
194        if response.result().is_some() {
195            Ok(PingResult {
196                data: None,
197                _meta: None,
198            })
199        } else if let Some(error) = response.error() {
200            // Preserve client error code by wrapping as Protocol error
201            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
202                error.code,
203                &error.message,
204            )))
205        } else {
206            Err(ServerError::Handler {
207                message: "Invalid ping response".to_string(),
208                context: Some("MCP compliance".to_string()),
209            })
210        }
211    }
212
213    async fn send_create_message(
214        &self,
215        request: CreateMessageRequest,
216        _ctx: RequestContext,
217    ) -> ServerResult<CreateMessageResult> {
218        let json_rpc_request = JsonRpcRequest {
219            jsonrpc: JsonRpcVersion,
220            method: "sampling/createMessage".to_string(),
221            params: Some(
222                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
223                    message: format!("Failed to serialize sampling request: {}", e),
224                    context: Some("MCP compliance".to_string()),
225                })?,
226            ),
227            id: Self::generate_request_id(),
228        };
229
230        let response = self.send_request(json_rpc_request).await?;
231
232        if let Some(result) = response.result() {
233            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
234                message: format!("Invalid sampling response format: {}", e),
235                context: Some("MCP compliance".to_string()),
236            })
237        } else if let Some(error) = response.error() {
238            // Preserve client error code by wrapping as Protocol error
239            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
240                error.code,
241                &error.message,
242            )))
243        } else {
244            Err(ServerError::Handler {
245                message: "Invalid sampling response: missing result and error".to_string(),
246                context: Some("MCP compliance".to_string()),
247            })
248        }
249    }
250
251    async fn send_list_roots(
252        &self,
253        _request: ListRootsRequest,
254        _ctx: RequestContext,
255    ) -> ServerResult<ListRootsResult> {
256        let json_rpc_request = JsonRpcRequest {
257            jsonrpc: JsonRpcVersion,
258            method: "roots/list".to_string(),
259            params: None,
260            id: Self::generate_request_id(),
261        };
262
263        let response = self.send_request(json_rpc_request).await?;
264
265        if let Some(result) = response.result() {
266            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
267                message: format!("Invalid roots response format: {}", e),
268                context: Some("MCP compliance".to_string()),
269            })
270        } else if let Some(error) = response.error() {
271            // Preserve client error code by wrapping as Protocol error
272            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
273                error.code,
274                &error.message,
275            )))
276        } else {
277            Err(ServerError::Handler {
278                message: "Invalid roots response: missing result and error".to_string(),
279                context: Some("MCP compliance".to_string()),
280            })
281        }
282    }
283
284    fn supports_bidirectional(&self) -> bool {
285        true
286    }
287
288    async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
289        Ok(None)
290    }
291}
292
293/// Run MCP server over STDIO transport with full bidirectional support
294///
295/// This runtime implements the complete MCP 2025-06-18 stdio protocol:
296/// - Reads JSON-RPC from stdin (client requests AND server response correlations)
297/// - Writes JSON-RPC to stdout (server responses AND server requests)
298/// - Maintains request/response correlation
299/// - Handles errors per MCP spec
300/// - Manages task lifecycle with JoinSet for clean shutdown
301pub async fn run_stdio_bidirectional(
302    router: Arc<RequestRouter>,
303    dispatcher: StdioDispatcher,
304    mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
305) -> Result<(), Box<dyn std::error::Error>> {
306    let stdin = tokio::io::stdin();
307    let stdout = tokio::io::stdout();
308    let mut reader = BufReader::new(stdin);
309    let mut line = String::new();
310
311    let stdout = Arc::new(Mutex::new(stdout));
312    let pending_requests = Arc::clone(&dispatcher.pending_requests);
313
314    // ✅ Create JoinSet to manage all spawned tasks
315    let mut tasks = JoinSet::new();
316
317    // ✅ Spawn stdout writer task and store handle in JoinSet
318    let stdout_writer = Arc::clone(&stdout);
319    tasks.spawn(async move {
320        while let Some(msg) = request_rx.recv().await {
321            match msg {
322                StdioMessage::ServerRequest { request } => {
323                    if let Ok(json) = serde_json::to_string(&request) {
324                        let mut stdout = stdout_writer.lock().await;
325                        let _ = stdout.write_all(json.as_bytes()).await;
326                        let _ = stdout.write_all(b"\n").await;
327                        let _ = stdout.flush().await;
328                    }
329                }
330                StdioMessage::Shutdown => break,
331            }
332        }
333    });
334
335    // Main stdin reader loop
336    loop {
337        line.clear();
338        match reader.read_line(&mut line).await {
339            Ok(0) => break, // EOF
340            Ok(_) => {
341                if line.trim().is_empty() {
342                    continue;
343                }
344
345                // Try parsing as JSON-RPC response first (for server-initiated request responses)
346                if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
347                    let request_id = match &response.id {
348                        turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
349                            MessageId::String(s) => s.clone(),
350                            MessageId::Number(n) => n.to_string(),
351                            MessageId::Uuid(u) => u.to_string(),
352                        },
353                        _ => continue,
354                    };
355
356                    if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
357                        let _ = tx.send(response);
358                    }
359                    continue;
360                }
361
362                // Try parsing as JSON-RPC request (client-initiated)
363                if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
364                    let router = Arc::clone(&router);
365                    let stdout = Arc::clone(&stdout);
366
367                    // ✅ Spawn request handler and store handle in JoinSet
368                    tasks.spawn(async move {
369                        // Create properly configured context with server-to-client capabilities
370                        let ctx = router.create_context();
371                        let response = router.route(request, ctx).await;
372
373                        if let Ok(json) = serde_json::to_string(&response) {
374                            let mut stdout = stdout.lock().await;
375                            let _ = stdout.write_all(json.as_bytes()).await;
376                            let _ = stdout.write_all(b"\n").await;
377                            let _ = stdout.flush().await;
378                        }
379                    });
380                }
381            }
382            Err(_) => break,
383        }
384    }
385
386    // ✅ GRACEFUL SHUTDOWN: Wait for all tasks to complete
387    tracing::debug!(
388        "STDIO dispatcher shutting down, waiting for {} tasks",
389        tasks.len()
390    );
391
392    // Signal writer task to shutdown by dropping the channel
393    // The request_rx.recv() in the writer task will return None, causing it to exit
394    drop(dispatcher);
395
396    // Wait for all tasks to complete with timeout
397    let shutdown_timeout = std::time::Duration::from_secs(5);
398    let start = std::time::Instant::now();
399
400    while let Some(result) = tokio::time::timeout(
401        shutdown_timeout.saturating_sub(start.elapsed()),
402        tasks.join_next(),
403    )
404    .await
405    .ok()
406    .flatten()
407    {
408        match result {
409            Ok(()) => {
410                tracing::debug!("Task completed successfully during shutdown");
411            }
412            Err(e) if e.is_panic() => {
413                tracing::warn!("Task panicked during shutdown: {:?}", e);
414            }
415            Err(e) if e.is_cancelled() => {
416                tracing::debug!("Task was cancelled during shutdown");
417            }
418            Err(e) => {
419                tracing::debug!("Task error during shutdown: {:?}", e);
420            }
421        }
422    }
423
424    // ✅ Abort remaining tasks if timeout occurred
425    if !tasks.is_empty() {
426        tracing::warn!(
427            "Aborting {} tasks due to shutdown timeout ({}s)",
428            tasks.len(),
429            shutdown_timeout.as_secs()
430        );
431        tasks.shutdown().await;
432    }
433
434    tracing::debug!("STDIO dispatcher shutdown complete");
435    Ok(())
436}
437
438// ============================================================================
439// Generic Transport Dispatcher (TCP, Unix Socket, and other Transport impls)
440// ============================================================================
441
442/// Generic dispatcher for any Transport implementation
443///
444/// This provides bidirectional MCP support for any transport that implements
445/// the `Transport` trait. Unlike `StdioDispatcher` which directly reads/writes
446/// stdin/stdout, this uses the Transport trait's `send()` and `receive()` methods.
447///
448/// **Usage**:
449/// ```rust,ignore
450/// use turbomcp_transport::TcpTransport;
451/// use turbomcp_server::runtime::TransportDispatcher;
452///
453/// let addr = "127.0.0.1:8080".parse().unwrap();
454/// let transport = TcpTransport::new_server(addr);
455/// let dispatcher = TransportDispatcher::new(transport);
456/// ```
457pub struct TransportDispatcher<T>
458where
459    T: turbomcp_transport::Transport,
460{
461    /// The underlying transport
462    transport: Arc<T>,
463    /// Pending server-initiated requests awaiting responses
464    pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
465}
466
467// Manual Clone implementation: Arc cloning doesn't require T: Clone
468impl<T> Clone for TransportDispatcher<T>
469where
470    T: turbomcp_transport::Transport,
471{
472    fn clone(&self) -> Self {
473        Self {
474            transport: Arc::clone(&self.transport),
475            pending_requests: Arc::clone(&self.pending_requests),
476        }
477    }
478}
479
480impl<T> std::fmt::Debug for TransportDispatcher<T>
481where
482    T: turbomcp_transport::Transport,
483{
484    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485        f.debug_struct("TransportDispatcher")
486            .field("transport_type", &self.transport.transport_type())
487            .field("pending_count", &"<mutex>")
488            .finish()
489    }
490}
491
492impl<T> TransportDispatcher<T>
493where
494    T: turbomcp_transport::Transport,
495{
496    /// Create a new transport dispatcher
497    pub fn new(transport: T) -> Self {
498        Self {
499            transport: Arc::new(transport),
500            pending_requests: Arc::new(Mutex::new(HashMap::new())),
501        }
502    }
503
504    /// Send a JSON-RPC request and wait for response
505    async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
506        use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
507
508        let (response_tx, response_rx) = oneshot::channel();
509
510        // Extract request ID for correlation
511        let request_id = match &request.id {
512            MessageId::String(s) => s.clone(),
513            MessageId::Number(n) => n.to_string(),
514            MessageId::Uuid(u) => u.to_string(),
515        };
516
517        // Register pending request
518        self.pending_requests
519            .lock()
520            .await
521            .insert(request_id.clone(), response_tx);
522
523        // Serialize request to JSON
524        let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
525            message: format!("Failed to serialize request: {}", e),
526            context: Some("transport_dispatcher".to_string()),
527        })?;
528
529        // Send via transport
530        let transport_msg = TransportMessage::with_metadata(
531            MessageId::Uuid(uuid::Uuid::new_v4()),
532            bytes::Bytes::from(json),
533            TransportMessageMetadata::with_content_type("application/json"),
534        );
535
536        self.transport
537            .send(transport_msg)
538            .await
539            .map_err(|e| ServerError::Handler {
540                message: format!("Failed to send request via transport: {}", e),
541                context: Some("transport_dispatcher".to_string()),
542            })?;
543
544        // Wait for response with timeout (60 seconds per MCP recommendation)
545        match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
546            Ok(Ok(response)) => Ok(response),
547            Ok(Err(_)) => Err(ServerError::Handler {
548                message: "Response channel closed".to_string(),
549                context: Some("transport_dispatcher".to_string()),
550            }),
551            Err(_) => {
552                // Timeout - remove from pending
553                self.pending_requests.lock().await.remove(&request_id);
554                Err(ServerError::Handler {
555                    message: "Request timeout (60s)".to_string(),
556                    context: Some("transport_dispatcher".to_string()),
557                })
558            }
559        }
560    }
561
562    /// Generate a unique request ID
563    fn generate_request_id() -> MessageId {
564        MessageId::String(uuid::Uuid::new_v4().to_string())
565    }
566
567    /// Get access to pending requests for response correlation
568    pub fn pending_requests(
569        &self,
570    ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
571        Arc::clone(&self.pending_requests)
572    }
573
574    /// Get access to the transport
575    pub fn transport(&self) -> Arc<T> {
576        Arc::clone(&self.transport)
577    }
578}
579
580#[async_trait::async_trait]
581impl<T> ServerRequestDispatcher for TransportDispatcher<T>
582where
583    T: turbomcp_transport::Transport + Send + Sync + 'static,
584{
585    async fn send_elicitation(
586        &self,
587        request: ElicitRequest,
588        _ctx: RequestContext,
589    ) -> ServerResult<ElicitResult> {
590        let json_rpc_request = JsonRpcRequest {
591            jsonrpc: JsonRpcVersion,
592            method: "elicitation/create".to_string(),
593            params: Some(
594                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
595                    message: format!("Failed to serialize elicitation request: {}", e),
596                    context: Some("MCP compliance".to_string()),
597                })?,
598            ),
599            id: Self::generate_request_id(),
600        };
601
602        let response = self.send_request(json_rpc_request).await?;
603
604        if let Some(result) = response.result() {
605            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
606                message: format!("Invalid elicitation response format: {}", e),
607                context: Some("MCP compliance".to_string()),
608            })
609        } else if let Some(error) = response.error() {
610            // Preserve client error code by wrapping as Protocol error
611            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
612                error.code,
613                &error.message,
614            )))
615        } else {
616            Err(ServerError::Handler {
617                message: "Invalid elicitation response: missing result and error".to_string(),
618                context: Some("MCP compliance".to_string()),
619            })
620        }
621    }
622
623    async fn send_ping(
624        &self,
625        _request: PingRequest,
626        _ctx: RequestContext,
627    ) -> ServerResult<PingResult> {
628        let json_rpc_request = JsonRpcRequest {
629            jsonrpc: JsonRpcVersion,
630            method: "ping".to_string(),
631            params: None,
632            id: Self::generate_request_id(),
633        };
634
635        let response = self.send_request(json_rpc_request).await?;
636
637        if response.result().is_some() {
638            Ok(PingResult {
639                data: None,
640                _meta: None,
641            })
642        } else if let Some(error) = response.error() {
643            // Preserve client error code by wrapping as Protocol error
644            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
645                error.code,
646                &error.message,
647            )))
648        } else {
649            Err(ServerError::Handler {
650                message: "Invalid ping response".to_string(),
651                context: Some("MCP compliance".to_string()),
652            })
653        }
654    }
655
656    async fn send_create_message(
657        &self,
658        request: CreateMessageRequest,
659        _ctx: RequestContext,
660    ) -> ServerResult<CreateMessageResult> {
661        let json_rpc_request = JsonRpcRequest {
662            jsonrpc: JsonRpcVersion,
663            method: "sampling/createMessage".to_string(),
664            params: Some(
665                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
666                    message: format!("Failed to serialize sampling request: {}", e),
667                    context: Some("MCP compliance".to_string()),
668                })?,
669            ),
670            id: Self::generate_request_id(),
671        };
672
673        let response = self.send_request(json_rpc_request).await?;
674
675        if let Some(result) = response.result() {
676            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
677                message: format!("Invalid sampling response format: {}", e),
678                context: Some("MCP compliance".to_string()),
679            })
680        } else if let Some(error) = response.error() {
681            // Preserve client error code by wrapping as Protocol error
682            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
683                error.code,
684                &error.message,
685            )))
686        } else {
687            Err(ServerError::Handler {
688                message: "Invalid sampling response: missing result and error".to_string(),
689                context: Some("MCP compliance".to_string()),
690            })
691        }
692    }
693
694    async fn send_list_roots(
695        &self,
696        _request: ListRootsRequest,
697        _ctx: RequestContext,
698    ) -> ServerResult<ListRootsResult> {
699        let json_rpc_request = JsonRpcRequest {
700            jsonrpc: JsonRpcVersion,
701            method: "roots/list".to_string(),
702            params: None,
703            id: Self::generate_request_id(),
704        };
705
706        let response = self.send_request(json_rpc_request).await?;
707
708        if let Some(result) = response.result() {
709            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
710                message: format!("Invalid roots response format: {}", e),
711                context: Some("MCP compliance".to_string()),
712            })
713        } else if let Some(error) = response.error() {
714            // Preserve client error code by wrapping as Protocol error
715            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
716                error.code,
717                &error.message,
718            )))
719        } else {
720            Err(ServerError::Handler {
721                message: "Invalid roots response: missing result and error".to_string(),
722                context: Some("MCP compliance".to_string()),
723            })
724        }
725    }
726
727    fn supports_bidirectional(&self) -> bool {
728        self.transport.capabilities().supports_bidirectional
729    }
730
731    async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
732        Ok(None)
733    }
734}
735
736/// Run MCP server with any Transport implementation with full bidirectional support
737///
738/// This is a generic runtime that works with TCP, Unix Socket, or any other
739/// transport implementing the `Transport` trait.
740///
741/// **Architecture**:
742/// - Spawns receiver task: reads from transport, routes through router
743/// - Transport send: used for both responses and server-initiated requests
744/// - Correlation: matches responses to pending requests
745///
746/// **Usage**:
747/// ```rust,ignore
748/// use std::sync::Arc;
749/// use turbomcp_transport::TcpTransport;
750/// use turbomcp_server::runtime::{TransportDispatcher, run_transport_bidirectional};
751/// use turbomcp_server::routing::RequestRouter;
752///
753/// let addr = "127.0.0.1:8080".parse().unwrap();
754/// let transport = TcpTransport::new_server(addr);
755/// let dispatcher = TransportDispatcher::new(transport);
756/// let router = Arc::new(RequestRouter::new());
757///
758/// // In async context:
759/// run_transport_bidirectional(router, dispatcher).await?;
760/// ```
761pub async fn run_transport_bidirectional<T>(
762    router: Arc<RequestRouter>,
763    dispatcher: TransportDispatcher<T>,
764) -> Result<(), Box<dyn std::error::Error>>
765where
766    T: turbomcp_transport::Transport + Send + Sync + 'static,
767{
768    let transport = dispatcher.transport();
769    let pending_requests = dispatcher.pending_requests();
770
771    // Main message processing loop
772    loop {
773        // Receive message from transport
774        match transport.receive().await {
775            Ok(Some(message)) => {
776                // Try parsing as JSON-RPC response first (for server-initiated request responses)
777                if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
778                    let request_id = match &response.id {
779                        turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
780                            MessageId::String(s) => s.clone(),
781                            MessageId::Number(n) => n.to_string(),
782                            MessageId::Uuid(u) => u.to_string(),
783                        },
784                        _ => continue,
785                    };
786
787                    if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
788                        let _ = tx.send(response);
789                    }
790                    continue;
791                }
792
793                // Try parsing as JSON-RPC request (client-initiated)
794                if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
795                    let router = Arc::clone(&router);
796                    let transport = Arc::clone(&transport);
797
798                    tokio::spawn(async move {
799                        // Create properly configured context with server-to-client capabilities
800                        let ctx = router.create_context();
801                        let response = router.route(request, ctx).await;
802
803                        // Send response back via transport
804                        if let Ok(json) = serde_json::to_vec(&response) {
805                            use turbomcp_transport::{
806                                TransportMessage, core::TransportMessageMetadata,
807                            };
808                            let transport_msg = TransportMessage::with_metadata(
809                                MessageId::Uuid(uuid::Uuid::new_v4()),
810                                bytes::Bytes::from(json),
811                                TransportMessageMetadata::with_content_type("application/json"),
812                            );
813                            let _ = transport.send(transport_msg).await;
814                        }
815                    });
816                }
817            }
818            Ok(None) => {
819                // No message available, sleep briefly
820                tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
821            }
822            Err(e) => {
823                tracing::error!(error = %e, "Transport receive failed");
824                break;
825            }
826        }
827    }
828
829    Ok(())
830}
831
832// ============================================================================
833// Unit Tests
834// ============================================================================
835
836#[cfg(test)]
837mod tests {
838    use super::*;
839
840    #[tokio::test]
841    async fn test_stdio_dispatcher_clean_shutdown() {
842        // Test that STDIO dispatcher can be dropped without panic
843        let (tx, _rx) = mpsc::unbounded_channel();
844        let dispatcher = StdioDispatcher::new(tx);
845
846        // Should not panic on drop
847        drop(dispatcher);
848    }
849
850    #[tokio::test]
851    async fn test_stdio_dispatcher_creation() {
852        // Test that StdioDispatcher can be created and used
853        let (tx, _rx) = mpsc::unbounded_channel();
854        let dispatcher = StdioDispatcher::new(tx.clone());
855
856        // Clone should work (needed for concurrent usage)
857        let _dispatcher2 = dispatcher.clone();
858
859        // Sending messages should work
860        assert!(tx.send(StdioMessage::Shutdown).is_ok());
861    }
862
863    #[tokio::test]
864    async fn test_joinset_task_tracking() {
865        // Test that JoinSet properly tracks and cleans up tasks
866        let mut tasks = JoinSet::new();
867
868        // Spawn some test tasks
869        for i in 0..5 {
870            tasks.spawn(async move {
871                tokio::time::sleep(tokio::time::Duration::from_millis(i * 10)).await;
872            });
873        }
874
875        assert_eq!(tasks.len(), 5);
876
877        // Wait for all tasks to complete
878        let mut completed = 0;
879        while let Some(result) = tasks.join_next().await {
880            assert!(result.is_ok());
881            completed += 1;
882        }
883
884        assert_eq!(completed, 5);
885        assert!(tasks.is_empty());
886    }
887
888    #[tokio::test]
889    async fn test_joinset_with_timeout() {
890        // Test timeout behavior for slow tasks
891        let mut tasks = JoinSet::new();
892
893        // Spawn a slow task
894        tasks.spawn(async move {
895            tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
896        });
897
898        // Wait with short timeout
899        let timeout = std::time::Duration::from_millis(100);
900        let start = std::time::Instant::now();
901
902        let result = tokio::time::timeout(timeout, tasks.join_next()).await;
903
904        // Should timeout
905        assert!(result.is_err());
906        assert!(start.elapsed() < std::time::Duration::from_secs(1));
907
908        // Cleanup
909        tasks.shutdown().await;
910    }
911
912    #[tokio::test]
913    async fn test_stdio_message_types() {
914        // Test that StdioMessage enum works correctly
915        use turbomcp_protocol::jsonrpc::JsonRpcRequest;
916
917        let request = JsonRpcRequest {
918            jsonrpc: JsonRpcVersion,
919            method: "test".to_string(),
920            params: None,
921            id: MessageId::String("test-1".to_string()),
922        };
923
924        let msg = StdioMessage::ServerRequest { request };
925
926        match msg {
927            StdioMessage::ServerRequest { .. } => { /* OK */ }
928            _ => panic!("Expected ServerRequest"),
929        }
930
931        let shutdown_msg = StdioMessage::Shutdown;
932        match shutdown_msg {
933            StdioMessage::Shutdown => { /* OK */ }
934            _ => panic!("Expected Shutdown"),
935        }
936    }
937
938    #[tokio::test]
939    async fn test_pending_requests_cleanup() {
940        // Test that pending requests are properly cleaned up
941        let pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
942            Arc::new(Mutex::new(HashMap::new()));
943
944        let (tx, _rx) = oneshot::channel();
945        pending_requests
946            .lock()
947            .await
948            .insert("test-id".to_string(), tx);
949
950        assert_eq!(pending_requests.lock().await.len(), 1);
951
952        // Remove the request
953        let removed = pending_requests.lock().await.remove("test-id");
954        assert!(removed.is_some());
955        assert_eq!(pending_requests.lock().await.len(), 0);
956    }
957}