turbomcp_client/client/
protocol.rs

1//! Protocol client for JSON-RPC communication
2//!
3//! This module provides the ProtocolClient which handles the low-level
4//! JSON-RPC protocol communication with MCP servers.
5//!
6//! ## Bidirectional Communication Architecture
7//!
8//! The ProtocolClient uses a MessageDispatcher to solve the bidirectional
9//! communication problem. Instead of directly calling `transport.receive()`,
10//! which created race conditions when multiple code paths tried to receive,
11//! we now use a centralized message routing layer:
12//!
13//! ```text
14//! ProtocolClient::request()
15//!     ↓
16//!   1. Register oneshot channel with dispatcher
17//!   2. Send request via transport
18//!   3. Wait on oneshot channel
19//!     ↓
20//! MessageDispatcher (background task)
21//!     ↓
22//!   Continuously reads transport.receive()
23//!   Routes responses → oneshot channels
24//!   Routes requests → Client handlers
25//! ```
26//!
27//! This ensures there's only ONE consumer of transport.receive(),
28//! eliminating the race condition.
29
30use std::sync::Arc;
31use std::sync::atomic::{AtomicU64, Ordering};
32
33use turbomcp_protocol::jsonrpc::{JsonRpcRequest, JsonRpcVersion};
34use turbomcp_protocol::{Error, Result};
35use turbomcp_transport::{Transport, TransportMessage};
36
37use super::dispatcher::MessageDispatcher;
38
39/// JSON-RPC protocol handler for MCP communication
40///
41/// Handles request/response correlation, serialization, and protocol-level concerns.
42/// This is the abstraction layer between raw Transport and high-level Client APIs.
43///
44/// ## Architecture
45///
46/// The ProtocolClient now uses a MessageDispatcher to handle bidirectional
47/// communication correctly. The dispatcher runs a background task that:
48/// - Reads ALL messages from the transport
49/// - Routes responses to waiting request() calls
50/// - Routes incoming requests to registered handlers
51///
52/// This eliminates race conditions by centralizing all message routing
53/// in a single background task.
54#[derive(Debug)]
55pub(super) struct ProtocolClient<T: Transport> {
56    transport: Arc<T>,
57    dispatcher: Arc<MessageDispatcher>,
58    next_id: AtomicU64,
59}
60
61impl<T: Transport + 'static> ProtocolClient<T> {
62    /// Create a new protocol client with message dispatcher
63    ///
64    /// This automatically starts the message routing background task.
65    pub(super) fn new(transport: T) -> Self {
66        let transport = Arc::new(transport);
67        let dispatcher = MessageDispatcher::new(transport.clone());
68
69        Self {
70            transport,
71            dispatcher,
72            next_id: AtomicU64::new(1),
73        }
74    }
75
76    /// Get the message dispatcher for handler registration
77    ///
78    /// This allows the Client to register request/notification handlers
79    /// with the dispatcher.
80    pub(super) fn dispatcher(&self) -> &Arc<MessageDispatcher> {
81        &self.dispatcher
82    }
83
84    /// Send JSON-RPC request and await typed response
85    ///
86    /// ## New Architecture (v2.0+)
87    ///
88    /// Instead of calling `transport.receive()` directly (which created the
89    /// race condition), this method now:
90    ///
91    /// 1. Registers a oneshot channel with the dispatcher BEFORE sending
92    /// 2. Sends the request via transport
93    /// 3. Waits on the oneshot channel for the response
94    ///
95    /// The dispatcher's background task receives the response and routes it
96    /// to the oneshot channel. This ensures responses always reach the right
97    /// request() call, even when the server sends requests (elicitation, etc.)
98    /// in between.
99    ///
100    /// ## Example Flow with Elicitation
101    ///
102    /// ```text
103    /// Client: call_tool("test") → request(id=1)
104    ///   1. Register oneshot channel for id=1
105    ///   2. Send tools/call request
106    ///   3. Wait on channel...
107    ///
108    /// Server: Sends elicitation/create request (id=2)
109    ///   → Dispatcher routes to request handler
110    ///   → Client processes elicitation
111    ///   → Client sends elicitation response
112    ///
113    /// Server: Sends tools/call response (id=1)
114    ///   → Dispatcher routes to oneshot channel for id=1
115    ///   → request() receives response ✓
116    /// ```
117    pub(super) async fn request<R: serde::de::DeserializeOwned>(
118        &self,
119        method: &str,
120        params: Option<serde_json::Value>,
121    ) -> Result<R> {
122        // Generate unique request ID
123        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
124        let request_id = turbomcp_protocol::MessageId::from(id.to_string());
125
126        // Build JSON-RPC request
127        let request = JsonRpcRequest {
128            jsonrpc: JsonRpcVersion,
129            id: request_id.clone(),
130            method: method.to_string(),
131            params,
132        };
133
134        // Step 1: Register oneshot channel BEFORE sending request
135        // This ensures the dispatcher can route the response when it arrives
136        let response_receiver = self.dispatcher.wait_for_response(request_id.clone());
137
138        // Step 2: Serialize and send request
139        let payload = serde_json::to_vec(&request)
140            .map_err(|e| Error::protocol(format!("Failed to serialize request: {e}")))?;
141
142        let message = TransportMessage::new(
143            turbomcp_protocol::MessageId::from(format!("req-{id}")),
144            payload.into(),
145        );
146
147        self.transport
148            .send(message)
149            .await
150            .map_err(|e| Error::transport(format!("Transport send failed: {e}")))?;
151
152        // Step 3: Wait for response via oneshot channel
153        // The dispatcher's background task will send the response when it arrives
154        let response = response_receiver
155            .await
156            .map_err(|_| Error::transport("Response channel closed".to_string()))?;
157
158        // Handle JSON-RPC errors
159        if let Some(error) = response.error() {
160            return Err(Error::rpc(error.code, &error.message));
161        }
162
163        // Deserialize result
164        serde_json::from_value(response.result().unwrap_or_default().clone())
165            .map_err(|e| Error::protocol(format!("Failed to deserialize response: {e}")))
166    }
167
168    /// Send JSON-RPC notification (no response expected)
169    pub(super) async fn notify(
170        &self,
171        method: &str,
172        params: Option<serde_json::Value>,
173    ) -> Result<()> {
174        let request = serde_json::json!({
175            "jsonrpc": "2.0",
176            "method": method,
177            "params": params
178        });
179
180        let payload = serde_json::to_vec(&request)
181            .map_err(|e| Error::protocol(format!("Failed to serialize notification: {e}")))?;
182
183        let message = TransportMessage::new(
184            turbomcp_protocol::MessageId::from("notification"),
185            payload.into(),
186        );
187
188        self.transport
189            .send(message)
190            .await
191            .map_err(|e| Error::transport(format!("Transport send failed: {e}")))
192    }
193
194    /// Connect the transport
195    #[allow(dead_code)] // Reserved for future use
196    pub(super) async fn connect(&self) -> Result<()> {
197        self.transport
198            .connect()
199            .await
200            .map_err(|e| Error::transport(format!("Transport connect failed: {e}")))
201    }
202
203    /// Disconnect the transport
204    #[allow(dead_code)] // Reserved for future use
205    pub(super) async fn disconnect(&self) -> Result<()> {
206        self.transport
207            .disconnect()
208            .await
209            .map_err(|e| Error::transport(format!("Transport disconnect failed: {e}")))
210    }
211
212    /// Get transport reference
213    ///
214    /// Returns an Arc reference to the transport, allowing it to be shared
215    /// with other components (like the message dispatcher).
216    pub(super) fn transport(&self) -> &Arc<T> {
217        &self.transport
218    }
219}