turbomcp_transport/
stdio.rs

1//! Standard I/O transport implementation.
2//!
3//! This transport uses stdin/stdout for communication, which is the
4//! standard way MCP servers communicate with clients. It supports
5//! JSON-RPC over newline-delimited JSON.
6//!
7//! # Interior Mutability Pattern
8//!
9//! This transport follows the research-backed hybrid mutex pattern for
10//! optimal performance in async contexts:
11//!
12//! - **std::sync::Mutex** for state/config (short-lived locks, never cross .await)
13//! - **AtomicMetrics** for lock-free counter updates (10-100x faster than Mutex)
14//! - **tokio::sync::Mutex** for I/O streams (only when necessary, cross .await points)
15
16use std::sync::atomic::Ordering;
17use std::sync::{Arc, Mutex as StdMutex};
18use std::time::Duration;
19
20use async_trait::async_trait;
21use bytes::Bytes;
22use futures::StreamExt;
23use serde_json;
24use tokio::io::{BufReader, Stdin, Stdout};
25use tokio::sync::{Mutex as TokioMutex, mpsc};
26use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
27use tracing::{debug, error, trace, warn};
28use turbomcp_protocol::MessageId;
29use uuid::Uuid;
30
31use crate::core::{
32    AtomicMetrics, Transport, TransportCapabilities, TransportConfig, TransportError,
33    TransportEventEmitter, TransportFactory, TransportMessage, TransportMessageMetadata,
34    TransportMetrics, TransportResult, TransportState, TransportType,
35};
36
37// Type alias to reduce complexity for clippy
38type StdinReader = FramedRead<BufReader<Stdin>, LinesCodec>;
39type StdoutWriter = FramedWrite<Stdout, LinesCodec>;
40
41/// Standard I/O transport implementation
42///
43/// # Interior Mutability Architecture
44///
45/// Following research-backed 2025 Rust async best practices:
46///
47/// - `state`: std::sync::Mutex (short-lived locks, never held across .await)
48/// - `config`: std::sync::Mutex (infrequent updates, short-lived locks)
49/// - `metrics`: AtomicMetrics (lock-free counters, 10-100x faster than Mutex)
50/// - I/O streams: tokio::sync::Mutex (held across .await, necessary for async I/O)
51#[derive(Debug)]
52pub struct StdioTransport {
53    /// Transport state (std::sync::Mutex - never crosses await)
54    state: Arc<StdMutex<TransportState>>,
55
56    /// Transport capabilities (immutable after construction)
57    capabilities: TransportCapabilities,
58
59    /// Transport configuration (std::sync::Mutex - infrequent access)
60    config: Arc<StdMutex<TransportConfig>>,
61
62    /// Lock-free atomic metrics (10-100x faster than Mutex)
63    metrics: Arc<AtomicMetrics>,
64
65    /// Event emitter
66    event_emitter: TransportEventEmitter,
67
68    /// Stdin reader (tokio::sync::Mutex - crosses await boundaries)
69    stdin_reader: Arc<TokioMutex<Option<StdinReader>>>,
70
71    /// Stdout writer (tokio::sync::Mutex - crosses await boundaries)
72    stdout_writer: Arc<TokioMutex<Option<StdoutWriter>>>,
73
74    /// Message receive channel (tokio::sync::Mutex - crosses await boundaries)
75    receive_channel: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
76
77    /// Background task handle (tokio::sync::Mutex - crosses await boundaries)
78    _task_handle: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
79}
80
81impl StdioTransport {
82    /// Create a new stdio transport
83    #[must_use]
84    pub fn new() -> Self {
85        let (event_emitter, _) = TransportEventEmitter::new();
86
87        Self {
88            state: Arc::new(StdMutex::new(TransportState::Disconnected)),
89            capabilities: TransportCapabilities {
90                max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
91                supports_compression: false,
92                supports_streaming: true,
93                supports_bidirectional: true,
94                supports_multiplexing: false,
95                compression_algorithms: Vec::new(),
96                custom: std::collections::HashMap::new(),
97            },
98            config: Arc::new(StdMutex::new(TransportConfig {
99                transport_type: TransportType::Stdio,
100                ..Default::default()
101            })),
102            metrics: Arc::new(AtomicMetrics::default()),
103            event_emitter,
104            stdin_reader: Arc::new(TokioMutex::new(None)),
105            stdout_writer: Arc::new(TokioMutex::new(None)),
106            receive_channel: Arc::new(TokioMutex::new(None)),
107            _task_handle: Arc::new(TokioMutex::new(None)),
108        }
109    }
110
111    /// Create a stdio transport with custom configuration
112    #[must_use]
113    pub fn with_config(config: TransportConfig) -> Self {
114        let transport = Self::new();
115        // std::sync::Mutex: .lock() returns LockResult, use expect() for poisoned mutex
116        *transport.config.lock().expect("config mutex poisoned") = config;
117        transport
118    }
119
120    /// Create a stdio transport with event emitter
121    #[must_use]
122    pub fn with_event_emitter(event_emitter: TransportEventEmitter) -> Self {
123        let (_, _) = TransportEventEmitter::new();
124
125        Self {
126            state: Arc::new(StdMutex::new(TransportState::Disconnected)),
127            capabilities: TransportCapabilities {
128                max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
129                supports_compression: false,
130                supports_streaming: true,
131                supports_bidirectional: true,
132                supports_multiplexing: false,
133                compression_algorithms: Vec::new(),
134                custom: std::collections::HashMap::new(),
135            },
136            config: Arc::new(StdMutex::new(TransportConfig {
137                transport_type: TransportType::Stdio,
138                ..Default::default()
139            })),
140            metrics: Arc::new(AtomicMetrics::default()),
141            event_emitter,
142            stdin_reader: Arc::new(TokioMutex::new(None)),
143            stdout_writer: Arc::new(TokioMutex::new(None)),
144            receive_channel: Arc::new(TokioMutex::new(None)),
145            _task_handle: Arc::new(TokioMutex::new(None)),
146        }
147    }
148
149    fn set_state(&self, new_state: TransportState) {
150        // std::sync::Mutex: short-lived lock, never crosses await
151        let mut state = self.state.lock().expect("state mutex poisoned");
152        if *state != new_state {
153            trace!("Stdio transport state: {:?} -> {:?}", *state, new_state);
154            *state = new_state.clone();
155
156            match new_state {
157                TransportState::Connected => {
158                    self.event_emitter
159                        .emit_connected(TransportType::Stdio, "stdio://".to_string());
160                }
161                TransportState::Disconnected => {
162                    self.event_emitter.emit_disconnected(
163                        TransportType::Stdio,
164                        "stdio://".to_string(),
165                        None,
166                    );
167                }
168                TransportState::Failed { reason } => {
169                    self.event_emitter.emit_disconnected(
170                        TransportType::Stdio,
171                        "stdio://".to_string(),
172                        Some(reason),
173                    );
174                }
175                _ => {}
176            }
177        }
178    }
179
180    /// Send a ping/heartbeat to stdout to keep the connection lively (optional for stdio)
181    #[allow(dead_code)]
182    fn heartbeat(&self) {
183        // No-op: AtomicMetrics are updated directly at send/receive sites
184        // No dedicated heartbeat counter needed
185    }
186
187    async fn setup_stdio_streams(&self) -> TransportResult<()> {
188        // Setup stdin reader
189        let stdin = tokio::io::stdin();
190        let reader = BufReader::new(stdin);
191        let mut stdin_reader = FramedRead::new(reader, LinesCodec::new());
192
193        // Setup stdout writer
194        let stdout = tokio::io::stdout();
195        *self.stdout_writer.lock().await = Some(FramedWrite::new(stdout, LinesCodec::new()));
196
197        // Setup message receive channel (bounded for backpressure)
198        let (tx, rx) = mpsc::channel(1000);
199        *self.receive_channel.lock().await = Some(rx);
200
201        // Start background reader task
202        {
203            let sender = tx;
204            let event_emitter = self.event_emitter.clone();
205            let metrics = self.metrics.clone();
206
207            let task_handle = tokio::spawn(async move {
208                while let Some(result) = stdin_reader.next().await {
209                    match result {
210                        Ok(line) => {
211                            trace!("Received line: {}", line);
212
213                            match Self::parse_message(&line) {
214                                Ok(message) => {
215                                    let size = message.size();
216
217                                    // Update metrics (lock-free atomic operations)
218                                    metrics.messages_received.fetch_add(1, Ordering::Relaxed);
219                                    metrics
220                                        .bytes_received
221                                        .fetch_add(size as u64, Ordering::Relaxed);
222
223                                    // Emit event
224                                    event_emitter.emit_message_received(message.id.clone(), size);
225
226                                    // Use try_send with backpressure handling
227                                    match sender.try_send(message) {
228                                        Ok(()) => {}
229                                        Err(mpsc::error::TrySendError::Full(_)) => {
230                                            warn!(
231                                                "STDIO message channel full, applying backpressure"
232                                            );
233                                            // Apply backpressure by dropping this message
234                                            continue;
235                                        }
236                                        Err(mpsc::error::TrySendError::Closed(_)) => {
237                                            debug!("Receive channel closed, stopping reader task");
238                                            break;
239                                        }
240                                    }
241                                }
242                                Err(e) => {
243                                    error!("Failed to parse message: {}", e);
244                                    event_emitter
245                                        .emit_error(e, Some("message parsing".to_string()));
246                                }
247                            }
248                        }
249                        Err(e) => {
250                            error!("Failed to read from stdin: {}", e);
251                            event_emitter.emit_error(
252                                TransportError::ReceiveFailed(e.to_string()),
253                                Some("stdin read".to_string()),
254                            );
255                            break;
256                        }
257                    }
258                }
259
260                debug!("Stdio reader task completed");
261            });
262
263            *self._task_handle.lock().await = Some(task_handle);
264        }
265
266        Ok(())
267    }
268
269    fn parse_message(line: &str) -> TransportResult<TransportMessage> {
270        let line = line.trim();
271        if line.is_empty() {
272            return Err(TransportError::ProtocolError("Empty message".to_string()));
273        }
274
275        // Parse JSON
276        let json_value: serde_json::Value = serde_json::from_str(line)
277            .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
278
279        // Extract message ID
280        let message_id = json_value
281            .get("id")
282            .and_then(|id| match id {
283                serde_json::Value::String(s) => Some(MessageId::from(s.clone())),
284                serde_json::Value::Number(n) => n.as_i64().map(MessageId::from),
285                _ => None,
286            })
287            .unwrap_or_else(|| MessageId::from(Uuid::new_v4()));
288
289        // Create transport message
290        let payload = Bytes::from(line.to_string());
291        let metadata = TransportMessageMetadata::with_content_type("application/json");
292
293        Ok(TransportMessage::with_metadata(
294            message_id, payload, metadata,
295        ))
296    }
297
298    fn serialize_message(message: &TransportMessage) -> TransportResult<String> {
299        // Convert bytes back to string for stdio transport
300        let json_str = std::str::from_utf8(&message.payload)
301            .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
302
303        // Validate JSON
304        let _: serde_json::Value = serde_json::from_str(json_str)
305            .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
306
307        Ok(json_str.to_string())
308    }
309}
310
311#[async_trait]
312impl Transport for StdioTransport {
313    fn transport_type(&self) -> TransportType {
314        TransportType::Stdio
315    }
316
317    fn capabilities(&self) -> &TransportCapabilities {
318        &self.capabilities
319    }
320
321    async fn state(&self) -> TransportState {
322        // std::sync::Mutex: short-lived lock for reading state
323        self.state.lock().expect("state mutex poisoned").clone()
324    }
325
326    async fn connect(&self) -> TransportResult<()> {
327        if matches!(self.state().await, TransportState::Connected) {
328            return Ok(());
329        }
330
331        self.set_state(TransportState::Connecting);
332
333        match self.setup_stdio_streams().await {
334            Ok(()) => {
335                // AtomicMetrics: lock-free increment
336                self.metrics.connections.fetch_add(1, Ordering::Relaxed);
337                self.set_state(TransportState::Connected);
338                debug!("Stdio transport connected");
339                Ok(())
340            }
341            Err(e) => {
342                // AtomicMetrics: lock-free increment
343                self.metrics
344                    .failed_connections
345                    .fetch_add(1, Ordering::Relaxed);
346                self.set_state(TransportState::Failed {
347                    reason: e.to_string(),
348                });
349                error!("Failed to connect stdio transport: {}", e);
350                Err(e)
351            }
352        }
353    }
354
355    async fn disconnect(&self) -> TransportResult<()> {
356        if matches!(self.state().await, TransportState::Disconnected) {
357            return Ok(());
358        }
359
360        self.set_state(TransportState::Disconnecting);
361
362        // Close streams
363        *self.stdin_reader.lock().await = None;
364        *self.stdout_writer.lock().await = None;
365        *self.receive_channel.lock().await = None;
366
367        // Cancel background task
368        if let Some(handle) = self._task_handle.lock().await.take() {
369            handle.abort();
370        }
371
372        self.set_state(TransportState::Disconnected);
373        debug!("Stdio transport disconnected");
374        Ok(())
375    }
376
377    async fn send(&self, message: TransportMessage) -> TransportResult<()> {
378        let state = self.state().await;
379        if !matches!(state, TransportState::Connected) {
380            return Err(TransportError::ConnectionFailed(format!(
381                "Transport not connected: {state}"
382            )));
383        }
384
385        let json_line = Self::serialize_message(&message)?;
386        let size = json_line.len();
387
388        let mut stdout_writer = self.stdout_writer.lock().await;
389        if let Some(writer) = stdout_writer.as_mut() {
390            if let Err(e) = writer.send(json_line).await {
391                error!("Failed to send message: {}", e);
392                self.set_state(TransportState::Failed {
393                    reason: e.to_string(),
394                });
395                return Err(TransportError::SendFailed(e.to_string()));
396            }
397
398            // Flush to ensure message is sent immediately
399            use futures::SinkExt;
400            if let Err(e) = SinkExt::<String>::flush(writer).await {
401                error!("Failed to flush stdout: {}", e);
402                return Err(TransportError::SendFailed(e.to_string()));
403            }
404
405            // Update metrics (lock-free atomic operations)
406            self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
407            self.metrics
408                .bytes_sent
409                .fetch_add(size as u64, Ordering::Relaxed);
410
411            // Emit event
412            self.event_emitter.emit_message_sent(message.id, size);
413
414            trace!("Sent message: {} bytes", size);
415            Ok(())
416        } else {
417            Err(TransportError::SendFailed(
418                "Stdout writer not available".to_string(),
419            ))
420        }
421    }
422
423    async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
424        let state = self.state().await;
425        if !matches!(state, TransportState::Connected) {
426            return Err(TransportError::ConnectionFailed(format!(
427                "Transport not connected: {state}"
428            )));
429        }
430
431        let mut receive_channel = self.receive_channel.lock().await;
432        if let Some(receiver) = receive_channel.as_mut() {
433            match receiver.recv().await {
434                Some(message) => {
435                    trace!("Received message: {} bytes", message.size());
436                    Ok(Some(message))
437                }
438                None => {
439                    warn!("Receive channel disconnected");
440                    self.set_state(TransportState::Failed {
441                        reason: "Receive channel disconnected".to_string(),
442                    });
443                    Err(TransportError::ReceiveFailed(
444                        "Channel disconnected".to_string(),
445                    ))
446                }
447            }
448        } else {
449            Err(TransportError::ReceiveFailed(
450                "Receive channel not available".to_string(),
451            ))
452        }
453    }
454
455    async fn metrics(&self) -> TransportMetrics {
456        // AtomicMetrics: lock-free snapshot with Ordering::Relaxed
457        self.metrics.snapshot()
458    }
459
460    fn endpoint(&self) -> Option<String> {
461        Some("stdio://".to_string())
462    }
463
464    async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
465        if config.transport_type != TransportType::Stdio {
466            return Err(TransportError::ConfigurationError(format!(
467                "Invalid transport type: {:?}",
468                config.transport_type
469            )));
470        }
471
472        // Validate configuration
473        if config.connect_timeout < Duration::from_millis(100) {
474            return Err(TransportError::ConfigurationError(
475                "Connect timeout too small".to_string(),
476            ));
477        }
478
479        // std::sync::Mutex: short-lived lock for updating config
480        *self.config.lock().expect("config mutex poisoned") = config;
481        debug!("Stdio transport configured");
482        Ok(())
483    }
484}
485
486/// Factory for creating stdio transport instances
487#[derive(Debug, Default)]
488pub struct StdioTransportFactory;
489
490impl StdioTransportFactory {
491    /// Create a new stdio transport factory
492    #[must_use]
493    pub const fn new() -> Self {
494        Self
495    }
496}
497
498impl TransportFactory for StdioTransportFactory {
499    fn transport_type(&self) -> TransportType {
500        TransportType::Stdio
501    }
502
503    fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>> {
504        if config.transport_type != TransportType::Stdio {
505            return Err(TransportError::ConfigurationError(format!(
506                "Invalid transport type: {:?}",
507                config.transport_type
508            )));
509        }
510
511        let transport = StdioTransport::with_config(config);
512        Ok(Box::new(transport))
513    }
514
515    fn is_available(&self) -> bool {
516        // Stdio is always available
517        true
518    }
519}
520
521impl Default for StdioTransport {
522    fn default() -> Self {
523        Self::new()
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use pretty_assertions::assert_eq;
531    // use serde_json::json;
532    // use tokio_test;
533
534    #[test]
535    fn test_stdio_transport_creation() {
536        let transport = StdioTransport::new();
537        assert_eq!(transport.transport_type(), TransportType::Stdio);
538        assert!(transport.capabilities().supports_streaming);
539        assert!(transport.capabilities().supports_bidirectional);
540    }
541
542    #[test]
543    fn test_stdio_transport_with_config() {
544        let config = TransportConfig {
545            transport_type: TransportType::Stdio,
546            connect_timeout: Duration::from_secs(10),
547            ..Default::default()
548        };
549
550        let transport = StdioTransport::with_config(config);
551        assert_eq!(
552            transport
553                .config
554                .lock()
555                .expect("config mutex poisoned")
556                .connect_timeout,
557            Duration::from_secs(10)
558        );
559    }
560
561    #[tokio::test]
562    async fn test_stdio_transport_state_management() {
563        let transport = StdioTransport::new();
564        assert_eq!(transport.state().await, TransportState::Disconnected);
565    }
566
567    #[test]
568    fn test_message_parsing() {
569        let json_line = r#"{"jsonrpc":"2.0","id":"test-123","method":"test","params":{}}"#;
570        let message = StdioTransport::parse_message(json_line).unwrap();
571
572        assert_eq!(message.id, MessageId::from("test-123"));
573        assert_eq!(message.content_type(), Some("application/json"));
574        assert!(!message.payload.is_empty());
575    }
576
577    #[test]
578    fn test_message_parsing_with_numeric_id() {
579        let json_line = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{}}"#;
580        let message = StdioTransport::parse_message(json_line).unwrap();
581
582        assert_eq!(message.id, MessageId::from(42));
583    }
584
585    #[test]
586    fn test_message_parsing_without_id() {
587        let json_line = r#"{"jsonrpc":"2.0","method":"notification","params":{}}"#;
588        let message = StdioTransport::parse_message(json_line).unwrap();
589
590        // Should generate a UUID when no ID is present
591        match message.id {
592            MessageId::Uuid(_) => {} // Expected
593            _ => assert!(
594                matches!(message.id, MessageId::Uuid(_)),
595                "Expected UUID message ID"
596            ),
597        }
598    }
599
600    #[test]
601    fn test_message_parsing_invalid_json() {
602        let invalid_json = "not json at all";
603        let result = StdioTransport::parse_message(invalid_json);
604
605        assert!(matches!(
606            result,
607            Err(TransportError::SerializationFailed(_))
608        ));
609    }
610
611    #[test]
612    fn test_message_parsing_empty() {
613        let result = StdioTransport::parse_message("");
614        assert!(matches!(result, Err(TransportError::ProtocolError(_))));
615
616        let result = StdioTransport::parse_message("   ");
617        assert!(matches!(result, Err(TransportError::ProtocolError(_))));
618    }
619
620    #[test]
621    fn test_message_serialization() {
622        let json_str = r#"{"jsonrpc":"2.0","id":"test","method":"ping"}"#;
623        let payload = Bytes::from(json_str);
624        let message = TransportMessage::new(MessageId::from("test"), payload);
625
626        let serialized = StdioTransport::serialize_message(&message).unwrap();
627        assert_eq!(serialized, json_str);
628    }
629
630    #[test]
631    fn test_message_serialization_invalid_utf8() {
632        let payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]); // Invalid UTF-8
633        let message = TransportMessage::new(MessageId::from("test"), payload);
634
635        let result = StdioTransport::serialize_message(&message);
636        assert!(matches!(
637            result,
638            Err(TransportError::SerializationFailed(_))
639        ));
640    }
641
642    #[test]
643    fn test_message_serialization_invalid_json() {
644        let payload = Bytes::from("not json");
645        let message = TransportMessage::new(MessageId::from("test"), payload);
646
647        let result = StdioTransport::serialize_message(&message);
648        assert!(matches!(
649            result,
650            Err(TransportError::SerializationFailed(_))
651        ));
652    }
653
654    #[test]
655    fn test_stdio_factory() {
656        let factory = StdioTransportFactory::new();
657        assert_eq!(factory.transport_type(), TransportType::Stdio);
658        assert!(factory.is_available());
659
660        let config = TransportConfig {
661            transport_type: TransportType::Stdio,
662            ..Default::default()
663        };
664
665        let transport = factory.create(config).unwrap();
666        assert_eq!(transport.transport_type(), TransportType::Stdio);
667    }
668
669    #[test]
670    fn test_stdio_factory_invalid_config() {
671        let factory = StdioTransportFactory::new();
672        let config = TransportConfig {
673            transport_type: TransportType::Http, // Wrong type
674            ..Default::default()
675        };
676
677        let result = factory.create(config);
678        assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
679    }
680
681    #[tokio::test]
682    async fn test_configuration_validation() {
683        let transport = StdioTransport::new();
684
685        // Valid configuration
686        let valid_config = TransportConfig {
687            transport_type: TransportType::Stdio,
688            connect_timeout: Duration::from_secs(5),
689            ..Default::default()
690        };
691
692        assert!(transport.configure(valid_config).await.is_ok());
693
694        // Invalid transport type
695        let invalid_config = TransportConfig {
696            transport_type: TransportType::Http,
697            ..Default::default()
698        };
699
700        let result = transport.configure(invalid_config).await;
701        assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
702
703        // Invalid timeout
704        let invalid_timeout_config = TransportConfig {
705            transport_type: TransportType::Stdio,
706            connect_timeout: Duration::from_millis(50), // Too small
707            ..Default::default()
708        };
709
710        let result = transport.configure(invalid_timeout_config).await;
711        assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
712    }
713}