turul_mcp_client/
client.rs

1//! Main MCP client implementation
2
3use serde_json::{Value, json};
4use std::sync::Arc;
5use tokio::time::timeout;
6use tracing::{debug, info, warn};
7
8use crate::config::ClientConfig;
9use crate::error::{McpClientError, McpClientResult, SessionError};
10use crate::session::{SessionManager, SessionState};
11use crate::streaming::StreamHandler;
12use crate::transport::{BoxedTransport, TransportFactory};
13
14// Re-export protocol types for convenience
15use turul_mcp_protocol::meta::Cursor;
16use turul_mcp_protocol::{
17    CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult,
18    ListToolsResult, Prompt, ReadResourceResult, Resource, Tool, ToolResult,
19};
20
21/// Main MCP client
22pub struct McpClient {
23    /// Transport layer
24    transport: Arc<tokio::sync::Mutex<BoxedTransport>>,
25    /// Session manager
26    session: Arc<SessionManager>,
27    /// Configuration
28    config: ClientConfig,
29    /// Stream handler for server events
30    stream_handler: Arc<tokio::sync::Mutex<StreamHandler>>,
31    /// Request ID counter
32    request_counter: Arc<std::sync::atomic::AtomicU64>,
33}
34
35impl Drop for McpClient {
36    /// Automatic cleanup when client is dropped
37    ///
38    /// This ensures that if the client is dropped without explicit disconnect,
39    /// we still attempt to send a DELETE request to clean up the session on the server.
40    fn drop(&mut self) {
41        let session_id = self.session.clone();
42        let transport = self.transport.clone();
43
44        // Spawn a background task to handle cleanup
45        // We can't await in Drop, so we spawn a task that will complete cleanup
46        tokio::spawn(async move {
47            // Only send DELETE if we have a session ID
48            if let Some(session_id_str) = session_id.session_id_optional().await {
49                let mut transport_guard = transport.lock().await;
50
51                info!(
52                    session_id = session_id_str,
53                    "McpClient dropped - attempting session cleanup via DELETE request"
54                );
55
56                if let Err(e) = transport_guard.send_delete(&session_id_str).await {
57                    warn!(
58                        session_id = session_id_str,
59                        error = %e,
60                        "Failed to send DELETE request during Drop cleanup"
61                    );
62                } else {
63                    info!(
64                        session_id = session_id_str,
65                        "Successfully sent DELETE request during Drop cleanup"
66                    );
67                }
68            } else {
69                debug!("No session ID available, skipping DELETE request during Drop");
70            }
71
72            // Also terminate the session locally
73            session_id
74                .terminate(Some("Client dropped".to_string()))
75                .await;
76        });
77    }
78}
79
80impl McpClient {
81    /// Create a new MCP client with the given transport
82    pub fn new(transport: BoxedTransport, config: ClientConfig) -> Self {
83        let session = Arc::new(SessionManager::new(config.clone()));
84
85        Self {
86            transport: Arc::new(tokio::sync::Mutex::new(transport)),
87            session,
88            config,
89            stream_handler: Arc::new(tokio::sync::Mutex::new(StreamHandler::new())),
90            request_counter: Arc::new(std::sync::atomic::AtomicU64::new(0)),
91        }
92    }
93
94    /// Connect to the MCP server
95    pub async fn connect(&self) -> McpClientResult<()> {
96        info!("Connecting to MCP server");
97
98        // Connect transport
99        {
100            let mut transport = self.transport.lock().await;
101            transport.connect().await?;
102
103            // Start event listener if supported
104            if transport.capabilities().server_events {
105                let receiver = transport.start_event_listener().await?;
106                let mut stream_handler = self.stream_handler.lock().await;
107                stream_handler.set_receiver(receiver);
108                stream_handler.start().await?;
109            }
110        }
111
112        // Initialize session
113        self.initialize_session().await?;
114
115        info!("Successfully connected to MCP server");
116        Ok(())
117    }
118
119    /// Disconnect from the MCP server
120    pub async fn disconnect(&self) -> McpClientResult<()> {
121        info!("Disconnecting from MCP server");
122
123        // Send DELETE request for session cleanup if we have a session ID
124        if let Some(session_id) = self.session.session_id_optional().await {
125            let mut transport = self.transport.lock().await;
126            if let Err(e) = transport.send_delete(&session_id).await {
127                warn!(
128                    session_id = session_id,
129                    error = %e,
130                    "Failed to send DELETE request during disconnect - continuing with cleanup"
131                );
132            }
133        } else {
134            debug!("No session ID available, skipping DELETE request during disconnect");
135        }
136
137        // Terminate session locally
138        self.session
139            .terminate(Some("Client disconnect".to_string()))
140            .await;
141
142        // Disconnect transport
143        let mut transport = self.transport.lock().await;
144        transport.disconnect().await?;
145
146        info!("Disconnected from MCP server");
147        Ok(())
148    }
149
150    /// Check if client is connected and ready
151    pub async fn is_ready(&self) -> bool {
152        let transport_connected = {
153            let transport = self.transport.lock().await;
154            transport.is_connected()
155        };
156
157        let session_ready = self.session.is_ready().await;
158
159        transport_connected && session_ready
160    }
161
162    /// Get client connection status
163    pub async fn connection_status(&self) -> ConnectionStatus {
164        let transport_info = {
165            let transport = self.transport.lock().await;
166            transport.connection_info()
167        };
168
169        let session_stats = self.session.statistics().await;
170
171        ConnectionStatus {
172            transport_connected: transport_info.connected,
173            session_state: session_stats.state,
174            transport_type: transport_info.transport_type,
175            endpoint: transport_info.endpoint,
176            session_id: session_stats.session_id,
177            protocol_version: session_stats.protocol_version,
178        }
179    }
180
181    /// Initialize session with server
182    async fn initialize_session(&self) -> McpClientResult<()> {
183        info!("Initializing MCP session");
184
185        self.session.mark_initializing().await?;
186
187        let init_request = self.session.create_initialize_request().await;
188        let request_json = serde_json::to_value(&init_request).map_err(|e| {
189            McpClientError::generic(format!("Failed to serialize initialize request: {}", e))
190        })?;
191
192        let json_rpc_request = json!({
193            "jsonrpc": "2.0",
194            "method": "initialize",
195            "id": self.next_request_id(),
196            "params": request_json
197        });
198
199        // Send initialize request with timeout (need headers for session ID)
200        let response = timeout(
201            self.config.timeouts.initialization,
202            self.send_request_with_headers_internal(json_rpc_request),
203        )
204        .await
205        .map_err(|_| McpClientError::Timeout)?;
206
207        let transport_response = response?;
208
209        // Extract session ID from headers (MCP protocol) - case insensitive lookup
210        let session_id = transport_response
211            .headers
212            .iter()
213            .find(|(key, _)| key.to_lowercase() == "mcp-session-id")
214            .map(|(_, value)| value.clone());
215
216        if let Some(session_id) = session_id {
217            info!("Server provided session ID: {}", session_id);
218
219            // Store in session manager
220            self.session.set_session_id(session_id.clone()).await?;
221
222            // Tell transport to include session ID in all subsequent requests
223            let mut transport = self.transport.lock().await;
224            transport.set_session_id(session_id);
225        } else {
226            return Err(McpClientError::generic(
227                "Server did not provide Mcp-Session-Id header during initialization",
228            ));
229        }
230
231        // Parse initialize response
232        let init_response: InitializeResult = serde_json::from_value(
233            transport_response
234                .body
235                .get("result")
236                .cloned()
237                .unwrap_or(Value::Null),
238        )
239        .map_err(|e| {
240            McpClientError::generic(format!("Failed to parse initialize response: {}", e))
241        })?;
242
243        // Validate server capabilities
244        self.session
245            .validate_server_capabilities(&init_response.capabilities)
246            .await?;
247
248        // Complete session initialization
249        self.session
250            .initialize(
251                init_request.capabilities,
252                init_response.capabilities,
253                init_response.protocol_version,
254            )
255            .await?;
256
257        // Send initialized notification per MCP 2025-06-18 spec
258        let initialized_notification = json!({
259            "jsonrpc": "2.0",
260            "method": "notifications/initialized",
261            "params": {}
262        });
263
264        self.send_notification_internal(initialized_notification)
265            .await?;
266
267        info!("MCP session initialized successfully");
268        Ok(())
269    }
270
271    /// Generate next request ID
272    fn next_request_id(&self) -> String {
273        let counter = self
274            .request_counter
275            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
276        format!("req_{}", counter)
277    }
278
279    /// Send request and handle retries
280    async fn send_request_internal(&self, request: Value) -> McpClientResult<Value> {
281        let mut last_error = None;
282
283        for attempt in 0..self.config.retry.max_attempts {
284            if attempt > 0 {
285                let delay = self.config.retry.delay_for_attempt(attempt);
286                debug!(
287                    attempt = attempt,
288                    delay_ms = delay.as_millis(),
289                    "Retrying request"
290                );
291                tokio::time::sleep(delay).await;
292            }
293
294            match self.send_request_raw(request.clone()).await {
295                Ok(response) => {
296                    self.session.update_activity().await;
297                    return Ok(response);
298                }
299                Err(e) => {
300                    warn!(attempt = attempt, error = %e, "Request failed");
301
302                    if !e.is_retryable() || !self.config.retry.should_retry(attempt + 1) {
303                        return Err(e);
304                    }
305
306                    last_error = Some(e);
307                }
308            }
309        }
310
311        Err(last_error.unwrap_or_else(|| McpClientError::generic("All retry attempts failed")))
312    }
313
314    /// Send request with headers and handle retries (used for initialization)
315    async fn send_request_with_headers_internal(
316        &self,
317        request: Value,
318    ) -> McpClientResult<crate::transport::TransportResponse> {
319        let mut last_error = None;
320
321        for attempt in 0..self.config.retry.max_attempts {
322            if attempt > 0 {
323                let delay = self.config.retry.delay_for_attempt(attempt);
324                debug!(
325                    attempt = attempt,
326                    delay_ms = delay.as_millis(),
327                    "Retrying request with headers"
328                );
329                tokio::time::sleep(delay).await;
330            }
331
332            match self.send_request_with_headers_raw(request.clone()).await {
333                Ok(response) => return Ok(response),
334                Err(e) => {
335                    warn!(
336                        attempt = attempt,
337                        max_attempts = self.config.retry.max_attempts,
338                        error = %e,
339                        "Request with headers failed"
340                    );
341
342                    if !e.is_retryable() {
343                        return Err(e);
344                    }
345
346                    last_error = Some(e);
347                }
348            }
349        }
350
351        Err(last_error.unwrap_or_else(|| McpClientError::generic("All retry attempts failed")))
352    }
353
354    /// Send raw request with headers without retries
355    async fn send_request_with_headers_raw(
356        &self,
357        request: Value,
358    ) -> McpClientResult<crate::transport::TransportResponse> {
359        let mut transport = self.transport.lock().await;
360
361        timeout(
362            self.config.timeouts.request,
363            transport.send_request_with_headers(request),
364        )
365        .await
366        .map_err(|_| McpClientError::Timeout)?
367    }
368
369    /// Send raw request without retries
370    async fn send_request_raw(&self, request: Value) -> McpClientResult<Value> {
371        if !self.session.is_ready().await {
372            return Err(SessionError::NotInitialized.into());
373        }
374
375        let mut transport = self.transport.lock().await;
376
377        let response = timeout(
378            self.config.timeouts.request,
379            transport.send_request(request),
380        )
381        .await
382        .map_err(|_| McpClientError::Timeout)??;
383
384        // Check for JSON-RPC error
385        if let Some(error) = response.get("error") {
386            let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1) as i32;
387            let message = error
388                .get("message")
389                .and_then(|m| m.as_str())
390                .unwrap_or("Unknown error");
391            let data = error.get("data").cloned();
392
393            return Err(McpClientError::server_error(code, message, data));
394        }
395
396        Ok(response)
397    }
398
399    /// Send notification
400    async fn send_notification_internal(&self, notification: Value) -> McpClientResult<()> {
401        let mut transport = self.transport.lock().await;
402        transport.send_notification(notification).await?;
403        self.session.update_activity().await;
404        Ok(())
405    }
406
407    /// List available tools
408    pub async fn list_tools(&self) -> McpClientResult<Vec<Tool>> {
409        debug!("Listing tools");
410
411        let request = json!({
412            "jsonrpc": "2.0",
413            "method": "tools/list",
414            "id": self.next_request_id(),
415            "params": {}
416        });
417
418        let response = self.send_request_internal(request).await?;
419        let tools_response: ListToolsResult =
420            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
421
422        debug!(count = tools_response.tools.len(), "Retrieved tools");
423        Ok(tools_response.tools)
424    }
425
426    /// List available tools with pagination support
427    pub async fn list_tools_paginated(
428        &self,
429        cursor: Option<Cursor>,
430    ) -> McpClientResult<ListToolsResult> {
431        debug!("Listing tools with pagination");
432
433        let request_params = if let Some(cursor) = cursor {
434            json!({ "cursor": cursor.as_str() })
435        } else {
436            json!({})
437        };
438
439        let request = json!({
440            "jsonrpc": "2.0",
441            "method": "tools/list",
442            "id": self.next_request_id(),
443            "params": request_params
444        });
445
446        let response = self.send_request_internal(request).await?;
447        let tools_response: ListToolsResult =
448            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
449
450        debug!(
451            count = tools_response.tools.len(),
452            has_cursor = tools_response.next_cursor.is_some(),
453            "Retrieved tools with pagination"
454        );
455        Ok(tools_response)
456    }
457
458    /// Call a tool
459    pub async fn call_tool(
460        &self,
461        name: &str,
462        arguments: Value,
463    ) -> McpClientResult<Vec<ToolResult>> {
464        debug!(tool = name, "Calling tool");
465
466        let request = json!({
467            "jsonrpc": "2.0",
468            "method": "tools/call",
469            "id": self.next_request_id(),
470            "params": {
471                "name": name,
472                "arguments": arguments
473            }
474        });
475
476        let response = self.send_request_internal(request).await?;
477        let call_response: CallToolResult =
478            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
479
480        debug!(
481            tool = name,
482            is_error = call_response.is_error,
483            "Tool call completed"
484        );
485        Ok(call_response.content)
486    }
487
488    /// List available resources
489    pub async fn list_resources(&self) -> McpClientResult<Vec<Resource>> {
490        debug!("Listing resources");
491
492        let request = json!({
493            "jsonrpc": "2.0",
494            "method": "resources/list",
495            "id": self.next_request_id(),
496            "params": {}
497        });
498
499        let response = self.send_request_internal(request).await?;
500        let resources_response: ListResourcesResult =
501            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
502
503        debug!(
504            count = resources_response.resources.len(),
505            "Retrieved resources"
506        );
507        Ok(resources_response.resources)
508    }
509
510    /// List available resources with pagination support
511    pub async fn list_resources_paginated(
512        &self,
513        cursor: Option<Cursor>,
514    ) -> McpClientResult<ListResourcesResult> {
515        debug!("Listing resources with pagination");
516
517        let request_params = if let Some(cursor) = cursor {
518            json!({ "cursor": cursor.as_str() })
519        } else {
520            json!({})
521        };
522
523        let request = json!({
524            "jsonrpc": "2.0",
525            "method": "resources/list",
526            "id": self.next_request_id(),
527            "params": request_params
528        });
529
530        let response = self.send_request_internal(request).await?;
531        let resources_response: ListResourcesResult =
532            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
533
534        debug!(
535            count = resources_response.resources.len(),
536            has_cursor = resources_response.next_cursor.is_some(),
537            "Retrieved resources with pagination"
538        );
539        Ok(resources_response)
540    }
541
542    /// Read a resource
543    pub async fn read_resource(
544        &self,
545        uri: &str,
546    ) -> McpClientResult<Vec<turul_mcp_protocol_2025_06_18::ResourceContent>> {
547        debug!(uri = uri, "Reading resource");
548
549        let request = json!({
550            "jsonrpc": "2.0",
551            "method": "resources/read",
552            "id": self.next_request_id(),
553            "params": {
554                "uri": uri
555            }
556        });
557
558        let response = self.send_request_internal(request).await?;
559        let read_response: ReadResourceResult =
560            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
561
562        debug!(
563            uri = uri,
564            content_count = read_response.contents.len(),
565            "Resource read completed"
566        );
567        Ok(read_response.contents)
568    }
569
570    /// List available prompts
571    pub async fn list_prompts(&self) -> McpClientResult<Vec<Prompt>> {
572        debug!("Listing prompts");
573
574        let request = json!({
575            "jsonrpc": "2.0",
576            "method": "prompts/list",
577            "id": self.next_request_id(),
578            "params": {}
579        });
580
581        let response = self.send_request_internal(request).await?;
582        let prompts_response: ListPromptsResult =
583            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
584
585        debug!(count = prompts_response.prompts.len(), "Retrieved prompts");
586        Ok(prompts_response.prompts)
587    }
588
589    /// List available prompts with pagination support
590    pub async fn list_prompts_paginated(
591        &self,
592        cursor: Option<Cursor>,
593    ) -> McpClientResult<ListPromptsResult> {
594        debug!("Listing prompts with pagination");
595
596        let request_params = if let Some(cursor) = cursor {
597            json!({ "cursor": cursor.as_str() })
598        } else {
599            json!({})
600        };
601
602        let request = json!({
603            "jsonrpc": "2.0",
604            "method": "prompts/list",
605            "id": self.next_request_id(),
606            "params": request_params
607        });
608
609        let response = self.send_request_internal(request).await?;
610        let prompts_response: ListPromptsResult =
611            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
612
613        debug!(
614            count = prompts_response.prompts.len(),
615            has_cursor = prompts_response.next_cursor.is_some(),
616            "Retrieved prompts with pagination"
617        );
618        Ok(prompts_response)
619    }
620
621    /// Get a prompt
622    pub async fn get_prompt(
623        &self,
624        name: &str,
625        arguments: Option<Value>,
626    ) -> McpClientResult<Vec<turul_mcp_protocol_2025_06_18::PromptMessage>> {
627        debug!(prompt = name, "Getting prompt");
628
629        let mut params = json!({
630            "name": name
631        });
632
633        if let Some(args) = arguments {
634            params["arguments"] = args;
635        }
636
637        let request = json!({
638            "jsonrpc": "2.0",
639            "method": "prompts/get",
640            "id": self.next_request_id(),
641            "params": params
642        });
643
644        let response = self.send_request_internal(request).await?;
645        let prompt_response: GetPromptResult =
646            serde_json::from_value(response.get("result").cloned().unwrap_or(Value::Null))?;
647
648        debug!(
649            prompt = name,
650            message_count = prompt_response.messages.len(),
651            "Prompt retrieved"
652        );
653        Ok(prompt_response.messages)
654    }
655
656    /// Send a ping to test connectivity
657    pub async fn ping(&self) -> McpClientResult<()> {
658        debug!("Sending ping");
659
660        let request = json!({
661            "jsonrpc": "2.0",
662            "method": "ping",
663            "id": self.next_request_id(),
664            "params": {}
665        });
666
667        self.send_request_internal(request).await?;
668        debug!("Ping successful");
669        Ok(())
670    }
671
672    /// Get stream handler for event callbacks
673    pub async fn stream_handler(&self) -> tokio::sync::MutexGuard<'_, StreamHandler> {
674        self.stream_handler.lock().await
675    }
676
677    /// Get session information
678    pub async fn session_info(&self) -> crate::session::SessionInfo {
679        self.session.session_info().await
680    }
681
682    /// Get transport statistics
683    pub async fn transport_stats(&self) -> crate::transport::TransportStatistics {
684        let transport = self.transport.lock().await;
685        transport.statistics()
686    }
687}
688
689/// Connection status information
690#[derive(Debug, Clone)]
691pub struct ConnectionStatus {
692    pub transport_connected: bool,
693    pub session_state: SessionState,
694    pub transport_type: crate::transport::TransportType,
695    pub endpoint: String,
696    pub session_id: Option<String>,
697    pub protocol_version: Option<String>,
698}
699
700impl ConnectionStatus {
701    /// Check if fully connected and ready
702    pub fn is_ready(&self) -> bool {
703        self.transport_connected && matches!(self.session_state, SessionState::Active)
704    }
705
706    /// Get status summary
707    pub fn summary(&self) -> String {
708        let session_display = match &self.session_id {
709            Some(id) => &id[..id.len().min(8)],
710            None => "None",
711        };
712        format!(
713            "{} transport to {} - Session {} ({})",
714            self.transport_type, self.endpoint, session_display, self.session_state
715        )
716    }
717}
718
719/// Builder for creating MCP clients
720pub struct McpClientBuilder {
721    transport: Option<BoxedTransport>,
722    config: Option<ClientConfig>,
723}
724
725impl McpClientBuilder {
726    /// Create a new client builder
727    pub fn new() -> Self {
728        Self {
729            transport: None,
730            config: None,
731        }
732    }
733
734    /// Set transport
735    pub fn with_transport(mut self, transport: BoxedTransport) -> Self {
736        self.transport = Some(transport);
737        self
738    }
739
740    /// Set transport from URL
741    pub fn with_url(mut self, url: &str) -> McpClientResult<Self> {
742        let transport = TransportFactory::from_url(url)?;
743        self.transport = Some(transport);
744        Ok(self)
745    }
746
747    /// Set configuration
748    pub fn with_config(mut self, config: ClientConfig) -> Self {
749        self.config = Some(config);
750        self
751    }
752
753    /// Build the client
754    pub fn build(self) -> McpClient {
755        let transport = self
756            .transport
757            .expect("Transport must be set before building client");
758        let config = self.config.unwrap_or_default();
759
760        McpClient::new(transport, config)
761    }
762}
763
764impl Default for McpClientBuilder {
765    fn default() -> Self {
766        Self::new()
767    }
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use crate::transport::http::HttpTransport;
774
775    #[tokio::test]
776    async fn test_client_builder() {
777        let transport = HttpTransport::new("http://localhost:8080/mcp").unwrap();
778        let client = McpClientBuilder::new()
779            .with_transport(Box::new(transport))
780            .build();
781
782        // Basic smoke test
783        assert!(!client.is_ready().await);
784    }
785
786    #[test]
787    fn test_connection_status() {
788        let status = ConnectionStatus {
789            transport_connected: true,
790            session_state: SessionState::Active,
791            transport_type: crate::transport::TransportType::Http,
792            endpoint: "http://localhost:8080/mcp".to_string(),
793            session_id: Some("session123".to_string()),
794            protocol_version: Some("2025-06-18".to_string()),
795        };
796
797        assert!(status.is_ready());
798        assert!(status.summary().contains("HTTP transport"));
799    }
800}