1use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{RwLock, oneshot};
9use tracing::{error, info, warn};
10use ultrafast_mcp_core::{
11 config::TimeoutConfig,
12 error::{MCPError, MCPResult, ProtocolError, TransportError},
13 protocol::{
14 InitializeRequest, InitializeResponse, InitializedNotification, ShutdownRequest,
15 jsonrpc::{JsonRpcMessage, JsonRpcRequest},
16 },
17 types::{
18 client::{ClientCapabilities, ClientInfo},
19 completion::{CompleteRequest, CompleteResponse},
20 elicitation::{ElicitationRequest, ElicitationResponse},
21 prompts::{GetPromptRequest, GetPromptResponse, ListPromptsRequest, ListPromptsResponse},
22 resources::{
23 ListResourcesRequest, ListResourcesResponse, ReadResourceRequest, ReadResourceResponse,
24 },
25 sampling::{CreateMessageRequest, CreateMessageResponse},
26 server::{ServerCapabilities, ServerInfo},
27 tools::{ListToolsRequest, ListToolsResponse, ToolCall, ToolResult},
28 },
29};
30use ultrafast_mcp_transport::Transport;
31
32#[async_trait::async_trait]
34pub trait ClientElicitationHandler: Send + Sync {
35 async fn handle_elicitation_request(
38 &self,
39 request: ElicitationRequest,
40 ) -> MCPResult<ElicitationResponse>;
41}
42
43#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ClientState {
46 Uninitialized,
47 Initializing,
48 Initialized,
49 Operating,
50 ShuttingDown,
51 Shutdown,
52}
53
54impl ClientState {
55 pub fn can_operate(&self) -> bool {
59 matches!(self, ClientState::Initialized | ClientState::Operating)
60 }
61
62 pub fn is_initialized(&self) -> bool {
64 matches!(self, ClientState::Initialized | ClientState::Operating)
65 }
66
67 pub fn is_shutting_down(&self) -> bool {
69 matches!(self, ClientState::ShuttingDown | ClientState::Shutdown)
70 }
71}
72
73#[derive(Debug)]
75struct PendingRequest {
76 response_sender: oneshot::Sender<JsonRpcMessage>,
77 #[allow(dead_code)]
78 timeout: tokio::time::Instant,
79}
80
81struct ClientStateManager {
83 state: ClientState,
84 server_info: Option<ServerInfo>,
85 server_capabilities: Option<ServerCapabilities>,
86 negotiated_version: Option<String>,
87 request_id_counter: u64,
88 pending_requests: HashMap<u64, PendingRequest>,
89 elicitation_handler: Option<Arc<dyn ClientElicitationHandler>>,
90}
91
92impl ClientStateManager {
93 fn new() -> Self {
94 Self {
95 state: ClientState::Uninitialized,
96 server_info: None,
97 server_capabilities: None,
98 negotiated_version: None,
99 request_id_counter: 1,
100 pending_requests: HashMap::new(),
101 elicitation_handler: None,
102 }
103 }
104
105 fn set_state(&mut self, state: ClientState) {
106 self.state = state;
107 }
108
109 fn set_server_info(&mut self, info: ServerInfo) {
110 self.server_info = Some(info);
111 }
112
113 fn set_server_capabilities(&mut self, capabilities: ServerCapabilities) {
114 self.server_capabilities = Some(capabilities);
115 }
116
117 fn set_negotiated_version(&mut self, version: String) {
118 self.negotiated_version = Some(version);
119 }
120
121 fn set_elicitation_handler(&mut self, handler: Option<Arc<dyn ClientElicitationHandler>>) {
122 self.elicitation_handler = handler;
123 }
124
125 fn next_request_id(&mut self) -> u64 {
126 let id = self.request_id_counter;
127 self.request_id_counter += 1;
128 id
129 }
130
131 fn add_pending_request(&mut self, id: u64, request: PendingRequest) {
132 self.pending_requests.insert(id, request);
133 }
134
135 fn remove_pending_request(&mut self, id: &u64) -> Option<PendingRequest> {
136 self.pending_requests.remove(id)
137 }
138}
139
140pub struct UltraFastClient {
142 info: ClientInfo,
143 capabilities: ClientCapabilities,
144 state_manager: Arc<RwLock<ClientStateManager>>,
145 transport: Arc<RwLock<Option<Box<dyn Transport>>>>,
146 message_receiver: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
147 request_timeout: std::time::Duration,
148 timeout_config: Arc<TimeoutConfig>,
150 #[cfg(feature = "oauth")]
152 auth_middleware: Arc<RwLock<Option<ultrafast_mcp_auth::ClientAuthMiddleware>>>,
153 #[cfg(not(feature = "oauth"))]
154 auth_middleware: Arc<RwLock<Option<()>>>,
155}
156
157impl UltraFastClient {
158 pub fn new(info: ClientInfo, capabilities: ClientCapabilities) -> Self {
160 Self {
161 info,
162 capabilities,
163 state_manager: Arc::new(RwLock::new(ClientStateManager::new())),
164 transport: Arc::new(RwLock::new(None)),
165 message_receiver: Arc::new(RwLock::new(None)),
166 request_timeout: std::time::Duration::from_secs(30),
167 timeout_config: Arc::new(TimeoutConfig::default()),
168 #[cfg(feature = "oauth")]
169 auth_middleware: Arc::new(RwLock::new(None)),
170 #[cfg(not(feature = "oauth"))]
171 auth_middleware: Arc::new(RwLock::new(None)),
172 }
173 }
174
175 pub fn new_with_timeout(
177 info: ClientInfo,
178 capabilities: ClientCapabilities,
179 timeout: std::time::Duration,
180 ) -> Self {
181 Self {
182 info,
183 capabilities,
184 state_manager: Arc::new(RwLock::new(ClientStateManager::new())),
185 transport: Arc::new(RwLock::new(None)),
186 message_receiver: Arc::new(RwLock::new(None)),
187 request_timeout: timeout,
188 timeout_config: Arc::new(TimeoutConfig::default()),
189 #[cfg(feature = "oauth")]
190 auth_middleware: Arc::new(RwLock::new(None)),
191 #[cfg(not(feature = "oauth"))]
192 auth_middleware: Arc::new(RwLock::new(None)),
193 }
194 }
195
196 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
198 self.request_timeout = timeout;
199 self
200 }
201
202 pub fn with_timeout_config(mut self, config: TimeoutConfig) -> Self {
204 self.timeout_config = Arc::new(config);
205 self
206 }
207
208 pub fn get_timeout_config(&self) -> TimeoutConfig {
210 (*self.timeout_config).clone()
211 }
212
213 pub fn with_high_performance_timeouts(mut self) -> Self {
215 self.timeout_config = Arc::new(TimeoutConfig::high_performance());
216 self
217 }
218
219 pub fn with_long_running_timeouts(mut self) -> Self {
221 self.timeout_config = Arc::new(TimeoutConfig::long_running());
222 self
223 }
224
225 pub fn get_operation_timeout(&self, operation: &str) -> std::time::Duration {
227 self.timeout_config.get_timeout_for_operation(operation)
228 }
229
230 #[cfg(feature = "oauth")]
232 pub fn with_auth(self, auth_method: ultrafast_mcp_auth::AuthMethod) -> Self {
233 #[cfg(feature = "oauth")]
234 {
235 let mut auth = self.auth_middleware.blocking_write();
236 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
237 }
238 self
239 }
240
241 #[cfg(feature = "oauth")]
243 pub fn with_bearer_auth(self, token: String) -> Self {
244 #[cfg(feature = "oauth")]
245 {
246 let auth_method = ultrafast_mcp_auth::AuthMethod::bearer(token);
247 let mut auth = self.auth_middleware.blocking_write();
248 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
249 }
250 self
251 }
252
253 #[cfg(feature = "oauth")]
255 pub fn with_bearer_auth_refresh<F, Fut>(self, token: String, refresh_fn: F) -> Self
256 where
257 F: Fn() -> Fut + Send + Sync + 'static,
258 Fut: std::future::Future<Output = Result<String, ultrafast_mcp_auth::AuthError>>
259 + Send
260 + 'static,
261 {
262 #[cfg(feature = "oauth")]
263 {
264 let bearer_auth =
265 ultrafast_mcp_auth::BearerAuth::new(token).with_auto_refresh(refresh_fn);
266 let auth_method = ultrafast_mcp_auth::AuthMethod::Bearer(bearer_auth);
267 let mut auth = self.auth_middleware.blocking_write();
268 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
269 }
270 self
271 }
272
273 #[cfg(feature = "oauth")]
275 pub fn with_oauth_auth(self, config: ultrafast_mcp_auth::OAuthConfig) -> Self {
276 #[cfg(feature = "oauth")]
277 {
278 let auth_method = ultrafast_mcp_auth::AuthMethod::oauth(config);
279 let mut auth = self.auth_middleware.blocking_write();
280 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
281 }
282 self
283 }
284
285 #[cfg(feature = "oauth")]
287 pub fn with_api_key_auth(self, api_key: String) -> Self {
288 #[cfg(feature = "oauth")]
289 {
290 let auth_method = ultrafast_mcp_auth::AuthMethod::api_key(api_key);
291 let mut auth = self.auth_middleware.blocking_write();
292 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
293 }
294 self
295 }
296
297 #[cfg(feature = "oauth")]
299 pub fn with_api_key_auth_custom(self, api_key: String, header_name: String) -> Self {
300 #[cfg(feature = "oauth")]
301 {
302 let api_key_auth =
303 ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
304 let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
305 let mut auth = self.auth_middleware.blocking_write();
306 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
307 }
308 self
309 }
310
311 #[cfg(feature = "oauth")]
313 pub fn with_basic_auth(self, username: String, password: String) -> Self {
314 #[cfg(feature = "oauth")]
315 {
316 let auth_method = ultrafast_mcp_auth::AuthMethod::basic(username, password);
317 let mut auth = self.auth_middleware.blocking_write();
318 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
319 }
320 self
321 }
322
323 #[cfg(feature = "oauth")]
325 pub fn with_custom_auth(self) -> Self {
326 #[cfg(feature = "oauth")]
327 {
328 let auth_method = ultrafast_mcp_auth::AuthMethod::custom();
329 let mut auth = self.auth_middleware.blocking_write();
330 *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
331 }
332 self
333 }
334
335 #[cfg(feature = "oauth")]
337 pub async fn get_auth_headers(
338 &self,
339 ) -> Result<std::collections::HashMap<String, String>, ultrafast_mcp_auth::AuthError> {
340 if let Some(auth) = self.auth_middleware.write().await.as_mut() {
341 auth.get_headers().await
342 } else {
343 Ok(std::collections::HashMap::new())
344 }
345 }
346
347 pub fn with_elicitation_handler(self, handler: Arc<dyn ClientElicitationHandler>) -> Self {
349 let state_manager = self.state_manager.clone();
350 tokio::spawn(async move {
351 let mut state = state_manager.write().await;
352 state.set_elicitation_handler(Some(handler));
353 });
354 self
355 }
356
357 pub async fn connect(&self, transport: Box<dyn Transport>) -> MCPResult<()> {
359 info!("Connecting to MCP server");
360
361 {
362 let mut transport_guard = self.transport.write().await;
363 *transport_guard = Some(transport);
364 }
365
366 self.start_message_receiver().await?;
368
369 self.initialize().await?;
371
372 info!("Successfully connected to MCP server");
373 Ok(())
374 }
375
376 async fn start_message_receiver(&self) -> MCPResult<()> {
378 let transport = self.transport.clone();
379 let state_manager = self.state_manager.clone();
380
381 let handle = tokio::spawn(async move {
382 let mut transport_guard = transport.write().await;
383 let transport = transport_guard
384 .as_mut()
385 .expect("Transport should be available");
386
387 loop {
388 match transport.receive_message().await {
389 Ok(message) => {
390 match &message {
391 JsonRpcMessage::Response(response) => {
392 if let Some(id) = &response.id {
393 if let Ok(id_num) = serde_json::from_value::<u64>(
394 serde_json::to_value(id).unwrap_or_default(),
395 ) {
396 let mut state = state_manager.write().await;
397 if let Some(pending_req) =
398 state.remove_pending_request(&id_num)
399 {
400 let _ = pending_req.response_sender.send(message);
402 }
403 }
404 }
405 }
406 JsonRpcMessage::Request(request) if request.id.is_none() => {
407 Self::handle_notification_static(request.clone()).await;
409 }
410 JsonRpcMessage::Request(request) => {
411 if request.method == "elicitation/create" {
413 info!("Processing elicitation request from server");
414
415 let elicitation_handler = {
417 let state = state_manager.read().await;
418 state.elicitation_handler.clone()
419 };
420
421 if let Some(handler) = elicitation_handler {
422 if let Ok(elicitation_request) =
424 serde_json::from_value::<ElicitationRequest>(
425 request.params.clone().unwrap_or_default(),
426 )
427 {
428 match handler
430 .handle_elicitation_request(elicitation_request)
431 .await
432 {
433 Ok(response) => {
434 let response_params = match serde_json::to_value(response) {
436 Ok(params) => Some(params),
437 Err(e) => {
438 error!("Failed to serialize elicitation response: {}", e);
439 continue;
440 }
441 };
442
443 let response_message = JsonRpcMessage::Request(
444 JsonRpcRequest::new(
445 "elicitation/respond".to_string(),
446 response_params,
447 None, ),
449 );
450
451 if let Err(e) = transport
452 .send_message(response_message)
453 .await
454 {
455 error!(
456 "Failed to send elicitation response: {}",
457 e
458 );
459 }
460 }
461 Err(e) => {
462 error!(
463 "Failed to handle elicitation request: {}",
464 e
465 );
466 }
467 }
468 } else {
469 error!("Failed to parse elicitation request");
470 }
471 } else {
472 warn!(
473 "No elicitation handler configured, ignoring elicitation request"
474 );
475 }
476 } else {
477 warn!(
478 "Received unexpected request without ID: {}",
479 request.method
480 );
481 }
482 }
483 JsonRpcMessage::Notification(notification) => {
484 Self::handle_notification_static(notification.clone()).await;
486 }
487 }
488 }
489 Err(e) => {
490 if !e.to_string().contains("Connection closed") {
492 error!("Transport error in message receiver: {}", e);
493 } else {
494 info!("Transport connection closed (normal shutdown)");
495 }
496 break;
497 }
498 }
499 }
500 });
501
502 {
503 let mut receiver_guard = self.message_receiver.write().await;
504 *receiver_guard = Some(handle);
505 }
506
507 Ok(())
508 }
509
510 async fn handle_notification_static(notification: JsonRpcRequest) {
511 match notification.method.as_str() {
512 "initialized" => {
513 info!("Received initialized notification");
514 }
515 "notifications/tools/listChanged" => {
516 info!("Received tools list changed notification");
517 }
518 "notifications/resources/listChanged" => {
519 info!("Received resources list changed notification");
520 }
521 "notifications/prompts/listChanged" => {
522 info!("Received prompts list changed notification");
523 }
524 "notifications/roots/listChanged" => {
525 info!("Received roots list changed notification");
526 }
527 "elicitation/create" => {
528 info!("Received elicitation request from server");
529 }
532 _ => {
533 warn!("Unknown notification method: {}", notification.method);
534 }
535 }
536 }
537
538 pub async fn connect_stdio(&self) -> MCPResult<()> {
540 let stdio_transport = ultrafast_mcp_transport::stdio::StdioTransport::new()
541 .await
542 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
543 self.connect(Box::new(stdio_transport)).await
544 }
545
546 #[cfg(feature = "http")]
549 pub async fn connect_streamable_http(&self, url: &str) -> MCPResult<()> {
550 use ultrafast_mcp_transport::streamable_http::client::{
551 StreamableHttpClient, StreamableHttpClientConfig,
552 };
553
554 let mut config = StreamableHttpClientConfig {
555 base_url: url.to_string(),
556 session_id: None,
557 protocol_version: "2025-06-18".to_string(),
558 timeout: self.request_timeout,
559 max_retries: 3,
560 auth_token: None,
561 oauth_config: None,
562 auth_method: None,
563 };
564
565 #[cfg(feature = "oauth")]
567 {
568 if let Some(auth) = self.auth_middleware.read().await.as_ref() {
569 config.auth_method = Some(auth.get_auth_method().clone());
570 }
571 }
572
573 let mut http_transport = StreamableHttpClient::new(config)
574 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
575
576 http_transport
578 .connect()
579 .await
580 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
581
582 self.connect(Box::new(http_transport)).await
583 }
584
585 #[cfg(feature = "http")]
587 pub async fn connect_http_with_auth(&self, url: &str, auth_token: String) -> MCPResult<()> {
588 use ultrafast_mcp_transport::streamable_http::client::{
589 StreamableHttpClient, StreamableHttpClientConfig,
590 };
591
592 let config = StreamableHttpClientConfig {
593 base_url: url.to_string(),
594 session_id: None,
595 protocol_version: "2025-06-18".to_string(),
596 timeout: self.request_timeout,
597 max_retries: 3,
598 auth_token: Some(auth_token),
599 oauth_config: None,
600 auth_method: None,
601 };
602
603 let mut http_transport = StreamableHttpClient::new(config)
604 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
605
606 http_transport
608 .connect()
609 .await
610 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
611
612 self.connect(Box::new(http_transport)).await
613 }
614
615 #[cfg(all(feature = "http", feature = "oauth"))]
617 pub async fn connect_streamable_http_with_bearer(
618 &self,
619 url: &str,
620 token: String,
621 ) -> MCPResult<()> {
622 use ultrafast_mcp_transport::streamable_http::client::{
623 StreamableHttpClient, StreamableHttpClientConfig,
624 };
625
626 let config = StreamableHttpClientConfig {
627 base_url: url.to_string(),
628 session_id: None,
629 protocol_version: "2025-06-18".to_string(),
630 timeout: self.request_timeout,
631 max_retries: 3,
632 auth_token: None,
633 oauth_config: None,
634 auth_method: Some(ultrafast_mcp_auth::AuthMethod::bearer(token)),
635 };
636
637 let mut http_transport = StreamableHttpClient::new(config)
638 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
639
640 http_transport
642 .connect()
643 .await
644 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
645
646 self.connect(Box::new(http_transport)).await
647 }
648
649 #[cfg(all(feature = "http", feature = "oauth"))]
651 pub async fn connect_streamable_http_with_oauth(
652 &self,
653 url: &str,
654 oauth_config: ultrafast_mcp_auth::OAuthConfig,
655 ) -> MCPResult<()> {
656 use ultrafast_mcp_transport::streamable_http::client::{
657 StreamableHttpClient, StreamableHttpClientConfig,
658 };
659
660 let config = StreamableHttpClientConfig {
661 base_url: url.to_string(),
662 session_id: None,
663 protocol_version: "2025-06-18".to_string(),
664 timeout: self.request_timeout,
665 max_retries: 3,
666 auth_token: None,
667 oauth_config: Some(oauth_config.clone()),
668 auth_method: Some(ultrafast_mcp_auth::AuthMethod::oauth(oauth_config)),
669 };
670
671 let mut http_transport = StreamableHttpClient::new(config)
672 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
673
674 http_transport
676 .connect()
677 .await
678 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
679
680 self.connect(Box::new(http_transport)).await
681 }
682
683 #[cfg(all(feature = "http", feature = "oauth"))]
685 pub async fn connect_streamable_http_with_api_key(
686 &self,
687 url: &str,
688 api_key: String,
689 ) -> MCPResult<()> {
690 use ultrafast_mcp_transport::streamable_http::client::{
691 StreamableHttpClient, StreamableHttpClientConfig,
692 };
693
694 let config = StreamableHttpClientConfig {
695 base_url: url.to_string(),
696 session_id: None,
697 protocol_version: "2025-06-18".to_string(),
698 timeout: self.request_timeout,
699 max_retries: 3,
700 auth_token: None,
701 oauth_config: None,
702 auth_method: Some(ultrafast_mcp_auth::AuthMethod::api_key(api_key)),
703 };
704
705 let mut http_transport = StreamableHttpClient::new(config)
706 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
707
708 http_transport
710 .connect()
711 .await
712 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
713
714 self.connect(Box::new(http_transport)).await
715 }
716
717 #[cfg(all(feature = "http", feature = "oauth"))]
719 pub async fn connect_streamable_http_with_api_key_custom(
720 &self,
721 url: &str,
722 api_key: String,
723 header_name: String,
724 ) -> MCPResult<()> {
725 use ultrafast_mcp_transport::streamable_http::client::{
726 StreamableHttpClient, StreamableHttpClientConfig,
727 };
728
729 let api_key_auth =
730 ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
731 let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
732
733 let config = StreamableHttpClientConfig {
734 base_url: url.to_string(),
735 session_id: None,
736 protocol_version: "2025-06-18".to_string(),
737 timeout: self.request_timeout,
738 max_retries: 3,
739 auth_token: None,
740 oauth_config: None,
741 auth_method: Some(auth_method),
742 };
743
744 let mut http_transport = StreamableHttpClient::new(config)
745 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
746
747 http_transport
749 .connect()
750 .await
751 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
752
753 self.connect(Box::new(http_transport)).await
754 }
755
756 #[cfg(all(feature = "http", feature = "oauth"))]
758 pub async fn connect_streamable_http_with_basic(
759 &self,
760 url: &str,
761 username: String,
762 password: String,
763 ) -> MCPResult<()> {
764 use ultrafast_mcp_transport::streamable_http::client::{
765 StreamableHttpClient, StreamableHttpClientConfig,
766 };
767
768 let config = StreamableHttpClientConfig {
769 base_url: url.to_string(),
770 session_id: None,
771 protocol_version: "2025-06-18".to_string(),
772 timeout: self.request_timeout,
773 max_retries: 3,
774 auth_token: None,
775 oauth_config: None,
776 auth_method: Some(ultrafast_mcp_auth::AuthMethod::basic(username, password)),
777 };
778
779 let mut http_transport = StreamableHttpClient::new(config)
780 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
781
782 http_transport
784 .connect()
785 .await
786 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
787
788 self.connect(Box::new(http_transport)).await
789 }
790
791 #[cfg(feature = "http")]
793 pub async fn connect_streamable_http_with_config(
794 &self,
795 config: ultrafast_mcp_transport::streamable_http::client::StreamableHttpClientConfig,
796 ) -> MCPResult<()> {
797 use ultrafast_mcp_transport::streamable_http::client::StreamableHttpClient;
798
799 let mut http_transport = StreamableHttpClient::new(config)
800 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
801
802 http_transport
804 .connect()
805 .await
806 .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
807
808 self.connect(Box::new(http_transport)).await
809 }
810
811 pub async fn initialize(&self) -> MCPResult<()> {
815 {
816 let mut state = self.state_manager.write().await;
817 state.set_state(ClientState::Initializing);
818 }
819
820 let init_request = InitializeRequest {
822 protocol_version: ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION.to_string(),
823 capabilities: self.capabilities.clone(),
824 client_info: self.info.clone(),
825 };
826
827 let init_response: InitializeResponse = self
829 .send_request("initialize", Some(serde_json::to_value(init_request)?))
830 .await?;
831
832 if init_response.protocol_version != ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION
834 {
835 return Err(MCPError::Protocol(ProtocolError::InvalidVersion(format!(
836 "Expected protocol version {}, got {}",
837 ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION,
838 init_response.protocol_version
839 ))));
840 }
841
842 {
844 let mut state = self.state_manager.write().await;
845 state.set_server_info(init_response.server_info);
846 state.set_server_capabilities(init_response.capabilities);
847 state.set_negotiated_version(init_response.protocol_version);
848 state.set_state(ClientState::Initialized);
849 }
850
851 let init_notification = InitializedNotification {};
854 if let Err(e) = self
855 .send_notification(
856 "initialized",
857 Some(serde_json::to_value(init_notification)?),
858 )
859 .await
860 {
861 warn!(
863 "Failed to send initialized notification (this is normal for HTTP transport): {}",
864 e
865 );
866 }
867
868 {
869 let mut state = self.state_manager.write().await;
870 state.set_state(ClientState::Operating);
871 }
872
873 info!("Client initialized successfully");
874 Ok(())
875 }
876
877 pub async fn shutdown(&self, reason: Option<String>) -> MCPResult<()> {
879 {
881 let state = self.state_manager.read().await;
882 if state.state == ClientState::ShuttingDown || state.state == ClientState::Shutdown {
883 info!("Client already shutting down or shutdown");
884 return Ok(());
885 }
886 }
887
888 {
889 let mut state = self.state_manager.write().await;
890 state.set_state(ClientState::ShuttingDown);
891 }
892
893 let shutdown_request = ShutdownRequest { reason };
895 if let Err(e) = self
896 .send_request::<serde_json::Value>(
897 "shutdown",
898 Some(serde_json::to_value(shutdown_request)?),
899 )
900 .await
901 {
902 warn!(
903 "Failed to send shutdown request (this is normal for some transports): {}",
904 e
905 );
906 }
907
908 {
909 let mut state = self.state_manager.write().await;
910 state.set_state(ClientState::Shutdown);
911 }
912
913 info!("Client shutdown completed");
914 Ok(())
915 }
916
917 pub async fn disconnect(&self) -> MCPResult<()> {
919 if let Some(handle) = self.message_receiver.write().await.take() {
921 handle.abort();
922 }
923
924 if let Some(mut transport) = self.transport.write().await.take() {
926 transport.close().await.map_err(|e| {
927 MCPError::Transport(TransportError::ConnectionFailed(e.to_string()))
928 })?;
929 }
930
931 {
932 let mut state = self.state_manager.write().await;
933 state.set_state(ClientState::Uninitialized);
934 }
935
936 info!("Client disconnected");
937 Ok(())
938 }
939
940 pub async fn get_state(&self) -> ClientState {
942 self.state_manager.read().await.state.clone()
943 }
944
945 pub async fn can_operate(&self) -> bool {
947 self.get_state().await.can_operate()
948 }
949
950 pub async fn get_server_info(&self) -> Option<ServerInfo> {
952 self.state_manager.read().await.server_info.clone()
953 }
954
955 pub async fn get_server_capabilities(&self) -> Option<ServerCapabilities> {
957 self.state_manager.read().await.server_capabilities.clone()
958 }
959
960 pub async fn get_negotiated_version(&self) -> Option<String> {
962 self.state_manager.read().await.negotiated_version.clone()
963 }
964
965 pub fn info(&self) -> &ClientInfo {
967 &self.info
968 }
969
970 pub async fn check_server_capability(&self, capability: &str) -> MCPResult<bool> {
972 self.ensure_capability_supported(capability).await?;
973
974 if let Some(caps) = self.get_server_capabilities().await {
975 Ok(caps.supports_capability(capability))
976 } else {
977 Ok(false)
978 }
979 }
980
981 pub async fn check_server_feature(&self, capability: &str, feature: &str) -> MCPResult<bool> {
983 self.ensure_capability_supported(capability).await?;
984
985 if let Some(caps) = self.get_server_capabilities().await {
986 Ok(caps.supports_feature(capability, feature))
987 } else {
988 Ok(false)
989 }
990 }
991
992 async fn ensure_capability_supported(&self, _capability: &str) -> MCPResult<()> {
993 if !self.can_operate().await {
994 return Err(MCPError::Protocol(ProtocolError::InternalError(
995 "Client is not in operating state".to_string(),
996 )));
997 }
998 Ok(())
999 }
1000
1001 pub async fn list_tools(&self, request: ListToolsRequest) -> MCPResult<ListToolsResponse> {
1003 self.send_request("tools/list", Some(serde_json::to_value(request)?))
1004 .await
1005 }
1006
1007 pub async fn list_tools_default(&self) -> MCPResult<ListToolsResponse> {
1009 self.list_tools(ListToolsRequest::default()).await
1010 }
1011
1012 pub async fn call_tool(&self, tool_call: ToolCall) -> MCPResult<ToolResult> {
1014 self.send_request("tools/call", Some(serde_json::to_value(tool_call)?))
1015 .await
1016 }
1017
1018 pub async fn list_resources(
1020 &self,
1021 request: ListResourcesRequest,
1022 ) -> MCPResult<ListResourcesResponse> {
1023 self.send_request("resources/list", Some(serde_json::to_value(request)?))
1024 .await
1025 }
1026
1027 pub async fn read_resource(
1029 &self,
1030 request: ReadResourceRequest,
1031 ) -> MCPResult<ReadResourceResponse> {
1032 self.send_request("resources/read", Some(serde_json::to_value(request)?))
1033 .await
1034 }
1035
1036 pub async fn subscribe_resource(&self, uri: String) -> MCPResult<()> {
1038 let request = serde_json::json!({
1039 "uri": uri
1040 });
1041 self.send_notification("resources/subscribe", Some(request))
1042 .await
1043 }
1044
1045 pub async fn list_prompts(
1047 &self,
1048 request: ListPromptsRequest,
1049 ) -> MCPResult<ListPromptsResponse> {
1050 self.send_request("prompts/list", Some(serde_json::to_value(request)?))
1051 .await
1052 }
1053
1054 pub async fn get_prompt(&self, request: GetPromptRequest) -> MCPResult<GetPromptResponse> {
1056 self.send_request("prompts/get", Some(serde_json::to_value(request)?))
1057 .await
1058 }
1059
1060 pub async fn create_message(
1062 &self,
1063 request: CreateMessageRequest,
1064 ) -> MCPResult<CreateMessageResponse> {
1065 self.send_request(
1066 "sampling/createMessage",
1067 Some(serde_json::to_value(request)?),
1068 )
1069 .await
1070 }
1071
1072 pub async fn complete(&self, request: CompleteRequest) -> MCPResult<CompleteResponse> {
1074 self.send_request("completion/complete", Some(serde_json::to_value(request)?))
1075 .await
1076 }
1077
1078 pub async fn respond_to_elicitation(&self, response: ElicitationResponse) -> MCPResult<()> {
1080 self.send_request("elicitation/respond", Some(serde_json::to_value(response)?))
1081 .await
1082 }
1083
1084 pub async fn list_roots(&self) -> MCPResult<Vec<ultrafast_mcp_core::types::roots::Root>> {
1086 self.send_request("roots/list", None).await
1087 }
1088
1089 pub async fn set_log_level(
1091 &self,
1092 level: ultrafast_mcp_core::types::notifications::LogLevel,
1093 ) -> MCPResult<()> {
1094 let request = serde_json::json!({
1095 "level": level
1096 });
1097 self.send_request("logging/setLevel", Some(request)).await
1098 }
1099
1100 pub async fn ping(
1102 &self,
1103 data: Option<serde_json::Value>,
1104 ) -> MCPResult<ultrafast_mcp_core::types::notifications::PingResponse> {
1105 self.send_request("ping", data).await
1106 }
1107
1108 pub async fn start_ping_monitoring(&self, ping_interval: std::time::Duration) -> MCPResult<()> {
1110 info!(
1111 "Starting periodic ping monitoring with interval: {:?}",
1112 ping_interval
1113 );
1114
1115 info!("Ping monitoring enabled (interval: {:?})", ping_interval);
1119
1120 Ok(())
1126 }
1127
1128 pub async fn stop_ping_monitoring(&self) -> MCPResult<()> {
1130 info!("Stopping periodic ping monitoring");
1131 Ok(())
1133 }
1134
1135 pub async fn notify_cancelled(
1137 &self,
1138 request_id: serde_json::Value,
1139 reason: Option<String>,
1140 ) -> MCPResult<()> {
1141 let request = serde_json::json!({
1142 "requestId": request_id,
1143 "reason": reason
1144 });
1145 self.send_notification("$/cancelRequest", Some(request))
1146 .await
1147 }
1148
1149 pub async fn notify_progress(
1151 &self,
1152 progress_token: serde_json::Value,
1153 progress: f64,
1154 total: Option<f64>,
1155 message: Option<String>,
1156 ) -> MCPResult<()> {
1157 let request = serde_json::json!({
1158 "token": progress_token,
1159 "progress": progress,
1160 "total": total,
1161 "message": message
1162 });
1163 self.send_notification("$/progress", Some(request)).await
1164 }
1165
1166 pub fn should_send_progress(&self, last_progress: std::time::Instant) -> bool {
1168 let progress_interval = std::time::Duration::from_secs(5);
1170 last_progress.elapsed() >= progress_interval
1171 }
1172
1173 pub fn get_progress_interval(&self) -> std::time::Duration {
1175 std::time::Duration::from_secs(5)
1177 }
1178
1179 async fn ensure_operational(&self) -> MCPResult<()> {
1180 let state = self.get_state().await;
1181 if !state.can_operate() {
1182 return Err(MCPError::Protocol(ProtocolError::InternalError(format!(
1183 "Client is not in operating state (current state: {state:?})"
1184 ))));
1185 }
1186 Ok(())
1187 }
1188
1189 async fn generate_request_id(&self) -> u64 {
1190 let mut state = self.state_manager.write().await;
1191 state.next_request_id()
1192 }
1193
1194 async fn send_request<T>(&self, method: &str, params: Option<Value>) -> MCPResult<T>
1195 where
1196 T: serde::de::DeserializeOwned,
1197 {
1198 if method != "initialize" {
1200 self.ensure_operational().await?;
1201 }
1202
1203 let request_id = self.generate_request_id().await;
1204 let request = JsonRpcRequest::new(
1205 method.to_string(),
1206 params,
1207 Some(ultrafast_mcp_core::protocol::jsonrpc::RequestId::Number(
1208 request_id as i64,
1209 )),
1210 );
1211
1212 let operation_timeout = self.get_operation_timeout(method);
1214
1215 let (response_sender, response_receiver) = oneshot::channel();
1217
1218 {
1220 let mut state = self.state_manager.write().await;
1221 state.add_pending_request(
1222 request_id,
1223 PendingRequest {
1224 response_sender,
1225 timeout: tokio::time::Instant::now() + operation_timeout,
1226 },
1227 );
1228 }
1229
1230 {
1232 let mut transport_guard = self.transport.write().await;
1233 let transport = transport_guard.as_mut().ok_or_else(|| {
1234 MCPError::Transport(TransportError::ConnectionFailed(
1235 "Transport not available".to_string(),
1236 ))
1237 })?;
1238 transport
1239 .send_message(JsonRpcMessage::Request(request))
1240 .await
1241 .map_err(|e| MCPError::Transport(TransportError::SendFailed(e.to_string())))?;
1242 }
1243
1244 let immediate_response = {
1246 let mut transport_guard = self.transport.write().await;
1247 let transport = transport_guard.as_mut().ok_or_else(|| {
1248 MCPError::Transport(TransportError::ConnectionFailed(
1249 "Transport not available".to_string(),
1250 ))
1251 })?;
1252 transport.receive_message().await.ok()
1253 };
1254
1255 let response = if let Some(immediate) = immediate_response {
1256 immediate
1258 } else {
1259 tokio::time::timeout(operation_timeout, response_receiver)
1261 .await
1262 .map_err(|_| MCPError::Protocol(ProtocolError::RequestTimeout))?
1263 .map_err(|_| {
1264 MCPError::Protocol(ProtocolError::InternalError(
1265 "Response channel closed".to_string(),
1266 ))
1267 })?
1268 };
1269
1270 {
1272 let mut state = self.state_manager.write().await;
1273 state.remove_pending_request(&request_id);
1274 }
1275
1276 match response {
1277 JsonRpcMessage::Response(response) => {
1278 if let Some(error) = response.error {
1279 return Err(MCPError::from(error));
1280 }
1281
1282 if let Some(result) = response.result {
1283 serde_json::from_value(result).map_err(MCPError::Serialization)
1284 } else {
1285 Err(MCPError::Protocol(ProtocolError::InvalidResponse(
1286 "Response has no result or error".to_string(),
1287 )))
1288 }
1289 }
1290 _ => Err(MCPError::Protocol(ProtocolError::InvalidResponse(
1291 "Expected response, got different message type".to_string(),
1292 ))),
1293 }
1294 }
1295
1296 async fn send_notification(&self, method: &str, params: Option<Value>) -> MCPResult<()> {
1297 if method != "initialized" {
1299 self.ensure_operational().await?;
1300 }
1301
1302 let notification = JsonRpcRequest::notification(method.to_string(), params);
1303
1304 let mut transport_guard = self.transport.write().await;
1305 let transport = transport_guard.as_mut().ok_or_else(|| {
1306 MCPError::Transport(TransportError::ConnectionFailed(
1307 "Transport not available".to_string(),
1308 ))
1309 })?;
1310
1311 transport
1312 .send_message(JsonRpcMessage::Notification(notification))
1313 .await
1314 .map_err(|e| MCPError::Transport(TransportError::SendFailed(e.to_string())))
1315 }
1316}
1317
1318impl std::fmt::Debug for UltraFastClient {
1319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1320 f.debug_struct("UltraFastClient")
1321 .field("info", &self.info)
1322 .field("capabilities", &self.capabilities)
1323 .field("state", &"<state_manager>")
1324 .field("transport", &"<transport>")
1325 .field("request_timeout", &self.request_timeout)
1326 .finish()
1327 }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332 use super::*;
1333
1334 #[tokio::test]
1335 async fn test_client_creation() {
1336 let client_info = ClientInfo {
1337 name: "test-client".to_string(),
1338 version: "1.0.0".to_string(),
1339 authors: None,
1340 description: None,
1341 homepage: None,
1342 repository: None,
1343 license: None,
1344 };
1345 let capabilities = ClientCapabilities::default();
1346 let client = UltraFastClient::new(client_info, capabilities);
1347
1348 assert_eq!(client.get_state().await, ClientState::Uninitialized);
1349 assert!(!client.can_operate().await);
1350 }
1351
1352 #[tokio::test]
1353 async fn test_client_creation_with_timeout() {
1354 let client_info = ClientInfo {
1355 name: "test-client".to_string(),
1356 version: "1.0.0".to_string(),
1357 authors: None,
1358 description: None,
1359 homepage: None,
1360 repository: None,
1361 license: None,
1362 };
1363 let capabilities = ClientCapabilities::default();
1364 let timeout = std::time::Duration::from_secs(60);
1365 let client = UltraFastClient::new_with_timeout(client_info, capabilities, timeout);
1366
1367 assert_eq!(client.get_state().await, ClientState::Uninitialized);
1368 }
1369
1370 #[tokio::test]
1371 async fn test_client_state_transitions() {
1372 let client_info = ClientInfo {
1373 name: "test-client".to_string(),
1374 version: "1.0.0".to_string(),
1375 authors: None,
1376 description: None,
1377 homepage: None,
1378 repository: None,
1379 license: None,
1380 };
1381 let capabilities = ClientCapabilities::default();
1382 let client = UltraFastClient::new(client_info, capabilities);
1383
1384 assert_eq!(client.get_state().await, ClientState::Uninitialized);
1385
1386 {
1388 let mut state = client.state_manager.write().await;
1389 state.set_state(ClientState::Initializing);
1390 }
1391 assert_eq!(client.get_state().await, ClientState::Initializing);
1392 }
1393}