1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 pub(super) async fn handle_mcp_command(
15 &mut self,
16 args: &str,
17 ) -> Result<String, super::error::AgentError> {
18 let parts: Vec<&str> = args.split_whitespace().collect();
19 match parts.first().copied() {
20 Some("add") => self.handle_mcp_add(&parts[1..]).await,
21 Some("list") => self.handle_mcp_list().await,
22 Some("tools") => Ok(self.handle_mcp_tools(parts.get(1).copied())),
23 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
24 _ => Ok("Usage: /mcp add|list|tools|remove".to_owned()),
25 }
26 }
27
28 #[allow(clippy::too_many_lines)]
29 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<String, super::error::AgentError> {
30 if args.len() < 2 {
31 return Ok("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>".to_owned());
32 }
33
34 let Some(manager) = self.mcp.manager.clone() else {
36 return Ok("MCP is not enabled.".to_owned());
37 };
38
39 let target = args[1];
40 let is_url = target.starts_with("http://") || target.starts_with("https://");
41
42 if !is_url
44 && !self.mcp.allowed_commands.is_empty()
45 && !self.mcp.allowed_commands.iter().any(|c| c == target)
46 {
47 return Ok(format!(
48 "Command '{target}' is not allowed. Permitted: {}",
49 self.mcp.allowed_commands.join(", ")
50 ));
51 }
52
53 let current_count = manager.list_servers().await.len();
55 if current_count >= self.mcp.max_dynamic {
56 return Ok(format!(
57 "Server limit reached ({}/{}).",
58 current_count, self.mcp.max_dynamic
59 ));
60 }
61
62 let transport = if is_url {
63 zeph_mcp::McpTransport::Http {
64 url: target.to_owned(),
65 headers: std::collections::HashMap::new(),
66 }
67 } else {
68 zeph_mcp::McpTransport::Stdio {
69 command: target.to_owned(),
70 args: args[2..].iter().map(|&s| s.to_owned()).collect(),
71 env: std::collections::HashMap::new(),
72 }
73 };
74
75 let entry = zeph_mcp::ServerEntry {
76 id: args[0].to_owned(),
77 transport,
78 timeout: std::time::Duration::from_secs(30),
79 trust_level: zeph_mcp::McpTrustLevel::Untrusted,
80 tool_allowlist: None,
81 expected_tools: Vec::new(),
82 roots: Vec::new(),
83 tool_metadata: std::collections::HashMap::new(),
84 elicitation_enabled: false,
85 elicitation_timeout_secs: 120,
86 env_isolation: false,
87 };
88
89 match manager.add_server(&entry).await {
90 Ok(tools) => {
91 let count = tools.len();
92 self.mcp
93 .server_outcomes
94 .push(zeph_mcp::ServerConnectOutcome {
95 id: entry.id.clone(),
96 connected: true,
97 tool_count: count,
98 error: String::new(),
99 });
100 self.mcp.tools.extend(tools);
101 self.mcp.sync_executor_tools();
102 self.mcp.pruning_cache.reset();
103 self.mcp.pending_semantic_rebuild = true;
106 let mcp_total = self.mcp.tools.len();
107 let mcp_server_count = self.mcp.server_outcomes.len();
108 let mcp_connected_count = self
109 .mcp
110 .server_outcomes
111 .iter()
112 .filter(|o| o.connected)
113 .count();
114 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
115 .mcp
116 .server_outcomes
117 .iter()
118 .map(|o| crate::metrics::McpServerStatus {
119 id: o.id.clone(),
120 status: if o.connected {
121 crate::metrics::McpServerConnectionStatus::Connected
122 } else {
123 crate::metrics::McpServerConnectionStatus::Failed
124 },
125 tool_count: o.tool_count,
126 error: o.error.clone(),
127 })
128 .collect();
129 self.update_metrics(|m| {
130 m.mcp_tool_count = mcp_total;
131 m.mcp_server_count = mcp_server_count;
132 m.mcp_connected_count = mcp_connected_count;
133 m.mcp_servers = mcp_servers;
134 });
135 Ok(format!(
136 "Connected MCP server '{}' ({count} tool(s))",
137 entry.id
138 ))
139 }
140 Err(e) => {
141 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
142 Ok(format!("Failed to connect server '{}': {e}", entry.id))
143 }
144 }
145 }
146
147 async fn handle_mcp_list(&mut self) -> Result<String, super::error::AgentError> {
148 use std::fmt::Write;
149
150 let Some(manager) = self.mcp.manager.clone() else {
151 return Ok("MCP is not enabled.".to_owned());
152 };
153
154 let server_ids = manager.list_servers().await;
155 if server_ids.is_empty() {
156 return Ok("No MCP servers connected.".to_owned());
157 }
158
159 let mut output = String::from("Connected MCP servers:\n");
160 let mut total = 0usize;
161 for id in &server_ids {
162 let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
163 total += count;
164 let _ = writeln!(output, "- {id} ({count} tools)");
165 }
166 let _ = write!(output, "Total: {total} tool(s)");
167
168 Ok(output)
169 }
170
171 fn handle_mcp_tools(&mut self, server_id: Option<&str>) -> String {
172 use std::fmt::Write;
173
174 let Some(server_id) = server_id else {
175 return "Usage: /mcp tools <server_id>".to_owned();
176 };
177
178 let tools: Vec<_> = self
179 .mcp
180 .tools
181 .iter()
182 .filter(|t| t.server_id == server_id)
183 .collect();
184
185 if tools.is_empty() {
186 return format!("No tools found for server '{server_id}'.");
187 }
188
189 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
190 for t in &tools {
191 if t.description.is_empty() {
192 let _ = writeln!(output, "- {}", t.name);
193 } else {
194 let _ = writeln!(output, "- {} — {}", t.name, t.description);
195 }
196 }
197 output
198 }
199
200 async fn handle_mcp_remove(
201 &mut self,
202 server_id: Option<&str>,
203 ) -> Result<String, super::error::AgentError> {
204 let Some(server_id) = server_id else {
205 return Ok("Usage: /mcp remove <id>".to_owned());
206 };
207
208 let Some(manager) = self.mcp.manager.clone() else {
210 return Ok("MCP is not enabled.".to_owned());
211 };
212
213 match manager.remove_server(server_id).await {
214 Ok(()) => {
215 let before = self.mcp.tools.len();
216 self.mcp.tools.retain(|t| t.server_id != server_id);
217 let removed = before - self.mcp.tools.len();
218 self.mcp.server_outcomes.retain(|o| o.id != server_id);
219 self.mcp.sync_executor_tools();
220 self.mcp.pruning_cache.reset();
221 self.mcp.pending_semantic_rebuild = true;
224 let mcp_total = self.mcp.tools.len();
225 let mcp_server_count = self.mcp.server_outcomes.len();
226 let mcp_connected_count = self
227 .mcp
228 .server_outcomes
229 .iter()
230 .filter(|o| o.connected)
231 .count();
232 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
233 .mcp
234 .server_outcomes
235 .iter()
236 .map(|o| crate::metrics::McpServerStatus {
237 id: o.id.clone(),
238 status: if o.connected {
239 crate::metrics::McpServerConnectionStatus::Connected
240 } else {
241 crate::metrics::McpServerConnectionStatus::Failed
242 },
243 tool_count: o.tool_count,
244 error: o.error.clone(),
245 })
246 .collect();
247 self.update_metrics(|m| {
248 m.mcp_tool_count = mcp_total;
249 m.mcp_server_count = mcp_server_count;
250 m.mcp_connected_count = mcp_connected_count;
251 m.mcp_servers = mcp_servers;
252 m.active_mcp_tools
253 .retain(|name| !name.starts_with(&format!("{server_id}:")));
254 });
255 Ok(format!(
256 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
257 ))
258 }
259 Err(e) => {
260 tracing::warn!(server_id, "MCP remove failed: {e:#}");
261 Ok(format!("Failed to remove server '{server_id}': {e}"))
262 }
263 }
264 }
265
266 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
267 let matched_tools = self.match_mcp_tools(query).await;
268 let active_mcp: Vec<String> = matched_tools
269 .iter()
270 .map(zeph_mcp::McpTool::qualified_name)
271 .collect();
272 let mcp_total = self.mcp.tools.len();
273 let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
274 let connected = self
275 .mcp
276 .tools
277 .iter()
278 .map(|t| &t.server_id)
279 .collect::<std::collections::HashSet<_>>()
280 .len();
281 (connected, connected)
282 } else {
283 let total = self.mcp.server_outcomes.len();
284 let connected = self
285 .mcp
286 .server_outcomes
287 .iter()
288 .filter(|o| o.connected)
289 .count();
290 (total, connected)
291 };
292 self.update_metrics(|m| {
293 m.active_mcp_tools = active_mcp;
294 m.mcp_tool_count = mcp_total;
295 m.mcp_server_count = mcp_server_count;
296 m.mcp_connected_count = mcp_connected_count;
297 });
298 if let Some(ref manager) = self.mcp.manager {
299 let instructions = manager.all_server_instructions().await;
300 if !instructions.is_empty() {
301 system_prompt.push_str("\n\n");
302 system_prompt.push_str(&instructions);
303 }
304 }
305 if !matched_tools.is_empty() {
306 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
307 tracing::debug!(
308 skills = ?self.skill_state.active_skill_names,
309 mcp_tools = ?tool_names,
310 "matched items"
311 );
312 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
313 if !tools_prompt.is_empty() {
314 system_prompt.push_str("\n\n");
315 system_prompt.push_str(&tools_prompt);
316 }
317 }
318 }
319
320 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
321 let Some(ref registry) = self.mcp.registry else {
322 return self.mcp.tools.clone();
323 };
324 let provider = self.embedding_provider.clone();
325 registry
326 .search(query, self.skill_state.max_active_skills, |text| {
327 let owned = text.to_owned();
328 let p = provider.clone();
329 Box::pin(async move { p.embed(&owned).await })
330 })
331 .await
332 }
333
334 pub(super) async fn check_tool_refresh(&mut self) {
345 if self.mcp.pending_semantic_rebuild {
347 self.mcp.pending_semantic_rebuild = false;
348 self.rebuild_semantic_index().await;
349 self.sync_mcp_registry().await;
350 let mcp_total = self.mcp.tools.len();
351 let mcp_servers = self
352 .mcp
353 .tools
354 .iter()
355 .map(|t| &t.server_id)
356 .collect::<std::collections::HashSet<_>>()
357 .len();
358 self.update_metrics(|m| {
359 m.mcp_tool_count = mcp_total;
360 m.mcp_server_count = mcp_servers;
361 });
362 }
363
364 let Some(ref mut rx) = self.mcp.tool_rx else {
365 return;
366 };
367 if !rx.has_changed().unwrap_or(false) {
368 return;
369 }
370 let new_tools = rx.borrow_and_update().clone();
371 if new_tools.is_empty() {
372 return;
375 }
376 tracing::info!(
377 tools = new_tools.len(),
378 "tools/list_changed: agent tool list refreshed"
379 );
380 self.mcp.tools = new_tools;
381 self.mcp.sync_executor_tools();
382 self.mcp.pruning_cache.reset();
383 self.rebuild_semantic_index().await;
384 self.sync_mcp_registry().await;
385 let mcp_total = self.mcp.tools.len();
386 let mcp_servers = self
387 .mcp
388 .tools
389 .iter()
390 .map(|t| &t.server_id)
391 .collect::<std::collections::HashSet<_>>()
392 .len();
393 self.update_metrics(|m| {
394 m.mcp_tool_count = mcp_total;
395 m.mcp_server_count = mcp_servers;
396 });
397 }
398
399 pub(super) async fn sync_mcp_registry(&mut self) {
400 if self.mcp.registry.is_none() {
401 return;
402 }
403 if !self.embedding_provider.supports_embeddings() {
404 return;
405 }
406 let tools = self.mcp.tools.clone();
408 let provider = self.embedding_provider.clone();
409 let embedding_model = self.skill_state.embedding_model.clone();
410 let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
411 let embed_fn = move |text: &str| -> zeph_mcp::registry::EmbedFuture {
412 let owned = text.to_owned();
413 let p = provider.clone();
414 Box::pin(async move {
415 if let Ok(result) = tokio::time::timeout(embed_timeout, p.embed(&owned)).await {
416 result
417 } else {
418 tracing::warn!(
419 timeout_secs = embed_timeout.as_secs(),
420 "MCP registry: embedding timed out"
421 );
422 Err(zeph_llm::LlmError::Timeout)
423 }
424 })
425 };
426 let Some(mut registry) = self.mcp.registry.take() else {
429 return;
430 };
431 if let Err(e) = registry.sync(&tools, &embedding_model, embed_fn).await {
432 tracing::warn!("failed to sync MCP tool registry: {e:#}");
433 }
434 self.mcp.registry = Some(registry);
435 }
436
437 pub async fn init_semantic_index(&mut self) {
444 self.rebuild_semantic_index().await;
445 }
446
447 pub(super) async fn process_pending_elicitations(&mut self) {
452 loop {
453 let Some(ref mut rx) = self.mcp.elicitation_rx else {
454 return;
455 };
456 match rx.try_recv() {
457 Ok(event) => {
458 self.handle_elicitation_event(event).await;
459 }
460 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
461 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
462 self.mcp.elicitation_rx = None;
463 return;
464 }
465 }
466 }
467 }
468
469 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
471 use crate::channel::{ElicitationRequest, ElicitationResponse};
472
473 let decline = CreateElicitationResult {
474 action: ElicitationAction::Decline,
475 content: None,
476 meta: None,
477 };
478
479 let channel_request = match &event.request {
480 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
481 message,
482 requested_schema,
483 ..
484 } => {
485 let fields = build_elicitation_fields(requested_schema);
486 ElicitationRequest {
487 server_name: event.server_id.clone(),
488 message: sanitize_elicitation_message(message),
489 fields,
490 }
491 }
492 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
493 tracing::debug!(
495 server_id = event.server_id,
496 "URL elicitation not supported, declining"
497 );
498 let _ = event.response_tx.send(decline);
499 return;
500 }
501 };
502
503 if self.mcp.elicitation_warn_sensitive_fields {
504 let sensitive: Vec<&str> = channel_request
505 .fields
506 .iter()
507 .filter(|f| is_sensitive_field(&f.name))
508 .map(|f| f.name.as_str())
509 .collect();
510 if !sensitive.is_empty() {
511 let fields_list = sensitive.join(", ");
512 let warning = format!(
513 "Warning: [{}] is requesting sensitive information (field: {}). \
514 Only proceed if you trust this server.",
515 channel_request.server_name, fields_list,
516 );
517 tracing::warn!(
518 server_id = event.server_id,
519 fields = %fields_list,
520 "elicitation requests sensitive fields"
521 );
522 let _ = self.channel.send(&warning).await;
523 }
524 }
525
526 let _ = self
527 .channel
528 .send_status("MCP server requesting input…")
529 .await;
530 let response = match self.channel.elicit(channel_request).await {
531 Ok(r) => r,
532 Err(e) => {
533 tracing::warn!(
534 server_id = event.server_id,
535 "elicitation channel error: {e:#}"
536 );
537 let _ = self.channel.send_status("").await;
538 let _ = event.response_tx.send(decline);
539 return;
540 }
541 };
542 let _ = self.channel.send_status("").await;
543
544 let result = match response {
545 ElicitationResponse::Accepted(value) => CreateElicitationResult {
546 action: ElicitationAction::Accept,
547 content: Some(value),
548 meta: None,
549 },
550 ElicitationResponse::Declined => CreateElicitationResult {
551 action: ElicitationAction::Decline,
552 content: None,
553 meta: None,
554 },
555 ElicitationResponse::Cancelled => CreateElicitationResult {
556 action: ElicitationAction::Cancel,
557 content: None,
558 meta: None,
559 },
560 };
561
562 if event.response_tx.send(result).is_err() {
563 tracing::warn!(
564 server_id = event.server_id,
565 "elicitation response dropped — handler disconnected"
566 );
567 }
568 }
569
570 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
580 if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
581 return;
582 }
583
584 if self.mcp.tools.is_empty() {
585 self.mcp.semantic_index = None;
586 return;
587 }
588
589 let provider = self
591 .mcp
592 .discovery_provider
593 .clone()
594 .unwrap_or_else(|| self.embedding_provider.clone());
595
596 let inner_embed = provider.embed_fn();
597 let embed_timeout = std::time::Duration::from_secs(self.runtime.timeouts.embedding_seconds);
598 let embed_fn = move |text: &str| -> zeph_llm::provider::EmbedFuture {
599 let fut = inner_embed(text);
600 Box::pin(async move {
601 if let Ok(result) = tokio::time::timeout(embed_timeout, fut).await {
602 result
603 } else {
604 tracing::warn!(
605 timeout_secs = embed_timeout.as_secs(),
606 "semantic index: embedding probe timed out"
607 );
608 Err(zeph_llm::LlmError::Timeout)
609 }
610 })
611 };
612
613 let tools = self.mcp.tools.clone();
615 match zeph_mcp::SemanticToolIndex::build(&tools, &embed_fn).await {
616 Ok(idx) => {
617 tracing::info!(
618 indexed = idx.len(),
619 total = self.mcp.tools.len(),
620 "semantic tool index built"
621 );
622 self.mcp.semantic_index = Some(idx);
623 }
624 Err(e) => {
625 tracing::warn!(
626 "semantic tool index build failed, falling back to all tools: {e:#}"
627 );
628 self.mcp.semantic_index = None;
629 }
630 }
631 }
632}
633
634fn build_elicitation_fields(
636 schema: &rmcp::model::ElicitationSchema,
637) -> Vec<crate::channel::ElicitationField> {
638 use crate::channel::{ElicitationField, ElicitationFieldType};
639 use rmcp::model::PrimitiveSchema;
640
641 schema
642 .properties
643 .iter()
644 .map(|(name, prop)| {
645 let json = serde_json::to_value(prop).unwrap_or_default();
649 let description = json
650 .get("description")
651 .and_then(|v| v.as_str())
652 .map(String::from);
653
654 let field_type = match prop {
655 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
656 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
657 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
658 PrimitiveSchema::String(_) => ElicitationFieldType::String,
659 PrimitiveSchema::Enum(_) => {
660 let vals = json
663 .get("enum")
664 .and_then(|v| v.as_array())
665 .map(|arr| {
666 arr.iter()
667 .filter_map(|v| v.as_str())
668 .map(String::from)
669 .collect::<Vec<_>>()
670 })
671 .unwrap_or_default();
672 ElicitationFieldType::Enum(vals)
673 }
674 };
675 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
676 ElicitationField {
677 name: name.clone(),
678 description,
679 field_type,
680 required,
681 }
682 })
683 .collect()
684}
685
686const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
688 "password",
689 "passwd",
690 "token",
691 "secret",
692 "key",
693 "credential",
694 "apikey",
695 "api_key",
696 "auth",
697 "authorization",
698 "private",
699 "passphrase",
700 "pin",
701];
702
703fn is_sensitive_field(field_name: &str) -> bool {
705 let lower = field_name.to_lowercase();
706 SENSITIVE_FIELD_PATTERNS
707 .iter()
708 .any(|pattern| lower.contains(pattern))
709}
710
711fn sanitize_elicitation_message(message: &str) -> String {
713 const MAX_CHARS: usize = 500;
714 message
716 .chars()
717 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
718 .take(MAX_CHARS)
719 .collect()
720}
721
722#[cfg(test)]
723mod tests {
724 use super::super::agent_tests::{
725 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
726 };
727 use super::*;
728
729 #[tokio::test]
730 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
731 let provider = mock_provider(vec![]);
732 let channel = MockChannel::new(vec![]);
733 let registry = create_test_registry();
734 let executor = MockToolExecutor::no_tools();
735 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
736
737 let result = agent.handle_mcp_command("unknown").await.unwrap();
738 assert!(
739 result.contains("Usage: /mcp"),
740 "expected usage message, got: {result:?}"
741 );
742 }
743
744 #[tokio::test]
745 async fn handle_mcp_list_no_manager_shows_disabled() {
746 let provider = mock_provider(vec![]);
747 let channel = MockChannel::new(vec![]);
748 let registry = create_test_registry();
749 let executor = MockToolExecutor::no_tools();
750 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
751
752 let result = agent.handle_mcp_command("list").await.unwrap();
753 assert!(
754 result.contains("MCP is not enabled"),
755 "expected not-enabled message, got: {result:?}"
756 );
757 }
758
759 #[tokio::test]
760 async fn handle_mcp_tools_no_server_id_shows_usage() {
761 let provider = mock_provider(vec![]);
762 let channel = MockChannel::new(vec![]);
763 let registry = create_test_registry();
764 let executor = MockToolExecutor::no_tools();
765 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
766
767 let result = agent.handle_mcp_command("tools").await.unwrap();
768 assert!(
769 result.contains("Usage: /mcp tools"),
770 "expected tools usage message, got: {result:?}"
771 );
772 }
773
774 #[tokio::test]
775 async fn handle_mcp_remove_no_server_id_shows_usage() {
776 let provider = mock_provider(vec![]);
777 let channel = MockChannel::new(vec![]);
778 let registry = create_test_registry();
779 let executor = MockToolExecutor::no_tools();
780 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
781
782 let result = agent.handle_mcp_command("remove").await.unwrap();
783 assert!(
784 result.contains("Usage: /mcp remove"),
785 "expected remove usage message, got: {result:?}"
786 );
787 }
788
789 #[tokio::test]
790 async fn handle_mcp_remove_no_manager_shows_disabled() {
791 let provider = mock_provider(vec![]);
792 let channel = MockChannel::new(vec![]);
793 let registry = create_test_registry();
794 let executor = MockToolExecutor::no_tools();
795 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
796
797 let result = agent.handle_mcp_command("remove my-server").await.unwrap();
798 assert!(
799 result.contains("MCP is not enabled"),
800 "expected not-enabled message, got: {result:?}"
801 );
802 }
803
804 #[tokio::test]
805 async fn handle_mcp_add_insufficient_args_shows_usage() {
806 let provider = mock_provider(vec![]);
807 let channel = MockChannel::new(vec![]);
808 let registry = create_test_registry();
809 let executor = MockToolExecutor::no_tools();
810 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
811
812 let result = agent.handle_mcp_command("add server-id").await.unwrap();
814 assert!(
815 result.contains("Usage: /mcp add"),
816 "expected add usage message, got: {result:?}"
817 );
818 }
819
820 #[tokio::test]
821 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
822 let provider = mock_provider(vec![]);
823 let channel = MockChannel::new(vec![]);
824 let registry = create_test_registry();
825 let executor = MockToolExecutor::no_tools();
826 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
827
828 let result = agent
830 .handle_mcp_command("tools nonexistent-server")
831 .await
832 .unwrap();
833 assert!(
834 result.contains("No tools found"),
835 "expected no-tools message, got: {result:?}"
836 );
837 }
838
839 #[tokio::test]
840 async fn mcp_tool_count_starts_at_zero() {
841 let provider = mock_provider(vec![]);
842 let channel = MockChannel::new(vec![]);
843 let registry = create_test_registry();
844 let executor = MockToolExecutor::no_tools();
845 let agent = Agent::new(provider, channel, registry, None, 5, executor);
846
847 assert_eq!(agent.mcp.tool_count(), 0);
848 }
849
850 #[tokio::test]
851 async fn check_tool_refresh_no_rx_is_noop() {
852 let provider = mock_provider(vec![]);
853 let channel = MockChannel::new(vec![]);
854 let registry = create_test_registry();
855 let executor = MockToolExecutor::no_tools();
856 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
857 agent.check_tool_refresh().await;
859 assert_eq!(agent.mcp.tool_count(), 0);
860 }
861
862 #[tokio::test]
863 async fn check_tool_refresh_no_change_is_noop() {
864 let provider = mock_provider(vec![]);
865 let channel = MockChannel::new(vec![]);
866 let registry = create_test_registry();
867 let executor = MockToolExecutor::no_tools();
868 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
869
870 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
871 agent.mcp.tool_rx = Some(rx);
872 agent.check_tool_refresh().await;
874 assert_eq!(agent.mcp.tool_count(), 0);
875 drop(tx);
876 }
877
878 #[tokio::test]
879 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
880 let provider = mock_provider(vec![]);
881 let channel = MockChannel::new(vec![]);
882 let registry = create_test_registry();
883 let executor = MockToolExecutor::no_tools();
884 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
885 agent.mcp.tools = vec![zeph_mcp::McpTool {
886 server_id: "srv".into(),
887 name: "existing_tool".into(),
888 description: String::new(),
889 input_schema: serde_json::json!({}),
890 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
891 }];
892
893 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
894 agent.mcp.tool_rx = Some(rx);
895 agent.check_tool_refresh().await;
897 assert_eq!(agent.mcp.tool_count(), 1);
898 }
899
900 #[tokio::test]
901 async fn check_tool_refresh_applies_update() {
902 let provider = mock_provider(vec![]);
903 let channel = MockChannel::new(vec![]);
904 let registry = create_test_registry();
905 let executor = MockToolExecutor::no_tools();
906 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
907
908 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
909 agent.mcp.tool_rx = Some(rx);
910
911 let new_tools = vec![zeph_mcp::McpTool {
912 server_id: "srv".into(),
913 name: "refreshed_tool".into(),
914 description: String::new(),
915 input_schema: serde_json::json!({}),
916 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
917 }];
918 tx.send(new_tools).unwrap();
919
920 agent.check_tool_refresh().await;
921 assert_eq!(agent.mcp.tool_count(), 1);
922 assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
923 }
924
925 #[test]
926 fn sanitize_elicitation_message_strips_control_chars() {
927 let input = "hello\x01world\x1b[31mred\x1b[0m";
928 let output = sanitize_elicitation_message(input);
929 assert!(!output.contains('\x01'));
930 assert!(!output.contains('\x1b'));
931 assert!(output.contains("hello"));
932 assert!(output.contains("world"));
933 }
934
935 #[test]
936 fn sanitize_elicitation_message_preserves_newline_and_tab() {
937 let input = "line1\nline2\ttabbed";
938 let output = sanitize_elicitation_message(input);
939 assert_eq!(output, "line1\nline2\ttabbed");
940 }
941
942 #[test]
943 fn sanitize_elicitation_message_caps_at_500_chars() {
944 let input: String = "a".repeat(600);
946 let output = sanitize_elicitation_message(&input);
947 assert_eq!(output.chars().count(), 500);
948 }
949
950 #[test]
951 fn sanitize_elicitation_message_handles_multibyte_boundary() {
952 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
955 assert_eq!(output.chars().count(), 300);
957 }
958
959 #[test]
960 fn build_elicitation_fields_maps_primitive_types() {
961 use crate::channel::ElicitationFieldType;
962 use rmcp::model::{
963 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
964 StringSchema,
965 };
966 use std::collections::BTreeMap;
967
968 let mut props = BTreeMap::new();
969 props.insert(
970 "flag".to_owned(),
971 PrimitiveSchema::Boolean(BooleanSchema::new()),
972 );
973 props.insert(
974 "count".to_owned(),
975 PrimitiveSchema::Integer(IntegerSchema::new()),
976 );
977 props.insert(
978 "ratio".to_owned(),
979 PrimitiveSchema::Number(NumberSchema::new()),
980 );
981 props.insert(
982 "name".to_owned(),
983 PrimitiveSchema::String(StringSchema::new()),
984 );
985
986 let schema = ElicitationSchema::new(props);
987 let fields = build_elicitation_fields(&schema);
988
989 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
990 assert!(matches!(
991 get("flag").field_type,
992 ElicitationFieldType::Boolean
993 ));
994 assert!(matches!(
995 get("count").field_type,
996 ElicitationFieldType::Integer
997 ));
998 assert!(matches!(
999 get("ratio").field_type,
1000 ElicitationFieldType::Number
1001 ));
1002 assert!(matches!(
1003 get("name").field_type,
1004 ElicitationFieldType::String
1005 ));
1006 }
1007
1008 #[test]
1009 fn build_elicitation_fields_required_flag() {
1010 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1011 use std::collections::BTreeMap;
1012
1013 let mut props = BTreeMap::new();
1014 props.insert(
1015 "req".to_owned(),
1016 PrimitiveSchema::String(StringSchema::new()),
1017 );
1018 props.insert(
1019 "opt".to_owned(),
1020 PrimitiveSchema::String(StringSchema::new()),
1021 );
1022
1023 let mut schema = ElicitationSchema::new(props);
1024 schema.required = Some(vec!["req".to_owned()]);
1025
1026 let fields = build_elicitation_fields(&schema);
1027 let req = fields.iter().find(|f| f.name == "req").unwrap();
1028 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1029 assert!(req.required);
1030 assert!(!opt.required);
1031 }
1032
1033 #[test]
1034 fn is_sensitive_field_detects_common_patterns() {
1035 assert!(is_sensitive_field("password"));
1036 assert!(is_sensitive_field("PASSWORD"));
1037 assert!(is_sensitive_field("user_password"));
1038 assert!(is_sensitive_field("api_token"));
1039 assert!(is_sensitive_field("SECRET_KEY"));
1040 assert!(is_sensitive_field("auth_header"));
1041 assert!(is_sensitive_field("private_key"));
1042 }
1043
1044 #[test]
1045 fn is_sensitive_field_allows_non_sensitive_names() {
1046 assert!(!is_sensitive_field("username"));
1047 assert!(!is_sensitive_field("email"));
1048 assert!(!is_sensitive_field("message"));
1049 assert!(!is_sensitive_field("description"));
1050 assert!(!is_sensitive_field("subject"));
1051 }
1052}