Skip to main content

tl_mcp/
client.rs

1//! MCP Client — blocking wrapper around rmcp's async MCP client.
2//!
3//! [`McpClient`] connects to MCP servers as subprocesses over stdio,
4//! performing the 3-step handshake (initialize -> response -> initialized)
5//! automatically via rmcp's [`ServiceExt::serve()`].
6//!
7//! All public methods are blocking — they use an internal tokio runtime
8//! with [`Runtime::block_on()`] to bridge async rmcp calls into sync TL land.
9
10use std::sync::Arc;
11
12use rmcp::model::{
13    CallToolRequestParams, CallToolResult, ClientCapabilities, ClientInfo,
14    CreateMessageRequestParams, CreateMessageResult, ErrorData, GetPromptRequestParams,
15    GetPromptResult, Implementation, ReadResourceRequestParams, ReadResourceResult, Role,
16    SamplingCapability, SamplingMessage, SamplingMessageContent, ServerInfo, Tool,
17};
18use rmcp::service::{RequestContext, RoleClient, RunningService};
19use rmcp::transport::TokioChildProcess;
20use rmcp::{ClientHandler, ServiceExt};
21use tl_errors::security::SecurityPolicy;
22
23use crate::error::McpError;
24
25// ---------------------------------------------------------------------------
26// Timeout constants
27// ---------------------------------------------------------------------------
28
29/// Timeout for initial MCP handshake (connect / serve).
30const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30);
31
32/// Timeout for tool calls (may do substantial work).
33const TOOL_CALL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
34
35/// Timeout for metadata / lightweight operations (ping, list, read).
36const METADATA_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
37
38// ---------------------------------------------------------------------------
39// Sampling types
40// ---------------------------------------------------------------------------
41
42/// Request for LLM completion from an MCP server.
43///
44/// Uses only primitive types so tl-mcp does not depend on tl-ai.
45#[derive(Debug, Clone)]
46pub struct SamplingRequest {
47    /// Conversation messages as (role, content) pairs.
48    pub messages: Vec<(String, String)>,
49    /// Optional system prompt to guide model behavior.
50    pub system_prompt: Option<String>,
51    /// Maximum tokens to generate.
52    pub max_tokens: u32,
53    /// Temperature for controlling randomness (0.0 to 1.0).
54    pub temperature: Option<f64>,
55    /// Hint for which model to use (e.g. "claude-sonnet-4-20250514").
56    pub model_hint: Option<String>,
57    /// Sequences that should stop generation.
58    pub stop_sequences: Option<Vec<String>>,
59}
60
61/// Response from LLM completion.
62#[derive(Debug, Clone)]
63pub struct SamplingResponse {
64    /// The model that produced this response.
65    pub model: String,
66    /// The generated text content.
67    pub content: String,
68    /// Reason generation stopped (e.g. "endTurn", "maxTokens").
69    pub stop_reason: Option<String>,
70}
71
72/// Callback type for handling sampling requests.
73///
74/// MCP servers can request LLM completions from the client via the
75/// `sampling/createMessage` method. This callback bridges that request
76/// to whatever LLM backend the host provides (e.g. tl-ai).
77pub type SamplingCallback =
78    Arc<dyn Fn(SamplingRequest) -> Result<SamplingResponse, String> + Send + Sync>;
79
80// ---------------------------------------------------------------------------
81// Client handler
82// ---------------------------------------------------------------------------
83
84/// MCP client handler for TL.
85///
86/// Provides client identification via `get_info()` and optionally handles
87/// `sampling/createMessage` requests from the server when a [`SamplingCallback`]
88/// is configured.
89pub struct TlClientHandler {
90    /// Optional callback for handling sampling requests from the server.
91    pub(crate) sampling_callback: Option<SamplingCallback>,
92}
93
94impl std::fmt::Debug for TlClientHandler {
95    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
96        f.debug_struct("TlClientHandler")
97            .field("has_sampling", &self.sampling_callback.is_some())
98            .finish()
99    }
100}
101
102impl TlClientHandler {
103    /// Create a new handler with no sampling support.
104    pub fn new() -> Self {
105        Self {
106            sampling_callback: None,
107        }
108    }
109
110    /// Configure a sampling callback for handling `sampling/createMessage`.
111    pub fn with_sampling(mut self, cb: SamplingCallback) -> Self {
112        self.sampling_callback = Some(cb);
113        self
114    }
115}
116
117impl Default for TlClientHandler {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl ClientHandler for TlClientHandler {
124    fn get_info(&self) -> ClientInfo {
125        let mut caps = ClientCapabilities::default();
126        if self.sampling_callback.is_some() {
127            caps.sampling = Some(SamplingCapability::default());
128        }
129        ClientInfo::new(
130            caps,
131            Implementation::new("tl", env!("CARGO_PKG_VERSION"))
132                .with_title("ThinkingLanguage MCP Client"),
133        )
134    }
135
136    fn create_message(
137        &self,
138        params: CreateMessageRequestParams,
139        _context: RequestContext<RoleClient>,
140    ) -> impl Future<Output = Result<CreateMessageResult, ErrorData>> + Send + '_ {
141        let result = match &self.sampling_callback {
142            Some(cb) => {
143                // Convert SamplingMessage list to (role, content) pairs
144                let messages: Vec<(String, String)> = params
145                    .messages
146                    .iter()
147                    .map(|m| {
148                        let role = match m.role {
149                            Role::User => "user".to_string(),
150                            Role::Assistant => "assistant".to_string(),
151                        };
152                        // Extract text from content (may be Single or Multiple)
153                        let content: String = m
154                            .content
155                            .iter()
156                            .filter_map(|c| c.as_text().map(|t| t.text.as_str()))
157                            .collect::<Vec<_>>()
158                            .join("");
159                        (role, content)
160                    })
161                    .collect();
162
163                // Extract model hint from model_preferences
164                let model_hint = params
165                    .model_preferences
166                    .as_ref()
167                    .and_then(|p| p.hints.as_ref())
168                    .and_then(|h| h.first())
169                    .and_then(|h| h.name.clone());
170
171                let req = SamplingRequest {
172                    messages,
173                    system_prompt: params.system_prompt.clone(),
174                    max_tokens: params.max_tokens,
175                    temperature: params.temperature.map(|t| t as f64),
176                    model_hint,
177                    stop_sequences: params.stop_sequences.clone(),
178                };
179
180                match cb(req) {
181                    Ok(resp) => {
182                        let mut result = CreateMessageResult::new(
183                            SamplingMessage::new(
184                                Role::Assistant,
185                                SamplingMessageContent::text(resp.content),
186                            ),
187                            resp.model,
188                        );
189                        if let Some(reason) = resp.stop_reason {
190                            result = result.with_stop_reason(reason);
191                        }
192                        Ok(result)
193                    }
194                    Err(e) => Err(ErrorData::internal_error(e, None)),
195                }
196            }
197            None => Err(ErrorData::method_not_found::<
198                rmcp::model::CreateMessageRequestMethod,
199            >()),
200        };
201        std::future::ready(result)
202    }
203}
204
205// ---------------------------------------------------------------------------
206// McpClient
207// ---------------------------------------------------------------------------
208
209/// A blocking MCP client that connects to servers over stdio subprocess.
210///
211/// Wraps rmcp's async [`RunningService`] with a tokio runtime so all
212/// operations can be called from synchronous TL code.
213///
214/// # Example (conceptual — requires a real MCP server binary)
215/// ```ignore
216/// let client = McpClient::connect("npx", &["-y".into(), "@modelcontextprotocol/server-filesystem".into(), "/tmp".into()], None)?;
217/// let tools = client.list_tools()?;
218/// println!("Available tools: {}", tools.len());
219/// ```
220pub struct McpClient {
221    /// Shared tokio runtime for async operations.
222    runtime: Arc<tokio::runtime::Runtime>,
223    /// The running rmcp service (handles message routing internally).
224    service: Option<RunningService<RoleClient, TlClientHandler>>,
225    /// Cached server info from the handshake.
226    server_info: Option<ServerInfo>,
227}
228
229impl std::fmt::Debug for McpClient {
230    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231        f.debug_struct("McpClient")
232            .field("connected", &self.is_connected())
233            .field("server_info", &self.server_info)
234            .finish()
235    }
236}
237
238impl McpClient {
239    /// Connect to an MCP server by spawning a subprocess.
240    ///
241    /// 1. Checks [`SecurityPolicy`] (if provided) — denies if subprocess
242    ///    execution or the specific command is blocked.
243    /// 2. Spawns the command as a child process with stdio piped.
244    /// 3. Performs the MCP 3-step handshake via rmcp's `ServiceExt::serve()`.
245    /// 4. Caches the server info from the handshake response.
246    ///
247    /// # Arguments
248    /// * `command` — The executable to spawn (e.g. `"npx"`, `"node"`).
249    /// * `args` — Arguments to pass to the executable.
250    /// * `security_policy` — Optional policy to enforce subprocess restrictions.
251    ///
252    /// # Errors
253    /// * [`McpError::PermissionDenied`] — SecurityPolicy blocked the command.
254    /// * [`McpError::ConnectionFailed`] — Could not spawn or handshake.
255    /// * [`McpError::RuntimeError`] — Could not create tokio runtime.
256    pub fn connect(
257        command: &str,
258        args: &[String],
259        security_policy: Option<&SecurityPolicy>,
260    ) -> Result<Self, McpError> {
261        Self::connect_with_sampling(command, args, security_policy, None)
262    }
263
264    /// Connect to an MCP server by spawning a subprocess, with optional
265    /// sampling callback for handling `sampling/createMessage` requests.
266    ///
267    /// 1. Checks [`SecurityPolicy`] (if provided) — denies if subprocess
268    ///    execution or the specific command is blocked.
269    /// 2. Spawns the command as a child process with stdio piped.
270    /// 3. Performs the MCP 3-step handshake via rmcp's `ServiceExt::serve()`.
271    /// 4. Caches the server info from the handshake response.
272    ///
273    /// # Arguments
274    /// * `command` — The executable to spawn (e.g. `"npx"`, `"node"`).
275    /// * `args` — Arguments to pass to the executable.
276    /// * `security_policy` — Optional policy to enforce subprocess restrictions.
277    /// * `sampling_cb` — Optional callback for LLM sampling requests from the server.
278    pub fn connect_with_sampling(
279        command: &str,
280        args: &[String],
281        security_policy: Option<&SecurityPolicy>,
282        sampling_cb: Option<SamplingCallback>,
283    ) -> Result<Self, McpError> {
284        // --- Security check ---
285        if let Some(policy) = security_policy
286            && !policy.check_command(command)
287        {
288            return Err(McpError::PermissionDenied(format!(
289                "Command '{}' is not allowed by security policy",
290                command
291            )));
292        }
293
294        // --- Create tokio runtime ---
295        let runtime = tokio::runtime::Builder::new_multi_thread()
296            .enable_all()
297            .build()
298            .map_err(|e| McpError::RuntimeError(e.to_string()))?;
299        let runtime = Arc::new(runtime);
300
301        // --- Build handler ---
302        let handler = match sampling_cb {
303            Some(cb) => TlClientHandler::new().with_sampling(cb),
304            None => TlClientHandler::new(),
305        };
306
307        // --- Spawn subprocess and perform handshake ---
308        let (service, server_info) = runtime.block_on(async {
309            // Build the tokio Command
310            let mut cmd = tokio::process::Command::new(command);
311            cmd.args(args);
312
313            // Spawn via TokioChildProcess (handles piped stdio + framing)
314            let transport = TokioChildProcess::new(cmd).map_err(|e| {
315                McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
316            })?;
317
318            // Perform 3-step handshake with timeout
319            match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
320                Ok(Ok(service)) => {
321                    let server_info = service.peer().peer_info().cloned();
322                    Ok::<_, McpError>((service, server_info))
323                }
324                Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
325                    "Handshake failed: {}",
326                    e
327                ))),
328                Err(_) => Err(McpError::Timeout),
329            }
330        })?;
331
332        Ok(McpClient {
333            runtime,
334            service: Some(service),
335            server_info,
336        })
337    }
338
339    /// Connect to an MCP server using an existing tokio runtime.
340    ///
341    /// Same as [`connect()`](Self::connect) but shares a runtime with the caller
342    /// (e.g. the VM's async runtime).
343    pub fn connect_with_runtime(
344        command: &str,
345        args: &[String],
346        security_policy: Option<&SecurityPolicy>,
347        runtime: Arc<tokio::runtime::Runtime>,
348    ) -> Result<Self, McpError> {
349        Self::connect_with_runtime_and_sampling(command, args, security_policy, runtime, None)
350    }
351
352    /// Connect to an MCP server using an existing tokio runtime and optional sampling.
353    pub fn connect_with_runtime_and_sampling(
354        command: &str,
355        args: &[String],
356        security_policy: Option<&SecurityPolicy>,
357        runtime: Arc<tokio::runtime::Runtime>,
358        sampling_cb: Option<SamplingCallback>,
359    ) -> Result<Self, McpError> {
360        // --- Security check ---
361        if let Some(policy) = security_policy
362            && !policy.check_command(command)
363        {
364            return Err(McpError::PermissionDenied(format!(
365                "Command '{}' is not allowed by security policy",
366                command
367            )));
368        }
369
370        // --- Build handler ---
371        let handler = match sampling_cb {
372            Some(cb) => TlClientHandler::new().with_sampling(cb),
373            None => TlClientHandler::new(),
374        };
375
376        // --- Spawn subprocess and perform handshake ---
377        let (service, server_info) = runtime.block_on(async {
378            let mut cmd = tokio::process::Command::new(command);
379            cmd.args(args);
380
381            let transport = TokioChildProcess::new(cmd).map_err(|e| {
382                McpError::ConnectionFailed(format!("Failed to spawn '{}': {}", command, e))
383            })?;
384
385            // Perform 3-step handshake with timeout
386            match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
387                Ok(Ok(service)) => {
388                    let server_info = service.peer().peer_info().cloned();
389                    Ok::<_, McpError>((service, server_info))
390                }
391                Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
392                    "Handshake failed: {}",
393                    e
394                ))),
395                Err(_) => Err(McpError::Timeout),
396            }
397        })?;
398
399        Ok(McpClient {
400            runtime,
401            service: Some(service),
402            server_info,
403        })
404    }
405
406    /// Connect to a remote MCP server over HTTP (Streamable HTTP transport).
407    ///
408    /// Creates a new tokio runtime internally. For sharing an existing runtime,
409    /// use [`connect_http_with_runtime()`](Self::connect_http_with_runtime).
410    ///
411    /// # Arguments
412    /// * `url` — The HTTP(S) URL of the MCP server endpoint (e.g. `"http://localhost:8080/mcp"`).
413    ///
414    /// # Errors
415    /// * [`McpError::RuntimeError`] — Could not create tokio runtime.
416    /// * [`McpError::ConnectionFailed`] — HTTP connection or MCP handshake failed.
417    pub fn connect_http(url: &str) -> Result<Self, McpError> {
418        Self::connect_http_with_sampling(url, None)
419    }
420
421    /// Connect to a remote MCP server over HTTP with optional sampling callback.
422    pub fn connect_http_with_sampling(
423        url: &str,
424        sampling_cb: Option<SamplingCallback>,
425    ) -> Result<Self, McpError> {
426        let rt = Arc::new(
427            tokio::runtime::Builder::new_multi_thread()
428                .enable_all()
429                .build()
430                .map_err(|e| McpError::RuntimeError(format!("Failed to create runtime: {e}")))?,
431        );
432        Self::connect_http_with_runtime_and_sampling(url, rt, sampling_cb)
433    }
434
435    /// Connect to a remote MCP server over HTTP using an existing tokio runtime.
436    ///
437    /// # Arguments
438    /// * `url` — The HTTP(S) URL of the MCP server endpoint.
439    /// * `runtime` — A shared tokio runtime to use for async operations.
440    pub fn connect_http_with_runtime(
441        url: &str,
442        runtime: Arc<tokio::runtime::Runtime>,
443    ) -> Result<Self, McpError> {
444        Self::connect_http_with_runtime_and_sampling(url, runtime, None)
445    }
446
447    /// Connect to a remote MCP server over HTTP with runtime and optional sampling.
448    pub fn connect_http_with_runtime_and_sampling(
449        url: &str,
450        runtime: Arc<tokio::runtime::Runtime>,
451        sampling_cb: Option<SamplingCallback>,
452    ) -> Result<Self, McpError> {
453        let url_str = url.to_string();
454        let handler = match sampling_cb {
455            Some(cb) => TlClientHandler::new().with_sampling(cb),
456            None => TlClientHandler::new(),
457        };
458        let (service, server_info) = runtime.block_on(async {
459            use rmcp::transport::StreamableHttpClientTransport;
460
461            let transport = StreamableHttpClientTransport::from_uri(url_str);
462            match tokio::time::timeout(CONNECT_TIMEOUT, handler.serve(transport)).await {
463                Ok(Ok(service)) => {
464                    let info = service.peer_info().cloned();
465                    Ok::<_, McpError>((service, info))
466                }
467                Ok(Err(e)) => Err(McpError::ConnectionFailed(format!(
468                    "HTTP connect failed: {e}"
469                ))),
470                Err(_) => Err(McpError::Timeout),
471            }
472        })?;
473
474        Ok(McpClient {
475            runtime,
476            service: Some(service),
477            server_info,
478        })
479    }
480
481    // -----------------------------------------------------------------------
482    // Operations
483    // -----------------------------------------------------------------------
484
485    /// List all tools exposed by the connected MCP server.
486    ///
487    /// Uses `list_all_tools()` which automatically handles pagination.
488    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
489    pub fn list_tools(&self) -> Result<Vec<Tool>, McpError> {
490        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
491        self.runtime.block_on(async {
492            match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_tools()).await {
493                Ok(Ok(tools)) => Ok(tools),
494                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
495                Err(_) => Err(McpError::Timeout),
496            }
497        })
498    }
499
500    /// Call a tool on the connected MCP server.
501    ///
502    /// # Arguments
503    /// * `name` — The tool name (must match one from `list_tools()`).
504    /// * `arguments` — JSON value with the tool arguments. Must be a JSON object
505    ///   or null/None. Non-object values are rejected.
506    ///
507    /// # Returns
508    /// The [`CallToolResult`] from the server. If the server sets `is_error`,
509    /// this method returns `Err(McpError::ToolError)` with the content text.
510    pub fn call_tool(
511        &self,
512        name: &str,
513        arguments: serde_json::Value,
514    ) -> Result<CallToolResult, McpError> {
515        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
516
517        // Convert Value to Option<JsonObject> (rmcp expects Map, not arbitrary Value)
518        let args_map = match arguments {
519            serde_json::Value::Object(map) => Some(map),
520            serde_json::Value::Null => None,
521            other => {
522                return Err(McpError::ProtocolError(format!(
523                    "Tool arguments must be a JSON object, got: {}",
524                    other
525                )));
526            }
527        };
528
529        let mut params = CallToolRequestParams::new(name.to_string());
530        if let Some(map) = args_map {
531            params = params.with_arguments(map);
532        }
533
534        let result = self.runtime.block_on(async {
535            match tokio::time::timeout(TOOL_CALL_TIMEOUT, service.peer().call_tool(params)).await {
536                Ok(Ok(r)) => Ok(r),
537                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
538                Err(_) => Err(McpError::Timeout),
539            }
540        })?;
541
542        // Check is_error flag
543        if result.is_error == Some(true) {
544            // Extract text content for the error message
545            let error_text: String = result
546                .content
547                .iter()
548                .filter_map(|c| c.raw.as_text().map(|t| t.text.as_str()))
549                .collect::<Vec<_>>()
550                .join("\n");
551            return Err(McpError::ToolError(if error_text.is_empty() {
552                "Tool returned an error".to_string()
553            } else {
554                error_text
555            }));
556        }
557
558        Ok(result)
559    }
560
561    /// Ping the connected MCP server.
562    ///
563    /// Sends a ping request and waits for a response. Useful for health checks.
564    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
565    pub fn ping(&self) -> Result<(), McpError> {
566        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
567        self.runtime.block_on(async {
568            let ping_fut = service
569                .peer()
570                .send_request(rmcp::model::ClientRequest::PingRequest(
571                    rmcp::model::PingRequest {
572                        method: Default::default(),
573                        extensions: Default::default(),
574                    },
575                ));
576            match tokio::time::timeout(METADATA_TIMEOUT, ping_fut).await {
577                Ok(Ok(_)) => Ok(()),
578                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
579                Err(_) => Err(McpError::Timeout),
580            }
581        })
582    }
583
584    /// List all resources exposed by the connected MCP server.
585    ///
586    /// Uses `list_all_resources()` which automatically handles pagination.
587    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
588    pub fn list_resources(&self) -> Result<Vec<rmcp::model::Resource>, McpError> {
589        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
590        self.runtime.block_on(async {
591            match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_resources()).await
592            {
593                Ok(Ok(resources)) => Ok(resources),
594                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
595                Err(_) => Err(McpError::Timeout),
596            }
597        })
598    }
599
600    /// Read a resource by URI.
601    ///
602    /// Returns the resource contents (text or blob).
603    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
604    pub fn read_resource(&self, uri: &str) -> Result<ReadResourceResult, McpError> {
605        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
606        let params = ReadResourceRequestParams::new(uri);
607        self.runtime.block_on(async {
608            match tokio::time::timeout(METADATA_TIMEOUT, service.peer().read_resource(params)).await
609            {
610                Ok(Ok(result)) => Ok(result),
611                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
612                Err(_) => Err(McpError::Timeout),
613            }
614        })
615    }
616
617    /// List all prompts exposed by the connected MCP server.
618    ///
619    /// Uses `list_all_prompts()` which automatically handles pagination.
620    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
621    pub fn list_prompts(&self) -> Result<Vec<rmcp::model::Prompt>, McpError> {
622        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
623        self.runtime.block_on(async {
624            match tokio::time::timeout(METADATA_TIMEOUT, service.peer().list_all_prompts()).await {
625                Ok(Ok(prompts)) => Ok(prompts),
626                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
627                Err(_) => Err(McpError::Timeout),
628            }
629        })
630    }
631
632    /// Get a prompt by name with optional arguments.
633    ///
634    /// Returns the prompt result containing description and messages.
635    /// Times out after [`METADATA_TIMEOUT`] (10 seconds).
636    pub fn get_prompt(
637        &self,
638        name: &str,
639        arguments: Option<serde_json::Map<String, serde_json::Value>>,
640    ) -> Result<GetPromptResult, McpError> {
641        let service = self.service.as_ref().ok_or(McpError::TransportClosed)?;
642        let mut params = GetPromptRequestParams::new(name);
643        if let Some(args) = arguments {
644            params.arguments = Some(args);
645        }
646        self.runtime.block_on(async {
647            match tokio::time::timeout(METADATA_TIMEOUT, service.peer().get_prompt(params)).await {
648                Ok(Ok(result)) => Ok(result),
649                Ok(Err(e)) => Err(McpError::ProtocolError(e.to_string())),
650                Err(_) => Err(McpError::Timeout),
651            }
652        })
653    }
654
655    /// Return cached server info from the handshake.
656    ///
657    /// Contains the server's name, version, capabilities, and protocol version.
658    pub fn server_info(&self) -> Option<&ServerInfo> {
659        self.server_info.as_ref()
660    }
661
662    /// Gracefully disconnect from the MCP server.
663    ///
664    /// Cancels the rmcp service which triggers transport close and child
665    /// process cleanup.
666    pub fn disconnect(&mut self) -> Result<(), McpError> {
667        if let Some(service) = self.service.take() {
668            self.runtime.block_on(async {
669                // cancel() consumes the service and triggers graceful shutdown
670                let _ = service.cancel().await;
671            });
672        }
673        Ok(())
674    }
675
676    /// Check whether the MCP connection is still alive.
677    pub fn is_connected(&self) -> bool {
678        self.service
679            .as_ref()
680            .map(|s| !s.is_closed())
681            .unwrap_or(false)
682    }
683}
684
685impl Drop for McpClient {
686    fn drop(&mut self) {
687        // Best-effort cleanup: cancel the service if still running.
688        // RunningService's own DropGuard will also cancel via CancellationToken,
689        // but we do it explicitly to ensure the runtime processes the shutdown.
690        if let Some(service) = self.service.take() {
691            // We cannot block_on inside Drop if the runtime is being dropped too,
692            // so we spawn a fire-and-forget task.
693            let rt = self.runtime.clone();
694            // Use a separate thread to avoid panic if we're already in an async context.
695            std::thread::spawn(move || {
696                rt.block_on(async {
697                    let _ = service.cancel().await;
698                });
699            });
700        }
701    }
702}
703
704// ---------------------------------------------------------------------------
705// Tests
706// ---------------------------------------------------------------------------
707
708#[cfg(test)]
709mod tests {
710    use super::*;
711
712    #[test]
713    fn test_mcp_error_display() {
714        let err = McpError::PermissionDenied("npx not allowed".to_string());
715        assert_eq!(err.to_string(), "Permission denied: npx not allowed");
716
717        let err = McpError::ConnectionFailed("spawn failed".to_string());
718        assert_eq!(err.to_string(), "Connection failed: spawn failed");
719
720        let err = McpError::ProtocolError("invalid response".to_string());
721        assert_eq!(err.to_string(), "Protocol error: invalid response");
722
723        let err = McpError::ToolError("division by zero".to_string());
724        assert_eq!(err.to_string(), "Tool error: division by zero");
725
726        let err = McpError::TransportClosed;
727        assert_eq!(err.to_string(), "Transport closed");
728
729        let err = McpError::Timeout;
730        assert_eq!(err.to_string(), "Timeout");
731
732        let err = McpError::RuntimeError("thread pool exhausted".to_string());
733        assert_eq!(err.to_string(), "Runtime error: thread pool exhausted");
734    }
735
736    #[test]
737    fn test_client_handler_info_no_sampling() {
738        let handler = TlClientHandler::new();
739        let info = handler.get_info();
740
741        assert_eq!(info.client_info.name, "tl");
742        assert_eq!(info.client_info.version, env!("CARGO_PKG_VERSION"));
743        assert_eq!(
744            info.client_info.title,
745            Some("ThinkingLanguage MCP Client".to_string())
746        );
747        // No sampling capability when no callback configured
748        assert!(info.capabilities.sampling.is_none());
749    }
750
751    #[test]
752    fn test_client_handler_info_with_sampling() {
753        let cb: SamplingCallback = Arc::new(|_req| {
754            Ok(SamplingResponse {
755                model: "test".to_string(),
756                content: "hello".to_string(),
757                stop_reason: None,
758            })
759        });
760        let handler = TlClientHandler::new().with_sampling(cb);
761        let info = handler.get_info();
762
763        assert_eq!(info.client_info.name, "tl");
764        // Sampling capability advertised when callback is configured
765        assert!(info.capabilities.sampling.is_some());
766    }
767
768    #[test]
769    fn test_sampling_callback_construction() {
770        let cb: SamplingCallback = Arc::new(|req| {
771            Ok(SamplingResponse {
772                model: "test-model".to_string(),
773                content: format!(
774                    "Echo: {}",
775                    req.messages.last().map(|(_, c)| c.as_str()).unwrap_or("")
776                ),
777                stop_reason: Some("endTurn".to_string()),
778            })
779        });
780        let handler = TlClientHandler::new().with_sampling(cb);
781        assert!(handler.sampling_callback.is_some());
782    }
783
784    #[test]
785    fn test_no_sampling_callback() {
786        let handler = TlClientHandler::new();
787        assert!(handler.sampling_callback.is_none());
788    }
789
790    #[test]
791    fn test_security_policy_denies_command() {
792        let mut policy = SecurityPolicy::sandbox();
793        // sandbox_mode = true, allow_subprocess = false by default
794        let result = McpClient::connect("npx", &[], Some(&policy));
795        assert!(result.is_err());
796        let err = result.unwrap_err();
797        assert!(matches!(err, McpError::PermissionDenied(_)));
798
799        // Now with subprocess allowed but command not in whitelist
800        policy.allow_subprocess = true;
801        policy.allowed_commands = vec!["node".to_string()];
802        let result = McpClient::connect("npx", &[], Some(&policy));
803        assert!(result.is_err());
804        let err = result.unwrap_err();
805        assert!(matches!(err, McpError::PermissionDenied(_)));
806    }
807
808    #[test]
809    fn test_security_policy_allows_command() {
810        let mut policy = SecurityPolicy::sandbox();
811        policy.allow_subprocess = true;
812        policy.allowed_commands = vec!["echo".to_string()];
813
814        // This will fail at connection (echo is not an MCP server) but
815        // it should NOT fail at the security check.
816        let result = McpClient::connect("echo", &["hello".to_string()], Some(&policy));
817        assert!(result.is_err());
818        let err = result.unwrap_err();
819        // Should be a connection error, not permission denied
820        assert!(
821            matches!(err, McpError::ConnectionFailed(_)),
822            "Expected ConnectionFailed, got: {:?}",
823            err
824        );
825    }
826
827    #[test]
828    fn test_no_security_policy_allows_anything() {
829        // Without a policy, any command is allowed (security check skipped)
830        // This will fail at connection (nonexistent binary) but should pass security
831        let result = McpClient::connect("__nonexistent_mcp_server__", &[], None);
832        assert!(result.is_err());
833        let err = result.unwrap_err();
834        assert!(
835            matches!(err, McpError::ConnectionFailed(_)),
836            "Expected ConnectionFailed, got: {:?}",
837            err
838        );
839    }
840
841    #[test]
842    fn test_permissive_policy_allows_anything() {
843        let policy = SecurityPolicy::permissive();
844        let result = McpClient::connect("__nonexistent_mcp_server__", &[], Some(&policy));
845        assert!(result.is_err());
846        let err = result.unwrap_err();
847        assert!(
848            matches!(err, McpError::ConnectionFailed(_)),
849            "Expected ConnectionFailed, got: {:?}",
850            err
851        );
852    }
853}