ultrafast_mcp_transport/streamable_http/
client.rs

1//! Streamable HTTP transport implementation
2//!
3//! This module implements a MCP-compliant Streamable HTTP transport that follows
4//! the MCP specification for stateless request/response communication.
5
6use crate::{Result, Transport, TransportError};
7use async_trait::async_trait;
8
9use ultrafast_mcp_core::protocol::JsonRpcMessage;
10use ultrafast_mcp_core::utils::generate_state;
11
12/// Streamable HTTP client configuration
13#[derive(Debug, Clone)]
14pub struct StreamableHttpClientConfig {
15    pub base_url: String,
16    pub session_id: Option<String>,
17    pub protocol_version: String,
18    pub timeout: std::time::Duration,
19    pub max_retries: u32,
20    pub auth_token: Option<String>,
21    pub oauth_config: Option<ultrafast_mcp_auth::OAuthConfig>,
22    pub auth_method: Option<ultrafast_mcp_auth::AuthMethod>,
23}
24
25impl Default for StreamableHttpClientConfig {
26    fn default() -> Self {
27        Self {
28            base_url: "http://127.0.0.1:8080".to_string(),
29            session_id: None,
30            protocol_version: "2025-06-18".to_string(),
31            timeout: std::time::Duration::from_secs(30),
32            max_retries: 3,
33            auth_token: None,
34            oauth_config: None,
35            auth_method: None,
36        }
37    }
38}
39
40impl StreamableHttpClientConfig {
41    /// Set Bearer token authentication
42    pub fn with_bearer_auth(mut self, token: String) -> Self {
43        self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::bearer(token));
44        self
45    }
46
47    /// Set OAuth authentication
48    pub fn with_oauth_auth(mut self, config: ultrafast_mcp_auth::OAuthConfig) -> Self {
49        self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::oauth(config));
50        self
51    }
52
53    /// Set API key authentication
54    pub fn with_api_key_auth(mut self, api_key: String) -> Self {
55        self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::api_key(api_key));
56        self
57    }
58
59    /// Set API key authentication with custom header name
60    pub fn with_api_key_auth_custom(mut self, api_key: String, header_name: String) -> Self {
61        let api_key_auth =
62            ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
63        let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
64        self.auth_method = Some(auth_method);
65        self
66    }
67
68    /// Set Basic authentication
69    pub fn with_basic_auth(mut self, username: String, password: String) -> Self {
70        self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::basic(username, password));
71        self
72    }
73
74    /// Set custom header authentication
75    pub fn with_custom_auth(mut self) -> Self {
76        self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::custom());
77        self
78    }
79
80    /// Set authentication method
81    pub fn with_auth_method(mut self, auth_method: ultrafast_mcp_auth::AuthMethod) -> Self {
82        self.auth_method = Some(auth_method);
83        self
84    }
85}
86
87/// Streamable HTTP client - MCP-compliant request/response implementation
88pub struct StreamableHttpClient {
89    client: reqwest::Client,
90    config: StreamableHttpClientConfig,
91    session_id: Option<String>,
92    pending_response: Option<JsonRpcMessage>,
93    oauth_client: Option<ultrafast_mcp_auth::OAuthClient>,
94    access_token: Option<String>,
95    token_expiry: Option<std::time::SystemTime>,
96    auth_middleware: Option<ultrafast_mcp_auth::ClientAuthMiddleware>,
97}
98
99impl StreamableHttpClient {
100    pub fn new(config: StreamableHttpClientConfig) -> Result<Self> {
101        let client = reqwest::Client::builder()
102            .timeout(config.timeout)
103            .build()
104            .map_err(|e| TransportError::InitializationError {
105                message: format!("Failed to create HTTP client: {e}"),
106            })?;
107
108        let oauth_client = config
109            .oauth_config
110            .as_ref()
111            .map(|config| ultrafast_mcp_auth::OAuthClient::from_config(config.clone()));
112
113        let access_token = config.auth_token.clone();
114
115        let auth_middleware = config
116            .auth_method
117            .as_ref()
118            .map(|auth_method| ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method.clone()));
119
120        Ok(Self {
121            client,
122            config,
123            session_id: None,
124            pending_response: None,
125            oauth_client,
126            access_token,
127            token_expiry: None,
128            auth_middleware,
129        })
130    }
131
132    /// Authenticate using OAuth 2.1 if configured
133    pub async fn authenticate(&mut self) -> Result<()> {
134        if let Some(oauth_client) = &self.oauth_client {
135            // Generate PKCE parameters
136            let pkce_params = ultrafast_mcp_auth::generate_pkce_params().map_err(|e| {
137                TransportError::AuthenticationError {
138                    message: format!("Failed to generate PKCE: {e}"),
139                }
140            })?;
141
142            // Generate state for CSRF protection
143            let state = generate_state();
144
145            // Get authorization URL
146            let auth_url = oauth_client
147                .get_authorization_url_with_pkce(state, pkce_params.clone())
148                .await
149                .map_err(|e| TransportError::AuthenticationError {
150                    message: format!("Failed to get auth URL: {e}"),
151                })?;
152
153            // In a real implementation, you would:
154            // 1. Open the auth_url in a browser
155            // 2. Wait for user to complete authorization
156            // 3. Receive the authorization code via callback
157            // For now, we'll simulate this with a placeholder
158
159            tracing::info!("OAuth authentication URL: {}", auth_url);
160            tracing::warn!(
161                "OAuth authentication requires manual user interaction. Please complete the flow manually."
162            );
163
164            // For testing purposes, we'll use a mock token
165            self.access_token = Some("mock_oauth_token".to_string());
166            self.token_expiry =
167                Some(std::time::SystemTime::now() + std::time::Duration::from_secs(3600));
168        }
169
170        Ok(())
171    }
172
173    /// Refresh OAuth token if needed
174    async fn refresh_token_if_needed(&mut self) -> Result<()> {
175        if let Some(expiry) = self.token_expiry {
176            if std::time::SystemTime::now() >= expiry {
177                tracing::info!("OAuth token expired, refreshing...");
178                self.authenticate().await?;
179            }
180        }
181        Ok(())
182    }
183
184    /// Get current authentication headers
185    async fn get_auth_headers(&mut self) -> Result<Vec<(String, String)>> {
186        let mut headers = Vec::new();
187
188        // Use auth middleware if available
189        if let Some(auth_middleware) = &mut self.auth_middleware {
190            let auth_headers = auth_middleware.get_headers().await.map_err(|e| {
191                TransportError::AuthenticationError {
192                    message: format!("Failed to get auth headers: {e}"),
193                }
194            })?;
195
196            headers.extend(auth_headers.into_iter());
197        } else {
198            // Fallback to legacy OAuth token handling
199            self.refresh_token_if_needed().await?;
200
201            // Add OAuth token if available
202            if let Some(token) = &self.access_token {
203                headers.push(("Authorization".to_string(), format!("Bearer {token}")));
204            }
205        }
206
207        Ok(headers)
208    }
209
210    /// Connect to the Streamable HTTP server
211    pub async fn connect(&mut self) -> Result<String> {
212        // Authenticate if OAuth is configured
213        if self.oauth_client.is_some() {
214            self.authenticate().await?;
215        }
216
217        // For Streamable HTTP, we just establish a session ID without sending initialize
218        // The client will handle the initialize request separately
219        let session_id = self
220            .config
221            .session_id
222            .clone()
223            .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
224
225        // Store session ID
226        self.session_id = Some(session_id.clone());
227
228        Ok(session_id)
229    }
230
231    /// Send a message and get immediate response
232    async fn send_message_internal(&mut self, message: JsonRpcMessage) -> Result<JsonRpcMessage> {
233        let session_id =
234            self.session_id
235                .clone()
236                .ok_or_else(|| TransportError::ConnectionError {
237                    message: "Not connected".to_string(),
238                })?;
239
240        let url = format!("{}/mcp", self.config.base_url);
241
242        // Get authentication headers
243        let auth_headers = self.get_auth_headers().await?;
244
245        let mut request_builder = self
246            .client
247            .post(&url)
248            .header("content-type", "application/json")
249            .header("accept", "application/json, text/event-stream") // Required Accept header
250            .header("mcp-session-id", session_id)
251            .header("mcp-protocol-version", &self.config.protocol_version)
252            .json(&message); // Send direct JSON-RPC message
253
254        // Add authentication headers
255        for (key, value) in auth_headers {
256            request_builder = request_builder.header(key, value);
257        }
258
259        let response = request_builder
260            .send()
261            .await
262            .map_err(|e| TransportError::NetworkError {
263                message: format!("Failed to send message: {e}"),
264            })?;
265
266        if !response.status().is_success() {
267            let error_text = response.text().await.unwrap_or_default();
268            return Err(TransportError::NetworkError {
269                message: format!("Send failed: {error_text}"),
270            });
271        }
272
273        // Parse the response - it should be a single JSON-RPC message
274        let response_message: JsonRpcMessage =
275            response
276                .json()
277                .await
278                .map_err(|e| TransportError::SerializationError {
279                    message: format!("Failed to parse response: {e}"),
280                })?;
281
282        Ok(response_message)
283    }
284
285    /// Send a JSON-RPC notification (fire-and-forget, do not wait for response)
286    pub async fn send_notification_internal(&mut self, message: JsonRpcMessage) -> Result<()> {
287        let session_id =
288            self.session_id
289                .clone()
290                .ok_or_else(|| TransportError::ConnectionError {
291                    message: "Not connected".to_string(),
292                })?;
293
294        let url = format!("{}/mcp", self.config.base_url);
295
296        // Get authentication headers
297        let auth_headers = self.get_auth_headers().await?;
298
299        let mut request_builder = self
300            .client
301            .post(&url)
302            .header("content-type", "application/json")
303            .header("accept", "application/json, text/event-stream")
304            .header("mcp-session-id", session_id)
305            .header("mcp-protocol-version", &self.config.protocol_version)
306            .json(&message);
307
308        // Add authentication headers
309        for (key, value) in auth_headers {
310            request_builder = request_builder.header(key, value);
311        }
312
313        // Fire and forget: do not block on response
314        let _ = request_builder.send().await;
315        Ok(())
316    }
317
318    /// Get current connection health
319    pub async fn get_health(&mut self) -> crate::TransportHealth {
320        crate::TransportHealth {
321            state: if self.session_id.is_some() {
322                crate::ConnectionState::Connected
323            } else {
324                crate::ConnectionState::Disconnected
325            },
326            connection_duration: None,
327            messages_sent: 0,
328            messages_received: 0,
329            error_count: 0,
330            last_activity: None,
331            last_error: None,
332        }
333    }
334
335    /// Check if the connection is healthy
336    pub async fn is_healthy(&self) -> bool {
337        self.session_id.is_some()
338    }
339
340    /// Reconnect to the server
341    pub async fn reconnect(&mut self) -> Result<()> {
342        self.session_id = None;
343        self.pending_response = None;
344        self.connect().await?;
345        Ok(())
346    }
347
348    /// Reset the client state
349    pub async fn reset(&mut self) -> Result<()> {
350        self.session_id = None;
351        self.pending_response = None;
352        self.access_token = None;
353        self.token_expiry = None;
354        Ok(())
355    }
356
357    /// Start an SSE stream for server-to-client communication
358    pub async fn start_sse_stream(&mut self) -> Result<reqwest::Response> {
359        let session_id =
360            self.session_id
361                .clone()
362                .ok_or_else(|| TransportError::ConnectionError {
363                    message: "Not connected".to_string(),
364                })?;
365
366        let url = format!("{}/mcp", self.config.base_url);
367
368        // Get authentication headers
369        let auth_headers = self.get_auth_headers().await?;
370
371        let mut request_builder = self
372            .client
373            .get(&url)
374            .header("accept", "text/event-stream") // SSE-specific Accept header
375            .header("mcp-session-id", session_id)
376            .header("mcp-protocol-version", &self.config.protocol_version);
377
378        // Add authentication headers
379        for (key, value) in auth_headers {
380            request_builder = request_builder.header(key, value);
381        }
382
383        let response = request_builder
384            .send()
385            .await
386            .map_err(|e| TransportError::NetworkError {
387                message: format!("Failed to start SSE stream: {e}"),
388            })?;
389
390        if !response.status().is_success() {
391            let error_text = response.text().await.unwrap_or_default();
392            return Err(TransportError::NetworkError {
393                message: format!("SSE stream failed: {error_text}"),
394            });
395        }
396
397        Ok(response)
398    }
399
400    /// Resume an SSE stream from a specific event ID
401    pub async fn resume_sse_stream(&mut self, last_event_id: &str) -> Result<reqwest::Response> {
402        let session_id =
403            self.session_id
404                .clone()
405                .ok_or_else(|| TransportError::ConnectionError {
406                    message: "Not connected".to_string(),
407                })?;
408
409        let url = format!("{}/mcp", self.config.base_url);
410
411        // Get authentication headers
412        let auth_headers = self.get_auth_headers().await?;
413
414        let mut request_builder = self
415            .client
416            .get(&url)
417            .header("accept", "text/event-stream")
418            .header("mcp-session-id", session_id)
419            .header("mcp-protocol-version", &self.config.protocol_version)
420            .header("last-event-id", last_event_id); // Resume from specific event
421
422        // Add authentication headers
423        for (key, value) in auth_headers {
424            request_builder = request_builder.header(key, value);
425        }
426
427        let response = request_builder
428            .send()
429            .await
430            .map_err(|e| TransportError::NetworkError {
431                message: format!("Failed to resume SSE stream: {e}"),
432            })?;
433
434        if !response.status().is_success() {
435            let error_text = response.text().await.unwrap_or_default();
436            return Err(TransportError::NetworkError {
437                message: format!("SSE stream resume failed: {error_text}"),
438            });
439        }
440
441        Ok(response)
442    }
443}
444
445#[async_trait]
446impl Transport for StreamableHttpClient {
447    async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
448        // For notifications, use fire-and-forget
449        if matches!(message, JsonRpcMessage::Notification(_)) {
450            self.send_notification_internal(message).await
451        } else {
452            // For requests, wait for response
453            let response = self.send_message_internal(message).await?;
454            self.pending_response = Some(response);
455            Ok(())
456        }
457    }
458
459    async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
460        // Return the pending response if available
461        if let Some(response) = self.pending_response.take() {
462            Ok(response)
463        } else {
464            // No pending response, connection is closed
465            Err(TransportError::ConnectionClosed)
466        }
467    }
468
469    async fn close(&mut self) -> Result<()> {
470        // Close the session using DELETE method
471        if let Some(session_id) = self.session_id.clone() {
472            let url = format!("{}/mcp", self.config.base_url);
473
474            // Get authentication headers
475            let auth_headers = self.get_auth_headers().await?;
476
477            let mut request_builder = self
478                .client
479                .delete(&url)
480                .header("mcp-session-id", session_id)
481                .header("mcp-protocol-version", &self.config.protocol_version);
482
483            // Add authentication headers
484            for (key, value) in auth_headers {
485                request_builder = request_builder.header(key, value);
486            }
487
488            let _ = request_builder.send().await;
489        }
490
491        Ok(())
492    }
493
494    fn get_state(&self) -> crate::ConnectionState {
495        if self.session_id.is_some() {
496            crate::ConnectionState::Connected
497        } else {
498            crate::ConnectionState::Disconnected
499        }
500    }
501
502    fn get_health(&self) -> crate::TransportHealth {
503        // This is a blocking call, so we can't use the async version
504        crate::TransportHealth {
505            state: self.get_state(),
506            last_activity: None,
507            messages_sent: 0,
508            messages_received: 0,
509            connection_duration: None,
510            error_count: 0,
511            last_error: None,
512        }
513    }
514
515    async fn reconnect(&mut self) -> Result<()> {
516        self.reconnect().await
517    }
518
519    async fn reset(&mut self) -> Result<()> {
520        self.reset().await
521    }
522}