Skip to main content

zeph_core/agent/
mcp.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9    pub(super) async fn handle_mcp_command(
10        &mut self,
11        args: &str,
12    ) -> Result<(), super::error::AgentError> {
13        let parts: Vec<&str> = args.split_whitespace().collect();
14        match parts.first().copied() {
15            Some("add") => self.handle_mcp_add(&parts[1..]).await,
16            Some("list") => self.handle_mcp_list().await,
17            Some("tools") => self.handle_mcp_tools(parts.get(1).copied()).await,
18            Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
19            _ => {
20                self.channel
21                    .send("Usage: /mcp add|list|tools|remove")
22                    .await?;
23                Ok(())
24            }
25        }
26    }
27
28    #[allow(clippy::too_many_lines)]
29    async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<(), super::error::AgentError> {
30        if args.len() < 2 {
31            self.channel
32                .send("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>")
33                .await?;
34            return Ok(());
35        }
36
37        let Some(ref manager) = self.mcp.manager else {
38            self.channel.send("MCP is not enabled.").await?;
39            return Ok(());
40        };
41
42        let target = args[1];
43        let is_url = target.starts_with("http://") || target.starts_with("https://");
44
45        // SEC-MCP-01: validate command against allowlist (stdio only)
46        if !is_url
47            && !self.mcp.allowed_commands.is_empty()
48            && !self.mcp.allowed_commands.iter().any(|c| c == target)
49        {
50            self.channel
51                .send(&format!(
52                    "Command '{target}' is not allowed. Permitted: {}",
53                    self.mcp.allowed_commands.join(", ")
54                ))
55                .await?;
56            return Ok(());
57        }
58
59        // SEC-MCP-03: enforce server limit
60        let current_count = manager.list_servers().await.len();
61        if current_count >= self.mcp.max_dynamic {
62            self.channel
63                .send(&format!(
64                    "Server limit reached ({}/{}).",
65                    current_count, self.mcp.max_dynamic
66                ))
67                .await?;
68            return Ok(());
69        }
70
71        let transport = if is_url {
72            zeph_mcp::McpTransport::Http {
73                url: target.to_owned(),
74                headers: std::collections::HashMap::new(),
75            }
76        } else {
77            zeph_mcp::McpTransport::Stdio {
78                command: target.to_owned(),
79                args: args[2..].iter().map(|&s| s.to_owned()).collect(),
80                env: std::collections::HashMap::new(),
81            }
82        };
83
84        let entry = zeph_mcp::ServerEntry {
85            id: args[0].to_owned(),
86            transport,
87            timeout: std::time::Duration::from_secs(30),
88            trust_level: zeph_mcp::McpTrustLevel::Untrusted,
89            tool_allowlist: None,
90            expected_tools: Vec::new(),
91            roots: Vec::new(),
92            tool_metadata: std::collections::HashMap::new(),
93            elicitation_enabled: false,
94            elicitation_timeout_secs: 120,
95            env_isolation: false,
96        };
97
98        let _ = self.channel.send_status("connecting to mcp...").await;
99        match manager.add_server(&entry).await {
100            Ok(tools) => {
101                let _ = self.channel.send_status("").await;
102                let count = tools.len();
103                self.mcp
104                    .server_outcomes
105                    .push(zeph_mcp::ServerConnectOutcome {
106                        id: entry.id.clone(),
107                        connected: true,
108                        tool_count: count,
109                        error: String::new(),
110                    });
111                self.mcp.tools.extend(tools);
112                self.sync_mcp_executor_tools();
113                self.mcp.pruning_cache.reset();
114                self.rebuild_semantic_index().await;
115                self.sync_mcp_registry().await;
116                let mcp_total = self.mcp.tools.len();
117                let mcp_server_count = self.mcp.server_outcomes.len();
118                let mcp_connected_count = self
119                    .mcp
120                    .server_outcomes
121                    .iter()
122                    .filter(|o| o.connected)
123                    .count();
124                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
125                    .mcp
126                    .server_outcomes
127                    .iter()
128                    .map(|o| crate::metrics::McpServerStatus {
129                        id: o.id.clone(),
130                        status: if o.connected {
131                            crate::metrics::McpServerConnectionStatus::Connected
132                        } else {
133                            crate::metrics::McpServerConnectionStatus::Failed
134                        },
135                        tool_count: o.tool_count,
136                        error: o.error.clone(),
137                    })
138                    .collect();
139                self.update_metrics(|m| {
140                    m.mcp_tool_count = mcp_total;
141                    m.mcp_server_count = mcp_server_count;
142                    m.mcp_connected_count = mcp_connected_count;
143                    m.mcp_servers = mcp_servers;
144                });
145                self.channel
146                    .send(&format!(
147                        "Connected MCP server '{}' ({count} tool(s))",
148                        entry.id
149                    ))
150                    .await?;
151                Ok(())
152            }
153            Err(e) => {
154                let _ = self.channel.send_status("").await;
155                tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
156                self.channel
157                    .send(&format!("Failed to connect server '{}': {e}", entry.id))
158                    .await?;
159                Ok(())
160            }
161        }
162    }
163
164    async fn handle_mcp_list(&mut self) -> Result<(), super::error::AgentError> {
165        use std::fmt::Write;
166
167        let Some(ref manager) = self.mcp.manager else {
168            self.channel.send("MCP is not enabled.").await?;
169            return Ok(());
170        };
171
172        let server_ids = manager.list_servers().await;
173        if server_ids.is_empty() {
174            self.channel.send("No MCP servers connected.").await?;
175            return Ok(());
176        }
177
178        let mut output = String::from("Connected MCP servers:\n");
179        let mut total = 0usize;
180        for id in &server_ids {
181            let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
182            total += count;
183            let _ = writeln!(output, "- {id} ({count} tools)");
184        }
185        let _ = write!(output, "Total: {total} tool(s)");
186
187        self.channel.send(&output).await?;
188        Ok(())
189    }
190
191    async fn handle_mcp_tools(
192        &mut self,
193        server_id: Option<&str>,
194    ) -> Result<(), super::error::AgentError> {
195        use std::fmt::Write;
196
197        let Some(server_id) = server_id else {
198            self.channel.send("Usage: /mcp tools <server_id>").await?;
199            return Ok(());
200        };
201
202        let tools: Vec<_> = self
203            .mcp
204            .tools
205            .iter()
206            .filter(|t| t.server_id == server_id)
207            .collect();
208
209        if tools.is_empty() {
210            self.channel
211                .send(&format!("No tools found for server '{server_id}'."))
212                .await?;
213            return Ok(());
214        }
215
216        let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
217        for t in &tools {
218            if t.description.is_empty() {
219                let _ = writeln!(output, "- {}", t.name);
220            } else {
221                let _ = writeln!(output, "- {} — {}", t.name, t.description);
222            }
223        }
224        self.channel.send(&output).await?;
225        Ok(())
226    }
227
228    async fn handle_mcp_remove(
229        &mut self,
230        server_id: Option<&str>,
231    ) -> Result<(), super::error::AgentError> {
232        let Some(server_id) = server_id else {
233            self.channel.send("Usage: /mcp remove <id>").await?;
234            return Ok(());
235        };
236
237        let Some(ref manager) = self.mcp.manager else {
238            self.channel.send("MCP is not enabled.").await?;
239            return Ok(());
240        };
241
242        match manager.remove_server(server_id).await {
243            Ok(()) => {
244                let before = self.mcp.tools.len();
245                self.mcp.tools.retain(|t| t.server_id != server_id);
246                let removed = before - self.mcp.tools.len();
247                self.mcp.server_outcomes.retain(|o| o.id != server_id);
248                self.sync_mcp_executor_tools();
249                self.mcp.pruning_cache.reset();
250                self.rebuild_semantic_index().await;
251                self.sync_mcp_registry().await;
252                let mcp_total = self.mcp.tools.len();
253                let mcp_server_count = self.mcp.server_outcomes.len();
254                let mcp_connected_count = self
255                    .mcp
256                    .server_outcomes
257                    .iter()
258                    .filter(|o| o.connected)
259                    .count();
260                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
261                    .mcp
262                    .server_outcomes
263                    .iter()
264                    .map(|o| crate::metrics::McpServerStatus {
265                        id: o.id.clone(),
266                        status: if o.connected {
267                            crate::metrics::McpServerConnectionStatus::Connected
268                        } else {
269                            crate::metrics::McpServerConnectionStatus::Failed
270                        },
271                        tool_count: o.tool_count,
272                        error: o.error.clone(),
273                    })
274                    .collect();
275                self.update_metrics(|m| {
276                    m.mcp_tool_count = mcp_total;
277                    m.mcp_server_count = mcp_server_count;
278                    m.mcp_connected_count = mcp_connected_count;
279                    m.mcp_servers = mcp_servers;
280                    m.active_mcp_tools
281                        .retain(|name| !name.starts_with(&format!("{server_id}:")));
282                });
283                self.channel
284                    .send(&format!(
285                        "Disconnected MCP server '{server_id}' (removed {removed} tools)"
286                    ))
287                    .await?;
288                Ok(())
289            }
290            Err(e) => {
291                tracing::warn!(server_id, "MCP remove failed: {e:#}");
292                self.channel
293                    .send(&format!("Failed to remove server '{server_id}': {e}"))
294                    .await?;
295                Ok(())
296            }
297        }
298    }
299
300    pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
301        let matched_tools = self.match_mcp_tools(query).await;
302        let active_mcp: Vec<String> = matched_tools
303            .iter()
304            .map(zeph_mcp::McpTool::qualified_name)
305            .collect();
306        let mcp_total = self.mcp.tools.len();
307        let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
308            let connected = self
309                .mcp
310                .tools
311                .iter()
312                .map(|t| &t.server_id)
313                .collect::<std::collections::HashSet<_>>()
314                .len();
315            (connected, connected)
316        } else {
317            let total = self.mcp.server_outcomes.len();
318            let connected = self
319                .mcp
320                .server_outcomes
321                .iter()
322                .filter(|o| o.connected)
323                .count();
324            (total, connected)
325        };
326        self.update_metrics(|m| {
327            m.active_mcp_tools = active_mcp;
328            m.mcp_tool_count = mcp_total;
329            m.mcp_server_count = mcp_server_count;
330            m.mcp_connected_count = mcp_connected_count;
331        });
332        if let Some(ref manager) = self.mcp.manager {
333            let instructions = manager.all_server_instructions().await;
334            if !instructions.is_empty() {
335                system_prompt.push_str("\n\n");
336                system_prompt.push_str(&instructions);
337            }
338        }
339        // When native tool_use is active, MCP tools flow through the executor chain
340        // as ToolDefinitions — skip text prompt injection to avoid duplication.
341        if self.provider.supports_tool_use() {
342            return;
343        }
344        if !matched_tools.is_empty() {
345            let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
346            tracing::debug!(
347                skills = ?self.skill_state.active_skill_names,
348                mcp_tools = ?tool_names,
349                "matched items"
350            );
351            let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
352            if !tools_prompt.is_empty() {
353                system_prompt.push_str("\n\n");
354                system_prompt.push_str(&tools_prompt);
355            }
356        }
357    }
358
359    async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
360        let Some(ref registry) = self.mcp.registry else {
361            return self.mcp.tools.clone();
362        };
363        let provider = self.embedding_provider.clone();
364        registry
365            .search(query, self.skill_state.max_active_skills, |text| {
366                let owned = text.to_owned();
367                let p = provider.clone();
368                Box::pin(async move { p.embed(&owned).await })
369            })
370            .await
371    }
372
373    #[cfg(test)]
374    pub(crate) fn mcp_tool_count(&self) -> usize {
375        self.mcp.tools.len()
376    }
377
378    /// Poll the watch receiver for tool list updates from `tools/list_changed` notifications.
379    ///
380    /// Called once per agent turn, before processing user input. When the tool list has changed,
381    /// updates `mcp.tools`, syncs the executor, and schedules a registry sync.
382    /// If no receiver is set (MCP disabled), or no change has occurred, this is a no-op.
383    pub(super) async fn check_tool_refresh(&mut self) {
384        let Some(ref mut rx) = self.mcp.tool_rx else {
385            return;
386        };
387        if !rx.has_changed().unwrap_or(false) {
388            return;
389        }
390        let new_tools = rx.borrow_and_update().clone();
391        if new_tools.is_empty() {
392            // Guard against replacing a non-empty initial tool list with the watch's empty
393            // initial value. The watch is only updated after a real tools/list_changed event.
394            return;
395        }
396        tracing::info!(
397            tools = new_tools.len(),
398            "tools/list_changed: agent tool list refreshed"
399        );
400        self.mcp.tools = new_tools;
401        self.sync_mcp_executor_tools();
402        self.mcp.pruning_cache.reset();
403        self.rebuild_semantic_index().await;
404        self.sync_mcp_registry().await;
405        let mcp_total = self.mcp.tools.len();
406        let mcp_servers = self
407            .mcp
408            .tools
409            .iter()
410            .map(|t| &t.server_id)
411            .collect::<std::collections::HashSet<_>>()
412            .len();
413        self.update_metrics(|m| {
414            m.mcp_tool_count = mcp_total;
415            m.mcp_server_count = mcp_servers;
416        });
417    }
418
419    /// Write the **full** `self.mcp.tools` set to the shared executor `RwLock`.
420    ///
421    /// This is the first of two writers to `mcp.shared_tools`.  Within a turn
422    /// this method must run **before** `apply_pruned_mcp_tools`, which writes the
423    /// pruned subset.  The normal call order guarantees this: tool-list change
424    /// events (notify, `/mcp add`, `/mcp remove`) call this method, and pruning
425    /// runs later inside `rebuild_system_prompt`.
426    /// See also: `McpState::shared_tools` doc comment.
427    pub(super) fn sync_mcp_executor_tools(&self) {
428        if let Some(ref shared) = self.mcp.shared_tools {
429            let mut guard = shared
430                .write()
431                .unwrap_or_else(std::sync::PoisonError::into_inner);
432            guard.clone_from(&self.mcp.tools);
433        }
434    }
435
436    /// Write the **pruned** tool subset to the shared executor `RwLock`.
437    ///
438    /// This is the second of two writers to `mcp.shared_tools`.  Must only be
439    /// called **after** `sync_mcp_executor_tools` has established the full tool
440    /// set for the current turn (guaranteed by call-site ordering: pruning runs
441    /// inside `rebuild_system_prompt`, after any tool-list change events).
442    ///
443    /// `self.mcp.tools` (the full set) is intentionally **not** modified: it is
444    /// retained for cache key computation and for restoration when the next turn
445    /// triggers a cache reset.
446    ///
447    /// This method must **NOT** call `sync_mcp_executor_tools` internally —
448    /// doing so would overwrite the pruned subset with the full set.
449    /// See also: `McpState::shared_tools` doc comment.
450    pub(in crate::agent) fn apply_pruned_mcp_tools(&self, pruned: Vec<zeph_mcp::McpTool>) {
451        debug_assert!(
452            pruned.iter().all(|p| self
453                .mcp
454                .tools
455                .iter()
456                .any(|t| t.server_id == p.server_id && t.name == p.name)),
457            "pruned set must be a subset of self.mcp.tools"
458        );
459        if let Some(ref shared) = self.mcp.shared_tools {
460            let mut guard = shared
461                .write()
462                .unwrap_or_else(std::sync::PoisonError::into_inner);
463            *guard = pruned;
464        }
465    }
466
467    pub(super) async fn sync_mcp_registry(&mut self) {
468        let Some(ref mut registry) = self.mcp.registry else {
469            return;
470        };
471        if !self.embedding_provider.supports_embeddings() {
472            return;
473        }
474        let provider = self.embedding_provider.clone();
475        let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
476            let owned = text.to_owned();
477            let p = provider.clone();
478            Box::pin(async move { p.embed(&owned).await })
479        };
480        if let Err(e) = registry
481            .sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
482            .await
483        {
484            tracing::warn!("failed to sync MCP tool registry: {e:#}");
485        }
486    }
487
488    /// Build (or rebuild) the in-memory semantic tool index for embedding-based discovery.
489    /// Build the initial semantic tool index after agent construction.
490    ///
491    /// Must be called once after `with_mcp` and `with_mcp_discovery` are applied,
492    /// before the first user turn.  Subsequent rebuilds happen automatically on
493    /// tool list change events (`check_tool_refresh`, `/mcp add`, `/mcp remove`).
494    pub async fn init_semantic_index(&mut self) {
495        self.rebuild_semantic_index().await;
496    }
497
498    /// Drain and process all pending elicitation requests without blocking.
499    ///
500    /// Call this at the start of each turn and between tool calls to prevent
501    /// elicitation events from accumulating while the agent loop is busy.
502    pub(super) async fn process_pending_elicitations(&mut self) {
503        loop {
504            let Some(ref mut rx) = self.mcp.elicitation_rx else {
505                return;
506            };
507            match rx.try_recv() {
508                Ok(event) => {
509                    self.handle_elicitation_event(event).await;
510                }
511                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
512                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
513                    self.mcp.elicitation_rx = None;
514                    return;
515                }
516            }
517        }
518    }
519
520    /// Handle a single elicitation event by routing it to the active channel.
521    pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
522        use crate::channel::{ElicitationRequest, ElicitationResponse};
523
524        let decline = CreateElicitationResult {
525            action: ElicitationAction::Decline,
526            content: None,
527        };
528
529        let channel_request = match &event.request {
530            rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
531                message,
532                requested_schema,
533                ..
534            } => {
535                let fields = build_elicitation_fields(requested_schema);
536                ElicitationRequest {
537                    server_name: event.server_id.clone(),
538                    message: sanitize_elicitation_message(message),
539                    fields,
540                }
541            }
542            rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
543                // URL elicitation not supported in phase 1 — decline.
544                tracing::debug!(
545                    server_id = event.server_id,
546                    "URL elicitation not supported, declining"
547                );
548                let _ = event.response_tx.send(decline);
549                return;
550            }
551        };
552
553        if self.mcp.elicitation_warn_sensitive_fields {
554            let sensitive: Vec<&str> = channel_request
555                .fields
556                .iter()
557                .filter(|f| is_sensitive_field(&f.name))
558                .map(|f| f.name.as_str())
559                .collect();
560            if !sensitive.is_empty() {
561                let fields_list = sensitive.join(", ");
562                let warning = format!(
563                    "Warning: [{}] is requesting sensitive information (field: {}). \
564                     Only proceed if you trust this server.",
565                    channel_request.server_name, fields_list,
566                );
567                tracing::warn!(
568                    server_id = event.server_id,
569                    fields = %fields_list,
570                    "elicitation requests sensitive fields"
571                );
572                let _ = self.channel.send(&warning).await;
573            }
574        }
575
576        let _ = self
577            .channel
578            .send_status("MCP server requesting input…")
579            .await;
580        let response = match self.channel.elicit(channel_request).await {
581            Ok(r) => r,
582            Err(e) => {
583                tracing::warn!(
584                    server_id = event.server_id,
585                    "elicitation channel error: {e:#}"
586                );
587                let _ = self.channel.send_status("").await;
588                let _ = event.response_tx.send(decline);
589                return;
590            }
591        };
592        let _ = self.channel.send_status("").await;
593
594        let result = match response {
595            ElicitationResponse::Accepted(value) => CreateElicitationResult {
596                action: ElicitationAction::Accept,
597                content: Some(value),
598            },
599            ElicitationResponse::Declined => CreateElicitationResult {
600                action: ElicitationAction::Decline,
601                content: None,
602            },
603            ElicitationResponse::Cancelled => CreateElicitationResult {
604                action: ElicitationAction::Cancel,
605                content: None,
606            },
607        };
608
609        if event.response_tx.send(result).is_err() {
610            tracing::warn!(
611                server_id = event.server_id,
612                "elicitation response dropped — handler disconnected"
613            );
614        }
615    }
616
617    /// Rebuild the in-memory semantic tool index.
618    ///
619    /// Only runs when `discovery_strategy == Embedding`.  On failure (all embeddings fail),
620    /// sets `semantic_index = None` and logs at WARN — the caller falls back to all tools.
621    ///
622    /// Called at:
623    /// - initial setup via `init_semantic_index()`
624    /// - `tools/list_changed` notification
625    /// - `/mcp add` and `/mcp remove`
626    pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
627        if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
628            return;
629        }
630
631        if self.mcp.tools.is_empty() {
632            self.mcp.semantic_index = None;
633            return;
634        }
635
636        // Resolve embedding provider: dedicated discovery provider → primary embedding provider.
637        let provider = self
638            .mcp
639            .discovery_provider
640            .clone()
641            .unwrap_or_else(|| self.embedding_provider.clone());
642
643        let embed_fn = provider.embed_fn();
644
645        match zeph_mcp::SemanticToolIndex::build(&self.mcp.tools, &embed_fn).await {
646            Ok(idx) => {
647                tracing::info!(
648                    indexed = idx.len(),
649                    total = self.mcp.tools.len(),
650                    "semantic tool index built"
651                );
652                self.mcp.semantic_index = Some(idx);
653            }
654            Err(e) => {
655                tracing::warn!(
656                    "semantic tool index build failed, falling back to all tools: {e:#}"
657                );
658                self.mcp.semantic_index = None;
659            }
660        }
661    }
662}
663
664/// Convert an rmcp `ElicitationSchema` into channel-agnostic `ElicitationField` list.
665fn build_elicitation_fields(
666    schema: &rmcp::model::ElicitationSchema,
667) -> Vec<crate::channel::ElicitationField> {
668    use crate::channel::{ElicitationField, ElicitationFieldType};
669    use rmcp::model::PrimitiveSchema;
670
671    schema
672        .properties
673        .iter()
674        .map(|(name, prop)| {
675            // Extract field type and description by serializing the PrimitiveSchema to JSON
676            // and reading the discriminator field.  This avoids deep-matching the nested
677            // EnumSchema / StringSchema / … variants of rmcp's type-safe schema hierarchy.
678            let json = serde_json::to_value(prop).unwrap_or_default();
679            let description = json
680                .get("description")
681                .and_then(|v| v.as_str())
682                .map(String::from);
683
684            let field_type = match prop {
685                PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
686                PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
687                PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
688                PrimitiveSchema::String(_) => ElicitationFieldType::String,
689                PrimitiveSchema::Enum(_) => {
690                    // Extract enum values from the serialized form.  All EnumSchema variants
691                    // serialise their allowed values under "enum" or inside "items.enum".
692                    let vals = json
693                        .get("enum")
694                        .and_then(|v| v.as_array())
695                        .map(|arr| {
696                            arr.iter()
697                                .filter_map(|v| v.as_str())
698                                .map(String::from)
699                                .collect::<Vec<_>>()
700                        })
701                        .unwrap_or_default();
702                    ElicitationFieldType::Enum(vals)
703                }
704            };
705            let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
706            ElicitationField {
707                name: name.clone(),
708                description,
709                field_type,
710                required,
711            }
712        })
713        .collect()
714}
715
716/// Sensitive field name patterns (case-insensitive substring match).
717const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
718    "password",
719    "passwd",
720    "token",
721    "secret",
722    "key",
723    "credential",
724    "apikey",
725    "api_key",
726    "auth",
727    "authorization",
728    "private",
729    "passphrase",
730    "pin",
731];
732
733/// Returns `true` when `field_name` matches any sensitive pattern (case-insensitive).
734fn is_sensitive_field(field_name: &str) -> bool {
735    let lower = field_name.to_lowercase();
736    SENSITIVE_FIELD_PATTERNS
737        .iter()
738        .any(|pattern| lower.contains(pattern))
739}
740
741/// Sanitize an elicitation message: cap length (in chars, not bytes) and strip control chars.
742fn sanitize_elicitation_message(message: &str) -> String {
743    const MAX_CHARS: usize = 500;
744    // Collect up to MAX_CHARS chars, filtering control characters that could manipulate terminals.
745    message
746        .chars()
747        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
748        .take(MAX_CHARS)
749        .collect()
750}
751
752#[cfg(test)]
753mod tests {
754    use super::super::agent_tests::{
755        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
756    };
757    use super::*;
758
759    #[tokio::test]
760    async fn handle_mcp_command_unknown_subcommand_shows_usage() {
761        let provider = mock_provider(vec![]);
762        let channel = MockChannel::new(vec![]);
763        let registry = create_test_registry();
764        let executor = MockToolExecutor::no_tools();
765        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
766
767        agent.handle_mcp_command("unknown").await.unwrap();
768
769        let sent = agent.channel.sent_messages();
770        assert!(
771            sent.iter().any(|s| s.contains("Usage: /mcp")),
772            "expected usage message, got: {sent:?}"
773        );
774    }
775
776    #[tokio::test]
777    async fn handle_mcp_list_no_manager_shows_disabled() {
778        let provider = mock_provider(vec![]);
779        let channel = MockChannel::new(vec![]);
780        let registry = create_test_registry();
781        let executor = MockToolExecutor::no_tools();
782        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
783
784        agent.handle_mcp_command("list").await.unwrap();
785
786        let sent = agent.channel.sent_messages();
787        assert!(
788            sent.iter().any(|s| s.contains("MCP is not enabled")),
789            "expected not-enabled message, got: {sent:?}"
790        );
791    }
792
793    #[tokio::test]
794    async fn handle_mcp_tools_no_server_id_shows_usage() {
795        let provider = mock_provider(vec![]);
796        let channel = MockChannel::new(vec![]);
797        let registry = create_test_registry();
798        let executor = MockToolExecutor::no_tools();
799        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
800
801        agent.handle_mcp_command("tools").await.unwrap();
802
803        let sent = agent.channel.sent_messages();
804        assert!(
805            sent.iter().any(|s| s.contains("Usage: /mcp tools")),
806            "expected tools usage message, got: {sent:?}"
807        );
808    }
809
810    #[tokio::test]
811    async fn handle_mcp_remove_no_server_id_shows_usage() {
812        let provider = mock_provider(vec![]);
813        let channel = MockChannel::new(vec![]);
814        let registry = create_test_registry();
815        let executor = MockToolExecutor::no_tools();
816        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
817
818        agent.handle_mcp_command("remove").await.unwrap();
819
820        let sent = agent.channel.sent_messages();
821        assert!(
822            sent.iter().any(|s| s.contains("Usage: /mcp remove")),
823            "expected remove usage message, got: {sent:?}"
824        );
825    }
826
827    #[tokio::test]
828    async fn handle_mcp_remove_no_manager_shows_disabled() {
829        let provider = mock_provider(vec![]);
830        let channel = MockChannel::new(vec![]);
831        let registry = create_test_registry();
832        let executor = MockToolExecutor::no_tools();
833        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
834
835        // "remove server-id" but no manager
836        agent.handle_mcp_command("remove my-server").await.unwrap();
837
838        let sent = agent.channel.sent_messages();
839        assert!(
840            sent.iter().any(|s| s.contains("MCP is not enabled")),
841            "expected not-enabled message, got: {sent:?}"
842        );
843    }
844
845    #[tokio::test]
846    async fn handle_mcp_add_insufficient_args_shows_usage() {
847        let provider = mock_provider(vec![]);
848        let channel = MockChannel::new(vec![]);
849        let registry = create_test_registry();
850        let executor = MockToolExecutor::no_tools();
851        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
852
853        // "add" with only 1 arg (needs at least 2)
854        agent.handle_mcp_command("add server-id").await.unwrap();
855
856        let sent = agent.channel.sent_messages();
857        assert!(
858            sent.iter().any(|s| s.contains("Usage: /mcp add")),
859            "expected add usage message, got: {sent:?}"
860        );
861    }
862
863    #[tokio::test]
864    async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
865        let provider = mock_provider(vec![]);
866        let channel = MockChannel::new(vec![]);
867        let registry = create_test_registry();
868        let executor = MockToolExecutor::no_tools();
869        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
870
871        // mcp.tools is empty, so any server will have no tools
872        agent
873            .handle_mcp_command("tools nonexistent-server")
874            .await
875            .unwrap();
876
877        let sent = agent.channel.sent_messages();
878        assert!(
879            sent.iter().any(|s| s.contains("No tools found")),
880            "expected no-tools message, got: {sent:?}"
881        );
882    }
883
884    #[tokio::test]
885    async fn mcp_tool_count_starts_at_zero() {
886        let provider = mock_provider(vec![]);
887        let channel = MockChannel::new(vec![]);
888        let registry = create_test_registry();
889        let executor = MockToolExecutor::no_tools();
890        let agent = Agent::new(provider, channel, registry, None, 5, executor);
891
892        assert_eq!(agent.mcp_tool_count(), 0);
893    }
894
895    #[tokio::test]
896    async fn check_tool_refresh_no_rx_is_noop() {
897        let provider = mock_provider(vec![]);
898        let channel = MockChannel::new(vec![]);
899        let registry = create_test_registry();
900        let executor = MockToolExecutor::no_tools();
901        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
902        // No tool_rx set; check_tool_refresh should be a no-op.
903        agent.check_tool_refresh().await;
904        assert_eq!(agent.mcp_tool_count(), 0);
905    }
906
907    #[tokio::test]
908    async fn check_tool_refresh_no_change_is_noop() {
909        let provider = mock_provider(vec![]);
910        let channel = MockChannel::new(vec![]);
911        let registry = create_test_registry();
912        let executor = MockToolExecutor::no_tools();
913        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
914
915        let (tx, rx) = tokio::sync::watch::channel(Vec::new());
916        agent.mcp.tool_rx = Some(rx);
917        // No changes sent; has_changed() returns false.
918        agent.check_tool_refresh().await;
919        assert_eq!(agent.mcp_tool_count(), 0);
920        drop(tx);
921    }
922
923    #[tokio::test]
924    async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
925        let provider = mock_provider(vec![]);
926        let channel = MockChannel::new(vec![]);
927        let registry = create_test_registry();
928        let executor = MockToolExecutor::no_tools();
929        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
930        agent.mcp.tools = vec![zeph_mcp::McpTool {
931            server_id: "srv".into(),
932            name: "existing_tool".into(),
933            description: String::new(),
934            input_schema: serde_json::json!({}),
935            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
936        }];
937
938        let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
939        agent.mcp.tool_rx = Some(rx);
940        // has_changed() is false for a fresh receiver; tools unchanged.
941        agent.check_tool_refresh().await;
942        assert_eq!(agent.mcp_tool_count(), 1);
943    }
944
945    #[tokio::test]
946    async fn check_tool_refresh_applies_update() {
947        let provider = mock_provider(vec![]);
948        let channel = MockChannel::new(vec![]);
949        let registry = create_test_registry();
950        let executor = MockToolExecutor::no_tools();
951        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
952
953        let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
954        agent.mcp.tool_rx = Some(rx);
955
956        let new_tools = vec![zeph_mcp::McpTool {
957            server_id: "srv".into(),
958            name: "refreshed_tool".into(),
959            description: String::new(),
960            input_schema: serde_json::json!({}),
961            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
962        }];
963        tx.send(new_tools).unwrap();
964
965        agent.check_tool_refresh().await;
966        assert_eq!(agent.mcp_tool_count(), 1);
967        assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
968    }
969
970    #[test]
971    fn sanitize_elicitation_message_strips_control_chars() {
972        let input = "hello\x01world\x1b[31mred\x1b[0m";
973        let output = sanitize_elicitation_message(input);
974        assert!(!output.contains('\x01'));
975        assert!(!output.contains('\x1b'));
976        assert!(output.contains("hello"));
977        assert!(output.contains("world"));
978    }
979
980    #[test]
981    fn sanitize_elicitation_message_preserves_newline_and_tab() {
982        let input = "line1\nline2\ttabbed";
983        let output = sanitize_elicitation_message(input);
984        assert_eq!(output, "line1\nline2\ttabbed");
985    }
986
987    #[test]
988    fn sanitize_elicitation_message_caps_at_500_chars() {
989        // Build a 600-char ASCII string — no multi-byte boundary issue.
990        let input: String = "a".repeat(600);
991        let output = sanitize_elicitation_message(&input);
992        assert_eq!(output.chars().count(), 500);
993    }
994
995    #[test]
996    fn sanitize_elicitation_message_handles_multibyte_boundary() {
997        // "é" is 2 bytes.  Build a string where a naive &str[..500] would panic.
998        let input: String = "é".repeat(300); // 300 chars = 600 bytes
999        let output = sanitize_elicitation_message(&input);
1000        // Should truncate to exactly 500 chars without panic.
1001        assert_eq!(output.chars().count(), 300);
1002    }
1003
1004    #[test]
1005    fn build_elicitation_fields_maps_primitive_types() {
1006        use crate::channel::ElicitationFieldType;
1007        use rmcp::model::{
1008            BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
1009            StringSchema,
1010        };
1011        use std::collections::BTreeMap;
1012
1013        let mut props = BTreeMap::new();
1014        props.insert(
1015            "flag".to_owned(),
1016            PrimitiveSchema::Boolean(BooleanSchema::new()),
1017        );
1018        props.insert(
1019            "count".to_owned(),
1020            PrimitiveSchema::Integer(IntegerSchema::new()),
1021        );
1022        props.insert(
1023            "ratio".to_owned(),
1024            PrimitiveSchema::Number(NumberSchema::new()),
1025        );
1026        props.insert(
1027            "name".to_owned(),
1028            PrimitiveSchema::String(StringSchema::new()),
1029        );
1030
1031        let schema = ElicitationSchema::new(props);
1032        let fields = build_elicitation_fields(&schema);
1033
1034        let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1035        assert!(matches!(
1036            get("flag").field_type,
1037            ElicitationFieldType::Boolean
1038        ));
1039        assert!(matches!(
1040            get("count").field_type,
1041            ElicitationFieldType::Integer
1042        ));
1043        assert!(matches!(
1044            get("ratio").field_type,
1045            ElicitationFieldType::Number
1046        ));
1047        assert!(matches!(
1048            get("name").field_type,
1049            ElicitationFieldType::String
1050        ));
1051    }
1052
1053    #[test]
1054    fn build_elicitation_fields_required_flag() {
1055        use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1056        use std::collections::BTreeMap;
1057
1058        let mut props = BTreeMap::new();
1059        props.insert(
1060            "req".to_owned(),
1061            PrimitiveSchema::String(StringSchema::new()),
1062        );
1063        props.insert(
1064            "opt".to_owned(),
1065            PrimitiveSchema::String(StringSchema::new()),
1066        );
1067
1068        let mut schema = ElicitationSchema::new(props);
1069        schema.required = Some(vec!["req".to_owned()]);
1070
1071        let fields = build_elicitation_fields(&schema);
1072        let req = fields.iter().find(|f| f.name == "req").unwrap();
1073        let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1074        assert!(req.required);
1075        assert!(!opt.required);
1076    }
1077
1078    #[test]
1079    fn is_sensitive_field_detects_common_patterns() {
1080        assert!(is_sensitive_field("password"));
1081        assert!(is_sensitive_field("PASSWORD"));
1082        assert!(is_sensitive_field("user_password"));
1083        assert!(is_sensitive_field("api_token"));
1084        assert!(is_sensitive_field("SECRET_KEY"));
1085        assert!(is_sensitive_field("auth_header"));
1086        assert!(is_sensitive_field("private_key"));
1087    }
1088
1089    #[test]
1090    fn is_sensitive_field_allows_non_sensitive_names() {
1091        assert!(!is_sensitive_field("username"));
1092        assert!(!is_sensitive_field("email"));
1093        assert!(!is_sensitive_field("message"));
1094        assert!(!is_sensitive_field("description"));
1095        assert!(!is_sensitive_field("subject"));
1096    }
1097}