1use std::collections::HashMap;
8use std::fmt;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use bytes::Bytes;
13use futures::{Sink, Stream};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use tokio::sync::mpsc;
17use turbomcp_protocol::MessageId;
18
19pub type TransportResult<T> = std::result::Result<T, TransportError>;
21
22#[derive(Error, Debug, Clone)]
24pub enum TransportError {
25 #[error("Connection failed: {0}")]
27 ConnectionFailed(String),
28
29 #[error("Connection lost: {0}")]
31 ConnectionLost(String),
32
33 #[error("Send failed: {0}")]
35 SendFailed(String),
36
37 #[error("Receive failed: {0}")]
39 ReceiveFailed(String),
40
41 #[error("Serialization failed: {0}")]
43 SerializationFailed(String),
44
45 #[error("Protocol error: {0}")]
47 ProtocolError(String),
48
49 #[error("Operation timed out")]
51 Timeout,
52
53 #[error("Configuration error: {0}")]
55 ConfigurationError(String),
56
57 #[error("Authentication failed: {0}")]
59 AuthenticationFailed(String),
60
61 #[error("Rate limit exceeded")]
63 RateLimitExceeded,
64
65 #[error("Transport not available: {0}")]
67 NotAvailable(String),
68
69 #[error("IO error: {0}")]
71 Io(String),
72
73 #[error("Internal error: {0}")]
75 Internal(String),
76}
77
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
80#[serde(rename_all = "lowercase")]
81pub enum TransportType {
82 Stdio,
84 Http,
86 WebSocket,
88 Tcp,
90 Unix,
92 ChildProcess,
94 #[cfg(feature = "grpc")]
96 Grpc,
97 #[cfg(feature = "quic")]
99 Quic,
100}
101
102#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub enum TransportState {
105 Disconnected,
107 Connecting,
109 Connected,
111 Disconnecting,
113 Failed {
115 reason: String,
117 },
118}
119
120#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct TransportCapabilities {
123 pub max_message_size: Option<usize>,
125
126 pub supports_compression: bool,
128
129 pub supports_streaming: bool,
131
132 pub supports_bidirectional: bool,
134
135 pub supports_multiplexing: bool,
137
138 pub compression_algorithms: Vec<String>,
140
141 pub custom: HashMap<String, serde_json::Value>,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct TransportConfig {
148 pub transport_type: TransportType,
150
151 pub connect_timeout: Duration,
153
154 pub read_timeout: Option<Duration>,
156
157 pub write_timeout: Option<Duration>,
159
160 pub keep_alive: Option<Duration>,
162
163 pub max_connections: Option<usize>,
165
166 pub compression: bool,
168
169 pub compression_algorithm: Option<String>,
171
172 pub custom: HashMap<String, serde_json::Value>,
174}
175
176#[derive(Debug, Clone)]
178pub struct TransportMessage {
179 pub id: MessageId,
181
182 pub payload: Bytes,
184
185 pub metadata: TransportMessageMetadata,
187}
188
189#[derive(Debug, Clone, Default, Serialize, Deserialize)]
191pub struct TransportMessageMetadata {
192 pub encoding: Option<String>,
194
195 pub content_type: Option<String>,
197
198 pub correlation_id: Option<String>,
200
201 pub headers: HashMap<String, String>,
203
204 pub priority: Option<u8>,
206
207 pub ttl: Option<u64>,
209
210 pub is_heartbeat: Option<bool>,
212}
213
214#[derive(Debug, Clone, Default, Serialize, Deserialize)]
228pub struct TransportMetrics {
229 pub bytes_sent: u64,
231
232 pub bytes_received: u64,
234
235 pub messages_sent: u64,
237
238 pub messages_received: u64,
240
241 pub connections: u64,
243
244 pub failed_connections: u64,
246
247 pub average_latency_ms: f64,
249
250 pub active_connections: u64,
252
253 pub compression_ratio: Option<f64>,
255
256 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
258 pub metadata: HashMap<String, serde_json::Value>,
259}
260
261#[derive(Debug)]
271pub struct AtomicMetrics {
272 pub bytes_sent: std::sync::atomic::AtomicU64,
274
275 pub bytes_received: std::sync::atomic::AtomicU64,
277
278 pub messages_sent: std::sync::atomic::AtomicU64,
280
281 pub messages_received: std::sync::atomic::AtomicU64,
283
284 pub connections: std::sync::atomic::AtomicU64,
286
287 pub failed_connections: std::sync::atomic::AtomicU64,
289
290 pub active_connections: std::sync::atomic::AtomicU64,
292
293 avg_latency_us: std::sync::atomic::AtomicU64,
295
296 uncompressed_bytes: std::sync::atomic::AtomicU64,
298
299 compressed_bytes: std::sync::atomic::AtomicU64,
301}
302
303impl Default for AtomicMetrics {
304 fn default() -> Self {
305 use std::sync::atomic::AtomicU64;
306 Self {
307 bytes_sent: AtomicU64::new(0),
308 bytes_received: AtomicU64::new(0),
309 messages_sent: AtomicU64::new(0),
310 messages_received: AtomicU64::new(0),
311 connections: AtomicU64::new(0),
312 failed_connections: AtomicU64::new(0),
313 active_connections: AtomicU64::new(0),
314 avg_latency_us: AtomicU64::new(0),
315 uncompressed_bytes: AtomicU64::new(0),
316 compressed_bytes: AtomicU64::new(0),
317 }
318 }
319}
320
321impl AtomicMetrics {
322 pub fn new() -> Self {
324 Self::default()
325 }
326
327 pub fn update_latency_us(&self, latency_us: u64) {
334 use std::sync::atomic::Ordering;
335
336 let current = self.avg_latency_us.load(Ordering::Relaxed);
337 let new_avg = if current == 0 {
338 latency_us
339 } else {
340 (current * 9 + latency_us) / 10
342 };
343 self.avg_latency_us.store(new_avg, Ordering::Relaxed);
344 }
345
346 pub fn record_compression(&self, uncompressed_size: u64, compressed_size: u64) {
352 use std::sync::atomic::Ordering;
353
354 self.uncompressed_bytes
355 .fetch_add(uncompressed_size, Ordering::Relaxed);
356 self.compressed_bytes
357 .fetch_add(compressed_size, Ordering::Relaxed);
358 }
359
360 pub fn snapshot(&self) -> TransportMetrics {
364 use std::sync::atomic::Ordering;
365
366 let avg_latency_us = self.avg_latency_us.load(Ordering::Relaxed);
367 let uncompressed = self.uncompressed_bytes.load(Ordering::Relaxed);
368 let compressed = self.compressed_bytes.load(Ordering::Relaxed);
369
370 let compression_ratio = if compressed > 0 && uncompressed > 0 {
371 Some(uncompressed as f64 / compressed as f64)
372 } else {
373 None
374 };
375
376 TransportMetrics {
377 bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
378 bytes_received: self.bytes_received.load(Ordering::Relaxed),
379 messages_sent: self.messages_sent.load(Ordering::Relaxed),
380 messages_received: self.messages_received.load(Ordering::Relaxed),
381 connections: self.connections.load(Ordering::Relaxed),
382 failed_connections: self.failed_connections.load(Ordering::Relaxed),
383 active_connections: self.active_connections.load(Ordering::Relaxed),
384 average_latency_ms: (avg_latency_us as f64) / 1000.0, compression_ratio,
386 metadata: HashMap::new(), }
388 }
389
390 pub fn reset(&self) {
392 use std::sync::atomic::Ordering;
393
394 self.bytes_sent.store(0, Ordering::Relaxed);
395 self.bytes_received.store(0, Ordering::Relaxed);
396 self.messages_sent.store(0, Ordering::Relaxed);
397 self.messages_received.store(0, Ordering::Relaxed);
398 self.connections.store(0, Ordering::Relaxed);
399 self.failed_connections.store(0, Ordering::Relaxed);
400 self.active_connections.store(0, Ordering::Relaxed);
401 self.avg_latency_us.store(0, Ordering::Relaxed);
402 self.uncompressed_bytes.store(0, Ordering::Relaxed);
403 self.compressed_bytes.store(0, Ordering::Relaxed);
404 }
405}
406
407#[derive(Debug, Clone)]
409pub enum TransportEvent {
410 Connected {
412 transport_type: TransportType,
414 endpoint: String,
416 },
417
418 Disconnected {
420 transport_type: TransportType,
422 endpoint: String,
424 reason: Option<String>,
426 },
427
428 MessageSent {
430 message_id: MessageId,
432 size: usize,
434 },
435
436 MessageReceived {
438 message_id: MessageId,
440 size: usize,
442 },
443
444 Error {
446 error: TransportError,
448 context: Option<String>,
450 },
451
452 MetricsUpdated {
454 metrics: TransportMetrics,
456 },
457}
458
459#[async_trait]
464pub trait Transport: Send + Sync + std::fmt::Debug {
465 fn transport_type(&self) -> TransportType;
467
468 fn capabilities(&self) -> &TransportCapabilities;
470
471 async fn state(&self) -> TransportState;
473
474 async fn connect(&self) -> TransportResult<()>;
476
477 async fn disconnect(&self) -> TransportResult<()>;
479
480 async fn send(&self, message: TransportMessage) -> TransportResult<()>;
482
483 async fn receive(&self) -> TransportResult<Option<TransportMessage>>;
485
486 async fn metrics(&self) -> TransportMetrics;
488
489 async fn is_connected(&self) -> bool {
491 matches!(self.state().await, TransportState::Connected)
492 }
493
494 fn endpoint(&self) -> Option<String> {
496 None
497 }
498
499 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
501 let _ = config;
503 Ok(())
504 }
505}
506
507#[async_trait]
512pub trait BidirectionalTransport: Transport {
513 async fn send_request(
515 &self,
516 message: TransportMessage,
517 timeout: Option<Duration>,
518 ) -> TransportResult<TransportMessage>;
519
520 async fn start_correlation(&self, correlation_id: String) -> TransportResult<()>;
522
523 async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()>;
525}
526
527#[async_trait]
529pub trait StreamingTransport: Transport {
530 type SendStream: Stream<Item = TransportResult<TransportMessage>> + Send + Unpin;
532
533 type ReceiveStream: Sink<TransportMessage, Error = TransportError> + Send + Unpin;
535
536 async fn send_stream(&self) -> TransportResult<Self::SendStream>;
538
539 async fn receive_stream(&self) -> TransportResult<Self::ReceiveStream>;
541}
542
543pub trait TransportFactory: Send + Sync + std::fmt::Debug {
545 fn transport_type(&self) -> TransportType;
547
548 fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>>;
550
551 fn is_available(&self) -> bool {
553 true
554 }
555}
556
557#[derive(Debug, Clone)]
559pub struct TransportEventEmitter {
560 sender: mpsc::Sender<TransportEvent>,
561}
562
563impl TransportEventEmitter {
564 #[must_use]
566 pub fn new() -> (Self, mpsc::Receiver<TransportEvent>) {
567 let (sender, receiver) = mpsc::channel(500); (Self { sender }, receiver)
569 }
570
571 pub fn emit(&self, event: TransportEvent) {
573 if self.sender.try_send(event).is_err() {
575 }
577 }
578
579 pub fn emit_connected(&self, transport_type: TransportType, endpoint: String) {
581 self.emit(TransportEvent::Connected {
582 transport_type,
583 endpoint,
584 });
585 }
586
587 pub fn emit_disconnected(
589 &self,
590 transport_type: TransportType,
591 endpoint: String,
592 reason: Option<String>,
593 ) {
594 self.emit(TransportEvent::Disconnected {
595 transport_type,
596 endpoint,
597 reason,
598 });
599 }
600
601 pub fn emit_message_sent(&self, message_id: MessageId, size: usize) {
603 self.emit(TransportEvent::MessageSent { message_id, size });
604 }
605
606 pub fn emit_message_received(&self, message_id: MessageId, size: usize) {
608 self.emit(TransportEvent::MessageReceived { message_id, size });
609 }
610
611 pub fn emit_error(&self, error: TransportError, context: Option<String>) {
613 self.emit(TransportEvent::Error { error, context });
614 }
615
616 pub fn emit_metrics_updated(&self, metrics: TransportMetrics) {
618 self.emit(TransportEvent::MetricsUpdated { metrics });
619 }
620}
621
622impl Default for TransportEventEmitter {
623 fn default() -> Self {
624 Self::new().0
625 }
626}
627
628impl Default for TransportCapabilities {
631 fn default() -> Self {
632 Self {
633 max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
634 supports_compression: false,
635 supports_streaming: false,
636 supports_bidirectional: true,
637 supports_multiplexing: false,
638 compression_algorithms: Vec::new(),
639 custom: HashMap::new(),
640 }
641 }
642}
643
644impl Default for TransportConfig {
645 fn default() -> Self {
646 Self {
647 transport_type: TransportType::Stdio,
648 connect_timeout: Duration::from_secs(30),
649 read_timeout: None,
650 write_timeout: None,
651 keep_alive: None,
652 max_connections: None,
653 compression: false,
654 compression_algorithm: None,
655 custom: HashMap::new(),
656 }
657 }
658}
659
660impl TransportMessage {
661 pub fn new(id: MessageId, payload: Bytes) -> Self {
671 Self {
672 id,
673 payload,
674 metadata: TransportMessageMetadata::default(),
675 }
676 }
677
678 pub const fn with_metadata(
680 id: MessageId,
681 payload: Bytes,
682 metadata: TransportMessageMetadata,
683 ) -> Self {
684 Self {
685 id,
686 payload,
687 metadata,
688 }
689 }
690
691 pub const fn size(&self) -> usize {
693 self.payload.len()
694 }
695
696 pub const fn is_compressed(&self) -> bool {
698 self.metadata.encoding.is_some()
699 }
700
701 pub fn content_type(&self) -> Option<&str> {
703 self.metadata.content_type.as_deref()
704 }
705
706 pub fn correlation_id(&self) -> Option<&str> {
708 self.metadata.correlation_id.as_deref()
709 }
710}
711
712impl TransportMessageMetadata {
713 pub fn with_content_type(content_type: impl Into<String>) -> Self {
715 Self {
716 content_type: Some(content_type.into()),
717 ..Default::default()
718 }
719 }
720
721 pub fn with_correlation_id(correlation_id: impl Into<String>) -> Self {
723 Self {
724 correlation_id: Some(correlation_id.into()),
725 ..Default::default()
726 }
727 }
728
729 pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
738 self.headers.insert(key.into(), value.into());
739 self
740 }
741
742 #[must_use]
744 pub const fn with_priority(mut self, priority: u8) -> Self {
745 self.priority = Some(priority);
746 self
747 }
748
749 #[must_use]
751 pub const fn with_ttl(mut self, ttl: Duration) -> Self {
752 self.ttl = Some(ttl.as_millis() as u64);
753 self
754 }
755
756 #[must_use]
758 pub const fn heartbeat(mut self) -> Self {
759 self.is_heartbeat = Some(true);
760 self
761 }
762}
763
764impl fmt::Display for TransportType {
765 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766 match self {
767 Self::Stdio => write!(f, "stdio"),
768 Self::Http => write!(f, "http"),
769 Self::WebSocket => write!(f, "websocket"),
770 Self::Tcp => write!(f, "tcp"),
771 Self::Unix => write!(f, "unix"),
772 Self::ChildProcess => write!(f, "child_process"),
773 #[cfg(feature = "grpc")]
774 Self::Grpc => write!(f, "grpc"),
775 #[cfg(feature = "quic")]
776 Self::Quic => write!(f, "quic"),
777 }
778 }
779}
780
781impl fmt::Display for TransportState {
782 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783 match self {
784 Self::Disconnected => write!(f, "disconnected"),
785 Self::Connecting => write!(f, "connecting"),
786 Self::Connected => write!(f, "connected"),
787 Self::Disconnecting => write!(f, "disconnecting"),
788 Self::Failed { reason } => write!(f, "failed: {reason}"),
789 }
790 }
791}
792
793impl From<std::io::Error> for TransportError {
794 fn from(err: std::io::Error) -> Self {
795 Self::Io(err.to_string())
796 }
797}
798
799impl From<serde_json::Error> for TransportError {
800 fn from(err: serde_json::Error) -> Self {
801 Self::SerializationFailed(err.to_string())
802 }
803}
804
805#[cfg(test)]
806mod tests {
807 use super::*;
808 #[test]
812 fn test_transport_capabilities_default() {
813 let caps = TransportCapabilities::default();
814 assert_eq!(
815 caps.max_message_size,
816 Some(turbomcp_protocol::MAX_MESSAGE_SIZE)
817 );
818 assert!(caps.supports_bidirectional);
819 }
820
821 #[test]
822 fn test_transport_config_default() {
823 let config = TransportConfig::default();
824 assert_eq!(config.transport_type, TransportType::Stdio);
825 assert_eq!(config.connect_timeout, Duration::from_secs(30));
826 }
827
828 #[test]
829 fn test_transport_message_creation() {
830 let id = MessageId::from("test");
831 let payload = Bytes::from("test payload");
832 let msg = TransportMessage::new(id.clone(), payload.clone());
833
834 assert_eq!(msg.id, id);
835 assert_eq!(msg.payload, payload);
836 assert_eq!(msg.size(), 12);
837 }
838
839 #[test]
840 fn test_transport_message_metadata() {
841 let metadata = TransportMessageMetadata::default()
842 .with_header("custom", "value")
843 .with_priority(5)
844 .with_ttl(Duration::from_secs(30));
845
846 assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
847 assert_eq!(metadata.priority, Some(5));
848 assert_eq!(metadata.ttl, Some(30000));
849 }
850
851 #[test]
852 fn test_transport_types_display() {
853 assert_eq!(TransportType::Stdio.to_string(), "stdio");
854 assert_eq!(TransportType::Http.to_string(), "http");
855 assert_eq!(TransportType::WebSocket.to_string(), "websocket");
856 assert_eq!(TransportType::Tcp.to_string(), "tcp");
857 assert_eq!(TransportType::Unix.to_string(), "unix");
858 }
859
860 #[test]
861 fn test_transport_state_display() {
862 assert_eq!(TransportState::Connected.to_string(), "connected");
863 assert_eq!(TransportState::Disconnected.to_string(), "disconnected");
864 assert_eq!(
865 TransportState::Failed {
866 reason: "timeout".to_string()
867 }
868 .to_string(),
869 "failed: timeout"
870 );
871 }
872
873 #[tokio::test]
874 async fn test_transport_event_emitter() {
875 let (emitter, mut receiver) = TransportEventEmitter::new();
876
877 emitter.emit_connected(TransportType::Stdio, "stdio://".to_string());
878
879 let event = receiver.recv().await.unwrap();
880 match event {
881 TransportEvent::Connected {
882 transport_type,
883 endpoint,
884 } => {
885 assert_eq!(transport_type, TransportType::Stdio);
886 assert_eq!(endpoint, "stdio://");
887 }
888 other => {
889 eprintln!("Unexpected event variant: {other:?}");
891 assert!(
892 matches!(other, TransportEvent::Connected { .. }),
893 "Expected Connected event"
894 );
895 }
896 }
897 }
898}