1#[cfg(feature = "http")]
24pub mod http;
25
26#[cfg(feature = "websocket")]
28pub mod websocket;
29
30use std::collections::HashMap;
31use std::sync::Arc;
32
33use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
34use tokio::sync::{Mutex, mpsc, oneshot};
35
36use turbomcp_protocol::RequestContext;
37use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
38use turbomcp_protocol::types::{
39 CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
40 ListRootsResult, PingRequest, PingResult,
41};
42
43use crate::routing::{RequestRouter, ServerRequestDispatcher};
44use crate::{ServerError, ServerResult};
45
46type MessageId = turbomcp_protocol::MessageId;
47
48#[derive(Clone)]
53pub struct StdioDispatcher {
54 request_tx: mpsc::UnboundedSender<StdioMessage>,
56 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
58}
59
60impl std::fmt::Debug for StdioDispatcher {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("StdioDispatcher")
63 .field("has_request_tx", &true)
64 .field("pending_count", &"<mutex>")
65 .finish()
66 }
67}
68
69#[derive(Debug)]
71pub enum StdioMessage {
72 ServerRequest {
74 request: JsonRpcRequest,
76 },
77 Shutdown,
79}
80
81impl StdioDispatcher {
82 pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
84 Self {
85 request_tx,
86 pending_requests: Arc::new(Mutex::new(HashMap::new())),
87 }
88 }
89
90 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
92 let (response_tx, response_rx) = oneshot::channel();
93
94 let request_id = match &request.id {
96 MessageId::String(s) => s.clone(),
97 MessageId::Number(n) => n.to_string(),
98 MessageId::Uuid(u) => u.to_string(),
99 };
100
101 self.pending_requests
103 .lock()
104 .await
105 .insert(request_id.clone(), response_tx);
106
107 self.request_tx
109 .send(StdioMessage::ServerRequest { request })
110 .map_err(|e| ServerError::Handler {
111 message: format!("Failed to send request to stdout: {}", e),
112 context: Some("stdio_dispatcher".to_string()),
113 })?;
114
115 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
117 Ok(Ok(response)) => Ok(response),
118 Ok(Err(_)) => Err(ServerError::Handler {
119 message: "Response channel closed".to_string(),
120 context: Some("stdio_dispatcher".to_string()),
121 }),
122 Err(_) => {
123 self.pending_requests.lock().await.remove(&request_id);
125 Err(ServerError::Handler {
126 message: "Request timeout (60s)".to_string(),
127 context: Some("stdio_dispatcher".to_string()),
128 })
129 }
130 }
131 }
132
133 fn generate_request_id() -> MessageId {
135 MessageId::String(uuid::Uuid::new_v4().to_string())
136 }
137}
138
139#[async_trait::async_trait]
140impl ServerRequestDispatcher for StdioDispatcher {
141 async fn send_elicitation(
142 &self,
143 request: ElicitRequest,
144 _ctx: RequestContext,
145 ) -> ServerResult<ElicitResult> {
146 let json_rpc_request = JsonRpcRequest {
147 jsonrpc: JsonRpcVersion,
148 method: "elicitation/create".to_string(),
149 params: Some(
150 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
151 message: format!("Failed to serialize elicitation request: {}", e),
152 context: Some("MCP compliance".to_string()),
153 })?,
154 ),
155 id: Self::generate_request_id(),
156 };
157
158 let response = self.send_request(json_rpc_request).await?;
159
160 if let Some(result) = response.result() {
161 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
162 message: format!("Invalid elicitation response format: {}", e),
163 context: Some("MCP compliance".to_string()),
164 })
165 } else if let Some(error) = response.error() {
166 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
168 error.code,
169 &error.message,
170 )))
171 } else {
172 Err(ServerError::Handler {
173 message: "Invalid elicitation response: missing result and error".to_string(),
174 context: Some("MCP compliance".to_string()),
175 })
176 }
177 }
178
179 async fn send_ping(
180 &self,
181 _request: PingRequest,
182 _ctx: RequestContext,
183 ) -> ServerResult<PingResult> {
184 let json_rpc_request = JsonRpcRequest {
185 jsonrpc: JsonRpcVersion,
186 method: "ping".to_string(),
187 params: None,
188 id: Self::generate_request_id(),
189 };
190
191 let response = self.send_request(json_rpc_request).await?;
192
193 if response.result().is_some() {
194 Ok(PingResult {
195 data: None,
196 _meta: None,
197 })
198 } else if let Some(error) = response.error() {
199 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
201 error.code,
202 &error.message,
203 )))
204 } else {
205 Err(ServerError::Handler {
206 message: "Invalid ping response".to_string(),
207 context: Some("MCP compliance".to_string()),
208 })
209 }
210 }
211
212 async fn send_create_message(
213 &self,
214 request: CreateMessageRequest,
215 _ctx: RequestContext,
216 ) -> ServerResult<CreateMessageResult> {
217 let json_rpc_request = JsonRpcRequest {
218 jsonrpc: JsonRpcVersion,
219 method: "sampling/createMessage".to_string(),
220 params: Some(
221 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
222 message: format!("Failed to serialize sampling request: {}", e),
223 context: Some("MCP compliance".to_string()),
224 })?,
225 ),
226 id: Self::generate_request_id(),
227 };
228
229 let response = self.send_request(json_rpc_request).await?;
230
231 if let Some(result) = response.result() {
232 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
233 message: format!("Invalid sampling response format: {}", e),
234 context: Some("MCP compliance".to_string()),
235 })
236 } else if let Some(error) = response.error() {
237 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
239 error.code,
240 &error.message,
241 )))
242 } else {
243 Err(ServerError::Handler {
244 message: "Invalid sampling response: missing result and error".to_string(),
245 context: Some("MCP compliance".to_string()),
246 })
247 }
248 }
249
250 async fn send_list_roots(
251 &self,
252 _request: ListRootsRequest,
253 _ctx: RequestContext,
254 ) -> ServerResult<ListRootsResult> {
255 let json_rpc_request = JsonRpcRequest {
256 jsonrpc: JsonRpcVersion,
257 method: "roots/list".to_string(),
258 params: None,
259 id: Self::generate_request_id(),
260 };
261
262 let response = self.send_request(json_rpc_request).await?;
263
264 if let Some(result) = response.result() {
265 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
266 message: format!("Invalid roots response format: {}", e),
267 context: Some("MCP compliance".to_string()),
268 })
269 } else if let Some(error) = response.error() {
270 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
272 error.code,
273 &error.message,
274 )))
275 } else {
276 Err(ServerError::Handler {
277 message: "Invalid roots response: missing result and error".to_string(),
278 context: Some("MCP compliance".to_string()),
279 })
280 }
281 }
282
283 fn supports_bidirectional(&self) -> bool {
284 true
285 }
286
287 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
288 Ok(None)
289 }
290}
291
292pub async fn run_stdio_bidirectional(
300 router: Arc<RequestRouter>,
301 dispatcher: StdioDispatcher,
302 mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
303) -> Result<(), Box<dyn std::error::Error>> {
304 let stdin = tokio::io::stdin();
305 let stdout = tokio::io::stdout();
306 let mut reader = BufReader::new(stdin);
307 let mut line = String::new();
308
309 let stdout = Arc::new(Mutex::new(stdout));
310 let pending_requests = Arc::clone(&dispatcher.pending_requests);
311
312 let stdout_writer = Arc::clone(&stdout);
314 tokio::spawn(async move {
315 while let Some(msg) = request_rx.recv().await {
316 match msg {
317 StdioMessage::ServerRequest { request } => {
318 if let Ok(json) = serde_json::to_string(&request) {
319 let mut stdout = stdout_writer.lock().await;
320 let _ = stdout.write_all(json.as_bytes()).await;
321 let _ = stdout.write_all(b"\n").await;
322 let _ = stdout.flush().await;
323 }
324 }
325 StdioMessage::Shutdown => break,
326 }
327 }
328 });
329
330 loop {
332 line.clear();
333 match reader.read_line(&mut line).await {
334 Ok(0) => break, Ok(_) => {
336 if line.trim().is_empty() {
337 continue;
338 }
339
340 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
342 let request_id = match &response.id {
343 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
344 MessageId::String(s) => s.clone(),
345 MessageId::Number(n) => n.to_string(),
346 MessageId::Uuid(u) => u.to_string(),
347 },
348 _ => continue,
349 };
350
351 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
352 let _ = tx.send(response);
353 }
354 continue;
355 }
356
357 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
359 let router = Arc::clone(&router);
360 let stdout = Arc::clone(&stdout);
361
362 tokio::spawn(async move {
363 let ctx = router.create_context();
365 let response = router.route(request, ctx).await;
366
367 if let Ok(json) = serde_json::to_string(&response) {
368 let mut stdout = stdout.lock().await;
369 let _ = stdout.write_all(json.as_bytes()).await;
370 let _ = stdout.write_all(b"\n").await;
371 let _ = stdout.flush().await;
372 }
373 });
374 }
375 }
376 Err(_) => break,
377 }
378 }
379
380 Ok(())
381}
382
383pub struct TransportDispatcher<T>
403where
404 T: turbomcp_transport::Transport,
405{
406 transport: Arc<T>,
408 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
410}
411
412impl<T> Clone for TransportDispatcher<T>
414where
415 T: turbomcp_transport::Transport,
416{
417 fn clone(&self) -> Self {
418 Self {
419 transport: Arc::clone(&self.transport),
420 pending_requests: Arc::clone(&self.pending_requests),
421 }
422 }
423}
424
425impl<T> std::fmt::Debug for TransportDispatcher<T>
426where
427 T: turbomcp_transport::Transport,
428{
429 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
430 f.debug_struct("TransportDispatcher")
431 .field("transport_type", &self.transport.transport_type())
432 .field("pending_count", &"<mutex>")
433 .finish()
434 }
435}
436
437impl<T> TransportDispatcher<T>
438where
439 T: turbomcp_transport::Transport,
440{
441 pub fn new(transport: T) -> Self {
443 Self {
444 transport: Arc::new(transport),
445 pending_requests: Arc::new(Mutex::new(HashMap::new())),
446 }
447 }
448
449 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
451 use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
452
453 let (response_tx, response_rx) = oneshot::channel();
454
455 let request_id = match &request.id {
457 MessageId::String(s) => s.clone(),
458 MessageId::Number(n) => n.to_string(),
459 MessageId::Uuid(u) => u.to_string(),
460 };
461
462 self.pending_requests
464 .lock()
465 .await
466 .insert(request_id.clone(), response_tx);
467
468 let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
470 message: format!("Failed to serialize request: {}", e),
471 context: Some("transport_dispatcher".to_string()),
472 })?;
473
474 let transport_msg = TransportMessage::with_metadata(
476 MessageId::Uuid(uuid::Uuid::new_v4()),
477 bytes::Bytes::from(json),
478 TransportMessageMetadata::with_content_type("application/json"),
479 );
480
481 self.transport
482 .send(transport_msg)
483 .await
484 .map_err(|e| ServerError::Handler {
485 message: format!("Failed to send request via transport: {}", e),
486 context: Some("transport_dispatcher".to_string()),
487 })?;
488
489 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
491 Ok(Ok(response)) => Ok(response),
492 Ok(Err(_)) => Err(ServerError::Handler {
493 message: "Response channel closed".to_string(),
494 context: Some("transport_dispatcher".to_string()),
495 }),
496 Err(_) => {
497 self.pending_requests.lock().await.remove(&request_id);
499 Err(ServerError::Handler {
500 message: "Request timeout (60s)".to_string(),
501 context: Some("transport_dispatcher".to_string()),
502 })
503 }
504 }
505 }
506
507 fn generate_request_id() -> MessageId {
509 MessageId::String(uuid::Uuid::new_v4().to_string())
510 }
511
512 pub fn pending_requests(
514 &self,
515 ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
516 Arc::clone(&self.pending_requests)
517 }
518
519 pub fn transport(&self) -> Arc<T> {
521 Arc::clone(&self.transport)
522 }
523}
524
525#[async_trait::async_trait]
526impl<T> ServerRequestDispatcher for TransportDispatcher<T>
527where
528 T: turbomcp_transport::Transport + Send + Sync + 'static,
529{
530 async fn send_elicitation(
531 &self,
532 request: ElicitRequest,
533 _ctx: RequestContext,
534 ) -> ServerResult<ElicitResult> {
535 let json_rpc_request = JsonRpcRequest {
536 jsonrpc: JsonRpcVersion,
537 method: "elicitation/create".to_string(),
538 params: Some(
539 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
540 message: format!("Failed to serialize elicitation request: {}", e),
541 context: Some("MCP compliance".to_string()),
542 })?,
543 ),
544 id: Self::generate_request_id(),
545 };
546
547 let response = self.send_request(json_rpc_request).await?;
548
549 if let Some(result) = response.result() {
550 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
551 message: format!("Invalid elicitation response format: {}", e),
552 context: Some("MCP compliance".to_string()),
553 })
554 } else if let Some(error) = response.error() {
555 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
557 error.code,
558 &error.message,
559 )))
560 } else {
561 Err(ServerError::Handler {
562 message: "Invalid elicitation response: missing result and error".to_string(),
563 context: Some("MCP compliance".to_string()),
564 })
565 }
566 }
567
568 async fn send_ping(
569 &self,
570 _request: PingRequest,
571 _ctx: RequestContext,
572 ) -> ServerResult<PingResult> {
573 let json_rpc_request = JsonRpcRequest {
574 jsonrpc: JsonRpcVersion,
575 method: "ping".to_string(),
576 params: None,
577 id: Self::generate_request_id(),
578 };
579
580 let response = self.send_request(json_rpc_request).await?;
581
582 if response.result().is_some() {
583 Ok(PingResult {
584 data: None,
585 _meta: None,
586 })
587 } else if let Some(error) = response.error() {
588 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
590 error.code,
591 &error.message,
592 )))
593 } else {
594 Err(ServerError::Handler {
595 message: "Invalid ping response".to_string(),
596 context: Some("MCP compliance".to_string()),
597 })
598 }
599 }
600
601 async fn send_create_message(
602 &self,
603 request: CreateMessageRequest,
604 _ctx: RequestContext,
605 ) -> ServerResult<CreateMessageResult> {
606 let json_rpc_request = JsonRpcRequest {
607 jsonrpc: JsonRpcVersion,
608 method: "sampling/createMessage".to_string(),
609 params: Some(
610 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
611 message: format!("Failed to serialize sampling request: {}", e),
612 context: Some("MCP compliance".to_string()),
613 })?,
614 ),
615 id: Self::generate_request_id(),
616 };
617
618 let response = self.send_request(json_rpc_request).await?;
619
620 if let Some(result) = response.result() {
621 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
622 message: format!("Invalid sampling response format: {}", e),
623 context: Some("MCP compliance".to_string()),
624 })
625 } else if let Some(error) = response.error() {
626 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
628 error.code,
629 &error.message,
630 )))
631 } else {
632 Err(ServerError::Handler {
633 message: "Invalid sampling response: missing result and error".to_string(),
634 context: Some("MCP compliance".to_string()),
635 })
636 }
637 }
638
639 async fn send_list_roots(
640 &self,
641 _request: ListRootsRequest,
642 _ctx: RequestContext,
643 ) -> ServerResult<ListRootsResult> {
644 let json_rpc_request = JsonRpcRequest {
645 jsonrpc: JsonRpcVersion,
646 method: "roots/list".to_string(),
647 params: None,
648 id: Self::generate_request_id(),
649 };
650
651 let response = self.send_request(json_rpc_request).await?;
652
653 if let Some(result) = response.result() {
654 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
655 message: format!("Invalid roots response format: {}", e),
656 context: Some("MCP compliance".to_string()),
657 })
658 } else if let Some(error) = response.error() {
659 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
661 error.code,
662 &error.message,
663 )))
664 } else {
665 Err(ServerError::Handler {
666 message: "Invalid roots response: missing result and error".to_string(),
667 context: Some("MCP compliance".to_string()),
668 })
669 }
670 }
671
672 fn supports_bidirectional(&self) -> bool {
673 self.transport.capabilities().supports_bidirectional
674 }
675
676 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
677 Ok(None)
678 }
679}
680
681pub async fn run_transport_bidirectional<T>(
707 router: Arc<RequestRouter>,
708 dispatcher: TransportDispatcher<T>,
709) -> Result<(), Box<dyn std::error::Error>>
710where
711 T: turbomcp_transport::Transport + Send + Sync + 'static,
712{
713 let transport = dispatcher.transport();
714 let pending_requests = dispatcher.pending_requests();
715
716 loop {
718 match transport.receive().await {
720 Ok(Some(message)) => {
721 if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
723 let request_id = match &response.id {
724 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
725 MessageId::String(s) => s.clone(),
726 MessageId::Number(n) => n.to_string(),
727 MessageId::Uuid(u) => u.to_string(),
728 },
729 _ => continue,
730 };
731
732 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
733 let _ = tx.send(response);
734 }
735 continue;
736 }
737
738 if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
740 let router = Arc::clone(&router);
741 let transport = Arc::clone(&transport);
742
743 tokio::spawn(async move {
744 let ctx = router.create_context();
746 let response = router.route(request, ctx).await;
747
748 if let Ok(json) = serde_json::to_vec(&response) {
750 use turbomcp_transport::{
751 TransportMessage, core::TransportMessageMetadata,
752 };
753 let transport_msg = TransportMessage::with_metadata(
754 MessageId::Uuid(uuid::Uuid::new_v4()),
755 bytes::Bytes::from(json),
756 TransportMessageMetadata::with_content_type("application/json"),
757 );
758 let _ = transport.send(transport_msg).await;
759 }
760 });
761 }
762 }
763 Ok(None) => {
764 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
766 }
767 Err(e) => {
768 tracing::error!(error = %e, "Transport receive failed");
769 break;
770 }
771 }
772 }
773
774 Ok(())
775}