1use std::collections::HashMap;
8use std::fmt;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use bytes::Bytes;
13use futures::{Sink, Stream};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use tokio::sync::mpsc;
17use turbomcp_protocol::MessageId;
18
19pub type TransportResult<T> = std::result::Result<T, TransportError>;
21
22#[derive(Error, Debug, Clone)]
24pub enum TransportError {
25 #[error("Connection failed: {0}")]
27 ConnectionFailed(String),
28
29 #[error("Connection lost: {0}")]
31 ConnectionLost(String),
32
33 #[error("Send failed: {0}")]
35 SendFailed(String),
36
37 #[error("Receive failed: {0}")]
39 ReceiveFailed(String),
40
41 #[error("Serialization failed: {0}")]
43 SerializationFailed(String),
44
45 #[error("Protocol error: {0}")]
47 ProtocolError(String),
48
49 #[error("Operation timed out")]
51 Timeout,
52
53 #[error("Configuration error: {0}")]
55 ConfigurationError(String),
56
57 #[error("Authentication failed: {0}")]
59 AuthenticationFailed(String),
60
61 #[error("Rate limit exceeded")]
63 RateLimitExceeded,
64
65 #[error("Transport not available: {0}")]
67 NotAvailable(String),
68
69 #[error("IO error: {0}")]
71 Io(String),
72
73 #[error("Internal error: {0}")]
75 Internal(String),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
80#[serde(rename_all = "lowercase")]
81pub enum TransportType {
82 Stdio,
84 Http,
86 WebSocket,
88 Tcp,
90 Unix,
92 ChildProcess,
94 #[cfg(feature = "grpc")]
96 Grpc,
97 #[cfg(feature = "quic")]
99 Quic,
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub enum TransportState {
105 Disconnected,
107 Connecting,
109 Connected,
111 Disconnecting,
113 Failed {
115 reason: String,
117 },
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct TransportCapabilities {
123 pub max_message_size: Option<usize>,
125
126 pub supports_compression: bool,
128
129 pub supports_streaming: bool,
131
132 pub supports_bidirectional: bool,
134
135 pub supports_multiplexing: bool,
137
138 pub compression_algorithms: Vec<String>,
140
141 pub custom: HashMap<String, serde_json::Value>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct TransportConfig {
148 pub transport_type: TransportType,
150
151 pub connect_timeout: Duration,
153
154 pub read_timeout: Option<Duration>,
156
157 pub write_timeout: Option<Duration>,
159
160 pub keep_alive: Option<Duration>,
162
163 pub max_connections: Option<usize>,
165
166 pub compression: bool,
168
169 pub compression_algorithm: Option<String>,
171
172 pub custom: HashMap<String, serde_json::Value>,
174}
175
176#[derive(Debug, Clone)]
178pub struct TransportMessage {
179 pub id: MessageId,
181
182 pub payload: Bytes,
184
185 pub metadata: TransportMessageMetadata,
187}
188
189#[derive(Debug, Clone, Default, Serialize, Deserialize)]
191pub struct TransportMessageMetadata {
192 pub encoding: Option<String>,
194
195 pub content_type: Option<String>,
197
198 pub correlation_id: Option<String>,
200
201 pub headers: HashMap<String, String>,
203
204 pub priority: Option<u8>,
206
207 pub ttl: Option<u64>,
209
210 pub is_heartbeat: Option<bool>,
212}
213
214#[derive(Debug, Clone, Default, Serialize, Deserialize)]
228pub struct TransportMetrics {
229 pub bytes_sent: u64,
231
232 pub bytes_received: u64,
234
235 pub messages_sent: u64,
237
238 pub messages_received: u64,
240
241 pub connections: u64,
243
244 pub failed_connections: u64,
246
247 pub average_latency_ms: f64,
249
250 pub active_connections: u64,
252
253 pub compression_ratio: Option<f64>,
255
256 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
258 pub metadata: HashMap<String, serde_json::Value>,
259}
260
261#[derive(Debug)]
271pub struct AtomicMetrics {
272 pub bytes_sent: std::sync::atomic::AtomicU64,
274
275 pub bytes_received: std::sync::atomic::AtomicU64,
277
278 pub messages_sent: std::sync::atomic::AtomicU64,
280
281 pub messages_received: std::sync::atomic::AtomicU64,
283
284 pub connections: std::sync::atomic::AtomicU64,
286
287 pub failed_connections: std::sync::atomic::AtomicU64,
289
290 pub active_connections: std::sync::atomic::AtomicU64,
292
293 avg_latency_us: std::sync::atomic::AtomicU64,
295
296 uncompressed_bytes: std::sync::atomic::AtomicU64,
298
299 compressed_bytes: std::sync::atomic::AtomicU64,
301}
302
303impl Default for AtomicMetrics {
304 fn default() -> Self {
305 use std::sync::atomic::AtomicU64;
306 Self {
307 bytes_sent: AtomicU64::new(0),
308 bytes_received: AtomicU64::new(0),
309 messages_sent: AtomicU64::new(0),
310 messages_received: AtomicU64::new(0),
311 connections: AtomicU64::new(0),
312 failed_connections: AtomicU64::new(0),
313 active_connections: AtomicU64::new(0),
314 avg_latency_us: AtomicU64::new(0),
315 uncompressed_bytes: AtomicU64::new(0),
316 compressed_bytes: AtomicU64::new(0),
317 }
318 }
319}
320
321impl AtomicMetrics {
322 pub fn new() -> Self {
324 Self::default()
325 }
326
327 pub fn update_latency_us(&self, latency_us: u64) {
334 use std::sync::atomic::Ordering;
335
336 let current = self.avg_latency_us.load(Ordering::Relaxed);
337 let new_avg = if current == 0 {
338 latency_us
339 } else {
340 (current * 9 + latency_us) / 10
342 };
343 self.avg_latency_us.store(new_avg, Ordering::Relaxed);
344 }
345
346 pub fn record_compression(&self, uncompressed_size: u64, compressed_size: u64) {
352 use std::sync::atomic::Ordering;
353
354 self.uncompressed_bytes
355 .fetch_add(uncompressed_size, Ordering::Relaxed);
356 self.compressed_bytes
357 .fetch_add(compressed_size, Ordering::Relaxed);
358 }
359
360 pub fn snapshot(&self) -> TransportMetrics {
364 use std::sync::atomic::Ordering;
365
366 let avg_latency_us = self.avg_latency_us.load(Ordering::Relaxed);
367 let uncompressed = self.uncompressed_bytes.load(Ordering::Relaxed);
368 let compressed = self.compressed_bytes.load(Ordering::Relaxed);
369
370 let compression_ratio = if compressed > 0 && uncompressed > 0 {
371 Some(uncompressed as f64 / compressed as f64)
372 } else {
373 None
374 };
375
376 TransportMetrics {
377 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
378 bytes_received: self.bytes_received.load(Ordering::Relaxed),
379 messages_sent: self.messages_sent.load(Ordering::Relaxed),
380 messages_received: self.messages_received.load(Ordering::Relaxed),
381 connections: self.connections.load(Ordering::Relaxed),
382 failed_connections: self.failed_connections.load(Ordering::Relaxed),
383 active_connections: self.active_connections.load(Ordering::Relaxed),
384 average_latency_ms: (avg_latency_us as f64) / 1000.0, compression_ratio,
386 metadata: HashMap::new(), }
388 }
389
390 pub fn reset(&self) {
392 use std::sync::atomic::Ordering;
393
394 self.bytes_sent.store(0, Ordering::Relaxed);
395 self.bytes_received.store(0, Ordering::Relaxed);
396 self.messages_sent.store(0, Ordering::Relaxed);
397 self.messages_received.store(0, Ordering::Relaxed);
398 self.connections.store(0, Ordering::Relaxed);
399 self.failed_connections.store(0, Ordering::Relaxed);
400 self.active_connections.store(0, Ordering::Relaxed);
401 self.avg_latency_us.store(0, Ordering::Relaxed);
402 self.uncompressed_bytes.store(0, Ordering::Relaxed);
403 self.compressed_bytes.store(0, Ordering::Relaxed);
404 }
405}
406
407#[derive(Debug, Clone)]
409pub enum TransportEvent {
410 Connected {
412 transport_type: TransportType,
414 endpoint: String,
416 },
417
418 Disconnected {
420 transport_type: TransportType,
422 endpoint: String,
424 reason: Option<String>,
426 },
427
428 MessageSent {
430 message_id: MessageId,
432 size: usize,
434 },
435
436 MessageReceived {
438 message_id: MessageId,
440 size: usize,
442 },
443
444 Error {
446 error: TransportError,
448 context: Option<String>,
450 },
451
452 MetricsUpdated {
454 metrics: TransportMetrics,
456 },
457}
458
459#[async_trait]
464pub trait Transport: Send + Sync + std::fmt::Debug {
465 fn transport_type(&self) -> TransportType;
467
468 fn capabilities(&self) -> &TransportCapabilities;
470
471 async fn state(&self) -> TransportState;
473
474 async fn connect(&self) -> TransportResult<()>;
476
477 async fn disconnect(&self) -> TransportResult<()>;
479
480 async fn send(&self, message: TransportMessage) -> TransportResult<()>;
482
483 async fn receive(&self) -> TransportResult<Option<TransportMessage>>;
485
486 async fn metrics(&self) -> TransportMetrics;
488
489 async fn is_connected(&self) -> bool {
491 matches!(self.state().await, TransportState::Connected)
492 }
493
494 fn endpoint(&self) -> Option<String> {
496 None
497 }
498
499 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
501 let _ = config;
503 Ok(())
504 }
505}
506
507#[async_trait]
512pub trait BidirectionalTransport: Transport {
513 async fn send_request(
515 &self,
516 message: TransportMessage,
517 timeout: Option<Duration>,
518 ) -> TransportResult<TransportMessage>;
519
520 async fn start_correlation(&self, correlation_id: String) -> TransportResult<()>;
522
523 async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()>;
525}
526
527#[async_trait]
529pub trait StreamingTransport: Transport {
530 type SendStream: Stream<Item = TransportResult<TransportMessage>> + Send + Unpin;
532
533 type ReceiveStream: Sink<TransportMessage, Error = TransportError> + Send + Unpin;
535
536 async fn send_stream(&self) -> TransportResult<Self::SendStream>;
538
539 async fn receive_stream(&self) -> TransportResult<Self::ReceiveStream>;
541}
542
543pub trait TransportFactory: Send + Sync + std::fmt::Debug {
545 fn transport_type(&self) -> TransportType;
547
548 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>>;
550
551 fn is_available(&self) -> bool {
553 true
554 }
555}
556
557#[async_trait]
559pub trait TransportEventListener: Send + Sync {
560 async fn on_event(&self, event: TransportEvent);
562}
563
564#[derive(Debug, Clone)]
566pub struct TransportEventEmitter {
567 sender: mpsc::Sender<TransportEvent>,
568}
569
570impl TransportEventEmitter {
571 #[must_use]
573 pub fn new() -> (Self, mpsc::Receiver<TransportEvent>) {
574 let (sender, receiver) = mpsc::channel(500); (Self { sender }, receiver)
576 }
577
578 pub fn emit(&self, event: TransportEvent) {
580 if self.sender.try_send(event).is_err() {
582 }
584 }
585
586 pub fn emit_connected(&self, transport_type: TransportType, endpoint: String) {
588 self.emit(TransportEvent::Connected {
589 transport_type,
590 endpoint,
591 });
592 }
593
594 pub fn emit_disconnected(
596 &self,
597 transport_type: TransportType,
598 endpoint: String,
599 reason: Option<String>,
600 ) {
601 self.emit(TransportEvent::Disconnected {
602 transport_type,
603 endpoint,
604 reason,
605 });
606 }
607
608 pub fn emit_message_sent(&self, message_id: MessageId, size: usize) {
610 self.emit(TransportEvent::MessageSent { message_id, size });
611 }
612
613 pub fn emit_message_received(&self, message_id: MessageId, size: usize) {
615 self.emit(TransportEvent::MessageReceived { message_id, size });
616 }
617
618 pub fn emit_error(&self, error: TransportError, context: Option<String>) {
620 self.emit(TransportEvent::Error { error, context });
621 }
622
623 pub fn emit_metrics_updated(&self, metrics: TransportMetrics) {
625 self.emit(TransportEvent::MetricsUpdated { metrics });
626 }
627}
628
629impl Default for TransportEventEmitter {
630 fn default() -> Self {
631 Self::new().0
632 }
633}
634
635impl Default for TransportCapabilities {
638 fn default() -> Self {
639 Self {
640 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
641 supports_compression: false,
642 supports_streaming: false,
643 supports_bidirectional: true,
644 supports_multiplexing: false,
645 compression_algorithms: Vec::new(),
646 custom: HashMap::new(),
647 }
648 }
649}
650
651impl Default for TransportConfig {
652 fn default() -> Self {
653 Self {
654 transport_type: TransportType::Stdio,
655 connect_timeout: Duration::from_secs(30),
656 read_timeout: None,
657 write_timeout: None,
658 keep_alive: None,
659 max_connections: None,
660 compression: false,
661 compression_algorithm: None,
662 custom: HashMap::new(),
663 }
664 }
665}
666
667impl TransportMessage {
668 pub fn new(id: MessageId, payload: Bytes) -> Self {
678 Self {
679 id,
680 payload,
681 metadata: TransportMessageMetadata::default(),
682 }
683 }
684
685 pub const fn with_metadata(
687 id: MessageId,
688 payload: Bytes,
689 metadata: TransportMessageMetadata,
690 ) -> Self {
691 Self {
692 id,
693 payload,
694 metadata,
695 }
696 }
697
698 pub const fn size(&self) -> usize {
700 self.payload.len()
701 }
702
703 pub const fn is_compressed(&self) -> bool {
705 self.metadata.encoding.is_some()
706 }
707
708 pub fn content_type(&self) -> Option<&str> {
710 self.metadata.content_type.as_deref()
711 }
712
713 pub fn correlation_id(&self) -> Option<&str> {
715 self.metadata.correlation_id.as_deref()
716 }
717}
718
719impl TransportMessageMetadata {
720 pub fn with_content_type(content_type: impl Into<String>) -> Self {
722 Self {
723 content_type: Some(content_type.into()),
724 ..Default::default()
725 }
726 }
727
728 pub fn with_correlation_id(correlation_id: impl Into<String>) -> Self {
730 Self {
731 correlation_id: Some(correlation_id.into()),
732 ..Default::default()
733 }
734 }
735
736 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
745 self.headers.insert(key.into(), value.into());
746 self
747 }
748
749 #[must_use]
751 pub const fn with_priority(mut self, priority: u8) -> Self {
752 self.priority = Some(priority);
753 self
754 }
755
756 #[must_use]
758 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
759 self.ttl = Some(ttl.as_millis() as u64);
760 self
761 }
762
763 #[must_use]
765 pub const fn heartbeat(mut self) -> Self {
766 self.is_heartbeat = Some(true);
767 self
768 }
769}
770
771impl fmt::Display for TransportType {
772 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
773 match self {
774 Self::Stdio => write!(f, "stdio"),
775 Self::Http => write!(f, "http"),
776 Self::WebSocket => write!(f, "websocket"),
777 Self::Tcp => write!(f, "tcp"),
778 Self::Unix => write!(f, "unix"),
779 Self::ChildProcess => write!(f, "child_process"),
780 #[cfg(feature = "grpc")]
781 Self::Grpc => write!(f, "grpc"),
782 #[cfg(feature = "quic")]
783 Self::Quic => write!(f, "quic"),
784 }
785 }
786}
787
788impl fmt::Display for TransportState {
789 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
790 match self {
791 Self::Disconnected => write!(f, "disconnected"),
792 Self::Connecting => write!(f, "connecting"),
793 Self::Connected => write!(f, "connected"),
794 Self::Disconnecting => write!(f, "disconnecting"),
795 Self::Failed { reason } => write!(f, "failed: {reason}"),
796 }
797 }
798}
799
800impl From<std::io::Error> for TransportError {
801 fn from(err: std::io::Error) -> Self {
802 Self::Io(err.to_string())
803 }
804}
805
806impl From<serde_json::Error> for TransportError {
807 fn from(err: serde_json::Error) -> Self {
808 Self::SerializationFailed(err.to_string())
809 }
810}
811
812#[cfg(test)]
813mod tests {
814 use super::*;
815 #[test]
819 fn test_transport_capabilities_default() {
820 let caps = TransportCapabilities::default();
821 assert_eq!(
822 caps.max_message_size,
823 Some(turbomcp_protocol::MAX_MESSAGE_SIZE)
824 );
825 assert!(caps.supports_bidirectional);
826 }
827
828 #[test]
829 fn test_transport_config_default() {
830 let config = TransportConfig::default();
831 assert_eq!(config.transport_type, TransportType::Stdio);
832 assert_eq!(config.connect_timeout, Duration::from_secs(30));
833 }
834
835 #[test]
836 fn test_transport_message_creation() {
837 let id = MessageId::from("test");
838 let payload = Bytes::from("test payload");
839 let msg = TransportMessage::new(id.clone(), payload.clone());
840
841 assert_eq!(msg.id, id);
842 assert_eq!(msg.payload, payload);
843 assert_eq!(msg.size(), 12);
844 }
845
846 #[test]
847 fn test_transport_message_metadata() {
848 let metadata = TransportMessageMetadata::default()
849 .with_header("custom", "value")
850 .with_priority(5)
851 .with_ttl(Duration::from_secs(30));
852
853 assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
854 assert_eq!(metadata.priority, Some(5));
855 assert_eq!(metadata.ttl, Some(30000));
856 }
857
858 #[test]
859 fn test_transport_types_display() {
860 assert_eq!(TransportType::Stdio.to_string(), "stdio");
861 assert_eq!(TransportType::Http.to_string(), "http");
862 assert_eq!(TransportType::WebSocket.to_string(), "websocket");
863 assert_eq!(TransportType::Tcp.to_string(), "tcp");
864 assert_eq!(TransportType::Unix.to_string(), "unix");
865 }
866
867 #[test]
868 fn test_transport_state_display() {
869 assert_eq!(TransportState::Connected.to_string(), "connected");
870 assert_eq!(TransportState::Disconnected.to_string(), "disconnected");
871 assert_eq!(
872 TransportState::Failed {
873 reason: "timeout".to_string()
874 }
875 .to_string(),
876 "failed: timeout"
877 );
878 }
879
880 #[tokio::test]
881 async fn test_transport_event_emitter() {
882 let (emitter, mut receiver) = TransportEventEmitter::new();
883
884 emitter.emit_connected(TransportType::Stdio, "stdio://".to_string());
885
886 let event = receiver.recv().await.unwrap();
887 match event {
888 TransportEvent::Connected {
889 transport_type,
890 endpoint,
891 } => {
892 assert_eq!(transport_type, TransportType::Stdio);
893 assert_eq!(endpoint, "stdio://");
894 }
895 other => {
896 eprintln!("Unexpected event variant: {other:?}");
898 assert!(
899 matches!(other, TransportEvent::Connected { .. }),
900 "Expected Connected event"
901 );
902 }
903 }
904 }
905}