1#[cfg(feature = "http")]
24pub mod http;
25
26use std::collections::HashMap;
27use std::sync::Arc;
28
29use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
30use tokio::sync::{Mutex, mpsc, oneshot};
31use tokio::task::JoinSet;
32
33use turbomcp_protocol::RequestContext;
34use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcResponse, JsonRpcVersion};
35use turbomcp_protocol::types::{
36 CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, ListRootsRequest,
37 ListRootsResult, PingRequest, PingResult,
38};
39
40use crate::routing::{RequestRouter, ServerRequestDispatcher};
41use crate::{ServerError, ServerResult};
42
43type MessageId = turbomcp_protocol::MessageId;
44
45#[derive(Clone)]
50pub struct StdioDispatcher {
51 request_tx: mpsc::UnboundedSender<StdioMessage>,
53 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
55}
56
57impl std::fmt::Debug for StdioDispatcher {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("StdioDispatcher")
60 .field("has_request_tx", &true)
61 .field("pending_count", &"<mutex>")
62 .finish()
63 }
64}
65
66#[derive(Debug)]
68pub enum StdioMessage {
69 ServerRequest {
71 request: JsonRpcRequest,
73 },
74 Shutdown,
76}
77
78impl StdioDispatcher {
79 pub fn new(request_tx: mpsc::UnboundedSender<StdioMessage>) -> Self {
81 Self {
82 request_tx,
83 pending_requests: Arc::new(Mutex::new(HashMap::new())),
84 }
85 }
86
87 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
89 let (response_tx, response_rx) = oneshot::channel();
90
91 let request_id = match &request.id {
93 MessageId::String(s) => s.clone(),
94 MessageId::Number(n) => n.to_string(),
95 MessageId::Uuid(u) => u.to_string(),
96 };
97
98 self.pending_requests
100 .lock()
101 .await
102 .insert(request_id.clone(), response_tx);
103
104 self.request_tx
106 .send(StdioMessage::ServerRequest { request })
107 .map_err(|e| ServerError::Handler {
108 message: format!("Failed to send request to stdout: {}", e),
109 context: Some("stdio_dispatcher".to_string()),
110 })?;
111
112 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
114 Ok(Ok(response)) => Ok(response),
115 Ok(Err(_)) => Err(ServerError::Handler {
116 message: "Response channel closed".to_string(),
117 context: Some("stdio_dispatcher".to_string()),
118 }),
119 Err(_) => {
120 self.pending_requests.lock().await.remove(&request_id);
122 Err(ServerError::Handler {
123 message: "Request timeout (60s)".to_string(),
124 context: Some("stdio_dispatcher".to_string()),
125 })
126 }
127 }
128 }
129
130 fn generate_request_id() -> MessageId {
132 MessageId::String(uuid::Uuid::new_v4().to_string())
133 }
134}
135
136#[async_trait::async_trait]
137impl ServerRequestDispatcher for StdioDispatcher {
138 async fn send_elicitation(
139 &self,
140 request: ElicitRequest,
141 _ctx: RequestContext,
142 ) -> ServerResult<ElicitResult> {
143 let json_rpc_request = JsonRpcRequest {
144 jsonrpc: JsonRpcVersion,
145 method: "elicitation/create".to_string(),
146 params: Some(
147 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
148 message: format!("Failed to serialize elicitation request: {}", e),
149 context: Some("MCP compliance".to_string()),
150 })?,
151 ),
152 id: Self::generate_request_id(),
153 };
154
155 let response = self.send_request(json_rpc_request).await?;
156
157 if let Some(result) = response.result() {
158 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
159 message: format!("Invalid elicitation response format: {}", e),
160 context: Some("MCP compliance".to_string()),
161 })
162 } else if let Some(error) = response.error() {
163 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
165 error.code,
166 &error.message,
167 )))
168 } else {
169 Err(ServerError::Handler {
170 message: "Invalid elicitation response: missing result and error".to_string(),
171 context: Some("MCP compliance".to_string()),
172 })
173 }
174 }
175
176 async fn send_ping(
177 &self,
178 _request: PingRequest,
179 _ctx: RequestContext,
180 ) -> ServerResult<PingResult> {
181 let json_rpc_request = JsonRpcRequest {
182 jsonrpc: JsonRpcVersion,
183 method: "ping".to_string(),
184 params: None,
185 id: Self::generate_request_id(),
186 };
187
188 let response = self.send_request(json_rpc_request).await?;
189
190 if response.result().is_some() {
191 Ok(PingResult {
192 data: None,
193 _meta: None,
194 })
195 } else if let Some(error) = response.error() {
196 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
198 error.code,
199 &error.message,
200 )))
201 } else {
202 Err(ServerError::Handler {
203 message: "Invalid ping response".to_string(),
204 context: Some("MCP compliance".to_string()),
205 })
206 }
207 }
208
209 async fn send_create_message(
210 &self,
211 request: CreateMessageRequest,
212 _ctx: RequestContext,
213 ) -> ServerResult<CreateMessageResult> {
214 let json_rpc_request = JsonRpcRequest {
215 jsonrpc: JsonRpcVersion,
216 method: "sampling/createMessage".to_string(),
217 params: Some(
218 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
219 message: format!("Failed to serialize sampling request: {}", e),
220 context: Some("MCP compliance".to_string()),
221 })?,
222 ),
223 id: Self::generate_request_id(),
224 };
225
226 let response = self.send_request(json_rpc_request).await?;
227
228 if let Some(result) = response.result() {
229 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
230 message: format!("Invalid sampling response format: {}", e),
231 context: Some("MCP compliance".to_string()),
232 })
233 } else if let Some(error) = response.error() {
234 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
236 error.code,
237 &error.message,
238 )))
239 } else {
240 Err(ServerError::Handler {
241 message: "Invalid sampling response: missing result and error".to_string(),
242 context: Some("MCP compliance".to_string()),
243 })
244 }
245 }
246
247 async fn send_list_roots(
248 &self,
249 _request: ListRootsRequest,
250 _ctx: RequestContext,
251 ) -> ServerResult<ListRootsResult> {
252 let json_rpc_request = JsonRpcRequest {
253 jsonrpc: JsonRpcVersion,
254 method: "roots/list".to_string(),
255 params: None,
256 id: Self::generate_request_id(),
257 };
258
259 let response = self.send_request(json_rpc_request).await?;
260
261 if let Some(result) = response.result() {
262 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
263 message: format!("Invalid roots response format: {}", e),
264 context: Some("MCP compliance".to_string()),
265 })
266 } else if let Some(error) = response.error() {
267 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
269 error.code,
270 &error.message,
271 )))
272 } else {
273 Err(ServerError::Handler {
274 message: "Invalid roots response: missing result and error".to_string(),
275 context: Some("MCP compliance".to_string()),
276 })
277 }
278 }
279
280 fn supports_bidirectional(&self) -> bool {
281 true
282 }
283
284 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
285 Ok(None)
286 }
287}
288
289pub async fn run_stdio_bidirectional(
298 router: Arc<RequestRouter>,
299 dispatcher: StdioDispatcher,
300 mut request_rx: mpsc::UnboundedReceiver<StdioMessage>,
301) -> Result<(), Box<dyn std::error::Error>> {
302 let stdin = tokio::io::stdin();
303 let stdout = tokio::io::stdout();
304 let mut reader = BufReader::new(stdin);
305 let mut line = String::new();
306
307 let stdout = Arc::new(Mutex::new(stdout));
308 let pending_requests = Arc::clone(&dispatcher.pending_requests);
309
310 let mut tasks = JoinSet::new();
312
313 let stdout_writer = Arc::clone(&stdout);
315 tasks.spawn(async move {
316 while let Some(msg) = request_rx.recv().await {
317 match msg {
318 StdioMessage::ServerRequest { request } => {
319 if let Ok(json) = serde_json::to_string(&request) {
320 let mut stdout = stdout_writer.lock().await;
321 let _ = stdout.write_all(json.as_bytes()).await;
322 let _ = stdout.write_all(b"\n").await;
323 let _ = stdout.flush().await;
324 }
325 }
326 StdioMessage::Shutdown => break,
327 }
328 }
329 });
330
331 loop {
333 line.clear();
334 match reader.read_line(&mut line).await {
335 Ok(0) => break, Ok(_) => {
337 if line.trim().is_empty() {
338 continue;
339 }
340
341 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(&line) {
343 let request_id = match &response.id {
344 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
345 MessageId::String(s) => s.clone(),
346 MessageId::Number(n) => n.to_string(),
347 MessageId::Uuid(u) => u.to_string(),
348 },
349 _ => continue,
350 };
351
352 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
353 let _ = tx.send(response);
354 }
355 continue;
356 }
357
358 if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(&line) {
360 let router = Arc::clone(&router);
361 let stdout = Arc::clone(&stdout);
362
363 tasks.spawn(async move {
365 let ctx = router.create_context(None, None);
367 let response = router.route(request, ctx).await;
368
369 if let Ok(json) = serde_json::to_string(&response) {
370 let mut stdout = stdout.lock().await;
371 let _ = stdout.write_all(json.as_bytes()).await;
372 let _ = stdout.write_all(b"\n").await;
373 let _ = stdout.flush().await;
374 }
375 });
376 }
377 }
378 Err(_) => break,
379 }
380 }
381
382 tracing::debug!(
384 "STDIO dispatcher shutting down, waiting for {} tasks",
385 tasks.len()
386 );
387
388 drop(dispatcher);
391
392 let shutdown_timeout = std::time::Duration::from_secs(5);
394 let start = std::time::Instant::now();
395
396 while let Some(result) = tokio::time::timeout(
397 shutdown_timeout.saturating_sub(start.elapsed()),
398 tasks.join_next(),
399 )
400 .await
401 .ok()
402 .flatten()
403 {
404 match result {
405 Ok(()) => {
406 tracing::debug!("Task completed successfully during shutdown");
407 }
408 Err(e) if e.is_panic() => {
409 tracing::warn!("Task panicked during shutdown: {:?}", e);
410 }
411 Err(e) if e.is_cancelled() => {
412 tracing::debug!("Task was cancelled during shutdown");
413 }
414 Err(e) => {
415 tracing::debug!("Task error during shutdown: {:?}", e);
416 }
417 }
418 }
419
420 if !tasks.is_empty() {
422 tracing::warn!(
423 "Aborting {} tasks due to shutdown timeout ({}s)",
424 tasks.len(),
425 shutdown_timeout.as_secs()
426 );
427 tasks.shutdown().await;
428 }
429
430 tracing::debug!("STDIO dispatcher shutdown complete");
431 Ok(())
432}
433
434pub struct TransportDispatcher<T>
454where
455 T: turbomcp_transport::Transport,
456{
457 transport: Arc<T>,
459 pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>>,
461}
462
463impl<T> Clone for TransportDispatcher<T>
465where
466 T: turbomcp_transport::Transport,
467{
468 fn clone(&self) -> Self {
469 Self {
470 transport: Arc::clone(&self.transport),
471 pending_requests: Arc::clone(&self.pending_requests),
472 }
473 }
474}
475
476impl<T> std::fmt::Debug for TransportDispatcher<T>
477where
478 T: turbomcp_transport::Transport,
479{
480 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481 f.debug_struct("TransportDispatcher")
482 .field("transport_type", &self.transport.transport_type())
483 .field("pending_count", &"<mutex>")
484 .finish()
485 }
486}
487
488impl<T> TransportDispatcher<T>
489where
490 T: turbomcp_transport::Transport,
491{
492 pub fn new(transport: T) -> Self {
494 Self {
495 transport: Arc::new(transport),
496 pending_requests: Arc::new(Mutex::new(HashMap::new())),
497 }
498 }
499
500 async fn send_request(&self, request: JsonRpcRequest) -> ServerResult<JsonRpcResponse> {
502 use turbomcp_transport::{TransportMessage, core::TransportMessageMetadata};
503
504 let (response_tx, response_rx) = oneshot::channel();
505
506 let request_id = match &request.id {
508 MessageId::String(s) => s.clone(),
509 MessageId::Number(n) => n.to_string(),
510 MessageId::Uuid(u) => u.to_string(),
511 };
512
513 self.pending_requests
515 .lock()
516 .await
517 .insert(request_id.clone(), response_tx);
518
519 let json = serde_json::to_vec(&request).map_err(|e| ServerError::Handler {
521 message: format!("Failed to serialize request: {}", e),
522 context: Some("transport_dispatcher".to_string()),
523 })?;
524
525 let transport_msg = TransportMessage::with_metadata(
527 MessageId::Uuid(uuid::Uuid::new_v4()),
528 bytes::Bytes::from(json),
529 TransportMessageMetadata::with_content_type("application/json"),
530 );
531
532 self.transport
533 .send(transport_msg)
534 .await
535 .map_err(|e| ServerError::Handler {
536 message: format!("Failed to send request via transport: {}", e),
537 context: Some("transport_dispatcher".to_string()),
538 })?;
539
540 match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx).await {
542 Ok(Ok(response)) => Ok(response),
543 Ok(Err(_)) => Err(ServerError::Handler {
544 message: "Response channel closed".to_string(),
545 context: Some("transport_dispatcher".to_string()),
546 }),
547 Err(_) => {
548 self.pending_requests.lock().await.remove(&request_id);
550 Err(ServerError::Handler {
551 message: "Request timeout (60s)".to_string(),
552 context: Some("transport_dispatcher".to_string()),
553 })
554 }
555 }
556 }
557
558 fn generate_request_id() -> MessageId {
560 MessageId::String(uuid::Uuid::new_v4().to_string())
561 }
562
563 pub fn pending_requests(
565 &self,
566 ) -> Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> {
567 Arc::clone(&self.pending_requests)
568 }
569
570 pub fn transport(&self) -> Arc<T> {
572 Arc::clone(&self.transport)
573 }
574}
575
576#[async_trait::async_trait]
577impl<T> ServerRequestDispatcher for TransportDispatcher<T>
578where
579 T: turbomcp_transport::Transport + Send + Sync + 'static,
580{
581 async fn send_elicitation(
582 &self,
583 request: ElicitRequest,
584 _ctx: RequestContext,
585 ) -> ServerResult<ElicitResult> {
586 let json_rpc_request = JsonRpcRequest {
587 jsonrpc: JsonRpcVersion,
588 method: "elicitation/create".to_string(),
589 params: Some(
590 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
591 message: format!("Failed to serialize elicitation request: {}", e),
592 context: Some("MCP compliance".to_string()),
593 })?,
594 ),
595 id: Self::generate_request_id(),
596 };
597
598 let response = self.send_request(json_rpc_request).await?;
599
600 if let Some(result) = response.result() {
601 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
602 message: format!("Invalid elicitation response format: {}", e),
603 context: Some("MCP compliance".to_string()),
604 })
605 } else if let Some(error) = response.error() {
606 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
608 error.code,
609 &error.message,
610 )))
611 } else {
612 Err(ServerError::Handler {
613 message: "Invalid elicitation response: missing result and error".to_string(),
614 context: Some("MCP compliance".to_string()),
615 })
616 }
617 }
618
619 async fn send_ping(
620 &self,
621 _request: PingRequest,
622 _ctx: RequestContext,
623 ) -> ServerResult<PingResult> {
624 let json_rpc_request = JsonRpcRequest {
625 jsonrpc: JsonRpcVersion,
626 method: "ping".to_string(),
627 params: None,
628 id: Self::generate_request_id(),
629 };
630
631 let response = self.send_request(json_rpc_request).await?;
632
633 if response.result().is_some() {
634 Ok(PingResult {
635 data: None,
636 _meta: None,
637 })
638 } else if let Some(error) = response.error() {
639 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
641 error.code,
642 &error.message,
643 )))
644 } else {
645 Err(ServerError::Handler {
646 message: "Invalid ping response".to_string(),
647 context: Some("MCP compliance".to_string()),
648 })
649 }
650 }
651
652 async fn send_create_message(
653 &self,
654 request: CreateMessageRequest,
655 _ctx: RequestContext,
656 ) -> ServerResult<CreateMessageResult> {
657 let json_rpc_request = JsonRpcRequest {
658 jsonrpc: JsonRpcVersion,
659 method: "sampling/createMessage".to_string(),
660 params: Some(
661 serde_json::to_value(&request).map_err(|e| ServerError::Handler {
662 message: format!("Failed to serialize sampling request: {}", e),
663 context: Some("MCP compliance".to_string()),
664 })?,
665 ),
666 id: Self::generate_request_id(),
667 };
668
669 let response = self.send_request(json_rpc_request).await?;
670
671 if let Some(result) = response.result() {
672 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
673 message: format!("Invalid sampling response format: {}", e),
674 context: Some("MCP compliance".to_string()),
675 })
676 } else if let Some(error) = response.error() {
677 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
679 error.code,
680 &error.message,
681 )))
682 } else {
683 Err(ServerError::Handler {
684 message: "Invalid sampling response: missing result and error".to_string(),
685 context: Some("MCP compliance".to_string()),
686 })
687 }
688 }
689
690 async fn send_list_roots(
691 &self,
692 _request: ListRootsRequest,
693 _ctx: RequestContext,
694 ) -> ServerResult<ListRootsResult> {
695 let json_rpc_request = JsonRpcRequest {
696 jsonrpc: JsonRpcVersion,
697 method: "roots/list".to_string(),
698 params: None,
699 id: Self::generate_request_id(),
700 };
701
702 let response = self.send_request(json_rpc_request).await?;
703
704 if let Some(result) = response.result() {
705 serde_json::from_value(result.clone()).map_err(|e| ServerError::Handler {
706 message: format!("Invalid roots response format: {}", e),
707 context: Some("MCP compliance".to_string()),
708 })
709 } else if let Some(error) = response.error() {
710 Err(ServerError::Protocol(turbomcp_protocol::Error::rpc(
712 error.code,
713 &error.message,
714 )))
715 } else {
716 Err(ServerError::Handler {
717 message: "Invalid roots response: missing result and error".to_string(),
718 context: Some("MCP compliance".to_string()),
719 })
720 }
721 }
722
723 fn supports_bidirectional(&self) -> bool {
724 self.transport.capabilities().supports_bidirectional
725 }
726
727 async fn get_client_capabilities(&self) -> ServerResult<Option<serde_json::Value>> {
728 Ok(None)
729 }
730}
731
732pub async fn run_transport_bidirectional<T>(
758 router: Arc<RequestRouter>,
759 dispatcher: TransportDispatcher<T>,
760) -> Result<(), Box<dyn std::error::Error>>
761where
762 T: turbomcp_transport::Transport + Send + Sync + 'static,
763{
764 let transport = dispatcher.transport();
765 let pending_requests = dispatcher.pending_requests();
766
767 loop {
769 match transport.receive().await {
771 Ok(Some(message)) => {
772 if let Ok(response) = serde_json::from_slice::<JsonRpcResponse>(&message.payload) {
774 let request_id = match &response.id {
775 turbomcp_protocol::jsonrpc::ResponseId(Some(id)) => match id {
776 MessageId::String(s) => s.clone(),
777 MessageId::Number(n) => n.to_string(),
778 MessageId::Uuid(u) => u.to_string(),
779 },
780 _ => continue,
781 };
782
783 if let Some(tx) = pending_requests.lock().await.remove(&request_id) {
784 let _ = tx.send(response);
785 }
786 continue;
787 }
788
789 if let Ok(request) = serde_json::from_slice::<JsonRpcRequest>(&message.payload) {
791 let router = Arc::clone(&router);
792 let transport = Arc::clone(&transport);
793
794 tokio::spawn(async move {
795 let ctx = router.create_context(None, None);
797 let response = router.route(request, ctx).await;
798
799 if let Ok(json) = serde_json::to_vec(&response) {
801 use turbomcp_transport::{
802 TransportMessage, core::TransportMessageMetadata,
803 };
804 let transport_msg = TransportMessage::with_metadata(
805 MessageId::Uuid(uuid::Uuid::new_v4()),
806 bytes::Bytes::from(json),
807 TransportMessageMetadata::with_content_type("application/json"),
808 );
809 let _ = transport.send(transport_msg).await;
810 }
811 });
812 }
813 }
814 Ok(None) => {
815 tokio::time::sleep(tokio::time::Duration::from_millis(5)).await;
817 }
818 Err(e) => {
819 tracing::error!(error = %e, "Transport receive failed");
820 break;
821 }
822 }
823 }
824
825 Ok(())
826}
827
828#[cfg(test)]
833mod tests {
834 use super::*;
835
836 #[tokio::test]
837 async fn test_stdio_dispatcher_clean_shutdown() {
838 let (tx, _rx) = mpsc::unbounded_channel();
840 let dispatcher = StdioDispatcher::new(tx);
841
842 drop(dispatcher);
844 }
845
846 #[tokio::test]
847 async fn test_stdio_dispatcher_creation() {
848 let (tx, _rx) = mpsc::unbounded_channel();
850 let dispatcher = StdioDispatcher::new(tx.clone());
851
852 let _dispatcher2 = dispatcher.clone();
854
855 assert!(tx.send(StdioMessage::Shutdown).is_ok());
857 }
858
859 #[tokio::test]
860 async fn test_joinset_task_tracking() {
861 let mut tasks = JoinSet::new();
863
864 for i in 0..5 {
866 tasks.spawn(async move {
867 tokio::time::sleep(tokio::time::Duration::from_millis(i * 10)).await;
868 });
869 }
870
871 assert_eq!(tasks.len(), 5);
872
873 let mut completed = 0;
875 while let Some(result) = tasks.join_next().await {
876 assert!(result.is_ok());
877 completed += 1;
878 }
879
880 assert_eq!(completed, 5);
881 assert!(tasks.is_empty());
882 }
883
884 #[tokio::test]
885 async fn test_joinset_with_timeout() {
886 let mut tasks = JoinSet::new();
888
889 tasks.spawn(async move {
891 tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;
892 });
893
894 let timeout = std::time::Duration::from_millis(100);
896 let start = std::time::Instant::now();
897
898 let result = tokio::time::timeout(timeout, tasks.join_next()).await;
899
900 assert!(result.is_err());
902 assert!(start.elapsed() < std::time::Duration::from_secs(1));
903
904 tasks.shutdown().await;
906 }
907
908 #[tokio::test]
909 async fn test_stdio_message_types() {
910 use turbomcp_protocol::jsonrpc::JsonRpcRequest;
912
913 let request = JsonRpcRequest {
914 jsonrpc: JsonRpcVersion,
915 method: "test".to_string(),
916 params: None,
917 id: MessageId::String("test-1".to_string()),
918 };
919
920 let msg = StdioMessage::ServerRequest { request };
921
922 match msg {
923 StdioMessage::ServerRequest { .. } => { }
924 _ => panic!("Expected ServerRequest"),
925 }
926
927 let shutdown_msg = StdioMessage::Shutdown;
928 match shutdown_msg {
929 StdioMessage::Shutdown => { }
930 _ => panic!("Expected Shutdown"),
931 }
932 }
933
934 #[tokio::test]
935 async fn test_pending_requests_cleanup() {
936 let pending_requests: Arc<Mutex<HashMap<String, oneshot::Sender<JsonRpcResponse>>>> =
938 Arc::new(Mutex::new(HashMap::new()));
939
940 let (tx, _rx) = oneshot::channel();
941 pending_requests
942 .lock()
943 .await
944 .insert("test-id".to_string(), tx);
945
946 assert_eq!(pending_requests.lock().await.len(), 1);
947
948 let removed = pending_requests.lock().await.remove("test-id");
950 assert!(removed.is_some());
951 assert_eq!(pending_requests.lock().await.len(), 0);
952 }
953}