Skip to main content

vtcode_core/tools/
improvements_registry_ext.rs

1//! Real integration with ToolRegistry
2//!
3//! Extends ToolRegistry with improvement capabilities:
4//! - Tool effectiveness tracking
5//! - Intelligent tool selection based on patterns
6//! - Result caching and optimization (migrated to UnifiedCache)
7//! - Observability integration
8
9use crate::cache::{CacheKey, DEFAULT_CACHE_TTL, EvictionPolicy, UnifiedCache};
10use crate::tools::{
11    improvements_errors::ObservabilityContext,
12    pattern_engine::{ExecutionEvent, PatternEngine},
13};
14use std::sync::Arc;
15
16/// Cache key for tool results
17#[derive(Debug, Clone, Hash, PartialEq, Eq)]
18struct ToolResultKey(String);
19
20impl CacheKey for ToolResultKey {
21    fn to_cache_key(&self) -> String {
22        self.0.clone()
23    }
24}
25
26/// Tool effectiveness metrics
27#[derive(Clone, Debug)]
28pub struct ToolMetrics {
29    pub name: String,
30    pub total_calls: usize,
31    pub successful_calls: usize,
32    pub total_duration_ms: u64,
33    pub avg_quality: f32,
34}
35
36impl ToolMetrics {
37    pub fn success_rate(&self) -> f32 {
38        if self.total_calls == 0 {
39            0.0
40        } else {
41            self.successful_calls as f32 / self.total_calls as f32
42        }
43    }
44
45    pub fn avg_duration_ms(&self) -> u64 {
46        if self.total_calls == 0 {
47            0
48        } else {
49            self.total_duration_ms / self.total_calls as u64
50        }
51    }
52}
53
54/// ToolRegistry improvement extension (migrated to UnifiedCache)
55pub struct ToolRegistryImprovement {
56    pattern_engine: Arc<PatternEngine>,
57    tool_metrics: Arc<parking_lot::RwLock<hashbrown::HashMap<String, ToolMetrics>>>,
58    result_cache: Arc<parking_lot::RwLock<UnifiedCache<ToolResultKey, String>>>,
59    obs_context: Arc<ObservabilityContext>,
60}
61
62impl ToolRegistryImprovement {
63    /// Create new registry extension (now using UnifiedCache)
64    pub fn new(obs_context: Arc<ObservabilityContext>) -> Self {
65        Self {
66            pattern_engine: Arc::new(PatternEngine::new(1000, 20)),
67            tool_metrics: Arc::new(parking_lot::RwLock::new(hashbrown::HashMap::new())),
68            result_cache: Arc::new(parking_lot::RwLock::new(UnifiedCache::new(
69                10000,
70                DEFAULT_CACHE_TTL,
71                EvictionPolicy::Lru,
72            ))),
73            obs_context,
74        }
75    }
76
77    /// Record tool execution
78    pub fn record_execution(
79        &self,
80        tool_name: &str,
81        arguments: &str,
82        success: bool,
83        quality_score: f32,
84        duration_ms: u64,
85    ) {
86        // Update metrics
87        {
88            let mut metrics = self.tool_metrics.write();
89            let entry = if let Some(m) = metrics.get_mut(tool_name) {
90                m
91            } else {
92                metrics
93                    .entry(tool_name.to_string())
94                    .or_insert_with(|| ToolMetrics {
95                        name: tool_name.to_string(),
96                        total_calls: 0,
97                        successful_calls: 0,
98                        total_duration_ms: 0,
99                        avg_quality: 0.0,
100                    })
101            };
102
103            entry.total_calls += 1;
104            if success {
105                entry.successful_calls += 1;
106            }
107            entry.total_duration_ms += duration_ms;
108            entry.avg_quality = (entry.avg_quality * (entry.total_calls as f32 - 1.0)
109                + quality_score)
110                / entry.total_calls as f32;
111        }
112
113        // Record pattern event
114        let now = std::time::SystemTime::now()
115            .duration_since(std::time::UNIX_EPOCH)
116            .unwrap_or_default()
117            .as_secs();
118
119        self.pattern_engine.record(ExecutionEvent {
120            tool_name: tool_name.to_string(),
121            arguments: arguments.to_string(),
122            success,
123            quality_score,
124            duration_ms,
125            timestamp: now,
126        });
127
128        // Log metric - avoid format! if possible, but for key names it's needed
129        let metric_key = format!("{}_success_rate", tool_name);
130        self.obs_context.metric("tool_effectiveness", &metric_key, {
131            let metrics = self.tool_metrics.read();
132            metrics
133                .get(tool_name)
134                .map(|m| m.success_rate())
135                .unwrap_or(0.0)
136        });
137    }
138
139    /// Get tool metrics
140    pub fn get_tool_metrics(&self, tool_name: &str) -> Option<ToolMetrics> {
141        self.tool_metrics.read().get(tool_name).cloned()
142    }
143
144    /// Get all tool metrics
145    pub fn get_all_metrics(&self) -> Vec<ToolMetrics> {
146        self.tool_metrics.read().values().cloned().collect()
147    }
148
149    /// Get execution summary
150    pub fn get_summary(&self) -> crate::tools::pattern_engine::ExecutionSummary {
151        self.pattern_engine.summary()
152    }
153
154    /// Predict next tool based on patterns
155    pub fn predict_next_tool(&self) -> Option<String> {
156        self.pattern_engine.predict_next_tool()
157    }
158
159    /// Cache result for tool execution (migrated to UnifiedCache)
160    pub fn cache_result(&self, tool: &str, args: &str, result: &str) {
161        let key = ToolResultKey(format!("{}::{}", tool, args));
162        let size = result.len() as u64;
163        self.result_cache
164            .write()
165            .insert(key, result.to_string(), size);
166    }
167
168    /// Try to get cached result (migrated to UnifiedCache)
169    pub fn get_cached_result(&self, tool: &str, args: &str) -> Option<String> {
170        let key = ToolResultKey(format!("{}::{}", tool, args));
171        self.result_cache.write().get_owned(&key)
172    }
173
174    /// Clear cache (migrated to UnifiedCache)
175    pub fn clear_cache(&self) {
176        self.result_cache.write().clear();
177    }
178
179    /// Get cache stats (migrated to UnifiedCache)
180    pub fn cache_stats(&self) -> crate::cache::CacheStats {
181        self.result_cache.read().stats().clone()
182    }
183
184    /// Rank tools by effectiveness
185    pub fn rank_tools(&self) -> Vec<(String, f32)> {
186        let metrics = self.tool_metrics.read();
187        let mut tools: Vec<_> = metrics
188            .values()
189            .map(|m| {
190                let score = (m.success_rate() * 0.6)
191                    + ((1.0 - (m.avg_duration_ms() as f32 / 5000.0).min(1.0)) * 0.4);
192                (&m.name, score)
193            })
194            .collect();
195
196        tools.sort_unstable_by(|a, b| b.1.total_cmp(&a.1));
197        tools.into_iter().map(|(n, s)| (n.clone(), s)).collect()
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::config::constants::tools;
205
206    #[test]
207    fn test_record_execution() {
208        let obs = Arc::new(ObservabilityContext::noop());
209        let ext = ToolRegistryImprovement::new(obs);
210
211        ext.record_execution(tools::UNIFIED_SEARCH, "pattern", true, 0.8, 100);
212
213        let metrics = ext.get_tool_metrics(tools::UNIFIED_SEARCH);
214        assert!(metrics.is_some());
215        assert_eq!(metrics.unwrap().success_rate(), 1.0);
216    }
217
218    #[test]
219    fn test_cache_result() {
220        let obs = Arc::new(ObservabilityContext::noop());
221        let ext = ToolRegistryImprovement::new(obs);
222
223        ext.cache_result(tools::UNIFIED_SEARCH, "pattern", "result");
224        assert_eq!(
225            ext.get_cached_result(tools::UNIFIED_SEARCH, "pattern"),
226            Some("result".to_owned())
227        );
228    }
229
230    #[test]
231    fn test_rank_tools() {
232        let obs = Arc::new(ObservabilityContext::noop());
233        let ext = ToolRegistryImprovement::new(obs);
234
235        ext.record_execution("tool1", "arg", true, 0.9, 100);
236        ext.record_execution("tool2", "arg", false, 0.3, 50);
237
238        let ranked = ext.rank_tools();
239        assert_eq!(ranked[0].0, "tool1"); // Higher success rate
240    }
241}