ultrafast_mcp_client/
lib.rs

1//! UltraFast MCP Client Library
2//!
3//! A high-performance client implementation for the Model Context Protocol (MCP).
4
5use serde_json::Value;
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::{RwLock, oneshot};
9use tracing::{error, info, warn};
10use ultrafast_mcp_core::{
11    config::TimeoutConfig,
12    error::{MCPError, MCPResult, ProtocolError, TransportError},
13    protocol::{
14        InitializeRequest, InitializeResponse, InitializedNotification, ShutdownRequest,
15        jsonrpc::{JsonRpcMessage, JsonRpcRequest},
16    },
17    types::{
18        client::{ClientCapabilities, ClientInfo},
19        completion::{CompleteRequest, CompleteResponse},
20        elicitation::{ElicitationRequest, ElicitationResponse},
21        prompts::{GetPromptRequest, GetPromptResponse, ListPromptsRequest, ListPromptsResponse},
22        resources::{
23            ListResourcesRequest, ListResourcesResponse, ReadResourceRequest, ReadResourceResponse,
24        },
25        sampling::{CreateMessageRequest, CreateMessageResponse},
26        server::{ServerCapabilities, ServerInfo},
27        tools::{ListToolsRequest, ListToolsResponse, ToolCall, ToolResult},
28    },
29};
30use ultrafast_mcp_transport::Transport;
31
32/// Client-side elicitation handler trait
33#[async_trait::async_trait]
34pub trait ClientElicitationHandler: Send + Sync {
35    /// Handle an elicitation request from the server
36    /// This method should present the request to the user and return their response
37    async fn handle_elicitation_request(
38        &self,
39        request: ElicitationRequest,
40    ) -> MCPResult<ElicitationResponse>;
41}
42
43/// MCP Client state
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum ClientState {
46    Uninitialized,
47    Initializing,
48    Initialized,
49    Operating,
50    ShuttingDown,
51    Shutdown,
52}
53
54impl ClientState {
55    /// Check if the client can perform operations
56    /// According to MCP 2025-06-18 specification, operations are allowed
57    /// once the client is initialized (after initialize response)
58    pub fn can_operate(&self) -> bool {
59        matches!(self, ClientState::Initialized | ClientState::Operating)
60    }
61
62    /// Check if the client is initialized
63    pub fn is_initialized(&self) -> bool {
64        matches!(self, ClientState::Initialized | ClientState::Operating)
65    }
66
67    /// Check if the client is shutting down
68    pub fn is_shutting_down(&self) -> bool {
69        matches!(self, ClientState::ShuttingDown | ClientState::Shutdown)
70    }
71}
72
73/// Pending request information
74#[derive(Debug)]
75struct PendingRequest {
76    response_sender: oneshot::Sender<JsonRpcMessage>,
77    #[allow(dead_code)]
78    timeout: tokio::time::Instant,
79}
80
81/// Client state management
82struct ClientStateManager {
83    state: ClientState,
84    server_info: Option<ServerInfo>,
85    server_capabilities: Option<ServerCapabilities>,
86    negotiated_version: Option<String>,
87    request_id_counter: u64,
88    pending_requests: HashMap<u64, PendingRequest>,
89    elicitation_handler: Option<Arc<dyn ClientElicitationHandler>>,
90}
91
92impl ClientStateManager {
93    fn new() -> Self {
94        Self {
95            state: ClientState::Uninitialized,
96            server_info: None,
97            server_capabilities: None,
98            negotiated_version: None,
99            request_id_counter: 1,
100            pending_requests: HashMap::new(),
101            elicitation_handler: None,
102        }
103    }
104
105    fn set_state(&mut self, state: ClientState) {
106        self.state = state;
107    }
108
109    fn set_server_info(&mut self, info: ServerInfo) {
110        self.server_info = Some(info);
111    }
112
113    fn set_server_capabilities(&mut self, capabilities: ServerCapabilities) {
114        self.server_capabilities = Some(capabilities);
115    }
116
117    fn set_negotiated_version(&mut self, version: String) {
118        self.negotiated_version = Some(version);
119    }
120
121    fn set_elicitation_handler(&mut self, handler: Option<Arc<dyn ClientElicitationHandler>>) {
122        self.elicitation_handler = handler;
123    }
124
125    fn next_request_id(&mut self) -> u64 {
126        let id = self.request_id_counter;
127        self.request_id_counter += 1;
128        id
129    }
130
131    fn add_pending_request(&mut self, id: u64, request: PendingRequest) {
132        self.pending_requests.insert(id, request);
133    }
134
135    fn remove_pending_request(&mut self, id: &u64) -> Option<PendingRequest> {
136        self.pending_requests.remove(id)
137    }
138}
139
140/// UltraFast MCP Client
141pub struct UltraFastClient {
142    info: ClientInfo,
143    capabilities: ClientCapabilities,
144    state_manager: Arc<RwLock<ClientStateManager>>,
145    transport: Arc<RwLock<Option<Box<dyn Transport>>>>,
146    message_receiver: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
147    request_timeout: std::time::Duration,
148    // Timeout configuration (MCP 2025-06-18 compliance)
149    timeout_config: Arc<TimeoutConfig>,
150    // Authentication middleware
151    #[cfg(feature = "oauth")]
152    auth_middleware: Arc<RwLock<Option<ultrafast_mcp_auth::ClientAuthMiddleware>>>,
153    #[cfg(not(feature = "oauth"))]
154    auth_middleware: Arc<RwLock<Option<()>>>,
155}
156
157impl UltraFastClient {
158    /// Create a new MCP client
159    pub fn new(info: ClientInfo, capabilities: ClientCapabilities) -> Self {
160        Self {
161            info,
162            capabilities,
163            state_manager: Arc::new(RwLock::new(ClientStateManager::new())),
164            transport: Arc::new(RwLock::new(None)),
165            message_receiver: Arc::new(RwLock::new(None)),
166            request_timeout: std::time::Duration::from_secs(30),
167            timeout_config: Arc::new(TimeoutConfig::default()),
168            #[cfg(feature = "oauth")]
169            auth_middleware: Arc::new(RwLock::new(None)),
170            #[cfg(not(feature = "oauth"))]
171            auth_middleware: Arc::new(RwLock::new(None)),
172        }
173    }
174
175    /// Create a new MCP client with custom timeout
176    pub fn new_with_timeout(
177        info: ClientInfo,
178        capabilities: ClientCapabilities,
179        timeout: std::time::Duration,
180    ) -> Self {
181        Self {
182            info,
183            capabilities,
184            state_manager: Arc::new(RwLock::new(ClientStateManager::new())),
185            transport: Arc::new(RwLock::new(None)),
186            message_receiver: Arc::new(RwLock::new(None)),
187            request_timeout: timeout,
188            timeout_config: Arc::new(TimeoutConfig::default()),
189            #[cfg(feature = "oauth")]
190            auth_middleware: Arc::new(RwLock::new(None)),
191            #[cfg(not(feature = "oauth"))]
192            auth_middleware: Arc::new(RwLock::new(None)),
193        }
194    }
195
196    /// Set request timeout
197    pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
198        self.request_timeout = timeout;
199        self
200    }
201
202    /// Set timeout configuration
203    pub fn with_timeout_config(mut self, config: TimeoutConfig) -> Self {
204        self.timeout_config = Arc::new(config);
205        self
206    }
207
208    /// Get current timeout configuration
209    pub fn get_timeout_config(&self) -> TimeoutConfig {
210        (*self.timeout_config).clone()
211    }
212
213    /// Set timeout configuration for high-performance scenarios
214    pub fn with_high_performance_timeouts(mut self) -> Self {
215        self.timeout_config = Arc::new(TimeoutConfig::high_performance());
216        self
217    }
218
219    /// Set timeout configuration for long-running operations
220    pub fn with_long_running_timeouts(mut self) -> Self {
221        self.timeout_config = Arc::new(TimeoutConfig::long_running());
222        self
223    }
224
225    /// Get operation-specific timeout
226    pub fn get_operation_timeout(&self, operation: &str) -> std::time::Duration {
227        self.timeout_config.get_timeout_for_operation(operation)
228    }
229
230    /// Set authentication method
231    #[cfg(feature = "oauth")]
232    pub fn with_auth(self, auth_method: ultrafast_mcp_auth::AuthMethod) -> Self {
233        #[cfg(feature = "oauth")]
234        {
235            let mut auth = self.auth_middleware.blocking_write();
236            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
237        }
238        self
239    }
240
241    /// Set Bearer token authentication
242    #[cfg(feature = "oauth")]
243    pub fn with_bearer_auth(self, token: String) -> Self {
244        #[cfg(feature = "oauth")]
245        {
246            let auth_method = ultrafast_mcp_auth::AuthMethod::bearer(token);
247            let mut auth = self.auth_middleware.blocking_write();
248            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
249        }
250        self
251    }
252
253    /// Set Bearer token authentication with auto-refresh
254    #[cfg(feature = "oauth")]
255    pub fn with_bearer_auth_refresh<F, Fut>(self, token: String, refresh_fn: F) -> Self
256    where
257        F: Fn() -> Fut + Send + Sync + 'static,
258        Fut: std::future::Future<Output = Result<String, ultrafast_mcp_auth::AuthError>>
259            + Send
260            + 'static,
261    {
262        #[cfg(feature = "oauth")]
263        {
264            let bearer_auth =
265                ultrafast_mcp_auth::BearerAuth::new(token).with_auto_refresh(refresh_fn);
266            let auth_method = ultrafast_mcp_auth::AuthMethod::Bearer(bearer_auth);
267            let mut auth = self.auth_middleware.blocking_write();
268            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
269        }
270        self
271    }
272
273    /// Set OAuth authentication
274    #[cfg(feature = "oauth")]
275    pub fn with_oauth_auth(self, config: ultrafast_mcp_auth::OAuthConfig) -> Self {
276        #[cfg(feature = "oauth")]
277        {
278            let auth_method = ultrafast_mcp_auth::AuthMethod::oauth(config);
279            let mut auth = self.auth_middleware.blocking_write();
280            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
281        }
282        self
283    }
284
285    /// Set API key authentication
286    #[cfg(feature = "oauth")]
287    pub fn with_api_key_auth(self, api_key: String) -> Self {
288        #[cfg(feature = "oauth")]
289        {
290            let auth_method = ultrafast_mcp_auth::AuthMethod::api_key(api_key);
291            let mut auth = self.auth_middleware.blocking_write();
292            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
293        }
294        self
295    }
296
297    /// Set API key authentication with custom header name
298    #[cfg(feature = "oauth")]
299    pub fn with_api_key_auth_custom(self, api_key: String, header_name: String) -> Self {
300        #[cfg(feature = "oauth")]
301        {
302            let api_key_auth =
303                ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
304            let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
305            let mut auth = self.auth_middleware.blocking_write();
306            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
307        }
308        self
309    }
310
311    /// Set Basic authentication
312    #[cfg(feature = "oauth")]
313    pub fn with_basic_auth(self, username: String, password: String) -> Self {
314        #[cfg(feature = "oauth")]
315        {
316            let auth_method = ultrafast_mcp_auth::AuthMethod::basic(username, password);
317            let mut auth = self.auth_middleware.blocking_write();
318            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
319        }
320        self
321    }
322
323    /// Set custom header authentication
324    #[cfg(feature = "oauth")]
325    pub fn with_custom_auth(self) -> Self {
326        #[cfg(feature = "oauth")]
327        {
328            let auth_method = ultrafast_mcp_auth::AuthMethod::custom();
329            let mut auth = self.auth_middleware.blocking_write();
330            *auth = Some(ultrafast_mcp_auth::ClientAuthMiddleware::new(auth_method));
331        }
332        self
333    }
334
335    /// Get authentication headers for requests
336    #[cfg(feature = "oauth")]
337    pub async fn get_auth_headers(
338        &self,
339    ) -> Result<std::collections::HashMap<String, String>, ultrafast_mcp_auth::AuthError> {
340        if let Some(auth) = self.auth_middleware.write().await.as_mut() {
341            auth.get_headers().await
342        } else {
343            Ok(std::collections::HashMap::new())
344        }
345    }
346
347    /// Set elicitation handler for handling server-initiated elicitation requests
348    pub fn with_elicitation_handler(self, handler: Arc<dyn ClientElicitationHandler>) -> Self {
349        let state_manager = self.state_manager.clone();
350        tokio::spawn(async move {
351            let mut state = state_manager.write().await;
352            state.set_elicitation_handler(Some(handler));
353        });
354        self
355    }
356
357    /// Connect to a server using the provided transport
358    pub async fn connect(&self, transport: Box<dyn Transport>) -> MCPResult<()> {
359        info!("Connecting to MCP server");
360
361        {
362            let mut transport_guard = self.transport.write().await;
363            *transport_guard = Some(transport);
364        }
365
366        // Start message receiver task
367        self.start_message_receiver().await?;
368
369        // Initialize the connection
370        self.initialize().await?;
371
372        info!("Successfully connected to MCP server");
373        Ok(())
374    }
375
376    /// Start the message receiver task for handling responses
377    async fn start_message_receiver(&self) -> MCPResult<()> {
378        let transport = self.transport.clone();
379        let state_manager = self.state_manager.clone();
380
381        let handle = tokio::spawn(async move {
382            let mut transport_guard = transport.write().await;
383            let transport = transport_guard
384                .as_mut()
385                .expect("Transport should be available");
386
387            loop {
388                match transport.receive_message().await {
389                    Ok(message) => {
390                        match &message {
391                            JsonRpcMessage::Response(response) => {
392                                if let Some(id) = &response.id {
393                                    if let Ok(id_num) = serde_json::from_value::<u64>(
394                                        serde_json::to_value(id).unwrap_or_default(),
395                                    ) {
396                                        let mut state = state_manager.write().await;
397                                        if let Some(pending_req) =
398                                            state.remove_pending_request(&id_num)
399                                        {
400                                            // Send response to waiting request
401                                            let _ = pending_req.response_sender.send(message);
402                                        }
403                                    }
404                                }
405                            }
406                            JsonRpcMessage::Request(request) if request.id.is_none() => {
407                                // This is a notification, handle it
408                                Self::handle_notification_static(request.clone()).await;
409                            }
410                            JsonRpcMessage::Request(request) => {
411                                // This is a request without ID, handle as elicitation
412                                if request.method == "elicitation/create" {
413                                    info!("Processing elicitation request from server");
414
415                                    // Get the elicitation handler from state manager
416                                    let elicitation_handler = {
417                                        let state = state_manager.read().await;
418                                        state.elicitation_handler.clone()
419                                    };
420
421                                    if let Some(handler) = elicitation_handler {
422                                        // Parse the elicitation request
423                                        if let Ok(elicitation_request) =
424                                            serde_json::from_value::<ElicitationRequest>(
425                                                request.params.clone().unwrap_or_default(),
426                                            )
427                                        {
428                                            // Handle the elicitation request
429                                            match handler
430                                                .handle_elicitation_request(elicitation_request)
431                                                .await
432                                            {
433                                                Ok(response) => {
434                                                    // Send the response back to the server
435                                                    let response_params = match serde_json::to_value(response) {
436                                                        Ok(params) => Some(params),
437                                                        Err(e) => {
438                                                            error!("Failed to serialize elicitation response: {}", e);
439                                                            continue;
440                                                        }
441                                                    };
442                                                    
443                                                    let response_message = JsonRpcMessage::Request(
444                                                        JsonRpcRequest::new(
445                                                            "elicitation/respond".to_string(),
446                                                            response_params,
447                                                            None, // No ID for elicitation response
448                                                        ),
449                                                    );
450
451                                                    if let Err(e) = transport
452                                                        .send_message(response_message)
453                                                        .await
454                                                    {
455                                                        error!(
456                                                            "Failed to send elicitation response: {}",
457                                                            e
458                                                        );
459                                                    }
460                                                }
461                                                Err(e) => {
462                                                    error!(
463                                                        "Failed to handle elicitation request: {}",
464                                                        e
465                                                    );
466                                                }
467                                            }
468                                        } else {
469                                            error!("Failed to parse elicitation request");
470                                        }
471                                    } else {
472                                        warn!(
473                                            "No elicitation handler configured, ignoring elicitation request"
474                                        );
475                                    }
476                                } else {
477                                    warn!(
478                                        "Received unexpected request without ID: {}",
479                                        request.method
480                                    );
481                                }
482                            }
483                            JsonRpcMessage::Notification(notification) => {
484                                // Handle notification
485                                Self::handle_notification_static(notification.clone()).await;
486                            }
487                        }
488                    }
489                    Err(e) => {
490                        // Only log as error if it's not a normal connection closure
491                        if !e.to_string().contains("Connection closed") {
492                            error!("Transport error in message receiver: {}", e);
493                        } else {
494                            info!("Transport connection closed (normal shutdown)");
495                        }
496                        break;
497                    }
498                }
499            }
500        });
501
502        {
503            let mut receiver_guard = self.message_receiver.write().await;
504            *receiver_guard = Some(handle);
505        }
506
507        Ok(())
508    }
509
510    async fn handle_notification_static(notification: JsonRpcRequest) {
511        match notification.method.as_str() {
512            "initialized" => {
513                info!("Received initialized notification");
514            }
515            "notifications/tools/listChanged" => {
516                info!("Received tools list changed notification");
517            }
518            "notifications/resources/listChanged" => {
519                info!("Received resources list changed notification");
520            }
521            "notifications/prompts/listChanged" => {
522                info!("Received prompts list changed notification");
523            }
524            "notifications/roots/listChanged" => {
525                info!("Received roots list changed notification");
526            }
527            "elicitation/create" => {
528                info!("Received elicitation request from server");
529                // Note: This should be handled by the client's elicitation handler
530                // The actual handling is done in the message receiver loop
531            }
532            _ => {
533                warn!("Unknown notification method: {}", notification.method);
534            }
535        }
536    }
537
538    /// Connect to a server using STDIO transport
539    pub async fn connect_stdio(&self) -> MCPResult<()> {
540        let stdio_transport = ultrafast_mcp_transport::stdio::StdioTransport::new()
541            .await
542            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
543        self.connect(Box::new(stdio_transport)).await
544    }
545
546    /// Connect to a server using Streamable HTTP transport
547    /// This method will automatically use any client-level authentication configured
548    #[cfg(feature = "http")]
549    pub async fn connect_streamable_http(&self, url: &str) -> MCPResult<()> {
550        use ultrafast_mcp_transport::streamable_http::client::{
551            StreamableHttpClient, StreamableHttpClientConfig,
552        };
553
554        let mut config = StreamableHttpClientConfig {
555            base_url: url.to_string(),
556            session_id: None,
557            protocol_version: "2025-06-18".to_string(),
558            timeout: self.request_timeout,
559            max_retries: 3,
560            auth_token: None,
561            oauth_config: None,
562            auth_method: None,
563        };
564
565        // Integrate with client-level auth middleware if available
566        #[cfg(feature = "oauth")]
567        {
568            if let Some(auth) = self.auth_middleware.read().await.as_ref() {
569                config.auth_method = Some(auth.get_auth_method().clone());
570            }
571        }
572
573        let mut http_transport = StreamableHttpClient::new(config)
574            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
575
576        // Connect the transport first
577        http_transport
578            .connect()
579            .await
580            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
581
582        self.connect(Box::new(http_transport)).await
583    }
584
585    /// Connect to a server using HTTP transport with authentication
586    #[cfg(feature = "http")]
587    pub async fn connect_http_with_auth(&self, url: &str, auth_token: String) -> MCPResult<()> {
588        use ultrafast_mcp_transport::streamable_http::client::{
589            StreamableHttpClient, StreamableHttpClientConfig,
590        };
591
592        let config = StreamableHttpClientConfig {
593            base_url: url.to_string(),
594            session_id: None,
595            protocol_version: "2025-06-18".to_string(),
596            timeout: self.request_timeout,
597            max_retries: 3,
598            auth_token: Some(auth_token),
599            oauth_config: None,
600            auth_method: None,
601        };
602
603        let mut http_transport = StreamableHttpClient::new(config)
604            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
605
606        // Connect the transport first
607        http_transport
608            .connect()
609            .await
610            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
611
612        self.connect(Box::new(http_transport)).await
613    }
614
615    /// Connect to a server using Streamable HTTP transport with Bearer token authentication
616    #[cfg(all(feature = "http", feature = "oauth"))]
617    pub async fn connect_streamable_http_with_bearer(
618        &self,
619        url: &str,
620        token: String,
621    ) -> MCPResult<()> {
622        use ultrafast_mcp_transport::streamable_http::client::{
623            StreamableHttpClient, StreamableHttpClientConfig,
624        };
625
626        let config = StreamableHttpClientConfig {
627            base_url: url.to_string(),
628            session_id: None,
629            protocol_version: "2025-06-18".to_string(),
630            timeout: self.request_timeout,
631            max_retries: 3,
632            auth_token: None,
633            oauth_config: None,
634            auth_method: Some(ultrafast_mcp_auth::AuthMethod::bearer(token)),
635        };
636
637        let mut http_transport = StreamableHttpClient::new(config)
638            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
639
640        // Connect the transport first
641        http_transport
642            .connect()
643            .await
644            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
645
646        self.connect(Box::new(http_transport)).await
647    }
648
649    /// Connect to a server using Streamable HTTP transport with OAuth 2.1 authentication
650    #[cfg(all(feature = "http", feature = "oauth"))]
651    pub async fn connect_streamable_http_with_oauth(
652        &self,
653        url: &str,
654        oauth_config: ultrafast_mcp_auth::OAuthConfig,
655    ) -> MCPResult<()> {
656        use ultrafast_mcp_transport::streamable_http::client::{
657            StreamableHttpClient, StreamableHttpClientConfig,
658        };
659
660        let config = StreamableHttpClientConfig {
661            base_url: url.to_string(),
662            session_id: None,
663            protocol_version: "2025-06-18".to_string(),
664            timeout: self.request_timeout,
665            max_retries: 3,
666            auth_token: None,
667            oauth_config: Some(oauth_config.clone()),
668            auth_method: Some(ultrafast_mcp_auth::AuthMethod::oauth(oauth_config)),
669        };
670
671        let mut http_transport = StreamableHttpClient::new(config)
672            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
673
674        // Connect the transport first
675        http_transport
676            .connect()
677            .await
678            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
679
680        self.connect(Box::new(http_transport)).await
681    }
682
683    /// Connect to a server using Streamable HTTP transport with API key authentication
684    #[cfg(all(feature = "http", feature = "oauth"))]
685    pub async fn connect_streamable_http_with_api_key(
686        &self,
687        url: &str,
688        api_key: String,
689    ) -> MCPResult<()> {
690        use ultrafast_mcp_transport::streamable_http::client::{
691            StreamableHttpClient, StreamableHttpClientConfig,
692        };
693
694        let config = StreamableHttpClientConfig {
695            base_url: url.to_string(),
696            session_id: None,
697            protocol_version: "2025-06-18".to_string(),
698            timeout: self.request_timeout,
699            max_retries: 3,
700            auth_token: None,
701            oauth_config: None,
702            auth_method: Some(ultrafast_mcp_auth::AuthMethod::api_key(api_key)),
703        };
704
705        let mut http_transport = StreamableHttpClient::new(config)
706            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
707
708        // Connect the transport first
709        http_transport
710            .connect()
711            .await
712            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
713
714        self.connect(Box::new(http_transport)).await
715    }
716
717    /// Connect to a server using Streamable HTTP transport with API key authentication (custom header)
718    #[cfg(all(feature = "http", feature = "oauth"))]
719    pub async fn connect_streamable_http_with_api_key_custom(
720        &self,
721        url: &str,
722        api_key: String,
723        header_name: String,
724    ) -> MCPResult<()> {
725        use ultrafast_mcp_transport::streamable_http::client::{
726            StreamableHttpClient, StreamableHttpClientConfig,
727        };
728
729        let api_key_auth =
730            ultrafast_mcp_auth::ApiKeyAuth::new(api_key).with_header_name(header_name);
731        let auth_method = ultrafast_mcp_auth::AuthMethod::ApiKey(api_key_auth);
732
733        let config = StreamableHttpClientConfig {
734            base_url: url.to_string(),
735            session_id: None,
736            protocol_version: "2025-06-18".to_string(),
737            timeout: self.request_timeout,
738            max_retries: 3,
739            auth_token: None,
740            oauth_config: None,
741            auth_method: Some(auth_method),
742        };
743
744        let mut http_transport = StreamableHttpClient::new(config)
745            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
746
747        // Connect the transport first
748        http_transport
749            .connect()
750            .await
751            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
752
753        self.connect(Box::new(http_transport)).await
754    }
755
756    /// Connect to a server using Streamable HTTP transport with Basic authentication
757    #[cfg(all(feature = "http", feature = "oauth"))]
758    pub async fn connect_streamable_http_with_basic(
759        &self,
760        url: &str,
761        username: String,
762        password: String,
763    ) -> MCPResult<()> {
764        use ultrafast_mcp_transport::streamable_http::client::{
765            StreamableHttpClient, StreamableHttpClientConfig,
766        };
767
768        let config = StreamableHttpClientConfig {
769            base_url: url.to_string(),
770            session_id: None,
771            protocol_version: "2025-06-18".to_string(),
772            timeout: self.request_timeout,
773            max_retries: 3,
774            auth_token: None,
775            oauth_config: None,
776            auth_method: Some(ultrafast_mcp_auth::AuthMethod::basic(username, password)),
777        };
778
779        let mut http_transport = StreamableHttpClient::new(config)
780            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
781
782        // Connect the transport first
783        http_transport
784            .connect()
785            .await
786            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
787
788        self.connect(Box::new(http_transport)).await
789    }
790
791    /// Connect to a server using Streamable HTTP transport with custom configuration
792    #[cfg(feature = "http")]
793    pub async fn connect_streamable_http_with_config(
794        &self,
795        config: ultrafast_mcp_transport::streamable_http::client::StreamableHttpClientConfig,
796    ) -> MCPResult<()> {
797        use ultrafast_mcp_transport::streamable_http::client::StreamableHttpClient;
798
799        let mut http_transport = StreamableHttpClient::new(config)
800            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
801
802        // Connect the transport first
803        http_transport
804            .connect()
805            .await
806            .map_err(|e| MCPError::Transport(TransportError::ConnectionFailed(e.to_string())))?;
807
808        self.connect(Box::new(http_transport)).await
809    }
810
811    /// Initialize the connection with the server
812    /// Initialize the client connection
813    /// This method must be called after connecting to establish the MCP protocol
814    pub async fn initialize(&self) -> MCPResult<()> {
815        {
816            let mut state = self.state_manager.write().await;
817            state.set_state(ClientState::Initializing);
818        }
819
820        // Create initialization request
821        let init_request = InitializeRequest {
822            protocol_version: ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION.to_string(),
823            capabilities: self.capabilities.clone(),
824            client_info: self.info.clone(),
825        };
826
827        // Send initialization request
828        let init_response: InitializeResponse = self
829            .send_request("initialize", Some(serde_json::to_value(init_request)?))
830            .await?;
831
832        // Validate protocol version
833        if init_response.protocol_version != ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION
834        {
835            return Err(MCPError::Protocol(ProtocolError::InvalidVersion(format!(
836                "Expected protocol version {}, got {}",
837                ultrafast_mcp_core::protocol::version::PROTOCOL_VERSION,
838                init_response.protocol_version
839            ))));
840        }
841
842        // Store server information
843        {
844            let mut state = self.state_manager.write().await;
845            state.set_server_info(init_response.server_info);
846            state.set_server_capabilities(init_response.capabilities);
847            state.set_negotiated_version(init_response.protocol_version);
848            state.set_state(ClientState::Initialized);
849        }
850
851        // Send initialized notification (skip for HTTP transport)
852        // Note: HTTP transport doesn't require initialized notification
853        let init_notification = InitializedNotification {};
854        if let Err(e) = self
855            .send_notification(
856                "initialized",
857                Some(serde_json::to_value(init_notification)?),
858            )
859            .await
860        {
861            // For HTTP transport, ignore errors on initialized notification
862            warn!(
863                "Failed to send initialized notification (this is normal for HTTP transport): {}",
864                e
865            );
866        }
867
868        {
869            let mut state = self.state_manager.write().await;
870            state.set_state(ClientState::Operating);
871        }
872
873        info!("Client initialized successfully");
874        Ok(())
875    }
876
877    /// Shutdown the client
878    pub async fn shutdown(&self, reason: Option<String>) -> MCPResult<()> {
879        // Check if we're already shutting down or shutdown
880        {
881            let state = self.state_manager.read().await;
882            if state.state == ClientState::ShuttingDown || state.state == ClientState::Shutdown {
883                info!("Client already shutting down or shutdown");
884                return Ok(());
885            }
886        }
887
888        {
889            let mut state = self.state_manager.write().await;
890            state.set_state(ClientState::ShuttingDown);
891        }
892
893        // Try to send shutdown request, but don't fail if it doesn't work
894        let shutdown_request = ShutdownRequest { reason };
895        if let Err(e) = self
896            .send_request::<serde_json::Value>(
897                "shutdown",
898                Some(serde_json::to_value(shutdown_request)?),
899            )
900            .await
901        {
902            warn!(
903                "Failed to send shutdown request (this is normal for some transports): {}",
904                e
905            );
906        }
907
908        {
909            let mut state = self.state_manager.write().await;
910            state.set_state(ClientState::Shutdown);
911        }
912
913        info!("Client shutdown completed");
914        Ok(())
915    }
916
917    /// Disconnect from the server
918    pub async fn disconnect(&self) -> MCPResult<()> {
919        // Stop message receiver
920        if let Some(handle) = self.message_receiver.write().await.take() {
921            handle.abort();
922        }
923
924        // Close transport
925        if let Some(mut transport) = self.transport.write().await.take() {
926            transport.close().await.map_err(|e| {
927                MCPError::Transport(TransportError::ConnectionFailed(e.to_string()))
928            })?;
929        }
930
931        {
932            let mut state = self.state_manager.write().await;
933            state.set_state(ClientState::Uninitialized);
934        }
935
936        info!("Client disconnected");
937        Ok(())
938    }
939
940    /// Get current client state
941    pub async fn get_state(&self) -> ClientState {
942        self.state_manager.read().await.state.clone()
943    }
944
945    /// Check if client can perform operations
946    pub async fn can_operate(&self) -> bool {
947        self.get_state().await.can_operate()
948    }
949
950    /// Get server information
951    pub async fn get_server_info(&self) -> Option<ServerInfo> {
952        self.state_manager.read().await.server_info.clone()
953    }
954
955    /// Get server capabilities
956    pub async fn get_server_capabilities(&self) -> Option<ServerCapabilities> {
957        self.state_manager.read().await.server_capabilities.clone()
958    }
959
960    /// Get negotiated protocol version
961    pub async fn get_negotiated_version(&self) -> Option<String> {
962        self.state_manager.read().await.negotiated_version.clone()
963    }
964
965    /// Get client information
966    pub fn info(&self) -> &ClientInfo {
967        &self.info
968    }
969
970    /// Check if server supports a specific capability
971    pub async fn check_server_capability(&self, capability: &str) -> MCPResult<bool> {
972        self.ensure_capability_supported(capability).await?;
973
974        if let Some(caps) = self.get_server_capabilities().await {
975            Ok(caps.supports_capability(capability))
976        } else {
977            Ok(false)
978        }
979    }
980
981    /// Check if server supports a specific feature within a capability
982    pub async fn check_server_feature(&self, capability: &str, feature: &str) -> MCPResult<bool> {
983        self.ensure_capability_supported(capability).await?;
984
985        if let Some(caps) = self.get_server_capabilities().await {
986            Ok(caps.supports_feature(capability, feature))
987        } else {
988            Ok(false)
989        }
990    }
991
992    async fn ensure_capability_supported(&self, _capability: &str) -> MCPResult<()> {
993        if !self.can_operate().await {
994            return Err(MCPError::Protocol(ProtocolError::InternalError(
995                "Client is not in operating state".to_string(),
996            )));
997        }
998        Ok(())
999    }
1000
1001    /// List available tools
1002    pub async fn list_tools(&self, request: ListToolsRequest) -> MCPResult<ListToolsResponse> {
1003        self.send_request("tools/list", Some(serde_json::to_value(request)?))
1004            .await
1005    }
1006
1007    /// List tools with default parameters
1008    pub async fn list_tools_default(&self) -> MCPResult<ListToolsResponse> {
1009        self.list_tools(ListToolsRequest::default()).await
1010    }
1011
1012    /// Call a tool
1013    pub async fn call_tool(&self, tool_call: ToolCall) -> MCPResult<ToolResult> {
1014        self.send_request("tools/call", Some(serde_json::to_value(tool_call)?))
1015            .await
1016    }
1017
1018    /// List available resources
1019    pub async fn list_resources(
1020        &self,
1021        request: ListResourcesRequest,
1022    ) -> MCPResult<ListResourcesResponse> {
1023        self.send_request("resources/list", Some(serde_json::to_value(request)?))
1024            .await
1025    }
1026
1027    /// Read a resource
1028    pub async fn read_resource(
1029        &self,
1030        request: ReadResourceRequest,
1031    ) -> MCPResult<ReadResourceResponse> {
1032        self.send_request("resources/read", Some(serde_json::to_value(request)?))
1033            .await
1034    }
1035
1036    /// Subscribe to resource changes
1037    pub async fn subscribe_resource(&self, uri: String) -> MCPResult<()> {
1038        let request = serde_json::json!({
1039            "uri": uri
1040        });
1041        self.send_notification("resources/subscribe", Some(request))
1042            .await
1043    }
1044
1045    /// List available prompts
1046    pub async fn list_prompts(
1047        &self,
1048        request: ListPromptsRequest,
1049    ) -> MCPResult<ListPromptsResponse> {
1050        self.send_request("prompts/list", Some(serde_json::to_value(request)?))
1051            .await
1052    }
1053
1054    /// Get a specific prompt
1055    pub async fn get_prompt(&self, request: GetPromptRequest) -> MCPResult<GetPromptResponse> {
1056        self.send_request("prompts/get", Some(serde_json::to_value(request)?))
1057            .await
1058    }
1059
1060    /// Create a message using sampling
1061    pub async fn create_message(
1062        &self,
1063        request: CreateMessageRequest,
1064    ) -> MCPResult<CreateMessageResponse> {
1065        self.send_request(
1066            "sampling/createMessage",
1067            Some(serde_json::to_value(request)?),
1068        )
1069        .await
1070    }
1071
1072    /// Complete a request
1073    pub async fn complete(&self, request: CompleteRequest) -> MCPResult<CompleteResponse> {
1074        self.send_request("completion/complete", Some(serde_json::to_value(request)?))
1075            .await
1076    }
1077
1078    /// Respond to elicitation request (called by client-side elicitation handler)
1079    pub async fn respond_to_elicitation(&self, response: ElicitationResponse) -> MCPResult<()> {
1080        self.send_request("elicitation/respond", Some(serde_json::to_value(response)?))
1081            .await
1082    }
1083
1084    /// List filesystem roots
1085    pub async fn list_roots(&self) -> MCPResult<Vec<ultrafast_mcp_core::types::roots::Root>> {
1086        self.send_request("roots/list", None).await
1087    }
1088
1089    /// Set log level
1090    pub async fn set_log_level(
1091        &self,
1092        level: ultrafast_mcp_core::types::notifications::LogLevel,
1093    ) -> MCPResult<()> {
1094        let request = serde_json::json!({
1095            "level": level
1096        });
1097        self.send_request("logging/setLevel", Some(request)).await
1098    }
1099
1100    /// Send ping
1101    pub async fn ping(
1102        &self,
1103        data: Option<serde_json::Value>,
1104    ) -> MCPResult<ultrafast_mcp_core::types::notifications::PingResponse> {
1105        self.send_request("ping", data).await
1106    }
1107
1108    /// Start periodic ping monitoring (optional, for connection health)
1109    pub async fn start_ping_monitoring(&self, ping_interval: std::time::Duration) -> MCPResult<()> {
1110        info!(
1111            "Starting periodic ping monitoring with interval: {:?}",
1112            ping_interval
1113        );
1114
1115        // Note: This is a placeholder for future implementation
1116        // The actual ping monitoring would need to be integrated with the transport layer
1117        // For now, we log that ping monitoring is enabled
1118        info!("Ping monitoring enabled (interval: {:?})", ping_interval);
1119
1120        // Future implementation would spawn a background task that:
1121        // 1. Clones the client's transport
1122        // 2. Sends periodic ping requests
1123        // 3. Handles timeouts and reconnection logic
1124
1125        Ok(())
1126    }
1127
1128    /// Stop ping monitoring
1129    pub async fn stop_ping_monitoring(&self) -> MCPResult<()> {
1130        info!("Stopping periodic ping monitoring");
1131        // The ping monitoring task will naturally stop when the client is disconnected
1132        Ok(())
1133    }
1134
1135    /// Notify cancellation
1136    pub async fn notify_cancelled(
1137        &self,
1138        request_id: serde_json::Value,
1139        reason: Option<String>,
1140    ) -> MCPResult<()> {
1141        let request = serde_json::json!({
1142            "requestId": request_id,
1143            "reason": reason
1144        });
1145        self.send_notification("$/cancelRequest", Some(request))
1146            .await
1147    }
1148
1149    /// Notify progress
1150    pub async fn notify_progress(
1151        &self,
1152        progress_token: serde_json::Value,
1153        progress: f64,
1154        total: Option<f64>,
1155        message: Option<String>,
1156    ) -> MCPResult<()> {
1157        let request = serde_json::json!({
1158            "token": progress_token,
1159            "progress": progress,
1160            "total": total,
1161            "message": message
1162        });
1163        self.send_notification("$/progress", Some(request)).await
1164    }
1165
1166    /// Check if progress notification should be sent based on timeout configuration
1167    pub fn should_send_progress(&self, last_progress: std::time::Instant) -> bool {
1168        // Use a default interval since progress_interval is not available in the new TimeoutConfig
1169        let progress_interval = std::time::Duration::from_secs(5);
1170        last_progress.elapsed() >= progress_interval
1171    }
1172
1173    /// Get progress interval from timeout configuration
1174    pub fn get_progress_interval(&self) -> std::time::Duration {
1175        // Return a default interval since progress_interval is not available in the new TimeoutConfig
1176        std::time::Duration::from_secs(5)
1177    }
1178
1179    async fn ensure_operational(&self) -> MCPResult<()> {
1180        let state = self.get_state().await;
1181        if !state.can_operate() {
1182            return Err(MCPError::Protocol(ProtocolError::InternalError(format!(
1183                "Client is not in operating state (current state: {state:?})"
1184            ))));
1185        }
1186        Ok(())
1187    }
1188
1189    async fn generate_request_id(&self) -> u64 {
1190        let mut state = self.state_manager.write().await;
1191        state.next_request_id()
1192    }
1193
1194    async fn send_request<T>(&self, method: &str, params: Option<Value>) -> MCPResult<T>
1195    where
1196        T: serde::de::DeserializeOwned,
1197    {
1198        // Allow initialize request even when not operational
1199        if method != "initialize" {
1200            self.ensure_operational().await?;
1201        }
1202
1203        let request_id = self.generate_request_id().await;
1204        let request = JsonRpcRequest::new(
1205            method.to_string(),
1206            params,
1207            Some(ultrafast_mcp_core::protocol::jsonrpc::RequestId::Number(
1208                request_id as i64,
1209            )),
1210        );
1211
1212        // Get operation-specific timeout
1213        let operation_timeout = self.get_operation_timeout(method);
1214
1215        // Create response channel
1216        let (response_sender, response_receiver) = oneshot::channel();
1217
1218        // Add to pending requests
1219        {
1220            let mut state = self.state_manager.write().await;
1221            state.add_pending_request(
1222                request_id,
1223                PendingRequest {
1224                    response_sender,
1225                    timeout: tokio::time::Instant::now() + operation_timeout,
1226                },
1227            );
1228        }
1229
1230        // Send request
1231        {
1232            let mut transport_guard = self.transport.write().await;
1233            let transport = transport_guard.as_mut().ok_or_else(|| {
1234                MCPError::Transport(TransportError::ConnectionFailed(
1235                    "Transport not available".to_string(),
1236                ))
1237            })?;
1238            transport
1239                .send_message(JsonRpcMessage::Request(request))
1240                .await
1241                .map_err(|e| MCPError::Transport(TransportError::SendFailed(e.to_string())))?;
1242        }
1243
1244        // Try to get immediate response from transport (for HTTP transport)
1245        let immediate_response = {
1246            let mut transport_guard = self.transport.write().await;
1247            let transport = transport_guard.as_mut().ok_or_else(|| {
1248                MCPError::Transport(TransportError::ConnectionFailed(
1249                    "Transport not available".to_string(),
1250                ))
1251            })?;
1252            transport.receive_message().await.ok()
1253        };
1254
1255        let response = if let Some(immediate) = immediate_response {
1256            // Got immediate response from transport
1257            immediate
1258        } else {
1259            // Wait for response through message receiver task
1260            tokio::time::timeout(operation_timeout, response_receiver)
1261                .await
1262                .map_err(|_| MCPError::Protocol(ProtocolError::RequestTimeout))?
1263                .map_err(|_| {
1264                    MCPError::Protocol(ProtocolError::InternalError(
1265                        "Response channel closed".to_string(),
1266                    ))
1267                })?
1268        };
1269
1270        // Remove from pending requests
1271        {
1272            let mut state = self.state_manager.write().await;
1273            state.remove_pending_request(&request_id);
1274        }
1275
1276        match response {
1277            JsonRpcMessage::Response(response) => {
1278                if let Some(error) = response.error {
1279                    return Err(MCPError::from(error));
1280                }
1281
1282                if let Some(result) = response.result {
1283                    serde_json::from_value(result).map_err(MCPError::Serialization)
1284                } else {
1285                    Err(MCPError::Protocol(ProtocolError::InvalidResponse(
1286                        "Response has no result or error".to_string(),
1287                    )))
1288                }
1289            }
1290            _ => Err(MCPError::Protocol(ProtocolError::InvalidResponse(
1291                "Expected response, got different message type".to_string(),
1292            ))),
1293        }
1294    }
1295
1296    async fn send_notification(&self, method: &str, params: Option<Value>) -> MCPResult<()> {
1297        // Allow initialized notification even when not operational
1298        if method != "initialized" {
1299            self.ensure_operational().await?;
1300        }
1301
1302        let notification = JsonRpcRequest::notification(method.to_string(), params);
1303
1304        let mut transport_guard = self.transport.write().await;
1305        let transport = transport_guard.as_mut().ok_or_else(|| {
1306            MCPError::Transport(TransportError::ConnectionFailed(
1307                "Transport not available".to_string(),
1308            ))
1309        })?;
1310
1311        transport
1312            .send_message(JsonRpcMessage::Notification(notification))
1313            .await
1314            .map_err(|e| MCPError::Transport(TransportError::SendFailed(e.to_string())))
1315    }
1316}
1317
1318impl std::fmt::Debug for UltraFastClient {
1319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1320        f.debug_struct("UltraFastClient")
1321            .field("info", &self.info)
1322            .field("capabilities", &self.capabilities)
1323            .field("state", &"<state_manager>")
1324            .field("transport", &"<transport>")
1325            .field("request_timeout", &self.request_timeout)
1326            .finish()
1327    }
1328}
1329
1330#[cfg(test)]
1331mod tests {
1332    use super::*;
1333
1334    #[tokio::test]
1335    async fn test_client_creation() {
1336        let client_info = ClientInfo {
1337            name: "test-client".to_string(),
1338            version: "1.0.0".to_string(),
1339            authors: None,
1340            description: None,
1341            homepage: None,
1342            repository: None,
1343            license: None,
1344        };
1345        let capabilities = ClientCapabilities::default();
1346        let client = UltraFastClient::new(client_info, capabilities);
1347
1348        assert_eq!(client.get_state().await, ClientState::Uninitialized);
1349        assert!(!client.can_operate().await);
1350    }
1351
1352    #[tokio::test]
1353    async fn test_client_creation_with_timeout() {
1354        let client_info = ClientInfo {
1355            name: "test-client".to_string(),
1356            version: "1.0.0".to_string(),
1357            authors: None,
1358            description: None,
1359            homepage: None,
1360            repository: None,
1361            license: None,
1362        };
1363        let capabilities = ClientCapabilities::default();
1364        let timeout = std::time::Duration::from_secs(60);
1365        let client = UltraFastClient::new_with_timeout(client_info, capabilities, timeout);
1366
1367        assert_eq!(client.get_state().await, ClientState::Uninitialized);
1368    }
1369
1370    #[tokio::test]
1371    async fn test_client_state_transitions() {
1372        let client_info = ClientInfo {
1373            name: "test-client".to_string(),
1374            version: "1.0.0".to_string(),
1375            authors: None,
1376            description: None,
1377            homepage: None,
1378            repository: None,
1379            license: None,
1380        };
1381        let capabilities = ClientCapabilities::default();
1382        let client = UltraFastClient::new(client_info, capabilities);
1383
1384        assert_eq!(client.get_state().await, ClientState::Uninitialized);
1385
1386        // Test state transitions through the state manager
1387        {
1388            let mut state = client.state_manager.write().await;
1389            state.set_state(ClientState::Initializing);
1390        }
1391        assert_eq!(client.get_state().await, ClientState::Initializing);
1392    }
1393}