1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 pub(super) async fn handle_mcp_command(
15 &mut self,
16 args: &str,
17 ) -> Result<String, super::error::AgentError> {
18 let parts: Vec<&str> = args.split_whitespace().collect();
19 match parts.first().copied() {
20 Some("add") => self.handle_mcp_add(&parts[1..]).await,
21 Some("list") => self.handle_mcp_list().await,
22 Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
23 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
24 _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
25 }
26 }
27
28 #[allow(clippy::too_many_lines)]
29 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
30 if args.len() < 2 {
31 return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
32 }
33
34 let Some(manager) = self.mcp.manager.clone() else {
36 return Ok("MCP is not enabled.".to_owned());
37 };
38
39 let target = args[1];
40 let is_url = target.starts_with("http://") || target.starts_with("https://");
41
42 if !is_url
44 && !self.mcp.allowed_commands.is_empty()
45 && !self.mcp.allowed_commands.iter().any(|c| c == target)
46 {
47 return Ok(format!(
48 "Command '{target}' is not allowed. Permitted: {}",
49 self.mcp.allowed_commands.join(", ")
50 ));
51 }
52
53 let current_count = manager.list_servers().await.len();
55 if current_count >= self.mcp.max_dynamic {
56 return Ok(format!(
57 "Server limit reached ({}/{}).",
58 current_count, self.mcp.max_dynamic
59 ));
60 }
61
62 let transport = if is_url {
63 zeph_mcp::McpTransport::Http {
64 url: target.to_owned(),
65 headers: std::collections::HashMap::new(),
66 }
67 } else {
68 zeph_mcp::McpTransport::Stdio {
69 command: target.to_owned(),
70 args: args[2..].iter().map(|&s| s.to_owned()).collect(),
71 env: std::collections::HashMap::new(),
72 }
73 };
74
75 let entry = zeph_mcp::ServerEntry {
76 id: args[0].to_owned(),
77 transport,
78 timeout: std::time::Duration::from_secs(30),
79 trust_level: zeph_mcp::McpTrustLevel::Untrusted,
80 tool_allowlist: None,
81 expected_tools: Vec::new(),
82 roots: Vec::new(),
83 tool_metadata: std::collections::HashMap::new(),
84 elicitation_enabled: false,
85 elicitation_timeout_secs: 120,
86 env_isolation: false,
87 };
88
89 match manager.add_server(&entry).await {
90 Ok(tools) => {
91 let count = tools.len();
92 self.mcp
93 .server_outcomes
94 .push(zeph_mcp::ServerConnectOutcome {
95 id: entry.id.clone(),
96 connected: true,
97 tool_count: count,
98 error: String::new(),
99 });
100 self.mcp.tools.extend(tools);
101 self.mcp.sync_executor_tools();
102 self.mcp.pruning_cache.reset();
103 self.mcp.pending_semantic_rebuild = true;
106 let mcp_total = self.mcp.tools.len();
107 let mcp_server_count = self.mcp.server_outcomes.len();
108 let mcp_connected_count = self
109 .mcp
110 .server_outcomes
111 .iter()
112 .filter(|o| o.connected)
113 .count();
114 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
115 .mcp
116 .server_outcomes
117 .iter()
118 .map(|o| crate::metrics::McpServerStatus {
119 id: o.id.clone(),
120 status: if o.connected {
121 crate::metrics::McpServerConnectionStatus::Connected
122 } else {
123 crate::metrics::McpServerConnectionStatus::Failed
124 },
125 tool_count: o.tool_count,
126 error: o.error.clone(),
127 })
128 .collect();
129 self.update_metrics(|m| {
130 m.mcp_tool_count = mcp_total;
131 m.mcp_server_count = mcp_server_count;
132 m.mcp_connected_count = mcp_connected_count;
133 m.mcp_servers = mcp_servers;
134 });
135 Ok(format!(
136 "Connected MCP server '{}' ({count} tool(s))",
137 entry.id
138 ))
139 }
140 Err(e) => {
141 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
142 Ok(format!("Failed to connect server '{}': {e}", entry.id))
143 }
144 }
145 }
146
147 async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
148 use std::fmt::Write;
149
150 let Some(manager) = self.mcp.manager.clone() else {
151 return Ok("MCP is not enabled.".to_owned());
152 };
153
154 let server_ids = manager.list_servers().await;
155 if server_ids.is_empty() {
156 return Ok("No MCP servers connected.".to_owned());
157 }
158
159 let mut output = String::from("Connected MCP servers:\n");
160 let mut total = 0usize;
161 for id in &server_ids {
162 let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
163 total += count;
164 let _ = writeln!(output, "- {id} ({count} tools)");
165 }
166 let _ = write!(output, "Total: {total} tool(s)");
167
168 Ok(output)
169 }
170
171 fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
172 use std::fmt::Write;
173
174 let Some(server_id) = server_id else {
175 return "Usage: /mcp tools <server_id>".to_owned();
176 };
177
178 let tools: Vec<_> = self
179 .mcp
180 .tools
181 .iter()
182 .filter(|t| t.server_id == server_id)
183 .collect();
184
185 if tools.is_empty() {
186 return format!("No tools found for server '{server_id}'.");
187 }
188
189 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
190 for t in &tools {
191 if t.description.is_empty() {
192 let _ = writeln!(output, "- {}", t.name);
193 } else {
194 let _ = writeln!(output, "- {} — {}", t.name, t.description);
195 }
196 }
197 output
198 }
199
200 async fn handle_mcp_remove(
201 &mut self,
202 server_id: Option<&str>,
203 ) -> Result<String, super::error::AgentError> {
204 let Some(server_id) = server_id else {
205 return Ok("Usage: /mcp remove <id>".to_owned());
206 };
207
208 let Some(manager) = self.mcp.manager.clone() else {
210 return Ok("MCP is not enabled.".to_owned());
211 };
212
213 match manager.remove_server(server_id).await {
214 Ok(()) => {
215 let before = self.mcp.tools.len();
216 self.mcp.tools.retain(|t| t.server_id != server_id);
217 let removed = before - self.mcp.tools.len();
218 self.mcp.server_outcomes.retain(|o| o.id != server_id);
219 self.mcp.sync_executor_tools();
220 self.mcp.pruning_cache.reset();
221 self.mcp.pending_semantic_rebuild = true;
224 let mcp_total = self.mcp.tools.len();
225 let mcp_server_count = self.mcp.server_outcomes.len();
226 let mcp_connected_count = self
227 .mcp
228 .server_outcomes
229 .iter()
230 .filter(|o| o.connected)
231 .count();
232 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
233 .mcp
234 .server_outcomes
235 .iter()
236 .map(|o| crate::metrics::McpServerStatus {
237 id: o.id.clone(),
238 status: if o.connected {
239 crate::metrics::McpServerConnectionStatus::Connected
240 } else {
241 crate::metrics::McpServerConnectionStatus::Failed
242 },
243 tool_count: o.tool_count,
244 error: o.error.clone(),
245 })
246 .collect();
247 self.update_metrics(|m| {
248 m.mcp_tool_count = mcp_total;
249 m.mcp_server_count = mcp_server_count;
250 m.mcp_connected_count = mcp_connected_count;
251 m.mcp_servers = mcp_servers;
252 m.active_mcp_tools
253 .retain(|name| !name.starts_with(&format!("{server_id}:")));
254 });
255 Ok(format!(
256 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
257 ))
258 }
259 Err(e) => {
260 tracing::warn!(server_id, "MCP remove failed: {e:#}");
261 Ok(format!("Failed to remove server '{server_id}': {e}"))
262 }
263 }
264 }
265
266 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
267 let matched_tools = self.match_mcp_tools(query).await;
268 let active_mcp: Vec<String> = matched_tools
269 .iter()
270 .map(zeph_mcp::McpTool::qualified_name)
271 .collect();
272 let mcp_total = self.mcp.tools.len();
273 let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
274 let connected = self
275 .mcp
276 .tools
277 .iter()
278 .map(|t| &t.server_id)
279 .collect::<std::collections::HashSet<_>>()
280 .len();
281 (connected, connected)
282 } else {
283 let total = self.mcp.server_outcomes.len();
284 let connected = self
285 .mcp
286 .server_outcomes
287 .iter()
288 .filter(|o| o.connected)
289 .count();
290 (total, connected)
291 };
292 self.update_metrics(|m| {
293 m.active_mcp_tools = active_mcp;
294 m.mcp_tool_count = mcp_total;
295 m.mcp_server_count = mcp_server_count;
296 m.mcp_connected_count = mcp_connected_count;
297 });
298 if let Some(ref manager) = self.mcp.manager {
299 let instructions = manager.all_server_instructions().await;
300 if !instructions.is_empty() {
301 system_prompt.push_str("\n\n");
302 system_prompt.push_str(&instructions);
303 }
304 }
305 if !matched_tools.is_empty() {
306 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
307 tracing::debug!(
308 skills = ?self.skill_state.active_skill_names,
309 mcp_tools = ?tool_names,
310 "matched items"
311 );
312 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
313 if !tools_prompt.is_empty() {
314 system_prompt.push_str("\n\n");
315 system_prompt.push_str(&tools_prompt);
316 }
317 }
318 }
319
320 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
321 let Some(ref registry) = self.mcp.registry else {
322 return self.mcp.tools.clone();
323 };
324 let provider = self.embedding_provider.clone();
325 registry
326 .search(query, self.skill_state.max_active_skills, |text| {
327 let owned = text.to_owned();
328 let p = provider.clone();
329 Box::pin(async move { p.embed(&owned).await })
330 })
331 .await
332 }
333
334 pub(super) async fn check_tool_refresh(&mut self) {
345 if self.mcp.pending_semantic_rebuild {
347 self.mcp.pending_semantic_rebuild = false;
348 self.rebuild_semantic_index().await;
349 self.sync_mcp_registry().await;
350 let mcp_total = self.mcp.tools.len();
351 let mcp_servers = self
352 .mcp
353 .tools
354 .iter()
355 .map(|t| &t.server_id)
356 .collect::<std::collections::HashSet<_>>()
357 .len();
358 self.update_metrics(|m| {
359 m.mcp_tool_count = mcp_total;
360 m.mcp_server_count = mcp_servers;
361 });
362 }
363
364 let Some(ref mut rx) = self.mcp.tool_rx else {
365 return;
366 };
367 if !rx.has_changed().unwrap_or(false) {
368 return;
369 }
370 let new_tools = rx.borrow_and_update().clone();
371 if new_tools.is_empty() {
372 return;
375 }
376 tracing::info!(
377 tools = new_tools.len(),
378 "tools/list_changed: agent tool list refreshed"
379 );
380 self.mcp.tools = new_tools;
381 self.mcp.sync_executor_tools();
382 self.mcp.pruning_cache.reset();
383 self.rebuild_semantic_index().await;
384 self.sync_mcp_registry().await;
385 let mcp_total = self.mcp.tools.len();
386 let mcp_servers = self
387 .mcp
388 .tools
389 .iter()
390 .map(|t| &t.server_id)
391 .collect::<std::collections::HashSet<_>>()
392 .len();
393 self.update_metrics(|m| {
394 m.mcp_tool_count = mcp_total;
395 m.mcp_server_count = mcp_servers;
396 });
397 }
398
399 pub(super) async fn sync_mcp_registry(&mut self) {
400 if self.mcp.registry.is_none() {
401 return;
402 }
403 if !self.embedding_provider.supports_embeddings() {
404 return;
405 }
406 let tools = self.mcp.tools.clone();
408 let provider = self.embedding_provider.clone();
409 let embedding_model = self.skill_state.embedding_model.clone();
410 let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
411 let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
412 let owned = text.to_owned();
413 let p = provider.clone();
414 Box::pin(async move {
415 if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
416 result
417 } else {
418 tracing::warn!(
419 timeout_secs = embed_timeout.as_secs(),
420 "MCP registry: embedding timed out"
421 );
422 Err(zeph_llm::LlmError::Timeout)
423 }
424 })
425 };
426 let Some(mut registry) = self.mcp.registry.take() else {
429 return;
430 };
431 if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
432 tracing::warn!("failed to sync MCP tool registry: {e:#}");
433 }
434 self.mcp.registry = Some(registry);
435 }
436
437 pub async fn init_semantic_index(&mut self) {
444 self.rebuild_semantic_index().await;
445 }
446
447 pub(super) async fn process_pending_elicitations(&mut self) {
452 loop {
453 let Some(ref mut rx) = self.mcp.elicitation_rx else {
454 return;
455 };
456 match rx.try_recv() {
457 Ok(event) => {
458 self.handle_elicitation_event(event).await;
459 }
460 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
461 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
462 self.mcp.elicitation_rx = None;
463 return;
464 }
465 }
466 }
467 }
468
469 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
471 use crate::channel::{ElicitationRequest, ElicitationResponse};
472
473 let decline = CreateElicitationResult {
474 action: ElicitationAction::Decline,
475 content: None,
476 meta: None,
477 };
478
479 let channel_request = match &event.request {
480 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
481 message,
482 requested_schema,
483 ..
484 } => {
485 let fields = build_elicitation_fields(requested_schema);
486 ElicitationRequest {
487 server_name: event.server_id.clone(),
488 message: sanitize_elicitation_message(message),
489 fields,
490 }
491 }
492 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
493 tracing::debug!(
495 server_id = event.server_id,
496 "URL elicitation not supported, declining"
497 );
498 let _ = event.response_tx.send(decline);
499 return;
500 }
501 };
502
503 if self.mcp.elicitation_warn_sensitive_fields {
504 let sensitive: Vec<&str> = channel_request
505 .fields
506 .iter()
507 .filter(|f| is_sensitive_field(&f.name))
508 .map(|f| f.name.as_str())
509 .collect();
510 if !sensitive.is_empty() {
511 let fields_list = sensitive.join(", ");
512 let warning = format!(
513 "Warning: [{}] is requesting sensitive information (field: {}). \
514 Only proceed if you trust this server.",
515 channel_request.server_name, fields_list,
516 );
517 tracing::warn!(
518 server_id = event.server_id,
519 fields = %fields_list,
520 "elicitation requests sensitive fields"
521 );
522 let _ = self.channel.send(&warning).await;
523 }
524 }
525
526 let _ = self
527 .channel
528 .send_status("MCP server requesting input…")
529 .await;
530 let response = match self.channel.elicit(channel_request).await {
531 Ok(r) => r,
532 Err(e) => {
533 tracing::warn!(
534 server_id = event.server_id,
535 "elicitation channel error: {e:#}"
536 );
537 let _ = self.channel.send_status("").await;
538 let _ = event.response_tx.send(decline);
539 return;
540 }
541 };
542 let _ = self.channel.send_status("").await;
543
544 let result = match response {
545 ElicitationResponse::Accepted(value) => CreateElicitationResult {
546 action: ElicitationAction::Accept,
547 content: Some(value),
548 meta: None,
549 },
550 ElicitationResponse::Declined => CreateElicitationResult {
551 action: ElicitationAction::Decline,
552 content: None,
553 meta: None,
554 },
555 ElicitationResponse::Cancelled => CreateElicitationResult {
556 action: ElicitationAction::Cancel,
557 content: None,
558 meta: None,
559 },
560 };
561
562 if event.response_tx.send(result).is_err() {
563 tracing::warn!(
564 server_id = event.server_id,
565 "elicitation response dropped — handler disconnected"
566 );
567 }
568 }
569
570 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
580 if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
581 return;
582 }
583
584 if self.mcp.tools.is_empty() {
585 self.mcp.semantic_index = None;
586 return;
587 }
588
589 let provider = self
591 .mcp
592 .discovery_provider
593 .clone()
594 .unwrap_or_else(|| self.embedding_provider.clone());
595
596 let inner_embed = provider.embed_fn();
597 let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
598 let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
599 let fut = inner_embed(text);
600 Box::pin(async move {
601 if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
602 result
603 } else {
604 tracing::warn!(
605 timeout_secs = embed_timeout.as_secs(),
606 "semantic index: embedding probe timed out"
607 );
608 Err(zeph_llm::LlmError::Timeout)
609 }
610 })
611 };
612
613 let tools = self.mcp.tools.clone();
615 match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
616 Ok(idx) => {
617 tracing::info!(
618 indexed = idx.len(),
619 total = self.mcp.tools.len(),
620 "semantic tool index built"
621 );
622 self.mcp.semantic_index = Some(idx);
623 }
624 Err(e) => {
625 tracing::warn!(
626 "semantic tool index build failed, falling back to all tools: {e:#}"
627 );
628 self.mcp.semantic_index = None;
629 }
630 }
631 }
632}
633
634fn build_elicitation_fields(
636 schema: &rmcp::model::ElicitationSchema,
637) -> Vec<crate::channel::ElicitationField> {
638 use crate::channel::{ElicitationField, ElicitationFieldType};
639 use rmcp::model::PrimitiveSchema;
640
641 schema
642 .properties
643 .iter()
644 .map(|(name, prop)| {
645 let json = serde_json::to_value(prop).unwrap_or_default();
649 let description = json
650 .get("description")
651 .and_then(|v| v.as_str())
652 .map(sanitize_elicitation_message);
653
654 let field_type = match prop {
655 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
656 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
657 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
658 PrimitiveSchema::String(_) => ElicitationFieldType::String,
659 PrimitiveSchema::Enum(_) => {
660 let vals = json
663 .get("enum")
664 .and_then(|v| v.as_array())
665 .map(|arr| {
666 arr.iter()
667 .filter_map(|v| v.as_str())
668 .map(sanitize_elicitation_message)
669 .collect::<Vec<_>>()
670 })
671 .unwrap_or_default();
672 ElicitationFieldType::Enum(vals)
673 }
674 };
675 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
676 ElicitationField {
677 name: name.clone(),
680 description,
681 field_type,
682 required,
683 }
684 })
685 .collect()
686}
687
688const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
690 "password",
691 "passwd",
692 "token",
693 "secret",
694 "key",
695 "credential",
696 "apikey",
697 "api_key",
698 "auth",
699 "authorization",
700 "private",
701 "passphrase",
702 "pin",
703];
704
705fn is_sensitive_field(field_name: &str) -> bool {
707 let lower = field_name.to_lowercase();
708 SENSITIVE_FIELD_PATTERNS
709 .iter()
710 .any(|pattern| lower.contains(pattern))
711}
712
713fn sanitize_elicitation_message(message: &str) -> String {
715 const MAX_CHARS: usize = 500;
716 message
718 .chars()
719 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
720 .take(MAX_CHARS)
721 .collect()
722}
723
724#[cfg(test)]
725mod tests {
726 use super::super::agent_tests::{
727 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
728 };
729 use super::*;
730
731 #[tokio::test]
732 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
733 let provider = mock_provider(vec![]);
734 let channel = MockChannel::new(vec![]);
735 let registry = create_test_registry();
736 let executor = MockToolExecutor::no_tools();
737 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
738
739 let result = agent.handle_mcp_command("unknown").await.unwrap();
740 assert!(
741 result.contains("Usage: /mcp"),
742 "expected usage message, got: {result:?}"
743 );
744 }
745
746 #[tokio::test]
747 async fn handle_mcp_list_no_manager_shows_disabled() {
748 let provider = mock_provider(vec![]);
749 let channel = MockChannel::new(vec![]);
750 let registry = create_test_registry();
751 let executor = MockToolExecutor::no_tools();
752 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
753
754 let result = agent.handle_mcp_command("list").await.unwrap();
755 assert!(
756 result.contains("MCP is not enabled"),
757 "expected not-enabled message, got: {result:?}"
758 );
759 }
760
761 #[tokio::test]
762 async fn handle_mcp_tools_no_server_id_shows_usage() {
763 let provider = mock_provider(vec![]);
764 let channel = MockChannel::new(vec![]);
765 let registry = create_test_registry();
766 let executor = MockToolExecutor::no_tools();
767 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
768
769 let result = agent.handle_mcp_command("tools").await.unwrap();
770 assert!(
771 result.contains("Usage: /mcp tools"),
772 "expected tools usage message, got: {result:?}"
773 );
774 }
775
776 #[tokio::test]
777 async fn handle_mcp_remove_no_server_id_shows_usage() {
778 let provider = mock_provider(vec![]);
779 let channel = MockChannel::new(vec![]);
780 let registry = create_test_registry();
781 let executor = MockToolExecutor::no_tools();
782 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
783
784 let result = agent.handle_mcp_command("remove").await.unwrap();
785 assert!(
786 result.contains("Usage: /mcp remove"),
787 "expected remove usage message, got: {result:?}"
788 );
789 }
790
791 #[tokio::test]
792 async fn handle_mcp_remove_no_manager_shows_disabled() {
793 let provider = mock_provider(vec![]);
794 let channel = MockChannel::new(vec![]);
795 let registry = create_test_registry();
796 let executor = MockToolExecutor::no_tools();
797 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
798
799 let result = agent.handle_mcp_command("remove my-server").await.unwrap();
800 assert!(
801 result.contains("MCP is not enabled"),
802 "expected not-enabled message, got: {result:?}"
803 );
804 }
805
806 #[tokio::test]
807 async fn handle_mcp_add_insufficient_args_shows_usage() {
808 let provider = mock_provider(vec![]);
809 let channel = MockChannel::new(vec![]);
810 let registry = create_test_registry();
811 let executor = MockToolExecutor::no_tools();
812 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
813
814 let result = agent.handle_mcp_command("add server-id").await.unwrap();
816 assert!(
817 result.contains("Usage: /mcp add"),
818 "expected add usage message, got: {result:?}"
819 );
820 }
821
822 #[tokio::test]
823 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
824 let provider = mock_provider(vec![]);
825 let channel = MockChannel::new(vec![]);
826 let registry = create_test_registry();
827 let executor = MockToolExecutor::no_tools();
828 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
829
830 let result = agent
832 .handle_mcp_command("tools nonexistent-server")
833 .await
834 .unwrap();
835 assert!(
836 result.contains("No tools found"),
837 "expected no-tools message, got: {result:?}"
838 );
839 }
840
841 #[tokio::test]
842 async fn mcp_tool_count_starts_at_zero() {
843 let provider = mock_provider(vec![]);
844 let channel = MockChannel::new(vec![]);
845 let registry = create_test_registry();
846 let executor = MockToolExecutor::no_tools();
847 let agent = Agent::new(provider, channel, registry, None, 5, executor);
848
849 assert_eq!(agent.mcp.tool_count(), 0);
850 }
851
852 #[tokio::test]
853 async fn check_tool_refresh_no_rx_is_noop() {
854 let provider = mock_provider(vec![]);
855 let channel = MockChannel::new(vec![]);
856 let registry = create_test_registry();
857 let executor = MockToolExecutor::no_tools();
858 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
859 agent.check_tool_refresh().await;
861 assert_eq!(agent.mcp.tool_count(), 0);
862 }
863
864 #[tokio::test]
865 async fn check_tool_refresh_no_change_is_noop() {
866 let provider = mock_provider(vec![]);
867 let channel = MockChannel::new(vec![]);
868 let registry = create_test_registry();
869 let executor = MockToolExecutor::no_tools();
870 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
871
872 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
873 agent.mcp.tool_rx = Some(rx);
874 agent.check_tool_refresh().await;
876 assert_eq!(agent.mcp.tool_count(), 0);
877 drop(tx);
878 }
879
880 #[tokio::test]
881 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
882 let provider = mock_provider(vec![]);
883 let channel = MockChannel::new(vec![]);
884 let registry = create_test_registry();
885 let executor = MockToolExecutor::no_tools();
886 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
887 agent.mcp.tools = vec![zeph_mcp::McpTool {
888 server_id: "srv".into(),
889 name: "existing_tool".into(),
890 description: String::new(),
891 input_schema: serde_json::json!({}),
892 output_schema: None,
893 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
894 }];
895
896 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
897 agent.mcp.tool_rx = Some(rx);
898 agent.check_tool_refresh().await;
900 assert_eq!(agent.mcp.tool_count(), 1);
901 }
902
903 #[tokio::test]
904 async fn check_tool_refresh_applies_update() {
905 let provider = mock_provider(vec![]);
906 let channel = MockChannel::new(vec![]);
907 let registry = create_test_registry();
908 let executor = MockToolExecutor::no_tools();
909 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
910
911 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
912 agent.mcp.tool_rx = Some(rx);
913
914 let new_tools = vec![zeph_mcp::McpTool {
915 server_id: "srv".into(),
916 name: "refreshed_tool".into(),
917 description: String::new(),
918 input_schema: serde_json::json!({}),
919 output_schema: None,
920 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
921 }];
922 tx.send(new_tools).unwrap();
923
924 agent.check_tool_refresh().await;
925 assert_eq!(agent.mcp.tool_count(), 1);
926 assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
927 }
928
929 #[test]
930 fn sanitize_elicitation_message_strips_control_chars() {
931 let input = "hello\x01world\x1b[31mred\x1b[0m";
932 let output = sanitize_elicitation_message(input);
933 assert!(!output.contains('\x01'));
934 assert!(!output.contains('\x1b'));
935 assert!(output.contains("hello"));
936 assert!(output.contains("world"));
937 }
938
939 #[test]
940 fn sanitize_elicitation_message_preserves_newline_and_tab() {
941 let input = "line1\nline2\ttabbed";
942 let output = sanitize_elicitation_message(input);
943 assert_eq!(output, "line1\nline2\ttabbed");
944 }
945
946 #[test]
947 fn sanitize_elicitation_message_caps_at_500_chars() {
948 let input: String = "a".repeat(600);
950 let output = sanitize_elicitation_message(&input);
951 assert_eq!(output.chars().count(), 500);
952 }
953
954 #[test]
955 fn sanitize_elicitation_message_handles_multibyte_boundary() {
956 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
959 assert_eq!(output.chars().count(), 300);
961 }
962
963 #[test]
964 fn build_elicitation_fields_maps_primitive_types() {
965 use crate::channel::ElicitationFieldType;
966 use rmcp::model::{
967 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
968 StringSchema,
969 };
970 use std::collections::BTreeMap;
971
972 let mut props = BTreeMap::new();
973 props.insert(
974 "flag".to_owned(),
975 PrimitiveSchema::Boolean(BooleanSchema::new()),
976 );
977 props.insert(
978 "count".to_owned(),
979 PrimitiveSchema::Integer(IntegerSchema::new()),
980 );
981 props.insert(
982 "ratio".to_owned(),
983 PrimitiveSchema::Number(NumberSchema::new()),
984 );
985 props.insert(
986 "name".to_owned(),
987 PrimitiveSchema::String(StringSchema::new()),
988 );
989
990 let schema = ElicitationSchema::new(props);
991 let fields = build_elicitation_fields(&schema);
992
993 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
994 assert!(matches!(
995 get("flag").field_type,
996 ElicitationFieldType::Boolean
997 ));
998 assert!(matches!(
999 get("count").field_type,
1000 ElicitationFieldType::Integer
1001 ));
1002 assert!(matches!(
1003 get("ratio").field_type,
1004 ElicitationFieldType::Number
1005 ));
1006 assert!(matches!(
1007 get("name").field_type,
1008 ElicitationFieldType::String
1009 ));
1010 }
1011
1012 #[test]
1013 fn build_elicitation_fields_required_flag() {
1014 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1015 use std::collections::BTreeMap;
1016
1017 let mut props = BTreeMap::new();
1018 props.insert(
1019 "req".to_owned(),
1020 PrimitiveSchema::String(StringSchema::new()),
1021 );
1022 props.insert(
1023 "opt".to_owned(),
1024 PrimitiveSchema::String(StringSchema::new()),
1025 );
1026
1027 let mut schema = ElicitationSchema::new(props);
1028 schema.required = Some(vec!["req".to_owned()]);
1029
1030 let fields = build_elicitation_fields(&schema);
1031 let req = fields.iter().find(|f| f.name == "req").unwrap();
1032 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1033 assert!(req.required);
1034 assert!(!opt.required);
1035 }
1036
1037 #[test]
1038 fn is_sensitive_field_detects_common_patterns() {
1039 assert!(is_sensitive_field("password"));
1040 assert!(is_sensitive_field("PASSWORD"));
1041 assert!(is_sensitive_field("user_password"));
1042 assert!(is_sensitive_field("api_token"));
1043 assert!(is_sensitive_field("SECRET_KEY"));
1044 assert!(is_sensitive_field("auth_header"));
1045 assert!(is_sensitive_field("private_key"));
1046 }
1047
1048 #[test]
1049 fn is_sensitive_field_allows_non_sensitive_names() {
1050 assert!(!is_sensitive_field("username"));
1051 assert!(!is_sensitive_field("email"));
1052 assert!(!is_sensitive_field("message"));
1053 assert!(!is_sensitive_field("description"));
1054 assert!(!is_sensitive_field("subject"));
1055 }
1056}