vtcode_core/tools/registry/
risk_scorer.rs1use crate::config::constants::tools;
4use crate::utils::ansi_codes::{FG_GREEN, FG_MAGENTA, FG_RED, FG_YELLOW};
5use serde::{Deserialize, Serialize};
6
7#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
10pub enum RiskLevel {
11 Low,
13
14 Medium,
16
17 High,
19
20 Critical,
22}
23
24impl RiskLevel {
25 pub fn as_str(self) -> &'static str {
26 match self {
27 Self::Low => "low",
28 Self::Medium => "medium",
29 Self::High => "high",
30 Self::Critical => "critical",
31 }
32 }
33
34 pub fn color_code(self) -> &'static str {
35 match self {
36 Self::Low => FG_GREEN,
37 Self::Medium => FG_YELLOW,
38 Self::High => FG_RED,
39 Self::Critical => FG_MAGENTA,
40 }
41 }
42}
43
44impl std::fmt::Display for RiskLevel {
45 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46 f.write_str(self.as_str())
47 }
48}
49
50#[derive(Debug, Clone, Copy, PartialEq, Eq)]
52pub enum ToolSource {
53 Internal,
55
56 Mcp,
58
59 Acp,
61
62 External,
64}
65
66impl ToolSource {
67 pub fn risk_multiplier(self) -> f32 {
70 match self {
71 Self::Internal => 1.0,
72 Self::Mcp => 1.5,
73 Self::Acp => 1.2,
74 Self::External => 2.0,
75 }
76 }
77}
78
79#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
81pub enum WorkspaceTrust {
82 Untrusted,
83 Partial,
84 Trusted,
85 FullAuto,
86}
87
88impl WorkspaceTrust {
89 pub fn risk_reduction(self) -> f32 {
91 match self {
92 Self::Untrusted => 1.0,
93 Self::Partial => 0.8,
94 Self::Trusted => 0.6,
95 Self::FullAuto => 0.3,
96 }
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct ToolRiskContext {
103 pub tool_name: String,
105
106 pub source: ToolSource,
108
109 pub workspace_trust: WorkspaceTrust,
111
112 pub recent_approvals: usize,
114
115 pub command_args: Vec<String>,
117
118 pub is_write: bool,
120
121 pub is_destructive: bool,
123
124 pub accesses_network: bool,
126}
127
128impl ToolRiskContext {
129 pub fn new(tool_name: String, source: ToolSource, workspace_trust: WorkspaceTrust) -> Self {
131 Self {
132 tool_name,
133 source,
134 workspace_trust,
135 recent_approvals: 0,
136 command_args: Vec::new(),
137 is_write: false,
138 is_destructive: false,
139 accesses_network: false,
140 }
141 }
142
143 pub fn with_args(mut self, args: Vec<String>) -> Self {
145 self.command_args = args;
146 self
147 }
148
149 pub fn as_write(mut self) -> Self {
151 self.is_write = true;
152 self
153 }
154
155 pub fn as_destructive(mut self) -> Self {
157 self.is_destructive = true;
158 self
159 }
160
161 pub fn accesses_network(mut self) -> Self {
163 self.accesses_network = true;
164 self
165 }
166}
167
168pub struct ToolRiskScorer;
170
171impl ToolRiskScorer {
172 pub fn calculate_risk(ctx: &ToolRiskContext) -> RiskLevel {
174 let mut base_score = Self::base_risk_for_tool(&ctx.tool_name);
175
176 if ctx.is_destructive {
178 base_score += 30;
179 }
180 if ctx.is_write {
181 base_score += 15;
182 }
183 if ctx.accesses_network {
184 base_score += 10;
185 }
186
187 base_score = (base_score as f32 * ctx.source.risk_multiplier()) as u32;
189
190 base_score = (base_score as f32 * ctx.workspace_trust.risk_reduction()) as u32;
192
193 let approval_reduction = ctx.recent_approvals.min(3) as u32 * 5;
195 base_score = base_score.saturating_sub(approval_reduction);
196
197 match base_score {
199 0..=25 => RiskLevel::Low,
200 26..=50 => RiskLevel::Medium,
201 51..=75 => RiskLevel::High,
202 _ => RiskLevel::Critical,
203 }
204 }
205
206 pub fn requires_justification(risk: RiskLevel, threshold: RiskLevel) -> bool {
208 risk >= threshold
209 }
210
211 fn base_risk_for_tool(tool_name: &str) -> u32 {
213 match tool_name {
214 tools::READ_FILE
216 | tools::UNIFIED_SEARCH
217 | tools::MCP_SEARCH_TOOLS
218 | tools::MCP_GET_TOOL_DETAILS
219 | tools::MCP_LIST_SERVERS => 0,
220
221 "file_info" | "status" | "logs" => 5,
223
224 tools::WRITE_FILE | tools::EDIT_FILE | tools::CREATE_FILE => 20,
226 tools::MCP_CONNECT_SERVER | tools::MCP_DISCONNECT_SERVER => 20,
227
228 tools::APPLY_PATCH | tools::DELETE_FILE => 25,
230
231 tools::CREATE_PTY_SESSION
233 | tools::RUN_PTY_CMD
234 | tools::SEND_PTY_INPUT
235 | tools::UNIFIED_EXEC => 35,
236
237 tools::WEB_SEARCH | tools::FETCH_URL | "unified_search:web" => 40,
239
240 _ if tool_name.starts_with("mcp_") => 30,
242
243 _ => 35,
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_risk_level_ordering() {
255 assert!(RiskLevel::Low < RiskLevel::Medium);
256 assert!(RiskLevel::Medium < RiskLevel::High);
257 assert!(RiskLevel::High < RiskLevel::Critical);
258 }
259
260 #[test]
261 fn test_risk_calculation() {
262 let ctx = ToolRiskContext::new(
264 tools::READ_FILE.to_string(),
265 ToolSource::Internal,
266 WorkspaceTrust::Trusted,
267 );
268 let risk = ToolRiskScorer::calculate_risk(&ctx);
269 assert_eq!(risk, RiskLevel::Low);
270
271 let ctx = ToolRiskContext::new(
273 tools::WRITE_FILE.to_string(),
274 ToolSource::External,
275 WorkspaceTrust::Untrusted,
276 )
277 .as_write();
278 let risk = ToolRiskScorer::calculate_risk(&ctx);
279 assert!(risk >= RiskLevel::High);
280 }
281
282 #[test]
283 fn test_approval_history_reduces_risk() {
284 let mut ctx = ToolRiskContext::new(
285 tools::RUN_PTY_CMD.to_string(),
286 ToolSource::Internal,
287 WorkspaceTrust::Untrusted,
288 );
289
290 let risk_before = ToolRiskScorer::calculate_risk(&ctx);
291
292 ctx.recent_approvals = 3;
293 let risk_after = ToolRiskScorer::calculate_risk(&ctx);
294
295 assert!(risk_after <= risk_before);
296 }
297
298 #[test]
299 fn test_source_multiplier() {
300 let base = ToolRiskContext::new(
301 "mcp_tool".to_string(),
302 ToolSource::Internal,
303 WorkspaceTrust::Trusted,
304 );
305 let base_risk = ToolRiskScorer::calculate_risk(&base);
306
307 let mcp = ToolRiskContext::new(
308 "mcp_tool".to_string(),
309 ToolSource::Mcp,
310 WorkspaceTrust::Trusted,
311 );
312 let mcp_risk = ToolRiskScorer::calculate_risk(&mcp);
313
314 assert!(mcp_risk > base_risk || mcp_risk == RiskLevel::Critical);
316 }
317
318 #[test]
319 fn test_requires_justification() {
320 assert!(ToolRiskScorer::requires_justification(
321 RiskLevel::High,
322 RiskLevel::High
323 ));
324 assert!(!ToolRiskScorer::requires_justification(
325 RiskLevel::Medium,
326 RiskLevel::High
327 ));
328 }
329}