1use std::collections::HashMap;
21use std::sync::Arc;
22
23use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
24use tokio::sync::{Mutex, mpsc, oneshot};
25
26use turbomcp_protocol::RequestContext;
27use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
28use turbomcp_protocol::types::{
29 CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
30 ListRootsResult, PingRequest, PingResult,
31};
32
33use crate::routing::{RequestRouter, ServerRequestDispatcher};
34use crate::{ServerError, ServerResult};
35
36type MessageId = turbomcp_protocol::MessageId;
37
38#[derive(Clone)]
43pub struct StdioDispatcher {
44 request_tx: mpsc::UnboundedSender<StdioMessage>,
46 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
48}
49
50impl std::fmt::Debug for StdioDispatcher {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 f.debug_struct("StdioDispatcher")
53 .field("has_request_tx", &true)
54 .field("pending_count", &"<mutex>")
55 .finish()
56 }
57}
58
59#[derive(Debug)]
61pub enum StdioMessage {
62 ServerRequest {
64 request: JsonRpcRequest,
66 },
67 Shutdown,
69}
70
71impl StdioDispatcher {
72 pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
74 Self {
75 request_tx,
76 pending_requests: Arc::new(Mutex::new(HashMap::new())),
77 }
78 }
79
80 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
82 let (response_tx, response_rx) = oneshot::channel();
83
84 let request_id = match &request.id {
86 MessageId::String(s) => s.clone(),
87 MessageId::Number(n) => n.to_string(),
88 MessageId::Uuid(u) => u.to_string(),
89 };
90
91 self.pending_requests
93 .lock()
94 .await
95 .insert(request_id.clone(), response_tx);
96
97 self.request_tx
99 .send(StdioMessage::ServerRequest { request })
100 .map_err(|e| ServerError::Handler {
101 message: format!("Failed to send request to stdout: {}", e),
102 context: Some("stdio_dispatcher".to_string()),
103 })?;
104
105 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
107 Ok(Ok(response)) => Ok(response),
108 Ok(Err(_)) => Err(ServerError::Handler {
109 message: "Response channel closed".to_string(),
110 context: Some("stdio_dispatcher".to_string()),
111 }),
112 Err(_) => {
113 self.pending_requests.lock().await.remove(&request_id);
115 Err(ServerError::Handler {
116 message: "Request timeout (60s)".to_string(),
117 context: Some("stdio_dispatcher".to_string()),
118 })
119 }
120 }
121 }
122
123 fn generate_request_id() -> MessageId {
125 MessageId::String(uuid::Uuid::new_v4().to_string())
126 }
127}
128
129#[async_trait::async_trait]
130impl ServerRequestDispatcher for StdioDispatcher {
131 async fn send_elicitation(
132 &self,
133 request: ElicitRequest,
134 _ctx: RequestContext,
135 ) -> ServerResult<ElicitResult> {
136 let json_rpc_request = JsonRpcRequest {
137 jsonrpc: JsonRpcVersion,
138 method: "elicitation/create".to_string(),
139 params: Some(
140 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
141 message: format!("Failed to serialize elicitation request: {}", e),
142 context: Some("MCP compliance".to_string()),
143 })?,
144 ),
145 id: Self::generate_request_id(),
146 };
147
148 let response = self.send_request(json_rpc_request).await?;
149
150 if let Some(result) = response.result() {
151 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
152 message: format!("Invalid elicitation response format: {}", e),
153 context: Some("MCP compliance".to_string()),
154 })
155 } else if let Some(error) = response.error() {
156 Err(ServerError::Handler {
157 message: format!("Elicitation error: {}", error.message),
158 context: Some(format!("MCP error code: {}", error.code)),
159 })
160 } else {
161 Err(ServerError::Handler {
162 message: "Invalid elicitation response: missing result and error".to_string(),
163 context: Some("MCP compliance".to_string()),
164 })
165 }
166 }
167
168 async fn send_ping(
169 &self,
170 _request: PingRequest,
171 _ctx: RequestContext,
172 ) -> ServerResult<PingResult> {
173 let json_rpc_request = JsonRpcRequest {
174 jsonrpc: JsonRpcVersion,
175 method: "ping".to_string(),
176 params: None,
177 id: Self::generate_request_id(),
178 };
179
180 let response = self.send_request(json_rpc_request).await?;
181
182 if response.result().is_some() {
183 Ok(PingResult {
184 data: None,
185 _meta: None,
186 })
187 } else if let Some(error) = response.error() {
188 Err(ServerError::Handler {
189 message: format!("Ping error: {}", error.message),
190 context: Some(format!("MCP error code: {}", error.code)),
191 })
192 } else {
193 Err(ServerError::Handler {
194 message: "Invalid ping response".to_string(),
195 context: Some("MCP compliance".to_string()),
196 })
197 }
198 }
199
200 async fn send_create_message(
201 &self,
202 request: CreateMessageRequest,
203 _ctx: RequestContext,
204 ) -> ServerResult<CreateMessageResult> {
205 let json_rpc_request = JsonRpcRequest {
206 jsonrpc: JsonRpcVersion,
207 method: "sampling/createMessage".to_string(),
208 params: Some(
209 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
210 message: format!("Failed to serialize sampling request: {}", e),
211 context: Some("MCP compliance".to_string()),
212 })?,
213 ),
214 id: Self::generate_request_id(),
215 };
216
217 let response = self.send_request(json_rpc_request).await?;
218
219 if let Some(result) = response.result() {
220 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
221 message: format!("Invalid sampling response format: {}", e),
222 context: Some("MCP compliance".to_string()),
223 })
224 } else if let Some(error) = response.error() {
225 Err(ServerError::Handler {
226 message: format!("Sampling error: {}", error.message),
227 context: Some(format!("MCP error code: {}", error.code)),
228 })
229 } else {
230 Err(ServerError::Handler {
231 message: "Invalid sampling response: missing result and error".to_string(),
232 context: Some("MCP compliance".to_string()),
233 })
234 }
235 }
236
237 async fn send_list_roots(
238 &self,
239 _request: ListRootsRequest,
240 _ctx: RequestContext,
241 ) -> ServerResult<ListRootsResult> {
242 let json_rpc_request = JsonRpcRequest {
243 jsonrpc: JsonRpcVersion,
244 method: "roots/list".to_string(),
245 params: None,
246 id: Self::generate_request_id(),
247 };
248
249 let response = self.send_request(json_rpc_request).await?;
250
251 if let Some(result) = response.result() {
252 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
253 message: format!("Invalid roots response format: {}", e),
254 context: Some("MCP compliance".to_string()),
255 })
256 } else if let Some(error) = response.error() {
257 Err(ServerError::Handler {
258 message: format!("Roots error: {}", error.message),
259 context: Some(format!("MCP error code: {}", error.code)),
260 })
261 } else {
262 Err(ServerError::Handler {
263 message: "Invalid roots response: missing result and error".to_string(),
264 context: Some("MCP compliance".to_string()),
265 })
266 }
267 }
268
269 fn supports_bidirectional(&self) -> bool {
270 true
271 }
272
273 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
274 Ok(None)
275 }
276}
277
278pub async fn run_stdio_bidirectional(
286 router: Arc<RequestRouter>,
287 dispatcher: StdioDispatcher,
288 mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
289) -> Result<(), Box<dyn std::error::Error>> {
290 let stdin = tokio::io::stdin();
291 let stdout = tokio::io::stdout();
292 let mut reader = BufReader::new(stdin);
293 let mut line = String::new();
294
295 let stdout = Arc::new(Mutex::new(stdout));
296 let pending_requests = Arc::clone(&dispatcher.pending_requests);
297
298 let stdout_writer = Arc::clone(&stdout);
300 tokio::spawn(async move {
301 while let Some(msg) = request_rx.recv().await {
302 match msg {
303 StdioMessage::ServerRequest { request } => {
304 if let Ok(json) = serde_json::to_string(&request) {
305 let mut stdout = stdout_writer.lock().await;
306 let _ = stdout.write_all(json.as_bytes()).await;
307 let _ = stdout.write_all(b"\n").await;
308 let _ = stdout.flush().await;
309 }
310 }
311 StdioMessage::Shutdown => break,
312 }
313 }
314 });
315
316 loop {
318 line.clear();
319 match reader.read_line(&mut line).await {
320 Ok(0) => break, Ok(_) => {
322 if line.trim().is_empty() {
323 continue;
324 }
325
326 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
328 let request_id = match &response.id {
329 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
330 MessageId::String(s) => s.clone(),
331 MessageId::Number(n) => n.to_string(),
332 MessageId::Uuid(u) => u.to_string(),
333 },
334 _ => continue,
335 };
336
337 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
338 let _ = tx.send(response);
339 }
340 continue;
341 }
342
343 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
345 let router = Arc::clone(&router);
346 let stdout = Arc::clone(&stdout);
347
348 tokio::spawn(async move {
349 let ctx = router.create_context();
351 let response = router.route(request, ctx).await;
352
353 if let Ok(json) = serde_json::to_string(&response) {
354 let mut stdout = stdout.lock().await;
355 let _ = stdout.write_all(json.as_bytes()).await;
356 let _ = stdout.write_all(b"\n").await;
357 let _ = stdout.flush().await;
358 }
359 });
360 }
361 }
362 Err(_) => break,
363 }
364 }
365
366 Ok(())
367}
368
369pub struct TransportDispatcher<T>
389where
390 T: turbomcp_transport::Transport,
391{
392 transport: Arc<T>,
394 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
396}
397
398impl<T> Clone for TransportDispatcher<T>
400where
401 T: turbomcp_transport::Transport,
402{
403 fn clone(&self) -> Self {
404 Self {
405 transport: Arc::clone(&self.transport),
406 pending_requests: Arc::clone(&self.pending_requests),
407 }
408 }
409}
410
411impl<T> std::fmt::Debug for TransportDispatcher<T>
412where
413 T: turbomcp_transport::Transport,
414{
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 f.debug_struct("TransportDispatcher")
417 .field("transport_type", &self.transport.transport_type())
418 .field("pending_count", &"<mutex>")
419 .finish()
420 }
421}
422
423impl<T> TransportDispatcher<T>
424where
425 T: turbomcp_transport::Transport,
426{
427 pub fn new(transport: T) -> Self {
429 Self {
430 transport: Arc::new(transport),
431 pending_requests: Arc::new(Mutex::new(HashMap::new())),
432 }
433 }
434
435 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
437 use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
438
439 let (response_tx, response_rx) = oneshot::channel();
440
441 let request_id = match &request.id {
443 MessageId::String(s) => s.clone(),
444 MessageId::Number(n) => n.to_string(),
445 MessageId::Uuid(u) => u.to_string(),
446 };
447
448 self.pending_requests
450 .lock()
451 .await
452 .insert(request_id.clone(), response_tx);
453
454 let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
456 message: format!("Failed to serialize request: {}", e),
457 context: Some("transport_dispatcher".to_string()),
458 })?;
459
460 let transport_msg = TransportMessage::with_metadata(
462 MessageId::Uuid(uuid::Uuid::new_v4()),
463 bytes::Bytes::from(json),
464 TransportMessageMetadata::with_content_type("application/json"),
465 );
466
467 self.transport
468 .send(transport_msg)
469 .await
470 .map_err(|e| ServerError::Handler {
471 message: format!("Failed to send request via transport: {}", e),
472 context: Some("transport_dispatcher".to_string()),
473 })?;
474
475 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
477 Ok(Ok(response)) => Ok(response),
478 Ok(Err(_)) => Err(ServerError::Handler {
479 message: "Response channel closed".to_string(),
480 context: Some("transport_dispatcher".to_string()),
481 }),
482 Err(_) => {
483 self.pending_requests.lock().await.remove(&request_id);
485 Err(ServerError::Handler {
486 message: "Request timeout (60s)".to_string(),
487 context: Some("transport_dispatcher".to_string()),
488 })
489 }
490 }
491 }
492
493 fn generate_request_id() -> MessageId {
495 MessageId::String(uuid::Uuid::new_v4().to_string())
496 }
497
498 pub fn pending_requests(
500 &self,
501 ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
502 Arc::clone(&self.pending_requests)
503 }
504
505 pub fn transport(&self) -> Arc<T> {
507 Arc::clone(&self.transport)
508 }
509}
510
511#[async_trait::async_trait]
512impl<T> ServerRequestDispatcher for TransportDispatcher<T>
513where
514 T: turbomcp_transport::Transport + Send + Sync + 'static,
515{
516 async fn send_elicitation(
517 &self,
518 request: ElicitRequest,
519 _ctx: RequestContext,
520 ) -> ServerResult<ElicitResult> {
521 let json_rpc_request = JsonRpcRequest {
522 jsonrpc: JsonRpcVersion,
523 method: "elicitation/create".to_string(),
524 params: Some(
525 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
526 message: format!("Failed to serialize elicitation request: {}", e),
527 context: Some("MCP compliance".to_string()),
528 })?,
529 ),
530 id: Self::generate_request_id(),
531 };
532
533 let response = self.send_request(json_rpc_request).await?;
534
535 if let Some(result) = response.result() {
536 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
537 message: format!("Invalid elicitation response format: {}", e),
538 context: Some("MCP compliance".to_string()),
539 })
540 } else if let Some(error) = response.error() {
541 Err(ServerError::Handler {
542 message: format!("Elicitation error: {}", error.message),
543 context: Some(format!("MCP error code: {}", error.code)),
544 })
545 } else {
546 Err(ServerError::Handler {
547 message: "Invalid elicitation response: missing result and error".to_string(),
548 context: Some("MCP compliance".to_string()),
549 })
550 }
551 }
552
553 async fn send_ping(
554 &self,
555 _request: PingRequest,
556 _ctx: RequestContext,
557 ) -> ServerResult<PingResult> {
558 let json_rpc_request = JsonRpcRequest {
559 jsonrpc: JsonRpcVersion,
560 method: "ping".to_string(),
561 params: None,
562 id: Self::generate_request_id(),
563 };
564
565 let response = self.send_request(json_rpc_request).await?;
566
567 if response.result().is_some() {
568 Ok(PingResult {
569 data: None,
570 _meta: None,
571 })
572 } else if let Some(error) = response.error() {
573 Err(ServerError::Handler {
574 message: format!("Ping error: {}", error.message),
575 context: Some(format!("MCP error code: {}", error.code)),
576 })
577 } else {
578 Err(ServerError::Handler {
579 message: "Invalid ping response".to_string(),
580 context: Some("MCP compliance".to_string()),
581 })
582 }
583 }
584
585 async fn send_create_message(
586 &self,
587 request: CreateMessageRequest,
588 _ctx: RequestContext,
589 ) -> ServerResult<CreateMessageResult> {
590 let json_rpc_request = JsonRpcRequest {
591 jsonrpc: JsonRpcVersion,
592 method: "sampling/createMessage".to_string(),
593 params: Some(
594 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
595 message: format!("Failed to serialize sampling request: {}", e),
596 context: Some("MCP compliance".to_string()),
597 })?,
598 ),
599 id: Self::generate_request_id(),
600 };
601
602 let response = self.send_request(json_rpc_request).await?;
603
604 if let Some(result) = response.result() {
605 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
606 message: format!("Invalid sampling response format: {}", e),
607 context: Some("MCP compliance".to_string()),
608 })
609 } else if let Some(error) = response.error() {
610 Err(ServerError::Handler {
611 message: format!("Sampling error: {}", error.message),
612 context: Some(format!("MCP error code: {}", error.code)),
613 })
614 } else {
615 Err(ServerError::Handler {
616 message: "Invalid sampling response: missing result and error".to_string(),
617 context: Some("MCP compliance".to_string()),
618 })
619 }
620 }
621
622 async fn send_list_roots(
623 &self,
624 _request: ListRootsRequest,
625 _ctx: RequestContext,
626 ) -> ServerResult<ListRootsResult> {
627 let json_rpc_request = JsonRpcRequest {
628 jsonrpc: JsonRpcVersion,
629 method: "roots/list".to_string(),
630 params: None,
631 id: Self::generate_request_id(),
632 };
633
634 let response = self.send_request(json_rpc_request).await?;
635
636 if let Some(result) = response.result() {
637 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
638 message: format!("Invalid roots response format: {}", e),
639 context: Some("MCP compliance".to_string()),
640 })
641 } else if let Some(error) = response.error() {
642 Err(ServerError::Handler {
643 message: format!("Roots error: {}", error.message),
644 context: Some(format!("MCP error code: {}", error.code)),
645 })
646 } else {
647 Err(ServerError::Handler {
648 message: "Invalid roots response: missing result and error".to_string(),
649 context: Some("MCP compliance".to_string()),
650 })
651 }
652 }
653
654 fn supports_bidirectional(&self) -> bool {
655 self.transport.capabilities().supports_bidirectional
656 }
657
658 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
659 Ok(None)
660 }
661}
662
663pub async fn run_transport_bidirectional<T>(
689 router: Arc<RequestRouter>,
690 dispatcher: TransportDispatcher<T>,
691) -> Result<(), Box<dyn std::error::Error>>
692where
693 T: turbomcp_transport::Transport + Send + Sync + 'static,
694{
695 let transport = dispatcher.transport();
696 let pending_requests = dispatcher.pending_requests();
697
698 loop {
700 match transport.receive().await {
702 Ok(Some(message)) => {
703 if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
705 let request_id = match &response.id {
706 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
707 MessageId::String(s) => s.clone(),
708 MessageId::Number(n) => n.to_string(),
709 MessageId::Uuid(u) => u.to_string(),
710 },
711 _ => continue,
712 };
713
714 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
715 let _ = tx.send(response);
716 }
717 continue;
718 }
719
720 if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
722 let router = Arc::clone(&router);
723 let transport = Arc::clone(&transport);
724
725 tokio::spawn(async move {
726 let ctx = router.create_context();
728 let response = router.route(request, ctx).await;
729
730 if let Ok(json) = serde_json::to_vec(&response) {
732 use turbomcp_transport::{
733 TransportMessage, core::TransportMessageMetadata,
734 };
735 let transport_msg = TransportMessage::with_metadata(
736 MessageId::Uuid(uuid::Uuid::new_v4()),
737 bytes::Bytes::from(json),
738 TransportMessageMetadata::with_content_type("application/json"),
739 );
740 let _ = transport.send(transport_msg).await;
741 }
742 });
743 }
744 }
745 Ok(None) => {
746 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
748 }
749 Err(e) => {
750 tracing::error!(error = %e, "Transport receive failed");
751 break;
752 }
753 }
754 }
755
756 Ok(())
757}