Skip to main content

vtcode_core/tools/registry/
risk_scorer.rs

1//! Risk scoring system for tool execution
2
3use crate::config::constants::tools;
4use crate::utils::ansi_codes::{FG_GREEN, FG_MAGENTA, FG_RED, FG_YELLOW};
5use serde::{Deserialize, Serialize};
6
7/// Risk level classification for tools
8#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
10pub enum RiskLevel {
11    /// Read-only operations with no side effects
12    Low,
13
14    /// Operations that create/modify data but within trusted boundaries
15    Medium,
16
17    /// Operations with potentially destructive effects or external access
18    High,
19
20    /// Operations that could compromise system security
21    Critical,
22}
23
24impl RiskLevel {
25    pub fn as_str(self) -> &'static str {
26        match self {
27            Self::Low => "low",
28            Self::Medium => "medium",
29            Self::High => "high",
30            Self::Critical => "critical",
31        }
32    }
33
34    pub fn color_code(self) -> &'static str {
35        match self {
36            Self::Low => FG_GREEN,
37            Self::Medium => FG_YELLOW,
38            Self::High => FG_RED,
39            Self::Critical => FG_MAGENTA,
40        }
41    }
42}
43
44impl std::fmt::Display for RiskLevel {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.write_str(self.as_str())
47    }
48}
49
50/// Source of the tool (internal, MCP, ACP, etc.)
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ToolSource {
53    /// Built-in tools
54    Internal,
55
56    /// Model Context Protocol (external)
57    Mcp,
58
59    /// Agent Client Protocol (IDE integration)
60    Acp,
61
62    /// Other external sources
63    External,
64}
65
66impl ToolSource {
67    /// Get the risk multiplier for this source
68    /// MCP/external tools are considered higher risk
69    pub fn risk_multiplier(self) -> f32 {
70        match self {
71            Self::Internal => 1.0,
72            Self::Mcp => 1.5,
73            Self::Acp => 1.2,
74            Self::External => 2.0,
75        }
76    }
77}
78
79/// Workspace trust level
80#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
81pub enum WorkspaceTrust {
82    Untrusted,
83    Partial,
84    Trusted,
85    FullAuto,
86}
87
88impl WorkspaceTrust {
89    /// Get the risk reduction multiplier for trusted workspaces
90    pub fn risk_reduction(self) -> f32 {
91        match self {
92            Self::Untrusted => 1.0,
93            Self::Partial => 0.8,
94            Self::Trusted => 0.6,
95            Self::FullAuto => 0.3,
96        }
97    }
98}
99
100/// Context for risk assessment
101#[derive(Debug, Clone)]
102pub struct ToolRiskContext {
103    /// Tool name
104    pub tool_name: String,
105
106    /// Source of the tool
107    pub source: ToolSource,
108
109    /// Workspace trust level
110    pub workspace_trust: WorkspaceTrust,
111
112    /// Number of times this tool has been approved recently
113    pub recent_approvals: usize,
114
115    /// Command arguments (if applicable)
116    pub command_args: Vec<String>,
117
118    /// Whether this is a write operation
119    pub is_write: bool,
120
121    /// Whether this is a potentially destructive operation
122    pub is_destructive: bool,
123
124    /// Whether this accesses external network
125    pub accesses_network: bool,
126}
127
128impl ToolRiskContext {
129    /// Create a new risk context
130    pub fn new(tool_name: String, source: ToolSource, workspace_trust: WorkspaceTrust) -> Self {
131        Self {
132            tool_name,
133            source,
134            workspace_trust,
135            recent_approvals: 0,
136            command_args: Vec::new(),
137            is_write: false,
138            is_destructive: false,
139            accesses_network: false,
140        }
141    }
142
143    /// Set command arguments
144    pub fn with_args(mut self, args: Vec<String>) -> Self {
145        self.command_args = args;
146        self
147    }
148
149    /// Mark as write operation
150    pub fn as_write(mut self) -> Self {
151        self.is_write = true;
152        self
153    }
154
155    /// Mark as potentially destructive
156    pub fn as_destructive(mut self) -> Self {
157        self.is_destructive = true;
158        self
159    }
160
161    /// Mark as network-accessing
162    pub fn accesses_network(mut self) -> Self {
163        self.accesses_network = true;
164        self
165    }
166}
167
168/// Risk scorer for tool execution
169pub struct ToolRiskScorer;
170
171impl ToolRiskScorer {
172    /// Calculate risk level for a tool
173    pub fn calculate_risk(ctx: &ToolRiskContext) -> RiskLevel {
174        let mut base_score = Self::base_risk_for_tool(&ctx.tool_name);
175
176        // Apply modifiers
177        if ctx.is_destructive {
178            base_score += 30;
179        }
180        if ctx.is_write {
181            base_score += 15;
182        }
183        if ctx.accesses_network {
184            base_score += 10;
185        }
186
187        // Apply source multiplier
188        base_score = (base_score as f32 * ctx.source.risk_multiplier()) as u32;
189
190        // Apply trust reduction
191        base_score = (base_score as f32 * ctx.workspace_trust.risk_reduction()) as u32;
192
193        // Approval history reduces risk (diminishing returns)
194        let approval_reduction = ctx.recent_approvals.min(3) as u32 * 5;
195        base_score = base_score.saturating_sub(approval_reduction);
196
197        // Convert to risk level
198        match base_score {
199            0..=25 => RiskLevel::Low,
200            26..=50 => RiskLevel::Medium,
201            51..=75 => RiskLevel::High,
202            _ => RiskLevel::Critical,
203        }
204    }
205
206    /// Determine if justification is required
207    pub fn requires_justification(risk: RiskLevel, threshold: RiskLevel) -> bool {
208        risk >= threshold
209    }
210
211    /// Base risk score for common tools
212    fn base_risk_for_tool(tool_name: &str) -> u32 {
213        match tool_name {
214            // Read-only tools (base: 0)
215            tools::READ_FILE
216            | tools::UNIFIED_SEARCH
217            | tools::MCP_SEARCH_TOOLS
218            | tools::MCP_GET_TOOL_DETAILS
219            | tools::MCP_LIST_SERVERS => 0,
220
221            // Safe metadata tools (base: 5)
222            "file_info" | "status" | "logs" => 5,
223
224            // Write tools (base: 20)
225            tools::WRITE_FILE | tools::EDIT_FILE | tools::CREATE_FILE => 20,
226            tools::MCP_CONNECT_SERVER | tools::MCP_DISCONNECT_SERVER => 20,
227
228            // Potentially risky write operations (base: 25)
229            tools::APPLY_PATCH | tools::DELETE_FILE => 25,
230
231            // PTY/interactive commands (base: 35)
232            tools::CREATE_PTY_SESSION
233            | tools::RUN_PTY_CMD
234            | tools::SEND_PTY_INPUT
235            | tools::UNIFIED_EXEC => 35,
236
237            // Network operations (base: 40)
238            tools::WEB_SEARCH | tools::FETCH_URL | "unified_search:web" => 40,
239
240            // MCP tools (default to medium risk)
241            _ if tool_name.starts_with("mcp_") => 30,
242
243            // Unknown tools default to medium-high risk
244            _ => 35,
245        }
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_risk_level_ordering() {
255        assert!(RiskLevel::Low < RiskLevel::Medium);
256        assert!(RiskLevel::Medium < RiskLevel::High);
257        assert!(RiskLevel::High < RiskLevel::Critical);
258    }
259
260    #[test]
261    fn test_risk_calculation() {
262        // Read-only operation in trusted workspace
263        let ctx = ToolRiskContext::new(
264            tools::READ_FILE.to_string(),
265            ToolSource::Internal,
266            WorkspaceTrust::Trusted,
267        );
268        let risk = ToolRiskScorer::calculate_risk(&ctx);
269        assert_eq!(risk, RiskLevel::Low);
270
271        // Write operation in untrusted workspace
272        let ctx = ToolRiskContext::new(
273            tools::WRITE_FILE.to_string(),
274            ToolSource::External,
275            WorkspaceTrust::Untrusted,
276        )
277        .as_write();
278        let risk = ToolRiskScorer::calculate_risk(&ctx);
279        assert!(risk >= RiskLevel::High);
280    }
281
282    #[test]
283    fn test_approval_history_reduces_risk() {
284        let mut ctx = ToolRiskContext::new(
285            tools::RUN_PTY_CMD.to_string(),
286            ToolSource::Internal,
287            WorkspaceTrust::Untrusted,
288        );
289
290        let risk_before = ToolRiskScorer::calculate_risk(&ctx);
291
292        ctx.recent_approvals = 3;
293        let risk_after = ToolRiskScorer::calculate_risk(&ctx);
294
295        assert!(risk_after <= risk_before);
296    }
297
298    #[test]
299    fn test_source_multiplier() {
300        let base = ToolRiskContext::new(
301            "mcp_tool".to_string(),
302            ToolSource::Internal,
303            WorkspaceTrust::Trusted,
304        );
305        let base_risk = ToolRiskScorer::calculate_risk(&base);
306
307        let mcp = ToolRiskContext::new(
308            "mcp_tool".to_string(),
309            ToolSource::Mcp,
310            WorkspaceTrust::Trusted,
311        );
312        let mcp_risk = ToolRiskScorer::calculate_risk(&mcp);
313
314        // MCP tool should have higher risk
315        assert!(mcp_risk > base_risk || mcp_risk == RiskLevel::Critical);
316    }
317
318    #[test]
319    fn test_requires_justification() {
320        assert!(ToolRiskScorer::requires_justification(
321            RiskLevel::High,
322            RiskLevel::High
323        ));
324        assert!(!ToolRiskScorer::requires_justification(
325            RiskLevel::Medium,
326            RiskLevel::High
327        ));
328    }
329}