1use std::sync::atomic::Ordering;
30use std::sync::{Arc, Mutex as StdMutex};
31use std::time::Duration;
32
33use async_trait::async_trait;
34use bytes::Bytes;
35use futures::StreamExt;
36use serde_json;
37use tokio::io::{BufReader, Stdin, Stdout};
38use tokio::sync::{Mutex as TokioMutex, mpsc};
39use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
40use tracing::{debug, error, trace, warn};
41use turbomcp_protocol::MessageId;
42use uuid::Uuid;
43
44use crate::core::{
45 AtomicMetrics, Transport, TransportCapabilities, TransportConfig, TransportError,
46 TransportEventEmitter, TransportFactory, TransportMessage, TransportMessageMetadata,
47 TransportMetrics, TransportResult, TransportState, TransportType,
48};
49
50type StdinReader = FramedRead<BufReader<Stdin>, LinesCodec>;
52type StdoutWriter = FramedWrite<Stdout, LinesCodec>;
53
54#[derive(Debug)]
65pub struct StdioTransport {
66 state: Arc<StdMutex<TransportState>>,
68
69 capabilities: TransportCapabilities,
71
72 config: Arc<StdMutex<TransportConfig>>,
74
75 metrics: Arc<AtomicMetrics>,
77
78 event_emitter: TransportEventEmitter,
80
81 stdin_reader: Arc<TokioMutex<Option<StdinReader>>>,
83
84 stdout_writer: Arc<TokioMutex<Option<StdoutWriter>>>,
86
87 receive_channel: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
89
90 _task_handle: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
92}
93
94impl StdioTransport {
95 #[must_use]
97 pub fn new() -> Self {
98 let (event_emitter, _) = TransportEventEmitter::new();
99
100 Self {
101 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
102 capabilities: TransportCapabilities {
103 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
104 supports_compression: false,
105 supports_streaming: true,
106 supports_bidirectional: true,
107 supports_multiplexing: false,
108 compression_algorithms: Vec::new(),
109 custom: std::collections::HashMap::new(),
110 },
111 config: Arc::new(StdMutex::new(TransportConfig {
112 transport_type: TransportType::Stdio,
113 ..Default::default()
114 })),
115 metrics: Arc::new(AtomicMetrics::default()),
116 event_emitter,
117 stdin_reader: Arc::new(TokioMutex::new(None)),
118 stdout_writer: Arc::new(TokioMutex::new(None)),
119 receive_channel: Arc::new(TokioMutex::new(None)),
120 _task_handle: Arc::new(TokioMutex::new(None)),
121 }
122 }
123
124 #[must_use]
126 pub fn with_config(config: TransportConfig) -> Self {
127 let transport = Self::new();
128 *transport.config.lock().expect("config mutex poisoned") = config;
130 transport
131 }
132
133 #[must_use]
135 pub fn with_event_emitter(event_emitter: TransportEventEmitter) -> Self {
136 let (_, _) = TransportEventEmitter::new();
137
138 Self {
139 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
140 capabilities: TransportCapabilities {
141 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
142 supports_compression: false,
143 supports_streaming: true,
144 supports_bidirectional: true,
145 supports_multiplexing: false,
146 compression_algorithms: Vec::new(),
147 custom: std::collections::HashMap::new(),
148 },
149 config: Arc::new(StdMutex::new(TransportConfig {
150 transport_type: TransportType::Stdio,
151 ..Default::default()
152 })),
153 metrics: Arc::new(AtomicMetrics::default()),
154 event_emitter,
155 stdin_reader: Arc::new(TokioMutex::new(None)),
156 stdout_writer: Arc::new(TokioMutex::new(None)),
157 receive_channel: Arc::new(TokioMutex::new(None)),
158 _task_handle: Arc::new(TokioMutex::new(None)),
159 }
160 }
161
162 fn set_state(&self, new_state: TransportState) {
163 let mut state = self.state.lock().expect("state mutex poisoned");
165 if *state != new_state {
166 trace!("Stdio transport state: {:?} -> {:?}", *state, new_state);
167 *state = new_state.clone();
168
169 match new_state {
170 TransportState::Connected => {
171 self.event_emitter
172 .emit_connected(TransportType::Stdio, "stdio://".to_string());
173 }
174 TransportState::Disconnected => {
175 self.event_emitter.emit_disconnected(
176 TransportType::Stdio,
177 "stdio://".to_string(),
178 None,
179 );
180 }
181 TransportState::Failed { reason } => {
182 self.event_emitter.emit_disconnected(
183 TransportType::Stdio,
184 "stdio://".to_string(),
185 Some(reason),
186 );
187 }
188 _ => {}
189 }
190 }
191 }
192
193 #[allow(dead_code)]
195 fn heartbeat(&self) {
196 }
199
200 async fn setup_stdio_streams(&self) -> TransportResult<()> {
201 let stdin = tokio::io::stdin();
203 let reader = BufReader::new(stdin);
204 let mut stdin_reader = FramedRead::new(reader, LinesCodec::new());
205
206 let stdout = tokio::io::stdout();
208 *self.stdout_writer.lock().await = Some(FramedWrite::new(stdout, LinesCodec::new()));
209
210 let (tx, rx) = mpsc::channel(1000);
212 *self.receive_channel.lock().await = Some(rx);
213
214 {
216 let sender = tx;
217 let event_emitter = self.event_emitter.clone();
218 let metrics = self.metrics.clone();
219
220 let task_handle = tokio::spawn(async move {
221 while let Some(result) = stdin_reader.next().await {
222 match result {
223 Ok(line) => {
224 trace!("Received line: {}", line);
225
226 match Self::parse_message(&line) {
227 Ok(message) => {
228 let size = message.size();
229
230 metrics.messages_received.fetch_add(1, Ordering::Relaxed);
232 metrics
233 .bytes_received
234 .fetch_add(size as u64, Ordering::Relaxed);
235
236 event_emitter.emit_message_received(message.id.clone(), size);
238
239 match sender.try_send(message) {
241 Ok(()) => {}
242 Err(mpsc::error::TrySendError::Full(_)) => {
243 warn!(
244 "STDIO message channel full, applying backpressure"
245 );
246 continue;
248 }
249 Err(mpsc::error::TrySendError::Closed(_)) => {
250 debug!("Receive channel closed, stopping reader task");
251 break;
252 }
253 }
254 }
255 Err(e) => {
256 error!("Failed to parse message: {}", e);
257 event_emitter
258 .emit_error(e, Some("message parsing".to_string()));
259 }
260 }
261 }
262 Err(e) => {
263 error!("Failed to read from stdin: {}", e);
264 event_emitter.emit_error(
265 TransportError::ReceiveFailed(e.to_string()),
266 Some("stdin read".to_string()),
267 );
268 break;
269 }
270 }
271 }
272
273 debug!("Stdio reader task completed");
274 });
275
276 *self._task_handle.lock().await = Some(task_handle);
277 }
278
279 Ok(())
280 }
281
282 fn parse_message(line: &str) -> TransportResult<TransportMessage> {
283 let line = line.trim();
284 if line.is_empty() {
285 return Err(TransportError::ProtocolError("Empty message".to_string()));
286 }
287
288 let json_value: serde_json::Value = serde_json::from_str(line)
290 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
291
292 let message_id = json_value
294 .get("id")
295 .and_then(|id| match id {
296 serde_json::Value::String(s) => Some(MessageId::from(s.clone())),
297 serde_json::Value::Number(n) => n.as_i64().map(MessageId::from),
298 _ => None,
299 })
300 .unwrap_or_else(|| MessageId::from(Uuid::new_v4()));
301
302 let payload = Bytes::from(line.to_string());
304 let metadata = TransportMessageMetadata::with_content_type("application/json");
305
306 Ok(TransportMessage::with_metadata(
307 message_id, payload, metadata,
308 ))
309 }
310
311 fn serialize_message(message: &TransportMessage) -> TransportResult<String> {
312 let json_str = std::str::from_utf8(&message.payload)
314 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
315
316 if json_str.contains('\n') || json_str.contains('\r') {
320 return Err(TransportError::ProtocolError(
321 "Message contains embedded newlines (forbidden by MCP stdio specification)"
322 .to_string(),
323 ));
324 }
325
326 let _: serde_json::Value = serde_json::from_str(json_str)
328 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
329
330 Ok(json_str.to_string())
331 }
332}
333
334#[async_trait]
335impl Transport for StdioTransport {
336 fn transport_type(&self) -> TransportType {
337 TransportType::Stdio
338 }
339
340 fn capabilities(&self) -> &TransportCapabilities {
341 &self.capabilities
342 }
343
344 async fn state(&self) -> TransportState {
345 self.state.lock().expect("state mutex poisoned").clone()
347 }
348
349 async fn connect(&self) -> TransportResult<()> {
350 if matches!(self.state().await, TransportState::Connected) {
351 return Ok(());
352 }
353
354 self.set_state(TransportState::Connecting);
355
356 match self.setup_stdio_streams().await {
357 Ok(()) => {
358 self.metrics.connections.fetch_add(1, Ordering::Relaxed);
360 self.set_state(TransportState::Connected);
361 debug!("Stdio transport connected");
362 Ok(())
363 }
364 Err(e) => {
365 self.metrics
367 .failed_connections
368 .fetch_add(1, Ordering::Relaxed);
369 self.set_state(TransportState::Failed {
370 reason: e.to_string(),
371 });
372 error!("Failed to connect stdio transport: {}", e);
373 Err(e)
374 }
375 }
376 }
377
378 async fn disconnect(&self) -> TransportResult<()> {
379 if matches!(self.state().await, TransportState::Disconnected) {
380 return Ok(());
381 }
382
383 self.set_state(TransportState::Disconnecting);
384
385 *self.stdin_reader.lock().await = None;
387 *self.stdout_writer.lock().await = None;
388 *self.receive_channel.lock().await = None;
389
390 if let Some(handle) = self._task_handle.lock().await.take() {
392 handle.abort();
393 }
394
395 self.set_state(TransportState::Disconnected);
396 debug!("Stdio transport disconnected");
397 Ok(())
398 }
399
400 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
401 let state = self.state().await;
402 if !matches!(state, TransportState::Connected) {
403 return Err(TransportError::ConnectionFailed(format!(
404 "Transport not connected: {state}"
405 )));
406 }
407
408 let json_line = Self::serialize_message(&message)?;
409 let size = json_line.len();
410
411 let mut stdout_writer = self.stdout_writer.lock().await;
412 if let Some(writer) = stdout_writer.as_mut() {
413 if let Err(e) = writer.send(json_line).await {
414 error!("Failed to send message: {}", e);
415 self.set_state(TransportState::Failed {
416 reason: e.to_string(),
417 });
418 return Err(TransportError::SendFailed(e.to_string()));
419 }
420
421 use futures::SinkExt;
423 if let Err(e) = SinkExt::<String>::flush(writer).await {
424 error!("Failed to flush stdout: {}", e);
425 return Err(TransportError::SendFailed(e.to_string()));
426 }
427
428 self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
430 self.metrics
431 .bytes_sent
432 .fetch_add(size as u64, Ordering::Relaxed);
433
434 self.event_emitter.emit_message_sent(message.id, size);
436
437 trace!("Sent message: {} bytes", size);
438 Ok(())
439 } else {
440 Err(TransportError::SendFailed(
441 "Stdout writer not available".to_string(),
442 ))
443 }
444 }
445
446 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
447 let state = self.state().await;
448 if !matches!(state, TransportState::Connected) {
449 return Err(TransportError::ConnectionFailed(format!(
450 "Transport not connected: {state}"
451 )));
452 }
453
454 let mut receive_channel = self.receive_channel.lock().await;
455 if let Some(receiver) = receive_channel.as_mut() {
456 match receiver.recv().await {
457 Some(message) => {
458 trace!("Received message: {} bytes", message.size());
459 Ok(Some(message))
460 }
461 None => {
462 warn!("Receive channel disconnected");
463 self.set_state(TransportState::Failed {
464 reason: "Receive channel disconnected".to_string(),
465 });
466 Err(TransportError::ReceiveFailed(
467 "Channel disconnected".to_string(),
468 ))
469 }
470 }
471 } else {
472 Err(TransportError::ReceiveFailed(
473 "Receive channel not available".to_string(),
474 ))
475 }
476 }
477
478 async fn metrics(&self) -> TransportMetrics {
479 self.metrics.snapshot()
481 }
482
483 fn endpoint(&self) -> Option<String> {
484 Some("stdio://".to_string())
485 }
486
487 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
488 if config.transport_type != TransportType::Stdio {
489 return Err(TransportError::ConfigurationError(format!(
490 "Invalid transport type: {:?}",
491 config.transport_type
492 )));
493 }
494
495 if config.connect_timeout < Duration::from_millis(100) {
497 return Err(TransportError::ConfigurationError(
498 "Connect timeout too small".to_string(),
499 ));
500 }
501
502 *self.config.lock().expect("config mutex poisoned") = config;
504 debug!("Stdio transport configured");
505 Ok(())
506 }
507}
508
509#[derive(Debug, Default)]
511pub struct StdioTransportFactory;
512
513impl StdioTransportFactory {
514 #[must_use]
516 pub const fn new() -> Self {
517 Self
518 }
519}
520
521impl TransportFactory for StdioTransportFactory {
522 fn transport_type(&self) -> TransportType {
523 TransportType::Stdio
524 }
525
526 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>> {
527 if config.transport_type != TransportType::Stdio {
528 return Err(TransportError::ConfigurationError(format!(
529 "Invalid transport type: {:?}",
530 config.transport_type
531 )));
532 }
533
534 let transport = StdioTransport::with_config(config);
535 Ok(Box::new(transport))
536 }
537
538 fn is_available(&self) -> bool {
539 true
541 }
542}
543
544impl Default for StdioTransport {
545 fn default() -> Self {
546 Self::new()
547 }
548}
549
550#[cfg(test)]
551mod tests {
552 use super::*;
553 use pretty_assertions::assert_eq;
554 #[test]
558 fn test_stdio_transport_creation() {
559 let transport = StdioTransport::new();
560 assert_eq!(transport.transport_type(), TransportType::Stdio);
561 assert!(transport.capabilities().supports_streaming);
562 assert!(transport.capabilities().supports_bidirectional);
563 }
564
565 #[test]
566 fn test_stdio_transport_with_config() {
567 let config = TransportConfig {
568 transport_type: TransportType::Stdio,
569 connect_timeout: Duration::from_secs(10),
570 ..Default::default()
571 };
572
573 let transport = StdioTransport::with_config(config);
574 assert_eq!(
575 transport
576 .config
577 .lock()
578 .expect("config mutex poisoned")
579 .connect_timeout,
580 Duration::from_secs(10)
581 );
582 }
583
584 #[tokio::test]
585 async fn test_stdio_transport_state_management() {
586 let transport = StdioTransport::new();
587 assert_eq!(transport.state().await, TransportState::Disconnected);
588 }
589
590 #[test]
591 fn test_message_parsing() {
592 let json_line = r#"{"jsonrpc":"2.0","id":"test-123","method":"test","params":{}}"#;
593 let message = StdioTransport::parse_message(json_line).unwrap();
594
595 assert_eq!(message.id, MessageId::from("test-123"));
596 assert_eq!(message.content_type(), Some("application/json"));
597 assert!(!message.payload.is_empty());
598 }
599
600 #[test]
601 fn test_message_parsing_with_numeric_id() {
602 let json_line = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{}}"#;
603 let message = StdioTransport::parse_message(json_line).unwrap();
604
605 assert_eq!(message.id, MessageId::from(42));
606 }
607
608 #[test]
609 fn test_message_parsing_without_id() {
610 let json_line = r#"{"jsonrpc":"2.0","method":"notification","params":{}}"#;
611 let message = StdioTransport::parse_message(json_line).unwrap();
612
613 match message.id {
615 MessageId::Uuid(_) => {} _ => assert!(
617 matches!(message.id, MessageId::Uuid(_)),
618 "Expected UUID message ID"
619 ),
620 }
621 }
622
623 #[test]
624 fn test_message_parsing_invalid_json() {
625 let invalid_json = "not json at all";
626 let result = StdioTransport::parse_message(invalid_json);
627
628 assert!(matches!(
629 result,
630 Err(TransportError::SerializationFailed(_))
631 ));
632 }
633
634 #[test]
635 fn test_message_parsing_empty() {
636 let result = StdioTransport::parse_message("");
637 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
638
639 let result = StdioTransport::parse_message(" ");
640 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
641 }
642
643 #[test]
644 fn test_message_serialization() {
645 let json_str = r#"{"jsonrpc":"2.0","id":"test","method":"ping"}"#;
646 let payload = Bytes::from(json_str);
647 let message = TransportMessage::new(MessageId::from("test"), payload);
648
649 let serialized = StdioTransport::serialize_message(&message).unwrap();
650 assert_eq!(serialized, json_str);
651 }
652
653 #[test]
654 fn test_message_serialization_invalid_utf8() {
655 let payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]); let message = TransportMessage::new(MessageId::from("test"), payload);
657
658 let result = StdioTransport::serialize_message(&message);
659 assert!(matches!(
660 result,
661 Err(TransportError::SerializationFailed(_))
662 ));
663 }
664
665 #[test]
666 fn test_message_serialization_invalid_json() {
667 let payload = Bytes::from("not json");
668 let message = TransportMessage::new(MessageId::from("test"), payload);
669
670 let result = StdioTransport::serialize_message(&message);
671 assert!(matches!(
672 result,
673 Err(TransportError::SerializationFailed(_))
674 ));
675 }
676
677 #[test]
678 fn test_message_serialization_embedded_newline_lf() {
679 let json_with_newline = r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"line1
681line2"}}"#;
682 let payload = Bytes::from(json_with_newline);
683 let message = TransportMessage::new(MessageId::from("test"), payload);
684
685 let result = StdioTransport::serialize_message(&message);
686 assert!(
687 matches!(result, Err(TransportError::ProtocolError(_))),
688 "Expected ProtocolError for message with LF, got: {:?}",
689 result
690 );
691 }
692
693 #[test]
694 fn test_message_serialization_embedded_newline_crlf() {
695 let json_with_crlf = "{\r\n\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
697 let payload = Bytes::from(json_with_crlf);
698 let message = TransportMessage::new(MessageId::from("test"), payload);
699
700 let result = StdioTransport::serialize_message(&message);
701 assert!(
702 matches!(result, Err(TransportError::ProtocolError(_))),
703 "Expected ProtocolError for message with CRLF, got: {:?}",
704 result
705 );
706 }
707
708 #[test]
709 fn test_message_serialization_embedded_cr() {
710 let json_with_cr = "{\r\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
712 let payload = Bytes::from(json_with_cr);
713 let message = TransportMessage::new(MessageId::from("test"), payload);
714
715 let result = StdioTransport::serialize_message(&message);
716 assert!(
717 matches!(result, Err(TransportError::ProtocolError(_))),
718 "Expected ProtocolError for message with CR, got: {:?}",
719 result
720 );
721 }
722
723 #[test]
724 fn test_message_serialization_valid_no_newlines() {
725 let valid_json =
727 r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"single line"}}"#;
728 let payload = Bytes::from(valid_json);
729 let message = TransportMessage::new(MessageId::from("test"), payload);
730
731 let result = StdioTransport::serialize_message(&message);
732 assert!(
733 result.is_ok(),
734 "Valid message without newlines should be accepted"
735 );
736 assert_eq!(result.unwrap(), valid_json);
737 }
738
739 #[test]
740 fn test_message_serialization_escaped_newlines_allowed() {
741 let json_with_escaped_newlines = r#"{"jsonrpc":"2.0","id":"test","method":"log","params":{"message":"line1\nline2\ntab:\there"}}"#;
755
756 assert!(
758 !json_with_escaped_newlines.contains('\n'),
759 "Test setup error: raw string should not contain literal newline bytes"
760 );
761 assert!(
762 !json_with_escaped_newlines.contains('\r'),
763 "Test setup error: raw string should not contain literal CR bytes"
764 );
765
766 let payload = Bytes::from(json_with_escaped_newlines);
767 let message = TransportMessage::new(MessageId::from("test"), payload);
768
769 let result = StdioTransport::serialize_message(&message);
770 assert!(
771 result.is_ok(),
772 "JSON with ESCAPED newlines (backslash-n) should be ALLOWED per MCP spec. Got: {:?}",
773 result
774 );
775 assert_eq!(result.unwrap(), json_with_escaped_newlines);
776 }
777
778 #[test]
779 fn test_stdio_factory() {
780 let factory = StdioTransportFactory::new();
781 assert_eq!(factory.transport_type(), TransportType::Stdio);
782 assert!(factory.is_available());
783
784 let config = TransportConfig {
785 transport_type: TransportType::Stdio,
786 ..Default::default()
787 };
788
789 let transport = factory.create(config).unwrap();
790 assert_eq!(transport.transport_type(), TransportType::Stdio);
791 }
792
793 #[test]
794 fn test_stdio_factory_invalid_config() {
795 let factory = StdioTransportFactory::new();
796 let config = TransportConfig {
797 transport_type: TransportType::Http, ..Default::default()
799 };
800
801 let result = factory.create(config);
802 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
803 }
804
805 #[tokio::test]
806 async fn test_configuration_validation() {
807 let transport = StdioTransport::new();
808
809 let valid_config = TransportConfig {
811 transport_type: TransportType::Stdio,
812 connect_timeout: Duration::from_secs(5),
813 ..Default::default()
814 };
815
816 assert!(transport.configure(valid_config).await.is_ok());
817
818 let invalid_config = TransportConfig {
820 transport_type: TransportType::Http,
821 ..Default::default()
822 };
823
824 let result = transport.configure(invalid_config).await;
825 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
826
827 let invalid_timeout_config = TransportConfig {
829 transport_type: TransportType::Stdio,
830 connect_timeout: Duration::from_millis(50), ..Default::default()
832 };
833
834 let result = transport.configure(invalid_timeout_config).await;
835 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
836 }
837}