1use std::sync::atomic::Ordering;
30use std::sync::{Arc, Mutex as StdMutex};
31use std::time::Duration;
32
33use async_trait::async_trait;
34use bytes::Bytes;
35use futures::StreamExt;
36use serde_json;
37use tokio::io::{BufReader, Stdin, Stdout};
38use tokio::sync::{Mutex as TokioMutex, mpsc};
39use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
40use tracing::{debug, error, trace, warn};
41use turbomcp_protocol::MessageId;
42use uuid::Uuid;
43
44use crate::core::{
45 AtomicMetrics, Transport, TransportCapabilities, TransportConfig, TransportError,
46 TransportEventEmitter, TransportFactory, TransportMessage, TransportMessageMetadata,
47 TransportMetrics, TransportResult, TransportState, TransportType,
48};
49
50type StdinReader = FramedRead<BufReader<Stdin>, LinesCodec>;
52type StdoutWriter = FramedWrite<Stdout, LinesCodec>;
53
54#[derive(Debug)]
65pub struct StdioTransport {
66 state: Arc<StdMutex<TransportState>>,
68
69 capabilities: TransportCapabilities,
71
72 config: Arc<StdMutex<TransportConfig>>,
74
75 metrics: Arc<AtomicMetrics>,
77
78 event_emitter: TransportEventEmitter,
80
81 stdin_reader: Arc<TokioMutex<Option<StdinReader>>>,
83
84 stdout_writer: Arc<TokioMutex<Option<StdoutWriter>>>,
86
87 receive_channel: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
89
90 _task_handle: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
92}
93
94impl StdioTransport {
95 #[must_use]
97 pub fn new() -> Self {
98 let (event_emitter, _) = TransportEventEmitter::new();
99
100 Self {
101 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
102 capabilities: TransportCapabilities {
103 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
104 supports_compression: false,
105 supports_streaming: true,
106 supports_bidirectional: true,
107 supports_multiplexing: false,
108 compression_algorithms: Vec::new(),
109 custom: std::collections::HashMap::new(),
110 },
111 config: Arc::new(StdMutex::new(TransportConfig {
112 transport_type: TransportType::Stdio,
113 ..Default::default()
114 })),
115 metrics: Arc::new(AtomicMetrics::default()),
116 event_emitter,
117 stdin_reader: Arc::new(TokioMutex::new(None)),
118 stdout_writer: Arc::new(TokioMutex::new(None)),
119 receive_channel: Arc::new(TokioMutex::new(None)),
120 _task_handle: Arc::new(TokioMutex::new(None)),
121 }
122 }
123
124 #[must_use]
126 pub fn with_config(config: TransportConfig) -> Self {
127 let transport = Self::new();
128 *transport.config.lock().expect("config mutex poisoned") = config;
130 transport
131 }
132
133 #[must_use]
135 pub fn with_event_emitter(event_emitter: TransportEventEmitter) -> Self {
136 let (_, _) = TransportEventEmitter::new();
137
138 Self {
139 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
140 capabilities: TransportCapabilities {
141 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
142 supports_compression: false,
143 supports_streaming: true,
144 supports_bidirectional: true,
145 supports_multiplexing: false,
146 compression_algorithms: Vec::new(),
147 custom: std::collections::HashMap::new(),
148 },
149 config: Arc::new(StdMutex::new(TransportConfig {
150 transport_type: TransportType::Stdio,
151 ..Default::default()
152 })),
153 metrics: Arc::new(AtomicMetrics::default()),
154 event_emitter,
155 stdin_reader: Arc::new(TokioMutex::new(None)),
156 stdout_writer: Arc::new(TokioMutex::new(None)),
157 receive_channel: Arc::new(TokioMutex::new(None)),
158 _task_handle: Arc::new(TokioMutex::new(None)),
159 }
160 }
161
162 fn set_state(&self, new_state: TransportState) {
163 let mut state = self.state.lock().expect("state mutex poisoned");
165 if *state != new_state {
166 trace!("Stdio transport state: {:?} -> {:?}", *state, new_state);
167 *state = new_state.clone();
168
169 match new_state {
170 TransportState::Connected => {
171 self.event_emitter
172 .emit_connected(TransportType::Stdio, "stdio://".to_string());
173 }
174 TransportState::Disconnected => {
175 self.event_emitter.emit_disconnected(
176 TransportType::Stdio,
177 "stdio://".to_string(),
178 None,
179 );
180 }
181 TransportState::Failed { reason } => {
182 self.event_emitter.emit_disconnected(
183 TransportType::Stdio,
184 "stdio://".to_string(),
185 Some(reason),
186 );
187 }
188 _ => {}
189 }
190 }
191 }
192
193 #[allow(dead_code)]
195 fn heartbeat(&self) {
196 }
199
200 async fn setup_stdio_streams(&self) -> TransportResult<()> {
201 let stdin = tokio::io::stdin();
203 let reader = BufReader::new(stdin);
204 let mut stdin_reader = FramedRead::new(reader, LinesCodec::new());
205
206 let stdout = tokio::io::stdout();
208 *self.stdout_writer.lock().await = Some(FramedWrite::new(stdout, LinesCodec::new()));
209
210 let (tx, rx) = mpsc::channel(1000);
212 *self.receive_channel.lock().await = Some(rx);
213
214 {
216 let sender = tx;
217 let event_emitter = self.event_emitter.clone();
218 let metrics = self.metrics.clone();
219 let config = self.config.clone();
220
221 let task_handle = tokio::spawn(async move {
222 while let Some(result) = stdin_reader.next().await {
223 match result {
224 Ok(line) => {
225 trace!("Received line: {}", line);
226
227 let size = line.len();
229 let limits = {
230 let cfg = config.lock().expect("config mutex poisoned");
231 cfg.limits.clone()
232 };
233
234 if let Err(e) = crate::core::validate_response_size(size, &limits) {
235 error!("Response size validation failed: {}", e);
236 event_emitter.emit_error(
237 e.clone(),
238 Some("response size validation".to_string()),
239 );
240 continue;
242 }
243
244 match Self::parse_message(&line) {
245 Ok(message) => {
246 let size = message.size();
247
248 metrics.messages_received.fetch_add(1, Ordering::Relaxed);
250 metrics
251 .bytes_received
252 .fetch_add(size as u64, Ordering::Relaxed);
253
254 event_emitter.emit_message_received(message.id.clone(), size);
256
257 match sender.try_send(message) {
259 Ok(()) => {}
260 Err(mpsc::error::TrySendError::Full(_)) => {
261 warn!(
262 "STDIO message channel full, applying backpressure"
263 );
264 continue;
266 }
267 Err(mpsc::error::TrySendError::Closed(_)) => {
268 debug!("Receive channel closed, stopping reader task");
269 break;
270 }
271 }
272 }
273 Err(e) => {
274 error!("Failed to parse message: {}", e);
275 event_emitter
276 .emit_error(e, Some("message parsing".to_string()));
277 }
278 }
279 }
280 Err(e) => {
281 error!("Failed to read from stdin: {}", e);
282 event_emitter.emit_error(
283 TransportError::ReceiveFailed(e.to_string()),
284 Some("stdin read".to_string()),
285 );
286 break;
287 }
288 }
289 }
290
291 debug!("Stdio reader task completed");
292 });
293
294 *self._task_handle.lock().await = Some(task_handle);
295 }
296
297 Ok(())
298 }
299
300 fn parse_message(line: &str) -> TransportResult<TransportMessage> {
301 let line = line.trim();
302 if line.is_empty() {
303 return Err(TransportError::ProtocolError("Empty message".to_string()));
304 }
305
306 let json_value: serde_json::Value = serde_json::from_str(line)
308 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
309
310 let message_id = json_value
312 .get("id")
313 .and_then(|id| match id {
314 serde_json::Value::String(s) => Some(MessageId::from(s.clone())),
315 serde_json::Value::Number(n) => n.as_i64().map(MessageId::from),
316 _ => None,
317 })
318 .unwrap_or_else(|| MessageId::from(Uuid::new_v4()));
319
320 let payload = Bytes::from(line.to_string());
322 let metadata = TransportMessageMetadata::with_content_type("application/json");
323
324 Ok(TransportMessage::with_metadata(
325 message_id, payload, metadata,
326 ))
327 }
328
329 fn serialize_message(message: &TransportMessage) -> TransportResult<String> {
330 let json_str = std::str::from_utf8(&message.payload)
332 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
333
334 if json_str.contains('\n') || json_str.contains('\r') {
338 return Err(TransportError::ProtocolError(
339 "Message contains embedded newlines (forbidden by MCP stdio specification)"
340 .to_string(),
341 ));
342 }
343
344 let _: serde_json::Value = serde_json::from_str(json_str)
346 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
347
348 Ok(json_str.to_string())
349 }
350}
351
352#[async_trait]
353impl Transport for StdioTransport {
354 fn transport_type(&self) -> TransportType {
355 TransportType::Stdio
356 }
357
358 fn capabilities(&self) -> &TransportCapabilities {
359 &self.capabilities
360 }
361
362 async fn state(&self) -> TransportState {
363 self.state.lock().expect("state mutex poisoned").clone()
365 }
366
367 async fn connect(&self) -> TransportResult<()> {
368 if matches!(self.state().await, TransportState::Connected) {
369 return Ok(());
370 }
371
372 self.set_state(TransportState::Connecting);
373
374 match self.setup_stdio_streams().await {
375 Ok(()) => {
376 self.metrics.connections.fetch_add(1, Ordering::Relaxed);
378 self.set_state(TransportState::Connected);
379 debug!("Stdio transport connected");
380 Ok(())
381 }
382 Err(e) => {
383 self.metrics
385 .failed_connections
386 .fetch_add(1, Ordering::Relaxed);
387 self.set_state(TransportState::Failed {
388 reason: e.to_string(),
389 });
390 error!("Failed to connect stdio transport: {}", e);
391 Err(e)
392 }
393 }
394 }
395
396 async fn disconnect(&self) -> TransportResult<()> {
397 if matches!(self.state().await, TransportState::Disconnected) {
398 return Ok(());
399 }
400
401 self.set_state(TransportState::Disconnecting);
402
403 *self.stdin_reader.lock().await = None;
405 *self.stdout_writer.lock().await = None;
406 *self.receive_channel.lock().await = None;
407
408 if let Some(handle) = self._task_handle.lock().await.take() {
410 handle.abort();
411 }
412
413 self.set_state(TransportState::Disconnected);
414 debug!("Stdio transport disconnected");
415 Ok(())
416 }
417
418 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
419 let state = self.state().await;
420 if !matches!(state, TransportState::Connected) {
421 return Err(TransportError::ConnectionFailed(format!(
422 "Transport not connected: {state}"
423 )));
424 }
425
426 let json_line = Self::serialize_message(&message)?;
427 let size = json_line.len();
428
429 let config = self.config.lock().expect("config mutex poisoned").clone();
431 crate::core::validate_request_size(size, &config.limits)?;
432
433 let mut stdout_writer = self.stdout_writer.lock().await;
434 if let Some(writer) = stdout_writer.as_mut() {
435 if let Err(e) = writer.send(json_line).await {
436 error!("Failed to send message: {}", e);
437 self.set_state(TransportState::Failed {
438 reason: e.to_string(),
439 });
440 return Err(TransportError::SendFailed(e.to_string()));
441 }
442
443 use futures::SinkExt;
445 if let Err(e) = SinkExt::<String>::flush(writer).await {
446 error!("Failed to flush stdout: {}", e);
447 return Err(TransportError::SendFailed(e.to_string()));
448 }
449
450 self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
452 self.metrics
453 .bytes_sent
454 .fetch_add(size as u64, Ordering::Relaxed);
455
456 self.event_emitter.emit_message_sent(message.id, size);
458
459 trace!("Sent message: {} bytes", size);
460 Ok(())
461 } else {
462 Err(TransportError::SendFailed(
463 "Stdout writer not available".to_string(),
464 ))
465 }
466 }
467
468 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
469 let state = self.state().await;
470 if !matches!(state, TransportState::Connected) {
471 return Err(TransportError::ConnectionFailed(format!(
472 "Transport not connected: {state}"
473 )));
474 }
475
476 let mut receive_channel = self.receive_channel.lock().await;
477 if let Some(receiver) = receive_channel.as_mut() {
478 match receiver.recv().await {
479 Some(message) => {
480 trace!("Received message: {} bytes", message.size());
481 Ok(Some(message))
482 }
483 None => {
484 warn!("Receive channel disconnected");
485 self.set_state(TransportState::Failed {
486 reason: "Receive channel disconnected".to_string(),
487 });
488 Err(TransportError::ReceiveFailed(
489 "Channel disconnected".to_string(),
490 ))
491 }
492 }
493 } else {
494 Err(TransportError::ReceiveFailed(
495 "Receive channel not available".to_string(),
496 ))
497 }
498 }
499
500 async fn metrics(&self) -> TransportMetrics {
501 self.metrics.snapshot()
503 }
504
505 fn endpoint(&self) -> Option<String> {
506 Some("stdio://".to_string())
507 }
508
509 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
510 if config.transport_type != TransportType::Stdio {
511 return Err(TransportError::ConfigurationError(format!(
512 "Invalid transport type: {:?}",
513 config.transport_type
514 )));
515 }
516
517 if config.connect_timeout < Duration::from_millis(100) {
519 return Err(TransportError::ConfigurationError(
520 "Connect timeout too small".to_string(),
521 ));
522 }
523
524 *self.config.lock().expect("config mutex poisoned") = config;
526 debug!("Stdio transport configured");
527 Ok(())
528 }
529}
530
531#[derive(Debug, Default)]
533pub struct StdioTransportFactory;
534
535impl StdioTransportFactory {
536 #[must_use]
538 pub const fn new() -> Self {
539 Self
540 }
541}
542
543impl TransportFactory for StdioTransportFactory {
544 fn transport_type(&self) -> TransportType {
545 TransportType::Stdio
546 }
547
548 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>> {
549 if config.transport_type != TransportType::Stdio {
550 return Err(TransportError::ConfigurationError(format!(
551 "Invalid transport type: {:?}",
552 config.transport_type
553 )));
554 }
555
556 let transport = StdioTransport::with_config(config);
557 Ok(Box::new(transport))
558 }
559
560 fn is_available(&self) -> bool {
561 true
563 }
564}
565
566impl Default for StdioTransport {
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use pretty_assertions::assert_eq;
576 #[test]
580 fn test_stdio_transport_creation() {
581 let transport = StdioTransport::new();
582 assert_eq!(transport.transport_type(), TransportType::Stdio);
583 assert!(transport.capabilities().supports_streaming);
584 assert!(transport.capabilities().supports_bidirectional);
585 }
586
587 #[test]
588 fn test_stdio_transport_with_config() {
589 let config = TransportConfig {
590 transport_type: TransportType::Stdio,
591 connect_timeout: Duration::from_secs(10),
592 ..Default::default()
593 };
594
595 let transport = StdioTransport::with_config(config);
596 assert_eq!(
597 transport
598 .config
599 .lock()
600 .expect("config mutex poisoned")
601 .connect_timeout,
602 Duration::from_secs(10)
603 );
604 }
605
606 #[tokio::test]
607 async fn test_stdio_transport_state_management() {
608 let transport = StdioTransport::new();
609 assert_eq!(transport.state().await, TransportState::Disconnected);
610 }
611
612 #[test]
613 fn test_message_parsing() {
614 let json_line = r#"{"jsonrpc":"2.0","id":"test-123","method":"test","params":{}}"#;
615 let message = StdioTransport::parse_message(json_line).unwrap();
616
617 assert_eq!(message.id, MessageId::from("test-123"));
618 assert_eq!(message.content_type(), Some("application/json"));
619 assert!(!message.payload.is_empty());
620 }
621
622 #[test]
623 fn test_message_parsing_with_numeric_id() {
624 let json_line = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{}}"#;
625 let message = StdioTransport::parse_message(json_line).unwrap();
626
627 assert_eq!(message.id, MessageId::from(42));
628 }
629
630 #[test]
631 fn test_message_parsing_without_id() {
632 let json_line = r#"{"jsonrpc":"2.0","method":"notification","params":{}}"#;
633 let message = StdioTransport::parse_message(json_line).unwrap();
634
635 match message.id {
637 MessageId::Uuid(_) => {} _ => assert!(
639 matches!(message.id, MessageId::Uuid(_)),
640 "Expected UUID message ID"
641 ),
642 }
643 }
644
645 #[test]
646 fn test_message_parsing_invalid_json() {
647 let invalid_json = "not json at all";
648 let result = StdioTransport::parse_message(invalid_json);
649
650 assert!(matches!(
651 result,
652 Err(TransportError::SerializationFailed(_))
653 ));
654 }
655
656 #[test]
657 fn test_message_parsing_empty() {
658 let result = StdioTransport::parse_message("");
659 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
660
661 let result = StdioTransport::parse_message(" ");
662 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
663 }
664
665 #[test]
666 fn test_message_serialization() {
667 let json_str = r#"{"jsonrpc":"2.0","id":"test","method":"ping"}"#;
668 let payload = Bytes::from(json_str);
669 let message = TransportMessage::new(MessageId::from("test"), payload);
670
671 let serialized = StdioTransport::serialize_message(&message).unwrap();
672 assert_eq!(serialized, json_str);
673 }
674
675 #[test]
676 fn test_message_serialization_invalid_utf8() {
677 let payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]); let message = TransportMessage::new(MessageId::from("test"), payload);
679
680 let result = StdioTransport::serialize_message(&message);
681 assert!(matches!(
682 result,
683 Err(TransportError::SerializationFailed(_))
684 ));
685 }
686
687 #[test]
688 fn test_message_serialization_invalid_json() {
689 let payload = Bytes::from("not json");
690 let message = TransportMessage::new(MessageId::from("test"), payload);
691
692 let result = StdioTransport::serialize_message(&message);
693 assert!(matches!(
694 result,
695 Err(TransportError::SerializationFailed(_))
696 ));
697 }
698
699 #[test]
700 fn test_message_serialization_embedded_newline_lf() {
701 let json_with_newline = r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"line1
703line2"}}"#;
704 let payload = Bytes::from(json_with_newline);
705 let message = TransportMessage::new(MessageId::from("test"), payload);
706
707 let result = StdioTransport::serialize_message(&message);
708 assert!(
709 matches!(result, Err(TransportError::ProtocolError(_))),
710 "Expected ProtocolError for message with LF, got: {:?}",
711 result
712 );
713 }
714
715 #[test]
716 fn test_message_serialization_embedded_newline_crlf() {
717 let json_with_crlf = "{\r\n\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
719 let payload = Bytes::from(json_with_crlf);
720 let message = TransportMessage::new(MessageId::from("test"), payload);
721
722 let result = StdioTransport::serialize_message(&message);
723 assert!(
724 matches!(result, Err(TransportError::ProtocolError(_))),
725 "Expected ProtocolError for message with CRLF, got: {:?}",
726 result
727 );
728 }
729
730 #[test]
731 fn test_message_serialization_embedded_cr() {
732 let json_with_cr = "{\r\"jsonrpc\":\"2.0\",\"id\":\"test\"}";
734 let payload = Bytes::from(json_with_cr);
735 let message = TransportMessage::new(MessageId::from("test"), payload);
736
737 let result = StdioTransport::serialize_message(&message);
738 assert!(
739 matches!(result, Err(TransportError::ProtocolError(_))),
740 "Expected ProtocolError for message with CR, got: {:?}",
741 result
742 );
743 }
744
745 #[test]
746 fn test_message_serialization_valid_no_newlines() {
747 let valid_json =
749 r#"{"jsonrpc":"2.0","id":"test","method":"test","params":{"text":"single line"}}"#;
750 let payload = Bytes::from(valid_json);
751 let message = TransportMessage::new(MessageId::from("test"), payload);
752
753 let result = StdioTransport::serialize_message(&message);
754 assert!(
755 result.is_ok(),
756 "Valid message without newlines should be accepted"
757 );
758 assert_eq!(result.unwrap(), valid_json);
759 }
760
761 #[test]
762 fn test_message_serialization_escaped_newlines_allowed() {
763 let json_with_escaped_newlines = r#"{"jsonrpc":"2.0","id":"test","method":"log","params":{"message":"line1\nline2\ntab:\there"}}"#;
777
778 assert!(
780 !json_with_escaped_newlines.contains('\n'),
781 "Test setup error: raw string should not contain literal newline bytes"
782 );
783 assert!(
784 !json_with_escaped_newlines.contains('\r'),
785 "Test setup error: raw string should not contain literal CR bytes"
786 );
787
788 let payload = Bytes::from(json_with_escaped_newlines);
789 let message = TransportMessage::new(MessageId::from("test"), payload);
790
791 let result = StdioTransport::serialize_message(&message);
792 assert!(
793 result.is_ok(),
794 "JSON with ESCAPED newlines (backslash-n) should be ALLOWED per MCP spec. Got: {:?}",
795 result
796 );
797 assert_eq!(result.unwrap(), json_with_escaped_newlines);
798 }
799
800 #[test]
801 fn test_stdio_factory() {
802 let factory = StdioTransportFactory::new();
803 assert_eq!(factory.transport_type(), TransportType::Stdio);
804 assert!(factory.is_available());
805
806 let config = TransportConfig {
807 transport_type: TransportType::Stdio,
808 ..Default::default()
809 };
810
811 let transport = factory.create(config).unwrap();
812 assert_eq!(transport.transport_type(), TransportType::Stdio);
813 }
814
815 #[test]
816 fn test_stdio_factory_invalid_config() {
817 let factory = StdioTransportFactory::new();
818 let config = TransportConfig {
819 transport_type: TransportType::Http, ..Default::default()
821 };
822
823 let result = factory.create(config);
824 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
825 }
826
827 #[tokio::test]
828 async fn test_configuration_validation() {
829 let transport = StdioTransport::new();
830
831 let valid_config = TransportConfig {
833 transport_type: TransportType::Stdio,
834 connect_timeout: Duration::from_secs(5),
835 ..Default::default()
836 };
837
838 assert!(transport.configure(valid_config).await.is_ok());
839
840 let invalid_config = TransportConfig {
842 transport_type: TransportType::Http,
843 ..Default::default()
844 };
845
846 let result = transport.configure(invalid_config).await;
847 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
848
849 let invalid_timeout_config = TransportConfig {
851 transport_type: TransportType::Stdio,
852 connect_timeout: Duration::from_millis(50), ..Default::default()
854 };
855
856 let result = transport.configure(invalid_timeout_config).await;
857 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
858 }
859}