ultrafast_mcp_transport/streamable_http/
client.rs1use crate::{Result, Transport, TransportError};
7use async_trait::async_trait;
8
9use ultrafast_mcp_core::protocol::JsonRpcMessage;
10use ultrafast_mcp_core::utils::generate_state;
11
12#[derive(Debug, Clone)]
14pub struct StreamableHttpClientConfig {
15 pub base_url: String,
16 pub session_id: Option<String>,
17 pub protocol_version: String,
18 pub timeout: std::time::Duration,
19 pub max_retries: u32,
20 pub auth_token: Option<String>,
21 pub oauth_config: Option<ultrafast_mcp_auth::OAuthConfig>,
22 pub auth_method: Option<ultrafast_mcp_auth::AuthMethod>,
23}
24
25impl Default for StreamableHttpClientConfig {
26 fn default() -> Self {
27 Self {
28 base_url: "http://127.0.0.1:8080".to_string(),
29 session_id: None,
30 protocol_version: "2025-06-18".to_string(),
31 timeout: std::time::Duration::from_secs(30),
32 max_retries: 3,
33 auth_token: None,
34 oauth_config: None,
35 auth_method: None,
36 }
37 }
38}
39
40impl StreamableHttpClientConfig {
41 pub fn with_bearer_auth(mut self, token: String) -> Self {
43 self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::bearer(token));
44 self
45 }
46
47 pub fn with_oauth_auth(mut self, config: ultrafast_mcp_auth::OAuthConfig) -> Self {
49 self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::oauth(config));
50 self
51 }
52
53 pub fn with_api_key_auth(mut self, api_key: String) -> Self {
55 self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::api_key(api_key));
56 self
57 }
58
59 pub fn with_api_key_auth_custom(mut self, api_key: String, header_name: String) -> Self {
61 let api_key_auth =
62 ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
63 let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
64 self.auth_method = Some(auth_method);
65 self
66 }
67
68 pub fn with_basic_auth(mut self, username: String, password: String) -> Self {
70 self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::basic(username, password));
71 self
72 }
73
74 pub fn with_custom_auth(mut self) -> Self {
76 self.auth_method = Some(ultrafast_mcp_auth::AuthMethod::custom());
77 self
78 }
79
80 pub fn with_auth_method(mut self, auth_method: ultrafast_mcp_auth::AuthMethod) -> Self {
82 self.auth_method = Some(auth_method);
83 self
84 }
85}
86
87pub struct StreamableHttpClient {
89 client: reqwest::Client,
90 config: StreamableHttpClientConfig,
91 session_id: Option<String>,
92 pending_response: Option<JsonRpcMessage>,
93 oauth_client: Option<ultrafast_mcp_auth::OAuthClient>,
94 access_token: Option<String>,
95 token_expiry: Option<std::time::SystemTime>,
96 auth_middleware: Option<ultrafast_mcp_auth::ClientAuthMiddleware>,
97}
98
99impl StreamableHttpClient {
100 pub fn new(config: StreamableHttpClientConfig) -> Result<Self> {
101 let client = reqwest::Client::builder()
102 .timeout(config.timeout)
103 .build()
104 .map_err(|e| TransportError::InitializationError {
105 message: format!("Failed to create HTTP client: {e}"),
106 })?;
107
108 let oauth_client = config
109 .oauth_config
110 .as_ref()
111 .map(|config| ultrafast_mcp_auth::OAuthClient::from_config(config.clone()));
112
113 let access_token = config.auth_token.clone();
114
115 let auth_middleware = config
116 .auth_method
117 .as_ref()
118 .map(|auth_method| ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method.clone()));
119
120 Ok(Self {
121 client,
122 config,
123 session_id: None,
124 pending_response: None,
125 oauth_client,
126 access_token,
127 token_expiry: None,
128 auth_middleware,
129 })
130 }
131
132 pub async fn authenticate(&mut self) -> Result<()> {
134 if let Some(oauth_client) = &self.oauth_client {
135 let pkce_params = ultrafast_mcp_auth::generate_pkce_params().map_err(|e| {
137 TransportError::AuthenticationError {
138 message: format!("Failed to generate PKCE: {e}"),
139 }
140 })?;
141
142 let state = generate_state();
144
145 let auth_url = oauth_client
147 .get_authorization_url_with_pkce(state, pkce_params.clone())
148 .await
149 .map_err(|e| TransportError::AuthenticationError {
150 message: format!("Failed to get auth URL: {e}"),
151 })?;
152
153 tracing::info!("OAuth authentication URL: {}", auth_url);
160 tracing::warn!(
161 "OAuth authentication requires manual user interaction. Please complete the flow manually."
162 );
163
164 self.access_token = Some("mock_oauth_token".to_string());
166 self.token_expiry =
167 Some(std::time::SystemTime::now() + std::time::Duration::from_secs(3600));
168 }
169
170 Ok(())
171 }
172
173 async fn refresh_token_if_needed(&mut self) -> Result<()> {
175 if let Some(expiry) = self.token_expiry {
176 if std::time::SystemTime::now() >= expiry {
177 tracing::info!("OAuth token expired, refreshing...");
178 self.authenticate().await?;
179 }
180 }
181 Ok(())
182 }
183
184 async fn get_auth_headers(&mut self) -> Result<Vec<(String, String)>> {
186 let mut headers = Vec::new();
187
188 if let Some(auth_middleware) = &mut self.auth_middleware {
190 let auth_headers = auth_middleware.get_headers().await.map_err(|e| {
191 TransportError::AuthenticationError {
192 message: format!("Failed to get auth headers: {e}"),
193 }
194 })?;
195
196 headers.extend(auth_headers.into_iter());
197 } else {
198 self.refresh_token_if_needed().await?;
200
201 if let Some(token) = &self.access_token {
203 headers.push(("Authorization".to_string(), format!("Bearer {token}")));
204 }
205 }
206
207 Ok(headers)
208 }
209
210 pub async fn connect(&mut self) -> Result<String> {
212 if self.oauth_client.is_some() {
214 self.authenticate().await?;
215 }
216
217 let session_id = self
220 .config
221 .session_id
222 .clone()
223 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
224
225 self.session_id = Some(session_id.clone());
227
228 Ok(session_id)
229 }
230
231 async fn send_message_internal(&mut self, message: JsonRpcMessage) -> Result<JsonRpcMessage> {
233 let session_id =
234 self.session_id
235 .clone()
236 .ok_or_else(|| TransportError::ConnectionError {
237 message: "Not connected".to_string(),
238 })?;
239
240 let url = format!("{}/mcp", self.config.base_url);
241
242 let auth_headers = self.get_auth_headers().await?;
244
245 let mut request_builder = self
246 .client
247 .post(&url)
248 .header("content-type", "application/json")
249 .header("accept", "application/json, text/event-stream") .header("mcp-session-id", session_id)
251 .header("mcp-protocol-version", &self.config.protocol_version)
252 .json(&message); for (key, value) in auth_headers {
256 request_builder = request_builder.header(key, value);
257 }
258
259 let response = request_builder
260 .send()
261 .await
262 .map_err(|e| TransportError::NetworkError {
263 message: format!("Failed to send message: {e}"),
264 })?;
265
266 if !response.status().is_success() {
267 let error_text = response.text().await.unwrap_or_default();
268 return Err(TransportError::NetworkError {
269 message: format!("Send failed: {error_text}"),
270 });
271 }
272
273 let response_message: JsonRpcMessage =
275 response
276 .json()
277 .await
278 .map_err(|e| TransportError::SerializationError {
279 message: format!("Failed to parse response: {e}"),
280 })?;
281
282 Ok(response_message)
283 }
284
285 pub async fn send_notification_internal(&mut self, message: JsonRpcMessage) -> Result<()> {
287 let session_id =
288 self.session_id
289 .clone()
290 .ok_or_else(|| TransportError::ConnectionError {
291 message: "Not connected".to_string(),
292 })?;
293
294 let url = format!("{}/mcp", self.config.base_url);
295
296 let auth_headers = self.get_auth_headers().await?;
298
299 let mut request_builder = self
300 .client
301 .post(&url)
302 .header("content-type", "application/json")
303 .header("accept", "application/json, text/event-stream")
304 .header("mcp-session-id", session_id)
305 .header("mcp-protocol-version", &self.config.protocol_version)
306 .json(&message);
307
308 for (key, value) in auth_headers {
310 request_builder = request_builder.header(key, value);
311 }
312
313 let _ = request_builder.send().await;
315 Ok(())
316 }
317
318 pub async fn get_health(&mut self) -> crate::TransportHealth {
320 crate::TransportHealth {
321 state: if self.session_id.is_some() {
322 crate::ConnectionState::Connected
323 } else {
324 crate::ConnectionState::Disconnected
325 },
326 connection_duration: None,
327 messages_sent: 0,
328 messages_received: 0,
329 error_count: 0,
330 last_activity: None,
331 last_error: None,
332 }
333 }
334
335 pub async fn is_healthy(&self) -> bool {
337 self.session_id.is_some()
338 }
339
340 pub async fn reconnect(&mut self) -> Result<()> {
342 self.session_id = None;
343 self.pending_response = None;
344 self.connect().await?;
345 Ok(())
346 }
347
348 pub async fn reset(&mut self) -> Result<()> {
350 self.session_id = None;
351 self.pending_response = None;
352 self.access_token = None;
353 self.token_expiry = None;
354 Ok(())
355 }
356
357 pub async fn start_sse_stream(&mut self) -> Result<reqwest::Response> {
359 let session_id =
360 self.session_id
361 .clone()
362 .ok_or_else(|| TransportError::ConnectionError {
363 message: "Not connected".to_string(),
364 })?;
365
366 let url = format!("{}/mcp", self.config.base_url);
367
368 let auth_headers = self.get_auth_headers().await?;
370
371 let mut request_builder = self
372 .client
373 .get(&url)
374 .header("accept", "text/event-stream") .header("mcp-session-id", session_id)
376 .header("mcp-protocol-version", &self.config.protocol_version);
377
378 for (key, value) in auth_headers {
380 request_builder = request_builder.header(key, value);
381 }
382
383 let response = request_builder
384 .send()
385 .await
386 .map_err(|e| TransportError::NetworkError {
387 message: format!("Failed to start SSE stream: {e}"),
388 })?;
389
390 if !response.status().is_success() {
391 let error_text = response.text().await.unwrap_or_default();
392 return Err(TransportError::NetworkError {
393 message: format!("SSE stream failed: {error_text}"),
394 });
395 }
396
397 Ok(response)
398 }
399
400 pub async fn resume_sse_stream(&mut self, last_event_id: &str) -> Result<reqwest::Response> {
402 let session_id =
403 self.session_id
404 .clone()
405 .ok_or_else(|| TransportError::ConnectionError {
406 message: "Not connected".to_string(),
407 })?;
408
409 let url = format!("{}/mcp", self.config.base_url);
410
411 let auth_headers = self.get_auth_headers().await?;
413
414 let mut request_builder = self
415 .client
416 .get(&url)
417 .header("accept", "text/event-stream")
418 .header("mcp-session-id", session_id)
419 .header("mcp-protocol-version", &self.config.protocol_version)
420 .header("last-event-id", last_event_id); for (key, value) in auth_headers {
424 request_builder = request_builder.header(key, value);
425 }
426
427 let response = request_builder
428 .send()
429 .await
430 .map_err(|e| TransportError::NetworkError {
431 message: format!("Failed to resume SSE stream: {e}"),
432 })?;
433
434 if !response.status().is_success() {
435 let error_text = response.text().await.unwrap_or_default();
436 return Err(TransportError::NetworkError {
437 message: format!("SSE stream resume failed: {error_text}"),
438 });
439 }
440
441 Ok(response)
442 }
443}
444
445#[async_trait]
446impl Transport for StreamableHttpClient {
447 async fn send_message(&mut self, message: JsonRpcMessage) -> Result<()> {
448 if matches!(message, JsonRpcMessage::Notification(_)) {
450 self.send_notification_internal(message).await
451 } else {
452 let response = self.send_message_internal(message).await?;
454 self.pending_response = Some(response);
455 Ok(())
456 }
457 }
458
459 async fn receive_message(&mut self) -> Result<JsonRpcMessage> {
460 if let Some(response) = self.pending_response.take() {
462 Ok(response)
463 } else {
464 Err(TransportError::ConnectionClosed)
466 }
467 }
468
469 async fn close(&mut self) -> Result<()> {
470 if let Some(session_id) = self.session_id.clone() {
472 let url = format!("{}/mcp", self.config.base_url);
473
474 let auth_headers = self.get_auth_headers().await?;
476
477 let mut request_builder = self
478 .client
479 .delete(&url)
480 .header("mcp-session-id", session_id)
481 .header("mcp-protocol-version", &self.config.protocol_version);
482
483 for (key, value) in auth_headers {
485 request_builder = request_builder.header(key, value);
486 }
487
488 let _ = request_builder.send().await;
489 }
490
491 Ok(())
492 }
493
494 fn get_state(&self) -> crate::ConnectionState {
495 if self.session_id.is_some() {
496 crate::ConnectionState::Connected
497 } else {
498 crate::ConnectionState::Disconnected
499 }
500 }
501
502 fn get_health(&self) -> crate::TransportHealth {
503 crate::TransportHealth {
505 state: self.get_state(),
506 last_activity: None,
507 messages_sent: 0,
508 messages_received: 0,
509 connection_duration: None,
510 error_count: 0,
511 last_error: None,
512 }
513 }
514
515 async fn reconnect(&mut self) -> Result<()> {
516 self.reconnect().await
517 }
518
519 async fn reset(&mut self) -> Result<()> {
520 self.reset().await
521 }
522}