vtcode_core/tools/registry/
justification.rs1use crate::tools::registry::risk_scorer::RiskLevel;
6use crate::utils::file_utils::{
7 ensure_dir_exists_sync, read_file_with_context_sync, write_file_with_context_sync,
8};
9use anyhow::Result;
10use hashbrown::HashMap;
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolJustification {
17 pub tool_name: String,
19 pub reason: String,
21 pub expected_outcome: Option<String>,
23 pub risk_level: String,
25 pub timestamp: String,
27}
28
29impl ToolJustification {
30 pub fn new(
32 tool_name: impl Into<String>,
33 reason: impl Into<String>,
34 risk_level: &RiskLevel,
35 ) -> Self {
36 Self {
37 tool_name: tool_name.into(),
38 reason: reason.into(),
39 expected_outcome: None,
40 risk_level: format!("{:?}", risk_level),
41 timestamp: chrono::Local::now().to_rfc3339(),
42 }
43 }
44
45 pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
47 self.expected_outcome = Some(outcome.into());
48 self
49 }
50
51 pub fn format_for_dialog(&self) -> Vec<String> {
53 let mut lines = vec![];
54
55 lines.push(String::new());
56 lines.push("Agent Reasoning:".to_owned());
57
58 for line in self.reason.lines() {
60 let wrapped = textwrap::fill(&format!(" {line}"), 78);
61 for wrapped_line in wrapped.lines() {
62 lines.push(wrapped_line.to_owned());
63 }
64 }
65
66 if let Some(outcome) = &self.expected_outcome {
67 lines.push(String::new());
68 lines.push("Expected Outcome:".to_owned());
69 let wrapped = textwrap::fill(&format!(" {outcome}"), 78);
70 for wrapped_line in wrapped.lines() {
71 lines.push(wrapped_line.to_owned());
72 }
73 }
74
75 lines.push(String::new());
76 lines.push(format!("Risk Level: {}", self.risk_level));
77
78 lines
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct ApprovalPattern {
85 pub tool_name: String,
87 #[serde(default)]
89 pub display_name: Option<String>,
90 pub approve_count: u32,
92 pub deny_count: u32,
94 pub last_decision: Option<bool>,
96 pub recent_reason: Option<String>,
98}
99
100impl ApprovalPattern {
101 pub fn approval_rate(&self) -> f32 {
103 let total = self.approve_count + self.deny_count;
104 if total == 0 {
105 0.0
106 } else {
107 self.approve_count as f32 / total as f32
108 }
109 }
110
111 pub fn has_high_approval_rate(&self) -> bool {
113 self.approval_count() >= 3 && self.approval_rate() > 0.8
114 }
115
116 pub fn approval_count(&self) -> u32 {
118 self.approve_count
119 }
120
121 pub fn display_name<'a>(&'a self, fallback: &'a str) -> &'a str {
122 self.display_name.as_deref().unwrap_or(fallback)
123 }
124}
125
126fn merge_pattern_from_disk(local: &mut ApprovalPattern, disk: &ApprovalPattern) {
130 local.approve_count = local.approve_count.max(disk.approve_count);
131 local.deny_count = local.deny_count.max(disk.deny_count);
132 if disk.display_name.is_some() {
133 local.display_name = disk.display_name.clone();
134 }
135 if disk.last_decision.is_some() {
136 local.last_decision = disk.last_decision;
137 }
138 if disk.recent_reason.is_some() {
139 local.recent_reason = disk.recent_reason.clone();
140 }
141}
142
143pub struct JustificationManager {
145 cache_dir: PathBuf,
146 patterns: std::sync::Arc<std::sync::Mutex<HashMap<String, ApprovalPattern>>>,
147}
148
149impl JustificationManager {
150 pub fn new(cache_dir: PathBuf) -> Self {
152 let patterns = std::sync::Arc::new(std::sync::Mutex::new(HashMap::new()));
153 let manager = Self {
154 cache_dir,
155 patterns,
156 };
157
158 let _ = manager.load_patterns();
160
161 manager
162 }
163
164 fn load_patterns(&self) -> Result<()> {
171 let patterns_file = self.cache_dir.join("approval_patterns.json");
172 if !patterns_file.exists() {
173 return Ok(());
174 }
175
176 let content = read_file_with_context_sync(&patterns_file, "approval patterns cache")?;
177 let loaded_patterns: HashMap<String, ApprovalPattern> = serde_json::from_str(&content)?;
178
179 let mut patterns = self
180 .patterns
181 .lock()
182 .map_err(|e| anyhow::anyhow!("Failed to lock patterns: {}", e))?;
183
184 for (key, disk) in loaded_patterns {
185 patterns
186 .entry(key)
187 .and_modify(|local| merge_pattern_from_disk(local, &disk))
188 .or_insert(disk);
189 }
190
191 Ok(())
192 }
193
194 pub fn refresh_patterns(&self) -> Result<()> {
196 self.load_patterns()
197 }
198
199 pub fn get_pattern(&self, approval_key: &str) -> Option<ApprovalPattern> {
201 if let Ok(patterns) = self.patterns.lock() {
202 patterns.get(approval_key).cloned()
203 } else {
204 None
205 }
206 }
207
208 pub fn record_decision(
210 &self,
211 approval_key: &str,
212 display_name: Option<&str>,
213 approved: bool,
214 reason: Option<String>,
215 ) {
216 let should_persist = if let Ok(mut patterns) = self.patterns.lock() {
217 let pattern =
218 patterns
219 .entry(approval_key.to_owned())
220 .or_insert_with(|| ApprovalPattern {
221 tool_name: approval_key.to_owned(),
222 display_name: display_name.map(str::to_owned),
223 approve_count: 0,
224 deny_count: 0,
225 last_decision: None,
226 recent_reason: None,
227 });
228
229 if let Some(display_name) = display_name {
230 pattern.display_name = Some(display_name.to_owned());
231 }
232
233 if approved {
234 pattern.approve_count += 1;
235 } else {
236 pattern.deny_count += 1;
237 }
238
239 pattern.last_decision = Some(approved);
240 pattern.recent_reason = reason;
241 true
242 } else {
243 false
244 };
245
246 if should_persist {
248 let _ = self.persist_patterns();
249 }
250 }
251
252 fn persist_patterns(&self) -> Result<()> {
254 ensure_dir_exists_sync(&self.cache_dir)?;
255 let patterns_file = self.cache_dir.join("approval_patterns.json");
256 let patterns = self
257 .patterns
258 .lock()
259 .map_err(|e| anyhow::anyhow!("Failed to lock patterns: {}", e))?;
260 let content = serde_json::to_string_pretty(&*patterns)?;
261 write_file_with_context_sync(&patterns_file, &content, "approval patterns cache")?;
262 Ok(())
263 }
264
265 pub fn get_learning_summary(&self, approval_key: &str) -> Option<String> {
267 let pattern = self.get_pattern(approval_key)?;
268
269 if pattern.approval_count() == 0 {
270 return None;
271 }
272
273 Some(format!(
274 "Approved {} of {} times ({:.0}%)",
275 pattern.approve_count,
276 pattern.approve_count + pattern.deny_count,
277 pattern.approval_rate() * 100.0
278 ))
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_tool_justification_creation() {
288 let just = ToolJustification::new(
289 "read_file",
290 "Need to understand code structure",
291 &RiskLevel::Low,
292 )
293 .with_outcome("Will analyze the AST to provide better context");
294
295 assert_eq!(just.tool_name, "read_file");
296 assert!(just.reason.contains("understand"));
297 assert!(just.expected_outcome.is_some());
298 }
299
300 #[test]
301 fn test_justification_formatting() {
302 let just = ToolJustification::new(
303 "run_command",
304 "Execute build to check for compilation errors",
305 &RiskLevel::High,
306 )
307 .with_outcome("Will produce build output for analysis");
308
309 let formatted = just.format_for_dialog();
310 assert!(formatted.iter().any(|l| l.contains("Agent Reasoning")));
311 assert!(formatted.iter().any(|l| l.contains("Expected Outcome")));
312 assert!(formatted.iter().any(|l| l.contains("Risk Level")));
313 }
314
315 #[test]
316 fn test_approval_pattern_calculation() {
317 let mut pattern = ApprovalPattern {
318 tool_name: "read_file".to_owned(),
319 display_name: None,
320 approve_count: 9,
321 deny_count: 1,
322 last_decision: Some(true),
323 recent_reason: None,
324 };
325
326 assert_eq!(pattern.approval_rate(), 0.9);
327 assert!(pattern.has_high_approval_rate());
328
329 pattern.approve_count = 3;
330 pattern.deny_count = 7;
331 assert!(!pattern.has_high_approval_rate()); }
333
334 #[test]
335 fn test_justification_manager_basic() {
336 let temp_dir = std::env::temp_dir().join(format!("vtcode_test_{}", std::process::id()));
337 let manager = JustificationManager::new(temp_dir.clone());
338
339 manager.record_decision("read_file", Some("Read File"), true, None);
340 manager.record_decision("read_file", Some("Read File"), true, None);
341 manager.record_decision("read_file", Some("Read File"), false, None);
342
343 let pattern = manager.get_pattern("read_file").unwrap();
344 assert_eq!(pattern.approve_count, 2);
345 assert_eq!(pattern.deny_count, 1);
346 assert_eq!(pattern.approval_rate(), 2.0 / 3.0);
347 assert_eq!(pattern.display_name.as_deref(), Some("Read File"));
348
349 let _ = std::fs::remove_dir_all(&temp_dir);
351 }
352
353 #[test]
354 fn test_justification_manager_preserves_new_display_name() {
355 let temp_dir = std::env::temp_dir().join(format!("vtcode_test_{}", std::process::id()));
356 let manager = JustificationManager::new(temp_dir.clone());
357
358 manager.record_decision("shell:key", Some("command `cargo test`"), true, None);
359 manager.record_decision(
360 "shell:key",
361 Some("commands starting with `cargo`"),
362 true,
363 None,
364 );
365
366 let pattern = manager.get_pattern("shell:key").unwrap();
367 assert_eq!(
368 pattern.display_name.as_deref(),
369 Some("commands starting with `cargo`")
370 );
371
372 let _ = std::fs::remove_dir_all(&temp_dir);
374 }
375}