1use std::sync::atomic::Ordering;
17use std::sync::{Arc, Mutex as StdMutex};
18use std::time::Duration;
19
20use async_trait::async_trait;
21use bytes::Bytes;
22use futures::StreamExt;
23use serde_json;
24use tokio::io::{BufReader, Stdin, Stdout};
25use tokio::sync::{Mutex as TokioMutex, mpsc};
26use tokio_util::codec::{FramedRead, FramedWrite, LinesCodec};
27use tracing::{debug, error, trace, warn};
28use turbomcp_protocol::MessageId;
29use uuid::Uuid;
30
31use crate::core::{
32 AtomicMetrics, Transport, TransportCapabilities, TransportConfig, TransportError,
33 TransportEventEmitter, TransportFactory, TransportMessage, TransportMessageMetadata,
34 TransportMetrics, TransportResult, TransportState, TransportType,
35};
36
37type StdinReader = FramedRead<BufReader<Stdin>, LinesCodec>;
39type StdoutWriter = FramedWrite<Stdout, LinesCodec>;
40
41#[derive(Debug)]
52pub struct StdioTransport {
53 state: Arc<StdMutex<TransportState>>,
55
56 capabilities: TransportCapabilities,
58
59 config: Arc<StdMutex<TransportConfig>>,
61
62 metrics: Arc<AtomicMetrics>,
64
65 event_emitter: TransportEventEmitter,
67
68 stdin_reader: Arc<TokioMutex<Option<StdinReader>>>,
70
71 stdout_writer: Arc<TokioMutex<Option<StdoutWriter>>>,
73
74 receive_channel: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
76
77 _task_handle: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
79}
80
81impl StdioTransport {
82 #[must_use]
84 pub fn new() -> Self {
85 let (event_emitter, _) = TransportEventEmitter::new();
86
87 Self {
88 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
89 capabilities: TransportCapabilities {
90 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
91 supports_compression: false,
92 supports_streaming: true,
93 supports_bidirectional: true,
94 supports_multiplexing: false,
95 compression_algorithms: Vec::new(),
96 custom: std::collections::HashMap::new(),
97 },
98 config: Arc::new(StdMutex::new(TransportConfig {
99 transport_type: TransportType::Stdio,
100 ..Default::default()
101 })),
102 metrics: Arc::new(AtomicMetrics::default()),
103 event_emitter,
104 stdin_reader: Arc::new(TokioMutex::new(None)),
105 stdout_writer: Arc::new(TokioMutex::new(None)),
106 receive_channel: Arc::new(TokioMutex::new(None)),
107 _task_handle: Arc::new(TokioMutex::new(None)),
108 }
109 }
110
111 #[must_use]
113 pub fn with_config(config: TransportConfig) -> Self {
114 let transport = Self::new();
115 *transport.config.lock().expect("config mutex poisoned") = config;
117 transport
118 }
119
120 #[must_use]
122 pub fn with_event_emitter(event_emitter: TransportEventEmitter) -> Self {
123 let (_, _) = TransportEventEmitter::new();
124
125 Self {
126 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
127 capabilities: TransportCapabilities {
128 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
129 supports_compression: false,
130 supports_streaming: true,
131 supports_bidirectional: true,
132 supports_multiplexing: false,
133 compression_algorithms: Vec::new(),
134 custom: std::collections::HashMap::new(),
135 },
136 config: Arc::new(StdMutex::new(TransportConfig {
137 transport_type: TransportType::Stdio,
138 ..Default::default()
139 })),
140 metrics: Arc::new(AtomicMetrics::default()),
141 event_emitter,
142 stdin_reader: Arc::new(TokioMutex::new(None)),
143 stdout_writer: Arc::new(TokioMutex::new(None)),
144 receive_channel: Arc::new(TokioMutex::new(None)),
145 _task_handle: Arc::new(TokioMutex::new(None)),
146 }
147 }
148
149 fn set_state(&self, new_state: TransportState) {
150 let mut state = self.state.lock().expect("state mutex poisoned");
152 if *state != new_state {
153 trace!("Stdio transport state: {:?} -> {:?}", *state, new_state);
154 *state = new_state.clone();
155
156 match new_state {
157 TransportState::Connected => {
158 self.event_emitter
159 .emit_connected(TransportType::Stdio, "stdio://".to_string());
160 }
161 TransportState::Disconnected => {
162 self.event_emitter.emit_disconnected(
163 TransportType::Stdio,
164 "stdio://".to_string(),
165 None,
166 );
167 }
168 TransportState::Failed { reason } => {
169 self.event_emitter.emit_disconnected(
170 TransportType::Stdio,
171 "stdio://".to_string(),
172 Some(reason),
173 );
174 }
175 _ => {}
176 }
177 }
178 }
179
180 #[allow(dead_code)]
182 fn heartbeat(&self) {
183 }
186
187 async fn setup_stdio_streams(&self) -> TransportResult<()> {
188 let stdin = tokio::io::stdin();
190 let reader = BufReader::new(stdin);
191 let mut stdin_reader = FramedRead::new(reader, LinesCodec::new());
192
193 let stdout = tokio::io::stdout();
195 *self.stdout_writer.lock().await = Some(FramedWrite::new(stdout, LinesCodec::new()));
196
197 let (tx, rx) = mpsc::channel(1000);
199 *self.receive_channel.lock().await = Some(rx);
200
201 {
203 let sender = tx;
204 let event_emitter = self.event_emitter.clone();
205 let metrics = self.metrics.clone();
206
207 let task_handle = tokio::spawn(async move {
208 while let Some(result) = stdin_reader.next().await {
209 match result {
210 Ok(line) => {
211 trace!("Received line: {}", line);
212
213 match Self::parse_message(&line) {
214 Ok(message) => {
215 let size = message.size();
216
217 metrics.messages_received.fetch_add(1, Ordering::Relaxed);
219 metrics
220 .bytes_received
221 .fetch_add(size as u64, Ordering::Relaxed);
222
223 event_emitter.emit_message_received(message.id.clone(), size);
225
226 match sender.try_send(message) {
228 Ok(()) => {}
229 Err(mpsc::error::TrySendError::Full(_)) => {
230 warn!(
231 "STDIO message channel full, applying backpressure"
232 );
233 continue;
235 }
236 Err(mpsc::error::TrySendError::Closed(_)) => {
237 debug!("Receive channel closed, stopping reader task");
238 break;
239 }
240 }
241 }
242 Err(e) => {
243 error!("Failed to parse message: {}", e);
244 event_emitter
245 .emit_error(e, Some("message parsing".to_string()));
246 }
247 }
248 }
249 Err(e) => {
250 error!("Failed to read from stdin: {}", e);
251 event_emitter.emit_error(
252 TransportError::ReceiveFailed(e.to_string()),
253 Some("stdin read".to_string()),
254 );
255 break;
256 }
257 }
258 }
259
260 debug!("Stdio reader task completed");
261 });
262
263 *self._task_handle.lock().await = Some(task_handle);
264 }
265
266 Ok(())
267 }
268
269 fn parse_message(line: &str) -> TransportResult<TransportMessage> {
270 let line = line.trim();
271 if line.is_empty() {
272 return Err(TransportError::ProtocolError("Empty message".to_string()));
273 }
274
275 let json_value: serde_json::Value = serde_json::from_str(line)
277 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
278
279 let message_id = json_value
281 .get("id")
282 .and_then(|id| match id {
283 serde_json::Value::String(s) => Some(MessageId::from(s.clone())),
284 serde_json::Value::Number(n) => n.as_i64().map(MessageId::from),
285 _ => None,
286 })
287 .unwrap_or_else(|| MessageId::from(Uuid::new_v4()));
288
289 let payload = Bytes::from(line.to_string());
291 let metadata = TransportMessageMetadata::with_content_type("application/json");
292
293 Ok(TransportMessage::with_metadata(
294 message_id, payload, metadata,
295 ))
296 }
297
298 fn serialize_message(message: &TransportMessage) -> TransportResult<String> {
299 let json_str = std::str::from_utf8(&message.payload)
301 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
302
303 let _: serde_json::Value = serde_json::from_str(json_str)
305 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
306
307 Ok(json_str.to_string())
308 }
309}
310
311#[async_trait]
312impl Transport for StdioTransport {
313 fn transport_type(&self) -> TransportType {
314 TransportType::Stdio
315 }
316
317 fn capabilities(&self) -> &TransportCapabilities {
318 &self.capabilities
319 }
320
321 async fn state(&self) -> TransportState {
322 self.state.lock().expect("state mutex poisoned").clone()
324 }
325
326 async fn connect(&self) -> TransportResult<()> {
327 if matches!(self.state().await, TransportState::Connected) {
328 return Ok(());
329 }
330
331 self.set_state(TransportState::Connecting);
332
333 match self.setup_stdio_streams().await {
334 Ok(()) => {
335 self.metrics.connections.fetch_add(1, Ordering::Relaxed);
337 self.set_state(TransportState::Connected);
338 debug!("Stdio transport connected");
339 Ok(())
340 }
341 Err(e) => {
342 self.metrics
344 .failed_connections
345 .fetch_add(1, Ordering::Relaxed);
346 self.set_state(TransportState::Failed {
347 reason: e.to_string(),
348 });
349 error!("Failed to connect stdio transport: {}", e);
350 Err(e)
351 }
352 }
353 }
354
355 async fn disconnect(&self) -> TransportResult<()> {
356 if matches!(self.state().await, TransportState::Disconnected) {
357 return Ok(());
358 }
359
360 self.set_state(TransportState::Disconnecting);
361
362 *self.stdin_reader.lock().await = None;
364 *self.stdout_writer.lock().await = None;
365 *self.receive_channel.lock().await = None;
366
367 if let Some(handle) = self._task_handle.lock().await.take() {
369 handle.abort();
370 }
371
372 self.set_state(TransportState::Disconnected);
373 debug!("Stdio transport disconnected");
374 Ok(())
375 }
376
377 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
378 let state = self.state().await;
379 if !matches!(state, TransportState::Connected) {
380 return Err(TransportError::ConnectionFailed(format!(
381 "Transport not connected: {state}"
382 )));
383 }
384
385 let json_line = Self::serialize_message(&message)?;
386 let size = json_line.len();
387
388 let mut stdout_writer = self.stdout_writer.lock().await;
389 if let Some(writer) = stdout_writer.as_mut() {
390 if let Err(e) = writer.send(json_line).await {
391 error!("Failed to send message: {}", e);
392 self.set_state(TransportState::Failed {
393 reason: e.to_string(),
394 });
395 return Err(TransportError::SendFailed(e.to_string()));
396 }
397
398 use futures::SinkExt;
400 if let Err(e) = SinkExt::<String>::flush(writer).await {
401 error!("Failed to flush stdout: {}", e);
402 return Err(TransportError::SendFailed(e.to_string()));
403 }
404
405 self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
407 self.metrics
408 .bytes_sent
409 .fetch_add(size as u64, Ordering::Relaxed);
410
411 self.event_emitter.emit_message_sent(message.id, size);
413
414 trace!("Sent message: {} bytes", size);
415 Ok(())
416 } else {
417 Err(TransportError::SendFailed(
418 "Stdout writer not available".to_string(),
419 ))
420 }
421 }
422
423 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
424 let state = self.state().await;
425 if !matches!(state, TransportState::Connected) {
426 return Err(TransportError::ConnectionFailed(format!(
427 "Transport not connected: {state}"
428 )));
429 }
430
431 let mut receive_channel = self.receive_channel.lock().await;
432 if let Some(receiver) = receive_channel.as_mut() {
433 match receiver.recv().await {
434 Some(message) => {
435 trace!("Received message: {} bytes", message.size());
436 Ok(Some(message))
437 }
438 None => {
439 warn!("Receive channel disconnected");
440 self.set_state(TransportState::Failed {
441 reason: "Receive channel disconnected".to_string(),
442 });
443 Err(TransportError::ReceiveFailed(
444 "Channel disconnected".to_string(),
445 ))
446 }
447 }
448 } else {
449 Err(TransportError::ReceiveFailed(
450 "Receive channel not available".to_string(),
451 ))
452 }
453 }
454
455 async fn metrics(&self) -> TransportMetrics {
456 self.metrics.snapshot()
458 }
459
460 fn endpoint(&self) -> Option<String> {
461 Some("stdio://".to_string())
462 }
463
464 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
465 if config.transport_type != TransportType::Stdio {
466 return Err(TransportError::ConfigurationError(format!(
467 "Invalid transport type: {:?}",
468 config.transport_type
469 )));
470 }
471
472 if config.connect_timeout < Duration::from_millis(100) {
474 return Err(TransportError::ConfigurationError(
475 "Connect timeout too small".to_string(),
476 ));
477 }
478
479 *self.config.lock().expect("config mutex poisoned") = config;
481 debug!("Stdio transport configured");
482 Ok(())
483 }
484}
485
486#[derive(Debug, Default)]
488pub struct StdioTransportFactory;
489
490impl StdioTransportFactory {
491 #[must_use]
493 pub const fn new() -> Self {
494 Self
495 }
496}
497
498impl TransportFactory for StdioTransportFactory {
499 fn transport_type(&self) -> TransportType {
500 TransportType::Stdio
501 }
502
503 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>> {
504 if config.transport_type != TransportType::Stdio {
505 return Err(TransportError::ConfigurationError(format!(
506 "Invalid transport type: {:?}",
507 config.transport_type
508 )));
509 }
510
511 let transport = StdioTransport::with_config(config);
512 Ok(Box::new(transport))
513 }
514
515 fn is_available(&self) -> bool {
516 true
518 }
519}
520
521impl Default for StdioTransport {
522 fn default() -> Self {
523 Self::new()
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use pretty_assertions::assert_eq;
531 #[test]
535 fn test_stdio_transport_creation() {
536 let transport = StdioTransport::new();
537 assert_eq!(transport.transport_type(), TransportType::Stdio);
538 assert!(transport.capabilities().supports_streaming);
539 assert!(transport.capabilities().supports_bidirectional);
540 }
541
542 #[test]
543 fn test_stdio_transport_with_config() {
544 let config = TransportConfig {
545 transport_type: TransportType::Stdio,
546 connect_timeout: Duration::from_secs(10),
547 ..Default::default()
548 };
549
550 let transport = StdioTransport::with_config(config);
551 assert_eq!(
552 transport
553 .config
554 .lock()
555 .expect("config mutex poisoned")
556 .connect_timeout,
557 Duration::from_secs(10)
558 );
559 }
560
561 #[tokio::test]
562 async fn test_stdio_transport_state_management() {
563 let transport = StdioTransport::new();
564 assert_eq!(transport.state().await, TransportState::Disconnected);
565 }
566
567 #[test]
568 fn test_message_parsing() {
569 let json_line = r#"{"jsonrpc":"2.0","id":"test-123","method":"test","params":{}}"#;
570 let message = StdioTransport::parse_message(json_line).unwrap();
571
572 assert_eq!(message.id, MessageId::from("test-123"));
573 assert_eq!(message.content_type(), Some("application/json"));
574 assert!(!message.payload.is_empty());
575 }
576
577 #[test]
578 fn test_message_parsing_with_numeric_id() {
579 let json_line = r#"{"jsonrpc":"2.0","id":42,"method":"test","params":{}}"#;
580 let message = StdioTransport::parse_message(json_line).unwrap();
581
582 assert_eq!(message.id, MessageId::from(42));
583 }
584
585 #[test]
586 fn test_message_parsing_without_id() {
587 let json_line = r#"{"jsonrpc":"2.0","method":"notification","params":{}}"#;
588 let message = StdioTransport::parse_message(json_line).unwrap();
589
590 match message.id {
592 MessageId::Uuid(_) => {} _ => assert!(
594 matches!(message.id, MessageId::Uuid(_)),
595 "Expected UUID message ID"
596 ),
597 }
598 }
599
600 #[test]
601 fn test_message_parsing_invalid_json() {
602 let invalid_json = "not json at all";
603 let result = StdioTransport::parse_message(invalid_json);
604
605 assert!(matches!(
606 result,
607 Err(TransportError::SerializationFailed(_))
608 ));
609 }
610
611 #[test]
612 fn test_message_parsing_empty() {
613 let result = StdioTransport::parse_message("");
614 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
615
616 let result = StdioTransport::parse_message(" ");
617 assert!(matches!(result, Err(TransportError::ProtocolError(_))));
618 }
619
620 #[test]
621 fn test_message_serialization() {
622 let json_str = r#"{"jsonrpc":"2.0","id":"test","method":"ping"}"#;
623 let payload = Bytes::from(json_str);
624 let message = TransportMessage::new(MessageId::from("test"), payload);
625
626 let serialized = StdioTransport::serialize_message(&message).unwrap();
627 assert_eq!(serialized, json_str);
628 }
629
630 #[test]
631 fn test_message_serialization_invalid_utf8() {
632 let payload = Bytes::from(vec![0xFF, 0xFE, 0xFD]); let message = TransportMessage::new(MessageId::from("test"), payload);
634
635 let result = StdioTransport::serialize_message(&message);
636 assert!(matches!(
637 result,
638 Err(TransportError::SerializationFailed(_))
639 ));
640 }
641
642 #[test]
643 fn test_message_serialization_invalid_json() {
644 let payload = Bytes::from("not json");
645 let message = TransportMessage::new(MessageId::from("test"), payload);
646
647 let result = StdioTransport::serialize_message(&message);
648 assert!(matches!(
649 result,
650 Err(TransportError::SerializationFailed(_))
651 ));
652 }
653
654 #[test]
655 fn test_stdio_factory() {
656 let factory = StdioTransportFactory::new();
657 assert_eq!(factory.transport_type(), TransportType::Stdio);
658 assert!(factory.is_available());
659
660 let config = TransportConfig {
661 transport_type: TransportType::Stdio,
662 ..Default::default()
663 };
664
665 let transport = factory.create(config).unwrap();
666 assert_eq!(transport.transport_type(), TransportType::Stdio);
667 }
668
669 #[test]
670 fn test_stdio_factory_invalid_config() {
671 let factory = StdioTransportFactory::new();
672 let config = TransportConfig {
673 transport_type: TransportType::Http, ..Default::default()
675 };
676
677 let result = factory.create(config);
678 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
679 }
680
681 #[tokio::test]
682 async fn test_configuration_validation() {
683 let transport = StdioTransport::new();
684
685 let valid_config = TransportConfig {
687 transport_type: TransportType::Stdio,
688 connect_timeout: Duration::from_secs(5),
689 ..Default::default()
690 };
691
692 assert!(transport.configure(valid_config).await.is_ok());
693
694 let invalid_config = TransportConfig {
696 transport_type: TransportType::Http,
697 ..Default::default()
698 };
699
700 let result = transport.configure(invalid_config).await;
701 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
702
703 let invalid_timeout_config = TransportConfig {
705 transport_type: TransportType::Stdio,
706 connect_timeout: Duration::from_millis(50), ..Default::default()
708 };
709
710 let result = transport.configure(invalid_timeout_config).await;
711 assert!(matches!(result, Err(TransportError::ConfigurationError(_))));
712 }
713}