1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12use crate::integrations::mcp::{McpClient, McpTool, VerificationStatus};
13use crate::logging::{ModelInteractionType, ModelLogger, RequestData, ResponseData};
14use crate::routing::{error::TaskType, ModelRequest, RoutingContext, RoutingEngine};
15use crate::types::{AgentId, RuntimeError};
16use dashmap::DashMap;
17use std::sync::Arc;
18
19#[derive(Error, Debug, Clone)]
21pub enum ToolInvocationError {
22 #[error("Tool invocation blocked: {tool_name} - {reason}")]
23 InvocationBlocked { tool_name: String, reason: String },
24
25 #[error("Tool not found: {tool_name}")]
26 ToolNotFound { tool_name: String },
27
28 #[error("Verification required but tool is not verified: {tool_name} (status: {status})")]
29 VerificationRequired { tool_name: String, status: String },
30
31 #[error("Tool verification failed: {tool_name} - {reason}")]
32 VerificationFailed { tool_name: String, reason: String },
33
34 #[error("Enforcement policy violation: {policy} - {reason}")]
35 PolicyViolation { policy: String, reason: String },
36
37 #[error("Tool invocation timeout: {tool_name}")]
38 Timeout { tool_name: String },
39
40 #[error("Tool execution failed: {tool_name} - {reason}")]
41 ExecutionFailed { tool_name: String, reason: String },
42
43 #[error("No MCP client configured for tool execution: {reason}")]
44 NoMcpClient { reason: String },
45
46 #[error("Runtime error during tool invocation: {source}")]
47 Runtime {
48 #[from]
49 source: RuntimeError,
50 },
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
55pub enum EnforcementPolicy {
56 #[default]
58 Strict,
59 Permissive,
61 Development,
63 Disabled,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct InvocationEnforcementConfig {
70 pub policy: EnforcementPolicy,
72 pub block_failed_verification: bool,
74 pub block_pending_verification: bool,
76 pub allow_skipped_in_dev: bool,
78 pub verification_timeout: Duration,
80 pub max_warnings_before_escalation: usize,
82}
83
84impl Default for InvocationEnforcementConfig {
85 fn default() -> Self {
86 Self {
87 policy: EnforcementPolicy::Strict,
88 block_failed_verification: true,
89 block_pending_verification: true,
90 allow_skipped_in_dev: false,
91 verification_timeout: Duration::from_secs(5),
92 max_warnings_before_escalation: 10,
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct InvocationContext {
100 pub agent_id: AgentId,
102 pub tool_name: String,
104 pub arguments: serde_json::Value,
106 pub timestamp: chrono::DateTime<chrono::Utc>,
108 pub metadata: HashMap<String, String>,
110 pub agent_credential: Option<crate::integrations::agentpin::AgentVerificationResult>,
112}
113
114#[derive(Debug, Clone)]
116pub struct InvocationResult {
117 pub success: bool,
119 pub result: serde_json::Value,
121 pub execution_time: Duration,
123 pub warnings: Vec<String>,
125 pub metadata: HashMap<String, String>,
127}
128
129#[derive(Debug, Clone)]
131pub enum EnforcementDecision {
132 Allow,
134 Block { reason: String },
136 AllowWithWarnings { warnings: Vec<String> },
138}
139
140#[async_trait]
142pub trait ToolInvocationEnforcer: Send + Sync {
143 async fn check_invocation_allowed(
145 &self,
146 tool: &McpTool,
147 context: &InvocationContext,
148 ) -> Result<EnforcementDecision, ToolInvocationError>;
149
150 async fn execute_tool_with_enforcement(
152 &self,
153 tool: &McpTool,
154 context: InvocationContext,
155 ) -> Result<InvocationResult, ToolInvocationError>;
156
157 fn get_enforcement_config(&self) -> &InvocationEnforcementConfig;
159
160 fn update_enforcement_config(&mut self, config: InvocationEnforcementConfig);
162}
163
164pub fn mask_sensitive_arguments(
169 arguments: &serde_json::Value,
170 sensitive_params: &[String],
171) -> serde_json::Value {
172 if sensitive_params.is_empty() {
173 return arguments.clone();
174 }
175
176 match arguments {
177 serde_json::Value::Object(map) => {
178 let mut masked = serde_json::Map::new();
179 for (key, value) in map {
180 if sensitive_params.iter().any(|p| p == key) {
181 masked.insert(
182 key.clone(),
183 serde_json::Value::String(format!("[REDACTED:{}]", key)),
184 );
185 } else {
186 masked.insert(
187 key.clone(),
188 mask_sensitive_arguments(value, sensitive_params),
189 );
190 }
191 }
192 serde_json::Value::Object(masked)
193 }
194 serde_json::Value::Array(arr) => serde_json::Value::Array(
195 arr.iter()
196 .map(|v| mask_sensitive_arguments(v, sensitive_params))
197 .collect(),
198 ),
199 other => other.clone(),
200 }
201}
202
203pub struct DefaultToolInvocationEnforcer {
205 config: InvocationEnforcementConfig,
206 warning_counts: DashMap<String, usize>,
207 model_logger: Option<Arc<ModelLogger>>,
208 routing_engine: Option<Arc<dyn RoutingEngine>>,
209 mcp_client: Option<Arc<dyn McpClient>>,
210}
211
212impl DefaultToolInvocationEnforcer {
213 pub fn new() -> Self {
215 Self {
216 config: InvocationEnforcementConfig::default(),
217 warning_counts: DashMap::new(),
218 model_logger: None,
219 routing_engine: None,
220 mcp_client: None,
221 }
222 }
223
224 pub fn with_config(config: InvocationEnforcementConfig) -> Self {
226 Self {
227 config,
228 warning_counts: DashMap::new(),
229 model_logger: None,
230 routing_engine: None,
231 mcp_client: None,
232 }
233 }
234
235 pub fn with_logger(config: InvocationEnforcementConfig, logger: Arc<ModelLogger>) -> Self {
237 Self {
238 config,
239 warning_counts: DashMap::new(),
240 model_logger: Some(logger),
241 routing_engine: None,
242 mcp_client: None,
243 }
244 }
245
246 pub fn with_routing(
248 config: InvocationEnforcementConfig,
249 logger: Option<Arc<ModelLogger>>,
250 routing_engine: Arc<dyn RoutingEngine>,
251 ) -> Self {
252 Self {
253 config,
254 warning_counts: DashMap::new(),
255 model_logger: logger,
256 routing_engine: Some(routing_engine),
257 mcp_client: None,
258 }
259 }
260
261 pub fn with_mcp_client(
263 config: InvocationEnforcementConfig,
264 mcp_client: Arc<dyn McpClient>,
265 ) -> Self {
266 Self {
267 config,
268 warning_counts: DashMap::new(),
269 model_logger: None,
270 routing_engine: None,
271 mcp_client: Some(mcp_client),
272 }
273 }
274
275 fn check_verification_status(&self, tool: &McpTool) -> EnforcementDecision {
277 match &self.config.policy {
278 EnforcementPolicy::Disabled => EnforcementDecision::Allow,
279 EnforcementPolicy::Development => {
280 match &tool.verification_status {
281 VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
282 VerificationStatus::Failed { reason, .. } => {
283 if self.config.block_failed_verification {
284 EnforcementDecision::Block {
285 reason: format!("Tool verification failed: {}", reason),
286 }
287 } else {
288 EnforcementDecision::AllowWithWarnings {
289 warnings: vec![format!("Tool '{}' has failed verification: {}", tool.name, reason)],
290 }
291 }
292 }
293 VerificationStatus::Pending => {
294 EnforcementDecision::AllowWithWarnings {
295 warnings: vec![format!("Tool '{}' verification is pending", tool.name)],
296 }
297 }
298 VerificationStatus::Skipped { reason } => {
299 if self.config.allow_skipped_in_dev {
300 EnforcementDecision::AllowWithWarnings {
301 warnings: vec![format!("Tool '{}' verification was skipped: {}", tool.name, reason)],
302 }
303 } else {
304 EnforcementDecision::Block {
305 reason: format!("Tool verification was skipped: {}", reason),
306 }
307 }
308 }
309 }
310 }
311 EnforcementPolicy::Permissive => {
312 match &tool.verification_status {
313 VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
314 VerificationStatus::Failed { reason, .. } => {
315 if self.config.block_failed_verification {
316 EnforcementDecision::Block {
317 reason: format!("Tool verification failed: {}", reason),
318 }
319 } else {
320 EnforcementDecision::AllowWithWarnings {
321 warnings: vec![format!("Tool '{}' has failed verification: {}", tool.name, reason)],
322 }
323 }
324 }
325 VerificationStatus::Pending => {
326 if self.config.block_pending_verification {
327 EnforcementDecision::AllowWithWarnings {
328 warnings: vec![format!("Tool '{}' verification is pending", tool.name)],
329 }
330 } else {
331 EnforcementDecision::Allow
332 }
333 }
334 VerificationStatus::Skipped { reason } => {
335 EnforcementDecision::AllowWithWarnings {
336 warnings: vec![format!("Tool '{}' verification was skipped: {}", tool.name, reason)],
337 }
338 }
339 }
340 }
341 EnforcementPolicy::Strict => {
342 match &tool.verification_status {
343 VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
344 VerificationStatus::Failed { reason, .. } => {
345 EnforcementDecision::Block {
346 reason: format!("Tool verification failed: {}", reason),
347 }
348 }
349 VerificationStatus::Pending => {
350 EnforcementDecision::Block {
351 reason: "Tool verification is pending - only verified tools are allowed in strict mode".to_string(),
352 }
353 }
354 VerificationStatus::Skipped { reason } => {
355 EnforcementDecision::Block {
356 reason: format!("Tool verification was skipped: {} - only verified tools are allowed in strict mode", reason),
357 }
358 }
359 }
360 }
361 }
362 }
363
364 fn handle_warning(&self, tool_name: &str, warning: &str) -> bool {
366 let mut count = self
367 .warning_counts
368 .entry(tool_name.to_string())
369 .or_insert(0);
370 *count += 1;
371
372 if *count >= self.config.max_warnings_before_escalation {
373 tracing::warn!(
374 "Tool '{}' has exceeded warning threshold ({} warnings): {}",
375 tool_name,
376 *count,
377 warning
378 );
379 *count = 0;
381 true
382 } else {
383 tracing::warn!(
384 "Tool '{}' warning (count: {}): {}",
385 tool_name,
386 *count,
387 warning
388 );
389 false
390 }
391 }
392
393 #[allow(dead_code)]
395 async fn route_tool_execution(
396 &self,
397 tool: &McpTool,
398 context: &InvocationContext,
399 ) -> Result<Option<String>, ToolInvocationError> {
400 if let Some(ref routing_engine) = self.routing_engine {
401 let task_type = self.classify_tool_task(tool, context);
403
404 let routing_context = RoutingContext::new(
406 context.agent_id,
407 task_type,
408 format!("Tool: {} - {}", tool.name, tool.description),
409 );
410
411 let _model_request = ModelRequest::from_task(format!(
413 "Execute tool '{}' with arguments: {}",
414 tool.name, context.arguments
415 ));
416
417 match routing_engine.route_request(&routing_context).await {
419 Ok(decision) => {
420 tracing::debug!("Routing decision for tool '{}': {:?}", tool.name, decision);
421 Ok(Some(format!("{:?}", decision)))
423 }
424 Err(e) => {
425 tracing::warn!("Routing failed for tool '{}': {}", tool.name, e);
426 Ok(None)
427 }
428 }
429 } else {
430 Ok(None)
431 }
432 }
433
434 #[allow(dead_code)]
436 fn classify_tool_task(&self, tool: &McpTool, context: &InvocationContext) -> TaskType {
437 let tool_name_lower = tool.name.to_lowercase();
438 let description_lower = tool.description.to_lowercase();
439
440 if tool_name_lower.contains("code")
442 || description_lower.contains("code")
443 || tool_name_lower.contains("program")
444 || description_lower.contains("program")
445 {
446 TaskType::CodeGeneration
447 } else if tool_name_lower.contains("analyze")
448 || description_lower.contains("analy")
449 || tool_name_lower.contains("inspect")
450 || description_lower.contains("inspect")
451 {
452 TaskType::Analysis
453 } else if tool_name_lower.contains("extract")
454 || description_lower.contains("extract")
455 || tool_name_lower.contains("parse")
456 || description_lower.contains("parse")
457 {
458 TaskType::Extract
459 } else if tool_name_lower.contains("summarize") || description_lower.contains("summar") {
460 TaskType::Summarization
461 } else if tool_name_lower.contains("translate") || description_lower.contains("translat") {
462 TaskType::Translation
463 } else if tool_name_lower.contains("reason")
464 || description_lower.contains("reason")
465 || tool_name_lower.contains("logic")
466 || description_lower.contains("logic")
467 {
468 TaskType::Reasoning
469 } else if tool_name_lower.contains("template") || description_lower.contains("template") {
470 TaskType::Template
471 } else if context.arguments.to_string().len() < 100 {
472 TaskType::Intent
474 } else {
475 TaskType::QA
477 }
478 }
479 async fn execute_actual_tool(
482 &self,
483 tool: &McpTool,
484 context: &InvocationContext,
485 _execution_time: Duration,
486 ) -> Result<InvocationResult, ToolInvocationError> {
487 let mcp_client =
488 self.mcp_client
489 .as_ref()
490 .ok_or_else(|| ToolInvocationError::NoMcpClient {
491 reason: format!(
492 "No MCP client configured for tool execution of '{}'",
493 tool.name
494 ),
495 })?;
496
497 tracing::info!(
498 "Executing tool '{}' for agent {} via MCP (provider: '{}')",
499 tool.name,
500 context.agent_id,
501 tool.provider.identifier
502 );
503
504 let result = mcp_client
505 .invoke_tool(&tool.name, context.arguments.clone(), context.clone())
506 .await
507 .map_err(|e| ToolInvocationError::ExecutionFailed {
508 tool_name: tool.name.clone(),
509 reason: format!("MCP invocation failed: {}", e),
510 })?;
511
512 Ok(result)
513 }
514}
515
516impl Default for DefaultToolInvocationEnforcer {
517 fn default() -> Self {
518 Self::new()
519 }
520}
521
522#[async_trait]
523impl ToolInvocationEnforcer for DefaultToolInvocationEnforcer {
524 async fn check_invocation_allowed(
525 &self,
526 tool: &McpTool,
527 _context: &InvocationContext,
528 ) -> Result<EnforcementDecision, ToolInvocationError> {
529 Ok(self.check_verification_status(tool))
530 }
531
532 async fn execute_tool_with_enforcement(
533 &self,
534 tool: &McpTool,
535 context: InvocationContext,
536 ) -> Result<InvocationResult, ToolInvocationError> {
537 let start_time = Instant::now();
538
539 let decision = self.check_invocation_allowed(tool, &context).await?;
541
542 let redacted_arguments =
544 mask_sensitive_arguments(&context.arguments, &tool.sensitive_params);
545
546 let request_data = RequestData {
548 prompt: format!("Tool invocation: {}", tool.name),
549 tool_name: Some(tool.name.clone()),
550 tool_arguments: Some(redacted_arguments),
551 parameters: {
552 let mut params = HashMap::new();
553 params.insert(
554 "verification_status".to_string(),
555 serde_json::Value::String(format!("{:?}", tool.verification_status)),
556 );
557 params.insert(
558 "enforcement_policy".to_string(),
559 serde_json::Value::String(format!("{:?}", self.config.policy)),
560 );
561 params
562 },
563 };
564
565 match decision {
566 EnforcementDecision::Allow => {
567 let execution_time = start_time.elapsed();
568
569 let response_data = ResponseData {
571 content: "Tool invocation allowed and executed".to_string(),
572 tool_result: Some(
573 serde_json::json!({"status": "success", "message": "Tool invocation allowed"}),
574 ),
575 confidence: Some(1.0),
576 metadata: HashMap::new(),
577 };
578
579 if let Some(ref logger) = self.model_logger {
581 let metadata = {
582 let mut meta = HashMap::new();
583 meta.insert(
584 "tool_provider".to_string(),
585 tool.provider.identifier.clone(),
586 );
587 meta.insert("enforcement_decision".to_string(), "allow".to_string());
588 meta.insert("agent_id".to_string(), context.agent_id.to_string());
589 meta
590 };
591
592 if let Err(e) = logger
593 .log_interaction(
594 context.agent_id,
595 ModelInteractionType::ToolCall,
596 &tool.name,
597 request_data,
598 response_data,
599 execution_time,
600 metadata,
601 None, None,
603 )
604 .await
605 {
606 tracing::warn!("Failed to log tool invocation: {}", e);
607 }
608 }
609
610 self.execute_actual_tool(tool, &context, execution_time)
612 .await
613 }
614 EnforcementDecision::Block { reason } => {
615 let execution_time = start_time.elapsed();
616
617 if let Some(ref logger) = self.model_logger {
619 let response_data = ResponseData {
620 content: "Tool invocation blocked".to_string(),
621 tool_result: Some(
622 serde_json::json!({"status": "blocked", "reason": &reason}),
623 ),
624 confidence: Some(1.0),
625 metadata: HashMap::new(),
626 };
627
628 let metadata = {
629 let mut meta = HashMap::new();
630 meta.insert(
631 "tool_provider".to_string(),
632 tool.provider.identifier.clone(),
633 );
634 meta.insert("enforcement_decision".to_string(), "block".to_string());
635 meta.insert("agent_id".to_string(), context.agent_id.to_string());
636 meta
637 };
638
639 if let Err(e) = logger
640 .log_interaction(
641 context.agent_id,
642 ModelInteractionType::ToolCall,
643 &tool.name,
644 request_data,
645 response_data,
646 execution_time,
647 metadata,
648 None,
649 Some(reason.clone()),
650 )
651 .await
652 {
653 tracing::warn!("Failed to log blocked tool invocation: {}", e);
654 }
655 }
656
657 Err(ToolInvocationError::InvocationBlocked {
658 tool_name: tool.name.clone(),
659 reason,
660 })
661 }
662 EnforcementDecision::AllowWithWarnings { warnings } => {
663 let execution_time = start_time.elapsed();
664
665 let mut escalated = false;
667 for warning in &warnings {
668 if self.handle_warning(&tool.name, warning) {
669 escalated = true;
670 }
671 }
672
673 let response_data = ResponseData {
675 content: "Tool invocation allowed with warnings".to_string(),
676 tool_result: Some(serde_json::json!({
677 "status": "success",
678 "message": "Tool invocation allowed with warnings",
679 "warnings": &warnings
680 })),
681 confidence: Some(0.8), metadata: HashMap::new(),
683 };
684
685 if let Some(ref logger) = self.model_logger {
687 let metadata = {
688 let mut meta = HashMap::new();
689 meta.insert(
690 "tool_provider".to_string(),
691 tool.provider.identifier.clone(),
692 );
693 meta.insert(
694 "enforcement_decision".to_string(),
695 "allow_with_warnings".to_string(),
696 );
697 meta.insert("agent_id".to_string(), context.agent_id.to_string());
698 meta.insert("warnings_count".to_string(), warnings.len().to_string());
699 if escalated {
700 meta.insert("escalated".to_string(), "true".to_string());
701 }
702 meta
703 };
704
705 if let Err(e) = logger
706 .log_interaction(
707 context.agent_id,
708 ModelInteractionType::ToolCall,
709 &tool.name,
710 request_data,
711 response_data,
712 execution_time,
713 metadata,
714 None,
715 None,
716 )
717 .await
718 {
719 tracing::warn!("Failed to log tool invocation with warnings: {}", e);
720 }
721 }
722
723 let mut result = self
725 .execute_actual_tool(tool, &context, execution_time)
726 .await?;
727 result.warnings.extend(warnings);
728 if escalated {
729 result
730 .metadata
731 .insert("escalated".to_string(), "true".to_string());
732 }
733 Ok(result)
734 }
735 }
736 }
737
738 fn get_enforcement_config(&self) -> &InvocationEnforcementConfig {
739 &self.config
740 }
741
742 fn update_enforcement_config(&mut self, config: InvocationEnforcementConfig) {
743 self.config = config;
744 }
745}
746
747#[cfg(test)]
748mod tests {
749 use super::*;
750 use crate::integrations::mcp::{McpTool, MockMcpClient, ToolProvider, VerificationStatus};
751 use crate::integrations::schemapin::VerificationResult;
752
753 fn create_test_tool(verification_status: VerificationStatus) -> McpTool {
754 McpTool {
755 name: "test_tool".to_string(),
756 description: "A test tool".to_string(),
757 schema: serde_json::json!({"type": "object"}),
758 provider: ToolProvider {
759 identifier: "test.example.com".to_string(),
760 name: "Test Provider".to_string(),
761 public_key_url: "https://test.example.com/pubkey".to_string(),
762 version: Some("1.0.0".to_string()),
763 },
764 verification_status,
765 metadata: None,
766 sensitive_params: vec![],
767 }
768 }
769
770 fn create_test_context() -> InvocationContext {
771 InvocationContext {
772 agent_id: AgentId::new(),
773 tool_name: "test_tool".to_string(),
774 arguments: serde_json::json!({"test": "value"}),
775 timestamp: chrono::Utc::now(),
776 metadata: HashMap::new(),
777 agent_credential: None,
778 }
779 }
780
781 async fn create_enforcer_with_mock_mcp(
782 config: InvocationEnforcementConfig,
783 tool: &McpTool,
784 ) -> DefaultToolInvocationEnforcer {
785 let mock_client = Arc::new(MockMcpClient::new_success());
786 let _ = mock_client.discover_tool(tool.clone()).await;
788 DefaultToolInvocationEnforcer {
789 config,
790 warning_counts: DashMap::new(),
791 model_logger: None,
792 routing_engine: None,
793 mcp_client: Some(mock_client),
794 }
795 }
796
797 #[tokio::test]
798 async fn test_strict_mode_allows_verified_tools() {
799 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
800 policy: EnforcementPolicy::Strict,
801 ..Default::default()
802 });
803
804 let tool = create_test_tool(VerificationStatus::Verified {
805 result: Box::new(VerificationResult {
806 success: true,
807 message: "Test verification".to_string(),
808 schema_hash: Some("hash123".to_string()),
809 public_key_url: Some("https://test.example.com/pubkey".to_string()),
810 signature: None,
811 metadata: None,
812 timestamp: Some("2024-01-01T00:00:00Z".to_string()),
813 }),
814 verified_at: "2024-01-01T00:00:00Z".to_string(),
815 });
816
817 let context = create_test_context();
818 let decision = enforcer
819 .check_invocation_allowed(&tool, &context)
820 .await
821 .unwrap();
822
823 assert!(matches!(decision, EnforcementDecision::Allow));
824 }
825
826 #[tokio::test]
827 async fn test_strict_mode_blocks_unverified_tools() {
828 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
829 policy: EnforcementPolicy::Strict,
830 ..Default::default()
831 });
832
833 let tool = create_test_tool(VerificationStatus::Pending);
834 let context = create_test_context();
835 let decision = enforcer
836 .check_invocation_allowed(&tool, &context)
837 .await
838 .unwrap();
839
840 assert!(matches!(decision, EnforcementDecision::Block { .. }));
841 }
842
843 #[tokio::test]
844 async fn test_permissive_mode_allows_with_warnings() {
845 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
846 policy: EnforcementPolicy::Permissive,
847 block_pending_verification: true,
848 ..Default::default()
849 });
850
851 let tool = create_test_tool(VerificationStatus::Pending);
852 let context = create_test_context();
853 let decision = enforcer
854 .check_invocation_allowed(&tool, &context)
855 .await
856 .unwrap();
857
858 assert!(matches!(
859 decision,
860 EnforcementDecision::AllowWithWarnings { .. }
861 ));
862 }
863
864 #[tokio::test]
865 async fn test_disabled_mode_allows_all_tools() {
866 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
867 policy: EnforcementPolicy::Disabled,
868 ..Default::default()
869 });
870
871 let tool = create_test_tool(VerificationStatus::Failed {
872 reason: "Test failure".to_string(),
873 failed_at: "2024-01-01T00:00:00Z".to_string(),
874 });
875 let context = create_test_context();
876 let decision = enforcer
877 .check_invocation_allowed(&tool, &context)
878 .await
879 .unwrap();
880
881 assert!(matches!(decision, EnforcementDecision::Allow));
882 }
883
884 #[tokio::test]
885 async fn test_execute_tool_blocks_unverified_in_strict_mode() {
886 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
887 policy: EnforcementPolicy::Strict,
888 ..Default::default()
889 });
890
891 let tool = create_test_tool(VerificationStatus::Pending);
892 let context = create_test_context();
893 let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
894
895 assert!(result.is_err());
896 assert!(matches!(
897 result.unwrap_err(),
898 ToolInvocationError::InvocationBlocked { .. }
899 ));
900 }
901
902 #[tokio::test]
903 async fn test_execute_tool_succeeds_with_warnings() {
904 let tool = create_test_tool(VerificationStatus::Pending);
905 let enforcer = create_enforcer_with_mock_mcp(
906 InvocationEnforcementConfig {
907 policy: EnforcementPolicy::Permissive,
908 block_pending_verification: true,
909 ..Default::default()
910 },
911 &tool,
912 )
913 .await;
914
915 let context = create_test_context();
916 let result = enforcer
917 .execute_tool_with_enforcement(&tool, context)
918 .await
919 .unwrap();
920
921 assert!(result.success);
922 assert!(!result.warnings.is_empty());
923 }
924
925 #[tokio::test]
926 async fn test_execute_tool_fails_without_mcp_client() {
927 let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
928 policy: EnforcementPolicy::Disabled,
929 ..Default::default()
930 });
931
932 let tool = create_test_tool(VerificationStatus::Pending);
933 let context = create_test_context();
934 let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
935
936 assert!(result.is_err());
937 assert!(matches!(
938 result.unwrap_err(),
939 ToolInvocationError::NoMcpClient { .. }
940 ));
941 }
942
943 #[test]
944 fn test_mask_sensitive_arguments_empty_list() {
945 let args = serde_json::json!({"user": "alice", "password": "s3cret"});
946 let masked = mask_sensitive_arguments(&args, &[]);
947 assert_eq!(masked, args);
948 }
949
950 #[test]
951 fn test_mask_sensitive_arguments_flat() {
952 let args = serde_json::json!({"user": "alice", "password": "s3cret", "token": "abc"});
953 let masked =
954 mask_sensitive_arguments(&args, &["password".to_string(), "token".to_string()]);
955 assert_eq!(masked["user"], "alice");
956 assert_eq!(masked["password"], "[REDACTED:password]");
957 assert_eq!(masked["token"], "[REDACTED:token]");
958 }
959
960 #[test]
961 fn test_mask_sensitive_arguments_nested() {
962 let args = serde_json::json!({
963 "config": {
964 "api_key": "sk-123",
965 "endpoint": "https://api.example.com"
966 },
967 "name": "test"
968 });
969 let masked = mask_sensitive_arguments(&args, &["api_key".to_string()]);
970 assert_eq!(masked["config"]["api_key"], "[REDACTED:api_key]");
971 assert_eq!(masked["config"]["endpoint"], "https://api.example.com");
972 assert_eq!(masked["name"], "test");
973 }
974}