Skip to main content

vtcode_core/tools/registry/
justification.rs

1/// Tool Justification System
2///
3/// Captures agent reasoning before high-risk tool execution to improve approval UX
4/// and enable learning of approval patterns.
5use crate::tools::registry::risk_scorer::RiskLevel;
6use crate::utils::file_utils::{
7    ensure_dir_exists_sync, read_file_with_context_sync, write_file_with_context_sync,
8};
9use anyhow::Result;
10use hashbrown::HashMap;
11use serde::{Deserialize, Serialize};
12use std::path::PathBuf;
13
14/// Justification provided by the agent for executing a high-risk tool
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct ToolJustification {
17    /// Tool being justified
18    pub tool_name: String,
19    /// Brief explanation from the agent
20    pub reason: String,
21    /// Expected outcome of tool execution
22    pub expected_outcome: Option<String>,
23    /// Risk level that triggered justification
24    pub risk_level: String,
25    /// Timestamp when justification was provided
26    pub timestamp: String,
27}
28
29impl ToolJustification {
30    /// Create a new tool justification
31    pub fn new(
32        tool_name: impl Into<String>,
33        reason: impl Into<String>,
34        risk_level: &RiskLevel,
35    ) -> Self {
36        Self {
37            tool_name: tool_name.into(),
38            reason: reason.into(),
39            expected_outcome: None,
40            risk_level: format!("{:?}", risk_level),
41            timestamp: chrono::Local::now().to_rfc3339(),
42        }
43    }
44
45    /// Add expected outcome to justification
46    pub fn with_outcome(mut self, outcome: impl Into<String>) -> Self {
47        self.expected_outcome = Some(outcome.into());
48        self
49    }
50
51    /// Format justification for display in approval dialog
52    pub fn format_for_dialog(&self) -> Vec<String> {
53        let mut lines = vec![];
54
55        lines.push(String::new());
56        lines.push("Agent Reasoning:".to_owned());
57
58        // Wrap reason text if needed - iterate directly without collecting
59        for line in self.reason.lines() {
60            let wrapped = textwrap::fill(&format!("  {line}"), 78);
61            for wrapped_line in wrapped.lines() {
62                lines.push(wrapped_line.to_owned());
63            }
64        }
65
66        if let Some(outcome) = &self.expected_outcome {
67            lines.push(String::new());
68            lines.push("Expected Outcome:".to_owned());
69            let wrapped = textwrap::fill(&format!("  {outcome}"), 78);
70            for wrapped_line in wrapped.lines() {
71                lines.push(wrapped_line.to_owned());
72            }
73        }
74
75        lines.push(String::new());
76        lines.push(format!("Risk Level: {}", self.risk_level));
77
78        lines
79    }
80}
81
82/// Tracks approval patterns to learn from user decisions
83#[derive(Debug, Clone, Serialize, Deserialize, Default)]
84pub struct ApprovalPattern {
85    /// Stable approval key used for lookup and persistence
86    pub tool_name: String,
87    /// Human-readable label for prompts and summaries
88    #[serde(default)]
89    pub display_name: Option<String>,
90    /// Number of times user approved
91    pub approve_count: u32,
92    /// Number of times user denied
93    pub deny_count: u32,
94    /// Last decision (true = approve, false = deny)
95    pub last_decision: Option<bool>,
96    /// Most recent reason (if available)
97    pub recent_reason: Option<String>,
98}
99
100impl ApprovalPattern {
101    /// Compute approval rate (0.0 to 1.0)
102    pub fn approval_rate(&self) -> f32 {
103        let total = self.approve_count + self.deny_count;
104        if total == 0 {
105            0.0
106        } else {
107            self.approve_count as f32 / total as f32
108        }
109    }
110
111    /// Check if this tool has high approval rate (>80%)
112    pub fn has_high_approval_rate(&self) -> bool {
113        self.approval_count() >= 3 && self.approval_rate() > 0.8
114    }
115
116    /// Return approval count
117    pub fn approval_count(&self) -> u32 {
118        self.approve_count
119    }
120
121    pub fn display_name<'a>(&'a self, fallback: &'a str) -> &'a str {
122        self.display_name.as_deref().unwrap_or(fallback)
123    }
124}
125
126/// Merge an on-disk pattern into the in-memory entry by taking the max of
127/// counters and preferring any non-`None` metadata from disk. Conservative:
128/// undercount is safer (more prompts) than overcount (fewer prompts → risk).
129fn merge_pattern_from_disk(local: &mut ApprovalPattern, disk: &ApprovalPattern) {
130    local.approve_count = local.approve_count.max(disk.approve_count);
131    local.deny_count = local.deny_count.max(disk.deny_count);
132    if disk.display_name.is_some() {
133        local.display_name = disk.display_name.clone();
134    }
135    if disk.last_decision.is_some() {
136        local.last_decision = disk.last_decision;
137    }
138    if disk.recent_reason.is_some() {
139        local.recent_reason = disk.recent_reason.clone();
140    }
141}
142
143/// Manager for approval pattern learning and justifications
144pub struct JustificationManager {
145    cache_dir: PathBuf,
146    patterns: std::sync::Arc<std::sync::Mutex<HashMap<String, ApprovalPattern>>>,
147}
148
149impl JustificationManager {
150    /// Create a new justification manager
151    pub fn new(cache_dir: PathBuf) -> Self {
152        let patterns = std::sync::Arc::new(std::sync::Mutex::new(HashMap::new()));
153        let manager = Self {
154            cache_dir,
155            patterns,
156        };
157
158        // Try to load existing patterns
159        let _ = manager.load_patterns();
160
161        manager
162    }
163
164    /// Load approval patterns from disk and merge into the in-memory map.
165    ///
166    /// Merging (rather than replacing) keeps in-memory increments that have not
167    /// yet been flushed to disk — important when refreshing right before an
168    /// auto-approval check while a concurrent vtcode session may have written
169    /// newer counts to the same file.
170    fn load_patterns(&self) -> Result<()> {
171        let patterns_file = self.cache_dir.join("approval_patterns.json");
172        if !patterns_file.exists() {
173            return Ok(());
174        }
175
176        let content = read_file_with_context_sync(&patterns_file, "approval patterns cache")?;
177        let loaded_patterns: HashMap<String, ApprovalPattern> = serde_json::from_str(&content)?;
178
179        let mut patterns = self
180            .patterns
181            .lock()
182            .map_err(|e| anyhow::anyhow!("Failed to lock patterns: {}", e))?;
183
184        for (key, disk) in loaded_patterns {
185            patterns
186                .entry(key)
187                .and_modify(|local| merge_pattern_from_disk(local, &disk))
188                .or_insert(disk);
189        }
190
191        Ok(())
192    }
193
194    /// Re-read patterns from disk, merging with any in-memory state.
195    pub fn refresh_patterns(&self) -> Result<()> {
196        self.load_patterns()
197    }
198
199    /// Get approval pattern for a key
200    pub fn get_pattern(&self, approval_key: &str) -> Option<ApprovalPattern> {
201        if let Ok(patterns) = self.patterns.lock() {
202            patterns.get(approval_key).cloned()
203        } else {
204            None
205        }
206    }
207
208    /// Record user approval decision
209    pub fn record_decision(
210        &self,
211        approval_key: &str,
212        display_name: Option<&str>,
213        approved: bool,
214        reason: Option<String>,
215    ) {
216        let should_persist = if let Ok(mut patterns) = self.patterns.lock() {
217            let pattern =
218                patterns
219                    .entry(approval_key.to_owned())
220                    .or_insert_with(|| ApprovalPattern {
221                        tool_name: approval_key.to_owned(),
222                        display_name: display_name.map(str::to_owned),
223                        approve_count: 0,
224                        deny_count: 0,
225                        last_decision: None,
226                        recent_reason: None,
227                    });
228
229            if let Some(display_name) = display_name {
230                pattern.display_name = Some(display_name.to_owned());
231            }
232
233            if approved {
234                pattern.approve_count += 1;
235            } else {
236                pattern.deny_count += 1;
237            }
238
239            pattern.last_decision = Some(approved);
240            pattern.recent_reason = reason;
241            true
242        } else {
243            false
244        };
245
246        // Persist to disk after releasing the lock.
247        if should_persist {
248            let _ = self.persist_patterns();
249        }
250    }
251
252    /// Persist patterns to disk
253    fn persist_patterns(&self) -> Result<()> {
254        ensure_dir_exists_sync(&self.cache_dir)?;
255        let patterns_file = self.cache_dir.join("approval_patterns.json");
256        let patterns = self
257            .patterns
258            .lock()
259            .map_err(|e| anyhow::anyhow!("Failed to lock patterns: {}", e))?;
260        let content = serde_json::to_string_pretty(&*patterns)?;
261        write_file_with_context_sync(&patterns_file, &content, "approval patterns cache")?;
262        Ok(())
263    }
264
265    /// Get learning summary for a key
266    pub fn get_learning_summary(&self, approval_key: &str) -> Option<String> {
267        let pattern = self.get_pattern(approval_key)?;
268
269        if pattern.approval_count() == 0 {
270            return None;
271        }
272
273        Some(format!(
274            "Approved {} of {} times ({:.0}%)",
275            pattern.approve_count,
276            pattern.approve_count + pattern.deny_count,
277            pattern.approval_rate() * 100.0
278        ))
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_tool_justification_creation() {
288        let just = ToolJustification::new(
289            "read_file",
290            "Need to understand code structure",
291            &RiskLevel::Low,
292        )
293        .with_outcome("Will analyze the AST to provide better context");
294
295        assert_eq!(just.tool_name, "read_file");
296        assert!(just.reason.contains("understand"));
297        assert!(just.expected_outcome.is_some());
298    }
299
300    #[test]
301    fn test_justification_formatting() {
302        let just = ToolJustification::new(
303            "run_command",
304            "Execute build to check for compilation errors",
305            &RiskLevel::High,
306        )
307        .with_outcome("Will produce build output for analysis");
308
309        let formatted = just.format_for_dialog();
310        assert!(formatted.iter().any(|l| l.contains("Agent Reasoning")));
311        assert!(formatted.iter().any(|l| l.contains("Expected Outcome")));
312        assert!(formatted.iter().any(|l| l.contains("Risk Level")));
313    }
314
315    #[test]
316    fn test_approval_pattern_calculation() {
317        let mut pattern = ApprovalPattern {
318            tool_name: "read_file".to_owned(),
319            display_name: None,
320            approve_count: 9,
321            deny_count: 1,
322            last_decision: Some(true),
323            recent_reason: None,
324        };
325
326        assert_eq!(pattern.approval_rate(), 0.9);
327        assert!(pattern.has_high_approval_rate());
328
329        pattern.approve_count = 3;
330        pattern.deny_count = 7;
331        assert!(!pattern.has_high_approval_rate()); // < 0.8 rate
332    }
333
334    #[test]
335    fn test_justification_manager_basic() {
336        let temp_dir = std::env::temp_dir().join(format!("vtcode_test_{}", std::process::id()));
337        let manager = JustificationManager::new(temp_dir.clone());
338
339        manager.record_decision("read_file", Some("Read File"), true, None);
340        manager.record_decision("read_file", Some("Read File"), true, None);
341        manager.record_decision("read_file", Some("Read File"), false, None);
342
343        let pattern = manager.get_pattern("read_file").unwrap();
344        assert_eq!(pattern.approve_count, 2);
345        assert_eq!(pattern.deny_count, 1);
346        assert_eq!(pattern.approval_rate(), 2.0 / 3.0);
347        assert_eq!(pattern.display_name.as_deref(), Some("Read File"));
348
349        // Cleanup
350        let _ = std::fs::remove_dir_all(&temp_dir);
351    }
352
353    #[test]
354    fn test_justification_manager_preserves_new_display_name() {
355        let temp_dir = std::env::temp_dir().join(format!("vtcode_test_{}", std::process::id()));
356        let manager = JustificationManager::new(temp_dir.clone());
357
358        manager.record_decision("shell:key", Some("command `cargo test`"), true, None);
359        manager.record_decision(
360            "shell:key",
361            Some("commands starting with `cargo`"),
362            true,
363            None,
364        );
365
366        let pattern = manager.get_pattern("shell:key").unwrap();
367        assert_eq!(
368            pattern.display_name.as_deref(),
369            Some("commands starting with `cargo`")
370        );
371
372        // Cleanup
373        let _ = std::fs::remove_dir_all(&temp_dir);
374    }
375}