Skip to main content

zeph_core/agent/
mcp.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9    /// Dispatch a `/mcp` subcommand, returning the output as a `String`.
10    ///
11    /// All output is collected into the returned string; no channel sends are
12    /// performed.  This makes the future `Send`-compatible for use in
13    /// `AgentAccess::handle_mcp`.
14    pub(super) async fn handle_mcp_command(
15        &mut self,
16        args: &str,
17    ) -> Result<String, super::error::AgentError> {
18        let parts: Vec<&str> = args.split_whitespace().collect();
19        match parts.first().copied() {
20            Some("add") => self.handle_mcp_add(&parts[1..]).await,
21            Some("list") => self.handle_mcp_list().await,
22            Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
23            Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
24            _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
25        }
26    }
27
28    async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
29        if args.len() < 2 {
30            return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
31        }
32
33        // Clone the Arc so no borrow of self.services.mcp.manager is held across .await.
34        let Some(manager) = self.services.mcp.manager.clone() else {
35            return Ok("MCP is not enabled.".to_owned());
36        };
37
38        let target = args[1];
39        if let Some(err) = validate_mcp_command(target, &self.services.mcp.allowed_commands) {
40            return Ok(err);
41        }
42
43        // SEC-MCP-03: enforce server limit
44        let current_count = manager.list_servers().await.len();
45        if current_count >= self.services.mcp.max_dynamic {
46            return Ok(format!(
47                "Server limit reached ({}/{}).",
48                current_count, self.services.mcp.max_dynamic
49            ));
50        }
51
52        let entry = build_server_entry(args[0], target, &args[2..]);
53
54        match manager.add_server(&entry).await {
55            Ok(tools) => {
56                let count = tools.len();
57                self.services
58                    .mcp
59                    .server_outcomes
60                    .push(zeph_mcp::ServerConnectOutcome {
61                        id: entry.id.clone(),
62                        connected: true,
63                        tool_count: count,
64                        error: String::new(),
65                    });
66                self.services.mcp.tools.extend(tools);
67                self.services.mcp.sync_executor_tools();
68                self.services.mcp.pruning_cache.reset();
69                // Defer rebuild to check_tool_refresh (next turn) so this method
70                // stays Send-compatible for use in AgentAccess::handle_mcp.
71                self.services.mcp.pending_semantic_rebuild = true;
72                self.update_mcp_metrics();
73                Ok(format!(
74                    "Connected MCP server '{}' ({count} tool(s))",
75                    entry.id
76                ))
77            }
78            Err(e) => {
79                tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
80                Ok(format!("Failed to connect server '{}': {e}", entry.id))
81            }
82        }
83    }
84
85    async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
86        use std::fmt::Write;
87
88        let Some(manager) = self.services.mcp.manager.clone() else {
89            return Ok("MCP is not enabled.".to_owned());
90        };
91
92        let server_ids = manager.list_servers().await;
93        if server_ids.is_empty() {
94            return Ok("No MCP servers connected.".to_owned());
95        }
96
97        let mut output = String::from("Connected MCP servers:\n");
98        let mut total = 0usize;
99        for id in &server_ids {
100            let count = self
101                .services
102                .mcp
103                .tools
104                .iter()
105                .filter(|t| t.server_id == *id)
106                .count();
107            total += count;
108            let _ = writeln!(output, "- {id} ({count} tools)");
109        }
110        let _ = write!(output, "Total: {total} tool(s)");
111
112        Ok(output)
113    }
114
115    fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
116        use std::fmt::Write;
117
118        let Some(server_id) = server_id else {
119            return "Usage: /mcp tools <server_id>".to_owned();
120        };
121
122        let tools: Vec<_> = self
123            .services
124            .mcp
125            .tools
126            .iter()
127            .filter(|t| t.server_id == server_id)
128            .collect();
129
130        if tools.is_empty() {
131            return format!("No tools found for server '{server_id}'.");
132        }
133
134        let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
135        for t in &tools {
136            if t.description.is_empty() {
137                let _ = writeln!(output, "- {}", t.name);
138            } else {
139                let _ = writeln!(output, "- {} — {}", t.name, t.description);
140            }
141        }
142        output
143    }
144
145    async fn handle_mcp_remove(
146        &mut self,
147        server_id: Option<&str>,
148    ) -> Result<String, super::error::AgentError> {
149        let Some(server_id) = server_id else {
150            return Ok("Usage: /mcp remove <id>".to_owned());
151        };
152
153        // Clone the Arc so no borrow of self.services.mcp.manager is held across .await.
154        let Some(manager) = self.services.mcp.manager.clone() else {
155            return Ok("MCP is not enabled.".to_owned());
156        };
157
158        match manager.remove_server(server_id).await {
159            Ok(()) => {
160                let before = self.services.mcp.tools.len();
161                self.services.mcp.tools.retain(|t| t.server_id != server_id);
162                let removed = before - self.services.mcp.tools.len();
163                self.services
164                    .mcp
165                    .server_outcomes
166                    .retain(|o| o.id != server_id);
167                self.services.mcp.sync_executor_tools();
168                self.services.mcp.pruning_cache.reset();
169                // Defer rebuild to check_tool_refresh (next turn) so this method
170                // stays Send-compatible for use in AgentAccess::handle_mcp.
171                self.services.mcp.pending_semantic_rebuild = true;
172                self.update_mcp_metrics();
173                let sid = server_id.to_owned();
174                self.update_metrics(|m| {
175                    m.active_mcp_tools
176                        .retain(|name| !name.starts_with(&format!("{sid}:")));
177                });
178                Ok(format!(
179                    "Disconnected MCP server '{server_id}' (removed {removed} tools)"
180                ))
181            }
182            Err(e) => {
183                tracing::warn!(server_id, "MCP remove failed: {e:#}");
184                Ok(format!("Failed to remove server '{server_id}': {e}"))
185            }
186        }
187    }
188
189    pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
190        let matched_tools = self.match_mcp_tools(query).await;
191        let active_mcp: Vec<String> = matched_tools
192            .iter()
193            .map(zeph_mcp::McpTool::qualified_name)
194            .collect();
195        let mcp_total = self.services.mcp.tools.len();
196        let (mcp_server_count, mcp_connected_count) =
197            if self.services.mcp.server_outcomes.is_empty() {
198                let connected = self
199                    .services
200                    .mcp
201                    .tools
202                    .iter()
203                    .map(|t| &t.server_id)
204                    .collect::<std::collections::HashSet<_>>()
205                    .len();
206                (connected, connected)
207            } else {
208                let total = self.services.mcp.server_outcomes.len();
209                let connected = self
210                    .services
211                    .mcp
212                    .server_outcomes
213                    .iter()
214                    .filter(|o| o.connected)
215                    .count();
216                (total, connected)
217            };
218        self.update_metrics(|m| {
219            m.active_mcp_tools = active_mcp;
220            m.mcp_tool_count = mcp_total;
221            m.mcp_server_count = mcp_server_count;
222            m.mcp_connected_count = mcp_connected_count;
223        });
224        if let Some(ref manager) = self.services.mcp.manager {
225            let instructions = manager.all_server_instructions().await;
226            if !instructions.is_empty() {
227                system_prompt.push_str("\n\n");
228                system_prompt.push_str(&instructions);
229            }
230        }
231        if !matched_tools.is_empty() {
232            let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
233            tracing::debug!(
234                skills = ?self.services.skill.active_skill_names,
235                mcp_tools = ?tool_names,
236                "matched items"
237            );
238            let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
239            if !tools_prompt.is_empty() {
240                system_prompt.push_str("\n\n");
241                system_prompt.push_str(&tools_prompt);
242            }
243        }
244    }
245
246    async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
247        let Some(ref registry) = self.services.mcp.registry else {
248            return self.services.mcp.tools.clone();
249        };
250        let provider = self.embedding_provider.clone();
251        registry
252            .search(query, self.services.skill.max_active_skills, |text| {
253                let owned = text.to_owned();
254                let p = provider.clone();
255                Box::pin(async move { p.embed(&owned).await })
256            })
257            .await
258    }
259
260    /// Poll the watch receiver for tool list updates from `tools/list_changed` notifications,
261    /// and process any deferred semantic index rebuild requests.
262    ///
263    /// Called once per agent turn, before processing user input.  Two triggers cause a rebuild:
264    /// - A `tools/list_changed` notification from an MCP server (via `tool_rx`).
265    /// - `pending_semantic_rebuild == true`, set by `/mcp add` or `/mcp remove` when dispatched
266    ///   via `AgentAccess::handle_mcp` (which cannot call `rebuild_semantic_index` directly
267    ///   because the future would be `!Send`).
268    ///
269    /// If neither trigger fires, this is a no-op.
270    pub(super) async fn check_tool_refresh(&mut self) {
271        // Handle deferred rebuild from /mcp add|remove via AgentAccess.
272        if self.services.mcp.pending_semantic_rebuild {
273            self.services.mcp.pending_semantic_rebuild = false;
274            self.rebuild_semantic_index().await;
275            self.sync_mcp_registry().await;
276            let mcp_total = self.services.mcp.tools.len();
277            let mcp_servers = self
278                .services
279                .mcp
280                .tools
281                .iter()
282                .map(|t| &t.server_id)
283                .collect::<std::collections::HashSet<_>>()
284                .len();
285            self.update_metrics(|m| {
286                m.mcp_tool_count = mcp_total;
287                m.mcp_server_count = mcp_servers;
288            });
289        }
290
291        let Some(ref mut rx) = self.services.mcp.tool_rx else {
292            return;
293        };
294        if !rx.has_changed().unwrap_or(false) {
295            return;
296        }
297        let new_tools = rx.borrow_and_update().clone();
298        if new_tools.is_empty() {
299            // Guard against replacing a non-empty initial tool list with the watch's empty
300            // initial value. The watch is only updated after a real tools/list_changed event.
301            return;
302        }
303        tracing::info!(
304            tools = new_tools.len(),
305            "tools/list_changed: agent tool list refreshed"
306        );
307        self.services.mcp.tools = new_tools;
308        self.services.mcp.sync_executor_tools();
309        self.services.mcp.pruning_cache.reset();
310        self.rebuild_semantic_index().await;
311        self.sync_mcp_registry().await;
312        let mcp_total = self.services.mcp.tools.len();
313        let mcp_servers = self
314            .services
315            .mcp
316            .tools
317            .iter()
318            .map(|t| &t.server_id)
319            .collect::<std::collections::HashSet<_>>()
320            .len();
321        self.update_metrics(|m| {
322            m.mcp_tool_count = mcp_total;
323            m.mcp_server_count = mcp_servers;
324        });
325    }
326
327    pub(super) async fn sync_mcp_registry(&mut self) {
328        if self.services.mcp.registry.is_none() {
329            return;
330        }
331        if !self.embedding_provider.supports_embeddings() {
332            return;
333        }
334        // Clone tools before .await to avoid holding &self.services.mcp.tools across an await point.
335        let tools = self.services.mcp.tools.clone();
336        let provider = self.embedding_provider.clone();
337        let embedding_model = self.services.skill.embedding_model.clone();
338        let embed_timeout =
339            std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
340        let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
341            let owned = text.to_owned();
342            let p = provider.clone();
343            Box::pin(async move {
344                if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
345                    result
346                } else {
347                    tracing::warn!(
348                        timeout_secs = embed_timeout.as_secs(),
349                        "MCP registry: embedding timed out"
350                    );
351                    Err(zeph_llm::LlmError::Timeout)
352                }
353            })
354        };
355        // Take registry out of self to avoid holding &mut self.services.mcp.registry across .await.
356        // No early returns between take() and put-back — the await is the only yield point here.
357        let Some(mut registry) = self.services.mcp.registry.take() else {
358            return;
359        };
360        if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
361            tracing::warn!("failed to sync MCP tool registry: {e:#}");
362        }
363        self.services.mcp.registry = Some(registry);
364    }
365
366    /// Build (or rebuild) the in-memory semantic tool index for embedding-based discovery.
367    /// Build the initial semantic tool index after agent construction.
368    ///
369    /// Must be called once after `with_mcp` and `with_mcp_discovery` are applied,
370    /// before the first user turn.  Subsequent rebuilds happen automatically on
371    /// tool list change events (`check_tool_refresh`, `/mcp add`, `/mcp remove`).
372    pub async fn init_semantic_index(&mut self) {
373        self.rebuild_semantic_index().await;
374    }
375
376    /// Drain and process all pending elicitation requests without blocking.
377    ///
378    /// Call this at the start of each turn and between tool calls to prevent
379    /// elicitation events from accumulating while the agent loop is busy.
380    pub(super) async fn process_pending_elicitations(&mut self) {
381        loop {
382            let Some(ref mut rx) = self.services.mcp.elicitation_rx else {
383                return;
384            };
385            match rx.try_recv() {
386                Ok(event) => {
387                    self.handle_elicitation_event(event).await;
388                }
389                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
390                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
391                    self.services.mcp.elicitation_rx = None;
392                    return;
393                }
394            }
395        }
396    }
397
398    /// Handle a single elicitation event by routing it to the active channel.
399    pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
400        use crate::channel::{ElicitationRequest, ElicitationResponse};
401
402        let decline = CreateElicitationResult {
403            action: ElicitationAction::Decline,
404            content: None,
405            meta: None,
406        };
407
408        let channel_request = match &event.request {
409            rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
410                message,
411                requested_schema,
412                ..
413            } => {
414                let fields = build_elicitation_fields(requested_schema);
415                ElicitationRequest {
416                    server_name: event.server_id.clone(),
417                    message: sanitize_elicitation_message(message),
418                    fields,
419                }
420            }
421            rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
422                // URL elicitation not supported in phase 1 — decline.
423                tracing::debug!(
424                    server_id = event.server_id,
425                    "URL elicitation not supported, declining"
426                );
427                let _ = event.response_tx.send(decline);
428                return;
429            }
430        };
431
432        if self.services.mcp.elicitation_warn_sensitive_fields {
433            let sensitive: Vec<&str> = channel_request
434                .fields
435                .iter()
436                .filter(|f| is_sensitive_field(&f.name))
437                .map(|f| f.name.as_str())
438                .collect();
439            if !sensitive.is_empty() {
440                let fields_list = sensitive.join(", ");
441                let warning = format!(
442                    "Warning: [{}] is requesting sensitive information (field: {}). \
443                     Only proceed if you trust this server.",
444                    channel_request.server_name, fields_list,
445                );
446                tracing::warn!(
447                    server_id = event.server_id,
448                    fields = %fields_list,
449                    "elicitation requests sensitive fields"
450                );
451                let _ = self.channel.send(&warning).await;
452            }
453        }
454
455        let _ = self
456            .channel
457            .send_status("MCP server requesting input…")
458            .await;
459        let response = match self.channel.elicit(channel_request).await {
460            Ok(r) => r,
461            Err(e) => {
462                tracing::warn!(
463                    server_id = event.server_id,
464                    "elicitation channel error: {e:#}"
465                );
466                let _ = self.channel.send_status("").await;
467                let _ = event.response_tx.send(decline);
468                return;
469            }
470        };
471        let _ = self.channel.send_status("").await;
472
473        let result = match response {
474            ElicitationResponse::Accepted(value) => CreateElicitationResult {
475                action: ElicitationAction::Accept,
476                content: Some(value),
477                meta: None,
478            },
479            ElicitationResponse::Declined => CreateElicitationResult {
480                action: ElicitationAction::Decline,
481                content: None,
482                meta: None,
483            },
484            ElicitationResponse::Cancelled => CreateElicitationResult {
485                action: ElicitationAction::Cancel,
486                content: None,
487                meta: None,
488            },
489        };
490
491        if event.response_tx.send(result).is_err() {
492            tracing::warn!(
493                server_id = event.server_id,
494                "elicitation response dropped — handler disconnected"
495            );
496        }
497    }
498
499    fn update_mcp_metrics(&mut self) {
500        let mcp_total = self.services.mcp.tools.len();
501        let mcp_server_count = self.services.mcp.server_outcomes.len();
502        let mcp_connected_count = self
503            .services
504            .mcp
505            .server_outcomes
506            .iter()
507            .filter(|o| o.connected)
508            .count();
509        let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
510            .services
511            .mcp
512            .server_outcomes
513            .iter()
514            .map(|o| crate::metrics::McpServerStatus {
515                id: o.id.clone(),
516                status: if o.connected {
517                    crate::metrics::McpServerConnectionStatus::Connected
518                } else {
519                    crate::metrics::McpServerConnectionStatus::Failed
520                },
521                tool_count: o.tool_count,
522                error: o.error.clone(),
523            })
524            .collect();
525        self.update_metrics(|m| {
526            m.mcp_tool_count = mcp_total;
527            m.mcp_server_count = mcp_server_count;
528            m.mcp_connected_count = mcp_connected_count;
529            m.mcp_servers = mcp_servers;
530        });
531    }
532
533    /// Rebuild the in-memory semantic tool index.
534    ///
535    /// Only runs when `discovery_strategy == Embedding`.  On failure (all embeddings fail),
536    /// sets `semantic_index = None` and logs at WARN — the caller falls back to all tools.
537    ///
538    /// Called at:
539    /// - initial setup via `init_semantic_index()`
540    /// - `tools/list_changed` notification
541    /// - `/mcp add` and `/mcp remove`
542    pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
543        if self.services.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
544            return;
545        }
546
547        if self.services.mcp.tools.is_empty() {
548            self.services.mcp.semantic_index = None;
549            return;
550        }
551
552        // Resolve embedding provider: dedicated discovery provider → primary embedding provider.
553        let provider = self
554            .services
555            .mcp
556            .discovery_provider
557            .clone()
558            .unwrap_or_else(|| self.embedding_provider.clone());
559
560        let inner_embed = provider.embed_fn();
561        let embed_timeout =
562            std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
563        let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
564            let fut = inner_embed(text);
565            Box::pin(async move {
566                if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
567                    result
568                } else {
569                    tracing::warn!(
570                        timeout_secs = embed_timeout.as_secs(),
571                        "semantic index: embedding probe timed out"
572                    );
573                    Err(zeph_llm::LlmError::Timeout)
574                }
575            })
576        };
577
578        // Clone tools before .await to avoid holding &self.services.mcp.tools across an await point.
579        let tools = self.services.mcp.tools.clone();
580        match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
581            Ok(idx) => {
582                tracing::info!(
583                    indexed = idx.len(),
584                    total = self.services.mcp.tools.len(),
585                    "semantic tool index built"
586                );
587                self.services.mcp.semantic_index = Some(idx);
588            }
589            Err(e) => {
590                tracing::warn!(
591                    "semantic tool index build failed, falling back to all tools: {e:#}"
592                );
593                self.services.mcp.semantic_index = None;
594            }
595        }
596    }
597}
598
599/// SEC-MCP-01: validate that a stdio command target is on the allowlist.
600///
601/// Returns `Some(error_message)` when the command is blocked, `None` when it is allowed.
602fn validate_mcp_command(target: &str, allowed_commands: &[String]) -> Option<String> {
603    let is_url = target.starts_with("http://") || target.starts_with("https://");
604    if !is_url && !allowed_commands.is_empty() && !allowed_commands.iter().any(|c| c == target) {
605        Some(format!(
606            "Command '{target}' is not allowed. Permitted: {}",
607            allowed_commands.join(", ")
608        ))
609    } else {
610        None
611    }
612}
613
614/// Build a `ServerEntry` for a newly added MCP server from parsed `/mcp add` arguments.
615fn build_server_entry(id: &str, target: &str, extra_args: &[&str]) -> zeph_mcp::ServerEntry {
616    let is_url = target.starts_with("http://") || target.starts_with("https://");
617    let transport = if is_url {
618        zeph_mcp::McpTransport::Http {
619            url: target.to_owned(),
620            headers: std::collections::HashMap::new(),
621        }
622    } else {
623        zeph_mcp::McpTransport::Stdio {
624            command: target.to_owned(),
625            args: extra_args.iter().map(|&s| s.to_owned()).collect(),
626            env: std::collections::HashMap::new(),
627        }
628    };
629    zeph_mcp::ServerEntry {
630        id: id.to_owned(),
631        transport,
632        timeout: std::time::Duration::from_secs(30),
633        trust_level: zeph_config::McpTrustLevel::Untrusted,
634        tool_allowlist: None,
635        expected_tools: Vec::new(),
636        roots: Vec::new(),
637        tool_metadata: std::collections::HashMap::new(),
638        elicitation_enabled: false,
639        elicitation_timeout_secs: 120,
640        env_isolation: false,
641    }
642}
643
644/// Convert an rmcp `ElicitationSchema` into channel-agnostic `ElicitationField` list.
645fn build_elicitation_fields(
646    schema: &rmcp::model::ElicitationSchema,
647) -> Vec<crate::channel::ElicitationField> {
648    use crate::channel::{ElicitationField, ElicitationFieldType};
649    use rmcp::model::PrimitiveSchema;
650
651    schema
652        .properties
653        .iter()
654        .map(|(name, prop)| {
655            // Extract field type and description by serializing the PrimitiveSchema to JSON
656            // and reading the discriminator field.  This avoids deep-matching the nested
657            // EnumSchema / StringSchema / … variants of rmcp's type-safe schema hierarchy.
658            let json = serde_json::to_value(prop).unwrap_or_default();
659            let description = json
660                .get("description")
661                .and_then(|v| v.as_str())
662                .map(sanitize_elicitation_message);
663
664            let field_type = match prop {
665                PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
666                PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
667                PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
668                PrimitiveSchema::String(_) => ElicitationFieldType::String,
669                PrimitiveSchema::Enum(_) => {
670                    // Extract enum values from the serialized form.  All EnumSchema variants
671                    // serialise their allowed values under "enum" or inside "items.enum".
672                    let vals = json
673                        .get("enum")
674                        .and_then(|v| v.as_array())
675                        .map(|arr| {
676                            arr.iter()
677                                .filter_map(|v| v.as_str())
678                                .map(sanitize_elicitation_message)
679                                .collect::<Vec<_>>()
680                        })
681                        .unwrap_or_default();
682                    ElicitationFieldType::Enum(vals)
683                }
684            };
685            let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
686            ElicitationField {
687                // Keep the raw schema key intact — it is used verbatim as the response map key.
688                // Channels sanitize it at display time only (see build_field_prompt / build_telegram_field_prompt).
689                name: name.clone(),
690                description,
691                field_type,
692                required,
693            }
694        })
695        .collect()
696}
697
698/// Sensitive field name patterns (case-insensitive substring match).
699const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
700    "password",
701    "passwd",
702    "token",
703    "secret",
704    "key",
705    "credential",
706    "apikey",
707    "api_key",
708    "auth",
709    "authorization",
710    "private",
711    "passphrase",
712    "pin",
713];
714
715/// Returns `true` when `field_name` matches any sensitive pattern (case-insensitive).
716fn is_sensitive_field(field_name: &str) -> bool {
717    let lower = field_name.to_lowercase();
718    SENSITIVE_FIELD_PATTERNS
719        .iter()
720        .any(|pattern| lower.contains(pattern))
721}
722
723/// Sanitize an elicitation message: cap length (in chars, not bytes) and strip control chars.
724fn sanitize_elicitation_message(message: &str) -> String {
725    const MAX_CHARS: usize = 500;
726    // Collect up to MAX_CHARS chars, filtering control characters that could manipulate terminals.
727    message
728        .chars()
729        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
730        .take(MAX_CHARS)
731        .collect()
732}
733
734#[cfg(test)]
735mod tests {
736    use super::super::agent_tests::{
737        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
738    };
739    use super::*;
740
741    #[tokio::test]
742    async fn handle_mcp_command_unknown_subcommand_shows_usage() {
743        let provider = mock_provider(vec![]);
744        let channel = MockChannel::new(vec![]);
745        let registry = create_test_registry();
746        let executor = MockToolExecutor::no_tools();
747        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
748
749        let result = agent.handle_mcp_command("unknown").await.unwrap();
750        assert!(
751            result.contains("Usage: /mcp"),
752            "expected usage message, got: {result:?}"
753        );
754    }
755
756    #[tokio::test]
757    async fn handle_mcp_list_no_manager_shows_disabled() {
758        let provider = mock_provider(vec![]);
759        let channel = MockChannel::new(vec![]);
760        let registry = create_test_registry();
761        let executor = MockToolExecutor::no_tools();
762        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
763
764        let result = agent.handle_mcp_command("list").await.unwrap();
765        assert!(
766            result.contains("MCP is not enabled"),
767            "expected not-enabled message, got: {result:?}"
768        );
769    }
770
771    #[tokio::test]
772    async fn handle_mcp_tools_no_server_id_shows_usage() {
773        let provider = mock_provider(vec![]);
774        let channel = MockChannel::new(vec![]);
775        let registry = create_test_registry();
776        let executor = MockToolExecutor::no_tools();
777        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
778
779        let result = agent.handle_mcp_command("tools").await.unwrap();
780        assert!(
781            result.contains("Usage: /mcp tools"),
782            "expected tools usage message, got: {result:?}"
783        );
784    }
785
786    #[tokio::test]
787    async fn handle_mcp_remove_no_server_id_shows_usage() {
788        let provider = mock_provider(vec![]);
789        let channel = MockChannel::new(vec![]);
790        let registry = create_test_registry();
791        let executor = MockToolExecutor::no_tools();
792        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
793
794        let result = agent.handle_mcp_command("remove").await.unwrap();
795        assert!(
796            result.contains("Usage: /mcp remove"),
797            "expected remove usage message, got: {result:?}"
798        );
799    }
800
801    #[tokio::test]
802    async fn handle_mcp_remove_no_manager_shows_disabled() {
803        let provider = mock_provider(vec![]);
804        let channel = MockChannel::new(vec![]);
805        let registry = create_test_registry();
806        let executor = MockToolExecutor::no_tools();
807        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
808
809        let result = agent.handle_mcp_command("remove my-server").await.unwrap();
810        assert!(
811            result.contains("MCP is not enabled"),
812            "expected not-enabled message, got: {result:?}"
813        );
814    }
815
816    #[tokio::test]
817    async fn handle_mcp_add_insufficient_args_shows_usage() {
818        let provider = mock_provider(vec![]);
819        let channel = MockChannel::new(vec![]);
820        let registry = create_test_registry();
821        let executor = MockToolExecutor::no_tools();
822        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
823
824        // "add" with only 1 arg (needs at least 2: id + command)
825        let result = agent.handle_mcp_command("add server-id").await.unwrap();
826        assert!(
827            result.contains("Usage: /mcp add"),
828            "expected add usage message, got: {result:?}"
829        );
830    }
831
832    #[tokio::test]
833    async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
834        let provider = mock_provider(vec![]);
835        let channel = MockChannel::new(vec![]);
836        let registry = create_test_registry();
837        let executor = MockToolExecutor::no_tools();
838        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
839
840        // mcp.tools is empty, so any server will have no tools
841        let result = agent
842            .handle_mcp_command("tools nonexistent-server")
843            .await
844            .unwrap();
845        assert!(
846            result.contains("No tools found"),
847            "expected no-tools message, got: {result:?}"
848        );
849    }
850
851    #[tokio::test]
852    async fn mcp_tool_count_starts_at_zero() {
853        let provider = mock_provider(vec![]);
854        let channel = MockChannel::new(vec![]);
855        let registry = create_test_registry();
856        let executor = MockToolExecutor::no_tools();
857        let agent = Agent::new(provider, channel, registry, None, 5, executor);
858
859        assert_eq!(agent.services.mcp.tool_count(), 0);
860    }
861
862    #[tokio::test]
863    async fn check_tool_refresh_no_rx_is_noop() {
864        let provider = mock_provider(vec![]);
865        let channel = MockChannel::new(vec![]);
866        let registry = create_test_registry();
867        let executor = MockToolExecutor::no_tools();
868        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
869        // No tool_rx set; check_tool_refresh should be a no-op.
870        agent.check_tool_refresh().await;
871        assert_eq!(agent.services.mcp.tool_count(), 0);
872    }
873
874    #[tokio::test]
875    async fn check_tool_refresh_no_change_is_noop() {
876        let provider = mock_provider(vec![]);
877        let channel = MockChannel::new(vec![]);
878        let registry = create_test_registry();
879        let executor = MockToolExecutor::no_tools();
880        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
881
882        let (tx, rx) = tokio::sync::watch::channel(Vec::new());
883        agent.services.mcp.tool_rx = Some(rx);
884        // No changes sent; has_changed() returns false.
885        agent.check_tool_refresh().await;
886        assert_eq!(agent.services.mcp.tool_count(), 0);
887        drop(tx);
888    }
889
890    #[tokio::test]
891    async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
892        let provider = mock_provider(vec![]);
893        let channel = MockChannel::new(vec![]);
894        let registry = create_test_registry();
895        let executor = MockToolExecutor::no_tools();
896        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
897        agent.services.mcp.tools = vec![zeph_mcp::McpTool {
898            server_id: "srv".into(),
899            name: "existing_tool".into(),
900            description: String::new(),
901            input_schema: serde_json::json!({}),
902            output_schema: None,
903            security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
904        }];
905
906        let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
907        agent.services.mcp.tool_rx = Some(rx);
908        // has_changed() is false for a fresh receiver; tools unchanged.
909        agent.check_tool_refresh().await;
910        assert_eq!(agent.services.mcp.tool_count(), 1);
911    }
912
913    #[tokio::test]
914    async fn check_tool_refresh_applies_update() {
915        let provider = mock_provider(vec![]);
916        let channel = MockChannel::new(vec![]);
917        let registry = create_test_registry();
918        let executor = MockToolExecutor::no_tools();
919        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
920
921        let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
922        agent.services.mcp.tool_rx = Some(rx);
923
924        let new_tools = vec![zeph_mcp::McpTool {
925            server_id: "srv".into(),
926            name: "refreshed_tool".into(),
927            description: String::new(),
928            input_schema: serde_json::json!({}),
929            output_schema: None,
930            security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
931        }];
932        tx.send(new_tools).unwrap();
933
934        agent.check_tool_refresh().await;
935        assert_eq!(agent.services.mcp.tool_count(), 1);
936        assert_eq!(agent.services.mcp.tools[0].name, "refreshed_tool");
937    }
938
939    #[test]
940    fn sanitize_elicitation_message_strips_control_chars() {
941        let input = "hello\x01world\x1b[31mred\x1b[0m";
942        let output = sanitize_elicitation_message(input);
943        assert!(!output.contains('\x01'));
944        assert!(!output.contains('\x1b'));
945        assert!(output.contains("hello"));
946        assert!(output.contains("world"));
947    }
948
949    #[test]
950    fn sanitize_elicitation_message_preserves_newline_and_tab() {
951        let input = "line1\nline2\ttabbed";
952        let output = sanitize_elicitation_message(input);
953        assert_eq!(output, "line1\nline2\ttabbed");
954    }
955
956    #[test]
957    fn sanitize_elicitation_message_caps_at_500_chars() {
958        // Build a 600-char ASCII string — no multi-byte boundary issue.
959        let input: String = "a".repeat(600);
960        let output = sanitize_elicitation_message(&input);
961        assert_eq!(output.chars().count(), 500);
962    }
963
964    #[test]
965    fn sanitize_elicitation_message_handles_multibyte_boundary() {
966        // "é" is 2 bytes.  Build a string where a naive &str[..500] would panic.
967        let input: String = "é".repeat(300); // 300 chars = 600 bytes
968        let output = sanitize_elicitation_message(&input);
969        // Should truncate to exactly 500 chars without panic.
970        assert_eq!(output.chars().count(), 300);
971    }
972
973    #[test]
974    fn build_elicitation_fields_maps_primitive_types() {
975        use crate::channel::ElicitationFieldType;
976        use rmcp::model::{
977            BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
978            StringSchema,
979        };
980        use std::collections::BTreeMap;
981
982        let mut props = BTreeMap::new();
983        props.insert(
984            "flag".to_owned(),
985            PrimitiveSchema::Boolean(BooleanSchema::new()),
986        );
987        props.insert(
988            "count".to_owned(),
989            PrimitiveSchema::Integer(IntegerSchema::new()),
990        );
991        props.insert(
992            "ratio".to_owned(),
993            PrimitiveSchema::Number(NumberSchema::new()),
994        );
995        props.insert(
996            "name".to_owned(),
997            PrimitiveSchema::String(StringSchema::new()),
998        );
999
1000        let schema = ElicitationSchema::new(props);
1001        let fields = build_elicitation_fields(&schema);
1002
1003        let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1004        assert!(matches!(
1005            get("flag").field_type,
1006            ElicitationFieldType::Boolean
1007        ));
1008        assert!(matches!(
1009            get("count").field_type,
1010            ElicitationFieldType::Integer
1011        ));
1012        assert!(matches!(
1013            get("ratio").field_type,
1014            ElicitationFieldType::Number
1015        ));
1016        assert!(matches!(
1017            get("name").field_type,
1018            ElicitationFieldType::String
1019        ));
1020    }
1021
1022    #[test]
1023    fn build_elicitation_fields_required_flag() {
1024        use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1025        use std::collections::BTreeMap;
1026
1027        let mut props = BTreeMap::new();
1028        props.insert(
1029            "req".to_owned(),
1030            PrimitiveSchema::String(StringSchema::new()),
1031        );
1032        props.insert(
1033            "opt".to_owned(),
1034            PrimitiveSchema::String(StringSchema::new()),
1035        );
1036
1037        let mut schema = ElicitationSchema::new(props);
1038        schema.required = Some(vec!["req".to_owned()]);
1039
1040        let fields = build_elicitation_fields(&schema);
1041        let req = fields.iter().find(|f| f.name == "req").unwrap();
1042        let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1043        assert!(req.required);
1044        assert!(!opt.required);
1045    }
1046
1047    #[test]
1048    fn is_sensitive_field_detects_common_patterns() {
1049        assert!(is_sensitive_field("password"));
1050        assert!(is_sensitive_field("PASSWORD"));
1051        assert!(is_sensitive_field("user_password"));
1052        assert!(is_sensitive_field("api_token"));
1053        assert!(is_sensitive_field("SECRET_KEY"));
1054        assert!(is_sensitive_field("auth_header"));
1055        assert!(is_sensitive_field("private_key"));
1056    }
1057
1058    #[test]
1059    fn is_sensitive_field_allows_non_sensitive_names() {
1060        assert!(!is_sensitive_field("username"));
1061        assert!(!is_sensitive_field("email"));
1062        assert!(!is_sensitive_field("message"));
1063        assert!(!is_sensitive_field("description"));
1064        assert!(!is_sensitive_field("subject"));
1065    }
1066}