rust_mcp_transport/
transport.rs

1use crate::{error::TransportResult, message_dispatcher::MessageDispatcher};
2use crate::{schema::RequestId, SessionId};
3use async_trait::async_trait;
4use std::{pin::Pin, sync::Arc, time::Duration};
5use tokio::{
6    sync::oneshot::{self, Sender},
7    task::JoinHandle,
8};
9
10/// Default Timeout in milliseconds
11const DEFAULT_TIMEOUT_MSEC: u64 = 60_000;
12
13/// Enum representing a stream that can either be readable or writable.
14/// This allows the reuse of the same traits for both MCP Server and MCP Client,
15/// where the data direction is reversed.
16///
17/// It encapsulates two types of I/O streams:
18/// - `Readable`: A stream that implements the `AsyncRead` trait for reading data asynchronously.
19/// - `Writable`: A stream that implements the `AsyncWrite` trait for writing data asynchronously.
20///
21pub enum IoStream {
22    Readable(Pin<Box<dyn tokio::io::AsyncRead + Send + Sync>>),
23    Writable(Pin<Box<dyn tokio::io::AsyncWrite + Send + Sync>>),
24}
25
26/// Configuration for the transport layer
27#[derive(Debug, Clone)]
28pub struct TransportOptions {
29    /// The timeout in milliseconds for requests.
30    ///
31    /// This value defines the maximum amount of time to wait for a response before
32    /// considering the request as timed out.
33    pub timeout: Duration,
34}
35impl Default for TransportOptions {
36    fn default() -> Self {
37        Self {
38            timeout: Duration::from_millis(DEFAULT_TIMEOUT_MSEC),
39        }
40    }
41}
42
43/// A trait for dispatching MCP (Message Communication Protocol) messages.
44///
45/// This trait is designed to be implemented by components such as clients, servers, or transports
46/// that send and receive messages in the MCP protocol. It defines the interface for transmitting messages,
47/// optionally awaiting responses, writing raw payloads, and handling batch communication.
48///
49/// # Associated Types
50///
51/// - `R`: The response type expected from a message. This must implement deserialization and be safe
52///   for concurrent use in async contexts.
53/// - `S`: The type of the outgoing message sent directly to the wire. Must be serializable.
54/// - `M`: The internal message type used for responses received from a remote peer.
55/// - `OM`: The outgoing message type submitted to the dispatcher. This is the higher-level form of `S`
56///   used by clients or services submitting requests.
57///
58#[async_trait]
59pub trait McpDispatch<R, S, M, OM>: Send + Sync + 'static
60where
61    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
62    S: Clone + Send + Sync + serde::Serialize + 'static,
63    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
64    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
65{
66    /// Sends a raw message represented by type `S` and optionally includes a `request_id`.
67    /// The `request_id` is used when sending a message in response to an MCP request.
68    /// It should match the `request_id` of the original request.
69    async fn send_message(
70        &self,
71        message: S,
72        request_timeout: Option<Duration>,
73    ) -> TransportResult<Option<R>>;
74
75    async fn send(&self, message: OM, timeout: Option<Duration>) -> TransportResult<Option<M>>;
76    async fn send_batch(
77        &self,
78        message: Vec<OM>,
79        timeout: Option<Duration>,
80    ) -> TransportResult<Option<Vec<M>>>;
81
82    /// Writes a string payload to the underlying asynchronous writable stream,
83    /// appending a newline character and flushing the stream afterward.
84    ///
85    async fn write_str(&self, payload: &str, skip_store: bool) -> TransportResult<()>;
86}
87
88/// A trait representing the transport layer for the MCP (Message Communication Protocol).
89///
90/// This trait abstracts the transport layer functionality required to send and receive messages
91/// within an MCP-based system. It provides methods to initialize the transport, send and receive
92/// messages, handle errors, manage pending requests, and implement keep-alive functionality.
93///
94/// # Associated Types
95///
96/// - `R`: The type of message expected to be received from the transport layer. Must be deserializable.
97/// - `S`: The type of message to be sent over the transport layer. Must be serializable.
98/// - `M`: The internal message type used by the dispatcher. Typically this wraps or transforms `R`.
99/// - `OR`: The outbound response type expected to be produced by the dispatcher when handling incoming messages.
100/// - `OM`: The outbound message type that the dispatcher expects to send as a reply to received messages.
101///
102#[async_trait]
103pub trait Transport<R, S, M, OR, OM>: Send + Sync + 'static
104where
105    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
106    S: Clone + Send + Sync + serde::Serialize + 'static,
107    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
108    OR: Clone + Send + Sync + serde::Serialize + 'static,
109    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
110{
111    async fn start(&self) -> TransportResult<tokio_stream::wrappers::ReceiverStream<R>>
112    where
113        MessageDispatcher<M>: McpDispatch<R, OR, M, OM>;
114    fn message_sender(&self) -> Arc<tokio::sync::RwLock<Option<MessageDispatcher<M>>>>;
115    fn error_stream(&self) -> &tokio::sync::RwLock<Option<IoStream>>;
116    async fn shut_down(&self) -> TransportResult<()>;
117    async fn is_shut_down(&self) -> bool;
118    async fn consume_string_payload(&self, payload: &str) -> TransportResult<()>;
119    async fn pending_request_tx(&self, request_id: &RequestId) -> Option<Sender<M>>;
120    async fn keep_alive(
121        &self,
122        interval: Duration,
123        disconnect_tx: oneshot::Sender<()>,
124    ) -> TransportResult<JoinHandle<()>>;
125    async fn session_id(&self) -> Option<SessionId> {
126        None
127    }
128}
129
130/// A composite trait that combines both transport and dispatch capabilities for the MCP protocol.
131///
132/// `TransportDispatcher` unifies the functionality of [`Transport`] and [`McpDispatch`], allowing implementors
133/// to both manage the transport layer and handle message dispatch logic in a single abstraction.
134///
135/// This trait applies to components responsible for the following operations:
136/// - Handle low-level I/O (stream management, payload parsing, lifecycle control)
137/// - Dispatch and route messages, potentially awaiting or sending responses
138///
139/// # Supertraits
140///
141/// - [`Transport<R, S, M, OR, OM>`]: Provides the transport-level operations (starting, shutting down,
142///   receiving messages, etc.).
143/// - [`McpDispatch<R, OR, M, OM>`]: Provides message-sending and dispatching capabilities.
144///
145/// # Associated Types
146///
147/// - `R`: The raw message type expected to be received. Must be deserializable.
148/// - `S`: The message type sent over the transport (often serialized directly to wire).
149/// - `M`: The internal message type used within the dispatcher.
150/// - `OR`: The outbound response type returned from processing a received message.
151/// - `OM`: The outbound message type submitted by clients or application code.
152///
153pub trait TransportDispatcher<R, S, M, OR, OM>:
154    Transport<R, S, M, OR, OM> + McpDispatch<R, OR, M, OM>
155where
156    R: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
157    S: Clone + Send + Sync + serde::Serialize + 'static,
158    M: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
159    OR: Clone + Send + Sync + serde::Serialize + 'static,
160    OM: Clone + Send + Sync + serde::de::DeserializeOwned + 'static,
161{
162}
163
164// pub trait IntoClientTransport {
165//     type TransportType: Transport<
166//         ServerMessages,
167//         MessageFromClient,
168//         ServerMessage,
169//         ClientMessages,
170//         ClientMessage,
171//     >;
172
173//     fn into_transport(self, session_id: Option<SessionId>) -> TransportResult<Self::TransportType>;
174// }
175
176// impl<T> IntoClientTransport for T
177// where
178//     T: Transport<ServerMessages, MessageFromClient, ServerMessage, ClientMessages, ClientMessage>,
179// {
180//     type TransportType = T;
181
182//     fn into_transport(self, _: Option<SessionId>) -> TransportResult<Self::TransportType> {
183//         Ok(self)
184//     }
185// }