turbomcp_transport/
core.rs

1//! Core transport traits, types, and errors.
2//!
3//! This module defines the fundamental abstractions for sending and receiving MCP messages
4//! over different communication protocols. The central piece is the [`Transport`] trait,
5//! which provides a generic interface for all transport implementations.
6
7use std::collections::HashMap;
8use std::fmt;
9use std::time::Duration;
10
11use async_trait::async_trait;
12use bytes::Bytes;
13use futures::{Sink, Stream};
14use serde::{Deserialize, Serialize};
15use thiserror::Error;
16use tokio::sync::mpsc;
17use turbomcp_protocol::MessageId;
18
19/// A specialized `Result` type for transport operations.
20pub type TransportResult<T> = std::result::Result<T, TransportError>;
21
22/// Represents errors that can occur during transport operations.
23#[derive(Error, Debug, Clone)]
24pub enum TransportError {
25    /// Failed to establish a connection.
26    #[error("Connection failed: {0}")]
27    ConnectionFailed(String),
28
29    /// An established connection was lost.
30    #[error("Connection lost: {0}")]
31    ConnectionLost(String),
32
33    /// Failed to send a message.
34    #[error("Send failed: {0}")]
35    SendFailed(String),
36
37    /// Failed to receive a message.
38    #[error("Receive failed: {0}")]
39    ReceiveFailed(String),
40
41    /// Failed to serialize or deserialize a message.
42    #[error("Serialization failed: {0}")]
43    SerializationFailed(String),
44
45    /// A protocol-level error occurred.
46    #[error("Protocol error: {0}")]
47    ProtocolError(String),
48
49    /// The operation did not complete within the specified timeout.
50    #[error("Operation timed out")]
51    Timeout,
52
53    /// The transport was configured with invalid parameters.
54    #[error("Configuration error: {0}")]
55    ConfigurationError(String),
56
57    /// Authentication with the remote endpoint failed.
58    #[error("Authentication failed: {0}")]
59    AuthenticationFailed(String),
60
61    /// The request was rejected due to rate limiting.
62    #[error("Rate limit exceeded")]
63    RateLimitExceeded,
64
65    /// The requested transport is not available.
66    #[error("Transport not available: {0}")]
67    NotAvailable(String),
68
69    /// An underlying I/O error occurred.
70    #[error("IO error: {0}")]
71    Io(String),
72
73    /// An unexpected internal error occurred.
74    #[error("Internal error: {0}")]
75    Internal(String),
76}
77
78/// Enumerates the types of transports supported by the system.
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
80#[serde(rename_all = "lowercase")]
81pub enum TransportType {
82    /// Standard Input/Output, for command-line servers.
83    Stdio,
84    /// HTTP, including Server-Sent Events (SSE).
85    Http,
86    /// WebSocket for full-duplex communication.
87    WebSocket,
88    /// TCP sockets for network communication.
89    Tcp,
90    /// Unix domain sockets for local inter-process communication.
91    Unix,
92    /// A transport that manages a child process.
93    ChildProcess,
94    /// gRPC for high-performance RPC.
95    #[cfg(feature = "grpc")]
96    Grpc,
97    /// QUIC for a modern, multiplexed transport.
98    #[cfg(feature = "quic")]
99    Quic,
100}
101
102/// Represents the current state of a transport connection.
103#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
104pub enum TransportState {
105    /// The transport is not connected.
106    Disconnected,
107    /// The transport is in the process of connecting.
108    Connecting,
109    /// The transport is connected and ready to send/receive messages.
110    Connected,
111    /// The transport is in the process of disconnecting.
112    Disconnecting,
113    /// The transport has encountered an unrecoverable error.
114    Failed {
115        /// A description of the failure reason.
116        reason: String,
117    },
118}
119
120/// Describes the capabilities of a transport implementation.
121#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
122pub struct TransportCapabilities {
123    /// The maximum message size in bytes that the transport can handle.
124    pub max_message_size: Option<usize>,
125
126    /// Whether the transport supports message compression.
127    pub supports_compression: bool,
128
129    /// Whether the transport supports streaming data.
130    pub supports_streaming: bool,
131
132    /// Whether the transport supports full-duplex bidirectional communication.
133    pub supports_bidirectional: bool,
134
135    /// Whether the transport can handle multiple concurrent requests over a single connection.
136    pub supports_multiplexing: bool,
137
138    /// A list of supported compression algorithms.
139    pub compression_algorithms: Vec<String>,
140
141    /// A map for any other custom capabilities.
142    pub custom: HashMap<String, serde_json::Value>,
143}
144
145/// Configuration for a transport instance.
146#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct TransportConfig {
148    /// The type of the transport.
149    pub transport_type: TransportType,
150
151    /// The maximum time to wait for a connection to be established.
152    pub connect_timeout: Duration,
153
154    /// The maximum time to wait for a read operation to complete.
155    pub read_timeout: Option<Duration>,
156
157    /// The maximum time to wait for a write operation to complete.
158    pub write_timeout: Option<Duration>,
159
160    /// The interval for sending keep-alive messages to maintain the connection.
161    pub keep_alive: Option<Duration>,
162
163    /// The maximum number of concurrent connections allowed.
164    pub max_connections: Option<usize>,
165
166    /// Whether to enable message compression.
167    pub compression: bool,
168
169    /// The preferred compression algorithm to use.
170    pub compression_algorithm: Option<String>,
171
172    /// A map for any other custom configuration.
173    pub custom: HashMap<String, serde_json::Value>,
174}
175
176/// A wrapper for a message being sent or received over a transport.
177#[derive(Debug, Clone)]
178pub struct TransportMessage {
179    /// The unique identifier of the message.
180    pub id: MessageId,
181
182    /// The binary payload of the message.
183    pub payload: Bytes,
184
185    /// Metadata associated with the message.
186    pub metadata: TransportMessageMetadata,
187}
188
189/// Metadata associated with a `TransportMessage`.
190#[derive(Debug, Clone, Default, Serialize, Deserialize)]
191pub struct TransportMessageMetadata {
192    /// The encoding of the message payload (e.g., "gzip").
193    pub encoding: Option<String>,
194
195    /// The MIME type of the message payload (e.g., "application/json").
196    pub content_type: Option<String>,
197
198    /// An ID used to correlate requests and responses.
199    pub correlation_id: Option<String>,
200
201    /// A map of custom headers.
202    pub headers: HashMap<String, String>,
203
204    /// The priority of the message (higher numbers indicate higher priority).
205    pub priority: Option<u8>,
206
207    /// The time-to-live for the message, in milliseconds.
208    pub ttl: Option<u64>,
209
210    /// A marker indicating that this is a heartbeat message.
211    pub is_heartbeat: Option<bool>,
212}
213
214/// A serializable snapshot of a transport's performance metrics.
215///
216/// This struct provides a consistent view of metrics for external monitoring.
217/// For internal, high-performance updates, `AtomicMetrics` is preferred.
218///
219/// # Custom Transport Metrics
220/// Transport implementations can store custom metrics in the `metadata` field.
221/// ```
222/// # use turbomcp_transport::core::TransportMetrics;
223/// # use serde_json::json;
224/// let mut metrics = TransportMetrics::default();
225/// metrics.metadata.insert("active_correlations".to_string(), json!(42));
226/// ```
227#[derive(Debug, Clone, Default, Serialize, Deserialize)]
228pub struct TransportMetrics {
229    /// Total number of bytes sent.
230    pub bytes_sent: u64,
231
232    /// Total number of bytes received.
233    pub bytes_received: u64,
234
235    /// Total number of messages sent.
236    pub messages_sent: u64,
237
238    /// Total number of messages received.
239    pub messages_received: u64,
240
241    /// Total number of connection attempts.
242    pub connections: u64,
243
244    /// Total number of failed connection attempts.
245    pub failed_connections: u64,
246
247    /// The average latency of operations, in milliseconds.
248    pub average_latency_ms: f64,
249
250    /// The current number of active connections.
251    pub active_connections: u64,
252
253    /// The compression ratio (uncompressed size / compressed size), if applicable.
254    pub compression_ratio: Option<f64>,
255
256    /// A map for custom, transport-specific metrics.
257    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
258    pub metadata: HashMap<String, serde_json::Value>,
259}
260
261/// A lock-free, atomic structure for high-performance metrics updates.
262///
263/// This struct uses `AtomicU64` for all counters, which is significantly faster
264/// than using mutexes for simple counter updates.
265///
266/// # Performance
267/// - Lock-free increments and decrements.
268/// - No contention on updates.
269/// - Uses `Ordering::Relaxed` for maximum performance where strict ordering is not required.
270#[derive(Debug)]
271pub struct AtomicMetrics {
272    /// Total bytes sent (atomic counter).
273    pub bytes_sent: std::sync::atomic::AtomicU64,
274
275    /// Total bytes received (atomic counter).
276    pub bytes_received: std::sync::atomic::AtomicU64,
277
278    /// Total messages sent (atomic counter).
279    pub messages_sent: std::sync::atomic::AtomicU64,
280
281    /// Total messages received (atomic counter).
282    pub messages_received: std::sync::atomic::AtomicU64,
283
284    /// Total connection attempts (atomic counter).
285    pub connections: std::sync::atomic::AtomicU64,
286
287    /// Failed connection attempts (atomic counter).
288    pub failed_connections: std::sync::atomic::AtomicU64,
289
290    /// Current active connections (atomic counter).
291    pub active_connections: std::sync::atomic::AtomicU64,
292
293    /// The average latency, stored as an exponential moving average in microseconds.
294    avg_latency_us: std::sync::atomic::AtomicU64,
295
296    /// Total bytes before compression.
297    uncompressed_bytes: std::sync::atomic::AtomicU64,
298
299    /// Total bytes after compression.
300    compressed_bytes: std::sync::atomic::AtomicU64,
301}
302
303impl Default for AtomicMetrics {
304    fn default() -> Self {
305        use std::sync::atomic::AtomicU64;
306        Self {
307            bytes_sent: AtomicU64::new(0),
308            bytes_received: AtomicU64::new(0),
309            messages_sent: AtomicU64::new(0),
310            messages_received: AtomicU64::new(0),
311            connections: AtomicU64::new(0),
312            failed_connections: AtomicU64::new(0),
313            active_connections: AtomicU64::new(0),
314            avg_latency_us: AtomicU64::new(0),
315            uncompressed_bytes: AtomicU64::new(0),
316            compressed_bytes: AtomicU64::new(0),
317        }
318    }
319}
320
321impl AtomicMetrics {
322    /// Creates a new `AtomicMetrics` instance with all counters initialized to zero.
323    pub fn new() -> Self {
324        Self::default()
325    }
326
327    /// Updates the average latency using an exponential moving average (EMA).
328    ///
329    /// This method uses an EMA with alpha = 0.1 for smooth latency tracking.
330    ///
331    /// # Arguments
332    /// * `latency_us` - The new latency measurement in microseconds.
333    pub fn update_latency_us(&self, latency_us: u64) {
334        use std::sync::atomic::Ordering;
335
336        let current = self.avg_latency_us.load(Ordering::Relaxed);
337        let new_avg = if current == 0 {
338            latency_us
339        } else {
340            // EMA with alpha = 0.1: new_avg = old_avg * 0.9 + new_value * 0.1
341            (current * 9 + latency_us) / 10
342        };
343        self.avg_latency_us.store(new_avg, Ordering::Relaxed);
344    }
345
346    /// Records compression statistics to track the compression ratio.
347    ///
348    /// # Arguments
349    /// * `uncompressed_size` - The size of the data before compression.
350    /// * `compressed_size` - The size of the data after compression.
351    pub fn record_compression(&self, uncompressed_size: u64, compressed_size: u64) {
352        use std::sync::atomic::Ordering;
353
354        self.uncompressed_bytes
355            .fetch_add(uncompressed_size, Ordering::Relaxed);
356        self.compressed_bytes
357            .fetch_add(compressed_size, Ordering::Relaxed);
358    }
359
360    /// Creates a serializable `TransportMetrics` snapshot from the current atomic values.
361    ///
362    /// This method uses `Ordering::Relaxed` for maximum performance.
363    pub fn snapshot(&self) -> TransportMetrics {
364        use std::sync::atomic::Ordering;
365
366        let avg_latency_us = self.avg_latency_us.load(Ordering::Relaxed);
367        let uncompressed = self.uncompressed_bytes.load(Ordering::Relaxed);
368        let compressed = self.compressed_bytes.load(Ordering::Relaxed);
369
370        let compression_ratio = if compressed > 0 && uncompressed > 0 {
371            Some(uncompressed as f64 / compressed as f64)
372        } else {
373            None
374        };
375
376        TransportMetrics {
377            bytes_sent: self.bytes_sent.load(Ordering::Relaxed),
378            bytes_received: self.bytes_received.load(Ordering::Relaxed),
379            messages_sent: self.messages_sent.load(Ordering::Relaxed),
380            messages_received: self.messages_received.load(Ordering::Relaxed),
381            connections: self.connections.load(Ordering::Relaxed),
382            failed_connections: self.failed_connections.load(Ordering::Relaxed),
383            active_connections: self.active_connections.load(Ordering::Relaxed),
384            average_latency_ms: (avg_latency_us as f64) / 1000.0, // Convert μs to ms
385            compression_ratio,
386            metadata: HashMap::new(), // Empty metadata for base atomic metrics
387        }
388    }
389
390    /// Resets all atomic metric counters to zero.
391    pub fn reset(&self) {
392        use std::sync::atomic::Ordering;
393
394        self.bytes_sent.store(0, Ordering::Relaxed);
395        self.bytes_received.store(0, Ordering::Relaxed);
396        self.messages_sent.store(0, Ordering::Relaxed);
397        self.messages_received.store(0, Ordering::Relaxed);
398        self.connections.store(0, Ordering::Relaxed);
399        self.failed_connections.store(0, Ordering::Relaxed);
400        self.active_connections.store(0, Ordering::Relaxed);
401        self.avg_latency_us.store(0, Ordering::Relaxed);
402        self.uncompressed_bytes.store(0, Ordering::Relaxed);
403        self.compressed_bytes.store(0, Ordering::Relaxed);
404    }
405}
406
407/// Represents events that occur within a transport's lifecycle.
408#[derive(Debug, Clone)]
409pub enum TransportEvent {
410    /// A new connection has been established.
411    Connected {
412        /// The type of the transport that connected.
413        transport_type: TransportType,
414        /// The endpoint of the connection.
415        endpoint: String,
416    },
417
418    /// A connection has been lost.
419    Disconnected {
420        /// The type of the transport that disconnected.
421        transport_type: TransportType,
422        /// The endpoint of the connection.
423        endpoint: String,
424        /// An optional reason for the disconnection.
425        reason: Option<String>,
426    },
427
428    /// A message has been successfully sent.
429    MessageSent {
430        /// The ID of the sent message.
431        message_id: MessageId,
432        /// The size of the sent message in bytes.
433        size: usize,
434    },
435
436    /// A message has been successfully received.
437    MessageReceived {
438        /// The ID of the received message.
439        message_id: MessageId,
440        /// The size of the received message in bytes.
441        size: usize,
442    },
443
444    /// An error has occurred in the transport.
445    Error {
446        /// The error that occurred.
447        error: TransportError,
448        /// Optional additional context about the error.
449        context: Option<String>,
450    },
451
452    /// The transport's metrics have been updated.
453    MetricsUpdated {
454        /// The updated metrics snapshot.
455        metrics: TransportMetrics,
456    },
457}
458
459/// The core trait for all transport implementations.
460///
461/// This trait defines the essential, asynchronous operations for a message-based
462/// communication channel, such as connecting, disconnecting, sending, and receiving.
463#[async_trait]
464pub trait Transport: Send + Sync + std::fmt::Debug {
465    /// Returns the type of this transport.
466    fn transport_type(&self) -> TransportType;
467
468    /// Returns the capabilities of this transport.
469    fn capabilities(&self) -> &TransportCapabilities;
470
471    /// Returns the current state of the transport.
472    async fn state(&self) -> TransportState;
473
474    /// Establishes a connection to the remote endpoint.
475    async fn connect(&self) -> TransportResult<()>;
476
477    /// Closes the connection to the remote endpoint.
478    async fn disconnect(&self) -> TransportResult<()>;
479
480    /// Sends a single message over the transport.
481    async fn send(&self, message: TransportMessage) -> TransportResult<()>;
482
483    /// Receives a single message from the transport in a non-blocking way.
484    async fn receive(&self) -> TransportResult<Option<TransportMessage>>;
485
486    /// Returns a snapshot of the transport's current performance metrics.
487    async fn metrics(&self) -> TransportMetrics;
488
489    /// Returns `true` if the transport is currently in the `Connected` state.
490    async fn is_connected(&self) -> bool {
491        matches!(self.state().await, TransportState::Connected)
492    }
493
494    /// Returns the endpoint address or identifier for this transport, if applicable.
495    fn endpoint(&self) -> Option<String> {
496        None
497    }
498
499    /// Applies a new configuration to the transport.
500    async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
501        // Default implementation does nothing. Transports can override this.
502        let _ = config;
503        Ok(())
504    }
505}
506
507/// A trait for transports that support full-duplex, bidirectional communication.
508///
509/// This extends the base `Transport` trait with the ability to send a request and
510/// await a correlated response.
511#[async_trait]
512pub trait BidirectionalTransport: Transport {
513    /// Sends a request message and waits for a corresponding response.
514    async fn send_request(
515        &self,
516        message: TransportMessage,
517        timeout: Option<Duration>,
518    ) -> TransportResult<TransportMessage>;
519
520    /// Starts tracking a request-response correlation.
521    async fn start_correlation(&self, correlation_id: String) -> TransportResult<()>;
522
523    /// Stops tracking a request-response correlation.
524    async fn stop_correlation(&self, correlation_id: &str) -> TransportResult<()>;
525}
526
527/// A trait for transports that support streaming data.
528#[async_trait]
529pub trait StreamingTransport: Transport {
530    /// The type of the stream used for sending messages.
531    type SendStream: Stream<Item = TransportResult<TransportMessage>> + Send + Unpin;
532
533    /// The type of the sink used for receiving messages.
534    type ReceiveStream: Sink<TransportMessage, Error = TransportError> + Send + Unpin;
535
536    /// Returns a stream for sending messages.
537    async fn send_stream(&self) -> TransportResult<Self::SendStream>;
538
539    /// Returns a sink for receiving messages.
540    async fn receive_stream(&self) -> TransportResult<Self::ReceiveStream>;
541}
542
543/// A factory for creating instances of a specific transport type.
544pub trait TransportFactory: Send + Sync + std::fmt::Debug {
545    /// Returns the type of transport this factory creates.
546    fn transport_type(&self) -> TransportType;
547
548    /// Creates a new transport instance with the given configuration.
549    fn create(&self, config: TransportConfig) -> TransportResult<Box<dyn Transport>>;
550
551    /// Returns `true` if this transport is available on the current system.
552    fn is_available(&self) -> bool {
553        true
554    }
555}
556
557/// An emitter for broadcasting `TransportEvent`s to listeners.
558#[derive(Debug, Clone)]
559pub struct TransportEventEmitter {
560    sender: mpsc::Sender<TransportEvent>,
561}
562
563impl TransportEventEmitter {
564    /// Creates a new event emitter and a corresponding receiver.
565    #[must_use]
566    pub fn new() -> (Self, mpsc::Receiver<TransportEvent>) {
567        let (sender, receiver) = mpsc::channel(500); // Bounded channel for backpressure
568        (Self { sender }, receiver)
569    }
570
571    /// Emits an event, dropping it if the channel is full to avoid blocking.
572    pub fn emit(&self, event: TransportEvent) {
573        // Use try_send for non-blocking event emission.
574        if self.sender.try_send(event).is_err() {
575            // Ignore the error if the channel is full or closed.
576        }
577    }
578
579    /// Emits a `Connected` event.
580    pub fn emit_connected(&self, transport_type: TransportType, endpoint: String) {
581        self.emit(TransportEvent::Connected {
582            transport_type,
583            endpoint,
584        });
585    }
586
587    /// Emits a `Disconnected` event.
588    pub fn emit_disconnected(
589        &self,
590        transport_type: TransportType,
591        endpoint: String,
592        reason: Option<String>,
593    ) {
594        self.emit(TransportEvent::Disconnected {
595            transport_type,
596            endpoint,
597            reason,
598        });
599    }
600
601    /// Emits a `MessageSent` event.
602    pub fn emit_message_sent(&self, message_id: MessageId, size: usize) {
603        self.emit(TransportEvent::MessageSent { message_id, size });
604    }
605
606    /// Emits a `MessageReceived` event.
607    pub fn emit_message_received(&self, message_id: MessageId, size: usize) {
608        self.emit(TransportEvent::MessageReceived { message_id, size });
609    }
610
611    /// Emits an `Error` event.
612    pub fn emit_error(&self, error: TransportError, context: Option<String>) {
613        self.emit(TransportEvent::Error { error, context });
614    }
615
616    /// Emits a `MetricsUpdated` event.
617    pub fn emit_metrics_updated(&self, metrics: TransportMetrics) {
618        self.emit(TransportEvent::MetricsUpdated { metrics });
619    }
620}
621
622impl Default for TransportEventEmitter {
623    fn default() -> Self {
624        Self::new().0
625    }
626}
627
628// Implementations for common types
629
630impl Default for TransportCapabilities {
631    fn default() -> Self {
632        Self {
633            max_message_size: Some(turbomcp_protocol::MAX_MESSAGE_SIZE),
634            supports_compression: false,
635            supports_streaming: false,
636            supports_bidirectional: true,
637            supports_multiplexing: false,
638            compression_algorithms: Vec::new(),
639            custom: HashMap::new(),
640        }
641    }
642}
643
644impl Default for TransportConfig {
645    fn default() -> Self {
646        Self {
647            transport_type: TransportType::Stdio,
648            connect_timeout: Duration::from_secs(30),
649            read_timeout: None,
650            write_timeout: None,
651            keep_alive: None,
652            max_connections: None,
653            compression: false,
654            compression_algorithm: None,
655            custom: HashMap::new(),
656        }
657    }
658}
659
660impl TransportMessage {
661    /// Creates a new `TransportMessage` with a given ID and payload.
662    ///
663    /// # Example
664    /// ```
665    /// # use turbomcp_transport::core::TransportMessage;
666    /// # use turbomcp_protocol::MessageId;
667    /// # use bytes::Bytes;
668    /// let msg = TransportMessage::new(MessageId::from(1), Bytes::from("hello"));
669    /// ```
670    pub fn new(id: MessageId, payload: Bytes) -> Self {
671        Self {
672            id,
673            payload,
674            metadata: TransportMessageMetadata::default(),
675        }
676    }
677
678    /// Creates a new `TransportMessage` with the given ID, payload, and metadata.
679    pub const fn with_metadata(
680        id: MessageId,
681        payload: Bytes,
682        metadata: TransportMessageMetadata,
683    ) -> Self {
684        Self {
685            id,
686            payload,
687            metadata,
688        }
689    }
690
691    /// Returns the size of the message payload in bytes.
692    pub const fn size(&self) -> usize {
693        self.payload.len()
694    }
695
696    /// Returns `true` if the message is compressed.
697    pub const fn is_compressed(&self) -> bool {
698        self.metadata.encoding.is_some()
699    }
700
701    /// Returns the content type of the message, if specified.
702    pub fn content_type(&self) -> Option<&str> {
703        self.metadata.content_type.as_deref()
704    }
705
706    /// Returns the correlation ID of the message, if specified.
707    pub fn correlation_id(&self) -> Option<&str> {
708        self.metadata.correlation_id.as_deref()
709    }
710}
711
712impl TransportMessageMetadata {
713    /// Creates a new `TransportMessageMetadata` with a specified content type.
714    pub fn with_content_type(content_type: impl Into<String>) -> Self {
715        Self {
716            content_type: Some(content_type.into()),
717            ..Default::default()
718        }
719    }
720
721    /// Creates a new `TransportMessageMetadata` with a specified correlation ID.
722    pub fn with_correlation_id(correlation_id: impl Into<String>) -> Self {
723        Self {
724            correlation_id: Some(correlation_id.into()),
725            ..Default::default()
726        }
727    }
728
729    /// Adds a header to the metadata using a builder pattern.
730    ///
731    /// # Example
732    /// ```
733    /// # use turbomcp_transport::core::TransportMessageMetadata;
734    /// let metadata = TransportMessageMetadata::default()
735    ///     .with_header("X-Request-ID", "123");
736    /// ```
737    pub fn with_header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
738        self.headers.insert(key.into(), value.into());
739        self
740    }
741
742    /// Sets the priority of the message.
743    #[must_use]
744    pub const fn with_priority(mut self, priority: u8) -> Self {
745        self.priority = Some(priority);
746        self
747    }
748
749    /// Sets the time-to-live for the message.
750    #[must_use]
751    pub const fn with_ttl(mut self, ttl: Duration) -> Self {
752        self.ttl = Some(ttl.as_millis() as u64);
753        self
754    }
755
756    /// Marks the message as a heartbeat.
757    #[must_use]
758    pub const fn heartbeat(mut self) -> Self {
759        self.is_heartbeat = Some(true);
760        self
761    }
762}
763
764impl fmt::Display for TransportType {
765    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
766        match self {
767            Self::Stdio => write!(f, "stdio"),
768            Self::Http => write!(f, "http"),
769            Self::WebSocket => write!(f, "websocket"),
770            Self::Tcp => write!(f, "tcp"),
771            Self::Unix => write!(f, "unix"),
772            Self::ChildProcess => write!(f, "child_process"),
773            #[cfg(feature = "grpc")]
774            Self::Grpc => write!(f, "grpc"),
775            #[cfg(feature = "quic")]
776            Self::Quic => write!(f, "quic"),
777        }
778    }
779}
780
781impl fmt::Display for TransportState {
782    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
783        match self {
784            Self::Disconnected => write!(f, "disconnected"),
785            Self::Connecting => write!(f, "connecting"),
786            Self::Connected => write!(f, "connected"),
787            Self::Disconnecting => write!(f, "disconnecting"),
788            Self::Failed { reason } => write!(f, "failed: {reason}"),
789        }
790    }
791}
792
793impl From<std::io::Error> for TransportError {
794    fn from(err: std::io::Error) -> Self {
795        Self::Io(err.to_string())
796    }
797}
798
799impl From<serde_json::Error> for TransportError {
800    fn from(err: serde_json::Error) -> Self {
801        Self::SerializationFailed(err.to_string())
802    }
803}
804
805#[cfg(test)]
806mod tests {
807    use super::*;
808    // use std::sync::Arc;
809    // use tokio_test;
810
811    #[test]
812    fn test_transport_capabilities_default() {
813        let caps = TransportCapabilities::default();
814        assert_eq!(
815            caps.max_message_size,
816            Some(turbomcp_protocol::MAX_MESSAGE_SIZE)
817        );
818        assert!(caps.supports_bidirectional);
819    }
820
821    #[test]
822    fn test_transport_config_default() {
823        let config = TransportConfig::default();
824        assert_eq!(config.transport_type, TransportType::Stdio);
825        assert_eq!(config.connect_timeout, Duration::from_secs(30));
826    }
827
828    #[test]
829    fn test_transport_message_creation() {
830        let id = MessageId::from("test");
831        let payload = Bytes::from("test payload");
832        let msg = TransportMessage::new(id.clone(), payload.clone());
833
834        assert_eq!(msg.id, id);
835        assert_eq!(msg.payload, payload);
836        assert_eq!(msg.size(), 12);
837    }
838
839    #[test]
840    fn test_transport_message_metadata() {
841        let metadata = TransportMessageMetadata::default()
842            .with_header("custom", "value")
843            .with_priority(5)
844            .with_ttl(Duration::from_secs(30));
845
846        assert_eq!(metadata.headers.get("custom"), Some(&"value".to_string()));
847        assert_eq!(metadata.priority, Some(5));
848        assert_eq!(metadata.ttl, Some(30000));
849    }
850
851    #[test]
852    fn test_transport_types_display() {
853        assert_eq!(TransportType::Stdio.to_string(), "stdio");
854        assert_eq!(TransportType::Http.to_string(), "http");
855        assert_eq!(TransportType::WebSocket.to_string(), "websocket");
856        assert_eq!(TransportType::Tcp.to_string(), "tcp");
857        assert_eq!(TransportType::Unix.to_string(), "unix");
858    }
859
860    #[test]
861    fn test_transport_state_display() {
862        assert_eq!(TransportState::Connected.to_string(), "connected");
863        assert_eq!(TransportState::Disconnected.to_string(), "disconnected");
864        assert_eq!(
865            TransportState::Failed {
866                reason: "timeout".to_string()
867            }
868            .to_string(),
869            "failed: timeout"
870        );
871    }
872
873    #[tokio::test]
874    async fn test_transport_event_emitter() {
875        let (emitter, mut receiver) = TransportEventEmitter::new();
876
877        emitter.emit_connected(TransportType::Stdio, "stdio://".to_string());
878
879        let event = receiver.recv().await.unwrap();
880        match event {
881            TransportEvent::Connected {
882                transport_type,
883                endpoint,
884            } => {
885                assert_eq!(transport_type, TransportType::Stdio);
886                assert_eq!(endpoint, "stdio://");
887            }
888            other => {
889                // Avoid panic in test to align with production error handling philosophy
890                eprintln!("Unexpected event variant: {other:?}");
891                assert!(
892                    matches!(other, TransportEvent::Connected { .. }),
893                    "Expected Connected event"
894                );
895            }
896        }
897    }
898}