1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 #[tracing::instrument(skip_all, name = "core.agent.handle_mcp_command")]
15 pub(super) async fn handle_mcp_command(
16 &mut self,
17 args: &str,
18 ) -> Result<String, super::error::AgentError> {
19 let parts: Vec<&str> = args.split_whitespace().collect();
20 match parts.first().copied() {
21 Some("add") => self.handle_mcp_add(&parts[1..]).await,
22 Some("list") => self.handle_mcp_list().await,
23 Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
24 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
25 _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
26 }
27 }
28
29 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
30 if args.len() < 2 {
31 return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
32 }
33
34 let Some(manager) = self.services.mcp.manager.clone() else {
36 return Ok("MCP is not enabled.".to_owned());
37 };
38
39 let target = args[1];
40 if let Some(err) = validate_mcp_command(target, &self.services.mcp.allowed_commands) {
41 return Ok(err);
42 }
43
44 let current_count = manager.list_servers().await.len();
46 if current_count >= self.services.mcp.max_dynamic {
47 return Ok(format!(
48 "Server limit reached ({}/{}).",
49 current_count, self.services.mcp.max_dynamic
50 ));
51 }
52
53 let entry = build_server_entry(args[0], target, &args[2..]);
54
55 match manager.add_server(&entry).await {
56 Ok(tools) => {
57 let count = tools.len();
58 self.services
59 .mcp
60 .server_outcomes
61 .push(zeph_mcp::ServerConnectOutcome {
62 id: entry.id.clone(),
63 connected: true,
64 tool_count: count,
65 error: String::new(),
66 });
67 self.services.mcp.tools.extend(tools);
68 self.services.mcp.sync_executor_tools();
69 self.services.mcp.pruning_cache.reset();
70 self.services.mcp.pending_semantic_rebuild = true;
73 self.update_mcp_metrics();
74 Ok(format!(
75 "Connected MCP server '{}' ({count} tool(s))",
76 entry.id
77 ))
78 }
79 Err(e) => {
80 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
81 Ok(format!("Failed to connect server '{}': {e}", entry.id))
82 }
83 }
84 }
85
86 async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
87 use std::fmt::Write;
88
89 let Some(manager) = self.services.mcp.manager.clone() else {
90 return Ok("MCP is not enabled.".to_owned());
91 };
92
93 let server_ids = manager.list_servers().await;
94 if server_ids.is_empty() {
95 return Ok("No MCP servers connected.".to_owned());
96 }
97
98 let mut output = String::from("Connected MCP servers:\n");
99 let mut total = 0usize;
100 for id in &server_ids {
101 let count = self
102 .services
103 .mcp
104 .tools
105 .iter()
106 .filter(|t| t.server_id == *id)
107 .count();
108 total += count;
109 let _ = writeln!(output, "- {id} ({count} tools)");
110 }
111 let _ = write!(output, "Total: {total} tool(s)");
112
113 Ok(output)
114 }
115
116 fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
117 use std::fmt::Write;
118
119 let Some(server_id) = server_id else {
120 return "Usage: /mcp tools <server_id>".to_owned();
121 };
122
123 let tools: Vec<_> = self
124 .services
125 .mcp
126 .tools
127 .iter()
128 .filter(|t| t.server_id == server_id)
129 .collect();
130
131 if tools.is_empty() {
132 return format!("No tools found for server '{server_id}'.");
133 }
134
135 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
136 for t in &tools {
137 if t.description.is_empty() {
138 let _ = writeln!(output, "- {}", t.name);
139 } else {
140 let _ = writeln!(output, "- {} — {}", t.name, t.description);
141 }
142 }
143 output
144 }
145
146 async fn handle_mcp_remove(
147 &mut self,
148 server_id: Option<&str>,
149 ) -> Result<String, super::error::AgentError> {
150 let Some(server_id) = server_id else {
151 return Ok("Usage: /mcp remove <id>".to_owned());
152 };
153
154 let Some(manager) = self.services.mcp.manager.clone() else {
156 return Ok("MCP is not enabled.".to_owned());
157 };
158
159 match manager.remove_server(server_id).await {
160 Ok(()) => {
161 let before = self.services.mcp.tools.len();
162 self.services.mcp.tools.retain(|t| t.server_id != server_id);
163 let removed = before - self.services.mcp.tools.len();
164 self.services
165 .mcp
166 .server_outcomes
167 .retain(|o| o.id != server_id);
168 self.services.mcp.sync_executor_tools();
169 self.services.mcp.pruning_cache.reset();
170 self.services.mcp.pending_semantic_rebuild = true;
173 self.update_mcp_metrics();
174 let sid = server_id.to_owned();
175 self.update_metrics(|m| {
176 m.active_mcp_tools
177 .retain(|name| !name.starts_with(&format!("{sid}:")));
178 });
179 Ok(format!(
180 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
181 ))
182 }
183 Err(e) => {
184 tracing::warn!(server_id, "MCP remove failed: {e:#}");
185 Ok(format!("Failed to remove server '{server_id}': {e}"))
186 }
187 }
188 }
189
190 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
191 let matched_tools = self.match_mcp_tools(query).await;
192 let active_mcp: Vec<String> = matched_tools
193 .iter()
194 .map(zeph_mcp::McpTool::qualified_name)
195 .collect();
196 let mcp_total = self.services.mcp.tools.len();
197 let (mcp_server_count, mcp_connected_count) =
198 if self.services.mcp.server_outcomes.is_empty() {
199 let connected = self
200 .services
201 .mcp
202 .tools
203 .iter()
204 .map(|t| &t.server_id)
205 .collect::<std::collections::HashSet<_>>()
206 .len();
207 (connected, connected)
208 } else {
209 let total = self.services.mcp.server_outcomes.len();
210 let connected = self
211 .services
212 .mcp
213 .server_outcomes
214 .iter()
215 .filter(|o| o.connected)
216 .count();
217 (total, connected)
218 };
219 self.update_metrics(|m| {
220 m.active_mcp_tools = active_mcp;
221 m.mcp_tool_count = mcp_total;
222 m.mcp_server_count = mcp_server_count;
223 m.mcp_connected_count = mcp_connected_count;
224 });
225 if let Some(ref manager) = self.services.mcp.manager {
226 let instructions = manager.all_server_instructions().await;
227 if !instructions.is_empty() {
228 system_prompt.push_str("\n\n");
229 system_prompt.push_str(&instructions);
230 }
231 }
232 if !matched_tools.is_empty() {
233 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
234 tracing::debug!(
235 skills = ?self.services.skill.active_skill_names,
236 mcp_tools = ?tool_names,
237 "matched items"
238 );
239 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
240 if !tools_prompt.is_empty() {
241 system_prompt.push_str("\n\n");
242 system_prompt.push_str(&tools_prompt);
243 }
244 }
245 }
246
247 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
248 let Some(ref registry) = self.services.mcp.registry else {
249 return self.services.mcp.tools.clone();
250 };
251 let provider = self.embedding_provider.clone();
252 registry
253 .search(query, self.services.skill.max_active_skills, |text| {
254 let owned = text.to_owned();
255 let p = provider.clone();
256 Box::pin(async move { p.embed(&owned).await })
257 })
258 .await
259 }
260
261 pub(super) async fn check_tool_refresh(&mut self) {
272 if self.services.mcp.pending_semantic_rebuild {
274 self.services.mcp.pending_semantic_rebuild = false;
275 self.rebuild_semantic_index().await;
276 self.sync_mcp_registry().await;
277 let mcp_total = self.services.mcp.tools.len();
278 let mcp_servers = self
279 .services
280 .mcp
281 .tools
282 .iter()
283 .map(|t| &t.server_id)
284 .collect::<std::collections::HashSet<_>>()
285 .len();
286 self.update_metrics(|m| {
287 m.mcp_tool_count = mcp_total;
288 m.mcp_server_count = mcp_servers;
289 });
290 }
291
292 let Some(ref mut rx) = self.services.mcp.tool_rx else {
293 return;
294 };
295 if !rx.has_changed().unwrap_or(false) {
296 return;
297 }
298 let new_tools = rx.borrow_and_update().clone();
299 if new_tools.is_empty() {
300 return;
303 }
304 tracing::info!(
305 tools = new_tools.len(),
306 "tools/list_changed: agent tool list refreshed"
307 );
308 self.services.mcp.tools = new_tools;
309 self.services.mcp.sync_executor_tools();
310 self.services.mcp.pruning_cache.reset();
311 self.rebuild_semantic_index().await;
312 self.sync_mcp_registry().await;
313 let mcp_total = self.services.mcp.tools.len();
314 let mcp_servers = self
315 .services
316 .mcp
317 .tools
318 .iter()
319 .map(|t| &t.server_id)
320 .collect::<std::collections::HashSet<_>>()
321 .len();
322 self.update_metrics(|m| {
323 m.mcp_tool_count = mcp_total;
324 m.mcp_server_count = mcp_servers;
325 });
326 }
327
328 pub(super) async fn sync_mcp_registry(&mut self) {
329 if self.services.mcp.registry.is_none() {
330 return;
331 }
332 if !self.embedding_provider.supports_embeddings() {
333 return;
334 }
335 let tools = self.services.mcp.tools.clone();
337 let provider = self.embedding_provider.clone();
338 let embedding_model = self.services.skill.embedding_model.clone();
339 let embed_timeout =
340 std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
341 let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
342 let owned = text.to_owned();
343 let p = provider.clone();
344 Box::pin(async move {
345 if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
346 result
347 } else {
348 tracing::warn!(
349 timeout_secs = embed_timeout.as_secs(),
350 "MCP registry: embedding timed out"
351 );
352 Err(zeph_llm::LlmError::Timeout)
353 }
354 })
355 };
356 let Some(mut registry) = self.services.mcp.registry.take() else {
359 return;
360 };
361 if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
362 tracing::warn!("failed to sync MCP tool registry: {e:#}");
363 }
364 self.services.mcp.registry = Some(registry);
365 }
366
367 pub async fn init_semantic_index(&mut self) {
374 self.rebuild_semantic_index().await;
375 }
376
377 pub(super) async fn process_pending_elicitations(&mut self) {
382 loop {
383 let Some(ref mut rx) = self.services.mcp.elicitation_rx else {
384 return;
385 };
386 match rx.try_recv() {
387 Ok(event) => {
388 self.handle_elicitation_event(event).await;
389 }
390 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
391 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
392 self.services.mcp.elicitation_rx = None;
393 return;
394 }
395 }
396 }
397 }
398
399 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
401 use crate::channel::{ElicitationRequest, ElicitationResponse};
402
403 let decline = CreateElicitationResult {
404 action: ElicitationAction::Decline,
405 content: None,
406 meta: None,
407 };
408
409 let channel_request = match &event.request {
410 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
411 message,
412 requested_schema,
413 ..
414 } => {
415 let fields = build_elicitation_fields(requested_schema);
416 ElicitationRequest {
417 server_name: event.server_id.clone(),
418 message: sanitize_elicitation_message(message),
419 fields,
420 }
421 }
422 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
423 tracing::debug!(
425 server_id = event.server_id,
426 "URL elicitation not supported, declining"
427 );
428 let _ = event.response_tx.send(decline);
429 return;
430 }
431 };
432
433 if self.services.mcp.elicitation_warn_sensitive_fields {
434 let sensitive: Vec<&str> = channel_request
435 .fields
436 .iter()
437 .filter(|f| is_sensitive_field(&f.name))
438 .map(|f| f.name.as_str())
439 .collect();
440 if !sensitive.is_empty() {
441 let fields_list = sensitive.join(", ");
442 let warning = format!(
443 "Warning: [{}] is requesting sensitive information (field: {}). \
444 Only proceed if you trust this server.",
445 channel_request.server_name, fields_list,
446 );
447 tracing::warn!(
448 server_id = event.server_id,
449 fields = %fields_list,
450 "elicitation requests sensitive fields"
451 );
452 let _ = self.channel.send(&warning).await;
453 }
454 }
455
456 let _ = self
457 .channel
458 .send_status("MCP server requesting input…")
459 .await;
460 let response = match self.channel.elicit(channel_request).await {
461 Ok(r) => r,
462 Err(e) => {
463 tracing::warn!(
464 server_id = event.server_id,
465 "elicitation channel error: {e:#}"
466 );
467 let _ = self.channel.send_status("").await;
468 let _ = event.response_tx.send(decline);
469 return;
470 }
471 };
472 let _ = self.channel.send_status("").await;
473
474 let result = match response {
475 ElicitationResponse::Accepted(value) => CreateElicitationResult {
476 action: ElicitationAction::Accept,
477 content: Some(value),
478 meta: None,
479 },
480 ElicitationResponse::Declined => CreateElicitationResult {
481 action: ElicitationAction::Decline,
482 content: None,
483 meta: None,
484 },
485 ElicitationResponse::Cancelled => CreateElicitationResult {
486 action: ElicitationAction::Cancel,
487 content: None,
488 meta: None,
489 },
490 };
491
492 if event.response_tx.send(result).is_err() {
493 tracing::warn!(
494 server_id = event.server_id,
495 "elicitation response dropped — handler disconnected"
496 );
497 }
498 }
499
500 fn update_mcp_metrics(&mut self) {
501 let mcp_total = self.services.mcp.tools.len();
502 let mcp_server_count = self.services.mcp.server_outcomes.len();
503 let mcp_connected_count = self
504 .services
505 .mcp
506 .server_outcomes
507 .iter()
508 .filter(|o| o.connected)
509 .count();
510 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
511 .services
512 .mcp
513 .server_outcomes
514 .iter()
515 .map(|o| crate::metrics::McpServerStatus {
516 id: o.id.clone(),
517 status: if o.connected {
518 crate::metrics::McpServerConnectionStatus::Connected
519 } else {
520 crate::metrics::McpServerConnectionStatus::Failed
521 },
522 tool_count: o.tool_count,
523 error: o.error.clone(),
524 })
525 .collect();
526 self.update_metrics(|m| {
527 m.mcp_tool_count = mcp_total;
528 m.mcp_server_count = mcp_server_count;
529 m.mcp_connected_count = mcp_connected_count;
530 m.mcp_servers = mcp_servers;
531 });
532 }
533
534 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
544 if self.services.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
545 return;
546 }
547
548 if self.services.mcp.tools.is_empty() {
549 self.services.mcp.semantic_index = None;
550 return;
551 }
552
553 let provider = self
555 .services
556 .mcp
557 .discovery_provider
558 .clone()
559 .unwrap_or_else(|| self.embedding_provider.clone());
560
561 let inner_embed = provider.embed_fn();
562 let embed_timeout =
563 std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
564 let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
565 let fut = inner_embed(text);
566 Box::pin(async move {
567 if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
568 result
569 } else {
570 tracing::warn!(
571 timeout_secs = embed_timeout.as_secs(),
572 "semantic index: embedding probe timed out"
573 );
574 Err(zeph_llm::LlmError::Timeout)
575 }
576 })
577 };
578
579 let tools = self.services.mcp.tools.clone();
581 match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
582 Ok(idx) => {
583 tracing::info!(
584 indexed = idx.len(),
585 total = self.services.mcp.tools.len(),
586 "semantic tool index built"
587 );
588 self.services.mcp.semantic_index = Some(idx);
589 }
590 Err(e) => {
591 tracing::warn!(
592 "semantic tool index build failed, falling back to all tools: {e:#}"
593 );
594 self.services.mcp.semantic_index = None;
595 }
596 }
597 }
598}
599
600fn validate_mcp_command(target: &str, allowed_commands: &[String]) -> Option<String> {
604 let is_url = target.starts_with("http://") || target.starts_with("https://");
605 if !is_url && !allowed_commands.is_empty() && !allowed_commands.iter().any(|c| c == target) {
606 Some(format!(
607 "Command '{target}' is not allowed. Permitted: {}",
608 allowed_commands.join(", ")
609 ))
610 } else {
611 None
612 }
613}
614
615fn build_server_entry(id: &str, target: &str, extra_args: &[&str]) -> zeph_mcp::ServerEntry {
617 let is_url = target.starts_with("http://") || target.starts_with("https://");
618 let transport = if is_url {
619 zeph_mcp::McpTransport::Http {
620 url: target.to_owned(),
621 headers: std::collections::HashMap::new(),
622 }
623 } else {
624 zeph_mcp::McpTransport::Stdio {
625 command: target.to_owned(),
626 args: extra_args.iter().map(|&s| s.to_owned()).collect(),
627 env: std::collections::HashMap::new(),
628 }
629 };
630 zeph_mcp::ServerEntry {
631 id: id.to_owned(),
632 transport,
633 timeout: std::time::Duration::from_secs(30),
634 trust_level: zeph_config::McpTrustLevel::Untrusted,
635 tool_allowlist: None,
636 expected_tools: Vec::new(),
637 roots: Vec::new(),
638 tool_metadata: std::collections::HashMap::new(),
639 elicitation_enabled: false,
640 elicitation_timeout_secs: 120,
641 env_isolation: false,
642 }
643}
644
645fn build_elicitation_fields(
647 schema: &rmcp::model::ElicitationSchema,
648) -> Vec<crate::channel::ElicitationField> {
649 use crate::channel::{ElicitationField, ElicitationFieldType};
650 use rmcp::model::PrimitiveSchema;
651
652 schema
653 .properties
654 .iter()
655 .map(|(name, prop)| {
656 let json = serde_json::to_value(prop).unwrap_or_default();
660 let description = json
661 .get("description")
662 .and_then(|v| v.as_str())
663 .map(sanitize_elicitation_message);
664
665 let field_type = match prop {
666 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
667 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
668 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
669 PrimitiveSchema::String(_) => ElicitationFieldType::String,
670 PrimitiveSchema::Enum(_) => {
671 let vals = json
674 .get("enum")
675 .and_then(|v| v.as_array())
676 .map(|arr| {
677 arr.iter()
678 .filter_map(|v| v.as_str())
679 .map(sanitize_elicitation_message)
680 .collect::<Vec<_>>()
681 })
682 .unwrap_or_default();
683 ElicitationFieldType::Enum(vals)
684 }
685 };
686 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
687 ElicitationField {
688 name: name.clone(),
691 description,
692 field_type,
693 required,
694 }
695 })
696 .collect()
697}
698
699const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
701 "password",
702 "passwd",
703 "token",
704 "secret",
705 "key",
706 "credential",
707 "apikey",
708 "api_key",
709 "auth",
710 "authorization",
711 "private",
712 "passphrase",
713 "pin",
714];
715
716fn is_sensitive_field(field_name: &str) -> bool {
718 let lower = field_name.to_lowercase();
719 SENSITIVE_FIELD_PATTERNS
720 .iter()
721 .any(|pattern| lower.contains(pattern))
722}
723
724fn sanitize_elicitation_message(message: &str) -> String {
726 const MAX_CHARS: usize = 500;
727 message
729 .chars()
730 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
731 .take(MAX_CHARS)
732 .collect()
733}
734
735#[cfg(test)]
736mod tests {
737 use super::super::agent_tests::{
738 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
739 };
740 use super::*;
741
742 #[tokio::test]
743 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
744 let provider = mock_provider(vec![]);
745 let channel = MockChannel::new(vec![]);
746 let registry = create_test_registry();
747 let executor = MockToolExecutor::no_tools();
748 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
749
750 let result = agent.handle_mcp_command("unknown").await.unwrap();
751 assert!(
752 result.contains("Usage: /mcp"),
753 "expected usage message, got: {result:?}"
754 );
755 }
756
757 #[tokio::test]
758 async fn handle_mcp_list_no_manager_shows_disabled() {
759 let provider = mock_provider(vec![]);
760 let channel = MockChannel::new(vec![]);
761 let registry = create_test_registry();
762 let executor = MockToolExecutor::no_tools();
763 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
764
765 let result = agent.handle_mcp_command("list").await.unwrap();
766 assert!(
767 result.contains("MCP is not enabled"),
768 "expected not-enabled message, got: {result:?}"
769 );
770 }
771
772 #[tokio::test]
773 async fn handle_mcp_tools_no_server_id_shows_usage() {
774 let provider = mock_provider(vec![]);
775 let channel = MockChannel::new(vec![]);
776 let registry = create_test_registry();
777 let executor = MockToolExecutor::no_tools();
778 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
779
780 let result = agent.handle_mcp_command("tools").await.unwrap();
781 assert!(
782 result.contains("Usage: /mcp tools"),
783 "expected tools usage message, got: {result:?}"
784 );
785 }
786
787 #[tokio::test]
788 async fn handle_mcp_remove_no_server_id_shows_usage() {
789 let provider = mock_provider(vec![]);
790 let channel = MockChannel::new(vec![]);
791 let registry = create_test_registry();
792 let executor = MockToolExecutor::no_tools();
793 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
794
795 let result = agent.handle_mcp_command("remove").await.unwrap();
796 assert!(
797 result.contains("Usage: /mcp remove"),
798 "expected remove usage message, got: {result:?}"
799 );
800 }
801
802 #[tokio::test]
803 async fn handle_mcp_remove_no_manager_shows_disabled() {
804 let provider = mock_provider(vec![]);
805 let channel = MockChannel::new(vec![]);
806 let registry = create_test_registry();
807 let executor = MockToolExecutor::no_tools();
808 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
809
810 let result = agent.handle_mcp_command("remove my-server").await.unwrap();
811 assert!(
812 result.contains("MCP is not enabled"),
813 "expected not-enabled message, got: {result:?}"
814 );
815 }
816
817 #[tokio::test]
818 async fn handle_mcp_add_insufficient_args_shows_usage() {
819 let provider = mock_provider(vec![]);
820 let channel = MockChannel::new(vec![]);
821 let registry = create_test_registry();
822 let executor = MockToolExecutor::no_tools();
823 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
824
825 let result = agent.handle_mcp_command("add server-id").await.unwrap();
827 assert!(
828 result.contains("Usage: /mcp add"),
829 "expected add usage message, got: {result:?}"
830 );
831 }
832
833 #[tokio::test]
834 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
835 let provider = mock_provider(vec![]);
836 let channel = MockChannel::new(vec![]);
837 let registry = create_test_registry();
838 let executor = MockToolExecutor::no_tools();
839 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
840
841 let result = agent
843 .handle_mcp_command("tools nonexistent-server")
844 .await
845 .unwrap();
846 assert!(
847 result.contains("No tools found"),
848 "expected no-tools message, got: {result:?}"
849 );
850 }
851
852 #[tokio::test]
853 async fn mcp_tool_count_starts_at_zero() {
854 let provider = mock_provider(vec![]);
855 let channel = MockChannel::new(vec![]);
856 let registry = create_test_registry();
857 let executor = MockToolExecutor::no_tools();
858 let agent = Agent::new(provider, channel, registry, None, 5, executor);
859
860 assert_eq!(agent.services.mcp.tool_count(), 0);
861 }
862
863 #[tokio::test]
864 async fn check_tool_refresh_no_rx_is_noop() {
865 let provider = mock_provider(vec![]);
866 let channel = MockChannel::new(vec![]);
867 let registry = create_test_registry();
868 let executor = MockToolExecutor::no_tools();
869 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
870 agent.check_tool_refresh().await;
872 assert_eq!(agent.services.mcp.tool_count(), 0);
873 }
874
875 #[tokio::test]
876 async fn check_tool_refresh_no_change_is_noop() {
877 let provider = mock_provider(vec![]);
878 let channel = MockChannel::new(vec![]);
879 let registry = create_test_registry();
880 let executor = MockToolExecutor::no_tools();
881 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
882
883 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
884 agent.services.mcp.tool_rx = Some(rx);
885 agent.check_tool_refresh().await;
887 assert_eq!(agent.services.mcp.tool_count(), 0);
888 drop(tx);
889 }
890
891 #[tokio::test]
892 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
893 let provider = mock_provider(vec![]);
894 let channel = MockChannel::new(vec![]);
895 let registry = create_test_registry();
896 let executor = MockToolExecutor::no_tools();
897 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
898 agent.services.mcp.tools = vec![zeph_mcp::McpTool {
899 server_id: "srv".into(),
900 name: "existing_tool".into(),
901 description: String::new(),
902 input_schema: serde_json::json!({}),
903 output_schema: None,
904 security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
905 }];
906
907 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
908 agent.services.mcp.tool_rx = Some(rx);
909 agent.check_tool_refresh().await;
911 assert_eq!(agent.services.mcp.tool_count(), 1);
912 }
913
914 #[tokio::test]
915 async fn check_tool_refresh_applies_update() {
916 let provider = mock_provider(vec![]);
917 let channel = MockChannel::new(vec![]);
918 let registry = create_test_registry();
919 let executor = MockToolExecutor::no_tools();
920 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
921
922 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
923 agent.services.mcp.tool_rx = Some(rx);
924
925 let new_tools = vec![zeph_mcp::McpTool {
926 server_id: "srv".into(),
927 name: "refreshed_tool".into(),
928 description: String::new(),
929 input_schema: serde_json::json!({}),
930 output_schema: None,
931 security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
932 }];
933 tx.send(new_tools).unwrap();
934
935 agent.check_tool_refresh().await;
936 assert_eq!(agent.services.mcp.tool_count(), 1);
937 assert_eq!(agent.services.mcp.tools[0].name, "refreshed_tool");
938 }
939
940 #[test]
941 fn sanitize_elicitation_message_strips_control_chars() {
942 let input = "hello\x01world\x1b[31mred\x1b[0m";
943 let output = sanitize_elicitation_message(input);
944 assert!(!output.contains('\x01'));
945 assert!(!output.contains('\x1b'));
946 assert!(output.contains("hello"));
947 assert!(output.contains("world"));
948 }
949
950 #[test]
951 fn sanitize_elicitation_message_preserves_newline_and_tab() {
952 let input = "line1\nline2\ttabbed";
953 let output = sanitize_elicitation_message(input);
954 assert_eq!(output, "line1\nline2\ttabbed");
955 }
956
957 #[test]
958 fn sanitize_elicitation_message_caps_at_500_chars() {
959 let input: String = "a".repeat(600);
961 let output = sanitize_elicitation_message(&input);
962 assert_eq!(output.chars().count(), 500);
963 }
964
965 #[test]
966 fn sanitize_elicitation_message_handles_multibyte_boundary() {
967 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
970 assert_eq!(output.chars().count(), 300);
972 }
973
974 #[test]
975 fn build_elicitation_fields_maps_primitive_types() {
976 use crate::channel::ElicitationFieldType;
977 use rmcp::model::{
978 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
979 StringSchema,
980 };
981 use std::collections::BTreeMap;
982
983 let mut props = BTreeMap::new();
984 props.insert(
985 "flag".to_owned(),
986 PrimitiveSchema::Boolean(BooleanSchema::new()),
987 );
988 props.insert(
989 "count".to_owned(),
990 PrimitiveSchema::Integer(IntegerSchema::new()),
991 );
992 props.insert(
993 "ratio".to_owned(),
994 PrimitiveSchema::Number(NumberSchema::new()),
995 );
996 props.insert(
997 "name".to_owned(),
998 PrimitiveSchema::String(StringSchema::new()),
999 );
1000
1001 let schema = ElicitationSchema::new(props);
1002 let fields = build_elicitation_fields(&schema);
1003
1004 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1005 assert!(matches!(
1006 get("flag").field_type,
1007 ElicitationFieldType::Boolean
1008 ));
1009 assert!(matches!(
1010 get("count").field_type,
1011 ElicitationFieldType::Integer
1012 ));
1013 assert!(matches!(
1014 get("ratio").field_type,
1015 ElicitationFieldType::Number
1016 ));
1017 assert!(matches!(
1018 get("name").field_type,
1019 ElicitationFieldType::String
1020 ));
1021 }
1022
1023 #[test]
1024 fn build_elicitation_fields_required_flag() {
1025 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1026 use std::collections::BTreeMap;
1027
1028 let mut props = BTreeMap::new();
1029 props.insert(
1030 "req".to_owned(),
1031 PrimitiveSchema::String(StringSchema::new()),
1032 );
1033 props.insert(
1034 "opt".to_owned(),
1035 PrimitiveSchema::String(StringSchema::new()),
1036 );
1037
1038 let mut schema = ElicitationSchema::new(props);
1039 schema.required = Some(vec!["req".to_owned()]);
1040
1041 let fields = build_elicitation_fields(&schema);
1042 let req = fields.iter().find(|f| f.name == "req").unwrap();
1043 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1044 assert!(req.required);
1045 assert!(!opt.required);
1046 }
1047
1048 #[test]
1049 fn is_sensitive_field_detects_common_patterns() {
1050 assert!(is_sensitive_field("password"));
1051 assert!(is_sensitive_field("PASSWORD"));
1052 assert!(is_sensitive_field("user_password"));
1053 assert!(is_sensitive_field("api_token"));
1054 assert!(is_sensitive_field("SECRET_KEY"));
1055 assert!(is_sensitive_field("auth_header"));
1056 assert!(is_sensitive_field("private_key"));
1057 }
1058
1059 #[test]
1060 fn is_sensitive_field_allows_non_sensitive_names() {
1061 assert!(!is_sensitive_field("username"));
1062 assert!(!is_sensitive_field("email"));
1063 assert!(!is_sensitive_field("message"));
1064 assert!(!is_sensitive_field("description"));
1065 assert!(!is_sensitive_field("subject"));
1066 }
1067}