Skip to main content

zeph_core/agent/
mcp.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9    /// Dispatch a `/mcp` subcommand, returning the output as a `String`.
10    ///
11    /// All output is collected into the returned string; no channel sends are
12    /// performed.  This makes the future `Send`-compatible for use in
13    /// `AgentAccess::handle_mcp`.
14    pub(super) async fn handle_mcp_command(
15        &mut self,
16        args: &str,
17    ) -> Result<String, super::error::AgentError> {
18        let parts: Vec<&str> = args.split_whitespace().collect();
19        match parts.first().copied() {
20            Some("add") => self.handle_mcp_add(&parts[1..]).await,
21            Some("list") => self.handle_mcp_list().await,
22            Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
23            Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
24            _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
25        }
26    }
27
28    #[allow(clippy::too_many_lines)]
29    async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
30        if args.len() < 2 {
31            return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
32        }
33
34        // Clone the Arc so no borrow of self.mcp.manager is held across .await.
35        let Some(manager) = self.mcp.manager.clone() else {
36            return Ok("MCP is not enabled.".to_owned());
37        };
38
39        let target = args[1];
40        let is_url = target.starts_with("http://") || target.starts_with("https://");
41
42        // SEC-MCP-01: validate command against allowlist (stdio only)
43        if !is_url
44            && !self.mcp.allowed_commands.is_empty()
45            && !self.mcp.allowed_commands.iter().any(|c| c == target)
46        {
47            return Ok(format!(
48                "Command '{target}' is not allowed. Permitted: {}",
49                self.mcp.allowed_commands.join(", ")
50            ));
51        }
52
53        // SEC-MCP-03: enforce server limit
54        let current_count = manager.list_servers().await.len();
55        if current_count >= self.mcp.max_dynamic {
56            return Ok(format!(
57                "Server limit reached ({}/{}).",
58                current_count, self.mcp.max_dynamic
59            ));
60        }
61
62        let transport = if is_url {
63            zeph_mcp::McpTransport::Http {
64                url: target.to_owned(),
65                headers: std::collections::HashMap::new(),
66            }
67        } else {
68            zeph_mcp::McpTransport::Stdio {
69                command: target.to_owned(),
70                args: args[2..].iter().map(|&s| s.to_owned()).collect(),
71                env: std::collections::HashMap::new(),
72            }
73        };
74
75        let entry = zeph_mcp::ServerEntry {
76            id: args[0].to_owned(),
77            transport,
78            timeout: std::time::Duration::from_secs(30),
79            trust_level: zeph_mcp::McpTrustLevel::Untrusted,
80            tool_allowlist: None,
81            expected_tools: Vec::new(),
82            roots: Vec::new(),
83            tool_metadata: std::collections::HashMap::new(),
84            elicitation_enabled: false,
85            elicitation_timeout_secs: 120,
86            env_isolation: false,
87        };
88
89        match manager.add_server(&entry).await {
90            Ok(tools) => {
91                let count = tools.len();
92                self.mcp
93                    .server_outcomes
94                    .push(zeph_mcp::ServerConnectOutcome {
95                        id: entry.id.clone(),
96                        connected: true,
97                        tool_count: count,
98                        error: String::new(),
99                    });
100                self.mcp.tools.extend(tools);
101                self.mcp.sync_executor_tools();
102                self.mcp.pruning_cache.reset();
103                // Defer rebuild to check_tool_refresh (next turn) so this method
104                // stays Send-compatible for use in AgentAccess::handle_mcp.
105                self.mcp.pending_semantic_rebuild = true;
106                let mcp_total = self.mcp.tools.len();
107                let mcp_server_count = self.mcp.server_outcomes.len();
108                let mcp_connected_count = self
109                    .mcp
110                    .server_outcomes
111                    .iter()
112                    .filter(|o| o.connected)
113                    .count();
114                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
115                    .mcp
116                    .server_outcomes
117                    .iter()
118                    .map(|o| crate::metrics::McpServerStatus {
119                        id: o.id.clone(),
120                        status: if o.connected {
121                            crate::metrics::McpServerConnectionStatus::Connected
122                        } else {
123                            crate::metrics::McpServerConnectionStatus::Failed
124                        },
125                        tool_count: o.tool_count,
126                        error: o.error.clone(),
127                    })
128                    .collect();
129                self.update_metrics(|m| {
130                    m.mcp_tool_count = mcp_total;
131                    m.mcp_server_count = mcp_server_count;
132                    m.mcp_connected_count = mcp_connected_count;
133                    m.mcp_servers = mcp_servers;
134                });
135                Ok(format!(
136                    "Connected MCP server '{}' ({count} tool(s))",
137                    entry.id
138                ))
139            }
140            Err(e) => {
141                tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
142                Ok(format!("Failed to connect server '{}': {e}", entry.id))
143            }
144        }
145    }
146
147    async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
148        use std::fmt::Write;
149
150        let Some(manager) = self.mcp.manager.clone() else {
151            return Ok("MCP is not enabled.".to_owned());
152        };
153
154        let server_ids = manager.list_servers().await;
155        if server_ids.is_empty() {
156            return Ok("No MCP servers connected.".to_owned());
157        }
158
159        let mut output = String::from("Connected MCP servers:\n");
160        let mut total = 0usize;
161        for id in &server_ids {
162            let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
163            total += count;
164            let _ = writeln!(output, "- {id} ({count} tools)");
165        }
166        let _ = write!(output, "Total: {total} tool(s)");
167
168        Ok(output)
169    }
170
171    fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
172        use std::fmt::Write;
173
174        let Some(server_id) = server_id else {
175            return "Usage: /mcp tools <server_id>".to_owned();
176        };
177
178        let tools: Vec<_> = self
179            .mcp
180            .tools
181            .iter()
182            .filter(|t| t.server_id == server_id)
183            .collect();
184
185        if tools.is_empty() {
186            return format!("No tools found for server '{server_id}'.");
187        }
188
189        let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
190        for t in &tools {
191            if t.description.is_empty() {
192                let _ = writeln!(output, "- {}", t.name);
193            } else {
194                let _ = writeln!(output, "- {} — {}", t.name, t.description);
195            }
196        }
197        output
198    }
199
200    async fn handle_mcp_remove(
201        &mut self,
202        server_id: Option<&str>,
203    ) -> Result<String, super::error::AgentError> {
204        let Some(server_id) = server_id else {
205            return Ok("Usage: /mcp remove <id>".to_owned());
206        };
207
208        // Clone the Arc so no borrow of self.mcp.manager is held across .await.
209        let Some(manager) = self.mcp.manager.clone() else {
210            return Ok("MCP is not enabled.".to_owned());
211        };
212
213        match manager.remove_server(server_id).await {
214            Ok(()) => {
215                let before = self.mcp.tools.len();
216                self.mcp.tools.retain(|t| t.server_id != server_id);
217                let removed = before - self.mcp.tools.len();
218                self.mcp.server_outcomes.retain(|o| o.id != server_id);
219                self.mcp.sync_executor_tools();
220                self.mcp.pruning_cache.reset();
221                // Defer rebuild to check_tool_refresh (next turn) so this method
222                // stays Send-compatible for use in AgentAccess::handle_mcp.
223                self.mcp.pending_semantic_rebuild = true;
224                let mcp_total = self.mcp.tools.len();
225                let mcp_server_count = self.mcp.server_outcomes.len();
226                let mcp_connected_count = self
227                    .mcp
228                    .server_outcomes
229                    .iter()
230                    .filter(|o| o.connected)
231                    .count();
232                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
233                    .mcp
234                    .server_outcomes
235                    .iter()
236                    .map(|o| crate::metrics::McpServerStatus {
237                        id: o.id.clone(),
238                        status: if o.connected {
239                            crate::metrics::McpServerConnectionStatus::Connected
240                        } else {
241                            crate::metrics::McpServerConnectionStatus::Failed
242                        },
243                        tool_count: o.tool_count,
244                        error: o.error.clone(),
245                    })
246                    .collect();
247                self.update_metrics(|m| {
248                    m.mcp_tool_count = mcp_total;
249                    m.mcp_server_count = mcp_server_count;
250                    m.mcp_connected_count = mcp_connected_count;
251                    m.mcp_servers = mcp_servers;
252                    m.active_mcp_tools
253                        .retain(|name| !name.starts_with(&format!("{server_id}:")));
254                });
255                Ok(format!(
256                    "Disconnected MCP server '{server_id}' (removed {removed} tools)"
257                ))
258            }
259            Err(e) => {
260                tracing::warn!(server_id, "MCP remove failed: {e:#}");
261                Ok(format!("Failed to remove server '{server_id}': {e}"))
262            }
263        }
264    }
265
266    pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
267        let matched_tools = self.match_mcp_tools(query).await;
268        let active_mcp: Vec<String> = matched_tools
269            .iter()
270            .map(zeph_mcp::McpTool::qualified_name)
271            .collect();
272        let mcp_total = self.mcp.tools.len();
273        let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
274            let connected = self
275                .mcp
276                .tools
277                .iter()
278                .map(|t| &t.server_id)
279                .collect::<std::collections::HashSet<_>>()
280                .len();
281            (connected, connected)
282        } else {
283            let total = self.mcp.server_outcomes.len();
284            let connected = self
285                .mcp
286                .server_outcomes
287                .iter()
288                .filter(|o| o.connected)
289                .count();
290            (total, connected)
291        };
292        self.update_metrics(|m| {
293            m.active_mcp_tools = active_mcp;
294            m.mcp_tool_count = mcp_total;
295            m.mcp_server_count = mcp_server_count;
296            m.mcp_connected_count = mcp_connected_count;
297        });
298        if let Some(ref manager) = self.mcp.manager {
299            let instructions = manager.all_server_instructions().await;
300            if !instructions.is_empty() {
301                system_prompt.push_str("\n\n");
302                system_prompt.push_str(&instructions);
303            }
304        }
305        if !matched_tools.is_empty() {
306            let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
307            tracing::debug!(
308                skills = ?self.skill_state.active_skill_names,
309                mcp_tools = ?tool_names,
310                "matched items"
311            );
312            let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
313            if !tools_prompt.is_empty() {
314                system_prompt.push_str("\n\n");
315                system_prompt.push_str(&tools_prompt);
316            }
317        }
318    }
319
320    async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
321        let Some(ref registry) = self.mcp.registry else {
322            return self.mcp.tools.clone();
323        };
324        let provider = self.embedding_provider.clone();
325        registry
326            .search(query, self.skill_state.max_active_skills, |text| {
327                let owned = text.to_owned();
328                let p = provider.clone();
329                Box::pin(async move { p.embed(&owned).await })
330            })
331            .await
332    }
333
334    /// Poll the watch receiver for tool list updates from `tools/list_changed` notifications,
335    /// and process any deferred semantic index rebuild requests.
336    ///
337    /// Called once per agent turn, before processing user input.  Two triggers cause a rebuild:
338    /// - A `tools/list_changed` notification from an MCP server (via `tool_rx`).
339    /// - `pending_semantic_rebuild == true`, set by `/mcp add` or `/mcp remove` when dispatched
340    ///   via `AgentAccess::handle_mcp` (which cannot call `rebuild_semantic_index` directly
341    ///   because the future would be `!Send`).
342    ///
343    /// If neither trigger fires, this is a no-op.
344    pub(super) async fn check_tool_refresh(&mut self) {
345        // Handle deferred rebuild from /mcp add|remove via AgentAccess.
346        if self.mcp.pending_semantic_rebuild {
347            self.mcp.pending_semantic_rebuild = false;
348            self.rebuild_semantic_index().await;
349            self.sync_mcp_registry().await;
350            let mcp_total = self.mcp.tools.len();
351            let mcp_servers = self
352                .mcp
353                .tools
354                .iter()
355                .map(|t| &t.server_id)
356                .collect::<std::collections::HashSet<_>>()
357                .len();
358            self.update_metrics(|m| {
359                m.mcp_tool_count = mcp_total;
360                m.mcp_server_count = mcp_servers;
361            });
362        }
363
364        let Some(ref mut rx) = self.mcp.tool_rx else {
365            return;
366        };
367        if !rx.has_changed().unwrap_or(false) {
368            return;
369        }
370        let new_tools = rx.borrow_and_update().clone();
371        if new_tools.is_empty() {
372            // Guard against replacing a non-empty initial tool list with the watch's empty
373            // initial value. The watch is only updated after a real tools/list_changed event.
374            return;
375        }
376        tracing::info!(
377            tools = new_tools.len(),
378            "tools/list_changed: agent tool list refreshed"
379        );
380        self.mcp.tools = new_tools;
381        self.mcp.sync_executor_tools();
382        self.mcp.pruning_cache.reset();
383        self.rebuild_semantic_index().await;
384        self.sync_mcp_registry().await;
385        let mcp_total = self.mcp.tools.len();
386        let mcp_servers = self
387            .mcp
388            .tools
389            .iter()
390            .map(|t| &t.server_id)
391            .collect::<std::collections::HashSet<_>>()
392            .len();
393        self.update_metrics(|m| {
394            m.mcp_tool_count = mcp_total;
395            m.mcp_server_count = mcp_servers;
396        });
397    }
398
399    pub(super) async fn sync_mcp_registry(&mut self) {
400        if self.mcp.registry.is_none() {
401            return;
402        }
403        if !self.embedding_provider.supports_embeddings() {
404            return;
405        }
406        // Clone tools before .await to avoid holding &self.mcp.tools across an await point.
407        let tools = self.mcp.tools.clone();
408        let provider = self.embedding_provider.clone();
409        let embedding_model = self.skill_state.embedding_model.clone();
410        let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
411        let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
412            let owned = text.to_owned();
413            let p = provider.clone();
414            Box::pin(async move {
415                if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
416                    result
417                } else {
418                    tracing::warn!(
419                        timeout_secs = embed_timeout.as_secs(),
420                        "MCP registry: embedding timed out"
421                    );
422                    Err(zeph_llm::LlmError::Timeout)
423                }
424            })
425        };
426        // Take registry out of self to avoid holding &mut self.mcp.registry across .await.
427        // No early returns between take() and put-back — the await is the only yield point here.
428        let Some(mut registry) = self.mcp.registry.take() else {
429            return;
430        };
431        if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
432            tracing::warn!("failed to sync MCP tool registry: {e:#}");
433        }
434        self.mcp.registry = Some(registry);
435    }
436
437    /// Build (or rebuild) the in-memory semantic tool index for embedding-based discovery.
438    /// Build the initial semantic tool index after agent construction.
439    ///
440    /// Must be called once after `with_mcp` and `with_mcp_discovery` are applied,
441    /// before the first user turn.  Subsequent rebuilds happen automatically on
442    /// tool list change events (`check_tool_refresh`, `/mcp add`, `/mcp remove`).
443    pub async fn init_semantic_index(&mut self) {
444        self.rebuild_semantic_index().await;
445    }
446
447    /// Drain and process all pending elicitation requests without blocking.
448    ///
449    /// Call this at the start of each turn and between tool calls to prevent
450    /// elicitation events from accumulating while the agent loop is busy.
451    pub(super) async fn process_pending_elicitations(&mut self) {
452        loop {
453            let Some(ref mut rx) = self.mcp.elicitation_rx else {
454                return;
455            };
456            match rx.try_recv() {
457                Ok(event) => {
458                    self.handle_elicitation_event(event).await;
459                }
460                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
461                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
462                    self.mcp.elicitation_rx = None;
463                    return;
464                }
465            }
466        }
467    }
468
469    /// Handle a single elicitation event by routing it to the active channel.
470    pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
471        use crate::channel::{ElicitationRequest, ElicitationResponse};
472
473        let decline = CreateElicitationResult {
474            action: ElicitationAction::Decline,
475            content: None,
476            meta: None,
477        };
478
479        let channel_request = match &event.request {
480            rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
481                message,
482                requested_schema,
483                ..
484            } => {
485                let fields = build_elicitation_fields(requested_schema);
486                ElicitationRequest {
487                    server_name: event.server_id.clone(),
488                    message: sanitize_elicitation_message(message),
489                    fields,
490                }
491            }
492            rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
493                // URL elicitation not supported in phase 1 — decline.
494                tracing::debug!(
495                    server_id = event.server_id,
496                    "URL elicitation not supported, declining"
497                );
498                let _ = event.response_tx.send(decline);
499                return;
500            }
501        };
502
503        if self.mcp.elicitation_warn_sensitive_fields {
504            let sensitive: Vec<&str> = channel_request
505                .fields
506                .iter()
507                .filter(|f| is_sensitive_field(&f.name))
508                .map(|f| f.name.as_str())
509                .collect();
510            if !sensitive.is_empty() {
511                let fields_list = sensitive.join(", ");
512                let warning = format!(
513                    "Warning: [{}] is requesting sensitive information (field: {}). \
514                     Only proceed if you trust this server.",
515                    channel_request.server_name, fields_list,
516                );
517                tracing::warn!(
518                    server_id = event.server_id,
519                    fields = %fields_list,
520                    "elicitation requests sensitive fields"
521                );
522                let _ = self.channel.send(&warning).await;
523            }
524        }
525
526        let _ = self
527            .channel
528            .send_status("MCP server requesting input…")
529            .await;
530        let response = match self.channel.elicit(channel_request).await {
531            Ok(r) => r,
532            Err(e) => {
533                tracing::warn!(
534                    server_id = event.server_id,
535                    "elicitation channel error: {e:#}"
536                );
537                let _ = self.channel.send_status("").await;
538                let _ = event.response_tx.send(decline);
539                return;
540            }
541        };
542        let _ = self.channel.send_status("").await;
543
544        let result = match response {
545            ElicitationResponse::Accepted(value) => CreateElicitationResult {
546                action: ElicitationAction::Accept,
547                content: Some(value),
548                meta: None,
549            },
550            ElicitationResponse::Declined => CreateElicitationResult {
551                action: ElicitationAction::Decline,
552                content: None,
553                meta: None,
554            },
555            ElicitationResponse::Cancelled => CreateElicitationResult {
556                action: ElicitationAction::Cancel,
557                content: None,
558                meta: None,
559            },
560        };
561
562        if event.response_tx.send(result).is_err() {
563            tracing::warn!(
564                server_id = event.server_id,
565                "elicitation response dropped — handler disconnected"
566            );
567        }
568    }
569
570    /// Rebuild the in-memory semantic tool index.
571    ///
572    /// Only runs when `discovery_strategy == Embedding`.  On failure (all embeddings fail),
573    /// sets `semantic_index = None` and logs at WARN — the caller falls back to all tools.
574    ///
575    /// Called at:
576    /// - initial setup via `init_semantic_index()`
577    /// - `tools/list_changed` notification
578    /// - `/mcp add` and `/mcp remove`
579    pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
580        if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
581            return;
582        }
583
584        if self.mcp.tools.is_empty() {
585            self.mcp.semantic_index = None;
586            return;
587        }
588
589        // Resolve embedding provider: dedicated discovery provider → primary embedding provider.
590        let provider = self
591            .mcp
592            .discovery_provider
593            .clone()
594            .unwrap_or_else(|| self.embedding_provider.clone());
595
596        let inner_embed = provider.embed_fn();
597        let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
598        let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
599            let fut = inner_embed(text);
600            Box::pin(async move {
601                if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
602                    result
603                } else {
604                    tracing::warn!(
605                        timeout_secs = embed_timeout.as_secs(),
606                        "semantic index: embedding probe timed out"
607                    );
608                    Err(zeph_llm::LlmError::Timeout)
609                }
610            })
611        };
612
613        // Clone tools before .await to avoid holding &self.mcp.tools across an await point.
614        let tools = self.mcp.tools.clone();
615        match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
616            Ok(idx) => {
617                tracing::info!(
618                    indexed = idx.len(),
619                    total = self.mcp.tools.len(),
620                    "semantic tool index built"
621                );
622                self.mcp.semantic_index = Some(idx);
623            }
624            Err(e) => {
625                tracing::warn!(
626                    "semantic tool index build failed, falling back to all tools: {e:#}"
627                );
628                self.mcp.semantic_index = None;
629            }
630        }
631    }
632}
633
634/// Convert an rmcp `ElicitationSchema` into channel-agnostic `ElicitationField` list.
635fn build_elicitation_fields(
636    schema: &rmcp::model::ElicitationSchema,
637) -> Vec<crate::channel::ElicitationField> {
638    use crate::channel::{ElicitationField, ElicitationFieldType};
639    use rmcp::model::PrimitiveSchema;
640
641    schema
642        .properties
643        .iter()
644        .map(|(name, prop)| {
645            // Extract field type and description by serializing the PrimitiveSchema to JSON
646            // and reading the discriminator field.  This avoids deep-matching the nested
647            // EnumSchema / StringSchema / … variants of rmcp's type-safe schema hierarchy.
648            let json = serde_json::to_value(prop).unwrap_or_default();
649            let description = json
650                .get("description")
651                .and_then(|v| v.as_str())
652                .map(String::from);
653
654            let field_type = match prop {
655                PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
656                PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
657                PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
658                PrimitiveSchema::String(_) => ElicitationFieldType::String,
659                PrimitiveSchema::Enum(_) => {
660                    // Extract enum values from the serialized form.  All EnumSchema variants
661                    // serialise their allowed values under "enum" or inside "items.enum".
662                    let vals = json
663                        .get("enum")
664                        .and_then(|v| v.as_array())
665                        .map(|arr| {
666                            arr.iter()
667                                .filter_map(|v| v.as_str())
668                                .map(String::from)
669                                .collect::<Vec<_>>()
670                        })
671                        .unwrap_or_default();
672                    ElicitationFieldType::Enum(vals)
673                }
674            };
675            let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
676            ElicitationField {
677                name: name.clone(),
678                description,
679                field_type,
680                required,
681            }
682        })
683        .collect()
684}
685
686/// Sensitive field name patterns (case-insensitive substring match).
687const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
688    "password",
689    "passwd",
690    "token",
691    "secret",
692    "key",
693    "credential",
694    "apikey",
695    "api_key",
696    "auth",
697    "authorization",
698    "private",
699    "passphrase",
700    "pin",
701];
702
703/// Returns `true` when `field_name` matches any sensitive pattern (case-insensitive).
704fn is_sensitive_field(field_name: &str) -> bool {
705    let lower = field_name.to_lowercase();
706    SENSITIVE_FIELD_PATTERNS
707        .iter()
708        .any(|pattern| lower.contains(pattern))
709}
710
711/// Sanitize an elicitation message: cap length (in chars, not bytes) and strip control chars.
712fn sanitize_elicitation_message(message: &str) -> String {
713    const MAX_CHARS: usize = 500;
714    // Collect up to MAX_CHARS chars, filtering control characters that could manipulate terminals.
715    message
716        .chars()
717        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
718        .take(MAX_CHARS)
719        .collect()
720}
721
722#[cfg(test)]
723mod tests {
724    use super::super::agent_tests::{
725        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
726    };
727    use super::*;
728
729    #[tokio::test]
730    async fn handle_mcp_command_unknown_subcommand_shows_usage() {
731        let provider = mock_provider(vec![]);
732        let channel = MockChannel::new(vec![]);
733        let registry = create_test_registry();
734        let executor = MockToolExecutor::no_tools();
735        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
736
737        let result = agent.handle_mcp_command("unknown").await.unwrap();
738        assert!(
739            result.contains("Usage: /mcp"),
740            "expected usage message, got: {result:?}"
741        );
742    }
743
744    #[tokio::test]
745    async fn handle_mcp_list_no_manager_shows_disabled() {
746        let provider = mock_provider(vec![]);
747        let channel = MockChannel::new(vec![]);
748        let registry = create_test_registry();
749        let executor = MockToolExecutor::no_tools();
750        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
751
752        let result = agent.handle_mcp_command("list").await.unwrap();
753        assert!(
754            result.contains("MCP is not enabled"),
755            "expected not-enabled message, got: {result:?}"
756        );
757    }
758
759    #[tokio::test]
760    async fn handle_mcp_tools_no_server_id_shows_usage() {
761        let provider = mock_provider(vec![]);
762        let channel = MockChannel::new(vec![]);
763        let registry = create_test_registry();
764        let executor = MockToolExecutor::no_tools();
765        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
766
767        let result = agent.handle_mcp_command("tools").await.unwrap();
768        assert!(
769            result.contains("Usage: /mcp tools"),
770            "expected tools usage message, got: {result:?}"
771        );
772    }
773
774    #[tokio::test]
775    async fn handle_mcp_remove_no_server_id_shows_usage() {
776        let provider = mock_provider(vec![]);
777        let channel = MockChannel::new(vec![]);
778        let registry = create_test_registry();
779        let executor = MockToolExecutor::no_tools();
780        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
781
782        let result = agent.handle_mcp_command("remove").await.unwrap();
783        assert!(
784            result.contains("Usage: /mcp remove"),
785            "expected remove usage message, got: {result:?}"
786        );
787    }
788
789    #[tokio::test]
790    async fn handle_mcp_remove_no_manager_shows_disabled() {
791        let provider = mock_provider(vec![]);
792        let channel = MockChannel::new(vec![]);
793        let registry = create_test_registry();
794        let executor = MockToolExecutor::no_tools();
795        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
796
797        let result = agent.handle_mcp_command("remove my-server").await.unwrap();
798        assert!(
799            result.contains("MCP is not enabled"),
800            "expected not-enabled message, got: {result:?}"
801        );
802    }
803
804    #[tokio::test]
805    async fn handle_mcp_add_insufficient_args_shows_usage() {
806        let provider = mock_provider(vec![]);
807        let channel = MockChannel::new(vec![]);
808        let registry = create_test_registry();
809        let executor = MockToolExecutor::no_tools();
810        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
811
812        // "add" with only 1 arg (needs at least 2: id + command)
813        let result = agent.handle_mcp_command("add server-id").await.unwrap();
814        assert!(
815            result.contains("Usage: /mcp add"),
816            "expected add usage message, got: {result:?}"
817        );
818    }
819
820    #[tokio::test]
821    async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
822        let provider = mock_provider(vec![]);
823        let channel = MockChannel::new(vec![]);
824        let registry = create_test_registry();
825        let executor = MockToolExecutor::no_tools();
826        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
827
828        // mcp.tools is empty, so any server will have no tools
829        let result = agent
830            .handle_mcp_command("tools nonexistent-server")
831            .await
832            .unwrap();
833        assert!(
834            result.contains("No tools found"),
835            "expected no-tools message, got: {result:?}"
836        );
837    }
838
839    #[tokio::test]
840    async fn mcp_tool_count_starts_at_zero() {
841        let provider = mock_provider(vec![]);
842        let channel = MockChannel::new(vec![]);
843        let registry = create_test_registry();
844        let executor = MockToolExecutor::no_tools();
845        let agent = Agent::new(provider, channel, registry, None, 5, executor);
846
847        assert_eq!(agent.mcp.tool_count(), 0);
848    }
849
850    #[tokio::test]
851    async fn check_tool_refresh_no_rx_is_noop() {
852        let provider = mock_provider(vec![]);
853        let channel = MockChannel::new(vec![]);
854        let registry = create_test_registry();
855        let executor = MockToolExecutor::no_tools();
856        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
857        // No tool_rx set; check_tool_refresh should be a no-op.
858        agent.check_tool_refresh().await;
859        assert_eq!(agent.mcp.tool_count(), 0);
860    }
861
862    #[tokio::test]
863    async fn check_tool_refresh_no_change_is_noop() {
864        let provider = mock_provider(vec![]);
865        let channel = MockChannel::new(vec![]);
866        let registry = create_test_registry();
867        let executor = MockToolExecutor::no_tools();
868        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
869
870        let (tx, rx) = tokio::sync::watch::channel(Vec::new());
871        agent.mcp.tool_rx = Some(rx);
872        // No changes sent; has_changed() returns false.
873        agent.check_tool_refresh().await;
874        assert_eq!(agent.mcp.tool_count(), 0);
875        drop(tx);
876    }
877
878    #[tokio::test]
879    async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
880        let provider = mock_provider(vec![]);
881        let channel = MockChannel::new(vec![]);
882        let registry = create_test_registry();
883        let executor = MockToolExecutor::no_tools();
884        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
885        agent.mcp.tools = vec![zeph_mcp::McpTool {
886            server_id: "srv".into(),
887            name: "existing_tool".into(),
888            description: String::new(),
889            input_schema: serde_json::json!({}),
890            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
891        }];
892
893        let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
894        agent.mcp.tool_rx = Some(rx);
895        // has_changed() is false for a fresh receiver; tools unchanged.
896        agent.check_tool_refresh().await;
897        assert_eq!(agent.mcp.tool_count(), 1);
898    }
899
900    #[tokio::test]
901    async fn check_tool_refresh_applies_update() {
902        let provider = mock_provider(vec![]);
903        let channel = MockChannel::new(vec![]);
904        let registry = create_test_registry();
905        let executor = MockToolExecutor::no_tools();
906        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
907
908        let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
909        agent.mcp.tool_rx = Some(rx);
910
911        let new_tools = vec![zeph_mcp::McpTool {
912            server_id: "srv".into(),
913            name: "refreshed_tool".into(),
914            description: String::new(),
915            input_schema: serde_json::json!({}),
916            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
917        }];
918        tx.send(new_tools).unwrap();
919
920        agent.check_tool_refresh().await;
921        assert_eq!(agent.mcp.tool_count(), 1);
922        assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
923    }
924
925    #[test]
926    fn sanitize_elicitation_message_strips_control_chars() {
927        let input = "hello\x01world\x1b[31mred\x1b[0m";
928        let output = sanitize_elicitation_message(input);
929        assert!(!output.contains('\x01'));
930        assert!(!output.contains('\x1b'));
931        assert!(output.contains("hello"));
932        assert!(output.contains("world"));
933    }
934
935    #[test]
936    fn sanitize_elicitation_message_preserves_newline_and_tab() {
937        let input = "line1\nline2\ttabbed";
938        let output = sanitize_elicitation_message(input);
939        assert_eq!(output, "line1\nline2\ttabbed");
940    }
941
942    #[test]
943    fn sanitize_elicitation_message_caps_at_500_chars() {
944        // Build a 600-char ASCII string — no multi-byte boundary issue.
945        let input: String = "a".repeat(600);
946        let output = sanitize_elicitation_message(&input);
947        assert_eq!(output.chars().count(), 500);
948    }
949
950    #[test]
951    fn sanitize_elicitation_message_handles_multibyte_boundary() {
952        // "é" is 2 bytes.  Build a string where a naive &str[..500] would panic.
953        let input: String = "é".repeat(300); // 300 chars = 600 bytes
954        let output = sanitize_elicitation_message(&input);
955        // Should truncate to exactly 500 chars without panic.
956        assert_eq!(output.chars().count(), 300);
957    }
958
959    #[test]
960    fn build_elicitation_fields_maps_primitive_types() {
961        use crate::channel::ElicitationFieldType;
962        use rmcp::model::{
963            BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
964            StringSchema,
965        };
966        use std::collections::BTreeMap;
967
968        let mut props = BTreeMap::new();
969        props.insert(
970            "flag".to_owned(),
971            PrimitiveSchema::Boolean(BooleanSchema::new()),
972        );
973        props.insert(
974            "count".to_owned(),
975            PrimitiveSchema::Integer(IntegerSchema::new()),
976        );
977        props.insert(
978            "ratio".to_owned(),
979            PrimitiveSchema::Number(NumberSchema::new()),
980        );
981        props.insert(
982            "name".to_owned(),
983            PrimitiveSchema::String(StringSchema::new()),
984        );
985
986        let schema = ElicitationSchema::new(props);
987        let fields = build_elicitation_fields(&schema);
988
989        let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
990        assert!(matches!(
991            get("flag").field_type,
992            ElicitationFieldType::Boolean
993        ));
994        assert!(matches!(
995            get("count").field_type,
996            ElicitationFieldType::Integer
997        ));
998        assert!(matches!(
999            get("ratio").field_type,
1000            ElicitationFieldType::Number
1001        ));
1002        assert!(matches!(
1003            get("name").field_type,
1004            ElicitationFieldType::String
1005        ));
1006    }
1007
1008    #[test]
1009    fn build_elicitation_fields_required_flag() {
1010        use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1011        use std::collections::BTreeMap;
1012
1013        let mut props = BTreeMap::new();
1014        props.insert(
1015            "req".to_owned(),
1016            PrimitiveSchema::String(StringSchema::new()),
1017        );
1018        props.insert(
1019            "opt".to_owned(),
1020            PrimitiveSchema::String(StringSchema::new()),
1021        );
1022
1023        let mut schema = ElicitationSchema::new(props);
1024        schema.required = Some(vec!["req".to_owned()]);
1025
1026        let fields = build_elicitation_fields(&schema);
1027        let req = fields.iter().find(|f| f.name == "req").unwrap();
1028        let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1029        assert!(req.required);
1030        assert!(!opt.required);
1031    }
1032
1033    #[test]
1034    fn is_sensitive_field_detects_common_patterns() {
1035        assert!(is_sensitive_field("password"));
1036        assert!(is_sensitive_field("PASSWORD"));
1037        assert!(is_sensitive_field("user_password"));
1038        assert!(is_sensitive_field("api_token"));
1039        assert!(is_sensitive_field("SECRET_KEY"));
1040        assert!(is_sensitive_field("auth_header"));
1041        assert!(is_sensitive_field("private_key"));
1042    }
1043
1044    #[test]
1045    fn is_sensitive_field_allows_non_sensitive_names() {
1046        assert!(!is_sensitive_field("username"));
1047        assert!(!is_sensitive_field("email"));
1048        assert!(!is_sensitive_field("message"));
1049        assert!(!is_sensitive_field("description"));
1050        assert!(!is_sensitive_field("subject"));
1051    }
1052}