1use rmcp::model::{CreateElicitationResult, ElicitationAction};
5
6use super::{Agent, Channel, LlmProvider};
7
8impl<C: Channel> Agent<C> {
9 pub(super) async fn handle_mcp_command(
10 &mut self,
11 args: &str,
12 ) -> Result<(), super::error::AgentError> {
13 let parts: Vec<&str> = args.split_whitespace().collect();
14 match parts.first().copied() {
15 Some("add") => self.handle_mcp_add(&parts[1..]).await,
16 Some("list") => self.handle_mcp_list().await,
17 Some("tools") => self.handle_mcp_tools(parts.get(1).copied()).await,
18 Some("remove") => self.handle_mcp_remove(parts.get(1).copied()).await,
19 _ => {
20 self.channel
21 .send("Usage: /mcp add|list|tools|remove")
22 .await?;
23 Ok(())
24 }
25 }
26 }
27
28 #[allow(clippy::too_many_lines)]
29 async fn handle_mcp_add(&mut self, args: &[&str]) -> Result<(), super::error::AgentError> {
30 if args.len() < 2 {
31 self.channel
32 .send("Usage: /mcp add <id> <command> [args...] | /mcp add <id> <url>")
33 .await?;
34 return Ok(());
35 }
36
37 let Some(ref manager) = self.mcp.manager else {
38 self.channel.send("MCP is not enabled.").await?;
39 return Ok(());
40 };
41
42 let target = args[1];
43 let is_url = target.starts_with("http://") || target.starts_with("https://");
44
45 if !is_url
47 && !self.mcp.allowed_commands.is_empty()
48 && !self.mcp.allowed_commands.iter().any(|c| c == target)
49 {
50 self.channel
51 .send(&format!(
52 "Command '{target}' is not allowed. Permitted: {}",
53 self.mcp.allowed_commands.join(", ")
54 ))
55 .await?;
56 return Ok(());
57 }
58
59 let current_count = manager.list_servers().await.len();
61 if current_count >= self.mcp.max_dynamic {
62 self.channel
63 .send(&format!(
64 "Server limit reached ({}/{}).",
65 current_count, self.mcp.max_dynamic
66 ))
67 .await?;
68 return Ok(());
69 }
70
71 let transport = if is_url {
72 zeph_mcp::McpTransport::Http {
73 url: target.to_owned(),
74 headers: std::collections::HashMap::new(),
75 }
76 } else {
77 zeph_mcp::McpTransport::Stdio {
78 command: target.to_owned(),
79 args: args[2..].iter().map(|&s| s.to_owned()).collect(),
80 env: std::collections::HashMap::new(),
81 }
82 };
83
84 let entry = zeph_mcp::ServerEntry {
85 id: args[0].to_owned(),
86 transport,
87 timeout: std::time::Duration::from_secs(30),
88 trust_level: zeph_mcp::McpTrustLevel::Untrusted,
89 tool_allowlist: None,
90 expected_tools: Vec::new(),
91 roots: Vec::new(),
92 tool_metadata: std::collections::HashMap::new(),
93 elicitation_enabled: false,
94 elicitation_timeout_secs: 120,
95 env_isolation: false,
96 };
97
98 let _ = self.channel.send_status("connecting to mcp...").await;
99 match manager.add_server(&entry).await {
100 Ok(tools) => {
101 let _ = self.channel.send_status("").await;
102 let count = tools.len();
103 self.mcp
104 .server_outcomes
105 .push(zeph_mcp::ServerConnectOutcome {
106 id: entry.id.clone(),
107 connected: true,
108 tool_count: count,
109 error: String::new(),
110 });
111 self.mcp.tools.extend(tools);
112 self.sync_mcp_executor_tools();
113 self.mcp.pruning_cache.reset();
114 self.rebuild_semantic_index().await;
115 self.sync_mcp_registry().await;
116 let mcp_total = self.mcp.tools.len();
117 let mcp_server_count = self.mcp.server_outcomes.len();
118 let mcp_connected_count = self
119 .mcp
120 .server_outcomes
121 .iter()
122 .filter(|o| o.connected)
123 .count();
124 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
125 .mcp
126 .server_outcomes
127 .iter()
128 .map(|o| crate::metrics::McpServerStatus {
129 id: o.id.clone(),
130 status: if o.connected {
131 crate::metrics::McpServerConnectionStatus::Connected
132 } else {
133 crate::metrics::McpServerConnectionStatus::Failed
134 },
135 tool_count: o.tool_count,
136 error: o.error.clone(),
137 })
138 .collect();
139 self.update_metrics(|m| {
140 m.mcp_tool_count = mcp_total;
141 m.mcp_server_count = mcp_server_count;
142 m.mcp_connected_count = mcp_connected_count;
143 m.mcp_servers = mcp_servers;
144 });
145 self.channel
146 .send(&format!(
147 "Connected MCP server '{}' ({count} tool(s))",
148 entry.id
149 ))
150 .await?;
151 Ok(())
152 }
153 Err(e) => {
154 let _ = self.channel.send_status("").await;
155 tracing::warn!(server_id = entry.id, "MCP add failed: {e:#}");
156 self.channel
157 .send(&format!("Failed to connect server '{}': {e}", entry.id))
158 .await?;
159 Ok(())
160 }
161 }
162 }
163
164 async fn handle_mcp_list(&mut self) -> Result<(), super::error::AgentError> {
165 use std::fmt::Write;
166
167 let Some(ref manager) = self.mcp.manager else {
168 self.channel.send("MCP is not enabled.").await?;
169 return Ok(());
170 };
171
172 let server_ids = manager.list_servers().await;
173 if server_ids.is_empty() {
174 self.channel.send("No MCP servers connected.").await?;
175 return Ok(());
176 }
177
178 let mut output = String::from("Connected MCP servers:\n");
179 let mut total = 0usize;
180 for id in &server_ids {
181 let count = self.mcp.tools.iter().filter(|t| t.server_id == *id).count();
182 total += count;
183 let _ = writeln!(output, "- {id} ({count} tools)");
184 }
185 let _ = write!(output, "Total: {total} tool(s)");
186
187 self.channel.send(&output).await?;
188 Ok(())
189 }
190
191 async fn handle_mcp_tools(
192 &mut self,
193 server_id: Option<&str>,
194 ) -> Result<(), super::error::AgentError> {
195 use std::fmt::Write;
196
197 let Some(server_id) = server_id else {
198 self.channel.send("Usage: /mcp tools <server_id>").await?;
199 return Ok(());
200 };
201
202 let tools: Vec<_> = self
203 .mcp
204 .tools
205 .iter()
206 .filter(|t| t.server_id == server_id)
207 .collect();
208
209 if tools.is_empty() {
210 self.channel
211 .send(&format!("No tools found for server '{server_id}'."))
212 .await?;
213 return Ok(());
214 }
215
216 let mut output = format!("Tools for '{server_id}' ({} total):\n", tools.len());
217 for t in &tools {
218 if t.description.is_empty() {
219 let _ = writeln!(output, "- {}", t.name);
220 } else {
221 let _ = writeln!(output, "- {} — {}", t.name, t.description);
222 }
223 }
224 self.channel.send(&output).await?;
225 Ok(())
226 }
227
228 async fn handle_mcp_remove(
229 &mut self,
230 server_id: Option<&str>,
231 ) -> Result<(), super::error::AgentError> {
232 let Some(server_id) = server_id else {
233 self.channel.send("Usage: /mcp remove <id>").await?;
234 return Ok(());
235 };
236
237 let Some(ref manager) = self.mcp.manager else {
238 self.channel.send("MCP is not enabled.").await?;
239 return Ok(());
240 };
241
242 match manager.remove_server(server_id).await {
243 Ok(()) => {
244 let before = self.mcp.tools.len();
245 self.mcp.tools.retain(|t| t.server_id != server_id);
246 let removed = before - self.mcp.tools.len();
247 self.mcp.server_outcomes.retain(|o| o.id != server_id);
248 self.sync_mcp_executor_tools();
249 self.mcp.pruning_cache.reset();
250 self.rebuild_semantic_index().await;
251 self.sync_mcp_registry().await;
252 let mcp_total = self.mcp.tools.len();
253 let mcp_server_count = self.mcp.server_outcomes.len();
254 let mcp_connected_count = self
255 .mcp
256 .server_outcomes
257 .iter()
258 .filter(|o| o.connected)
259 .count();
260 let mcp_servers: Vec<crate::metrics::McpServerStatus> = self
261 .mcp
262 .server_outcomes
263 .iter()
264 .map(|o| crate::metrics::McpServerStatus {
265 id: o.id.clone(),
266 status: if o.connected {
267 crate::metrics::McpServerConnectionStatus::Connected
268 } else {
269 crate::metrics::McpServerConnectionStatus::Failed
270 },
271 tool_count: o.tool_count,
272 error: o.error.clone(),
273 })
274 .collect();
275 self.update_metrics(|m| {
276 m.mcp_tool_count = mcp_total;
277 m.mcp_server_count = mcp_server_count;
278 m.mcp_connected_count = mcp_connected_count;
279 m.mcp_servers = mcp_servers;
280 m.active_mcp_tools
281 .retain(|name| !name.starts_with(&format!("{server_id}:")));
282 });
283 self.channel
284 .send(&format!(
285 "Disconnected MCP server '{server_id}' (removed {removed} tools)"
286 ))
287 .await?;
288 Ok(())
289 }
290 Err(e) => {
291 tracing::warn!(server_id, "MCP remove failed: {e:#}");
292 self.channel
293 .send(&format!("Failed to remove server '{server_id}': {e}"))
294 .await?;
295 Ok(())
296 }
297 }
298 }
299
300 pub(super) async fn append_mcp_prompt(&mut self, query: &str, system_prompt: &mut String) {
301 let matched_tools = self.match_mcp_tools(query).await;
302 let active_mcp: Vec<String> = matched_tools
303 .iter()
304 .map(zeph_mcp::McpTool::qualified_name)
305 .collect();
306 let mcp_total = self.mcp.tools.len();
307 let (mcp_server_count, mcp_connected_count) = if self.mcp.server_outcomes.is_empty() {
308 let connected = self
309 .mcp
310 .tools
311 .iter()
312 .map(|t| &t.server_id)
313 .collect::<std::collections::HashSet<_>>()
314 .len();
315 (connected, connected)
316 } else {
317 let total = self.mcp.server_outcomes.len();
318 let connected = self
319 .mcp
320 .server_outcomes
321 .iter()
322 .filter(|o| o.connected)
323 .count();
324 (total, connected)
325 };
326 self.update_metrics(|m| {
327 m.active_mcp_tools = active_mcp;
328 m.mcp_tool_count = mcp_total;
329 m.mcp_server_count = mcp_server_count;
330 m.mcp_connected_count = mcp_connected_count;
331 });
332 if let Some(ref manager) = self.mcp.manager {
333 let instructions = manager.all_server_instructions().await;
334 if !instructions.is_empty() {
335 system_prompt.push_str("\n\n");
336 system_prompt.push_str(&instructions);
337 }
338 }
339 if self.provider.supports_tool_use() {
342 return;
343 }
344 if !matched_tools.is_empty() {
345 let tool_names: Vec<&str> = matched_tools.iter().map(|t| t.name.as_str()).collect();
346 tracing::debug!(
347 skills = ?self.skill_state.active_skill_names,
348 mcp_tools = ?tool_names,
349 "matched items"
350 );
351 let tools_prompt = zeph_mcp::format_mcp_tools_prompt(&matched_tools);
352 if !tools_prompt.is_empty() {
353 system_prompt.push_str("\n\n");
354 system_prompt.push_str(&tools_prompt);
355 }
356 }
357 }
358
359 async fn match_mcp_tools(&self, query: &str) -> Vec<zeph_mcp::McpTool> {
360 let Some(ref registry) = self.mcp.registry else {
361 return self.mcp.tools.clone();
362 };
363 let provider = self.embedding_provider.clone();
364 registry
365 .search(query, self.skill_state.max_active_skills, |text| {
366 let owned = text.to_owned();
367 let p = provider.clone();
368 Box::pin(async move { p.embed(&owned).await })
369 })
370 .await
371 }
372
373 #[cfg(test)]
374 pub(crate) fn mcp_tool_count(&self) -> usize {
375 self.mcp.tools.len()
376 }
377
378 pub(super) async fn check_tool_refresh(&mut self) {
384 let Some(ref mut rx) = self.mcp.tool_rx else {
385 return;
386 };
387 if !rx.has_changed().unwrap_or(false) {
388 return;
389 }
390 let new_tools = rx.borrow_and_update().clone();
391 if new_tools.is_empty() {
392 return;
395 }
396 tracing::info!(
397 tools = new_tools.len(),
398 "tools/list_changed: agent tool list refreshed"
399 );
400 self.mcp.tools = new_tools;
401 self.sync_mcp_executor_tools();
402 self.mcp.pruning_cache.reset();
403 self.rebuild_semantic_index().await;
404 self.sync_mcp_registry().await;
405 let mcp_total = self.mcp.tools.len();
406 let mcp_servers = self
407 .mcp
408 .tools
409 .iter()
410 .map(|t| &t.server_id)
411 .collect::<std::collections::HashSet<_>>()
412 .len();
413 self.update_metrics(|m| {
414 m.mcp_tool_count = mcp_total;
415 m.mcp_server_count = mcp_servers;
416 });
417 }
418
419 pub(super) fn sync_mcp_executor_tools(&self) {
428 if let Some(ref shared) = self.mcp.shared_tools {
429 let mut guard = shared
430 .write()
431 .unwrap_or_else(std::sync::PoisonError::into_inner);
432 guard.clone_from(&self.mcp.tools);
433 }
434 }
435
436 pub(in crate::agent) fn apply_pruned_mcp_tools(&self, pruned: Vec<zeph_mcp::McpTool>) {
451 debug_assert!(
452 pruned.iter().all(|p| self
453 .mcp
454 .tools
455 .iter()
456 .any(|t| t.server_id == p.server_id && t.name == p.name)),
457 "pruned set must be a subset of self.mcp.tools"
458 );
459 if let Some(ref shared) = self.mcp.shared_tools {
460 let mut guard = shared
461 .write()
462 .unwrap_or_else(std::sync::PoisonError::into_inner);
463 *guard = pruned;
464 }
465 }
466
467 pub(super) async fn sync_mcp_registry(&mut self) {
468 let Some(ref mut registry) = self.mcp.registry else {
469 return;
470 };
471 if !self.embedding_provider.supports_embeddings() {
472 return;
473 }
474 let provider = self.embedding_provider.clone();
475 let embed_fn = |text: &str| -> zeph_mcp::registry::EmbedFuture {
476 let owned = text.to_owned();
477 let p = provider.clone();
478 Box::pin(async move { p.embed(&owned).await })
479 };
480 if let Err(e) = registry
481 .sync(&self.mcp.tools, &self.skill_state.embedding_model, embed_fn)
482 .await
483 {
484 tracing::warn!("failed to sync MCP tool registry: {e:#}");
485 }
486 }
487
488 pub async fn init_semantic_index(&mut self) {
495 self.rebuild_semantic_index().await;
496 }
497
498 pub(super) async fn process_pending_elicitations(&mut self) {
503 loop {
504 let Some(ref mut rx) = self.mcp.elicitation_rx else {
505 return;
506 };
507 match rx.try_recv() {
508 Ok(event) => {
509 self.handle_elicitation_event(event).await;
510 }
511 Err(tokio::sync::mpsc::error::TryRecvError::Empty) => return,
512 Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
513 self.mcp.elicitation_rx = None;
514 return;
515 }
516 }
517 }
518 }
519
520 pub(super) async fn handle_elicitation_event(&mut self, event: zeph_mcp::ElicitationEvent) {
522 use crate::channel::{ElicitationRequest, ElicitationResponse};
523
524 let decline = CreateElicitationResult {
525 action: ElicitationAction::Decline,
526 content: None,
527 };
528
529 let channel_request = match &event.request {
530 rmcp::model::CreateElicitationRequestParams::FormElicitationParams {
531 message,
532 requested_schema,
533 ..
534 } => {
535 let fields = build_elicitation_fields(requested_schema);
536 ElicitationRequest {
537 server_name: event.server_id.clone(),
538 message: sanitize_elicitation_message(message),
539 fields,
540 }
541 }
542 rmcp::model::CreateElicitationRequestParams::UrlElicitationParams { .. } => {
543 tracing::debug!(
545 server_id = event.server_id,
546 "URL elicitation not supported, declining"
547 );
548 let _ = event.response_tx.send(decline);
549 return;
550 }
551 };
552
553 if self.mcp.elicitation_warn_sensitive_fields {
554 let sensitive: Vec<&str> = channel_request
555 .fields
556 .iter()
557 .filter(|f| is_sensitive_field(&f.name))
558 .map(|f| f.name.as_str())
559 .collect();
560 if !sensitive.is_empty() {
561 let fields_list = sensitive.join(", ");
562 let warning = format!(
563 "Warning: [{}] is requesting sensitive information (field: {}). \
564 Only proceed if you trust this server.",
565 channel_request.server_name, fields_list,
566 );
567 tracing::warn!(
568 server_id = event.server_id,
569 fields = %fields_list,
570 "elicitation requests sensitive fields"
571 );
572 let _ = self.channel.send(&warning).await;
573 }
574 }
575
576 let _ = self
577 .channel
578 .send_status("MCP server requesting input…")
579 .await;
580 let response = match self.channel.elicit(channel_request).await {
581 Ok(r) => r,
582 Err(e) => {
583 tracing::warn!(
584 server_id = event.server_id,
585 "elicitation channel error: {e:#}"
586 );
587 let _ = self.channel.send_status("").await;
588 let _ = event.response_tx.send(decline);
589 return;
590 }
591 };
592 let _ = self.channel.send_status("").await;
593
594 let result = match response {
595 ElicitationResponse::Accepted(value) => CreateElicitationResult {
596 action: ElicitationAction::Accept,
597 content: Some(value),
598 },
599 ElicitationResponse::Declined => CreateElicitationResult {
600 action: ElicitationAction::Decline,
601 content: None,
602 },
603 ElicitationResponse::Cancelled => CreateElicitationResult {
604 action: ElicitationAction::Cancel,
605 content: None,
606 },
607 };
608
609 if event.response_tx.send(result).is_err() {
610 tracing::warn!(
611 server_id = event.server_id,
612 "elicitation response dropped — handler disconnected"
613 );
614 }
615 }
616
617 pub(in crate::agent) async fn rebuild_semantic_index(&mut self) {
627 if self.mcp.discovery_strategy != zeph_mcp::ToolDiscoveryStrategy::Embedding {
628 return;
629 }
630
631 if self.mcp.tools.is_empty() {
632 self.mcp.semantic_index = None;
633 return;
634 }
635
636 let provider = self
638 .mcp
639 .discovery_provider
640 .clone()
641 .unwrap_or_else(|| self.embedding_provider.clone());
642
643 let embed_fn = provider.embed_fn();
644
645 match zeph_mcp::SemanticToolIndex::build(&self.mcp.tools, &embed_fn).await {
646 Ok(idx) => {
647 tracing::info!(
648 indexed = idx.len(),
649 total = self.mcp.tools.len(),
650 "semantic tool index built"
651 );
652 self.mcp.semantic_index = Some(idx);
653 }
654 Err(e) => {
655 tracing::warn!(
656 "semantic tool index build failed, falling back to all tools: {e:#}"
657 );
658 self.mcp.semantic_index = None;
659 }
660 }
661 }
662}
663
664fn build_elicitation_fields(
666 schema: &rmcp::model::ElicitationSchema,
667) -> Vec<crate::channel::ElicitationField> {
668 use crate::channel::{ElicitationField, ElicitationFieldType};
669 use rmcp::model::PrimitiveSchema;
670
671 schema
672 .properties
673 .iter()
674 .map(|(name, prop)| {
675 let json = serde_json::to_value(prop).unwrap_or_default();
679 let description = json
680 .get("description")
681 .and_then(|v| v.as_str())
682 .map(String::from);
683
684 let field_type = match prop {
685 PrimitiveSchema::Boolean(_) => ElicitationFieldType::Boolean,
686 PrimitiveSchema::Integer(_) => ElicitationFieldType::Integer,
687 PrimitiveSchema::Number(_) => ElicitationFieldType::Number,
688 PrimitiveSchema::String(_) => ElicitationFieldType::String,
689 PrimitiveSchema::Enum(_) => {
690 let vals = json
693 .get("enum")
694 .and_then(|v| v.as_array())
695 .map(|arr| {
696 arr.iter()
697 .filter_map(|v| v.as_str())
698 .map(String::from)
699 .collect::<Vec<_>>()
700 })
701 .unwrap_or_default();
702 ElicitationFieldType::Enum(vals)
703 }
704 };
705 let required = schema.required.as_deref().is_some_and(|r| r.contains(name));
706 ElicitationField {
707 name: name.clone(),
708 description,
709 field_type,
710 required,
711 }
712 })
713 .collect()
714}
715
716const SENSITIVE_FIELD_PATTERNS: &[&str] = &[
718 "password",
719 "passwd",
720 "token",
721 "secret",
722 "key",
723 "credential",
724 "apikey",
725 "api_key",
726 "auth",
727 "authorization",
728 "private",
729 "passphrase",
730 "pin",
731];
732
733fn is_sensitive_field(field_name: &str) -> bool {
735 let lower = field_name.to_lowercase();
736 SENSITIVE_FIELD_PATTERNS
737 .iter()
738 .any(|pattern| lower.contains(pattern))
739}
740
741fn sanitize_elicitation_message(message: &str) -> String {
743 const MAX_CHARS: usize = 500;
744 message
746 .chars()
747 .filter(|c| !c.is_control() || *c == '\n' || *c == '\t')
748 .take(MAX_CHARS)
749 .collect()
750}
751
752#[cfg(test)]
753mod tests {
754 use super::super::agent_tests::{
755 MockChannel, MockToolExecutor, create_test_registry, mock_provider,
756 };
757 use super::*;
758
759 #[tokio::test]
760 async fn handle_mcp_command_unknown_subcommand_shows_usage() {
761 let provider = mock_provider(vec![]);
762 let channel = MockChannel::new(vec![]);
763 let registry = create_test_registry();
764 let executor = MockToolExecutor::no_tools();
765 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
766
767 agent.handle_mcp_command("unknown").await.unwrap();
768
769 let sent = agent.channel.sent_messages();
770 assert!(
771 sent.iter().any(|s| s.contains("Usage: /mcp")),
772 "expected usage message, got: {sent:?}"
773 );
774 }
775
776 #[tokio::test]
777 async fn handle_mcp_list_no_manager_shows_disabled() {
778 let provider = mock_provider(vec![]);
779 let channel = MockChannel::new(vec![]);
780 let registry = create_test_registry();
781 let executor = MockToolExecutor::no_tools();
782 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
783
784 agent.handle_mcp_command("list").await.unwrap();
785
786 let sent = agent.channel.sent_messages();
787 assert!(
788 sent.iter().any(|s| s.contains("MCP is not enabled")),
789 "expected not-enabled message, got: {sent:?}"
790 );
791 }
792
793 #[tokio::test]
794 async fn handle_mcp_tools_no_server_id_shows_usage() {
795 let provider = mock_provider(vec![]);
796 let channel = MockChannel::new(vec![]);
797 let registry = create_test_registry();
798 let executor = MockToolExecutor::no_tools();
799 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
800
801 agent.handle_mcp_command("tools").await.unwrap();
802
803 let sent = agent.channel.sent_messages();
804 assert!(
805 sent.iter().any(|s| s.contains("Usage: /mcp tools")),
806 "expected tools usage message, got: {sent:?}"
807 );
808 }
809
810 #[tokio::test]
811 async fn handle_mcp_remove_no_server_id_shows_usage() {
812 let provider = mock_provider(vec![]);
813 let channel = MockChannel::new(vec![]);
814 let registry = create_test_registry();
815 let executor = MockToolExecutor::no_tools();
816 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
817
818 agent.handle_mcp_command("remove").await.unwrap();
819
820 let sent = agent.channel.sent_messages();
821 assert!(
822 sent.iter().any(|s| s.contains("Usage: /mcp remove")),
823 "expected remove usage message, got: {sent:?}"
824 );
825 }
826
827 #[tokio::test]
828 async fn handle_mcp_remove_no_manager_shows_disabled() {
829 let provider = mock_provider(vec![]);
830 let channel = MockChannel::new(vec![]);
831 let registry = create_test_registry();
832 let executor = MockToolExecutor::no_tools();
833 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
834
835 agent.handle_mcp_command("remove my-server").await.unwrap();
837
838 let sent = agent.channel.sent_messages();
839 assert!(
840 sent.iter().any(|s| s.contains("MCP is not enabled")),
841 "expected not-enabled message, got: {sent:?}"
842 );
843 }
844
845 #[tokio::test]
846 async fn handle_mcp_add_insufficient_args_shows_usage() {
847 let provider = mock_provider(vec![]);
848 let channel = MockChannel::new(vec![]);
849 let registry = create_test_registry();
850 let executor = MockToolExecutor::no_tools();
851 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
852
853 agent.handle_mcp_command("add server-id").await.unwrap();
855
856 let sent = agent.channel.sent_messages();
857 assert!(
858 sent.iter().any(|s| s.contains("Usage: /mcp add")),
859 "expected add usage message, got: {sent:?}"
860 );
861 }
862
863 #[tokio::test]
864 async fn handle_mcp_tools_with_unknown_server_shows_no_tools() {
865 let provider = mock_provider(vec![]);
866 let channel = MockChannel::new(vec![]);
867 let registry = create_test_registry();
868 let executor = MockToolExecutor::no_tools();
869 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
870
871 agent
873 .handle_mcp_command("tools nonexistent-server")
874 .await
875 .unwrap();
876
877 let sent = agent.channel.sent_messages();
878 assert!(
879 sent.iter().any(|s| s.contains("No tools found")),
880 "expected no-tools message, got: {sent:?}"
881 );
882 }
883
884 #[tokio::test]
885 async fn mcp_tool_count_starts_at_zero() {
886 let provider = mock_provider(vec![]);
887 let channel = MockChannel::new(vec![]);
888 let registry = create_test_registry();
889 let executor = MockToolExecutor::no_tools();
890 let agent = Agent::new(provider, channel, registry, None, 5, executor);
891
892 assert_eq!(agent.mcp_tool_count(), 0);
893 }
894
895 #[tokio::test]
896 async fn check_tool_refresh_no_rx_is_noop() {
897 let provider = mock_provider(vec![]);
898 let channel = MockChannel::new(vec![]);
899 let registry = create_test_registry();
900 let executor = MockToolExecutor::no_tools();
901 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
902 agent.check_tool_refresh().await;
904 assert_eq!(agent.mcp_tool_count(), 0);
905 }
906
907 #[tokio::test]
908 async fn check_tool_refresh_no_change_is_noop() {
909 let provider = mock_provider(vec![]);
910 let channel = MockChannel::new(vec![]);
911 let registry = create_test_registry();
912 let executor = MockToolExecutor::no_tools();
913 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
914
915 let (tx, rx) = tokio::sync::watch::channel(Vec::new());
916 agent.mcp.tool_rx = Some(rx);
917 agent.check_tool_refresh().await;
919 assert_eq!(agent.mcp_tool_count(), 0);
920 drop(tx);
921 }
922
923 #[tokio::test]
924 async fn check_tool_refresh_with_empty_initial_value_does_not_replace_tools() {
925 let provider = mock_provider(vec![]);
926 let channel = MockChannel::new(vec![]);
927 let registry = create_test_registry();
928 let executor = MockToolExecutor::no_tools();
929 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
930 agent.mcp.tools = vec![zeph_mcp::McpTool {
931 server_id: "srv".into(),
932 name: "existing_tool".into(),
933 description: String::new(),
934 input_schema: serde_json::json!({}),
935 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
936 }];
937
938 let (_tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
939 agent.mcp.tool_rx = Some(rx);
940 agent.check_tool_refresh().await;
942 assert_eq!(agent.mcp_tool_count(), 1);
943 }
944
945 #[tokio::test]
946 async fn check_tool_refresh_applies_update() {
947 let provider = mock_provider(vec![]);
948 let channel = MockChannel::new(vec![]);
949 let registry = create_test_registry();
950 let executor = MockToolExecutor::no_tools();
951 let mut agent = Agent::new(provider, channel, registry, None, 5, executor);
952
953 let (tx, rx) = tokio::sync::watch::channel(Vec::<zeph_mcp::McpTool>::new());
954 agent.mcp.tool_rx = Some(rx);
955
956 let new_tools = vec![zeph_mcp::McpTool {
957 server_id: "srv".into(),
958 name: "refreshed_tool".into(),
959 description: String::new(),
960 input_schema: serde_json::json!({}),
961 security_meta: zeph_mcp::tool::ToolSecurityMeta::default(),
962 }];
963 tx.send(new_tools).unwrap();
964
965 agent.check_tool_refresh().await;
966 assert_eq!(agent.mcp_tool_count(), 1);
967 assert_eq!(agent.mcp.tools[0].name, "refreshed_tool");
968 }
969
970 #[test]
971 fn sanitize_elicitation_message_strips_control_chars() {
972 let input = "hello\x01world\x1b[31mred\x1b[0m";
973 let output = sanitize_elicitation_message(input);
974 assert!(!output.contains('\x01'));
975 assert!(!output.contains('\x1b'));
976 assert!(output.contains("hello"));
977 assert!(output.contains("world"));
978 }
979
980 #[test]
981 fn sanitize_elicitation_message_preserves_newline_and_tab() {
982 let input = "line1\nline2\ttabbed";
983 let output = sanitize_elicitation_message(input);
984 assert_eq!(output, "line1\nline2\ttabbed");
985 }
986
987 #[test]
988 fn sanitize_elicitation_message_caps_at_500_chars() {
989 let input: String = "a".repeat(600);
991 let output = sanitize_elicitation_message(&input);
992 assert_eq!(output.chars().count(), 500);
993 }
994
995 #[test]
996 fn sanitize_elicitation_message_handles_multibyte_boundary() {
997 let input: String = "é".repeat(300); let output = sanitize_elicitation_message(&input);
1000 assert_eq!(output.chars().count(), 300);
1002 }
1003
1004 #[test]
1005 fn build_elicitation_fields_maps_primitive_types() {
1006 use crate::channel::ElicitationFieldType;
1007 use rmcp::model::{
1008 BooleanSchema, ElicitationSchema, IntegerSchema, NumberSchema, PrimitiveSchema,
1009 StringSchema,
1010 };
1011 use std::collections::BTreeMap;
1012
1013 let mut props = BTreeMap::new();
1014 props.insert(
1015 "flag".to_owned(),
1016 PrimitiveSchema::Boolean(BooleanSchema::new()),
1017 );
1018 props.insert(
1019 "count".to_owned(),
1020 PrimitiveSchema::Integer(IntegerSchema::new()),
1021 );
1022 props.insert(
1023 "ratio".to_owned(),
1024 PrimitiveSchema::Number(NumberSchema::new()),
1025 );
1026 props.insert(
1027 "name".to_owned(),
1028 PrimitiveSchema::String(StringSchema::new()),
1029 );
1030
1031 let schema = ElicitationSchema::new(props);
1032 let fields = build_elicitation_fields(&schema);
1033
1034 let get = |n: &str| fields.iter().find(|f| f.name == n).unwrap();
1035 assert!(matches!(
1036 get("flag").field_type,
1037 ElicitationFieldType::Boolean
1038 ));
1039 assert!(matches!(
1040 get("count").field_type,
1041 ElicitationFieldType::Integer
1042 ));
1043 assert!(matches!(
1044 get("ratio").field_type,
1045 ElicitationFieldType::Number
1046 ));
1047 assert!(matches!(
1048 get("name").field_type,
1049 ElicitationFieldType::String
1050 ));
1051 }
1052
1053 #[test]
1054 fn build_elicitation_fields_required_flag() {
1055 use rmcp::model::{ElicitationSchema, PrimitiveSchema, StringSchema};
1056 use std::collections::BTreeMap;
1057
1058 let mut props = BTreeMap::new();
1059 props.insert(
1060 "req".to_owned(),
1061 PrimitiveSchema::String(StringSchema::new()),
1062 );
1063 props.insert(
1064 "opt".to_owned(),
1065 PrimitiveSchema::String(StringSchema::new()),
1066 );
1067
1068 let mut schema = ElicitationSchema::new(props);
1069 schema.required = Some(vec!["req".to_owned()]);
1070
1071 let fields = build_elicitation_fields(&schema);
1072 let req = fields.iter().find(|f| f.name == "req").unwrap();
1073 let opt = fields.iter().find(|f| f.name == "opt").unwrap();
1074 assert!(req.required);
1075 assert!(!opt.required);
1076 }
1077
1078 #[test]
1079 fn is_sensitive_field_detects_common_patterns() {
1080 assert!(is_sensitive_field("password"));
1081 assert!(is_sensitive_field("PASSWORD"));
1082 assert!(is_sensitive_field("user_password"));
1083 assert!(is_sensitive_field("api_token"));
1084 assert!(is_sensitive_field("SECRET_KEY"));
1085 assert!(is_sensitive_field("auth_header"));
1086 assert!(is_sensitive_field("private_key"));
1087 }
1088
1089 #[test]
1090 fn is_sensitive_field_allows_non_sensitive_names() {
1091 assert!(!is_sensitive_field("username"));
1092 assert!(!is_sensitive_field("email"));
1093 assert!(!is_sensitive_field("message"));
1094 assert!(!is_sensitive_field("description"));
1095 assert!(!is_sensitive_field("subject"));
1096 }
1097}