turbomcp_transport/shared.rs
1//! Shared transport wrappers for concurrent access
2//!
3//! This module provides thread-safe wrappers around Transport instances that enable
4//! concurrent access across multiple async tasks without exposing Arc/Mutex complexity.
5
6use async_trait::async_trait;
7use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use crate::core::{
11 Transport, TransportCapabilities, TransportConfig, TransportMessage, TransportMetrics,
12 TransportResult, TransportState, TransportType,
13};
14
15/// Thread-safe wrapper for sharing Transport instances across async tasks
16///
17/// This wrapper encapsulates Arc/Mutex complexity and provides a clean API
18/// for concurrent access to transport functionality. It addresses the limitations
19/// where Transport methods require `&mut self` but need to be shared across
20/// multiple async tasks.
21///
22/// # Design Rationale
23///
24/// Transport methods require `&mut self` because:
25/// - Connection state management requires mutation
26/// - Send/receive operations modify internal buffers and state
27/// - Connect/disconnect operations change connection status
28///
29/// While Transport implements Send + Sync, this only means it's safe to move/share
30/// between threads, not that multiple tasks can mutate it concurrently.
31///
32/// # Examples
33///
34/// ```rust,no_run
35/// use turbomcp_transport::{StdioTransport, SharedTransport};
36///
37/// # async fn example() -> turbomcp_transport::core::TransportResult<()> {
38/// let transport = StdioTransport::new();
39/// let shared = SharedTransport::new(transport);
40///
41/// // Connect once
42/// shared.connect().await?;
43///
44/// // Clone for sharing across tasks
45/// let shared1 = shared.clone();
46/// let shared2 = shared.clone();
47///
48/// // Both tasks can use the transport concurrently
49/// let handle1 = tokio::spawn(async move {
50/// shared1.is_connected().await
51/// });
52///
53/// let handle2 = tokio::spawn(async move {
54/// shared2.metrics().await
55/// });
56///
57/// let (connected, metrics) = tokio::try_join!(handle1, handle2).unwrap();
58/// # Ok(())
59/// # }
60/// ```
61pub struct SharedTransport<T: Transport> {
62 inner: Arc<Mutex<T>>,
63}
64
65impl<T: Transport> SharedTransport<T> {
66 /// Create a new shared transport wrapper
67 ///
68 /// Takes ownership of a Transport and wraps it for thread-safe sharing.
69 /// The original transport can no longer be accessed directly after this call.
70 pub fn new(transport: T) -> Self {
71 Self {
72 inner: Arc::new(Mutex::new(transport)),
73 }
74 }
75
76 /// Get transport type
77 ///
78 /// Returns the type of the underlying transport.
79 pub async fn transport_type(&self) -> TransportType {
80 self.inner.lock().await.transport_type()
81 }
82
83 /// Get transport capabilities
84 ///
85 /// Returns the capabilities of the underlying transport.
86 /// Note: This returns a clone since capabilities are typically small and immutable.
87 pub async fn capabilities(&self) -> TransportCapabilities {
88 self.inner.lock().await.capabilities().clone()
89 }
90
91 /// Get current transport state
92 ///
93 /// Returns the current connection state of the transport.
94 pub async fn state(&self) -> TransportState {
95 self.inner.lock().await.state().await
96 }
97
98 /// Connect to the transport endpoint
99 ///
100 /// Establishes a connection to the transport's target endpoint.
101 /// This method is thread-safe and will serialize connection attempts.
102 pub async fn connect(&self) -> TransportResult<()> {
103 self.inner.lock().await.connect().await
104 }
105
106 /// Disconnect from the transport
107 ///
108 /// Cleanly closes the transport connection.
109 /// This method is thread-safe and will serialize disconnection attempts.
110 pub async fn disconnect(&self) -> TransportResult<()> {
111 self.inner.lock().await.disconnect().await
112 }
113
114 /// Send a message through the transport
115 ///
116 /// Sends a message via the underlying transport. Messages are serialized
117 /// to ensure proper ordering and prevent race conditions.
118 pub async fn send(&self, message: TransportMessage) -> TransportResult<()> {
119 self.inner.lock().await.send(message).await
120 }
121
122 /// Receive a message from the transport
123 ///
124 /// Receives a message from the underlying transport. Receive operations
125 /// are serialized to ensure message ordering and prevent lost messages.
126 pub async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
127 self.inner.lock().await.receive().await
128 }
129
130 /// Get transport metrics
131 ///
132 /// Returns current metrics for the transport including message counts,
133 /// connection status, and performance statistics.
134 pub async fn metrics(&self) -> TransportMetrics {
135 self.inner.lock().await.metrics().await
136 }
137
138 /// Check if transport is connected
139 ///
140 /// Returns true if the transport is currently connected and ready
141 /// for message transmission.
142 pub async fn is_connected(&self) -> bool {
143 self.inner.lock().await.is_connected().await
144 }
145
146 /// Get endpoint information
147 ///
148 /// Returns information about the transport's endpoint configuration.
149 pub async fn endpoint(&self) -> Option<String> {
150 self.inner.lock().await.endpoint()
151 }
152
153 /// Configure the transport
154 ///
155 /// Sets the configuration for the transport.
156 pub async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
157 self.inner.lock().await.configure(config).await
158 }
159}
160
161impl<T: Transport> Clone for SharedTransport<T> {
162 /// Clone the shared transport for use in multiple async tasks
163 ///
164 /// This creates a new reference to the same underlying transport,
165 /// allowing multiple tasks to share access safely.
166 fn clone(&self) -> Self {
167 Self {
168 inner: Arc::clone(&self.inner),
169 }
170 }
171}
172
173impl<T: Transport> std::fmt::Debug for SharedTransport<T> {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("SharedTransport")
176 .field("inner", &"Arc<Mutex<Transport>>")
177 .finish()
178 }
179}
180
181// Implement Transport trait for SharedTransport to enable drop-in replacement
182#[async_trait]
183impl<T: Transport> Transport for SharedTransport<T> {
184 fn transport_type(&self) -> TransportType {
185 // Cannot implement: requires async mutex access
186 // Use SharedTransport::transport_type_async() instead
187 unimplemented!(
188 "SharedTransport::transport_type() cannot be called directly. \
189 Use the async version: transport_type_async()"
190 )
191 }
192
193 fn capabilities(&self) -> &TransportCapabilities {
194 // Cannot implement: cannot return reference from async mutex
195 // Use SharedTransport::capabilities_async() instead
196 unimplemented!(
197 "SharedTransport::capabilities() cannot be called directly. \
198 Use the async version: capabilities_async()"
199 )
200 }
201
202 async fn state(&self) -> TransportState {
203 self.state().await
204 }
205
206 async fn connect(&self) -> TransportResult<()> {
207 self.connect().await
208 }
209
210 async fn disconnect(&self) -> TransportResult<()> {
211 self.disconnect().await
212 }
213
214 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
215 self.send(message).await
216 }
217
218 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
219 self.receive().await
220 }
221
222 async fn metrics(&self) -> TransportMetrics {
223 self.metrics().await
224 }
225
226 async fn is_connected(&self) -> bool {
227 self.is_connected().await
228 }
229
230 fn endpoint(&self) -> Option<String> {
231 // Cannot implement: requires async mutex access
232 // Use SharedTransport::endpoint_async() instead
233 unimplemented!(
234 "SharedTransport::endpoint() cannot be called directly. \
235 Use the async version: endpoint_async()"
236 )
237 }
238
239 async fn configure(&self, config: TransportConfig) -> TransportResult<()> {
240 self.configure(config).await
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::stdio::StdioTransport;
248
249 #[tokio::test]
250 async fn test_shared_transport_creation() {
251 let transport = StdioTransport::new();
252 let shared = SharedTransport::new(transport);
253
254 // Test that we can clone the shared transport
255 let _shared2 = shared.clone();
256 }
257
258 #[tokio::test]
259 async fn test_shared_transport_cloning() {
260 let transport = StdioTransport::new();
261 let shared = SharedTransport::new(transport);
262
263 // Clone multiple times to test Arc behavior
264 let clones: Vec<_> = (0..10).map(|_| shared.clone()).collect();
265 assert_eq!(clones.len(), 10);
266
267 // All clones should reference the same underlying transport
268 // This is verified by the fact that they can all be created without error
269 }
270
271 #[tokio::test]
272 async fn test_shared_transport_api_surface() {
273 let transport = StdioTransport::new();
274 let shared = SharedTransport::new(transport);
275
276 // Test that SharedTransport provides the expected API surface
277 // These calls should compile, verifying the API is properly wrapped
278
279 // Core operations (will fail due to no server, but should compile)
280 let _transport_type = shared.transport_type().await;
281 let _capabilities = shared.capabilities().await;
282 let _state = shared.state().await;
283 let _metrics = shared.metrics().await;
284 let _is_connected = shared.is_connected().await;
285 let _endpoint_info = shared.endpoint().await;
286 }
287
288 #[tokio::test]
289 async fn test_shared_transport_type_compatibility() {
290 let transport = StdioTransport::new();
291 let shared = SharedTransport::new(transport);
292
293 // Test that the SharedTransport can be used in generic contexts
294 fn takes_shared_transport<T>(_transport: T)
295 where
296 T: Clone + Send + Sync + 'static,
297 {
298 }
299
300 takes_shared_transport(shared);
301 }
302
303 #[tokio::test]
304 async fn test_shared_transport_send_sync() {
305 let transport = StdioTransport::new();
306 let shared = SharedTransport::new(transport);
307
308 // Test that SharedTransport can be moved across task boundaries
309 let handle = tokio::spawn(async move {
310 let _cloned = shared.clone();
311 // SharedTransport should be Send + Sync, allowing this to compile
312 });
313
314 handle.await.unwrap();
315 }
316
317 #[tokio::test]
318 async fn test_shared_transport_thread_safety() {
319 let transport = StdioTransport::new();
320 let shared = SharedTransport::new(transport);
321
322 // Test that SharedTransport can be shared across threads safely
323 let shared1 = shared.clone();
324 let shared2 = shared.clone();
325
326 // Verify that concurrent access doesn't corrupt state
327 let handle1 = tokio::spawn(async move { shared1.transport_type().await });
328
329 let handle2 = tokio::spawn(async move { shared2.transport_type().await });
330
331 let (type1, type2) = tokio::join!(handle1, handle2);
332 let type1 = type1.unwrap();
333 let type2 = type2.unwrap();
334
335 // Both should see identical transport types (proving state consistency)
336 assert_eq!(type1, type2);
337 }
338}