Skip to main content

zeph_core/agent/
mcp.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9    pub(super) async fn handle_mcp_command(
10        &mut self,
11        args: &str,
12    ) -> Result<(), super::error::AgentError> {
13        let parts: Vec<&str> = args.split_whitespace().collect();
14        match parts.first().copied() {
15            Some("add") => self.handle_mcp_add(&parts[1..]).await,
16            Some("list") => self.handle_mcp_list().await,
17            Some("tools") => self.handle_mcp_tools(parts.get(1).copied()).await,
18            Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
19            _ => {
20                self.channel
21                    .send("Usage: /mcp add|list|tools|remove")
22                    .await?;
23                Ok(())
24            }
25        }
26    }
27
28    #[allow(clippy::too_many_lines)]
29    async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<(), super::error::AgentError> {
30        if args.len() < 2 {
31            self.channel
32                .send("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>")
33                .await?;
34            return Ok(());
35        }
36
37        let Some(ref manager) = self.mcp.manager else {
38            self.channel.send("MCP is not enabled.").await?;
39            return Ok(());
40        };
41
42        let target = args[1];
43        let is_url = target.starts_with("http://") || target.starts_with("https://");
44
45        // SEC-MCP-01: validate command against allowlist (stdio only)
46        if !is_url
47            && !self.mcp.allowed_commands.is_empty()
48            && !self.mcp.allowed_commands.iter().any(|c| c == target)
49        {
50            self.channel
51                .send(&format!(
52                    "Command '{target}' is not allowed. Permitted: {}",
53                    self.mcp.allowed_commands.join(", ")
54                ))
55                .await?;
56            return Ok(());
57        }
58
59        // SEC-MCP-03: enforce server limit
60        let current_count = manager.list_servers().await.len();
61        if current_count >= self.mcp.max_dynamic {
62            self.channel
63                .send(&format!(
64                    "Server limit reached ({}/{}).",
65                    current_count, self.mcp.max_dynamic
66                ))
67                .await?;
68            return Ok(());
69        }
70
71        let transport = if is_url {
72            zeph_mcp::McpTransport::Http {
73                url: target.to_owned(),
74                headers: std::collections::HashMap::new(),
75            }
76        } else {
77            zeph_mcp::McpTransport::Stdio {
78                command: target.to_owned(),
79                args: args[2..].iter().map(|&s| s.to_owned()).collect(),
80                env: std::collections::HashMap::new(),
81            }
82        };
83
84        let entry = zeph_mcp::ServerEntry {
85            id: args[0].to_owned(),
86            transport,
87            timeout: std::time::Duration::from_secs(30),
88            trust_level: zeph_mcp::McpTrustLevel::Untrusted,
89            tool_allowlist: None,
90            expected_tools: Vec::new(),
91            roots: Vec::new(),
92            tool_metadata: std::collections::HashMap::new(),
93            elicitation_enabled: false,
94            elicitation_timeout_secs: 120,
95            env_isolation: false,
96        };
97
98        let _ = self.channel.send_status("connecting to mcp...").await;
99        match manager.add_server(&entry).await {
100            Ok(tools) => {
101                let _ = self.channel.send_status("").await;
102                let count = tools.len();
103                self.mcp
104                    .server_outcomes
105                    .push(zeph_mcp::ServerConnectOutcome {
106                        id: entry.id.clone(),
107                        connected: true,
108                        tool_count: count,
109                        error: String::new(),
110                    });
111                self.mcp.tools.extend(tools);
112                self.sync_mcp_executor_tools();
113                self.mcp.pruning_cache.reset();
114                self.rebuild_semantic_index().await;
115                self.sync_mcp_registry().await;
116                let mcp_total = self.mcp.tools.len();
117                let mcp_server_count = self.mcp.server_outcomes.len();
118                let mcp_connected_count = self
119                    .mcp
120                    .server_outcomes
121                    .iter()
122                    .filter(|o| o.connected)
123                    .count();
124                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
125                    .mcp
126                    .server_outcomes
127                    .iter()
128                    .map(|o| crate::metrics::McpServerStatus {
129                        id: o.id.clone(),
130                        status: if o.connected {
131                            crate::metrics::McpServerConnectionStatus::Connected
132                        } else {
133                            crate::metrics::McpServerConnectionStatus::Failed
134                        },
135                        tool_count: o.tool_count,
136                        error: o.error.clone(),
137                    })
138                    .collect();
139                self.update_metrics(|m| {
140                    m.mcp_tool_count = mcp_total;
141                    m.mcp_server_count = mcp_server_count;
142                    m.mcp_connected_count = mcp_connected_count;
143                    m.mcp_servers = mcp_servers;
144                });
145                self.channel
146                    .send(&format!(
147                        "Connected MCP server '{}' ({count} tool(s))",
148                        entry.id
149                    ))
150                    .await?;
151                Ok(())
152            }
153            Err(e) => {
154                let _ = self.channel.send_status("").await;
155                tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
156                self.channel
157                    .send(&format!("Failed to connect server '{}': {e}", entry.id))
158                    .await?;
159                Ok(())
160            }
161        }
162    }
163
164    async fn handle_mcp_list(&mut self) -> Result<(), super::error::AgentError> {
165        use std::fmt::Write;
166
167        let Some(ref manager) = self.mcp.manager else {
168            self.channel.send("MCP is not enabled.").await?;
169            return Ok(());
170        };
171
172        let server_ids = manager.list_servers().await;
173        if server_ids.is_empty() {
174            self.channel.send("No MCP servers connected.").await?;
175            return Ok(());
176        }
177
178        let mut output = String::from("Connected MCP servers:\n");
179        let mut total = 0usize;
180        for id in &server_ids {
181            let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
182            total += count;
183            let _ = writeln!(output, "- {id} ({count} tools)");
184        }
185        let _ = write!(output, "Total: {total} tool(s)");
186
187        self.channel.send(&output).await?;
188        Ok(())
189    }
190
191    async fn handle_mcp_tools(
192        &mut self,
193        server_id: Option<&str>,
194    ) -> Result<(), super::error::AgentError> {
195        use std::fmt::Write;
196
197        let Some(server_id) = server_id else {
198            self.channel.send("Usage: /mcp tools <server_id>").await?;
199            return Ok(());
200        };
201
202        let tools: Vec<_> = self
203            .mcp
204            .tools
205            .iter()
206            .filter(|t| t.server_id == server_id)
207            .collect();
208
209        if tools.is_empty() {
210            self.channel
211                .send(&format!("No tools found for server '{server_id}'."))
212                .await?;
213            return Ok(());
214        }
215
216        let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
217        for t in &tools {
218            if t.description.is_empty() {
219                let _ = writeln!(output, "- {}", t.name);
220            } else {
221                let _ = writeln!(output, "- {} — {}", t.name, t.description);
222            }
223        }
224        self.channel.send(&output).await?;
225        Ok(())
226    }
227
228    async fn handle_mcp_remove(
229        &mut self,
230        server_id: Option<&str>,
231    ) -> Result<(), super::error::AgentError> {
232        let Some(server_id) = server_id else {
233            self.channel.send("Usage: /mcp remove <id>").await?;
234            return Ok(());
235        };
236
237        let Some(ref manager) = self.mcp.manager else {
238            self.channel.send("MCP is not enabled.").await?;
239            return Ok(());
240        };
241
242        match manager.remove_server(server_id).await {
243            Ok(()) => {
244                let before = self.mcp.tools.len();
245                self.mcp.tools.retain(|t| t.server_id != server_id);
246                let removed = before - self.mcp.tools.len();
247                self.mcp.server_outcomes.retain(|o| o.id != server_id);
248                self.sync_mcp_executor_tools();
249                self.mcp.pruning_cache.reset();
250                self.rebuild_semantic_index().await;
251                self.sync_mcp_registry().await;
252                let mcp_total = self.mcp.tools.len();
253                let mcp_server_count = self.mcp.server_outcomes.len();
254                let mcp_connected_count = self
255                    .mcp
256                    .server_outcomes
257                    .iter()
258                    .filter(|o| o.connected)
259                    .count();
260                let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
261                    .mcp
262                    .server_outcomes
263                    .iter()
264                    .map(|o| crate::metrics::McpServerStatus {
265                        id: o.id.clone(),
266                        status: if o.connected {
267                            crate::metrics::McpServerConnectionStatus::Connected
268                        } else {
269                            crate::metrics::McpServerConnectionStatus::Failed
270                        },
271                        tool_count: o.tool_count,
272                        error: o.error.clone(),
273                    })
274                    .collect();
275                self.update_metrics(|m| {
276                    m.mcp_tool_count = mcp_total;
277                    m.mcp_server_count = mcp_server_count;
278                    m.mcp_connected_count = mcp_connected_count;
279                    m.mcp_servers = mcp_servers;
280                    m.active_mcp_tools
281                        .retain(|name| !name.starts_with(&format!("{server_id}:")));
282                });
283                self.channel
284                    .send(&format!(
285                        "Disconnected MCP server '{server_id}' (removed {removed} tools)"
286                    ))
287                    .await?;
288                Ok(())
289            }
290            Err(e) => {
291                tracing::warn!(server_id, "MCP remove failed: {e:#}");
292                self.channel
293                    .send(&format!("Failed to remove server '{server_id}': {e}"))
294                    .await?;
295                Ok(())
296            }
297        }
298    }
299
300    pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
301        let matched_tools = self.match_mcp_tools(query).await;
302        let active_mcp: Vec<String> = matched_tools
303            .iter()
304            .map(zeph_mcp::McpTool::qualified_name)
305            .collect();
306        let mcp_total = self.mcp.tools.len();
307        let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
308            let connected = self
309                .mcp
310                .tools
311                .iter()
312                .map(|t| &t.server_id)
313                .collect::<std::collections::HashSet<_>>()
314                .len();
315            (connected, connected)
316        } else {
317            let total = self.mcp.server_outcomes.len();
318            let connected = self
319                .mcp
320                .server_outcomes
321                .iter()
322                .filter(|o| o.connected)
323                .count();
324            (total, connected)
325        };
326        self.update_metrics(|m| {
327            m.active_mcp_tools = active_mcp;
328            m.mcp_tool_count = mcp_total;
329            m.mcp_server_count = mcp_server_count;
330            m.mcp_connected_count = mcp_connected_count;
331        });
332        // When native tool_use is active, MCP tools flow through the executor chain
333        // as ToolDefinitions — skip text prompt injection to avoid duplication.
334        if self.provider.supports_tool_use() {
335            return;
336        }
337        if !matched_tools.is_empty() {
338            let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
339            tracing::debug!(
340                skills = ?self.skill_state.active_skill_names,
341                mcp_tools = ?tool_names,
342                "matched items"
343            );
344            let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
345            if !tools_prompt.is_empty() {
346                system_prompt.push_str("\n\n");
347                system_prompt.push_str(&tools_prompt);
348            }
349        }
350    }
351
352    async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
353        let Some(ref registry) = self.mcp.registry else {
354            return self.mcp.tools.clone();
355        };
356        let provider = self.embedding_provider.clone();
357        registry
358            .search(query, self.skill_state.max_active_skills, |text| {
359                let owned = text.to_owned();
360                let p = provider.clone();
361                Box::pin(async move { p.embed(&owned).await })
362            })
363            .await
364    }
365
366    #[cfg(test)]
367    pub(crate) fn mcp_tool_count(&self) -> usize {
368        self.mcp.tools.len()
369    }
370
371    /// Poll the watch receiver for tool list updates from `tools/list_changed` notifications.
372    ///
373    /// Called once per agent turn, before processing user input. When the tool list has changed,
374    /// updates `mcp.tools`, syncs the executor, and schedules a registry sync.
375    /// If no receiver is set (MCP disabled), or no change has occurred, this is a no-op.
376    pub(super) async fn check_tool_refresh(&mut self) {
377        let Some(ref mut rx) = self.mcp.tool_rx else {
378            return;
379        };
380        if !rx.has_changed().unwrap_or(false) {
381            return;
382        }
383        let new_tools = rx.borrow_and_update().clone();
384        if new_tools.is_empty() {
385            // Guard against replacing a non-empty initial tool list with the watch's empty
386            // initial value. The watch is only updated after a real tools/list_changed event.
387            return;
388        }
389        tracing::info!(
390            tools = new_tools.len(),
391            "tools/list_changed: agent tool list refreshed"
392        );
393        self.mcp.tools = new_tools;
394        self.sync_mcp_executor_tools();
395        self.mcp.pruning_cache.reset();
396        self.rebuild_semantic_index().await;
397        self.sync_mcp_registry().await;
398        let mcp_total = self.mcp.tools.len();
399        let mcp_servers = self
400            .mcp
401            .tools
402            .iter()
403            .map(|t| &t.server_id)
404            .collect::<std::collections::HashSet<_>>()
405            .len();
406        self.update_metrics(|m| {
407            m.mcp_tool_count = mcp_total;
408            m.mcp_server_count = mcp_servers;
409        });
410    }
411
412    /// Write the **full** `self.mcp.tools` set to the shared executor `RwLock`.
413    ///
414    /// This is the first of two writers to `mcp.shared_tools`.  Within a turn
415    /// this method must run **before** `apply_pruned_mcp_tools`, which writes the
416    /// pruned subset.  The normal call order guarantees this: tool-list change
417    /// events (notify, `/mcp add`, `/mcp remove`) call this method, and pruning
418    /// runs later inside `rebuild_system_prompt`.
419    /// See also: `McpState::shared_tools` doc comment.
420    pub(super) fn sync_mcp_executor_tools(&self) {
421        if let Some(ref shared) = self.mcp.shared_tools {
422            let mut guard = shared
423                .write()
424                .unwrap_or_else(std::sync::PoisonError::into_inner);
425            guard.clone_from(&self.mcp.tools);
426        }
427    }
428
429    /// Write the **pruned** tool subset to the shared executor `RwLock`.
430    ///
431    /// This is the second of two writers to `mcp.shared_tools`.  Must only be
432    /// called **after** `sync_mcp_executor_tools` has established the full tool
433    /// set for the current turn (guaranteed by call-site ordering: pruning runs
434    /// inside `rebuild_system_prompt`, after any tool-list change events).
435    ///
436    /// `self.mcp.tools` (the full set) is intentionally **not** modified: it is
437    /// retained for cache key computation and for restoration when the next turn
438    /// triggers a cache reset.
439    ///
440    /// This method must **NOT** call `sync_mcp_executor_tools` internally —
441    /// doing so would overwrite the pruned subset with the full set.
442    /// See also: `McpState::shared_tools` doc comment.
443    pub(in crate::agent) fn apply_pruned_mcp_tools(&self, pruned: Vec<zeph_mcp::McpTool>) {
444        debug_assert!(
445            pruned.iter().all(|p| self
446                .mcp
447                .tools
448                .iter()
449                .any(|t| t.server_id == p.server_id && t.name == p.name)),
450            "pruned set must be a subset of self.mcp.tools"
451        );
452        if let Some(ref shared) = self.mcp.shared_tools {
453            let mut guard = shared
454                .write()
455                .unwrap_or_else(std::sync::PoisonError::into_inner);
456            *guard = pruned;
457        }
458    }
459
460    pub(super) async fn sync_mcp_registry(&mut self) {
461        let Some(ref mut registry) = self.mcp.registry else {
462            return;
463        };
464        if !self.embedding_provider.supports_embeddings() {
465            return;
466        }
467        let provider = self.embedding_provider.clone();
468        let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
469            let owned = text.to_owned();
470            let p = provider.clone();
471            Box::pin(async move { p.embed(&owned).await })
472        };
473        if let Err(e) = registry
474            .sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
475            .await
476        {
477            tracing::warn!("failed to sync MCP tool registry: {e:#}");
478        }
479    }
480
481    /// Build (or rebuild) the in-memory semantic tool index for embedding-based discovery.
482    /// Build the initial semantic tool index after agent construction.
483    ///
484    /// Must be called once after `with_mcp` and `with_mcp_discovery` are applied,
485    /// before the first user turn.  Subsequent rebuilds happen automatically on
486    /// tool list change events (`check_tool_refresh`, `/mcp add`, `/mcp remove`).
487    pub async fn init_semantic_index(&mut self) {
488        self.rebuild_semantic_index().await;
489    }
490
491    /// Drain and process all pending elicitation requests without blocking.
492    ///
493    /// Call this at the start of each turn and between tool calls to prevent
494    /// elicitation events from accumulating while the agent loop is busy.
495    pub(super) async fn process_pending_elicitations(&mut self) {
496        loop {
497            let Some(ref mut rx) = self.mcp.elicitation_rx else {
498                return;
499            };
500            match rx.try_recv() {
501                Ok(event) => {
502                    self.handle_elicitation_event(event).await;
503                }
504                Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
505                Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
506                    self.mcp.elicitation_rx = None;
507                    return;
508                }
509            }
510        }
511    }
512
513    /// Handle a single elicitation event by routing it to the active channel.
514    pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
515        use crate::channel::{ElicitationRequest, ElicitationResponse};
516
517        let decline = CreateElicitationResult {
518            action: ElicitationAction::Decline,
519            content: None,
520        };
521
522        let channel_request = match &event.request {
523            rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
524                message,
525                requested_schema,
526                ..
527            } => {
528                let fields = build_elicitation_fields(requested_schema);
529                ElicitationRequest {
530                    server_name: event.server_id.clone(),
531                    message: sanitize_elicitation_message(message),
532                    fields,
533                }
534            }
535            rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
536                // URL elicitation not supported in phase 1 — decline.
537                tracing::debug!(
538                    server_id = event.server_id,
539                    "URL elicitation not supported, declining"
540                );
541                let _ = event.response_tx.send(decline);
542                return;
543            }
544        };
545
546        if self.mcp.elicitation_warn_sensitive_fields {
547            let sensitive: Vec<&str> = channel_request
548                .fields
549                .iter()
550                .filter(|f| is_sensitive_field(&f.name))
551                .map(|f| f.name.as_str())
552                .collect();
553            if !sensitive.is_empty() {
554                let fields_list = sensitive.join(", ");
555                let warning = format!(
556                    "Warning: [{}] is requesting sensitive information (field: {}). \
557                     Only proceed if you trust this server.",
558                    channel_request.server_name, fields_list,
559                );
560                tracing::warn!(
561                    server_id = event.server_id,
562                    fields = %fields_list,
563                    "elicitation requests sensitive fields"
564                );
565                let _ = self.channel.send(&warning).await;
566            }
567        }
568
569        let _ = self
570            .channel
571            .send_status("MCP server requesting input…")
572            .await;
573        let response = match self.channel.elicit(channel_request).await {
574            Ok(r) => r,
575            Err(e) => {
576                tracing::warn!(
577                    server_id = event.server_id,
578                    "elicitation channel error: {e:#}"
579                );
580                let _ = self.channel.send_status("").await;
581                let _ = event.response_tx.send(decline);
582                return;
583            }
584        };
585        let _ = self.channel.send_status("").await;
586
587        let result = match response {
588            ElicitationResponse::Accepted(value) => CreateElicitationResult {
589                action: ElicitationAction::Accept,
590                content: Some(value),
591            },
592            ElicitationResponse::Declined => CreateElicitationResult {
593                action: ElicitationAction::Decline,
594                content: None,
595            },
596            ElicitationResponse::Cancelled => CreateElicitationResult {
597                action: ElicitationAction::Cancel,
598                content: None,
599            },
600        };
601
602        if event.response_tx.send(result).is_err() {
603            tracing::warn!(
604                server_id = event.server_id,
605                "elicitation response dropped — handler disconnected"
606            );
607        }
608    }
609
610    /// Rebuild the in-memory semantic tool index.
611    ///
612    /// Only runs when `discovery_strategy == Embedding`.  On failure (all embeddings fail),
613    /// sets `semantic_index = None` and logs at WARN — the caller falls back to all tools.
614    ///
615    /// Called at:
616    /// - initial setup via `init_semantic_index()`
617    /// - `tools/list_changed` notification
618    /// - `/mcp add` and `/mcp remove`
619    pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
620        if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
621            return;
622        }
623
624        if self.mcp.tools.is_empty() {
625            self.mcp.semantic_index = None;
626            return;
627        }
628
629        // Resolve embedding provider: dedicated discovery provider → primary embedding provider.
630        let provider = self
631            .mcp
632            .discovery_provider
633            .clone()
634            .unwrap_or_else(|| self.embedding_provider.clone());
635
636        let embed_fn = provider.embed_fn();
637
638        match zeph_mcp::SemanticToolIndex::build(&self.mcp.tools, &embed_fn).await {
639            Ok(idx) => {
640                tracing::info!(
641                    indexed = idx.len(),
642                    total = self.mcp.tools.len(),
643                    "semantic tool index built"
644                );
645                self.mcp.semantic_index = Some(idx);
646            }
647            Err(e) => {
648                tracing::warn!(
649                    "semantic tool index build failed, falling back to all tools: {e:#}"
650                );
651                self.mcp.semantic_index = None;
652            }
653        }
654    }
655}
656
657/// Convert an rmcp `ElicitationSchema` into channel-agnostic `ElicitationField` list.
658fn build_elicitation_fields(
659    schema: &rmcp::model::ElicitationSchema,
660) -> Vec<crate::channel::ElicitationField> {
661    use crate::channel::{ElicitationField, ElicitationFieldType};
662    use rmcp::model::PrimitiveSchema;
663
664    schema
665        .properties
666        .iter()
667        .map(|(name, prop)| {
668            // Extract field type and description by serializing the PrimitiveSchema to JSON
669            // and reading the discriminator field.  This avoids deep-matching the nested
670            // EnumSchema / StringSchema / … variants of rmcp's type-safe schema hierarchy.
671            let json = serde_json::to_value(prop).unwrap_or_default();
672            let description = json
673                .get("description")
674                .and_then(|v| v.as_str())
675                .map(String::from);
676
677            let field_type = match prop {
678                PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
679                PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
680                PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
681                PrimitiveSchema::String(_) => ElicitationFieldType::String,
682                PrimitiveSchema::Enum(_) => {
683                    // Extract enum values from the serialized form.  All EnumSchema variants
684                    // serialise their allowed values under "enum" or inside "items.enum".
685                    let vals = json
686                        .get("enum")
687                        .and_then(|v| v.as_array())
688                        .map(|arr| {
689                            arr.iter()
690                                .filter_map(|v| v.as_str())
691                                .map(String::from)
692                                .collect::<Vec<_>>()
693                        })
694                        .unwrap_or_default();
695                    ElicitationFieldType::Enum(vals)
696                }
697            };
698            let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
699            ElicitationField {
700                name: name.clone(),
701                description,
702                field_type,
703                required,
704            }
705        })
706        .collect()
707}
708
709/// Sensitive field name patterns (case-insensitive substring match).
710const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
711    "password",
712    "passwd",
713    "token",
714    "secret",
715    "key",
716    "credential",
717    "apikey",
718    "api_key",
719    "auth",
720    "authorization",
721    "private",
722    "passphrase",
723    "pin",
724];
725
726/// Returns `true` when `field_name` matches any sensitive pattern (case-insensitive).
727fn is_sensitive_field(field_name: &str) -> bool {
728    let lower = field_name.to_lowercase();
729    SENSITIVE_FIELD_PATTERNS
730        .iter()
731        .any(|pattern| lower.contains(pattern))
732}
733
734/// Sanitize an elicitation message: cap length (in chars, not bytes) and strip control chars.
735fn sanitize_elicitation_message(message: &str) -> String {
736    const MAX_CHARS: usize = 500;
737    // Collect up to MAX_CHARS chars, filtering control characters that could manipulate terminals.
738    message
739        .chars()
740        .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
741        .take(MAX_CHARS)
742        .collect()
743}
744
745#[cfg(test)]
746mod tests {
747    use super::super::agent_tests::{
748        MockChannel, MockToolExecutor, create_test_registry, mock_provider,
749    };
750    use super::*;
751
752    #[tokio::test]
753    async fn handle_mcp_command_unknown_subcommand_shows_usage() {
754        let provider = mock_provider(vec![]);
755        let channel = MockChannel::new(vec![]);
756        let registry = create_test_registry();
757        let executor = MockToolExecutor::no_tools();
758        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
759
760        agent.handle_mcp_command("unknown").await.unwrap();
761
762        let sent = agent.channel.sent_messages();
763        assert!(
764            sent.iter().any(|s| s.contains("Usage: /mcp")),
765            "expected usage message, got: {sent:?}"
766        );
767    }
768
769    #[tokio::test]
770    async fn handle_mcp_list_no_manager_shows_disabled() {
771        let provider = mock_provider(vec![]);
772        let channel = MockChannel::new(vec![]);
773        let registry = create_test_registry();
774        let executor = MockToolExecutor::no_tools();
775        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
776
777        agent.handle_mcp_command("list").await.unwrap();
778
779        let sent = agent.channel.sent_messages();
780        assert!(
781            sent.iter().any(|s| s.contains("MCP is not enabled")),
782            "expected not-enabled message, got: {sent:?}"
783        );
784    }
785
786    #[tokio::test]
787    async fn handle_mcp_tools_no_server_id_shows_usage() {
788        let provider = mock_provider(vec![]);
789        let channel = MockChannel::new(vec![]);
790        let registry = create_test_registry();
791        let executor = MockToolExecutor::no_tools();
792        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
793
794        agent.handle_mcp_command("tools").await.unwrap();
795
796        let sent = agent.channel.sent_messages();
797        assert!(
798            sent.iter().any(|s| s.contains("Usage: /mcp tools")),
799            "expected tools usage message, got: {sent:?}"
800        );
801    }
802
803    #[tokio::test]
804    async fn handle_mcp_remove_no_server_id_shows_usage() {
805        let provider = mock_provider(vec![]);
806        let channel = MockChannel::new(vec![]);
807        let registry = create_test_registry();
808        let executor = MockToolExecutor::no_tools();
809        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
810
811        agent.handle_mcp_command("remove").await.unwrap();
812
813        let sent = agent.channel.sent_messages();
814        assert!(
815            sent.iter().any(|s| s.contains("Usage: /mcp remove")),
816            "expected remove usage message, got: {sent:?}"
817        );
818    }
819
820    #[tokio::test]
821    async fn handle_mcp_remove_no_manager_shows_disabled() {
822        let provider = mock_provider(vec![]);
823        let channel = MockChannel::new(vec![]);
824        let registry = create_test_registry();
825        let executor = MockToolExecutor::no_tools();
826        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
827
828        // "remove server-id" but no manager
829        agent.handle_mcp_command("remove my-server").await.unwrap();
830
831        let sent = agent.channel.sent_messages();
832        assert!(
833            sent.iter().any(|s| s.contains("MCP is not enabled")),
834            "expected not-enabled message, got: {sent:?}"
835        );
836    }
837
838    #[tokio::test]
839    async fn handle_mcp_add_insufficient_args_shows_usage() {
840        let provider = mock_provider(vec![]);
841        let channel = MockChannel::new(vec![]);
842        let registry = create_test_registry();
843        let executor = MockToolExecutor::no_tools();
844        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
845
846        // "add" with only 1 arg (needs at least 2)
847        agent.handle_mcp_command("add server-id").await.unwrap();
848
849        let sent = agent.channel.sent_messages();
850        assert!(
851            sent.iter().any(|s| s.contains("Usage: /mcp add")),
852            "expected add usage message, got: {sent:?}"
853        );
854    }
855
856    #[tokio::test]
857    async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
858        let provider = mock_provider(vec![]);
859        let channel = MockChannel::new(vec![]);
860        let registry = create_test_registry();
861        let executor = MockToolExecutor::no_tools();
862        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
863
864        // mcp.tools is empty, so any server will have no tools
865        agent
866            .handle_mcp_command("tools nonexistent-server")
867            .await
868            .unwrap();
869
870        let sent = agent.channel.sent_messages();
871        assert!(
872            sent.iter().any(|s| s.contains("No tools found")),
873            "expected no-tools message, got: {sent:?}"
874        );
875    }
876
877    #[tokio::test]
878    async fn mcp_tool_count_starts_at_zero() {
879        let provider = mock_provider(vec![]);
880        let channel = MockChannel::new(vec![]);
881        let registry = create_test_registry();
882        let executor = MockToolExecutor::no_tools();
883        let agent = Agent::new(provider, channel, registry, None, 5, executor);
884
885        assert_eq!(agent.mcp_tool_count(), 0);
886    }
887
888    #[tokio::test]
889    async fn check_tool_refresh_no_rx_is_noop() {
890        let provider = mock_provider(vec![]);
891        let channel = MockChannel::new(vec![]);
892        let registry = create_test_registry();
893        let executor = MockToolExecutor::no_tools();
894        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
895        // No tool_rx set; check_tool_refresh should be a no-op.
896        agent.check_tool_refresh().await;
897        assert_eq!(agent.mcp_tool_count(), 0);
898    }
899
900    #[tokio::test]
901    async fn check_tool_refresh_no_change_is_noop() {
902        let provider = mock_provider(vec![]);
903        let channel = MockChannel::new(vec![]);
904        let registry = create_test_registry();
905        let executor = MockToolExecutor::no_tools();
906        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
907
908        let (tx, rx) = tokio::sync::watch::channel(Vec::new());
909        agent.mcp.tool_rx = Some(rx);
910        // No changes sent; has_changed() returns false.
911        agent.check_tool_refresh().await;
912        assert_eq!(agent.mcp_tool_count(), 0);
913        drop(tx);
914    }
915
916    #[tokio::test]
917    async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
918        let provider = mock_provider(vec![]);
919        let channel = MockChannel::new(vec![]);
920        let registry = create_test_registry();
921        let executor = MockToolExecutor::no_tools();
922        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
923        agent.mcp.tools = vec![zeph_mcp::McpTool {
924            server_id: "srv".into(),
925            name: "existing_tool".into(),
926            description: String::new(),
927            input_schema: serde_json::json!({}),
928            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
929        }];
930
931        let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
932        agent.mcp.tool_rx = Some(rx);
933        // has_changed() is false for a fresh receiver; tools unchanged.
934        agent.check_tool_refresh().await;
935        assert_eq!(agent.mcp_tool_count(), 1);
936    }
937
938    #[tokio::test]
939    async fn check_tool_refresh_applies_update() {
940        let provider = mock_provider(vec![]);
941        let channel = MockChannel::new(vec![]);
942        let registry = create_test_registry();
943        let executor = MockToolExecutor::no_tools();
944        let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
945
946        let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
947        agent.mcp.tool_rx = Some(rx);
948
949        let new_tools = vec![zeph_mcp::McpTool {
950            server_id: "srv".into(),
951            name: "refreshed_tool".into(),
952            description: String::new(),
953            input_schema: serde_json::json!({}),
954            security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
955        }];
956        tx.send(new_tools).unwrap();
957
958        agent.check_tool_refresh().await;
959        assert_eq!(agent.mcp_tool_count(), 1);
960        assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
961    }
962
963    #[test]
964    fn sanitize_elicitation_message_strips_control_chars() {
965        let input = "hello\x01world\x1b[31mred\x1b[0m";
966        let output = sanitize_elicitation_message(input);
967        assert!(!output.contains('\x01'));
968        assert!(!output.contains('\x1b'));
969        assert!(output.contains("hello"));
970        assert!(output.contains("world"));
971    }
972
973    #[test]
974    fn sanitize_elicitation_message_preserves_newline_and_tab() {
975        let input = "line1\nline2\ttabbed";
976        let output = sanitize_elicitation_message(input);
977        assert_eq!(output, "line1\nline2\ttabbed");
978    }
979
980    #[test]
981    fn sanitize_elicitation_message_caps_at_500_chars() {
982        // Build a 600-char ASCII string — no multi-byte boundary issue.
983        let input: String = "a".repeat(600);
984        let output = sanitize_elicitation_message(&input);
985        assert_eq!(output.chars().count(), 500);
986    }
987
988    #[test]
989    fn sanitize_elicitation_message_handles_multibyte_boundary() {
990        // "é" is 2 bytes.  Build a string where a naive &str[..500] would panic.
991        let input: String = "é".repeat(300); // 300 chars = 600 bytes
992        let output = sanitize_elicitation_message(&input);
993        // Should truncate to exactly 500 chars without panic.
994        assert_eq!(output.chars().count(), 300);
995    }
996
997    #[test]
998    fn build_elicitation_fields_maps_primitive_types() {
999        use crate::channel::ElicitationFieldType;
1000        use rmcp::model::{
1001            BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
1002            StringSchema,
1003        };
1004        use std::collections::BTreeMap;
1005
1006        let mut props = BTreeMap::new();
1007        props.insert(
1008            "flag".to_owned(),
1009            PrimitiveSchema::Boolean(BooleanSchema::new()),
1010        );
1011        props.insert(
1012            "count".to_owned(),
1013            PrimitiveSchema::Integer(IntegerSchema::new()),
1014        );
1015        props.insert(
1016            "ratio".to_owned(),
1017            PrimitiveSchema::Number(NumberSchema::new()),
1018        );
1019        props.insert(
1020            "name".to_owned(),
1021            PrimitiveSchema::String(StringSchema::new()),
1022        );
1023
1024        let schema = ElicitationSchema::new(props);
1025        let fields = build_elicitation_fields(&schema);
1026
1027        let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1028        assert!(matches!(
1029            get("flag").field_type,
1030            ElicitationFieldType::Boolean
1031        ));
1032        assert!(matches!(
1033            get("count").field_type,
1034            ElicitationFieldType::Integer
1035        ));
1036        assert!(matches!(
1037            get("ratio").field_type,
1038            ElicitationFieldType::Number
1039        ));
1040        assert!(matches!(
1041            get("name").field_type,
1042            ElicitationFieldType::String
1043        ));
1044    }
1045
1046    #[test]
1047    fn build_elicitation_fields_required_flag() {
1048        use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1049        use std::collections::BTreeMap;
1050
1051        let mut props = BTreeMap::new();
1052        props.insert(
1053            "req".to_owned(),
1054            PrimitiveSchema::String(StringSchema::new()),
1055        );
1056        props.insert(
1057            "opt".to_owned(),
1058            PrimitiveSchema::String(StringSchema::new()),
1059        );
1060
1061        let mut schema = ElicitationSchema::new(props);
1062        schema.required = Some(vec!["req".to_owned()]);
1063
1064        let fields = build_elicitation_fields(&schema);
1065        let req = fields.iter().find(|f| f.name == "req").unwrap();
1066        let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1067        assert!(req.required);
1068        assert!(!opt.required);
1069    }
1070
1071    #[test]
1072    fn is_sensitive_field_detects_common_patterns() {
1073        assert!(is_sensitive_field("password"));
1074        assert!(is_sensitive_field("PASSWORD"));
1075        assert!(is_sensitive_field("user_password"));
1076        assert!(is_sensitive_field("api_token"));
1077        assert!(is_sensitive_field("SECRET_KEY"));
1078        assert!(is_sensitive_field("auth_header"));
1079        assert!(is_sensitive_field("private_key"));
1080    }
1081
1082    #[test]
1083    fn is_sensitive_field_allows_non_sensitive_names() {
1084        assert!(!is_sensitive_field("username"));
1085        assert!(!is_sensitive_field("email"));
1086        assert!(!is_sensitive_field("message"));
1087        assert!(!is_sensitive_field("description"));
1088        assert!(!is_sensitive_field("subject"));
1089    }
1090}