1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 pub(super) async fn handle_mcp_command(
10 &mut self,
11 args: &str,
12 ) -> Result<(), super::error::AgentError> {
13 let parts: Vec<&str> = args.split_whitespace().collect();
14 match parts.first().copied() {
15 Some("add") => self.handle_mcp_add(&parts[1..]).await,
16 Some("list") => self.handle_mcp_list().await,
17 Some("tools") => self.handle_mcp_tools(parts.get(1).copied()).await,
18 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
19 _ => {
20 self.channel
21 .send("Usage: /mcp add|list|tools|remove")
22 .await?;
23 Ok(())
24 }
25 }
26 }
27
28 #[allow(clippy::too_many_lines)]
29 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<(), super::error::AgentError> {
30 if args.len() < 2 {
31 self.channel
32 .send("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>")
33 .await?;
34 return Ok(());
35 }
36
37 let Some(ref manager) = self.mcp.manager else {
38 self.channel.send("MCP is not enabled.").await?;
39 return Ok(());
40 };
41
42 let target = args[1];
43 let is_url = target.starts_with("http://") || target.starts_with("https://");
44
45 if !is_url
47 && !self.mcp.allowed_commands.is_empty()
48 && !self.mcp.allowed_commands.iter().any(|c| c == target)
49 {
50 self.channel
51 .send(&format!(
52 "Command '{target}' is not allowed. Permitted: {}",
53 self.mcp.allowed_commands.join(", ")
54 ))
55 .await?;
56 return Ok(());
57 }
58
59 let current_count = manager.list_servers().await.len();
61 if current_count >= self.mcp.max_dynamic {
62 self.channel
63 .send(&format!(
64 "Server limit reached ({}/{}).",
65 current_count, self.mcp.max_dynamic
66 ))
67 .await?;
68 return Ok(());
69 }
70
71 let transport = if is_url {
72 zeph_mcp::McpTransport::Http {
73 url: target.to_owned(),
74 headers: std::collections::HashMap::new(),
75 }
76 } else {
77 zeph_mcp::McpTransport::Stdio {
78 command: target.to_owned(),
79 args: args[2..].iter().map(|&s| s.to_owned()).collect(),
80 env: std::collections::HashMap::new(),
81 }
82 };
83
84 let entry = zeph_mcp::ServerEntry {
85 id: args[0].to_owned(),
86 transport,
87 timeout: std::time::Duration::from_secs(30),
88 trust_level: zeph_mcp::McpTrustLevel::Untrusted,
89 tool_allowlist: None,
90 expected_tools: Vec::new(),
91 roots: Vec::new(),
92 tool_metadata: std::collections::HashMap::new(),
93 elicitation_enabled: false,
94 elicitation_timeout_secs: 120,
95 env_isolation: false,
96 };
97
98 let _ = self.channel.send_status("connecting to mcp...").await;
99 match manager.add_server(&entry).await {
100 Ok(tools) => {
101 let _ = self.channel.send_status("").await;
102 let count = tools.len();
103 self.mcp
104 .server_outcomes
105 .push(zeph_mcp::ServerConnectOutcome {
106 id: entry.id.clone(),
107 connected: true,
108 tool_count: count,
109 error: String::new(),
110 });
111 self.mcp.tools.extend(tools);
112 self.sync_mcp_executor_tools();
113 self.mcp.pruning_cache.reset();
114 self.rebuild_semantic_index().await;
115 self.sync_mcp_registry().await;
116 let mcp_total = self.mcp.tools.len();
117 let mcp_server_count = self.mcp.server_outcomes.len();
118 let mcp_connected_count = self
119 .mcp
120 .server_outcomes
121 .iter()
122 .filter(|o| o.connected)
123 .count();
124 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
125 .mcp
126 .server_outcomes
127 .iter()
128 .map(|o| crate::metrics::McpServerStatus {
129 id: o.id.clone(),
130 status: if o.connected {
131 crate::metrics::McpServerConnectionStatus::Connected
132 } else {
133 crate::metrics::McpServerConnectionStatus::Failed
134 },
135 tool_count: o.tool_count,
136 error: o.error.clone(),
137 })
138 .collect();
139 self.update_metrics(|m| {
140 m.mcp_tool_count = mcp_total;
141 m.mcp_server_count = mcp_server_count;
142 m.mcp_connected_count = mcp_connected_count;
143 m.mcp_servers = mcp_servers;
144 });
145 self.channel
146 .send(&format!(
147 "Connected MCP server '{}' ({count} tool(s))",
148 entry.id
149 ))
150 .await?;
151 Ok(())
152 }
153 Err(e) => {
154 let _ = self.channel.send_status("").await;
155 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
156 self.channel
157 .send(&format!("Failed to connect server '{}': {e}", entry.id))
158 .await?;
159 Ok(())
160 }
161 }
162 }
163
164 async fn handle_mcp_list(&mut self) -> Result<(), super::error::AgentError> {
165 use std::fmt::Write;
166
167 let Some(ref manager) = self.mcp.manager else {
168 self.channel.send("MCP is not enabled.").await?;
169 return Ok(());
170 };
171
172 let server_ids = manager.list_servers().await;
173 if server_ids.is_empty() {
174 self.channel.send("No MCP servers connected.").await?;
175 return Ok(());
176 }
177
178 let mut output = String::from("Connected MCP servers:\n");
179 let mut total = 0usize;
180 for id in &server_ids {
181 let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
182 total += count;
183 let _ = writeln!(output, "- {id} ({count} tools)");
184 }
185 let _ = write!(output, "Total: {total} tool(s)");
186
187 self.channel.send(&output).await?;
188 Ok(())
189 }
190
191 async fn handle_mcp_tools(
192 &mut self,
193 server_id: Option<&str>,
194 ) -> Result<(), super::error::AgentError> {
195 use std::fmt::Write;
196
197 let Some(server_id) = server_id else {
198 self.channel.send("Usage: /mcp tools <server_id>").await?;
199 return Ok(());
200 };
201
202 let tools: Vec<_> = self
203 .mcp
204 .tools
205 .iter()
206 .filter(|t| t.server_id == server_id)
207 .collect();
208
209 if tools.is_empty() {
210 self.channel
211 .send(&format!("No tools found for server '{server_id}'."))
212 .await?;
213 return Ok(());
214 }
215
216 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
217 for t in &tools {
218 if t.description.is_empty() {
219 let _ = writeln!(output, "- {}", t.name);
220 } else {
221 let _ = writeln!(output, "- {} — {}", t.name, t.description);
222 }
223 }
224 self.channel.send(&output).await?;
225 Ok(())
226 }
227
228 async fn handle_mcp_remove(
229 &mut self,
230 server_id: Option<&str>,
231 ) -> Result<(), super::error::AgentError> {
232 let Some(server_id) = server_id else {
233 self.channel.send("Usage: /mcp remove <id>").await?;
234 return Ok(());
235 };
236
237 let Some(ref manager) = self.mcp.manager else {
238 self.channel.send("MCP is not enabled.").await?;
239 return Ok(());
240 };
241
242 match manager.remove_server(server_id).await {
243 Ok(()) => {
244 let before = self.mcp.tools.len();
245 self.mcp.tools.retain(|t| t.server_id != server_id);
246 let removed = before - self.mcp.tools.len();
247 self.mcp.server_outcomes.retain(|o| o.id != server_id);
248 self.sync_mcp_executor_tools();
249 self.mcp.pruning_cache.reset();
250 self.rebuild_semantic_index().await;
251 self.sync_mcp_registry().await;
252 let mcp_total = self.mcp.tools.len();
253 let mcp_server_count = self.mcp.server_outcomes.len();
254 let mcp_connected_count = self
255 .mcp
256 .server_outcomes
257 .iter()
258 .filter(|o| o.connected)
259 .count();
260 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
261 .mcp
262 .server_outcomes
263 .iter()
264 .map(|o| crate::metrics::McpServerStatus {
265 id: o.id.clone(),
266 status: if o.connected {
267 crate::metrics::McpServerConnectionStatus::Connected
268 } else {
269 crate::metrics::McpServerConnectionStatus::Failed
270 },
271 tool_count: o.tool_count,
272 error: o.error.clone(),
273 })
274 .collect();
275 self.update_metrics(|m| {
276 m.mcp_tool_count = mcp_total;
277 m.mcp_server_count = mcp_server_count;
278 m.mcp_connected_count = mcp_connected_count;
279 m.mcp_servers = mcp_servers;
280 m.active_mcp_tools
281 .retain(|name| !name.starts_with(&format!("{server_id}:")));
282 });
283 self.channel
284 .send(&format!(
285 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
286 ))
287 .await?;
288 Ok(())
289 }
290 Err(e) => {
291 tracing::warn!(server_id, "MCP remove failed: {e:#}");
292 self.channel
293 .send(&format!("Failed to remove server '{server_id}': {e}"))
294 .await?;
295 Ok(())
296 }
297 }
298 }
299
300 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
301 let matched_tools = self.match_mcp_tools(query).await;
302 let active_mcp: Vec<String> = matched_tools
303 .iter()
304 .map(zeph_mcp::McpTool::qualified_name)
305 .collect();
306 let mcp_total = self.mcp.tools.len();
307 let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
308 let connected = self
309 .mcp
310 .tools
311 .iter()
312 .map(|t| &t.server_id)
313 .collect::<std::collections::HashSet<_>>()
314 .len();
315 (connected, connected)
316 } else {
317 let total = self.mcp.server_outcomes.len();
318 let connected = self
319 .mcp
320 .server_outcomes
321 .iter()
322 .filter(|o| o.connected)
323 .count();
324 (total, connected)
325 };
326 self.update_metrics(|m| {
327 m.active_mcp_tools = active_mcp;
328 m.mcp_tool_count = mcp_total;
329 m.mcp_server_count = mcp_server_count;
330 m.mcp_connected_count = mcp_connected_count;
331 });
332 if self.provider.supports_tool_use() {
335 return;
336 }
337 if !matched_tools.is_empty() {
338 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
339 tracing::debug!(
340 skills = ?self.skill_state.active_skill_names,
341 mcp_tools = ?tool_names,
342 "matched items"
343 );
344 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
345 if !tools_prompt.is_empty() {
346 system_prompt.push_str("\n\n");
347 system_prompt.push_str(&tools_prompt);
348 }
349 }
350 }
351
352 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
353 let Some(ref registry) = self.mcp.registry else {
354 return self.mcp.tools.clone();
355 };
356 let provider = self.embedding_provider.clone();
357 registry
358 .search(query, self.skill_state.max_active_skills, |text| {
359 let owned = text.to_owned();
360 let p = provider.clone();
361 Box::pin(async move { p.embed(&owned).await })
362 })
363 .await
364 }
365
366 #[cfg(test)]
367 pub(crate) fn mcp_tool_count(&self) -> usize {
368 self.mcp.tools.len()
369 }
370
371 pub(super) async fn check_tool_refresh(&mut self) {
377 let Some(ref mut rx) = self.mcp.tool_rx else {
378 return;
379 };
380 if !rx.has_changed().unwrap_or(false) {
381 return;
382 }
383 let new_tools = rx.borrow_and_update().clone();
384 if new_tools.is_empty() {
385 return;
388 }
389 tracing::info!(
390 tools = new_tools.len(),
391 "tools/list_changed: agent tool list refreshed"
392 );
393 self.mcp.tools = new_tools;
394 self.sync_mcp_executor_tools();
395 self.mcp.pruning_cache.reset();
396 self.rebuild_semantic_index().await;
397 self.sync_mcp_registry().await;
398 let mcp_total = self.mcp.tools.len();
399 let mcp_servers = self
400 .mcp
401 .tools
402 .iter()
403 .map(|t| &t.server_id)
404 .collect::<std::collections::HashSet<_>>()
405 .len();
406 self.update_metrics(|m| {
407 m.mcp_tool_count = mcp_total;
408 m.mcp_server_count = mcp_servers;
409 });
410 }
411
412 pub(super) fn sync_mcp_executor_tools(&self) {
421 if let Some(ref shared) = self.mcp.shared_tools {
422 let mut guard = shared
423 .write()
424 .unwrap_or_else(std::sync::PoisonError::into_inner);
425 guard.clone_from(&self.mcp.tools);
426 }
427 }
428
429 pub(in crate::agent) fn apply_pruned_mcp_tools(&self, pruned: Vec<zeph_mcp::McpTool>) {
444 debug_assert!(
445 pruned.iter().all(|p| self
446 .mcp
447 .tools
448 .iter()
449 .any(|t| t.server_id == p.server_id && t.name == p.name)),
450 "pruned set must be a subset of self.mcp.tools"
451 );
452 if let Some(ref shared) = self.mcp.shared_tools {
453 let mut guard = shared
454 .write()
455 .unwrap_or_else(std::sync::PoisonError::into_inner);
456 *guard = pruned;
457 }
458 }
459
460 pub(super) async fn sync_mcp_registry(&mut self) {
461 let Some(ref mut registry) = self.mcp.registry else {
462 return;
463 };
464 if !self.embedding_provider.supports_embeddings() {
465 return;
466 }
467 let provider = self.embedding_provider.clone();
468 let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
469 let owned = text.to_owned();
470 let p = provider.clone();
471 Box::pin(async move { p.embed(&owned).await })
472 };
473 if let Err(e) = registry
474 .sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
475 .await
476 {
477 tracing::warn!("failed to sync MCP tool registry: {e:#}");
478 }
479 }
480
481 pub async fn init_semantic_index(&mut self) {
488 self.rebuild_semantic_index().await;
489 }
490
491 pub(super) async fn process_pending_elicitations(&mut self) {
496 loop {
497 let Some(ref mut rx) = self.mcp.elicitation_rx else {
498 return;
499 };
500 match rx.try_recv() {
501 Ok(event) => {
502 self.handle_elicitation_event(event).await;
503 }
504 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
505 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
506 self.mcp.elicitation_rx = None;
507 return;
508 }
509 }
510 }
511 }
512
513 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
515 use crate::channel::{ElicitationRequest, ElicitationResponse};
516
517 let decline = CreateElicitationResult {
518 action: ElicitationAction::Decline,
519 content: None,
520 };
521
522 let channel_request = match &event.request {
523 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
524 message,
525 requested_schema,
526 ..
527 } => {
528 let fields = build_elicitation_fields(requested_schema);
529 ElicitationRequest {
530 server_name: event.server_id.clone(),
531 message: sanitize_elicitation_message(message),
532 fields,
533 }
534 }
535 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
536 tracing::debug!(
538 server_id = event.server_id,
539 "URL elicitation not supported, declining"
540 );
541 let _ = event.response_tx.send(decline);
542 return;
543 }
544 };
545
546 if self.mcp.elicitation_warn_sensitive_fields {
547 let sensitive: Vec<&str> = channel_request
548 .fields
549 .iter()
550 .filter(|f| is_sensitive_field(&f.name))
551 .map(|f| f.name.as_str())
552 .collect();
553 if !sensitive.is_empty() {
554 let fields_list = sensitive.join(", ");
555 let warning = format!(
556 "Warning: [{}] is requesting sensitive information (field: {}). \
557 Only proceed if you trust this server.",
558 channel_request.server_name, fields_list,
559 );
560 tracing::warn!(
561 server_id = event.server_id,
562 fields = %fields_list,
563 "elicitation requests sensitive fields"
564 );
565 let _ = self.channel.send(&warning).await;
566 }
567 }
568
569 let _ = self
570 .channel
571 .send_status("MCP server requesting input…")
572 .await;
573 let response = match self.channel.elicit(channel_request).await {
574 Ok(r) => r,
575 Err(e) => {
576 tracing::warn!(
577 server_id = event.server_id,
578 "elicitation channel error: {e:#}"
579 );
580 let _ = self.channel.send_status("").await;
581 let _ = event.response_tx.send(decline);
582 return;
583 }
584 };
585 let _ = self.channel.send_status("").await;
586
587 let result = match response {
588 ElicitationResponse::Accepted(value) => CreateElicitationResult {
589 action: ElicitationAction::Accept,
590 content: Some(value),
591 },
592 ElicitationResponse::Declined => CreateElicitationResult {
593 action: ElicitationAction::Decline,
594 content: None,
595 },
596 ElicitationResponse::Cancelled => CreateElicitationResult {
597 action: ElicitationAction::Cancel,
598 content: None,
599 },
600 };
601
602 if event.response_tx.send(result).is_err() {
603 tracing::warn!(
604 server_id = event.server_id,
605 "elicitation response dropped — handler disconnected"
606 );
607 }
608 }
609
610 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
620 if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
621 return;
622 }
623
624 if self.mcp.tools.is_empty() {
625 self.mcp.semantic_index = None;
626 return;
627 }
628
629 let provider = self
631 .mcp
632 .discovery_provider
633 .clone()
634 .unwrap_or_else(|| self.embedding_provider.clone());
635
636 let embed_fn = provider.embed_fn();
637
638 match zeph_mcp::SemanticToolIndex::build(&self.mcp.tools, &embed_fn).await {
639 Ok(idx) => {
640 tracing::info!(
641 indexed = idx.len(),
642 total = self.mcp.tools.len(),
643 "semantic tool index built"
644 );
645 self.mcp.semantic_index = Some(idx);
646 }
647 Err(e) => {
648 tracing::warn!(
649 "semantic tool index build failed, falling back to all tools: {e:#}"
650 );
651 self.mcp.semantic_index = None;
652 }
653 }
654 }
655}
656
657fn build_elicitation_fields(
659 schema: &rmcp::model::ElicitationSchema,
660) -> Vec<crate::channel::ElicitationField> {
661 use crate::channel::{ElicitationField, ElicitationFieldType};
662 use rmcp::model::PrimitiveSchema;
663
664 schema
665 .properties
666 .iter()
667 .map(|(name, prop)| {
668 let json = serde_json::to_value(prop).unwrap_or_default();
672 let description = json
673 .get("description")
674 .and_then(|v| v.as_str())
675 .map(String::from);
676
677 let field_type = match prop {
678 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
679 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
680 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
681 PrimitiveSchema::String(_) => ElicitationFieldType::String,
682 PrimitiveSchema::Enum(_) => {
683 let vals = json
686 .get("enum")
687 .and_then(|v| v.as_array())
688 .map(|arr| {
689 arr.iter()
690 .filter_map(|v| v.as_str())
691 .map(String::from)
692 .collect::<Vec<_>>()
693 })
694 .unwrap_or_default();
695 ElicitationFieldType::Enum(vals)
696 }
697 };
698 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
699 ElicitationField {
700 name: name.clone(),
701 description,
702 field_type,
703 required,
704 }
705 })
706 .collect()
707}
708
709const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
711 "password",
712 "passwd",
713 "token",
714 "secret",
715 "key",
716 "credential",
717 "apikey",
718 "api_key",
719 "auth",
720 "authorization",
721 "private",
722 "passphrase",
723 "pin",
724];
725
726fn is_sensitive_field(field_name: &str) -> bool {
728 let lower = field_name.to_lowercase();
729 SENSITIVE_FIELD_PATTERNS
730 .iter()
731 .any(|pattern| lower.contains(pattern))
732}
733
734fn sanitize_elicitation_message(message: &str) -> String {
736 const MAX_CHARS: usize = 500;
737 message
739 .chars()
740 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
741 .take(MAX_CHARS)
742 .collect()
743}
744
745#[cfg(test)]
746mod tests {
747 use super::super::agent_tests::{
748 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
749 };
750 use super::*;
751
752 #[tokio::test]
753 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
754 let provider = mock_provider(vec![]);
755 let channel = MockChannel::new(vec![]);
756 let registry = create_test_registry();
757 let executor = MockToolExecutor::no_tools();
758 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
759
760 agent.handle_mcp_command("unknown").await.unwrap();
761
762 let sent = agent.channel.sent_messages();
763 assert!(
764 sent.iter().any(|s| s.contains("Usage: /mcp")),
765 "expected usage message, got: {sent:?}"
766 );
767 }
768
769 #[tokio::test]
770 async fn handle_mcp_list_no_manager_shows_disabled() {
771 let provider = mock_provider(vec![]);
772 let channel = MockChannel::new(vec![]);
773 let registry = create_test_registry();
774 let executor = MockToolExecutor::no_tools();
775 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
776
777 agent.handle_mcp_command("list").await.unwrap();
778
779 let sent = agent.channel.sent_messages();
780 assert!(
781 sent.iter().any(|s| s.contains("MCP is not enabled")),
782 "expected not-enabled message, got: {sent:?}"
783 );
784 }
785
786 #[tokio::test]
787 async fn handle_mcp_tools_no_server_id_shows_usage() {
788 let provider = mock_provider(vec![]);
789 let channel = MockChannel::new(vec![]);
790 let registry = create_test_registry();
791 let executor = MockToolExecutor::no_tools();
792 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
793
794 agent.handle_mcp_command("tools").await.unwrap();
795
796 let sent = agent.channel.sent_messages();
797 assert!(
798 sent.iter().any(|s| s.contains("Usage: /mcp tools")),
799 "expected tools usage message, got: {sent:?}"
800 );
801 }
802
803 #[tokio::test]
804 async fn handle_mcp_remove_no_server_id_shows_usage() {
805 let provider = mock_provider(vec![]);
806 let channel = MockChannel::new(vec![]);
807 let registry = create_test_registry();
808 let executor = MockToolExecutor::no_tools();
809 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
810
811 agent.handle_mcp_command("remove").await.unwrap();
812
813 let sent = agent.channel.sent_messages();
814 assert!(
815 sent.iter().any(|s| s.contains("Usage: /mcp remove")),
816 "expected remove usage message, got: {sent:?}"
817 );
818 }
819
820 #[tokio::test]
821 async fn handle_mcp_remove_no_manager_shows_disabled() {
822 let provider = mock_provider(vec![]);
823 let channel = MockChannel::new(vec![]);
824 let registry = create_test_registry();
825 let executor = MockToolExecutor::no_tools();
826 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
827
828 agent.handle_mcp_command("remove my-server").await.unwrap();
830
831 let sent = agent.channel.sent_messages();
832 assert!(
833 sent.iter().any(|s| s.contains("MCP is not enabled")),
834 "expected not-enabled message, got: {sent:?}"
835 );
836 }
837
838 #[tokio::test]
839 async fn handle_mcp_add_insufficient_args_shows_usage() {
840 let provider = mock_provider(vec![]);
841 let channel = MockChannel::new(vec![]);
842 let registry = create_test_registry();
843 let executor = MockToolExecutor::no_tools();
844 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
845
846 agent.handle_mcp_command("add server-id").await.unwrap();
848
849 let sent = agent.channel.sent_messages();
850 assert!(
851 sent.iter().any(|s| s.contains("Usage: /mcp add")),
852 "expected add usage message, got: {sent:?}"
853 );
854 }
855
856 #[tokio::test]
857 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
858 let provider = mock_provider(vec![]);
859 let channel = MockChannel::new(vec![]);
860 let registry = create_test_registry();
861 let executor = MockToolExecutor::no_tools();
862 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
863
864 agent
866 .handle_mcp_command("tools nonexistent-server")
867 .await
868 .unwrap();
869
870 let sent = agent.channel.sent_messages();
871 assert!(
872 sent.iter().any(|s| s.contains("No tools found")),
873 "expected no-tools message, got: {sent:?}"
874 );
875 }
876
877 #[tokio::test]
878 async fn mcp_tool_count_starts_at_zero() {
879 let provider = mock_provider(vec![]);
880 let channel = MockChannel::new(vec![]);
881 let registry = create_test_registry();
882 let executor = MockToolExecutor::no_tools();
883 let agent = Agent::new(provider, channel, registry, None, 5, executor);
884
885 assert_eq!(agent.mcp_tool_count(), 0);
886 }
887
888 #[tokio::test]
889 async fn check_tool_refresh_no_rx_is_noop() {
890 let provider = mock_provider(vec![]);
891 let channel = MockChannel::new(vec![]);
892 let registry = create_test_registry();
893 let executor = MockToolExecutor::no_tools();
894 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
895 agent.check_tool_refresh().await;
897 assert_eq!(agent.mcp_tool_count(), 0);
898 }
899
900 #[tokio::test]
901 async fn check_tool_refresh_no_change_is_noop() {
902 let provider = mock_provider(vec![]);
903 let channel = MockChannel::new(vec![]);
904 let registry = create_test_registry();
905 let executor = MockToolExecutor::no_tools();
906 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
907
908 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
909 agent.mcp.tool_rx = Some(rx);
910 agent.check_tool_refresh().await;
912 assert_eq!(agent.mcp_tool_count(), 0);
913 drop(tx);
914 }
915
916 #[tokio::test]
917 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
918 let provider = mock_provider(vec![]);
919 let channel = MockChannel::new(vec![]);
920 let registry = create_test_registry();
921 let executor = MockToolExecutor::no_tools();
922 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
923 agent.mcp.tools = vec![zeph_mcp::McpTool {
924 server_id: "srv".into(),
925 name: "existing_tool".into(),
926 description: String::new(),
927 input_schema: serde_json::json!({}),
928 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
929 }];
930
931 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
932 agent.mcp.tool_rx = Some(rx);
933 agent.check_tool_refresh().await;
935 assert_eq!(agent.mcp_tool_count(), 1);
936 }
937
938 #[tokio::test]
939 async fn check_tool_refresh_applies_update() {
940 let provider = mock_provider(vec![]);
941 let channel = MockChannel::new(vec![]);
942 let registry = create_test_registry();
943 let executor = MockToolExecutor::no_tools();
944 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
945
946 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
947 agent.mcp.tool_rx = Some(rx);
948
949 let new_tools = vec![zeph_mcp::McpTool {
950 server_id: "srv".into(),
951 name: "refreshed_tool".into(),
952 description: String::new(),
953 input_schema: serde_json::json!({}),
954 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
955 }];
956 tx.send(new_tools).unwrap();
957
958 agent.check_tool_refresh().await;
959 assert_eq!(agent.mcp_tool_count(), 1);
960 assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
961 }
962
963 #[test]
964 fn sanitize_elicitation_message_strips_control_chars() {
965 let input = "hello\x01world\x1b[31mred\x1b[0m";
966 let output = sanitize_elicitation_message(input);
967 assert!(!output.contains('\x01'));
968 assert!(!output.contains('\x1b'));
969 assert!(output.contains("hello"));
970 assert!(output.contains("world"));
971 }
972
973 #[test]
974 fn sanitize_elicitation_message_preserves_newline_and_tab() {
975 let input = "line1\nline2\ttabbed";
976 let output = sanitize_elicitation_message(input);
977 assert_eq!(output, "line1\nline2\ttabbed");
978 }
979
980 #[test]
981 fn sanitize_elicitation_message_caps_at_500_chars() {
982 let input: String = "a".repeat(600);
984 let output = sanitize_elicitation_message(&input);
985 assert_eq!(output.chars().count(), 500);
986 }
987
988 #[test]
989 fn sanitize_elicitation_message_handles_multibyte_boundary() {
990 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
993 assert_eq!(output.chars().count(), 300);
995 }
996
997 #[test]
998 fn build_elicitation_fields_maps_primitive_types() {
999 use crate::channel::ElicitationFieldType;
1000 use rmcp::model::{
1001 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
1002 StringSchema,
1003 };
1004 use std::collections::BTreeMap;
1005
1006 let mut props = BTreeMap::new();
1007 props.insert(
1008 "flag".to_owned(),
1009 PrimitiveSchema::Boolean(BooleanSchema::new()),
1010 );
1011 props.insert(
1012 "count".to_owned(),
1013 PrimitiveSchema::Integer(IntegerSchema::new()),
1014 );
1015 props.insert(
1016 "ratio".to_owned(),
1017 PrimitiveSchema::Number(NumberSchema::new()),
1018 );
1019 props.insert(
1020 "name".to_owned(),
1021 PrimitiveSchema::String(StringSchema::new()),
1022 );
1023
1024 let schema = ElicitationSchema::new(props);
1025 let fields = build_elicitation_fields(&schema);
1026
1027 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1028 assert!(matches!(
1029 get("flag").field_type,
1030 ElicitationFieldType::Boolean
1031 ));
1032 assert!(matches!(
1033 get("count").field_type,
1034 ElicitationFieldType::Integer
1035 ));
1036 assert!(matches!(
1037 get("ratio").field_type,
1038 ElicitationFieldType::Number
1039 ));
1040 assert!(matches!(
1041 get("name").field_type,
1042 ElicitationFieldType::String
1043 ));
1044 }
1045
1046 #[test]
1047 fn build_elicitation_fields_required_flag() {
1048 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1049 use std::collections::BTreeMap;
1050
1051 let mut props = BTreeMap::new();
1052 props.insert(
1053 "req".to_owned(),
1054 PrimitiveSchema::String(StringSchema::new()),
1055 );
1056 props.insert(
1057 "opt".to_owned(),
1058 PrimitiveSchema::String(StringSchema::new()),
1059 );
1060
1061 let mut schema = ElicitationSchema::new(props);
1062 schema.required = Some(vec!["req".to_owned()]);
1063
1064 let fields = build_elicitation_fields(&schema);
1065 let req = fields.iter().find(|f| f.name == "req").unwrap();
1066 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1067 assert!(req.required);
1068 assert!(!opt.required);
1069 }
1070
1071 #[test]
1072 fn is_sensitive_field_detects_common_patterns() {
1073 assert!(is_sensitive_field("password"));
1074 assert!(is_sensitive_field("PASSWORD"));
1075 assert!(is_sensitive_field("user_password"));
1076 assert!(is_sensitive_field("api_token"));
1077 assert!(is_sensitive_field("SECRET_KEY"));
1078 assert!(is_sensitive_field("auth_header"));
1079 assert!(is_sensitive_field("private_key"));
1080 }
1081
1082 #[test]
1083 fn is_sensitive_field_allows_non_sensitive_names() {
1084 assert!(!is_sensitive_field("username"));
1085 assert!(!is_sensitive_field("email"));
1086 assert!(!is_sensitive_field("message"));
1087 assert!(!is_sensitive_field("description"));
1088 assert!(!is_sensitive_field("subject"));
1089 }
1090}