1use std::collections::HashMap;
8use std::fmt;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use bytes::Bytes;
13use futures::{Sink, Stream};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use tokio::sync::mpsc;
17use turbomcp_protocol::MessageId;
18
19pub type TransportResult<T> = std::result::Result<T, TransportError>;
21
22#[derive(Error, Debug, Clone)]
24pub enum TransportError {
25 #[error("Connection failed: {0}")]
27 ConnectionFailed(String),
28
29 #[error("Connection lost: {0}")]
31 ConnectionLost(String),
32
33 #[error("Send failed: {0}")]
35 SendFailed(String),
36
37 #[error("Receive failed: {0}")]
39 ReceiveFailed(String),
40
41 #[error("Serialization failed: {0}")]
43 SerializationFailed(String),
44
45 #[error("Protocol error: {0}")]
47 ProtocolError(String),
48
49 #[error("Operation timed out")]
51 Timeout,
52
53 #[error(
58 "Connection timed out after {timeout:?} for operation: {operation}. \
59 If this is expected, increase the timeout with \
60 `TimeoutConfig {{ connect: Duration::from_secs({}) }}`",
61 timeout.as_secs() * 2
62 )]
63 ConnectionTimeout {
64 operation: String,
66 timeout: Duration,
68 },
69
70 #[error(
75 "Request timed out after {timeout:?} for operation: {operation}. \
76 If this is expected, increase the timeout with \
77 `TimeoutConfig {{ request: Some(Duration::from_secs({})) }}` \
78 or use `TimeoutConfig::patient()` for slow operations",
79 timeout.as_secs() * 2
80 )]
81 RequestTimeout {
82 operation: String,
84 timeout: Duration,
86 },
87
88 #[error(
93 "Total operation timed out after {timeout:?} for operation: {operation}. \
94 This includes retries. If this is expected, increase the timeout with \
95 `TimeoutConfig {{ total: Some(Duration::from_secs({})) }}`",
96 timeout.as_secs() * 2
97 )]
98 TotalTimeout {
99 operation: String,
101 timeout: Duration,
103 },
104
105 #[error(
110 "Read timed out after {timeout:?} while streaming response for operation: {operation}. \
111 If this is expected, increase the timeout with \
112 `TimeoutConfig {{ read: Some(Duration::from_secs({})) }}`",
113 timeout.as_secs() * 2
114 )]
115 ReadTimeout {
116 operation: String,
118 timeout: Duration,
120 },
121
122 #[error("Configuration error: {0}")]
124 ConfigurationError(String),
125
126 #[error("Authentication failed: {0}")]
128 AuthenticationFailed(String),
129
130 #[error("Rate limit exceeded")]
132 RateLimitExceeded,
133
134 #[error("Transport not available: {0}")]
136 NotAvailable(String),
137
138 #[error("IO error: {0}")]
140 Io(String),
141
142 #[error("Internal error: {0}")]
144 Internal(String),
145
146 #[error(
152 "Request size ({size} bytes) exceeds maximum allowed ({max} bytes). \
153 If this is expected, increase the limit with \
154 `LimitsConfig {{ max_request_size: Some({}) }}` or use `LimitsConfig::unlimited()` \
155 if running behind an API gateway.",
156 size
157 )]
158 RequestTooLarge {
159 size: usize,
161 max: usize,
163 },
164
165 #[error(
171 "Response size ({size} bytes) exceeds maximum allowed ({max} bytes). \
172 If this is expected, increase the limit with \
173 `LimitsConfig {{ max_response_size: Some({}) }}` or use `LimitsConfig::unlimited()` \
174 if running behind an API gateway.",
175 size
176 )]
177 ResponseTooLarge {
178 size: usize,
180 max: usize,
182 },
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
187#[serde(rename_all = "lowercase")]
188pub enum TransportType {
189 Stdio,
191 Http,
193 WebSocket,
195 Tcp,
197 Unix,
199 ChildProcess,
201 #[cfg(feature = "grpc")]
203 Grpc,
204 #[cfg(feature = "quic")]
206 Quic,
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
211pub enum TransportState {
212 Disconnected,
214 Connecting,
216 Connected,
218 Disconnecting,
220 Failed {
222 reason: String,
224 },
225}
226
227#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
229pub struct TransportCapabilities {
230 pub max_message_size: Option<usize>,
232
233 pub supports_compression: bool,
235
236 pub supports_streaming: bool,
238
239 pub supports_bidirectional: bool,
241
242 pub supports_multiplexing: bool,
244
245 pub compression_algorithms: Vec<String>,
247
248 pub custom: HashMap<String, serde_json::Value>,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct TransportConfig {
255 pub transport_type: TransportType,
257
258 pub connect_timeout: Duration,
260
261 pub read_timeout: Option<Duration>,
263
264 pub write_timeout: Option<Duration>,
266
267 pub keep_alive: Option<Duration>,
269
270 pub max_connections: Option<usize>,
272
273 pub compression: bool,
275
276 pub compression_algorithm: Option<String>,
278
279 #[serde(default)]
284 pub limits: crate::config::LimitsConfig,
285
286 #[serde(default)]
291 pub timeouts: crate::config::TimeoutConfig,
292
293 #[serde(default)]
300 pub tls: crate::config::TlsConfig,
301
302 pub custom: HashMap<String, serde_json::Value>,
304}
305
306#[derive(Debug, Clone)]
308pub struct TransportMessage {
309 pub id: MessageId,
311
312 pub payload: Bytes,
314
315 pub metadata: TransportMessageMetadata,
317}
318
319#[derive(Debug, Clone, Default, Serialize, Deserialize)]
321pub struct TransportMessageMetadata {
322 pub encoding: Option<String>,
324
325 pub content_type: Option<String>,
327
328 pub correlation_id: Option<String>,
330
331 pub headers: HashMap<String, String>,
333
334 pub priority: Option<u8>,
336
337 pub ttl: Option<u64>,
339
340 pub is_heartbeat: Option<bool>,
342}
343
344#[derive(Debug, Clone, Default, Serialize, Deserialize)]
358pub struct TransportMetrics {
359 pub bytes_sent: u64,
361
362 pub bytes_received: u64,
364
365 pub messages_sent: u64,
367
368 pub messages_received: u64,
370
371 pub connections: u64,
373
374 pub failed_connections: u64,
376
377 pub average_latency_ms: f64,
379
380 pub active_connections: u64,
382
383 pub compression_ratio: Option<f64>,
385
386 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
388 pub metadata: HashMap<String, serde_json::Value>,
389}
390
391#[derive(Debug)]
401pub struct AtomicMetrics {
402 pub bytes_sent: std::sync::atomic::AtomicU64,
404
405 pub bytes_received: std::sync::atomic::AtomicU64,
407
408 pub messages_sent: std::sync::atomic::AtomicU64,
410
411 pub messages_received: std::sync::atomic::AtomicU64,
413
414 pub connections: std::sync::atomic::AtomicU64,
416
417 pub failed_connections: std::sync::atomic::AtomicU64,
419
420 pub active_connections: std::sync::atomic::AtomicU64,
422
423 avg_latency_us: std::sync::atomic::AtomicU64,
425
426 uncompressed_bytes: std::sync::atomic::AtomicU64,
428
429 compressed_bytes: std::sync::atomic::AtomicU64,
431}
432
433impl Default for AtomicMetrics {
434 fn default() -> Self {
435 use std::sync::atomic::AtomicU64;
436 Self {
437 bytes_sent: AtomicU64::new(0),
438 bytes_received: AtomicU64::new(0),
439 messages_sent: AtomicU64::new(0),
440 messages_received: AtomicU64::new(0),
441 connections: AtomicU64::new(0),
442 failed_connections: AtomicU64::new(0),
443 active_connections: AtomicU64::new(0),
444 avg_latency_us: AtomicU64::new(0),
445 uncompressed_bytes: AtomicU64::new(0),
446 compressed_bytes: AtomicU64::new(0),
447 }
448 }
449}
450
451impl AtomicMetrics {
452 pub fn new() -> Self {
454 Self::default()
455 }
456
457 pub fn update_latency_us(&self, latency_us: u64) {
464 use std::sync::atomic::Ordering;
465
466 let current = self.avg_latency_us.load(Ordering::Relaxed);
467 let new_avg = if current == 0 {
468 latency_us
469 } else {
470 (current * 9 + latency_us) / 10
472 };
473 self.avg_latency_us.store(new_avg, Ordering::Relaxed);
474 }
475
476 pub fn record_compression(&self, uncompressed_size: u64, compressed_size: u64) {
482 use std::sync::atomic::Ordering;
483
484 self.uncompressed_bytes
485 .fetch_add(uncompressed_size, Ordering::Relaxed);
486 self.compressed_bytes
487 .fetch_add(compressed_size, Ordering::Relaxed);
488 }
489
490 pub fn snapshot(&self) -> TransportMetrics {
494 use std::sync::atomic::Ordering;
495
496 let avg_latency_us = self.avg_latency_us.load(Ordering::Relaxed);
497 let uncompressed = self.uncompressed_bytes.load(Ordering::Relaxed);
498 let compressed = self.compressed_bytes.load(Ordering::Relaxed);
499
500 let compression_ratio = if compressed > 0 && uncompressed > 0 {
501 Some(uncompressed as f64 / compressed as f64)
502 } else {
503 None
504 };
505
506 TransportMetrics {
507 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
508 bytes_received: self.bytes_received.load(Ordering::Relaxed),
509 messages_sent: self.messages_sent.load(Ordering::Relaxed),
510 messages_received: self.messages_received.load(Ordering::Relaxed),
511 connections: self.connections.load(Ordering::Relaxed),
512 failed_connections: self.failed_connections.load(Ordering::Relaxed),
513 active_connections: self.active_connections.load(Ordering::Relaxed),
514 average_latency_ms: (avg_latency_us as f64) / 1000.0, compression_ratio,
516 metadata: HashMap::new(), }
518 }
519
520 pub fn reset(&self) {
522 use std::sync::atomic::Ordering;
523
524 self.bytes_sent.store(0, Ordering::Relaxed);
525 self.bytes_received.store(0, Ordering::Relaxed);
526 self.messages_sent.store(0, Ordering::Relaxed);
527 self.messages_received.store(0, Ordering::Relaxed);
528 self.connections.store(0, Ordering::Relaxed);
529 self.failed_connections.store(0, Ordering::Relaxed);
530 self.active_connections.store(0, Ordering::Relaxed);
531 self.avg_latency_us.store(0, Ordering::Relaxed);
532 self.uncompressed_bytes.store(0, Ordering::Relaxed);
533 self.compressed_bytes.store(0, Ordering::Relaxed);
534 }
535}
536
537#[derive(Debug, Clone)]
539pub enum TransportEvent {
540 Connected {
542 transport_type: TransportType,
544 endpoint: String,
546 },
547
548 Disconnected {
550 transport_type: TransportType,
552 endpoint: String,
554 reason: Option<String>,
556 },
557
558 MessageSent {
560 message_id: MessageId,
562 size: usize,
564 },
565
566 MessageReceived {
568 message_id: MessageId,
570 size: usize,
572 },
573
574 Error {
576 error: TransportError,
578 context: Option<String>,
580 },
581
582 MetricsUpdated {
584 metrics: TransportMetrics,
586 },
587}
588
589#[async_trait]
594pub trait Transport: Send + Sync + std::fmt::Debug {
595 fn transport_type(&self) -> TransportType;
597
598 fn capabilities(&self) -> &TransportCapabilities;
600
601 async fn state(&self) -> TransportState;
603
604 async fn connect(&self) -> TransportResult<()>;
606
607 async fn disconnect(&self) -> TransportResult<()>;
609
610 async fn send(&self, message: TransportMessage) -> TransportResult<()>;
612
613 async fn receive(&self) -> TransportResult<Option<TransportMessage>>;
615
616 async fn metrics(&self) -> TransportMetrics;
618
619 async fn is_connected(&self) -> bool {
621 matches!(self.state().await, TransportState::Connected)
622 }
623
624 fn endpoint(&self) -> Option<String> {
626 None
627 }
628
629 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
631 let _ = config;
633 Ok(())
634 }
635}
636
637#[async_trait]
642pub trait BidirectionalTransport: Transport {
643 async fn send_request(
645 &self,
646 message: TransportMessage,
647 timeout: Option<Duration>,
648 ) -> TransportResult<TransportMessage>;
649
650 async fn start_correlation(&self, correlation_id: String) -> TransportResult<()>;
652
653 async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()>;
655}
656
657#[async_trait]
659pub trait StreamingTransport: Transport {
660 type SendStream: Stream<Item = TransportResult<TransportMessage>> + Send + Unpin;
662
663 type ReceiveStream: Sink<TransportMessage, Error = TransportError> + Send + Unpin;
665
666 async fn send_stream(&self) -> TransportResult<Self::SendStream>;
668
669 async fn receive_stream(&self) -> TransportResult<Self::ReceiveStream>;
671}
672
673pub trait TransportFactory: Send + Sync + std::fmt::Debug {
675 fn transport_type(&self) -> TransportType;
677
678 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>>;
680
681 fn is_available(&self) -> bool {
683 true
684 }
685}
686
687#[derive(Debug, Clone)]
689pub struct TransportEventEmitter {
690 sender: mpsc::Sender<TransportEvent>,
691}
692
693impl TransportEventEmitter {
694 #[must_use]
696 pub fn new() -> (Self, mpsc::Receiver<TransportEvent>) {
697 let (sender, receiver) = mpsc::channel(500); (Self { sender }, receiver)
699 }
700
701 pub fn emit(&self, event: TransportEvent) {
703 if self.sender.try_send(event).is_err() {
705 }
707 }
708
709 pub fn emit_connected(&self, transport_type: TransportType, endpoint: String) {
711 self.emit(TransportEvent::Connected {
712 transport_type,
713 endpoint,
714 });
715 }
716
717 pub fn emit_disconnected(
719 &self,
720 transport_type: TransportType,
721 endpoint: String,
722 reason: Option<String>,
723 ) {
724 self.emit(TransportEvent::Disconnected {
725 transport_type,
726 endpoint,
727 reason,
728 });
729 }
730
731 pub fn emit_message_sent(&self, message_id: MessageId, size: usize) {
733 self.emit(TransportEvent::MessageSent { message_id, size });
734 }
735
736 pub fn emit_message_received(&self, message_id: MessageId, size: usize) {
738 self.emit(TransportEvent::MessageReceived { message_id, size });
739 }
740
741 pub fn emit_error(&self, error: TransportError, context: Option<String>) {
743 self.emit(TransportEvent::Error { error, context });
744 }
745
746 pub fn emit_metrics_updated(&self, metrics: TransportMetrics) {
748 self.emit(TransportEvent::MetricsUpdated { metrics });
749 }
750}
751
752impl Default for TransportEventEmitter {
753 fn default() -> Self {
754 Self::new().0
755 }
756}
757
758impl Default for TransportCapabilities {
761 fn default() -> Self {
762 Self {
763 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
764 supports_compression: false,
765 supports_streaming: false,
766 supports_bidirectional: true,
767 supports_multiplexing: false,
768 compression_algorithms: Vec::new(),
769 custom: HashMap::new(),
770 }
771 }
772}
773
774impl Default for TransportConfig {
775 fn default() -> Self {
776 Self {
777 transport_type: TransportType::Stdio,
778 connect_timeout: Duration::from_secs(30),
779 read_timeout: None,
780 write_timeout: None,
781 keep_alive: None,
782 max_connections: None,
783 compression: false,
784 compression_algorithm: None,
785 limits: crate::config::LimitsConfig::default(),
786 timeouts: crate::config::TimeoutConfig::default(),
787 tls: crate::config::TlsConfig::default(),
788 custom: HashMap::new(),
789 }
790 }
791}
792
793impl TransportMessage {
794 pub fn new(id: MessageId, payload: Bytes) -> Self {
804 Self {
805 id,
806 payload,
807 metadata: TransportMessageMetadata::default(),
808 }
809 }
810
811 pub const fn with_metadata(
813 id: MessageId,
814 payload: Bytes,
815 metadata: TransportMessageMetadata,
816 ) -> Self {
817 Self {
818 id,
819 payload,
820 metadata,
821 }
822 }
823
824 pub const fn size(&self) -> usize {
826 self.payload.len()
827 }
828
829 pub const fn is_compressed(&self) -> bool {
831 self.metadata.encoding.is_some()
832 }
833
834 pub fn content_type(&self) -> Option<&str> {
836 self.metadata.content_type.as_deref()
837 }
838
839 pub fn correlation_id(&self) -> Option<&str> {
841 self.metadata.correlation_id.as_deref()
842 }
843}
844
845impl TransportMessageMetadata {
846 pub fn with_content_type(content_type: impl Into<String>) -> Self {
848 Self {
849 content_type: Some(content_type.into()),
850 ..Default::default()
851 }
852 }
853
854 pub fn with_correlation_id(correlation_id: impl Into<String>) -> Self {
856 Self {
857 correlation_id: Some(correlation_id.into()),
858 ..Default::default()
859 }
860 }
861
862 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
871 self.headers.insert(key.into(), value.into());
872 self
873 }
874
875 #[must_use]
877 pub const fn with_priority(mut self, priority: u8) -> Self {
878 self.priority = Some(priority);
879 self
880 }
881
882 #[must_use]
884 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
885 self.ttl = Some(ttl.as_millis() as u64);
886 self
887 }
888
889 #[must_use]
891 pub const fn heartbeat(mut self) -> Self {
892 self.is_heartbeat = Some(true);
893 self
894 }
895}
896
897impl fmt::Display for TransportType {
898 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
899 match self {
900 Self::Stdio => write!(f, "stdio"),
901 Self::Http => write!(f, "http"),
902 Self::WebSocket => write!(f, "websocket"),
903 Self::Tcp => write!(f, "tcp"),
904 Self::Unix => write!(f, "unix"),
905 Self::ChildProcess => write!(f, "child_process"),
906 #[cfg(feature = "grpc")]
907 Self::Grpc => write!(f, "grpc"),
908 #[cfg(feature = "quic")]
909 Self::Quic => write!(f, "quic"),
910 }
911 }
912}
913
914impl fmt::Display for TransportState {
915 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
916 match self {
917 Self::Disconnected => write!(f, "disconnected"),
918 Self::Connecting => write!(f, "connecting"),
919 Self::Connected => write!(f, "connected"),
920 Self::Disconnecting => write!(f, "disconnecting"),
921 Self::Failed { reason } => write!(f, "failed: {reason}"),
922 }
923 }
924}
925
926impl From<std::io::Error> for TransportError {
927 fn from(err: std::io::Error) -> Self {
928 Self::Io(err.to_string())
929 }
930}
931
932impl From<serde_json::Error> for TransportError {
933 fn from(err: serde_json::Error) -> Self {
934 Self::SerializationFailed(err.to_string())
935 }
936}
937
938pub fn validate_request_size(
960 size: usize,
961 limits: &crate::config::LimitsConfig,
962) -> TransportResult<()> {
963 if let Some(max_size) = limits.max_request_size
964 && size > max_size
965 {
966 return Err(TransportError::RequestTooLarge {
967 size,
968 max: max_size,
969 });
970 }
971 Ok(())
972}
973
974pub fn validate_response_size(
996 size: usize,
997 limits: &crate::config::LimitsConfig,
998) -> TransportResult<()> {
999 if let Some(max_size) = limits.max_response_size
1000 && size > max_size
1001 {
1002 return Err(TransportError::ResponseTooLarge {
1003 size,
1004 max: max_size,
1005 });
1006 }
1007 Ok(())
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::*;
1013 #[test]
1017 fn test_transport_capabilities_default() {
1018 let caps = TransportCapabilities::default();
1019 assert_eq!(
1020 caps.max_message_size,
1021 Some(turbomcp_protocol::MAX_MESSAGE_SIZE)
1022 );
1023 assert!(caps.supports_bidirectional);
1024 }
1025
1026 #[test]
1027 fn test_transport_config_default() {
1028 let config = TransportConfig::default();
1029 assert_eq!(config.transport_type, TransportType::Stdio);
1030 assert_eq!(config.connect_timeout, Duration::from_secs(30));
1031 }
1032
1033 #[test]
1034 fn test_transport_message_creation() {
1035 let id = MessageId::from("test");
1036 let payload = Bytes::from("test payload");
1037 let msg = TransportMessage::new(id.clone(), payload.clone());
1038
1039 assert_eq!(msg.id, id);
1040 assert_eq!(msg.payload, payload);
1041 assert_eq!(msg.size(), 12);
1042 }
1043
1044 #[test]
1045 fn test_transport_message_metadata() {
1046 let metadata = TransportMessageMetadata::default()
1047 .with_header("custom", "value")
1048 .with_priority(5)
1049 .with_ttl(Duration::from_secs(30));
1050
1051 assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
1052 assert_eq!(metadata.priority, Some(5));
1053 assert_eq!(metadata.ttl, Some(30000));
1054 }
1055
1056 #[test]
1057 fn test_transport_types_display() {
1058 assert_eq!(TransportType::Stdio.to_string(), "stdio");
1059 assert_eq!(TransportType::Http.to_string(), "http");
1060 assert_eq!(TransportType::WebSocket.to_string(), "websocket");
1061 assert_eq!(TransportType::Tcp.to_string(), "tcp");
1062 assert_eq!(TransportType::Unix.to_string(), "unix");
1063 }
1064
1065 #[test]
1066 fn test_transport_state_display() {
1067 assert_eq!(TransportState::Connected.to_string(), "connected");
1068 assert_eq!(TransportState::Disconnected.to_string(), "disconnected");
1069 assert_eq!(
1070 TransportState::Failed {
1071 reason: "timeout".to_string()
1072 }
1073 .to_string(),
1074 "failed: timeout"
1075 );
1076 }
1077
1078 #[tokio::test]
1079 async fn test_transport_event_emitter() {
1080 let (emitter, mut receiver) = TransportEventEmitter::new();
1081
1082 emitter.emit_connected(TransportType::Stdio, "stdio://".to_string());
1083
1084 let event = receiver.recv().await.unwrap();
1085 match event {
1086 TransportEvent::Connected {
1087 transport_type,
1088 endpoint,
1089 } => {
1090 assert_eq!(transport_type, TransportType::Stdio);
1091 assert_eq!(endpoint, "stdio://");
1092 }
1093 other => {
1094 eprintln!("Unexpected event variant: {other:?}");
1096 assert!(
1097 matches!(other, TransportEvent::Connected { .. }),
1098 "Expected Connected event"
1099 );
1100 }
1101 }
1102 }
1103}