turbomcp_server/
runtime.rs

1//! Runtime components for bidirectional transport
2//!
3//! This module provides unified bidirectional communication support for all
4//! duplex transports (STDIO, TCP, Unix Socket, HTTP, WebSocket) with full MCP 2025-06-18 compliance.
5//!
6//! ## Architecture
7//!
8//! **Generic Abstraction**: `TransportDispatcher<T>` works with any `Transport`
9//! - Sends server-initiated requests via transport
10//! - Correlates responses with pending requests
11//! - Implements `ServerRequestDispatcher` trait
12//!
13//! **Specialized Implementations**:
14//! - `StdioDispatcher`: Optimized for stdin/stdout (line-delimited JSON)
15//! - `TransportDispatcher<TcpTransport>`: For TCP sockets
16//! - `TransportDispatcher<UnixTransport>`: For Unix domain sockets
17//! - `http::HttpDispatcher`: For HTTP + SSE sessions (feature-gated)
18//! - WebSocket: Uses transport layer's bidirectional infrastructure (`turbomcp_transport::axum::WebSocketDispatcher`)
19//!
20//! All share the same request correlation and error handling logic.
21
22// HTTP bidirectional support (feature-gated)
23#[cfg(feature = "http")]
24pub mod http;
25
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
30use tokio::sync::{Mutex, mpsc, oneshot};
31use tokio::task::JoinSet;
32
33use turbomcp_protocol::RequestContext;
34use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
35use turbomcp_protocol::types::{
36    CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
37    ListRootsResult, PingRequest, PingResult,
38};
39
40use crate::routing::{RequestRouter, ServerRequestDispatcher};
41use crate::{ServerError, ServerResult};
42
43type MessageId = turbomcp_protocol::MessageId;
44
45/// STDIO dispatcher for server-initiated requests
46///
47/// This dispatcher implements the MCP 2025-06-18 specification for stdio transport,
48/// allowing servers to make requests to clients (server→client capability).
49#[derive(Clone)]
50pub struct StdioDispatcher {
51    /// Channel for sending messages to stdout writer
52    request_tx: mpsc::UnboundedSender<StdioMessage>,
53    /// Pending server-initiated requests awaiting responses
54    pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
55}
56
57impl std::fmt::Debug for StdioDispatcher {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("StdioDispatcher")
60            .field("has_request_tx", &true)
61            .field("pending_count", &"<mutex>")
62            .finish()
63    }
64}
65
66/// Internal message type for STDIO transport
67#[derive(Debug)]
68pub enum StdioMessage {
69    /// Server request to be sent to client
70    ServerRequest {
71        /// The JSON-RPC request
72        request: JsonRpcRequest,
73    },
74    /// Shutdown signal
75    Shutdown,
76}
77
78impl StdioDispatcher {
79    /// Create a new STDIO dispatcher
80    pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
81        Self {
82            request_tx,
83            pending_requests: Arc::new(Mutex::new(HashMap::new())),
84        }
85    }
86
87    /// Send a JSON-RPC request and wait for response
88    async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
89        let (response_tx, response_rx) = oneshot::channel();
90
91        // Extract request ID for correlation
92        let request_id = match &request.id {
93            MessageId::String(s) => s.clone(),
94            MessageId::Number(n) => n.to_string(),
95            MessageId::Uuid(u) => u.to_string(),
96        };
97
98        // Register pending request
99        self.pending_requests
100            .lock()
101            .await
102            .insert(request_id.clone(), response_tx);
103
104        // Send to stdout writer
105        self.request_tx
106            .send(StdioMessage::ServerRequest { request })
107            .map_err(|e| ServerError::Handler {
108                message: format!("Failed to send request to stdout: {}", e),
109                context: Some("stdio_dispatcher".to_string()),
110            })?;
111
112        // Wait for response with timeout (60 seconds per MCP recommendation)
113        match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
114            Ok(Ok(response)) => Ok(response),
115            Ok(Err(_)) => Err(ServerError::Handler {
116                message: "Response channel closed".to_string(),
117                context: Some("stdio_dispatcher".to_string()),
118            }),
119            Err(_) => {
120                // Timeout - remove from pending
121                self.pending_requests.lock().await.remove(&request_id);
122                Err(ServerError::Handler {
123                    message: "Request timeout (60s)".to_string(),
124                    context: Some("stdio_dispatcher".to_string()),
125                })
126            }
127        }
128    }
129
130    /// Generate a unique request ID
131    fn generate_request_id() -> MessageId {
132        MessageId::String(uuid::Uuid::new_v4().to_string())
133    }
134}
135
136#[async_trait::async_trait]
137impl ServerRequestDispatcher for StdioDispatcher {
138    async fn send_elicitation(
139        &self,
140        request: ElicitRequest,
141        _ctx: RequestContext,
142    ) -> ServerResult<ElicitResult> {
143        let json_rpc_request = JsonRpcRequest {
144            jsonrpc: JsonRpcVersion,
145            method: "elicitation/create".to_string(),
146            params: Some(
147                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
148                    message: format!("Failed to serialize elicitation request: {}", e),
149                    context: Some("MCP compliance".to_string()),
150                })?,
151            ),
152            id: Self::generate_request_id(),
153        };
154
155        let response = self.send_request(json_rpc_request).await?;
156
157        if let Some(result) = response.result() {
158            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
159                message: format!("Invalid elicitation response format: {}", e),
160                context: Some("MCP compliance".to_string()),
161            })
162        } else if let Some(error) = response.error() {
163            // Preserve client error code by wrapping as Protocol error
164            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
165                error.code,
166                &error.message,
167            )))
168        } else {
169            Err(ServerError::Handler {
170                message: "Invalid elicitation response: missing result and error".to_string(),
171                context: Some("MCP compliance".to_string()),
172            })
173        }
174    }
175
176    async fn send_ping(
177        &self,
178        _request: PingRequest,
179        _ctx: RequestContext,
180    ) -> ServerResult<PingResult> {
181        let json_rpc_request = JsonRpcRequest {
182            jsonrpc: JsonRpcVersion,
183            method: "ping".to_string(),
184            params: None,
185            id: Self::generate_request_id(),
186        };
187
188        let response = self.send_request(json_rpc_request).await?;
189
190        if response.result().is_some() {
191            Ok(PingResult {
192                data: None,
193                _meta: None,
194            })
195        } else if let Some(error) = response.error() {
196            // Preserve client error code by wrapping as Protocol error
197            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
198                error.code,
199                &error.message,
200            )))
201        } else {
202            Err(ServerError::Handler {
203                message: "Invalid ping response".to_string(),
204                context: Some("MCP compliance".to_string()),
205            })
206        }
207    }
208
209    async fn send_create_message(
210        &self,
211        request: CreateMessageRequest,
212        _ctx: RequestContext,
213    ) -> ServerResult<CreateMessageResult> {
214        let json_rpc_request = JsonRpcRequest {
215            jsonrpc: JsonRpcVersion,
216            method: "sampling/createMessage".to_string(),
217            params: Some(
218                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
219                    message: format!("Failed to serialize sampling request: {}", e),
220                    context: Some("MCP compliance".to_string()),
221                })?,
222            ),
223            id: Self::generate_request_id(),
224        };
225
226        let response = self.send_request(json_rpc_request).await?;
227
228        if let Some(result) = response.result() {
229            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
230                message: format!("Invalid sampling response format: {}", e),
231                context: Some("MCP compliance".to_string()),
232            })
233        } else if let Some(error) = response.error() {
234            // Preserve client error code by wrapping as Protocol error
235            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
236                error.code,
237                &error.message,
238            )))
239        } else {
240            Err(ServerError::Handler {
241                message: "Invalid sampling response: missing result and error".to_string(),
242                context: Some("MCP compliance".to_string()),
243            })
244        }
245    }
246
247    async fn send_list_roots(
248        &self,
249        _request: ListRootsRequest,
250        _ctx: RequestContext,
251    ) -> ServerResult<ListRootsResult> {
252        let json_rpc_request = JsonRpcRequest {
253            jsonrpc: JsonRpcVersion,
254            method: "roots/list".to_string(),
255            params: None,
256            id: Self::generate_request_id(),
257        };
258
259        let response = self.send_request(json_rpc_request).await?;
260
261        if let Some(result) = response.result() {
262            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
263                message: format!("Invalid roots response format: {}", e),
264                context: Some("MCP compliance".to_string()),
265            })
266        } else if let Some(error) = response.error() {
267            // Preserve client error code by wrapping as Protocol error
268            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
269                error.code,
270                &error.message,
271            )))
272        } else {
273            Err(ServerError::Handler {
274                message: "Invalid roots response: missing result and error".to_string(),
275                context: Some("MCP compliance".to_string()),
276            })
277        }
278    }
279
280    fn supports_bidirectional(&self) -> bool {
281        true
282    }
283
284    async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
285        Ok(None)
286    }
287}
288
289/// Run MCP server over STDIO transport with full bidirectional support
290///
291/// This runtime implements the complete MCP 2025-06-18 stdio protocol:
292/// - Reads JSON-RPC from stdin (client requests AND server response correlations)
293/// - Writes JSON-RPC to stdout (server responses AND server requests)
294/// - Maintains request/response correlation
295/// - Handles errors per MCP spec
296/// - Manages task lifecycle with JoinSet for clean shutdown
297pub async fn run_stdio_bidirectional(
298    router: Arc<RequestRouter>,
299    dispatcher: StdioDispatcher,
300    mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
301) -> Result<(), Box<dyn std::error::Error>> {
302    let stdin = tokio::io::stdin();
303    let stdout = tokio::io::stdout();
304    let mut reader = BufReader::new(stdin);
305    let mut line = String::new();
306
307    let stdout = Arc::new(Mutex::new(stdout));
308    let pending_requests = Arc::clone(&dispatcher.pending_requests);
309
310    // ✅ Create JoinSet to manage all spawned tasks
311    let mut tasks = JoinSet::new();
312
313    // ✅ Spawn stdout writer task and store handle in JoinSet
314    let stdout_writer = Arc::clone(&stdout);
315    tasks.spawn(async move {
316        while let Some(msg) = request_rx.recv().await {
317            match msg {
318                StdioMessage::ServerRequest { request } => {
319                    if let Ok(json) = serde_json::to_string(&request) {
320                        let mut stdout = stdout_writer.lock().await;
321                        let _ = stdout.write_all(json.as_bytes()).await;
322                        let _ = stdout.write_all(b"\n").await;
323                        let _ = stdout.flush().await;
324                    }
325                }
326                StdioMessage::Shutdown => break,
327            }
328        }
329    });
330
331    // Main stdin reader loop
332    loop {
333        line.clear();
334        match reader.read_line(&mut line).await {
335            Ok(0) => break, // EOF
336            Ok(_) => {
337                if line.trim().is_empty() {
338                    continue;
339                }
340
341                // Try parsing as JSON-RPC response first (for server-initiated request responses)
342                if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
343                    let request_id = match &response.id {
344                        turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
345                            MessageId::String(s) => s.clone(),
346                            MessageId::Number(n) => n.to_string(),
347                            MessageId::Uuid(u) => u.to_string(),
348                        },
349                        _ => continue,
350                    };
351
352                    if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
353                        let _ = tx.send(response);
354                    }
355                    continue;
356                }
357
358                // Try parsing as JSON-RPC request (client-initiated)
359                if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
360                    let router = Arc::clone(&router);
361                    let stdout = Arc::clone(&stdout);
362
363                    // ✅ Spawn request handler and store handle in JoinSet
364                    tasks.spawn(async move {
365                        // Create properly configured context with server-to-client capabilities
366                        let ctx = router.create_context(None, None);
367                        let response = router.route(request, ctx).await;
368
369                        if let Ok(json) = serde_json::to_string(&response) {
370                            let mut stdout = stdout.lock().await;
371                            let _ = stdout.write_all(json.as_bytes()).await;
372                            let _ = stdout.write_all(b"\n").await;
373                            let _ = stdout.flush().await;
374                        }
375                    });
376                }
377            }
378            Err(_) => break,
379        }
380    }
381
382    // ✅ GRACEFUL SHUTDOWN: Wait for all tasks to complete
383    tracing::debug!(
384        "STDIO dispatcher shutting down, waiting for {} tasks",
385        tasks.len()
386    );
387
388    // Signal writer task to shutdown by dropping the channel
389    // The request_rx.recv() in the writer task will return None, causing it to exit
390    drop(dispatcher);
391
392    // Wait for all tasks to complete with timeout
393    let shutdown_timeout = std::time::Duration::from_secs(5);
394    let start = std::time::Instant::now();
395
396    while let Some(result) = tokio::time::timeout(
397        shutdown_timeout.saturating_sub(start.elapsed()),
398        tasks.join_next(),
399    )
400    .await
401    .ok()
402    .flatten()
403    {
404        match result {
405            Ok(()) => {
406                tracing::debug!("Task completed successfully during shutdown");
407            }
408            Err(e) if e.is_panic() => {
409                tracing::warn!("Task panicked during shutdown: {:?}", e);
410            }
411            Err(e) if e.is_cancelled() => {
412                tracing::debug!("Task was cancelled during shutdown");
413            }
414            Err(e) => {
415                tracing::debug!("Task error during shutdown: {:?}", e);
416            }
417        }
418    }
419
420    // ✅ Abort remaining tasks if timeout occurred
421    if !tasks.is_empty() {
422        tracing::warn!(
423            "Aborting {} tasks due to shutdown timeout ({}s)",
424            tasks.len(),
425            shutdown_timeout.as_secs()
426        );
427        tasks.shutdown().await;
428    }
429
430    tracing::debug!("STDIO dispatcher shutdown complete");
431    Ok(())
432}
433
434// ============================================================================
435// Generic Transport Dispatcher (TCP, Unix Socket, and other Transport impls)
436// ============================================================================
437
438/// Generic dispatcher for any Transport implementation
439///
440/// This provides bidirectional MCP support for any transport that implements
441/// the `Transport` trait. Unlike `StdioDispatcher` which directly reads/writes
442/// stdin/stdout, this uses the Transport trait's `send()` and `receive()` methods.
443///
444/// **Usage**:
445/// ```rust,ignore
446/// use turbomcp_transport::TcpTransport;
447/// use turbomcp_server::runtime::TransportDispatcher;
448///
449/// let addr = "127.0.0.1:8080".parse().unwrap();
450/// let transport = TcpTransport::new_server(addr);
451/// let dispatcher = TransportDispatcher::new(transport);
452/// ```
453pub struct TransportDispatcher<T>
454where
455    T: turbomcp_transport::Transport,
456{
457    /// The underlying transport
458    transport: Arc<T>,
459    /// Pending server-initiated requests awaiting responses
460    pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
461}
462
463// Manual Clone implementation: Arc cloning doesn't require T: Clone
464impl<T> Clone for TransportDispatcher<T>
465where
466    T: turbomcp_transport::Transport,
467{
468    fn clone(&self) -> Self {
469        Self {
470            transport: Arc::clone(&self.transport),
471            pending_requests: Arc::clone(&self.pending_requests),
472        }
473    }
474}
475
476impl<T> std::fmt::Debug for TransportDispatcher<T>
477where
478    T: turbomcp_transport::Transport,
479{
480    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481        f.debug_struct("TransportDispatcher")
482            .field("transport_type", &self.transport.transport_type())
483            .field("pending_count", &"<mutex>")
484            .finish()
485    }
486}
487
488impl<T> TransportDispatcher<T>
489where
490    T: turbomcp_transport::Transport,
491{
492    /// Create a new transport dispatcher
493    pub fn new(transport: T) -> Self {
494        Self {
495            transport: Arc::new(transport),
496            pending_requests: Arc::new(Mutex::new(HashMap::new())),
497        }
498    }
499
500    /// Send a JSON-RPC request and wait for response
501    async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
502        use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
503
504        let (response_tx, response_rx) = oneshot::channel();
505
506        // Extract request ID for correlation
507        let request_id = match &request.id {
508            MessageId::String(s) => s.clone(),
509            MessageId::Number(n) => n.to_string(),
510            MessageId::Uuid(u) => u.to_string(),
511        };
512
513        // Register pending request
514        self.pending_requests
515            .lock()
516            .await
517            .insert(request_id.clone(), response_tx);
518
519        // Serialize request to JSON
520        let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
521            message: format!("Failed to serialize request: {}", e),
522            context: Some("transport_dispatcher".to_string()),
523        })?;
524
525        // Send via transport
526        let transport_msg = TransportMessage::with_metadata(
527            MessageId::Uuid(uuid::Uuid::new_v4()),
528            bytes::Bytes::from(json),
529            TransportMessageMetadata::with_content_type("application/json"),
530        );
531
532        self.transport
533            .send(transport_msg)
534            .await
535            .map_err(|e| ServerError::Handler {
536                message: format!("Failed to send request via transport: {}", e),
537                context: Some("transport_dispatcher".to_string()),
538            })?;
539
540        // Wait for response with timeout (60 seconds per MCP recommendation)
541        match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
542            Ok(Ok(response)) => Ok(response),
543            Ok(Err(_)) => Err(ServerError::Handler {
544                message: "Response channel closed".to_string(),
545                context: Some("transport_dispatcher".to_string()),
546            }),
547            Err(_) => {
548                // Timeout - remove from pending
549                self.pending_requests.lock().await.remove(&request_id);
550                Err(ServerError::Handler {
551                    message: "Request timeout (60s)".to_string(),
552                    context: Some("transport_dispatcher".to_string()),
553                })
554            }
555        }
556    }
557
558    /// Generate a unique request ID
559    fn generate_request_id() -> MessageId {
560        MessageId::String(uuid::Uuid::new_v4().to_string())
561    }
562
563    /// Get access to pending requests for response correlation
564    pub fn pending_requests(
565        &self,
566    ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
567        Arc::clone(&self.pending_requests)
568    }
569
570    /// Get access to the transport
571    pub fn transport(&self) -> Arc<T> {
572        Arc::clone(&self.transport)
573    }
574}
575
576#[async_trait::async_trait]
577impl<T> ServerRequestDispatcher for TransportDispatcher<T>
578where
579    T: turbomcp_transport::Transport + Send + Sync + 'static,
580{
581    async fn send_elicitation(
582        &self,
583        request: ElicitRequest,
584        _ctx: RequestContext,
585    ) -> ServerResult<ElicitResult> {
586        let json_rpc_request = JsonRpcRequest {
587            jsonrpc: JsonRpcVersion,
588            method: "elicitation/create".to_string(),
589            params: Some(
590                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
591                    message: format!("Failed to serialize elicitation request: {}", e),
592                    context: Some("MCP compliance".to_string()),
593                })?,
594            ),
595            id: Self::generate_request_id(),
596        };
597
598        let response = self.send_request(json_rpc_request).await?;
599
600        if let Some(result) = response.result() {
601            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
602                message: format!("Invalid elicitation response format: {}", e),
603                context: Some("MCP compliance".to_string()),
604            })
605        } else if let Some(error) = response.error() {
606            // Preserve client error code by wrapping as Protocol error
607            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
608                error.code,
609                &error.message,
610            )))
611        } else {
612            Err(ServerError::Handler {
613                message: "Invalid elicitation response: missing result and error".to_string(),
614                context: Some("MCP compliance".to_string()),
615            })
616        }
617    }
618
619    async fn send_ping(
620        &self,
621        _request: PingRequest,
622        _ctx: RequestContext,
623    ) -> ServerResult<PingResult> {
624        let json_rpc_request = JsonRpcRequest {
625            jsonrpc: JsonRpcVersion,
626            method: "ping".to_string(),
627            params: None,
628            id: Self::generate_request_id(),
629        };
630
631        let response = self.send_request(json_rpc_request).await?;
632
633        if response.result().is_some() {
634            Ok(PingResult {
635                data: None,
636                _meta: None,
637            })
638        } else if let Some(error) = response.error() {
639            // Preserve client error code by wrapping as Protocol error
640            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
641                error.code,
642                &error.message,
643            )))
644        } else {
645            Err(ServerError::Handler {
646                message: "Invalid ping response".to_string(),
647                context: Some("MCP compliance".to_string()),
648            })
649        }
650    }
651
652    async fn send_create_message(
653        &self,
654        request: CreateMessageRequest,
655        _ctx: RequestContext,
656    ) -> ServerResult<CreateMessageResult> {
657        let json_rpc_request = JsonRpcRequest {
658            jsonrpc: JsonRpcVersion,
659            method: "sampling/createMessage".to_string(),
660            params: Some(
661                serde_json::to_value(&request).map_err(|e| ServerError::Handler {
662                    message: format!("Failed to serialize sampling request: {}", e),
663                    context: Some("MCP compliance".to_string()),
664                })?,
665            ),
666            id: Self::generate_request_id(),
667        };
668
669        let response = self.send_request(json_rpc_request).await?;
670
671        if let Some(result) = response.result() {
672            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
673                message: format!("Invalid sampling response format: {}", e),
674                context: Some("MCP compliance".to_string()),
675            })
676        } else if let Some(error) = response.error() {
677            // Preserve client error code by wrapping as Protocol error
678            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
679                error.code,
680                &error.message,
681            )))
682        } else {
683            Err(ServerError::Handler {
684                message: "Invalid sampling response: missing result and error".to_string(),
685                context: Some("MCP compliance".to_string()),
686            })
687        }
688    }
689
690    async fn send_list_roots(
691        &self,
692        _request: ListRootsRequest,
693        _ctx: RequestContext,
694    ) -> ServerResult<ListRootsResult> {
695        let json_rpc_request = JsonRpcRequest {
696            jsonrpc: JsonRpcVersion,
697            method: "roots/list".to_string(),
698            params: None,
699            id: Self::generate_request_id(),
700        };
701
702        let response = self.send_request(json_rpc_request).await?;
703
704        if let Some(result) = response.result() {
705            serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
706                message: format!("Invalid roots response format: {}", e),
707                context: Some("MCP compliance".to_string()),
708            })
709        } else if let Some(error) = response.error() {
710            // Preserve client error code by wrapping as Protocol error
711            Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
712                error.code,
713                &error.message,
714            )))
715        } else {
716            Err(ServerError::Handler {
717                message: "Invalid roots response: missing result and error".to_string(),
718                context: Some("MCP compliance".to_string()),
719            })
720        }
721    }
722
723    fn supports_bidirectional(&self) -> bool {
724        self.transport.capabilities().supports_bidirectional
725    }
726
727    async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
728        Ok(None)
729    }
730}
731
732/// Run MCP server with any Transport implementation with full bidirectional support
733///
734/// This is a generic runtime that works with TCP, Unix Socket, or any other
735/// transport implementing the `Transport` trait.
736///
737/// **Architecture**:
738/// - Spawns receiver task: reads from transport, routes through router
739/// - Transport send: used for both responses and server-initiated requests
740/// - Correlation: matches responses to pending requests
741///
742/// **Usage**:
743/// ```rust,ignore
744/// use std::sync::Arc;
745/// use turbomcp_transport::TcpTransport;
746/// use turbomcp_server::runtime::{TransportDispatcher, run_transport_bidirectional};
747/// use turbomcp_server::routing::RequestRouter;
748///
749/// let addr = "127.0.0.1:8080".parse().unwrap();
750/// let transport = TcpTransport::new_server(addr);
751/// let dispatcher = TransportDispatcher::new(transport);
752/// let router = Arc::new(RequestRouter::new());
753///
754/// // In async context:
755/// run_transport_bidirectional(router, dispatcher).await?;
756/// ```
757pub async fn run_transport_bidirectional<T>(
758    router: Arc<RequestRouter>,
759    dispatcher: TransportDispatcher<T>,
760) -> Result<(), Box<dyn std::error::Error>>
761where
762    T: turbomcp_transport::Transport + Send + Sync + 'static,
763{
764    let transport = dispatcher.transport();
765    let pending_requests = dispatcher.pending_requests();
766
767    // Main message processing loop
768    loop {
769        // Receive message from transport
770        match transport.receive().await {
771            Ok(Some(message)) => {
772                // Try parsing as JSON-RPC response first (for server-initiated request responses)
773                if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
774                    let request_id = match &response.id {
775                        turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
776                            MessageId::String(s) => s.clone(),
777                            MessageId::Number(n) => n.to_string(),
778                            MessageId::Uuid(u) => u.to_string(),
779                        },
780                        _ => continue,
781                    };
782
783                    if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
784                        let _ = tx.send(response);
785                    }
786                    continue;
787                }
788
789                // Try parsing as JSON-RPC request (client-initiated)
790                if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
791                    let router = Arc::clone(&router);
792                    let transport = Arc::clone(&transport);
793
794                    tokio::spawn(async move {
795                        // Create properly configured context with server-to-client capabilities
796                        let ctx = router.create_context(None, None);
797                        let response = router.route(request, ctx).await;
798
799                        // Send response back via transport
800                        if let Ok(json) = serde_json::to_vec(&response) {
801                            use turbomcp_transport::{
802                                TransportMessage, core::TransportMessageMetadata,
803                            };
804                            let transport_msg = TransportMessage::with_metadata(
805                                MessageId::Uuid(uuid::Uuid::new_v4()),
806                                bytes::Bytes::from(json),
807                                TransportMessageMetadata::with_content_type("application/json"),
808                            );
809                            let _ = transport.send(transport_msg).await;
810                        }
811                    });
812                }
813            }
814            Ok(None) => {
815                // No message available, sleep briefly
816                tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
817            }
818            Err(e) => {
819                tracing::error!(error = %e, "Transport receive failed");
820                break;
821            }
822        }
823    }
824
825    Ok(())
826}
827
828// ============================================================================
829// Unit Tests
830// ============================================================================
831
832#[cfg(test)]
833mod tests {
834    use super::*;
835
836    #[tokio::test]
837    async fn test_stdio_dispatcher_clean_shutdown() {
838        // Test that STDIO dispatcher can be dropped without panic
839        let (tx, _rx) = mpsc::unbounded_channel();
840        let dispatcher = StdioDispatcher::new(tx);
841
842        // Should not panic on drop
843        drop(dispatcher);
844    }
845
846    #[tokio::test]
847    async fn test_stdio_dispatcher_creation() {
848        // Test that StdioDispatcher can be created and used
849        let (tx, _rx) = mpsc::unbounded_channel();
850        let dispatcher = StdioDispatcher::new(tx.clone());
851
852        // Clone should work (needed for concurrent usage)
853        let _dispatcher2 = dispatcher.clone();
854
855        // Sending messages should work
856        assert!(tx.send(StdioMessage::Shutdown).is_ok());
857    }
858
859    #[tokio::test]
860    async fn test_joinset_task_tracking() {
861        // Test that JoinSet properly tracks and cleans up tasks
862        let mut tasks = JoinSet::new();
863
864        // Spawn some test tasks
865        for i in 0..5 {
866            tasks.spawn(async move {
867                tokio::time::sleep(tokio::time::Duration::from_millis(i * 10)).await;
868            });
869        }
870
871        assert_eq!(tasks.len(), 5);
872
873        // Wait for all tasks to complete
874        let mut completed = 0;
875        while let Some(result) = tasks.join_next().await {
876            assert!(result.is_ok());
877            completed += 1;
878        }
879
880        assert_eq!(completed, 5);
881        assert!(tasks.is_empty());
882    }
883
884    #[tokio::test]
885    async fn test_joinset_with_timeout() {
886        // Test timeout behavior for slow tasks
887        let mut tasks = JoinSet::new();
888
889        // Spawn a slow task
890        tasks.spawn(async move {
891            tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
892        });
893
894        // Wait with short timeout
895        let timeout = std::time::Duration::from_millis(100);
896        let start = std::time::Instant::now();
897
898        let result = tokio::time::timeout(timeout, tasks.join_next()).await;
899
900        // Should timeout
901        assert!(result.is_err());
902        assert!(start.elapsed() < std::time::Duration::from_secs(1));
903
904        // Cleanup
905        tasks.shutdown().await;
906    }
907
908    #[tokio::test]
909    async fn test_stdio_message_types() {
910        // Test that StdioMessage enum works correctly
911        use turbomcp_protocol::jsonrpc::JsonRpcRequest;
912
913        let request = JsonRpcRequest {
914            jsonrpc: JsonRpcVersion,
915            method: "test".to_string(),
916            params: None,
917            id: MessageId::String("test-1".to_string()),
918        };
919
920        let msg = StdioMessage::ServerRequest { request };
921
922        match msg {
923            StdioMessage::ServerRequest { .. } => { /* OK */ }
924            _ => panic!("Expected ServerRequest"),
925        }
926
927        let shutdown_msg = StdioMessage::Shutdown;
928        match shutdown_msg {
929            StdioMessage::Shutdown => { /* OK */ }
930            _ => panic!("Expected Shutdown"),
931        }
932    }
933
934    #[tokio::test]
935    async fn test_pending_requests_cleanup() {
936        // Test that pending requests are properly cleaned up
937        let pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
938            Arc::new(Mutex::new(HashMap::new()));
939
940        let (tx, _rx) = oneshot::channel();
941        pending_requests
942            .lock()
943            .await
944            .insert("test-id".to_string(), tx);
945
946        assert_eq!(pending_requests.lock().await.len(), 1);
947
948        // Remove the request
949        let removed = pending_requests.lock().await.remove("test-id");
950        assert!(removed.is_some());
951        assert_eq!(pending_requests.lock().await.len(), 0);
952    }
953}