Skip to main content

symbi_runtime/integrations/
tool_invocation.rs

1//! Tool Invocation Enforcement
2//!
3//! Provides verification enforcement for tool invocations to ensure only
4//! verified tools can be executed based on configurable policies.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::{Duration, Instant};
10use thiserror::Error;
11
12use crate::integrations::mcp::{McpClient, McpTool, VerificationStatus};
13use crate::logging::{ModelInteractionType, ModelLogger, RequestData, ResponseData};
14use crate::routing::{error::TaskType, ModelRequest, RoutingContext, RoutingEngine};
15use crate::types::{AgentId, RuntimeError};
16use dashmap::DashMap;
17use std::sync::Arc;
18
19/// Tool invocation enforcement errors
20#[derive(Error, Debug, Clone)]
21pub enum ToolInvocationError {
22    #[error("Tool invocation blocked: {tool_name} - {reason}")]
23    InvocationBlocked { tool_name: String, reason: String },
24
25    #[error("Tool not found: {tool_name}")]
26    ToolNotFound { tool_name: String },
27
28    #[error("Verification required but tool is not verified: {tool_name} (status: {status})")]
29    VerificationRequired { tool_name: String, status: String },
30
31    #[error("Tool verification failed: {tool_name} - {reason}")]
32    VerificationFailed { tool_name: String, reason: String },
33
34    #[error("Enforcement policy violation: {policy} - {reason}")]
35    PolicyViolation { policy: String, reason: String },
36
37    #[error("Tool invocation timeout: {tool_name}")]
38    Timeout { tool_name: String },
39
40    #[error("Tool execution failed: {tool_name} - {reason}")]
41    ExecutionFailed { tool_name: String, reason: String },
42
43    #[error("No MCP client configured for tool execution: {reason}")]
44    NoMcpClient { reason: String },
45
46    #[error("Runtime error during tool invocation: {source}")]
47    Runtime {
48        #[from]
49        source: RuntimeError,
50    },
51}
52
53/// Tool invocation enforcement policies
54#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
55pub enum EnforcementPolicy {
56    /// Strict mode - only verified tools can be executed
57    #[default]
58    Strict,
59    /// Permissive mode - unverified tools are allowed with warnings
60    Permissive,
61    /// Development mode - allows unverified tools and logs warnings
62    Development,
63    /// Disabled - no verification enforcement
64    Disabled,
65}
66
67/// Configuration for tool invocation enforcement
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct InvocationEnforcementConfig {
70    /// Primary enforcement policy
71    pub policy: EnforcementPolicy,
72    /// Whether to block tools with failed verification
73    pub block_failed_verification: bool,
74    /// Whether to block tools with pending verification
75    pub block_pending_verification: bool,
76    /// Whether to allow skipped verification in development
77    pub allow_skipped_in_dev: bool,
78    /// Timeout for tool invocation verification checks
79    pub verification_timeout: Duration,
80    /// Maximum number of warning logs before escalation
81    pub max_warnings_before_escalation: usize,
82}
83
84impl Default for InvocationEnforcementConfig {
85    fn default() -> Self {
86        Self {
87            policy: EnforcementPolicy::Strict,
88            block_failed_verification: true,
89            block_pending_verification: true,
90            allow_skipped_in_dev: false,
91            verification_timeout: Duration::from_secs(5),
92            max_warnings_before_escalation: 10,
93        }
94    }
95}
96
97/// Tool invocation context
98#[derive(Debug, Clone)]
99pub struct InvocationContext {
100    /// Agent requesting the invocation
101    pub agent_id: AgentId,
102    /// Tool name being invoked
103    pub tool_name: String,
104    /// Arguments for the tool invocation
105    pub arguments: serde_json::Value,
106    /// Timestamp of invocation request
107    pub timestamp: chrono::DateTime<chrono::Utc>,
108    /// Additional metadata
109    pub metadata: HashMap<String, String>,
110    /// Optional AgentPin credential verification result
111    pub agent_credential: Option<crate::integrations::agentpin::AgentVerificationResult>,
112}
113
114/// Tool invocation result
115#[derive(Debug, Clone)]
116pub struct InvocationResult {
117    /// Whether the invocation was successful
118    pub success: bool,
119    /// Result data from tool execution
120    pub result: serde_json::Value,
121    /// Execution time
122    pub execution_time: Duration,
123    /// Any warnings generated during invocation
124    pub warnings: Vec<String>,
125    /// Metadata about the invocation
126    pub metadata: HashMap<String, String>,
127}
128
129/// Tool invocation enforcement decision
130#[derive(Debug, Clone)]
131pub enum EnforcementDecision {
132    /// Allow the invocation to proceed
133    Allow,
134    /// Block the invocation with reason
135    Block { reason: String },
136    /// Allow with warnings
137    AllowWithWarnings { warnings: Vec<String> },
138}
139
140/// Trait for tool invocation enforcement
141#[async_trait]
142pub trait ToolInvocationEnforcer: Send + Sync {
143    /// Check if a tool invocation should be allowed based on verification status
144    async fn check_invocation_allowed(
145        &self,
146        tool: &McpTool,
147        context: &InvocationContext,
148    ) -> Result<EnforcementDecision, ToolInvocationError>;
149
150    /// Execute a tool invocation with enforcement checks
151    async fn execute_tool_with_enforcement(
152        &self,
153        tool: &McpTool,
154        context: InvocationContext,
155    ) -> Result<InvocationResult, ToolInvocationError>;
156
157    /// Get the current enforcement configuration
158    fn get_enforcement_config(&self) -> &InvocationEnforcementConfig;
159
160    /// Update the enforcement configuration
161    fn update_enforcement_config(&mut self, config: InvocationEnforcementConfig);
162}
163
164/// Recursively mask sensitive argument values in a JSON object.
165///
166/// Keys matching any entry in `sensitive_params` (case-sensitive) are replaced
167/// with `[REDACTED:sensitive_param]`. Nested objects are traversed recursively.
168pub fn mask_sensitive_arguments(
169    arguments: &serde_json::Value,
170    sensitive_params: &[String],
171) -> serde_json::Value {
172    if sensitive_params.is_empty() {
173        return arguments.clone();
174    }
175
176    match arguments {
177        serde_json::Value::Object(map) => {
178            let mut masked = serde_json::Map::new();
179            for (key, value) in map {
180                if sensitive_params.iter().any(|p| p == key) {
181                    masked.insert(
182                        key.clone(),
183                        serde_json::Value::String(format!("[REDACTED:{}]", key)),
184                    );
185                } else {
186                    masked.insert(
187                        key.clone(),
188                        mask_sensitive_arguments(value, sensitive_params),
189                    );
190                }
191            }
192            serde_json::Value::Object(masked)
193        }
194        serde_json::Value::Array(arr) => serde_json::Value::Array(
195            arr.iter()
196                .map(|v| mask_sensitive_arguments(v, sensitive_params))
197                .collect(),
198        ),
199        other => other.clone(),
200    }
201}
202
203/// Default implementation of tool invocation enforcement
204pub struct DefaultToolInvocationEnforcer {
205    config: InvocationEnforcementConfig,
206    warning_counts: DashMap<String, usize>,
207    model_logger: Option<Arc<ModelLogger>>,
208    routing_engine: Option<Arc<dyn RoutingEngine>>,
209    mcp_client: Option<Arc<dyn McpClient>>,
210}
211
212impl DefaultToolInvocationEnforcer {
213    /// Create a new tool invocation enforcer with default configuration
214    pub fn new() -> Self {
215        Self {
216            config: InvocationEnforcementConfig::default(),
217            warning_counts: DashMap::new(),
218            model_logger: None,
219            routing_engine: None,
220            mcp_client: None,
221        }
222    }
223
224    /// Create a new tool invocation enforcer with custom configuration
225    pub fn with_config(config: InvocationEnforcementConfig) -> Self {
226        Self {
227            config,
228            warning_counts: DashMap::new(),
229            model_logger: None,
230            routing_engine: None,
231            mcp_client: None,
232        }
233    }
234
235    /// Create a new tool invocation enforcer with model logging
236    pub fn with_logger(config: InvocationEnforcementConfig, logger: Arc<ModelLogger>) -> Self {
237        Self {
238            config,
239            warning_counts: DashMap::new(),
240            model_logger: Some(logger),
241            routing_engine: None,
242            mcp_client: None,
243        }
244    }
245
246    /// Create a new tool invocation enforcer with routing engine
247    pub fn with_routing(
248        config: InvocationEnforcementConfig,
249        logger: Option<Arc<ModelLogger>>,
250        routing_engine: Arc<dyn RoutingEngine>,
251    ) -> Self {
252        Self {
253            config,
254            warning_counts: DashMap::new(),
255            model_logger: logger,
256            routing_engine: Some(routing_engine),
257            mcp_client: None,
258        }
259    }
260
261    /// Create a new tool invocation enforcer with an MCP client for real tool execution
262    pub fn with_mcp_client(
263        config: InvocationEnforcementConfig,
264        mcp_client: Arc<dyn McpClient>,
265    ) -> Self {
266        Self {
267            config,
268            warning_counts: DashMap::new(),
269            model_logger: None,
270            routing_engine: None,
271            mcp_client: Some(mcp_client),
272        }
273    }
274
275    /// Check verification status and determine if tool should be allowed
276    fn check_verification_status(&self, tool: &McpTool) -> EnforcementDecision {
277        match &self.config.policy {
278            EnforcementPolicy::Disabled => EnforcementDecision::Allow,
279            EnforcementPolicy::Development => {
280                match &tool.verification_status {
281                    VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
282                    VerificationStatus::Failed { reason, .. } => {
283                        if self.config.block_failed_verification {
284                            EnforcementDecision::Block {
285                                reason: format!("Tool verification failed: {}", reason),
286                            }
287                        } else {
288                            EnforcementDecision::AllowWithWarnings {
289                                warnings: vec![format!("Tool '{}' has failed verification: {}", tool.name, reason)],
290                            }
291                        }
292                    }
293                    VerificationStatus::Pending => {
294                        EnforcementDecision::AllowWithWarnings {
295                            warnings: vec![format!("Tool '{}' verification is pending", tool.name)],
296                        }
297                    }
298                    VerificationStatus::Skipped { reason } => {
299                        if self.config.allow_skipped_in_dev {
300                            EnforcementDecision::AllowWithWarnings {
301                                warnings: vec![format!("Tool '{}' verification was skipped: {}", tool.name, reason)],
302                            }
303                        } else {
304                            EnforcementDecision::Block {
305                                reason: format!("Tool verification was skipped: {}", reason),
306                            }
307                        }
308                    }
309                }
310            }
311            EnforcementPolicy::Permissive => {
312                match &tool.verification_status {
313                    VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
314                    VerificationStatus::Failed { reason, .. } => {
315                        if self.config.block_failed_verification {
316                            EnforcementDecision::Block {
317                                reason: format!("Tool verification failed: {}", reason),
318                            }
319                        } else {
320                            EnforcementDecision::AllowWithWarnings {
321                                warnings: vec![format!("Tool '{}' has failed verification: {}", tool.name, reason)],
322                            }
323                        }
324                    }
325                    VerificationStatus::Pending => {
326                        if self.config.block_pending_verification {
327                            EnforcementDecision::AllowWithWarnings {
328                                warnings: vec![format!("Tool '{}' verification is pending", tool.name)],
329                            }
330                        } else {
331                            EnforcementDecision::Allow
332                        }
333                    }
334                    VerificationStatus::Skipped { reason } => {
335                        EnforcementDecision::AllowWithWarnings {
336                            warnings: vec![format!("Tool '{}' verification was skipped: {}", tool.name, reason)],
337                        }
338                    }
339                }
340            }
341            EnforcementPolicy::Strict => {
342                match &tool.verification_status {
343                    VerificationStatus::Verified { .. } => EnforcementDecision::Allow,
344                    VerificationStatus::Failed { reason, .. } => {
345                        EnforcementDecision::Block {
346                            reason: format!("Tool verification failed: {}", reason),
347                        }
348                    }
349                    VerificationStatus::Pending => {
350                        EnforcementDecision::Block {
351                            reason: "Tool verification is pending - only verified tools are allowed in strict mode".to_string(),
352                        }
353                    }
354                    VerificationStatus::Skipped { reason } => {
355                        EnforcementDecision::Block {
356                            reason: format!("Tool verification was skipped: {} - only verified tools are allowed in strict mode", reason),
357                        }
358                    }
359                }
360            }
361        }
362    }
363
364    /// Increment warning count for a tool and check if escalation is needed
365    fn handle_warning(&self, tool_name: &str, warning: &str) -> bool {
366        let mut count = self
367            .warning_counts
368            .entry(tool_name.to_string())
369            .or_insert(0);
370        *count += 1;
371
372        if *count >= self.config.max_warnings_before_escalation {
373            tracing::warn!(
374                "Tool '{}' has exceeded warning threshold ({} warnings): {}",
375                tool_name,
376                *count,
377                warning
378            );
379            // Reset count after escalation
380            *count = 0;
381            true
382        } else {
383            tracing::warn!(
384                "Tool '{}' warning (count: {}): {}",
385                tool_name,
386                *count,
387                warning
388            );
389            false
390        }
391    }
392
393    /// Use routing engine to determine best model for tool execution
394    #[allow(dead_code)]
395    async fn route_tool_execution(
396        &self,
397        tool: &McpTool,
398        context: &InvocationContext,
399    ) -> Result<Option<String>, ToolInvocationError> {
400        if let Some(ref routing_engine) = self.routing_engine {
401            // Classify the tool task type based on tool description and arguments
402            let task_type = self.classify_tool_task(tool, context);
403
404            // Create routing context
405            let routing_context = RoutingContext::new(
406                context.agent_id,
407                task_type,
408                format!("Tool: {} - {}", tool.name, tool.description),
409            );
410
411            // Create model request
412            let _model_request = ModelRequest::from_task(format!(
413                "Execute tool '{}' with arguments: {}",
414                tool.name, context.arguments
415            ));
416
417            // Get routing decision
418            match routing_engine.route_request(&routing_context).await {
419                Ok(decision) => {
420                    tracing::debug!("Routing decision for tool '{}': {:?}", tool.name, decision);
421                    // Return the routing decision info for logging/metadata
422                    Ok(Some(format!("{:?}", decision)))
423                }
424                Err(e) => {
425                    tracing::warn!("Routing failed for tool '{}': {}", tool.name, e);
426                    Ok(None)
427                }
428            }
429        } else {
430            Ok(None)
431        }
432    }
433
434    /// Classify tool execution into routing task types
435    #[allow(dead_code)]
436    fn classify_tool_task(&self, tool: &McpTool, context: &InvocationContext) -> TaskType {
437        let tool_name_lower = tool.name.to_lowercase();
438        let description_lower = tool.description.to_lowercase();
439
440        // Analyze tool name and description to determine task type
441        if tool_name_lower.contains("code")
442            || description_lower.contains("code")
443            || tool_name_lower.contains("program")
444            || description_lower.contains("program")
445        {
446            TaskType::CodeGeneration
447        } else if tool_name_lower.contains("analyze")
448            || description_lower.contains("analy")
449            || tool_name_lower.contains("inspect")
450            || description_lower.contains("inspect")
451        {
452            TaskType::Analysis
453        } else if tool_name_lower.contains("extract")
454            || description_lower.contains("extract")
455            || tool_name_lower.contains("parse")
456            || description_lower.contains("parse")
457        {
458            TaskType::Extract
459        } else if tool_name_lower.contains("summarize") || description_lower.contains("summar") {
460            TaskType::Summarization
461        } else if tool_name_lower.contains("translate") || description_lower.contains("translat") {
462            TaskType::Translation
463        } else if tool_name_lower.contains("reason")
464            || description_lower.contains("reason")
465            || tool_name_lower.contains("logic")
466            || description_lower.contains("logic")
467        {
468            TaskType::Reasoning
469        } else if tool_name_lower.contains("template") || description_lower.contains("template") {
470            TaskType::Template
471        } else if context.arguments.to_string().len() < 100 {
472            // Simple tools with minimal arguments
473            TaskType::Intent
474        } else {
475            // Default to QA for general tools
476            TaskType::QA
477        }
478    }
479    /// Execute the actual tool via the configured MCP client.
480    /// Returns an error if no MCP client is configured.
481    async fn execute_actual_tool(
482        &self,
483        tool: &McpTool,
484        context: &InvocationContext,
485        _execution_time: Duration,
486    ) -> Result<InvocationResult, ToolInvocationError> {
487        let mcp_client =
488            self.mcp_client
489                .as_ref()
490                .ok_or_else(|| ToolInvocationError::NoMcpClient {
491                    reason: format!(
492                        "No MCP client configured for tool execution of '{}'",
493                        tool.name
494                    ),
495                })?;
496
497        tracing::info!(
498            "Executing tool '{}' for agent {} via MCP (provider: '{}')",
499            tool.name,
500            context.agent_id,
501            tool.provider.identifier
502        );
503
504        let result = mcp_client
505            .invoke_tool(&tool.name, context.arguments.clone(), context.clone())
506            .await
507            .map_err(|e| ToolInvocationError::ExecutionFailed {
508                tool_name: tool.name.clone(),
509                reason: format!("MCP invocation failed: {}", e),
510            })?;
511
512        Ok(result)
513    }
514}
515
516impl Default for DefaultToolInvocationEnforcer {
517    fn default() -> Self {
518        Self::new()
519    }
520}
521
522#[async_trait]
523impl ToolInvocationEnforcer for DefaultToolInvocationEnforcer {
524    async fn check_invocation_allowed(
525        &self,
526        tool: &McpTool,
527        _context: &InvocationContext,
528    ) -> Result<EnforcementDecision, ToolInvocationError> {
529        Ok(self.check_verification_status(tool))
530    }
531
532    async fn execute_tool_with_enforcement(
533        &self,
534        tool: &McpTool,
535        context: InvocationContext,
536    ) -> Result<InvocationResult, ToolInvocationError> {
537        let start_time = Instant::now();
538
539        // Check if invocation is allowed
540        let decision = self.check_invocation_allowed(tool, &context).await?;
541
542        // Mask sensitive arguments before logging
543        let redacted_arguments =
544            mask_sensitive_arguments(&context.arguments, &tool.sensitive_params);
545
546        // Prepare request data for logging
547        let request_data = RequestData {
548            prompt: format!("Tool invocation: {}", tool.name),
549            tool_name: Some(tool.name.clone()),
550            tool_arguments: Some(redacted_arguments),
551            parameters: {
552                let mut params = HashMap::new();
553                params.insert(
554                    "verification_status".to_string(),
555                    serde_json::Value::String(format!("{:?}", tool.verification_status)),
556                );
557                params.insert(
558                    "enforcement_policy".to_string(),
559                    serde_json::Value::String(format!("{:?}", self.config.policy)),
560                );
561                params
562            },
563        };
564
565        match decision {
566            EnforcementDecision::Allow => {
567                let execution_time = start_time.elapsed();
568
569                // Prepare successful response data
570                let response_data = ResponseData {
571                    content: "Tool invocation allowed and executed".to_string(),
572                    tool_result: Some(
573                        serde_json::json!({"status": "success", "message": "Tool invocation allowed"}),
574                    ),
575                    confidence: Some(1.0),
576                    metadata: HashMap::new(),
577                };
578
579                // Log the tool invocation if logger is available
580                if let Some(ref logger) = self.model_logger {
581                    let metadata = {
582                        let mut meta = HashMap::new();
583                        meta.insert(
584                            "tool_provider".to_string(),
585                            tool.provider.identifier.clone(),
586                        );
587                        meta.insert("enforcement_decision".to_string(), "allow".to_string());
588                        meta.insert("agent_id".to_string(), context.agent_id.to_string());
589                        meta
590                    };
591
592                    if let Err(e) = logger
593                        .log_interaction(
594                            context.agent_id,
595                            ModelInteractionType::ToolCall,
596                            &tool.name,
597                            request_data,
598                            response_data,
599                            execution_time,
600                            metadata,
601                            None, // No token usage for tool calls
602                            None,
603                        )
604                        .await
605                    {
606                        tracing::warn!("Failed to log tool invocation: {}", e);
607                    }
608                }
609
610                // Execute the actual tool
611                self.execute_actual_tool(tool, &context, execution_time)
612                    .await
613            }
614            EnforcementDecision::Block { reason } => {
615                let execution_time = start_time.elapsed();
616
617                // Log the blocked invocation if logger is available
618                if let Some(ref logger) = self.model_logger {
619                    let response_data = ResponseData {
620                        content: "Tool invocation blocked".to_string(),
621                        tool_result: Some(
622                            serde_json::json!({"status": "blocked", "reason": &reason}),
623                        ),
624                        confidence: Some(1.0),
625                        metadata: HashMap::new(),
626                    };
627
628                    let metadata = {
629                        let mut meta = HashMap::new();
630                        meta.insert(
631                            "tool_provider".to_string(),
632                            tool.provider.identifier.clone(),
633                        );
634                        meta.insert("enforcement_decision".to_string(), "block".to_string());
635                        meta.insert("agent_id".to_string(), context.agent_id.to_string());
636                        meta
637                    };
638
639                    if let Err(e) = logger
640                        .log_interaction(
641                            context.agent_id,
642                            ModelInteractionType::ToolCall,
643                            &tool.name,
644                            request_data,
645                            response_data,
646                            execution_time,
647                            metadata,
648                            None,
649                            Some(reason.clone()),
650                        )
651                        .await
652                    {
653                        tracing::warn!("Failed to log blocked tool invocation: {}", e);
654                    }
655                }
656
657                Err(ToolInvocationError::InvocationBlocked {
658                    tool_name: tool.name.clone(),
659                    reason,
660                })
661            }
662            EnforcementDecision::AllowWithWarnings { warnings } => {
663                let execution_time = start_time.elapsed();
664
665                // Handle warnings
666                let mut escalated = false;
667                for warning in &warnings {
668                    if self.handle_warning(&tool.name, warning) {
669                        escalated = true;
670                    }
671                }
672
673                // Prepare response data with warnings
674                let response_data = ResponseData {
675                    content: "Tool invocation allowed with warnings".to_string(),
676                    tool_result: Some(serde_json::json!({
677                        "status": "success",
678                        "message": "Tool invocation allowed with warnings",
679                        "warnings": &warnings
680                    })),
681                    confidence: Some(0.8), // Lower confidence due to warnings
682                    metadata: HashMap::new(),
683                };
684
685                // Log the tool invocation with warnings if logger is available
686                if let Some(ref logger) = self.model_logger {
687                    let metadata = {
688                        let mut meta = HashMap::new();
689                        meta.insert(
690                            "tool_provider".to_string(),
691                            tool.provider.identifier.clone(),
692                        );
693                        meta.insert(
694                            "enforcement_decision".to_string(),
695                            "allow_with_warnings".to_string(),
696                        );
697                        meta.insert("agent_id".to_string(), context.agent_id.to_string());
698                        meta.insert("warnings_count".to_string(), warnings.len().to_string());
699                        if escalated {
700                            meta.insert("escalated".to_string(), "true".to_string());
701                        }
702                        meta
703                    };
704
705                    if let Err(e) = logger
706                        .log_interaction(
707                            context.agent_id,
708                            ModelInteractionType::ToolCall,
709                            &tool.name,
710                            request_data,
711                            response_data,
712                            execution_time,
713                            metadata,
714                            None,
715                            None,
716                        )
717                        .await
718                    {
719                        tracing::warn!("Failed to log tool invocation with warnings: {}", e);
720                    }
721                }
722
723                // Execute the actual tool with warnings
724                let mut result = self
725                    .execute_actual_tool(tool, &context, execution_time)
726                    .await?;
727                result.warnings.extend(warnings);
728                if escalated {
729                    result
730                        .metadata
731                        .insert("escalated".to_string(), "true".to_string());
732                }
733                Ok(result)
734            }
735        }
736    }
737
738    fn get_enforcement_config(&self) -> &InvocationEnforcementConfig {
739        &self.config
740    }
741
742    fn update_enforcement_config(&mut self, config: InvocationEnforcementConfig) {
743        self.config = config;
744    }
745}
746
747#[cfg(test)]
748mod tests {
749    use super::*;
750    use crate::integrations::mcp::{McpTool, MockMcpClient, ToolProvider, VerificationStatus};
751    use crate::integrations::schemapin::VerificationResult;
752
753    fn create_test_tool(verification_status: VerificationStatus) -> McpTool {
754        McpTool {
755            name: "test_tool".to_string(),
756            description: "A test tool".to_string(),
757            schema: serde_json::json!({"type": "object"}),
758            provider: ToolProvider {
759                identifier: "test.example.com".to_string(),
760                name: "Test Provider".to_string(),
761                public_key_url: "https://test.example.com/pubkey".to_string(),
762                version: Some("1.0.0".to_string()),
763            },
764            verification_status,
765            metadata: None,
766            sensitive_params: vec![],
767        }
768    }
769
770    fn create_test_context() -> InvocationContext {
771        InvocationContext {
772            agent_id: AgentId::new(),
773            tool_name: "test_tool".to_string(),
774            arguments: serde_json::json!({"test": "value"}),
775            timestamp: chrono::Utc::now(),
776            metadata: HashMap::new(),
777            agent_credential: None,
778        }
779    }
780
781    async fn create_enforcer_with_mock_mcp(
782        config: InvocationEnforcementConfig,
783        tool: &McpTool,
784    ) -> DefaultToolInvocationEnforcer {
785        let mock_client = Arc::new(MockMcpClient::new_success());
786        // Register the tool so invoke_tool can find it
787        let _ = mock_client.discover_tool(tool.clone()).await;
788        DefaultToolInvocationEnforcer {
789            config,
790            warning_counts: DashMap::new(),
791            model_logger: None,
792            routing_engine: None,
793            mcp_client: Some(mock_client),
794        }
795    }
796
797    #[tokio::test]
798    async fn test_strict_mode_allows_verified_tools() {
799        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
800            policy: EnforcementPolicy::Strict,
801            ..Default::default()
802        });
803
804        let tool = create_test_tool(VerificationStatus::Verified {
805            result: Box::new(VerificationResult {
806                success: true,
807                message: "Test verification".to_string(),
808                schema_hash: Some("hash123".to_string()),
809                public_key_url: Some("https://test.example.com/pubkey".to_string()),
810                signature: None,
811                metadata: None,
812                timestamp: Some("2024-01-01T00:00:00Z".to_string()),
813            }),
814            verified_at: "2024-01-01T00:00:00Z".to_string(),
815        });
816
817        let context = create_test_context();
818        let decision = enforcer
819            .check_invocation_allowed(&tool, &context)
820            .await
821            .unwrap();
822
823        assert!(matches!(decision, EnforcementDecision::Allow));
824    }
825
826    #[tokio::test]
827    async fn test_strict_mode_blocks_unverified_tools() {
828        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
829            policy: EnforcementPolicy::Strict,
830            ..Default::default()
831        });
832
833        let tool = create_test_tool(VerificationStatus::Pending);
834        let context = create_test_context();
835        let decision = enforcer
836            .check_invocation_allowed(&tool, &context)
837            .await
838            .unwrap();
839
840        assert!(matches!(decision, EnforcementDecision::Block { .. }));
841    }
842
843    #[tokio::test]
844    async fn test_permissive_mode_allows_with_warnings() {
845        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
846            policy: EnforcementPolicy::Permissive,
847            block_pending_verification: true,
848            ..Default::default()
849        });
850
851        let tool = create_test_tool(VerificationStatus::Pending);
852        let context = create_test_context();
853        let decision = enforcer
854            .check_invocation_allowed(&tool, &context)
855            .await
856            .unwrap();
857
858        assert!(matches!(
859            decision,
860            EnforcementDecision::AllowWithWarnings { .. }
861        ));
862    }
863
864    #[tokio::test]
865    async fn test_disabled_mode_allows_all_tools() {
866        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
867            policy: EnforcementPolicy::Disabled,
868            ..Default::default()
869        });
870
871        let tool = create_test_tool(VerificationStatus::Failed {
872            reason: "Test failure".to_string(),
873            failed_at: "2024-01-01T00:00:00Z".to_string(),
874        });
875        let context = create_test_context();
876        let decision = enforcer
877            .check_invocation_allowed(&tool, &context)
878            .await
879            .unwrap();
880
881        assert!(matches!(decision, EnforcementDecision::Allow));
882    }
883
884    #[tokio::test]
885    async fn test_execute_tool_blocks_unverified_in_strict_mode() {
886        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
887            policy: EnforcementPolicy::Strict,
888            ..Default::default()
889        });
890
891        let tool = create_test_tool(VerificationStatus::Pending);
892        let context = create_test_context();
893        let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
894
895        assert!(result.is_err());
896        assert!(matches!(
897            result.unwrap_err(),
898            ToolInvocationError::InvocationBlocked { .. }
899        ));
900    }
901
902    #[tokio::test]
903    async fn test_execute_tool_succeeds_with_warnings() {
904        let tool = create_test_tool(VerificationStatus::Pending);
905        let enforcer = create_enforcer_with_mock_mcp(
906            InvocationEnforcementConfig {
907                policy: EnforcementPolicy::Permissive,
908                block_pending_verification: true,
909                ..Default::default()
910            },
911            &tool,
912        )
913        .await;
914
915        let context = create_test_context();
916        let result = enforcer
917            .execute_tool_with_enforcement(&tool, context)
918            .await
919            .unwrap();
920
921        assert!(result.success);
922        assert!(!result.warnings.is_empty());
923    }
924
925    #[tokio::test]
926    async fn test_execute_tool_fails_without_mcp_client() {
927        let enforcer = DefaultToolInvocationEnforcer::with_config(InvocationEnforcementConfig {
928            policy: EnforcementPolicy::Disabled,
929            ..Default::default()
930        });
931
932        let tool = create_test_tool(VerificationStatus::Pending);
933        let context = create_test_context();
934        let result = enforcer.execute_tool_with_enforcement(&tool, context).await;
935
936        assert!(result.is_err());
937        assert!(matches!(
938            result.unwrap_err(),
939            ToolInvocationError::NoMcpClient { .. }
940        ));
941    }
942
943    #[test]
944    fn test_mask_sensitive_arguments_empty_list() {
945        let args = serde_json::json!({"user": "alice", "password": "s3cret"});
946        let masked = mask_sensitive_arguments(&args, &[]);
947        assert_eq!(masked, args);
948    }
949
950    #[test]
951    fn test_mask_sensitive_arguments_flat() {
952        let args = serde_json::json!({"user": "alice", "password": "s3cret", "token": "abc"});
953        let masked =
954            mask_sensitive_arguments(&args, &["password".to_string(), "token".to_string()]);
955        assert_eq!(masked["user"], "alice");
956        assert_eq!(masked["password"], "[REDACTED:password]");
957        assert_eq!(masked["token"], "[REDACTED:token]");
958    }
959
960    #[test]
961    fn test_mask_sensitive_arguments_nested() {
962        let args = serde_json::json!({
963            "config": {
964                "api_key": "sk-123",
965                "endpoint": "https://api.example.com"
966            },
967            "name": "test"
968        });
969        let masked = mask_sensitive_arguments(&args, &["api_key".to_string()]);
970        assert_eq!(masked["config"]["api_key"], "[REDACTED:api_key]");
971        assert_eq!(masked["config"]["endpoint"], "https://api.example.com");
972        assert_eq!(masked["name"], "test");
973    }
974}