Skip to main content

zeph_core/agent/
mcp.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9    /// Dispatch a `/mcp` subcommand, returning the output as a `String`.
10    ///
11    /// All output is collected into the returned string; no channel sends are
12    /// performed.  This makes the future `Send`-compatible for use in
13    /// `AgentAccess::handle_mcp`.
14    #[tracing::instrument(skip_all, name = "core.agent.handle_mcp_command")]
15    pub(super) async fn handle_mcp_command(
16        &mut self,
17        args: &str,
18    ) -> Result<String, super::error::AgentError> {
19        let parts: Vec<&str> = args.split_whitespace().collect();
20        match parts.first().copied() {
21            Some("add") => self.handle_mcp_add(&parts[1..]).await,
22            Some("list") => self.handle_mcp_list().await,
23            Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
24            Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
25            _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
26        }
27    }
28
29    async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
30        if args.len() < 2 {
31            return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
32        }
33
34        // Clone the Arc so no borrow of self.services.mcp.manager is held across .await.
35        let Some(manager) = self.services.mcp.manager.clone() else {
36            return Ok("MCP is not enabled.".to_owned());
37        };
38
39        let target = args[1];
40        if let Some(err) = validate_mcp_command(target, &self.services.mcp.allowed_commands) {
41            return Ok(err);
42        }
43
44        // SEC-MCP-03: enforce server limit
45        let current_count = manager.list_servers().await.len();
46        if current_count >= self.services.mcp.max_dynamic {
47            return Ok(format!(
48                "Server limit reached ({}/{}).",
49                current_count, self.services.mcp.max_dynamic
50            ));
51        }
52
53        let entry = build_server_entry(args[0], target, &args[2..]);
54
55        match manager.add_server(&entry).await {
56            Ok(tools) => {
57                let count = tools.len();
58                self.services
59                    .mcp
60                    .server_outcomes
61                    .push(zeph_mcp::ServerConnectOutcome {
62                        id: entry.id.clone(),
63                        connected: true,
64                        tool_count: count,
65                        error: String::new(),
66                    });
67                self.services.mcp.tools.extend(tools);
68                self.services.mcp.sync_executor_tools();
69                self.services.mcp.pruning_cache.reset();
70                // Defer rebuild to check_tool_refresh (next turn) so this method
71                // stays Send-compatible for use in AgentAccess::handle_mcp.
72                self.services.mcp.pending_semantic_rebuild = true;
73                self.update_mcp_metrics();
74                Ok(format!(
75                    "Connected MCP server '{}' ({count} tool(s))",
76                    entry.id
77                ))
78            }
79            Err(e) => {
80                tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
81                Ok(format!("Failed to connect server '{}': {e}", entry.id))
82            }
83        }
84    }
85
86    async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
87        use std::fmt::Write;
88
89        let Some(manager) = self.services.mcp.manager.clone() else {
90            return Ok("MCP is not enabled.".to_owned());
91        };
92
93        let server_ids = manager.list_servers().await;
94        if server_ids.is_empty() {
95            return Ok("No MCP servers connected.".to_owned());
96        }
97
98        let mut output = String::from("Connected MCP servers:\n");
99        let mut total = 0usize;
100        for id in &server_ids {
101            let count = self
102                .services
103                .mcp
104                .tools
105                .iter()
106                .filter(|t| t.server_id == *id)
107                .count();
108            total += count;
109            let _ = writeln!(output, "- {id} ({count} tools)");
110        }
111        let _ = write!(output, "Total: {total} tool(s)");
112
113        Ok(output)
114    }
115
116    fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
117        use std::fmt::Write;
118
119        let Some(server_id) = server_id else {
120            return "Usage: /mcp tools <server_id>".to_owned();
121        };
122
123        let tools: Vec<_> = self
124            .services
125            .mcp
126            .tools
127            .iter()
128            .filter(|t| t.server_id == server_id)
129            .collect();
130
131        if tools.is_empty() {
132            return format!("No tools found for server '{server_id}'.");
133        }
134
135        let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
136        for t in &tools {
137            if t.description.is_empty() {
138                let _ = writeln!(output, "- {}", t.name);
139            } else {
140                let _ = writeln!(output, "- {} — {}", t.name, t.description);
141            }
142        }
143        output
144    }
145
146    async fn handle_mcp_remove(
147        &mut self,
148        server_id: Option<&str>,
149    ) -> Result<String, super::error::AgentError> {
150        let Some(server_id) = server_id else {
151            return Ok("Usage: /mcp remove <id>".to_owned());
152        };
153
154        // Clone the Arc so no borrow of self.services.mcp.manager is held across .await.
155        let Some(manager) = self.services.mcp.manager.clone() else {
156            return Ok("MCP is not enabled.".to_owned());
157        };
158
159        match manager.remove_server(server_id).await {
160            Ok(()) => {
161                let before = self.services.mcp.tools.len();
162                self.services.mcp.tools.retain(|t| t.server_id != server_id);
163                let removed = before - self.services.mcp.tools.len();
164                self.services
165                    .mcp
166                    .server_outcomes
167                    .retain(|o| o.id != server_id);
168                self.services.mcp.sync_executor_tools();
169                self.services.mcp.pruning_cache.reset();
170                // Defer rebuild to check_tool_refresh (next turn) so this method
171                // stays Send-compatible for use in AgentAccess::handle_mcp.
172                self.services.mcp.pending_semantic_rebuild = true;
173                self.update_mcp_metrics();
174                let sid = server_id.to_owned();
175                self.update_metrics(|m| {
176                    m.active_mcp_tools
177                        .retain(|name| !name.starts_with(&format!("{sid}:")));
178                });
179                Ok(format!(
180                    "Disconnected MCP server '{server_id}' (removed {removed} tools)"
181                ))
182            }
183            Err(e) => {
184                tracing::warn!(server_id, "MCP remove failed: {e:#}");
185                Ok(format!("Failed to remove server '{server_id}': {e}"))
186            }
187        }
188    }
189
190    pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
191        let matched_tools = self.match_mcp_tools(query).await;
192        let active_mcp: Vec<String> = matched_tools
193            .iter()
194            .map(zeph_mcp::McpTool::qualified_name)
195            .collect();
196        let mcp_total = self.services.mcp.tools.len();
197        let (mcp_server_count, mcp_connected_count) =
198            if self.services.mcp.server_outcomes.is_empty() {
199                let connected = self
200                    .services
201                    .mcp
202                    .tools
203                    .iter()
204                    .map(|t| &t.server_id)
205                    .collect::<std::collections::HashSet<_>>()
206                    .len();
207                (connected, connected)
208            } else {
209                let total = self.services.mcp.server_outcomes.len();
210                let connected = self
211                    .services
212                    .mcp
213                    .server_outcomes
214                    .iter()
215                    .filter(|o| o.connected)
216                    .count();
217                (total, connected)
218            };
219        self.update_metrics(|m| {
220            m.active_mcp_tools = active_mcp;
221            m.mcp_tool_count = mcp_total;
222            m.mcp_server_count = mcp_server_count;
223            m.mcp_connected_count = mcp_connected_count;
224        });
225        if let Some(ref manager) = self.services.mcp.manager {
226            let instructions = manager.all_server_instructions().await;
227            if !instructions.is_empty() {
228                system_prompt.push_str("\n\n");
229                system_prompt.push_str(&instructions);
230            }
231        }
232        if !matched_tools.is_empty() {
233            let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
234            tracing::debug!(
235                skills = ?self.services.skill.active_skill_names,
236                mcp_tools = ?tool_names,
237                "matched items"
238            );
239            let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
240            if !tools_prompt.is_empty() {
241                system_prompt.push_str("\n\n");
242                system_prompt.push_str(&tools_prompt);
243            }
244        }
245    }
246
247    async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
248        let Some(ref registry) = self.services.mcp.registry else {
249            return self.services.mcp.tools.clone();
250        };
251        let provider = self.embedding_provider.clone();
252        registry
253            .search(query, self.services.skill.max_active_skills, |text| {
254                let owned = text.to_owned();
255                let p = provider.clone();
256                Box::pin(async move { p.embed(&owned).await })
257            })
258            .await
259    }
260
261    /// Poll the watch receiver for tool list updates from `tools/list_changed` notifications,
262    /// and process any deferred semantic index rebuild requests.
263    ///
264    /// Called once per agent turn, before processing user input.  Two triggers cause a rebuild:
265    /// - A `tools/list_changed` notification from an MCP server (via `tool_rx`).
266    /// - `pending_semantic_rebuild == true`, set by `/mcp add` or `/mcp remove` when dispatched
267    ///   via `AgentAccess::handle_mcp` (which cannot call `rebuild_semantic_index` directly
268    ///   because the future would be `!Send`).
269    ///
270    /// If neither trigger fires, this is a no-op.
271    pub(super) async fn check_tool_refresh(&mut self) {
272        // Handle deferred rebuild from /mcp add|remove via AgentAccess.
273        if self.services.mcp.pending_semantic_rebuild {
274            self.services.mcp.pending_semantic_rebuild = false;
275            self.rebuild_semantic_index().await;
276            self.sync_mcp_registry().await;
277            let mcp_total = self.services.mcp.tools.len();
278            let mcp_servers = self
279                .services
280                .mcp
281                .tools
282                .iter()
283                .map(|t| &t.server_id)
284                .collect::<std::collections::HashSet<_>>()
285                .len();
286            self.update_metrics(|m| {
287                m.mcp_tool_count = mcp_total;
288                m.mcp_server_count = mcp_servers;
289            });
290        }
291
292        let Some(ref mut rx) = self.services.mcp.tool_rx else {
293            return;
294        };
295        if !rx.has_changed().unwrap_or(false) {
296            return;
297        }
298        let new_tools = rx.borrow_and_update().clone();
299        if new_tools.is_empty() {
300            // Guard against replacing a non-empty initial tool list with the watch's empty
301            // initial value. The watch is only updated after a real tools/list_changed event.
302            return;
303        }
304        tracing::info!(
305            tools = new_tools.len(),
306            "tools/list_changed: agent tool list refreshed"
307        );
308        self.services.mcp.tools = new_tools;
309        self.services.mcp.sync_executor_tools();
310        self.services.mcp.pruning_cache.reset();
311        self.rebuild_semantic_index().await;
312        self.sync_mcp_registry().await;
313        let mcp_total = self.services.mcp.tools.len();
314        let mcp_servers = self
315            .services
316            .mcp
317            .tools
318            .iter()
319            .map(|t| &t.server_id)
320            .collect::<std::collections::HashSet<_>>()
321            .len();
322        self.update_metrics(|m| {
323            m.mcp_tool_count = mcp_total;
324            m.mcp_server_count = mcp_servers;
325        });
326    }
327
328    pub(super) async fn sync_mcp_registry(&mut self) {
329        if self.services.mcp.registry.is_none() {
330            return;
331        }
332        if !self.embedding_provider.supports_embeddings() {
333            return;
334        }
335        // Clone tools before .await to avoid holding &self.services.mcp.tools across an await point.
336        let tools = self.services.mcp.tools.clone();
337        let provider = self.embedding_provider.clone();
338        let embedding_model = self.services.skill.embedding_model.clone();
339        let embed_timeout =
340            std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
341        let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
342            let owned = text.to_owned();
343            let p = provider.clone();
344            Box::pin(async move {
345                if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
346                    result
347                } else {
348                    tracing::warn!(
349                        timeout_secs = embed_timeout.as_secs(),
350                        "MCP registry: embedding timed out"
351                    );
352                    Err(zeph_llm::LlmError::Timeout)
353                }
354            })
355        };
356        // Take registry out of self to avoid holding &mut self.services.mcp.registry across .await.
357        // No early returns between take() and put-back — the await is the only yield point here.
358        let Some(mut registry) = self.services.mcp.registry.take() else {
359            return;
360        };
361        if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
362            tracing::warn!("failed to sync MCP tool registry: {e:#}");
363        }
364        self.services.mcp.registry = Some(registry);
365    }
366
367    /// Build (or rebuild) the in-memory semantic tool index for embedding-based discovery.
368    /// Build the initial semantic tool index after agent construction.
369    ///
370    /// Must be called once after `with_mcp` and `with_mcp_discovery` are applied,
371    /// before the first user turn.  Subsequent rebuilds happen automatically on
372    /// tool list change events (`check_tool_refresh`, `/mcp add`, `/mcp remove`).
373    pub async fn init_semantic_index(&mut self) {
374        self.rebuild_semantic_index().await;
375    }
376
377    /// Drain and process all pending elicitation requests without blocking.
378    ///
379    /// Call this at the start of each turn and between tool calls to prevent
380    /// elicitation events from accumulating while the agent loop is busy.
381    pub(super) async fn process_pending_elicitations(&mut self) {
382        loop {
383            let Some(ref mut rx) = self.services.mcp.elicitation_rx else {
384                return;
385            };
386            match rx.try_recv() {
387                Ok(event) => {
388                    self.handle_elicitation_event(event).await;
389                }
390                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
391                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
392                    self.services.mcp.elicitation_rx = None;
393                    return;
394                }
395            }
396        }
397    }
398
399    /// Handle a single elicitation event by routing it to the active channel.
400    pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
401        use crate::channel::{ElicitationRequest, ElicitationResponse};
402
403        let decline = CreateElicitationResult {
404            action: ElicitationAction::Decline,
405            content: None,
406            meta: None,
407        };
408
409        let channel_request = match &event.request {
410            rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
411                message,
412                requested_schema,
413                ..
414            } => {
415                let fields = build_elicitation_fields(requested_schema);
416                ElicitationRequest {
417                    server_name: event.server_id.clone(),
418                    message: sanitize_elicitation_message(message),
419                    fields,
420                }
421            }
422            rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
423                // URL elicitation not supported in phase 1 — decline.
424                tracing::debug!(
425                    server_id = event.server_id,
426                    "URL elicitation not supported, declining"
427                );
428                let _ = event.response_tx.send(decline);
429                return;
430            }
431        };
432
433        if self.services.mcp.elicitation_warn_sensitive_fields {
434            let sensitive: Vec<&str> = channel_request
435                .fields
436                .iter()
437                .filter(|f| is_sensitive_field(&f.name))
438                .map(|f| f.name.as_str())
439                .collect();
440            if !sensitive.is_empty() {
441                let fields_list = sensitive.join(", ");
442                let warning = format!(
443                    "Warning: [{}] is requesting sensitive information (field: {}). \
444                     Only proceed if you trust this server.",
445                    channel_request.server_name, fields_list,
446                );
447                tracing::warn!(
448                    server_id = event.server_id,
449                    fields = %fields_list,
450                    "elicitation requests sensitive fields"
451                );
452                let _ = self.channel.send(&warning).await;
453            }
454        }
455
456        let _ = self
457            .channel
458            .send_status("MCP server requesting input…")
459            .await;
460        let response = match self.channel.elicit(channel_request).await {
461            Ok(r) => r,
462            Err(e) => {
463                tracing::warn!(
464                    server_id = event.server_id,
465                    "elicitation channel error: {e:#}"
466                );
467                let _ = self.channel.send_status("").await;
468                let _ = event.response_tx.send(decline);
469                return;
470            }
471        };
472        let _ = self.channel.send_status("").await;
473
474        let result = match response {
475            ElicitationResponse::Accepted(value) => CreateElicitationResult {
476                action: ElicitationAction::Accept,
477                content: Some(value),
478                meta: None,
479            },
480            ElicitationResponse::Declined => CreateElicitationResult {
481                action: ElicitationAction::Decline,
482                content: None,
483                meta: None,
484            },
485            ElicitationResponse::Cancelled => CreateElicitationResult {
486                action: ElicitationAction::Cancel,
487                content: None,
488                meta: None,
489            },
490        };
491
492        if event.response_tx.send(result).is_err() {
493            tracing::warn!(
494                server_id = event.server_id,
495                "elicitation response dropped — handler disconnected"
496            );
497        }
498    }
499
500    fn update_mcp_metrics(&mut self) {
501        let mcp_total = self.services.mcp.tools.len();
502        let mcp_server_count = self.services.mcp.server_outcomes.len();
503        let mcp_connected_count = self
504            .services
505            .mcp
506            .server_outcomes
507            .iter()
508            .filter(|o| o.connected)
509            .count();
510        let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
511            .services
512            .mcp
513            .server_outcomes
514            .iter()
515            .map(|o| crate::metrics::McpServerStatus {
516                id: o.id.clone(),
517                status: if o.connected {
518                    crate::metrics::McpServerConnectionStatus::Connected
519                } else {
520                    crate::metrics::McpServerConnectionStatus::Failed
521                },
522                tool_count: o.tool_count,
523                error: o.error.clone(),
524            })
525            .collect();
526        self.update_metrics(|m| {
527            m.mcp_tool_count = mcp_total;
528            m.mcp_server_count = mcp_server_count;
529            m.mcp_connected_count = mcp_connected_count;
530            m.mcp_servers = mcp_servers;
531        });
532    }
533
534    /// Rebuild the in-memory semantic tool index.
535    ///
536    /// Only runs when `discovery_strategy == Embedding`.  On failure (all embeddings fail),
537    /// sets `semantic_index = None` and logs at WARN — the caller falls back to all tools.
538    ///
539    /// Called at:
540    /// - initial setup via `init_semantic_index()`
541    /// - `tools/list_changed` notification
542    /// - `/mcp add` and `/mcp remove`
543    pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
544        if self.services.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
545            return;
546        }
547
548        if self.services.mcp.tools.is_empty() {
549            self.services.mcp.semantic_index = None;
550            return;
551        }
552
553        // Resolve embedding provider: dedicated discovery provider → primary embedding provider.
554        let provider = self
555            .services
556            .mcp
557            .discovery_provider
558            .clone()
559            .unwrap_or_else(|| self.embedding_provider.clone());
560
561        let inner_embed = provider.embed_fn();
562        let embed_timeout =
563            std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
564        let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
565            let fut = inner_embed(text);
566            Box::pin(async move {
567                if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
568                    result
569                } else {
570                    tracing::warn!(
571                        timeout_secs = embed_timeout.as_secs(),
572                        "semantic index: embedding probe timed out"
573                    );
574                    Err(zeph_llm::LlmError::Timeout)
575                }
576            })
577        };
578
579        // Clone tools before .await to avoid holding &self.services.mcp.tools across an await point.
580        let tools = self.services.mcp.tools.clone();
581        match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
582            Ok(idx) => {
583                tracing::info!(
584                    indexed = idx.len(),
585                    total = self.services.mcp.tools.len(),
586                    "semantic tool index built"
587                );
588                self.services.mcp.semantic_index = Some(idx);
589            }
590            Err(e) => {
591                tracing::warn!(
592                    "semantic tool index build failed, falling back to all tools: {e:#}"
593                );
594                self.services.mcp.semantic_index = None;
595            }
596        }
597    }
598}
599
600/// SEC-MCP-01: validate that a stdio command target is on the allowlist.
601///
602/// Returns `Some(error_message)` when the command is blocked, `None` when it is allowed.
603fn validate_mcp_command(target: &str, allowed_commands: &[String]) -> Option<String> {
604    let is_url = target.starts_with("http://") || target.starts_with("https://");
605    if !is_url && !allowed_commands.is_empty() && !allowed_commands.iter().any(|c| c == target) {
606        Some(format!(
607            "Command '{target}' is not allowed. Permitted: {}",
608            allowed_commands.join(", ")
609        ))
610    } else {
611        None
612    }
613}
614
615/// Build a `ServerEntry` for a newly added MCP server from parsed `/mcp add` arguments.
616fn build_server_entry(id: &str, target: &str, extra_args: &[&str]) -> zeph_mcp::ServerEntry {
617    let is_url = target.starts_with("http://") || target.starts_with("https://");
618    let transport = if is_url {
619        zeph_mcp::McpTransport::Http {
620            url: target.to_owned(),
621            headers: std::collections::HashMap::new(),
622        }
623    } else {
624        zeph_mcp::McpTransport::Stdio {
625            command: target.to_owned(),
626            args: extra_args.iter().map(|&s| s.to_owned()).collect(),
627            env: std::collections::HashMap::new(),
628        }
629    };
630    zeph_mcp::ServerEntry {
631        id: id.to_owned(),
632        transport,
633        timeout: std::time::Duration::from_secs(30),
634        trust_level: zeph_config::McpTrustLevel::Untrusted,
635        tool_allowlist: None,
636        expected_tools: Vec::new(),
637        roots: Vec::new(),
638        tool_metadata: std::collections::HashMap::new(),
639        elicitation_enabled: false,
640        elicitation_timeout_secs: 120,
641        env_isolation: false,
642    }
643}
644
645/// Convert an rmcp `ElicitationSchema` into channel-agnostic `ElicitationField` list.
646fn build_elicitation_fields(
647    schema: &rmcp::model::ElicitationSchema,
648) -> Vec<crate::channel::ElicitationField> {
649    use crate::channel::{ElicitationField, ElicitationFieldType};
650    use rmcp::model::PrimitiveSchema;
651
652    schema
653        .properties
654        .iter()
655        .map(|(name, prop)| {
656            // Extract field type and description by serializing the PrimitiveSchema to JSON
657            // and reading the discriminator field.  This avoids deep-matching the nested
658            // EnumSchema / StringSchema / … variants of rmcp's type-safe schema hierarchy.
659            let json = serde_json::to_value(prop).unwrap_or_default();
660            let description = json
661                .get("description")
662                .and_then(|v| v.as_str())
663                .map(sanitize_elicitation_message);
664
665            let field_type = match prop {
666                PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
667                PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
668                PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
669                PrimitiveSchema::String(_) => ElicitationFieldType::String,
670                PrimitiveSchema::Enum(_) => {
671                    // Extract enum values from the serialized form.  All EnumSchema variants
672                    // serialise their allowed values under "enum" or inside "items.enum".
673                    let vals = json
674                        .get("enum")
675                        .and_then(|v| v.as_array())
676                        .map(|arr| {
677                            arr.iter()
678                                .filter_map(|v| v.as_str())
679                                .map(sanitize_elicitation_message)
680                                .collect::<Vec<_>>()
681                        })
682                        .unwrap_or_default();
683                    ElicitationFieldType::Enum(vals)
684                }
685            };
686            let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
687            ElicitationField {
688                // Keep the raw schema key intact — it is used verbatim as the response map key.
689                // Channels sanitize it at display time only (see build_field_prompt / build_telegram_field_prompt).
690                name: name.clone(),
691                description,
692                field_type,
693                required,
694            }
695        })
696        .collect()
697}
698
699/// Sensitive field name patterns (case-insensitive substring match).
700const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
701    "password",
702    "passwd",
703    "token",
704    "secret",
705    "key",
706    "credential",
707    "apikey",
708    "api_key",
709    "auth",
710    "authorization",
711    "private",
712    "passphrase",
713    "pin",
714];
715
716/// Returns `true` when `field_name` matches any sensitive pattern (case-insensitive).
717fn is_sensitive_field(field_name: &str) -> bool {
718    let lower = field_name.to_lowercase();
719    SENSITIVE_FIELD_PATTERNS
720        .iter()
721        .any(|pattern| lower.contains(pattern))
722}
723
724/// Sanitize an elicitation message: cap length (in chars, not bytes) and strip control chars.
725fn sanitize_elicitation_message(message: &str) -> String {
726    const MAX_CHARS: usize = 500;
727    // Collect up to MAX_CHARS chars, filtering control characters that could manipulate terminals.
728    message
729        .chars()
730        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
731        .take(MAX_CHARS)
732        .collect()
733}
734
735#[cfg(test)]
736mod tests {
737    use super::super::agent_tests::{
738        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
739    };
740    use super::*;
741
742    #[tokio::test]
743    async fn handle_mcp_command_unknown_subcommand_shows_usage() {
744        let provider = mock_provider(vec![]);
745        let channel = MockChannel::new(vec![]);
746        let registry = create_test_registry();
747        let executor = MockToolExecutor::no_tools();
748        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
749
750        let result = agent.handle_mcp_command("unknown").await.unwrap();
751        assert!(
752            result.contains("Usage: /mcp"),
753            "expected usage message, got: {result:?}"
754        );
755    }
756
757    #[tokio::test]
758    async fn handle_mcp_list_no_manager_shows_disabled() {
759        let provider = mock_provider(vec![]);
760        let channel = MockChannel::new(vec![]);
761        let registry = create_test_registry();
762        let executor = MockToolExecutor::no_tools();
763        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
764
765        let result = agent.handle_mcp_command("list").await.unwrap();
766        assert!(
767            result.contains("MCP is not enabled"),
768            "expected not-enabled message, got: {result:?}"
769        );
770    }
771
772    #[tokio::test]
773    async fn handle_mcp_tools_no_server_id_shows_usage() {
774        let provider = mock_provider(vec![]);
775        let channel = MockChannel::new(vec![]);
776        let registry = create_test_registry();
777        let executor = MockToolExecutor::no_tools();
778        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
779
780        let result = agent.handle_mcp_command("tools").await.unwrap();
781        assert!(
782            result.contains("Usage: /mcp tools"),
783            "expected tools usage message, got: {result:?}"
784        );
785    }
786
787    #[tokio::test]
788    async fn handle_mcp_remove_no_server_id_shows_usage() {
789        let provider = mock_provider(vec![]);
790        let channel = MockChannel::new(vec![]);
791        let registry = create_test_registry();
792        let executor = MockToolExecutor::no_tools();
793        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
794
795        let result = agent.handle_mcp_command("remove").await.unwrap();
796        assert!(
797            result.contains("Usage: /mcp remove"),
798            "expected remove usage message, got: {result:?}"
799        );
800    }
801
802    #[tokio::test]
803    async fn handle_mcp_remove_no_manager_shows_disabled() {
804        let provider = mock_provider(vec![]);
805        let channel = MockChannel::new(vec![]);
806        let registry = create_test_registry();
807        let executor = MockToolExecutor::no_tools();
808        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
809
810        let result = agent.handle_mcp_command("remove my-server").await.unwrap();
811        assert!(
812            result.contains("MCP is not enabled"),
813            "expected not-enabled message, got: {result:?}"
814        );
815    }
816
817    #[tokio::test]
818    async fn handle_mcp_add_insufficient_args_shows_usage() {
819        let provider = mock_provider(vec![]);
820        let channel = MockChannel::new(vec![]);
821        let registry = create_test_registry();
822        let executor = MockToolExecutor::no_tools();
823        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
824
825        // "add" with only 1 arg (needs at least 2: id + command)
826        let result = agent.handle_mcp_command("add server-id").await.unwrap();
827        assert!(
828            result.contains("Usage: /mcp add"),
829            "expected add usage message, got: {result:?}"
830        );
831    }
832
833    #[tokio::test]
834    async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
835        let provider = mock_provider(vec![]);
836        let channel = MockChannel::new(vec![]);
837        let registry = create_test_registry();
838        let executor = MockToolExecutor::no_tools();
839        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
840
841        // mcp.tools is empty, so any server will have no tools
842        let result = agent
843            .handle_mcp_command("tools nonexistent-server")
844            .await
845            .unwrap();
846        assert!(
847            result.contains("No tools found"),
848            "expected no-tools message, got: {result:?}"
849        );
850    }
851
852    #[tokio::test]
853    async fn mcp_tool_count_starts_at_zero() {
854        let provider = mock_provider(vec![]);
855        let channel = MockChannel::new(vec![]);
856        let registry = create_test_registry();
857        let executor = MockToolExecutor::no_tools();
858        let agent = Agent::new(provider, channel, registry, None, 5, executor);
859
860        assert_eq!(agent.services.mcp.tool_count(), 0);
861    }
862
863    #[tokio::test]
864    async fn check_tool_refresh_no_rx_is_noop() {
865        let provider = mock_provider(vec![]);
866        let channel = MockChannel::new(vec![]);
867        let registry = create_test_registry();
868        let executor = MockToolExecutor::no_tools();
869        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
870        // No tool_rx set; check_tool_refresh should be a no-op.
871        agent.check_tool_refresh().await;
872        assert_eq!(agent.services.mcp.tool_count(), 0);
873    }
874
875    #[tokio::test]
876    async fn check_tool_refresh_no_change_is_noop() {
877        let provider = mock_provider(vec![]);
878        let channel = MockChannel::new(vec![]);
879        let registry = create_test_registry();
880        let executor = MockToolExecutor::no_tools();
881        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
882
883        let (tx, rx) = tokio::sync::watch::channel(Vec::new());
884        agent.services.mcp.tool_rx = Some(rx);
885        // No changes sent; has_changed() returns false.
886        agent.check_tool_refresh().await;
887        assert_eq!(agent.services.mcp.tool_count(), 0);
888        drop(tx);
889    }
890
891    #[tokio::test]
892    async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
893        let provider = mock_provider(vec![]);
894        let channel = MockChannel::new(vec![]);
895        let registry = create_test_registry();
896        let executor = MockToolExecutor::no_tools();
897        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
898        agent.services.mcp.tools = vec![zeph_mcp::McpTool {
899            server_id: "srv".into(),
900            name: "existing_tool".into(),
901            description: String::new(),
902            input_schema: serde_json::json!({}),
903            output_schema: None,
904            security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
905        }];
906
907        let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
908        agent.services.mcp.tool_rx = Some(rx);
909        // has_changed() is false for a fresh receiver; tools unchanged.
910        agent.check_tool_refresh().await;
911        assert_eq!(agent.services.mcp.tool_count(), 1);
912    }
913
914    #[tokio::test]
915    async fn check_tool_refresh_applies_update() {
916        let provider = mock_provider(vec![]);
917        let channel = MockChannel::new(vec![]);
918        let registry = create_test_registry();
919        let executor = MockToolExecutor::no_tools();
920        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
921
922        let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
923        agent.services.mcp.tool_rx = Some(rx);
924
925        let new_tools = vec![zeph_mcp::McpTool {
926            server_id: "srv".into(),
927            name: "refreshed_tool".into(),
928            description: String::new(),
929            input_schema: serde_json::json!({}),
930            output_schema: None,
931            security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
932        }];
933        tx.send(new_tools).unwrap();
934
935        agent.check_tool_refresh().await;
936        assert_eq!(agent.services.mcp.tool_count(), 1);
937        assert_eq!(agent.services.mcp.tools[0].name, "refreshed_tool");
938    }
939
940    #[test]
941    fn sanitize_elicitation_message_strips_control_chars() {
942        let input = "hello\x01world\x1b[31mred\x1b[0m";
943        let output = sanitize_elicitation_message(input);
944        assert!(!output.contains('\x01'));
945        assert!(!output.contains('\x1b'));
946        assert!(output.contains("hello"));
947        assert!(output.contains("world"));
948    }
949
950    #[test]
951    fn sanitize_elicitation_message_preserves_newline_and_tab() {
952        let input = "line1\nline2\ttabbed";
953        let output = sanitize_elicitation_message(input);
954        assert_eq!(output, "line1\nline2\ttabbed");
955    }
956
957    #[test]
958    fn sanitize_elicitation_message_caps_at_500_chars() {
959        // Build a 600-char ASCII string — no multi-byte boundary issue.
960        let input: String = "a".repeat(600);
961        let output = sanitize_elicitation_message(&input);
962        assert_eq!(output.chars().count(), 500);
963    }
964
965    #[test]
966    fn sanitize_elicitation_message_handles_multibyte_boundary() {
967        // "é" is 2 bytes.  Build a string where a naive &str[..500] would panic.
968        let input: String = "é".repeat(300); // 300 chars = 600 bytes
969        let output = sanitize_elicitation_message(&input);
970        // Should truncate to exactly 500 chars without panic.
971        assert_eq!(output.chars().count(), 300);
972    }
973
974    #[test]
975    fn build_elicitation_fields_maps_primitive_types() {
976        use crate::channel::ElicitationFieldType;
977        use rmcp::model::{
978            BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
979            StringSchema,
980        };
981        use std::collections::BTreeMap;
982
983        let mut props = BTreeMap::new();
984        props.insert(
985            "flag".to_owned(),
986            PrimitiveSchema::Boolean(BooleanSchema::new()),
987        );
988        props.insert(
989            "count".to_owned(),
990            PrimitiveSchema::Integer(IntegerSchema::new()),
991        );
992        props.insert(
993            "ratio".to_owned(),
994            PrimitiveSchema::Number(NumberSchema::new()),
995        );
996        props.insert(
997            "name".to_owned(),
998            PrimitiveSchema::String(StringSchema::new()),
999        );
1000
1001        let schema = ElicitationSchema::new(props);
1002        let fields = build_elicitation_fields(&schema);
1003
1004        let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1005        assert!(matches!(
1006            get("flag").field_type,
1007            ElicitationFieldType::Boolean
1008        ));
1009        assert!(matches!(
1010            get("count").field_type,
1011            ElicitationFieldType::Integer
1012        ));
1013        assert!(matches!(
1014            get("ratio").field_type,
1015            ElicitationFieldType::Number
1016        ));
1017        assert!(matches!(
1018            get("name").field_type,
1019            ElicitationFieldType::String
1020        ));
1021    }
1022
1023    #[test]
1024    fn build_elicitation_fields_required_flag() {
1025        use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1026        use std::collections::BTreeMap;
1027
1028        let mut props = BTreeMap::new();
1029        props.insert(
1030            "req".to_owned(),
1031            PrimitiveSchema::String(StringSchema::new()),
1032        );
1033        props.insert(
1034            "opt".to_owned(),
1035            PrimitiveSchema::String(StringSchema::new()),
1036        );
1037
1038        let mut schema = ElicitationSchema::new(props);
1039        schema.required = Some(vec!["req".to_owned()]);
1040
1041        let fields = build_elicitation_fields(&schema);
1042        let req = fields.iter().find(|f| f.name == "req").unwrap();
1043        let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1044        assert!(req.required);
1045        assert!(!opt.required);
1046    }
1047
1048    #[test]
1049    fn is_sensitive_field_detects_common_patterns() {
1050        assert!(is_sensitive_field("password"));
1051        assert!(is_sensitive_field("PASSWORD"));
1052        assert!(is_sensitive_field("user_password"));
1053        assert!(is_sensitive_field("api_token"));
1054        assert!(is_sensitive_field("SECRET_KEY"));
1055        assert!(is_sensitive_field("auth_header"));
1056        assert!(is_sensitive_field("private_key"));
1057    }
1058
1059    #[test]
1060    fn is_sensitive_field_allows_non_sensitive_names() {
1061        assert!(!is_sensitive_field("username"));
1062        assert!(!is_sensitive_field("email"));
1063        assert!(!is_sensitive_field("message"));
1064        assert!(!is_sensitive_field("description"));
1065        assert!(!is_sensitive_field("subject"));
1066    }
1067}