1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 pub(super) async fn handle_mcp_command(
15 &mut self,
16 args: &str,
17 ) -> Result<String, super::error::AgentError> {
18 let parts: Vec<&str> = args.split_whitespace().collect();
19 match parts.first().copied() {
20 Some("add") => self.handle_mcp_add(&parts[1..]).await,
21 Some("list") => self.handle_mcp_list().await,
22 Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
23 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
24 _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
25 }
26 }
27
28 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
29 if args.len() < 2 {
30 return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
31 }
32
33 let Some(manager) = self.services.mcp.manager.clone() else {
35 return Ok("MCP is not enabled.".to_owned());
36 };
37
38 let target = args[1];
39 if let Some(err) = validate_mcp_command(target, &self.services.mcp.allowed_commands) {
40 return Ok(err);
41 }
42
43 let current_count = manager.list_servers().await.len();
45 if current_count >= self.services.mcp.max_dynamic {
46 return Ok(format!(
47 "Server limit reached ({}/{}).",
48 current_count, self.services.mcp.max_dynamic
49 ));
50 }
51
52 let entry = build_server_entry(args[0], target, &args[2..]);
53
54 match manager.add_server(&entry).await {
55 Ok(tools) => {
56 let count = tools.len();
57 self.services
58 .mcp
59 .server_outcomes
60 .push(zeph_mcp::ServerConnectOutcome {
61 id: entry.id.clone(),
62 connected: true,
63 tool_count: count,
64 error: String::new(),
65 });
66 self.services.mcp.tools.extend(tools);
67 self.services.mcp.sync_executor_tools();
68 self.services.mcp.pruning_cache.reset();
69 self.services.mcp.pending_semantic_rebuild = true;
72 self.update_mcp_metrics();
73 Ok(format!(
74 "Connected MCP server '{}' ({count} tool(s))",
75 entry.id
76 ))
77 }
78 Err(e) => {
79 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
80 Ok(format!("Failed to connect server '{}': {e}", entry.id))
81 }
82 }
83 }
84
85 async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
86 use std::fmt::Write;
87
88 let Some(manager) = self.services.mcp.manager.clone() else {
89 return Ok("MCP is not enabled.".to_owned());
90 };
91
92 let server_ids = manager.list_servers().await;
93 if server_ids.is_empty() {
94 return Ok("No MCP servers connected.".to_owned());
95 }
96
97 let mut output = String::from("Connected MCP servers:\n");
98 let mut total = 0usize;
99 for id in &server_ids {
100 let count = self
101 .services
102 .mcp
103 .tools
104 .iter()
105 .filter(|t| t.server_id == *id)
106 .count();
107 total += count;
108 let _ = writeln!(output, "- {id} ({count} tools)");
109 }
110 let _ = write!(output, "Total: {total} tool(s)");
111
112 Ok(output)
113 }
114
115 fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
116 use std::fmt::Write;
117
118 let Some(server_id) = server_id else {
119 return "Usage: /mcp tools <server_id>".to_owned();
120 };
121
122 let tools: Vec<_> = self
123 .services
124 .mcp
125 .tools
126 .iter()
127 .filter(|t| t.server_id == server_id)
128 .collect();
129
130 if tools.is_empty() {
131 return format!("No tools found for server '{server_id}'.");
132 }
133
134 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
135 for t in &tools {
136 if t.description.is_empty() {
137 let _ = writeln!(output, "- {}", t.name);
138 } else {
139 let _ = writeln!(output, "- {} — {}", t.name, t.description);
140 }
141 }
142 output
143 }
144
145 async fn handle_mcp_remove(
146 &mut self,
147 server_id: Option<&str>,
148 ) -> Result<String, super::error::AgentError> {
149 let Some(server_id) = server_id else {
150 return Ok("Usage: /mcp remove <id>".to_owned());
151 };
152
153 let Some(manager) = self.services.mcp.manager.clone() else {
155 return Ok("MCP is not enabled.".to_owned());
156 };
157
158 match manager.remove_server(server_id).await {
159 Ok(()) => {
160 let before = self.services.mcp.tools.len();
161 self.services.mcp.tools.retain(|t| t.server_id != server_id);
162 let removed = before - self.services.mcp.tools.len();
163 self.services
164 .mcp
165 .server_outcomes
166 .retain(|o| o.id != server_id);
167 self.services.mcp.sync_executor_tools();
168 self.services.mcp.pruning_cache.reset();
169 self.services.mcp.pending_semantic_rebuild = true;
172 self.update_mcp_metrics();
173 let sid = server_id.to_owned();
174 self.update_metrics(|m| {
175 m.active_mcp_tools
176 .retain(|name| !name.starts_with(&format!("{sid}:")));
177 });
178 Ok(format!(
179 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
180 ))
181 }
182 Err(e) => {
183 tracing::warn!(server_id, "MCP remove failed: {e:#}");
184 Ok(format!("Failed to remove server '{server_id}': {e}"))
185 }
186 }
187 }
188
189 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
190 let matched_tools = self.match_mcp_tools(query).await;
191 let active_mcp: Vec<String> = matched_tools
192 .iter()
193 .map(zeph_mcp::McpTool::qualified_name)
194 .collect();
195 let mcp_total = self.services.mcp.tools.len();
196 let (mcp_server_count, mcp_connected_count) =
197 if self.services.mcp.server_outcomes.is_empty() {
198 let connected = self
199 .services
200 .mcp
201 .tools
202 .iter()
203 .map(|t| &t.server_id)
204 .collect::<std::collections::HashSet<_>>()
205 .len();
206 (connected, connected)
207 } else {
208 let total = self.services.mcp.server_outcomes.len();
209 let connected = self
210 .services
211 .mcp
212 .server_outcomes
213 .iter()
214 .filter(|o| o.connected)
215 .count();
216 (total, connected)
217 };
218 self.update_metrics(|m| {
219 m.active_mcp_tools = active_mcp;
220 m.mcp_tool_count = mcp_total;
221 m.mcp_server_count = mcp_server_count;
222 m.mcp_connected_count = mcp_connected_count;
223 });
224 if let Some(ref manager) = self.services.mcp.manager {
225 let instructions = manager.all_server_instructions().await;
226 if !instructions.is_empty() {
227 system_prompt.push_str("\n\n");
228 system_prompt.push_str(&instructions);
229 }
230 }
231 if !matched_tools.is_empty() {
232 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
233 tracing::debug!(
234 skills = ?self.services.skill.active_skill_names,
235 mcp_tools = ?tool_names,
236 "matched items"
237 );
238 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
239 if !tools_prompt.is_empty() {
240 system_prompt.push_str("\n\n");
241 system_prompt.push_str(&tools_prompt);
242 }
243 }
244 }
245
246 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
247 let Some(ref registry) = self.services.mcp.registry else {
248 return self.services.mcp.tools.clone();
249 };
250 let provider = self.embedding_provider.clone();
251 registry
252 .search(query, self.services.skill.max_active_skills, |text| {
253 let owned = text.to_owned();
254 let p = provider.clone();
255 Box::pin(async move { p.embed(&owned).await })
256 })
257 .await
258 }
259
260 pub(super) async fn check_tool_refresh(&mut self) {
271 if self.services.mcp.pending_semantic_rebuild {
273 self.services.mcp.pending_semantic_rebuild = false;
274 self.rebuild_semantic_index().await;
275 self.sync_mcp_registry().await;
276 let mcp_total = self.services.mcp.tools.len();
277 let mcp_servers = self
278 .services
279 .mcp
280 .tools
281 .iter()
282 .map(|t| &t.server_id)
283 .collect::<std::collections::HashSet<_>>()
284 .len();
285 self.update_metrics(|m| {
286 m.mcp_tool_count = mcp_total;
287 m.mcp_server_count = mcp_servers;
288 });
289 }
290
291 let Some(ref mut rx) = self.services.mcp.tool_rx else {
292 return;
293 };
294 if !rx.has_changed().unwrap_or(false) {
295 return;
296 }
297 let new_tools = rx.borrow_and_update().clone();
298 if new_tools.is_empty() {
299 return;
302 }
303 tracing::info!(
304 tools = new_tools.len(),
305 "tools/list_changed: agent tool list refreshed"
306 );
307 self.services.mcp.tools = new_tools;
308 self.services.mcp.sync_executor_tools();
309 self.services.mcp.pruning_cache.reset();
310 self.rebuild_semantic_index().await;
311 self.sync_mcp_registry().await;
312 let mcp_total = self.services.mcp.tools.len();
313 let mcp_servers = self
314 .services
315 .mcp
316 .tools
317 .iter()
318 .map(|t| &t.server_id)
319 .collect::<std::collections::HashSet<_>>()
320 .len();
321 self.update_metrics(|m| {
322 m.mcp_tool_count = mcp_total;
323 m.mcp_server_count = mcp_servers;
324 });
325 }
326
327 pub(super) async fn sync_mcp_registry(&mut self) {
328 if self.services.mcp.registry.is_none() {
329 return;
330 }
331 if !self.embedding_provider.supports_embeddings() {
332 return;
333 }
334 let tools = self.services.mcp.tools.clone();
336 let provider = self.embedding_provider.clone();
337 let embedding_model = self.services.skill.embedding_model.clone();
338 let embed_timeout =
339 std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
340 let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
341 let owned = text.to_owned();
342 let p = provider.clone();
343 Box::pin(async move {
344 if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
345 result
346 } else {
347 tracing::warn!(
348 timeout_secs = embed_timeout.as_secs(),
349 "MCP registry: embedding timed out"
350 );
351 Err(zeph_llm::LlmError::Timeout)
352 }
353 })
354 };
355 let Some(mut registry) = self.services.mcp.registry.take() else {
358 return;
359 };
360 if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
361 tracing::warn!("failed to sync MCP tool registry: {e:#}");
362 }
363 self.services.mcp.registry = Some(registry);
364 }
365
366 pub async fn init_semantic_index(&mut self) {
373 self.rebuild_semantic_index().await;
374 }
375
376 pub(super) async fn process_pending_elicitations(&mut self) {
381 loop {
382 let Some(ref mut rx) = self.services.mcp.elicitation_rx else {
383 return;
384 };
385 match rx.try_recv() {
386 Ok(event) => {
387 self.handle_elicitation_event(event).await;
388 }
389 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
390 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
391 self.services.mcp.elicitation_rx = None;
392 return;
393 }
394 }
395 }
396 }
397
398 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
400 use crate::channel::{ElicitationRequest, ElicitationResponse};
401
402 let decline = CreateElicitationResult {
403 action: ElicitationAction::Decline,
404 content: None,
405 meta: None,
406 };
407
408 let channel_request = match &event.request {
409 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
410 message,
411 requested_schema,
412 ..
413 } => {
414 let fields = build_elicitation_fields(requested_schema);
415 ElicitationRequest {
416 server_name: event.server_id.clone(),
417 message: sanitize_elicitation_message(message),
418 fields,
419 }
420 }
421 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
422 tracing::debug!(
424 server_id = event.server_id,
425 "URL elicitation not supported, declining"
426 );
427 let _ = event.response_tx.send(decline);
428 return;
429 }
430 };
431
432 if self.services.mcp.elicitation_warn_sensitive_fields {
433 let sensitive: Vec<&str> = channel_request
434 .fields
435 .iter()
436 .filter(|f| is_sensitive_field(&f.name))
437 .map(|f| f.name.as_str())
438 .collect();
439 if !sensitive.is_empty() {
440 let fields_list = sensitive.join(", ");
441 let warning = format!(
442 "Warning: [{}] is requesting sensitive information (field: {}). \
443 Only proceed if you trust this server.",
444 channel_request.server_name, fields_list,
445 );
446 tracing::warn!(
447 server_id = event.server_id,
448 fields = %fields_list,
449 "elicitation requests sensitive fields"
450 );
451 let _ = self.channel.send(&warning).await;
452 }
453 }
454
455 let _ = self
456 .channel
457 .send_status("MCP server requesting input…")
458 .await;
459 let response = match self.channel.elicit(channel_request).await {
460 Ok(r) => r,
461 Err(e) => {
462 tracing::warn!(
463 server_id = event.server_id,
464 "elicitation channel error: {e:#}"
465 );
466 let _ = self.channel.send_status("").await;
467 let _ = event.response_tx.send(decline);
468 return;
469 }
470 };
471 let _ = self.channel.send_status("").await;
472
473 let result = match response {
474 ElicitationResponse::Accepted(value) => CreateElicitationResult {
475 action: ElicitationAction::Accept,
476 content: Some(value),
477 meta: None,
478 },
479 ElicitationResponse::Declined => CreateElicitationResult {
480 action: ElicitationAction::Decline,
481 content: None,
482 meta: None,
483 },
484 ElicitationResponse::Cancelled => CreateElicitationResult {
485 action: ElicitationAction::Cancel,
486 content: None,
487 meta: None,
488 },
489 };
490
491 if event.response_tx.send(result).is_err() {
492 tracing::warn!(
493 server_id = event.server_id,
494 "elicitation response dropped — handler disconnected"
495 );
496 }
497 }
498
499 fn update_mcp_metrics(&mut self) {
500 let mcp_total = self.services.mcp.tools.len();
501 let mcp_server_count = self.services.mcp.server_outcomes.len();
502 let mcp_connected_count = self
503 .services
504 .mcp
505 .server_outcomes
506 .iter()
507 .filter(|o| o.connected)
508 .count();
509 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
510 .services
511 .mcp
512 .server_outcomes
513 .iter()
514 .map(|o| crate::metrics::McpServerStatus {
515 id: o.id.clone(),
516 status: if o.connected {
517 crate::metrics::McpServerConnectionStatus::Connected
518 } else {
519 crate::metrics::McpServerConnectionStatus::Failed
520 },
521 tool_count: o.tool_count,
522 error: o.error.clone(),
523 })
524 .collect();
525 self.update_metrics(|m| {
526 m.mcp_tool_count = mcp_total;
527 m.mcp_server_count = mcp_server_count;
528 m.mcp_connected_count = mcp_connected_count;
529 m.mcp_servers = mcp_servers;
530 });
531 }
532
533 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
543 if self.services.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
544 return;
545 }
546
547 if self.services.mcp.tools.is_empty() {
548 self.services.mcp.semantic_index = None;
549 return;
550 }
551
552 let provider = self
554 .services
555 .mcp
556 .discovery_provider
557 .clone()
558 .unwrap_or_else(|| self.embedding_provider.clone());
559
560 let inner_embed = provider.embed_fn();
561 let embed_timeout =
562 std::time::Duration::from_secs(self.runtime.config.timeouts.embedding_seconds);
563 let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
564 let fut = inner_embed(text);
565 Box::pin(async move {
566 if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
567 result
568 } else {
569 tracing::warn!(
570 timeout_secs = embed_timeout.as_secs(),
571 "semantic index: embedding probe timed out"
572 );
573 Err(zeph_llm::LlmError::Timeout)
574 }
575 })
576 };
577
578 let tools = self.services.mcp.tools.clone();
580 match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
581 Ok(idx) => {
582 tracing::info!(
583 indexed = idx.len(),
584 total = self.services.mcp.tools.len(),
585 "semantic tool index built"
586 );
587 self.services.mcp.semantic_index = Some(idx);
588 }
589 Err(e) => {
590 tracing::warn!(
591 "semantic tool index build failed, falling back to all tools: {e:#}"
592 );
593 self.services.mcp.semantic_index = None;
594 }
595 }
596 }
597}
598
599fn validate_mcp_command(target: &str, allowed_commands: &[String]) -> Option<String> {
603 let is_url = target.starts_with("http://") || target.starts_with("https://");
604 if !is_url && !allowed_commands.is_empty() && !allowed_commands.iter().any(|c| c == target) {
605 Some(format!(
606 "Command '{target}' is not allowed. Permitted: {}",
607 allowed_commands.join(", ")
608 ))
609 } else {
610 None
611 }
612}
613
614fn build_server_entry(id: &str, target: &str, extra_args: &[&str]) -> zeph_mcp::ServerEntry {
616 let is_url = target.starts_with("http://") || target.starts_with("https://");
617 let transport = if is_url {
618 zeph_mcp::McpTransport::Http {
619 url: target.to_owned(),
620 headers: std::collections::HashMap::new(),
621 }
622 } else {
623 zeph_mcp::McpTransport::Stdio {
624 command: target.to_owned(),
625 args: extra_args.iter().map(|&s| s.to_owned()).collect(),
626 env: std::collections::HashMap::new(),
627 }
628 };
629 zeph_mcp::ServerEntry {
630 id: id.to_owned(),
631 transport,
632 timeout: std::time::Duration::from_secs(30),
633 trust_level: zeph_config::McpTrustLevel::Untrusted,
634 tool_allowlist: None,
635 expected_tools: Vec::new(),
636 roots: Vec::new(),
637 tool_metadata: std::collections::HashMap::new(),
638 elicitation_enabled: false,
639 elicitation_timeout_secs: 120,
640 env_isolation: false,
641 }
642}
643
644fn build_elicitation_fields(
646 schema: &rmcp::model::ElicitationSchema,
647) -> Vec<crate::channel::ElicitationField> {
648 use crate::channel::{ElicitationField, ElicitationFieldType};
649 use rmcp::model::PrimitiveSchema;
650
651 schema
652 .properties
653 .iter()
654 .map(|(name, prop)| {
655 let json = serde_json::to_value(prop).unwrap_or_default();
659 let description = json
660 .get("description")
661 .and_then(|v| v.as_str())
662 .map(sanitize_elicitation_message);
663
664 let field_type = match prop {
665 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
666 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
667 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
668 PrimitiveSchema::String(_) => ElicitationFieldType::String,
669 PrimitiveSchema::Enum(_) => {
670 let vals = json
673 .get("enum")
674 .and_then(|v| v.as_array())
675 .map(|arr| {
676 arr.iter()
677 .filter_map(|v| v.as_str())
678 .map(sanitize_elicitation_message)
679 .collect::<Vec<_>>()
680 })
681 .unwrap_or_default();
682 ElicitationFieldType::Enum(vals)
683 }
684 };
685 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
686 ElicitationField {
687 name: name.clone(),
690 description,
691 field_type,
692 required,
693 }
694 })
695 .collect()
696}
697
698const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
700 "password",
701 "passwd",
702 "token",
703 "secret",
704 "key",
705 "credential",
706 "apikey",
707 "api_key",
708 "auth",
709 "authorization",
710 "private",
711 "passphrase",
712 "pin",
713];
714
715fn is_sensitive_field(field_name: &str) -> bool {
717 let lower = field_name.to_lowercase();
718 SENSITIVE_FIELD_PATTERNS
719 .iter()
720 .any(|pattern| lower.contains(pattern))
721}
722
723fn sanitize_elicitation_message(message: &str) -> String {
725 const MAX_CHARS: usize = 500;
726 message
728 .chars()
729 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
730 .take(MAX_CHARS)
731 .collect()
732}
733
734#[cfg(test)]
735mod tests {
736 use super::super::agent_tests::{
737 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
738 };
739 use super::*;
740
741 #[tokio::test]
742 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
743 let provider = mock_provider(vec![]);
744 let channel = MockChannel::new(vec![]);
745 let registry = create_test_registry();
746 let executor = MockToolExecutor::no_tools();
747 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
748
749 let result = agent.handle_mcp_command("unknown").await.unwrap();
750 assert!(
751 result.contains("Usage: /mcp"),
752 "expected usage message, got: {result:?}"
753 );
754 }
755
756 #[tokio::test]
757 async fn handle_mcp_list_no_manager_shows_disabled() {
758 let provider = mock_provider(vec![]);
759 let channel = MockChannel::new(vec![]);
760 let registry = create_test_registry();
761 let executor = MockToolExecutor::no_tools();
762 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
763
764 let result = agent.handle_mcp_command("list").await.unwrap();
765 assert!(
766 result.contains("MCP is not enabled"),
767 "expected not-enabled message, got: {result:?}"
768 );
769 }
770
771 #[tokio::test]
772 async fn handle_mcp_tools_no_server_id_shows_usage() {
773 let provider = mock_provider(vec![]);
774 let channel = MockChannel::new(vec![]);
775 let registry = create_test_registry();
776 let executor = MockToolExecutor::no_tools();
777 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
778
779 let result = agent.handle_mcp_command("tools").await.unwrap();
780 assert!(
781 result.contains("Usage: /mcp tools"),
782 "expected tools usage message, got: {result:?}"
783 );
784 }
785
786 #[tokio::test]
787 async fn handle_mcp_remove_no_server_id_shows_usage() {
788 let provider = mock_provider(vec![]);
789 let channel = MockChannel::new(vec![]);
790 let registry = create_test_registry();
791 let executor = MockToolExecutor::no_tools();
792 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
793
794 let result = agent.handle_mcp_command("remove").await.unwrap();
795 assert!(
796 result.contains("Usage: /mcp remove"),
797 "expected remove usage message, got: {result:?}"
798 );
799 }
800
801 #[tokio::test]
802 async fn handle_mcp_remove_no_manager_shows_disabled() {
803 let provider = mock_provider(vec![]);
804 let channel = MockChannel::new(vec![]);
805 let registry = create_test_registry();
806 let executor = MockToolExecutor::no_tools();
807 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
808
809 let result = agent.handle_mcp_command("remove my-server").await.unwrap();
810 assert!(
811 result.contains("MCP is not enabled"),
812 "expected not-enabled message, got: {result:?}"
813 );
814 }
815
816 #[tokio::test]
817 async fn handle_mcp_add_insufficient_args_shows_usage() {
818 let provider = mock_provider(vec![]);
819 let channel = MockChannel::new(vec![]);
820 let registry = create_test_registry();
821 let executor = MockToolExecutor::no_tools();
822 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
823
824 let result = agent.handle_mcp_command("add server-id").await.unwrap();
826 assert!(
827 result.contains("Usage: /mcp add"),
828 "expected add usage message, got: {result:?}"
829 );
830 }
831
832 #[tokio::test]
833 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
834 let provider = mock_provider(vec![]);
835 let channel = MockChannel::new(vec![]);
836 let registry = create_test_registry();
837 let executor = MockToolExecutor::no_tools();
838 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
839
840 let result = agent
842 .handle_mcp_command("tools nonexistent-server")
843 .await
844 .unwrap();
845 assert!(
846 result.contains("No tools found"),
847 "expected no-tools message, got: {result:?}"
848 );
849 }
850
851 #[tokio::test]
852 async fn mcp_tool_count_starts_at_zero() {
853 let provider = mock_provider(vec![]);
854 let channel = MockChannel::new(vec![]);
855 let registry = create_test_registry();
856 let executor = MockToolExecutor::no_tools();
857 let agent = Agent::new(provider, channel, registry, None, 5, executor);
858
859 assert_eq!(agent.services.mcp.tool_count(), 0);
860 }
861
862 #[tokio::test]
863 async fn check_tool_refresh_no_rx_is_noop() {
864 let provider = mock_provider(vec![]);
865 let channel = MockChannel::new(vec![]);
866 let registry = create_test_registry();
867 let executor = MockToolExecutor::no_tools();
868 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
869 agent.check_tool_refresh().await;
871 assert_eq!(agent.services.mcp.tool_count(), 0);
872 }
873
874 #[tokio::test]
875 async fn check_tool_refresh_no_change_is_noop() {
876 let provider = mock_provider(vec![]);
877 let channel = MockChannel::new(vec![]);
878 let registry = create_test_registry();
879 let executor = MockToolExecutor::no_tools();
880 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
881
882 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
883 agent.services.mcp.tool_rx = Some(rx);
884 agent.check_tool_refresh().await;
886 assert_eq!(agent.services.mcp.tool_count(), 0);
887 drop(tx);
888 }
889
890 #[tokio::test]
891 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
892 let provider = mock_provider(vec![]);
893 let channel = MockChannel::new(vec![]);
894 let registry = create_test_registry();
895 let executor = MockToolExecutor::no_tools();
896 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
897 agent.services.mcp.tools = vec![zeph_mcp::McpTool {
898 server_id: "srv".into(),
899 name: "existing_tool".into(),
900 description: String::new(),
901 input_schema: serde_json::json!({}),
902 output_schema: None,
903 security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
904 }];
905
906 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
907 agent.services.mcp.tool_rx = Some(rx);
908 agent.check_tool_refresh().await;
910 assert_eq!(agent.services.mcp.tool_count(), 1);
911 }
912
913 #[tokio::test]
914 async fn check_tool_refresh_applies_update() {
915 let provider = mock_provider(vec![]);
916 let channel = MockChannel::new(vec![]);
917 let registry = create_test_registry();
918 let executor = MockToolExecutor::no_tools();
919 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
920
921 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
922 agent.services.mcp.tool_rx = Some(rx);
923
924 let new_tools = vec![zeph_mcp::McpTool {
925 server_id: "srv".into(),
926 name: "refreshed_tool".into(),
927 description: String::new(),
928 input_schema: serde_json::json!({}),
929 output_schema: None,
930 security_meta: zeph_config::mcp_security::ToolSecurityMeta::default(),
931 }];
932 tx.send(new_tools).unwrap();
933
934 agent.check_tool_refresh().await;
935 assert_eq!(agent.services.mcp.tool_count(), 1);
936 assert_eq!(agent.services.mcp.tools[0].name, "refreshed_tool");
937 }
938
939 #[test]
940 fn sanitize_elicitation_message_strips_control_chars() {
941 let input = "hello\x01world\x1b[31mred\x1b[0m";
942 let output = sanitize_elicitation_message(input);
943 assert!(!output.contains('\x01'));
944 assert!(!output.contains('\x1b'));
945 assert!(output.contains("hello"));
946 assert!(output.contains("world"));
947 }
948
949 #[test]
950 fn sanitize_elicitation_message_preserves_newline_and_tab() {
951 let input = "line1\nline2\ttabbed";
952 let output = sanitize_elicitation_message(input);
953 assert_eq!(output, "line1\nline2\ttabbed");
954 }
955
956 #[test]
957 fn sanitize_elicitation_message_caps_at_500_chars() {
958 let input: String = "a".repeat(600);
960 let output = sanitize_elicitation_message(&input);
961 assert_eq!(output.chars().count(), 500);
962 }
963
964 #[test]
965 fn sanitize_elicitation_message_handles_multibyte_boundary() {
966 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
969 assert_eq!(output.chars().count(), 300);
971 }
972
973 #[test]
974 fn build_elicitation_fields_maps_primitive_types() {
975 use crate::channel::ElicitationFieldType;
976 use rmcp::model::{
977 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
978 StringSchema,
979 };
980 use std::collections::BTreeMap;
981
982 let mut props = BTreeMap::new();
983 props.insert(
984 "flag".to_owned(),
985 PrimitiveSchema::Boolean(BooleanSchema::new()),
986 );
987 props.insert(
988 "count".to_owned(),
989 PrimitiveSchema::Integer(IntegerSchema::new()),
990 );
991 props.insert(
992 "ratio".to_owned(),
993 PrimitiveSchema::Number(NumberSchema::new()),
994 );
995 props.insert(
996 "name".to_owned(),
997 PrimitiveSchema::String(StringSchema::new()),
998 );
999
1000 let schema = ElicitationSchema::new(props);
1001 let fields = build_elicitation_fields(&schema);
1002
1003 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1004 assert!(matches!(
1005 get("flag").field_type,
1006 ElicitationFieldType::Boolean
1007 ));
1008 assert!(matches!(
1009 get("count").field_type,
1010 ElicitationFieldType::Integer
1011 ));
1012 assert!(matches!(
1013 get("ratio").field_type,
1014 ElicitationFieldType::Number
1015 ));
1016 assert!(matches!(
1017 get("name").field_type,
1018 ElicitationFieldType::String
1019 ));
1020 }
1021
1022 #[test]
1023 fn build_elicitation_fields_required_flag() {
1024 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1025 use std::collections::BTreeMap;
1026
1027 let mut props = BTreeMap::new();
1028 props.insert(
1029 "req".to_owned(),
1030 PrimitiveSchema::String(StringSchema::new()),
1031 );
1032 props.insert(
1033 "opt".to_owned(),
1034 PrimitiveSchema::String(StringSchema::new()),
1035 );
1036
1037 let mut schema = ElicitationSchema::new(props);
1038 schema.required = Some(vec!["req".to_owned()]);
1039
1040 let fields = build_elicitation_fields(&schema);
1041 let req = fields.iter().find(|f| f.name == "req").unwrap();
1042 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1043 assert!(req.required);
1044 assert!(!opt.required);
1045 }
1046
1047 #[test]
1048 fn is_sensitive_field_detects_common_patterns() {
1049 assert!(is_sensitive_field("password"));
1050 assert!(is_sensitive_field("PASSWORD"));
1051 assert!(is_sensitive_field("user_password"));
1052 assert!(is_sensitive_field("api_token"));
1053 assert!(is_sensitive_field("SECRET_KEY"));
1054 assert!(is_sensitive_field("auth_header"));
1055 assert!(is_sensitive_field("private_key"));
1056 }
1057
1058 #[test]
1059 fn is_sensitive_field_allows_non_sensitive_names() {
1060 assert!(!is_sensitive_field("username"));
1061 assert!(!is_sensitive_field("email"));
1062 assert!(!is_sensitive_field("message"));
1063 assert!(!is_sensitive_field("description"));
1064 assert!(!is_sensitive_field("subject"));
1065 }
1066}