1#[cfg(feature = "http")]
24pub mod http;
25
26#[cfg(feature = "websocket")]
28pub mod websocket;
29
30use std::collections::HashMap;
31use std::sync::Arc;
32
33use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
34use tokio::sync::{Mutex, mpsc, oneshot};
35use tokio::task::JoinSet;
36
37use turbomcp_protocol::RequestContext;
38use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
39use turbomcp_protocol::types::{
40 CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
41 ListRootsResult, PingRequest, PingResult,
42};
43
44use crate::routing::{RequestRouter, ServerRequestDispatcher};
45use crate::{ServerError, ServerResult};
46
47type MessageId = turbomcp_protocol::MessageId;
48
49#[derive(Clone)]
54pub struct StdioDispatcher {
55 request_tx: mpsc::UnboundedSender<StdioMessage>,
57 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
59}
60
61impl std::fmt::Debug for StdioDispatcher {
62 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63 f.debug_struct("StdioDispatcher")
64 .field("has_request_tx", &true)
65 .field("pending_count", &"<mutex>")
66 .finish()
67 }
68}
69
70#[derive(Debug)]
72pub enum StdioMessage {
73 ServerRequest {
75 request: JsonRpcRequest,
77 },
78 Shutdown,
80}
81
82impl StdioDispatcher {
83 pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
85 Self {
86 request_tx,
87 pending_requests: Arc::new(Mutex::new(HashMap::new())),
88 }
89 }
90
91 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
93 let (response_tx, response_rx) = oneshot::channel();
94
95 let request_id = match &request.id {
97 MessageId::String(s) => s.clone(),
98 MessageId::Number(n) => n.to_string(),
99 MessageId::Uuid(u) => u.to_string(),
100 };
101
102 self.pending_requests
104 .lock()
105 .await
106 .insert(request_id.clone(), response_tx);
107
108 self.request_tx
110 .send(StdioMessage::ServerRequest { request })
111 .map_err(|e| ServerError::Handler {
112 message: format!("Failed to send request to stdout: {}", e),
113 context: Some("stdio_dispatcher".to_string()),
114 })?;
115
116 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
118 Ok(Ok(response)) => Ok(response),
119 Ok(Err(_)) => Err(ServerError::Handler {
120 message: "Response channel closed".to_string(),
121 context: Some("stdio_dispatcher".to_string()),
122 }),
123 Err(_) => {
124 self.pending_requests.lock().await.remove(&request_id);
126 Err(ServerError::Handler {
127 message: "Request timeout (60s)".to_string(),
128 context: Some("stdio_dispatcher".to_string()),
129 })
130 }
131 }
132 }
133
134 fn generate_request_id() -> MessageId {
136 MessageId::String(uuid::Uuid::new_v4().to_string())
137 }
138}
139
140#[async_trait::async_trait]
141impl ServerRequestDispatcher for StdioDispatcher {
142 async fn send_elicitation(
143 &self,
144 request: ElicitRequest,
145 _ctx: RequestContext,
146 ) -> ServerResult<ElicitResult> {
147 let json_rpc_request = JsonRpcRequest {
148 jsonrpc: JsonRpcVersion,
149 method: "elicitation/create".to_string(),
150 params: Some(
151 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
152 message: format!("Failed to serialize elicitation request: {}", e),
153 context: Some("MCP compliance".to_string()),
154 })?,
155 ),
156 id: Self::generate_request_id(),
157 };
158
159 let response = self.send_request(json_rpc_request).await?;
160
161 if let Some(result) = response.result() {
162 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
163 message: format!("Invalid elicitation response format: {}", e),
164 context: Some("MCP compliance".to_string()),
165 })
166 } else if let Some(error) = response.error() {
167 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
169 error.code,
170 &error.message,
171 )))
172 } else {
173 Err(ServerError::Handler {
174 message: "Invalid elicitation response: missing result and error".to_string(),
175 context: Some("MCP compliance".to_string()),
176 })
177 }
178 }
179
180 async fn send_ping(
181 &self,
182 _request: PingRequest,
183 _ctx: RequestContext,
184 ) -> ServerResult<PingResult> {
185 let json_rpc_request = JsonRpcRequest {
186 jsonrpc: JsonRpcVersion,
187 method: "ping".to_string(),
188 params: None,
189 id: Self::generate_request_id(),
190 };
191
192 let response = self.send_request(json_rpc_request).await?;
193
194 if response.result().is_some() {
195 Ok(PingResult {
196 data: None,
197 _meta: None,
198 })
199 } else if let Some(error) = response.error() {
200 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
202 error.code,
203 &error.message,
204 )))
205 } else {
206 Err(ServerError::Handler {
207 message: "Invalid ping response".to_string(),
208 context: Some("MCP compliance".to_string()),
209 })
210 }
211 }
212
213 async fn send_create_message(
214 &self,
215 request: CreateMessageRequest,
216 _ctx: RequestContext,
217 ) -> ServerResult<CreateMessageResult> {
218 let json_rpc_request = JsonRpcRequest {
219 jsonrpc: JsonRpcVersion,
220 method: "sampling/createMessage".to_string(),
221 params: Some(
222 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
223 message: format!("Failed to serialize sampling request: {}", e),
224 context: Some("MCP compliance".to_string()),
225 })?,
226 ),
227 id: Self::generate_request_id(),
228 };
229
230 let response = self.send_request(json_rpc_request).await?;
231
232 if let Some(result) = response.result() {
233 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
234 message: format!("Invalid sampling response format: {}", e),
235 context: Some("MCP compliance".to_string()),
236 })
237 } else if let Some(error) = response.error() {
238 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
240 error.code,
241 &error.message,
242 )))
243 } else {
244 Err(ServerError::Handler {
245 message: "Invalid sampling response: missing result and error".to_string(),
246 context: Some("MCP compliance".to_string()),
247 })
248 }
249 }
250
251 async fn send_list_roots(
252 &self,
253 _request: ListRootsRequest,
254 _ctx: RequestContext,
255 ) -> ServerResult<ListRootsResult> {
256 let json_rpc_request = JsonRpcRequest {
257 jsonrpc: JsonRpcVersion,
258 method: "roots/list".to_string(),
259 params: None,
260 id: Self::generate_request_id(),
261 };
262
263 let response = self.send_request(json_rpc_request).await?;
264
265 if let Some(result) = response.result() {
266 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
267 message: format!("Invalid roots response format: {}", e),
268 context: Some("MCP compliance".to_string()),
269 })
270 } else if let Some(error) = response.error() {
271 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
273 error.code,
274 &error.message,
275 )))
276 } else {
277 Err(ServerError::Handler {
278 message: "Invalid roots response: missing result and error".to_string(),
279 context: Some("MCP compliance".to_string()),
280 })
281 }
282 }
283
284 fn supports_bidirectional(&self) -> bool {
285 true
286 }
287
288 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
289 Ok(None)
290 }
291}
292
293pub async fn run_stdio_bidirectional(
302 router: Arc<RequestRouter>,
303 dispatcher: StdioDispatcher,
304 mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
305) -> Result<(), Box<dyn std::error::Error>> {
306 let stdin = tokio::io::stdin();
307 let stdout = tokio::io::stdout();
308 let mut reader = BufReader::new(stdin);
309 let mut line = String::new();
310
311 let stdout = Arc::new(Mutex::new(stdout));
312 let pending_requests = Arc::clone(&dispatcher.pending_requests);
313
314 let mut tasks = JoinSet::new();
316
317 let stdout_writer = Arc::clone(&stdout);
319 tasks.spawn(async move {
320 while let Some(msg) = request_rx.recv().await {
321 match msg {
322 StdioMessage::ServerRequest { request } => {
323 if let Ok(json) = serde_json::to_string(&request) {
324 let mut stdout = stdout_writer.lock().await;
325 let _ = stdout.write_all(json.as_bytes()).await;
326 let _ = stdout.write_all(b"\n").await;
327 let _ = stdout.flush().await;
328 }
329 }
330 StdioMessage::Shutdown => break,
331 }
332 }
333 });
334
335 loop {
337 line.clear();
338 match reader.read_line(&mut line).await {
339 Ok(0) => break, Ok(_) => {
341 if line.trim().is_empty() {
342 continue;
343 }
344
345 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
347 let request_id = match &response.id {
348 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
349 MessageId::String(s) => s.clone(),
350 MessageId::Number(n) => n.to_string(),
351 MessageId::Uuid(u) => u.to_string(),
352 },
353 _ => continue,
354 };
355
356 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
357 let _ = tx.send(response);
358 }
359 continue;
360 }
361
362 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
364 let router = Arc::clone(&router);
365 let stdout = Arc::clone(&stdout);
366
367 tasks.spawn(async move {
369 let ctx = router.create_context();
371 let response = router.route(request, ctx).await;
372
373 if let Ok(json) = serde_json::to_string(&response) {
374 let mut stdout = stdout.lock().await;
375 let _ = stdout.write_all(json.as_bytes()).await;
376 let _ = stdout.write_all(b"\n").await;
377 let _ = stdout.flush().await;
378 }
379 });
380 }
381 }
382 Err(_) => break,
383 }
384 }
385
386 tracing::debug!(
388 "STDIO dispatcher shutting down, waiting for {} tasks",
389 tasks.len()
390 );
391
392 drop(dispatcher);
395
396 let shutdown_timeout = std::time::Duration::from_secs(5);
398 let start = std::time::Instant::now();
399
400 while let Some(result) = tokio::time::timeout(
401 shutdown_timeout.saturating_sub(start.elapsed()),
402 tasks.join_next(),
403 )
404 .await
405 .ok()
406 .flatten()
407 {
408 match result {
409 Ok(()) => {
410 tracing::debug!("Task completed successfully during shutdown");
411 }
412 Err(e) if e.is_panic() => {
413 tracing::warn!("Task panicked during shutdown: {:?}", e);
414 }
415 Err(e) if e.is_cancelled() => {
416 tracing::debug!("Task was cancelled during shutdown");
417 }
418 Err(e) => {
419 tracing::debug!("Task error during shutdown: {:?}", e);
420 }
421 }
422 }
423
424 if !tasks.is_empty() {
426 tracing::warn!(
427 "Aborting {} tasks due to shutdown timeout ({}s)",
428 tasks.len(),
429 shutdown_timeout.as_secs()
430 );
431 tasks.shutdown().await;
432 }
433
434 tracing::debug!("STDIO dispatcher shutdown complete");
435 Ok(())
436}
437
438pub struct TransportDispatcher<T>
458where
459 T: turbomcp_transport::Transport,
460{
461 transport: Arc<T>,
463 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
465}
466
467impl<T> Clone for TransportDispatcher<T>
469where
470 T: turbomcp_transport::Transport,
471{
472 fn clone(&self) -> Self {
473 Self {
474 transport: Arc::clone(&self.transport),
475 pending_requests: Arc::clone(&self.pending_requests),
476 }
477 }
478}
479
480impl<T> std::fmt::Debug for TransportDispatcher<T>
481where
482 T: turbomcp_transport::Transport,
483{
484 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
485 f.debug_struct("TransportDispatcher")
486 .field("transport_type", &self.transport.transport_type())
487 .field("pending_count", &"<mutex>")
488 .finish()
489 }
490}
491
492impl<T> TransportDispatcher<T>
493where
494 T: turbomcp_transport::Transport,
495{
496 pub fn new(transport: T) -> Self {
498 Self {
499 transport: Arc::new(transport),
500 pending_requests: Arc::new(Mutex::new(HashMap::new())),
501 }
502 }
503
504 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
506 use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
507
508 let (response_tx, response_rx) = oneshot::channel();
509
510 let request_id = match &request.id {
512 MessageId::String(s) => s.clone(),
513 MessageId::Number(n) => n.to_string(),
514 MessageId::Uuid(u) => u.to_string(),
515 };
516
517 self.pending_requests
519 .lock()
520 .await
521 .insert(request_id.clone(), response_tx);
522
523 let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
525 message: format!("Failed to serialize request: {}", e),
526 context: Some("transport_dispatcher".to_string()),
527 })?;
528
529 let transport_msg = TransportMessage::with_metadata(
531 MessageId::Uuid(uuid::Uuid::new_v4()),
532 bytes::Bytes::from(json),
533 TransportMessageMetadata::with_content_type("application/json"),
534 );
535
536 self.transport
537 .send(transport_msg)
538 .await
539 .map_err(|e| ServerError::Handler {
540 message: format!("Failed to send request via transport: {}", e),
541 context: Some("transport_dispatcher".to_string()),
542 })?;
543
544 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
546 Ok(Ok(response)) => Ok(response),
547 Ok(Err(_)) => Err(ServerError::Handler {
548 message: "Response channel closed".to_string(),
549 context: Some("transport_dispatcher".to_string()),
550 }),
551 Err(_) => {
552 self.pending_requests.lock().await.remove(&request_id);
554 Err(ServerError::Handler {
555 message: "Request timeout (60s)".to_string(),
556 context: Some("transport_dispatcher".to_string()),
557 })
558 }
559 }
560 }
561
562 fn generate_request_id() -> MessageId {
564 MessageId::String(uuid::Uuid::new_v4().to_string())
565 }
566
567 pub fn pending_requests(
569 &self,
570 ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
571 Arc::clone(&self.pending_requests)
572 }
573
574 pub fn transport(&self) -> Arc<T> {
576 Arc::clone(&self.transport)
577 }
578}
579
580#[async_trait::async_trait]
581impl<T> ServerRequestDispatcher for TransportDispatcher<T>
582where
583 T: turbomcp_transport::Transport + Send + Sync + 'static,
584{
585 async fn send_elicitation(
586 &self,
587 request: ElicitRequest,
588 _ctx: RequestContext,
589 ) -> ServerResult<ElicitResult> {
590 let json_rpc_request = JsonRpcRequest {
591 jsonrpc: JsonRpcVersion,
592 method: "elicitation/create".to_string(),
593 params: Some(
594 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
595 message: format!("Failed to serialize elicitation request: {}", e),
596 context: Some("MCP compliance".to_string()),
597 })?,
598 ),
599 id: Self::generate_request_id(),
600 };
601
602 let response = self.send_request(json_rpc_request).await?;
603
604 if let Some(result) = response.result() {
605 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
606 message: format!("Invalid elicitation response format: {}", e),
607 context: Some("MCP compliance".to_string()),
608 })
609 } else if let Some(error) = response.error() {
610 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
612 error.code,
613 &error.message,
614 )))
615 } else {
616 Err(ServerError::Handler {
617 message: "Invalid elicitation response: missing result and error".to_string(),
618 context: Some("MCP compliance".to_string()),
619 })
620 }
621 }
622
623 async fn send_ping(
624 &self,
625 _request: PingRequest,
626 _ctx: RequestContext,
627 ) -> ServerResult<PingResult> {
628 let json_rpc_request = JsonRpcRequest {
629 jsonrpc: JsonRpcVersion,
630 method: "ping".to_string(),
631 params: None,
632 id: Self::generate_request_id(),
633 };
634
635 let response = self.send_request(json_rpc_request).await?;
636
637 if response.result().is_some() {
638 Ok(PingResult {
639 data: None,
640 _meta: None,
641 })
642 } else if let Some(error) = response.error() {
643 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
645 error.code,
646 &error.message,
647 )))
648 } else {
649 Err(ServerError::Handler {
650 message: "Invalid ping response".to_string(),
651 context: Some("MCP compliance".to_string()),
652 })
653 }
654 }
655
656 async fn send_create_message(
657 &self,
658 request: CreateMessageRequest,
659 _ctx: RequestContext,
660 ) -> ServerResult<CreateMessageResult> {
661 let json_rpc_request = JsonRpcRequest {
662 jsonrpc: JsonRpcVersion,
663 method: "sampling/createMessage".to_string(),
664 params: Some(
665 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
666 message: format!("Failed to serialize sampling request: {}", e),
667 context: Some("MCP compliance".to_string()),
668 })?,
669 ),
670 id: Self::generate_request_id(),
671 };
672
673 let response = self.send_request(json_rpc_request).await?;
674
675 if let Some(result) = response.result() {
676 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
677 message: format!("Invalid sampling response format: {}", e),
678 context: Some("MCP compliance".to_string()),
679 })
680 } else if let Some(error) = response.error() {
681 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
683 error.code,
684 &error.message,
685 )))
686 } else {
687 Err(ServerError::Handler {
688 message: "Invalid sampling response: missing result and error".to_string(),
689 context: Some("MCP compliance".to_string()),
690 })
691 }
692 }
693
694 async fn send_list_roots(
695 &self,
696 _request: ListRootsRequest,
697 _ctx: RequestContext,
698 ) -> ServerResult<ListRootsResult> {
699 let json_rpc_request = JsonRpcRequest {
700 jsonrpc: JsonRpcVersion,
701 method: "roots/list".to_string(),
702 params: None,
703 id: Self::generate_request_id(),
704 };
705
706 let response = self.send_request(json_rpc_request).await?;
707
708 if let Some(result) = response.result() {
709 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
710 message: format!("Invalid roots response format: {}", e),
711 context: Some("MCP compliance".to_string()),
712 })
713 } else if let Some(error) = response.error() {
714 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
716 error.code,
717 &error.message,
718 )))
719 } else {
720 Err(ServerError::Handler {
721 message: "Invalid roots response: missing result and error".to_string(),
722 context: Some("MCP compliance".to_string()),
723 })
724 }
725 }
726
727 fn supports_bidirectional(&self) -> bool {
728 self.transport.capabilities().supports_bidirectional
729 }
730
731 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
732 Ok(None)
733 }
734}
735
736pub async fn run_transport_bidirectional<T>(
762 router: Arc<RequestRouter>,
763 dispatcher: TransportDispatcher<T>,
764) -> Result<(), Box<dyn std::error::Error>>
765where
766 T: turbomcp_transport::Transport + Send + Sync + 'static,
767{
768 let transport = dispatcher.transport();
769 let pending_requests = dispatcher.pending_requests();
770
771 loop {
773 match transport.receive().await {
775 Ok(Some(message)) => {
776 if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
778 let request_id = match &response.id {
779 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
780 MessageId::String(s) => s.clone(),
781 MessageId::Number(n) => n.to_string(),
782 MessageId::Uuid(u) => u.to_string(),
783 },
784 _ => continue,
785 };
786
787 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
788 let _ = tx.send(response);
789 }
790 continue;
791 }
792
793 if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
795 let router = Arc::clone(&router);
796 let transport = Arc::clone(&transport);
797
798 tokio::spawn(async move {
799 let ctx = router.create_context();
801 let response = router.route(request, ctx).await;
802
803 if let Ok(json) = serde_json::to_vec(&response) {
805 use turbomcp_transport::{
806 TransportMessage, core::TransportMessageMetadata,
807 };
808 let transport_msg = TransportMessage::with_metadata(
809 MessageId::Uuid(uuid::Uuid::new_v4()),
810 bytes::Bytes::from(json),
811 TransportMessageMetadata::with_content_type("application/json"),
812 );
813 let _ = transport.send(transport_msg).await;
814 }
815 });
816 }
817 }
818 Ok(None) => {
819 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
821 }
822 Err(e) => {
823 tracing::error!(error = %e, "Transport receive failed");
824 break;
825 }
826 }
827 }
828
829 Ok(())
830}
831
832#[cfg(test)]
837mod tests {
838 use super::*;
839
840 #[tokio::test]
841 async fn test_stdio_dispatcher_clean_shutdown() {
842 let (tx, _rx) = mpsc::unbounded_channel();
844 let dispatcher = StdioDispatcher::new(tx);
845
846 drop(dispatcher);
848 }
849
850 #[tokio::test]
851 async fn test_stdio_dispatcher_creation() {
852 let (tx, _rx) = mpsc::unbounded_channel();
854 let dispatcher = StdioDispatcher::new(tx.clone());
855
856 let _dispatcher2 = dispatcher.clone();
858
859 assert!(tx.send(StdioMessage::Shutdown).is_ok());
861 }
862
863 #[tokio::test]
864 async fn test_joinset_task_tracking() {
865 let mut tasks = JoinSet::new();
867
868 for i in 0..5 {
870 tasks.spawn(async move {
871 tokio::time::sleep(tokio::time::Duration::from_millis(i * 10)).await;
872 });
873 }
874
875 assert_eq!(tasks.len(), 5);
876
877 let mut completed = 0;
879 while let Some(result) = tasks.join_next().await {
880 assert!(result.is_ok());
881 completed += 1;
882 }
883
884 assert_eq!(completed, 5);
885 assert!(tasks.is_empty());
886 }
887
888 #[tokio::test]
889 async fn test_joinset_with_timeout() {
890 let mut tasks = JoinSet::new();
892
893 tasks.spawn(async move {
895 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
896 });
897
898 let timeout = std::time::Duration::from_millis(100);
900 let start = std::time::Instant::now();
901
902 let result = tokio::time::timeout(timeout, tasks.join_next()).await;
903
904 assert!(result.is_err());
906 assert!(start.elapsed() < std::time::Duration::from_secs(1));
907
908 tasks.shutdown().await;
910 }
911
912 #[tokio::test]
913 async fn test_stdio_message_types() {
914 use turbomcp_protocol::jsonrpc::JsonRpcRequest;
916
917 let request = JsonRpcRequest {
918 jsonrpc: JsonRpcVersion,
919 method: "test".to_string(),
920 params: None,
921 id: MessageId::String("test-1".to_string()),
922 };
923
924 let msg = StdioMessage::ServerRequest { request };
925
926 match msg {
927 StdioMessage::ServerRequest { .. } => { }
928 _ => panic!("Expected ServerRequest"),
929 }
930
931 let shutdown_msg = StdioMessage::Shutdown;
932 match shutdown_msg {
933 StdioMessage::Shutdown => { }
934 _ => panic!("Expected Shutdown"),
935 }
936 }
937
938 #[tokio::test]
939 async fn test_pending_requests_cleanup() {
940 let pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
942 Arc::new(Mutex::new(HashMap::new()));
943
944 let (tx, _rx) = oneshot::channel();
945 pending_requests
946 .lock()
947 .await
948 .insert("test-id".to_string(), tx);
949
950 assert_eq!(pending_requests.lock().await.len(), 1);
951
952 let removed = pending_requests.lock().await.remove("test-id");
954 assert!(removed.is_some());
955 assert_eq!(pending_requests.lock().await.len(), 0);
956 }
957}