turbomcp_transport/
shared.rs

1//! Shared transport wrappers for concurrent access
2//!
3//! This module provides thread-safe wrappers around Transport instances that enable
4//! concurrent access across multiple async tasks without exposing Arc/Mutex complexity.
5
6use async_trait::async_trait;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use crate::core::{
11    Transport, TransportCapabilities, TransportConfig, TransportMessage, TransportMetrics,
12    TransportResult, TransportState, TransportType,
13};
14
15/// Thread-safe wrapper for sharing Transport instances across async tasks
16///
17/// This wrapper encapsulates Arc/Mutex complexity and provides a clean API
18/// for concurrent access to transport functionality. It addresses the limitations
19/// where Transport methods require `&mut self` but need to be shared across
20/// multiple async tasks.
21///
22/// # Design Rationale
23///
24/// Transport methods require `&mut self` because:
25/// - Connection state management requires mutation
26/// - Send/receive operations modify internal buffers and state
27/// - Connect/disconnect operations change connection status
28///
29/// While Transport implements Send + Sync, this only means it's safe to move/share
30/// between threads, not that multiple tasks can mutate it concurrently.
31///
32/// # Examples
33///
34/// ```rust,no_run
35/// use turbomcp_transport::{StdioTransport, SharedTransport};
36///
37/// # async fn example() -> turbomcp_transport::core::TransportResult<()> {
38/// let transport = StdioTransport::new();
39/// let shared = SharedTransport::new(transport);
40///
41/// // Connect once
42/// shared.connect().await?;
43///
44/// // Clone for sharing across tasks
45/// let shared1 = shared.clone();
46/// let shared2 = shared.clone();
47///
48/// // Both tasks can use the transport concurrently
49/// let handle1 = tokio::spawn(async move {
50///     shared1.is_connected().await
51/// });
52///
53/// let handle2 = tokio::spawn(async move {
54///     shared2.metrics().await
55/// });
56///
57/// let (connected, metrics) = tokio::try_join!(handle1, handle2).unwrap();
58/// # Ok(())
59/// # }
60/// ```
61pub struct SharedTransport<T: Transport> {
62    inner: Arc<Mutex<T>>,
63}
64
65impl<T: Transport> SharedTransport<T> {
66    /// Create a new shared transport wrapper
67    ///
68    /// Takes ownership of a Transport and wraps it for thread-safe sharing.
69    /// The original transport can no longer be accessed directly after this call.
70    pub fn new(transport: T) -> Self {
71        Self {
72            inner: Arc::new(Mutex::new(transport)),
73        }
74    }
75
76    /// Get transport type
77    ///
78    /// Returns the type of the underlying transport.
79    pub async fn transport_type(&self) -> TransportType {
80        self.inner.lock().await.transport_type()
81    }
82
83    /// Get transport capabilities
84    ///
85    /// Returns the capabilities of the underlying transport.
86    /// Note: This returns a clone since capabilities are typically small and immutable.
87    pub async fn capabilities(&self) -> TransportCapabilities {
88        self.inner.lock().await.capabilities().clone()
89    }
90
91    /// Get current transport state
92    ///
93    /// Returns the current connection state of the transport.
94    pub async fn state(&self) -> TransportState {
95        self.inner.lock().await.state().await
96    }
97
98    /// Connect to the transport endpoint
99    ///
100    /// Establishes a connection to the transport's target endpoint.
101    /// This method is thread-safe and will serialize connection attempts.
102    pub async fn connect(&self) -> TransportResult<()> {
103        self.inner.lock().await.connect().await
104    }
105
106    /// Disconnect from the transport
107    ///
108    /// Cleanly closes the transport connection.
109    /// This method is thread-safe and will serialize disconnection attempts.
110    pub async fn disconnect(&self) -> TransportResult<()> {
111        self.inner.lock().await.disconnect().await
112    }
113
114    /// Send a message through the transport
115    ///
116    /// Sends a message via the underlying transport. Messages are serialized
117    /// to ensure proper ordering and prevent race conditions.
118    pub async fn send(&self, message: TransportMessage) -> TransportResult<()> {
119        self.inner.lock().await.send(message).await
120    }
121
122    /// Receive a message from the transport
123    ///
124    /// Receives a message from the underlying transport. Receive operations
125    /// are serialized to ensure message ordering and prevent lost messages.
126    pub async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
127        self.inner.lock().await.receive().await
128    }
129
130    /// Get transport metrics
131    ///
132    /// Returns current metrics for the transport including message counts,
133    /// connection status, and performance statistics.
134    pub async fn metrics(&self) -> TransportMetrics {
135        self.inner.lock().await.metrics().await
136    }
137
138    /// Check if transport is connected
139    ///
140    /// Returns true if the transport is currently connected and ready
141    /// for message transmission.
142    pub async fn is_connected(&self) -> bool {
143        self.inner.lock().await.is_connected().await
144    }
145
146    /// Get endpoint information
147    ///
148    /// Returns information about the transport's endpoint configuration.
149    pub async fn endpoint(&self) -> Option<String> {
150        self.inner.lock().await.endpoint()
151    }
152
153    /// Configure the transport
154    ///
155    /// Sets the configuration for the transport.
156    pub async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
157        self.inner.lock().await.configure(config).await
158    }
159}
160
161impl<T: Transport> Clone for SharedTransport<T> {
162    /// Clone the shared transport for use in multiple async tasks
163    ///
164    /// This creates a new reference to the same underlying transport,
165    /// allowing multiple tasks to share access safely.
166    fn clone(&self) -> Self {
167        Self {
168            inner: Arc::clone(&self.inner),
169        }
170    }
171}
172
173impl<T: Transport> std::fmt::Debug for SharedTransport<T> {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("SharedTransport")
176            .field("inner", &"Arc<Mutex<Transport>>")
177            .finish()
178    }
179}
180
181// Implement Transport trait for SharedTransport to enable drop-in replacement
182#[async_trait]
183impl<T: Transport> Transport for SharedTransport<T> {
184    fn transport_type(&self) -> TransportType {
185        // Cannot implement: requires async mutex access
186        // Use SharedTransport::transport_type_async() instead
187        unimplemented!(
188            "SharedTransport::transport_type() cannot be called directly. \
189             Use the async version: transport_type_async()"
190        )
191    }
192
193    fn capabilities(&self) -> &TransportCapabilities {
194        // Cannot implement: cannot return reference from async mutex
195        // Use SharedTransport::capabilities_async() instead
196        unimplemented!(
197            "SharedTransport::capabilities() cannot be called directly. \
198             Use the async version: capabilities_async()"
199        )
200    }
201
202    async fn state(&self) -> TransportState {
203        self.state().await
204    }
205
206    async fn connect(&self) -> TransportResult<()> {
207        self.connect().await
208    }
209
210    async fn disconnect(&self) -> TransportResult<()> {
211        self.disconnect().await
212    }
213
214    async fn send(&self, message: TransportMessage) -> TransportResult<()> {
215        self.send(message).await
216    }
217
218    async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
219        self.receive().await
220    }
221
222    async fn metrics(&self) -> TransportMetrics {
223        self.metrics().await
224    }
225
226    async fn is_connected(&self) -> bool {
227        self.is_connected().await
228    }
229
230    fn endpoint(&self) -> Option<String> {
231        // Cannot implement: requires async mutex access
232        // Use SharedTransport::endpoint_async() instead
233        unimplemented!(
234            "SharedTransport::endpoint() cannot be called directly. \
235             Use the async version: endpoint_async()"
236        )
237    }
238
239    async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
240        self.configure(config).await
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::stdio::StdioTransport;
248
249    #[tokio::test]
250    async fn test_shared_transport_creation() {
251        let transport = StdioTransport::new();
252        let shared = SharedTransport::new(transport);
253
254        // Test that we can clone the shared transport
255        let _shared2 = shared.clone();
256    }
257
258    #[tokio::test]
259    async fn test_shared_transport_cloning() {
260        let transport = StdioTransport::new();
261        let shared = SharedTransport::new(transport);
262
263        // Clone multiple times to test Arc behavior
264        let clones: Vec<_> = (0..10).map(|_| shared.clone()).collect();
265        assert_eq!(clones.len(), 10);
266
267        // All clones should reference the same underlying transport
268        // This is verified by the fact that they can all be created without error
269    }
270
271    #[tokio::test]
272    async fn test_shared_transport_api_surface() {
273        let transport = StdioTransport::new();
274        let shared = SharedTransport::new(transport);
275
276        // Test that SharedTransport provides the expected API surface
277        // These calls should compile, verifying the API is properly wrapped
278
279        // Core operations (will fail due to no server, but should compile)
280        let _transport_type = shared.transport_type().await;
281        let _capabilities = shared.capabilities().await;
282        let _state = shared.state().await;
283        let _metrics = shared.metrics().await;
284        let _is_connected = shared.is_connected().await;
285        let _endpoint_info = shared.endpoint().await;
286    }
287
288    #[tokio::test]
289    async fn test_shared_transport_type_compatibility() {
290        let transport = StdioTransport::new();
291        let shared = SharedTransport::new(transport);
292
293        // Test that the SharedTransport can be used in generic contexts
294        fn takes_shared_transport<T>(_transport: T)
295        where
296            T: Clone + Send + Sync + 'static,
297        {
298        }
299
300        takes_shared_transport(shared);
301    }
302
303    #[tokio::test]
304    async fn test_shared_transport_send_sync() {
305        let transport = StdioTransport::new();
306        let shared = SharedTransport::new(transport);
307
308        // Test that SharedTransport can be moved across task boundaries
309        let handle = tokio::spawn(async move {
310            let _cloned = shared.clone();
311            // SharedTransport should be Send + Sync, allowing this to compile
312        });
313
314        handle.await.unwrap();
315    }
316
317    #[tokio::test]
318    async fn test_shared_transport_thread_safety() {
319        let transport = StdioTransport::new();
320        let shared = SharedTransport::new(transport);
321
322        // Test that SharedTransport can be shared across threads safely
323        let shared1 = shared.clone();
324        let shared2 = shared.clone();
325
326        // Verify that concurrent access doesn't corrupt state
327        let handle1 = tokio::spawn(async move { shared1.transport_type().await });
328
329        let handle2 = tokio::spawn(async move { shared2.transport_type().await });
330
331        let (type1, type2) = tokio::join!(handle1, handle2);
332        let type1 = type1.unwrap();
333        let type2 = type2.unwrap();
334
335        // Both should see identical transport types (proving state consistency)
336        assert_eq!(type1, type2);
337    }
338}