Skip to main content

vtcode_core/tools/
improvement_algorithms.rs

1//! Production-grade algorithms for tool improvements.
2//!
3//! Provides: Jaro-Winkler string similarity, time-decay effectiveness scoring,
4//! pattern detection over tool execution history, and ML-ready feature vectors.
5
6use crate::utils::current_timestamp;
7use serde::{Deserialize, Serialize};
8use smallvec::SmallVec;
9
10// ── String similarity ─────────────────────────────────────────────────────────
11
12/// Jaro-Winkler string similarity in [0.0, 1.0].
13///
14/// Preferred over Levenshtein for short strings (tool arguments) because it
15/// rewards matching prefixes, which is common in tool argument patterns.
16pub fn jaro_winkler_similarity(s1: &str, s2: &str) -> f32 {
17    if s1 == s2 {
18        return 1.0;
19    }
20    if s1.is_empty() || s2.is_empty() {
21        return 0.0;
22    }
23
24    let jaro = jaro_similarity(s1, s2);
25
26    // Common prefix length, capped at 4 (standard Winkler constant).
27    let prefix_len = s1
28        .chars()
29        .zip(s2.chars())
30        .take_while(|(a, b)| a == b)
31        .take(4)
32        .count();
33
34    // Winkler boost: p = 0.1 (standard scaling factor).
35    jaro + (prefix_len as f32 * 0.1 * (1.0 - jaro))
36}
37
38/// Collect matched characters for Jaro similarity.
39/// Returns the number of matches found.
40fn jaro_collect_matches(
41    s1c: &[char],
42    s2c: &[char],
43    s1_matched: &mut [bool],
44    s2_matched: &mut [bool],
45    window: usize,
46) -> usize {
47    let len2 = s2c.len();
48    let mut matches = 0usize;
49
50    for (i, &c1) in s1c.iter().enumerate() {
51        let lo = i.saturating_sub(window);
52        let hi = (i + window + 1).min(len2);
53        if lo >= len2 {
54            continue;
55        }
56        // Reslice to the search window so LLVM sees the exact bounds.
57        for (s2_char, s2_m) in s2c[lo..hi].iter().zip(s2_matched[lo..hi].iter_mut()) {
58            if !*s2_m && c1 == *s2_char {
59                s1_matched[i] = true;
60                *s2_m = true;
61                matches += 1;
62                break;
63            }
64        }
65    }
66
67    matches
68}
69
70/// Count transpositions between matched character pairs.
71fn jaro_count_transpositions(
72    s1c: &[char],
73    s2c: &[char],
74    s1_matched: &[bool],
75    s2_matched: &[bool],
76) -> usize {
77    let mut transpositions = 0usize;
78
79    let mut s2_matched_iter = s2c.iter().zip(s2_matched.iter()).filter(|&(_, &m)| m);
80
81    for (&a, _) in s1c.iter().zip(s1_matched.iter()).filter(|&(_, &m)| m) {
82        if let Some((&b, _)) = s2_matched_iter.next() {
83            if a != b {
84                transpositions += 1;
85            }
86        } else {
87            break;
88        }
89    }
90
91    transpositions / 2
92}
93
94/// Jaro similarity in [0.0, 1.0].
95fn jaro_similarity(s1: &str, s2: &str) -> f32 {
96    // Use SmallVec to avoid heap allocation for common short strings.
97    let s1c: SmallVec<[char; 64]> = s1.chars().collect();
98    let s2c: SmallVec<[char; 64]> = s2.chars().collect();
99    let len1 = s1c.len();
100    let len2 = s2c.len();
101
102    if len1 == 0 && len2 == 0 {
103        return 1.0;
104    }
105    if len1 == 0 || len2 == 0 {
106        return 0.0;
107    }
108
109    let window = (len1.max(len2) >> 1).saturating_sub(1);
110
111    let mut s1_matched = SmallVec::<[bool; 64]>::from_elem(false, len1);
112    let mut s2_matched = SmallVec::<[bool; 64]>::from_elem(false, len2);
113    let matches = jaro_collect_matches(&s1c, &s2c, &mut s1_matched, &mut s2_matched, window);
114
115    if matches == 0 {
116        return 0.0;
117    }
118
119    let transpositions = jaro_count_transpositions(&s1c, &s2c, &s1_matched, &s2_matched);
120
121    let m = matches as f32;
122    (m / len1 as f32 + m / len2 as f32 + (m - transpositions as f32) / m) / 3.0
123}
124
125// ── Time-decay scoring ────────────────────────────────────────────────────────
126
127/// Time-decay effectiveness score.
128///
129/// Recent successes are weighted higher. Decay follows:
130/// `score × exp(−λ × age_hours)`, default λ = 0.1 per hour.
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct TimeDecayedScore {
133    /// Base score (0.0–1.0).
134    pub base_score: f32,
135    /// Age in seconds.
136    pub age_seconds: u64,
137    /// Decay constant.
138    pub decay_lambda: f32,
139    /// Decayed score (0.0–1.0).
140    pub decayed_score: f32,
141}
142
143impl TimeDecayedScore {
144    /// Calculate a time-decayed score for `base_score` recorded at `timestamp`.
145    pub fn calculate(base_score: f32, timestamp: u64) -> Self {
146        const DEFAULT_LAMBDA: f32 = 0.1;
147        let now = current_timestamp();
148        let age_seconds = now.saturating_sub(timestamp);
149        let age_hours = age_seconds as f32 / 3600.0;
150        let decayed_score = (base_score * (-DEFAULT_LAMBDA * age_hours).exp()).clamp(0.0, 1.0);
151
152        Self {
153            base_score,
154            age_seconds,
155            decay_lambda: DEFAULT_LAMBDA,
156            decayed_score,
157        }
158    }
159
160    /// Return a copy with a custom decay constant applied.
161    pub fn with_decay(mut self, lambda: f32) -> Self {
162        let age_hours = self.age_seconds as f32 / 3600.0;
163        self.decayed_score = (self.base_score * (-lambda * age_hours).exp()).clamp(0.0, 1.0);
164        self.decay_lambda = lambda;
165        self
166    }
167}
168
169// ── Pattern detection ─────────────────────────────────────────────────────────
170
171/// Detected state in a tool-execution sequence.
172#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
173pub enum PatternState {
174    /// Single execution.
175    Single,
176    /// Two identical executions.
177    Duplicate,
178    /// Multiple identical executions (3+).
179    Loop,
180    /// Executions with slight variation (fuzzy match).
181    NearLoop,
182    /// Sequential quality improvement.
183    RefinementChain,
184    /// Multiple tools converging to similar quality.
185    Convergence,
186    /// Quality degrading over iterations.
187    Degradation,
188}
189
190/// Detect a pattern in a history of `(tool, args_hash, quality)` triples.
191///
192/// Only the last `window_size` entries are examined.
193/// Returns [`PatternState::Single`] for an empty or single-entry history.
194///
195/// # Why a free function?
196/// The logic is stateless — `window_size` is the only parameter. A wrapper
197/// struct added no encapsulation and hurt discoverability (KISS).
198pub fn detect_pattern(
199    history: &[(String, String, f32)], // (tool, args_hash, quality)
200    window_size: usize,
201) -> PatternState {
202    if history.is_empty() {
203        return PatternState::Single;
204    }
205
206    // Work on the most recent `window_size` entries; borrow the slice directly
207    // (no intermediate Vec allocation).
208    let start = history.len().saturating_sub(window_size);
209    let recent = &history[start..];
210
211    if recent.len() < 2 {
212        return PatternState::Single;
213    }
214
215    let first = &recent[0];
216
217    // --- Exact duplicates ---
218    if recent.iter().all(|r| r.0 == first.0 && r.1 == first.1) {
219        return if recent.len() >= 3 {
220            PatternState::Loop
221        } else {
222            PatternState::Duplicate
223        };
224    }
225
226    // --- Single combined scan -------------------------------------------------
227    // Collect qualities + same-tool flag + multi-tool flag in one pass.
228    let mut qualities = SmallVec::<[f32; 32]>::with_capacity(recent.len());
229    let mut same_tool = true;
230    let mut multi_tool = false;
231
232    for r in recent {
233        qualities.push(r.2);
234        if r.0 != first.0 {
235            same_tool = false;
236            multi_tool = true;
237        }
238    }
239
240    // --- Quality trends (≥3 points) ---
241    if qualities.len() >= 3 {
242        if qualities.windows(2).all(|w| w[1] > w[0] + 0.05) {
243            return PatternState::RefinementChain;
244        }
245        if qualities.windows(2).all(|w| w[1] < w[0] - 0.05) {
246            return PatternState::Degradation;
247        }
248    }
249
250    // --- Near-loop: same tool, fuzzy args ---
251    // Use `.all()` directly — no intermediate Vec<f32> needed.
252    if same_tool
253        && recent.len() >= 3
254        && recent
255            .windows(2)
256            .all(|w| jaro_winkler_similarity(&w[0].1, &w[1].1) > 0.85)
257    {
258        return PatternState::NearLoop;
259    }
260
261    // --- Convergence: different tools, similar quality ---
262    if multi_tool {
263        let n = qualities.len() as f32;
264        let avg = qualities.iter().sum::<f32>() / n;
265        if qualities.iter().all(|&q| (q - avg).abs() < 0.1) {
266            return PatternState::Convergence;
267        }
268    }
269
270    PatternState::Single
271}
272
273// ── ML scoring ────────────────────────────────────────────────────────────────
274
275/// ML-ready scoring components for tool effectiveness.
276///
277/// Can be used as a feature vector for training models.
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct MLScoreComponents {
280    /// Success rate (0–1).
281    pub success_rate: f32,
282    /// Average execution time (ms).
283    pub avg_execution_time: f32,
284    /// Result quality (0–1).
285    pub result_quality: f32,
286    /// Number of failure modes observed.
287    pub failure_count: usize,
288    /// Time since last use (hours).
289    pub age_hours: f32,
290    /// Usage frequency (calls per hour).
291    pub frequency: f32,
292    /// Confidence in measurement (0–1).
293    pub confidence: f32,
294}
295
296impl MLScoreComponents {
297    /// Combined ML score before time decay.
298    ///
299    /// Weights: success 40% + quality 30% + speed 15% + frequency 15%.
300    pub fn raw_score(&self) -> f32 {
301        (self.success_rate * 0.40)
302            + (self.result_quality * 0.30)
303            + ((10_000.0 - self.avg_execution_time).max(0.0) / 10_000.0 * 0.15)
304            + (self.frequency.min(1.0) * 0.15)
305    }
306
307    /// Apply confidence decay for older measurements (1-week half-life).
308    pub fn with_age_decay(mut self) -> Self {
309        // Older measurements are less reliable; decay confidence over time.
310        self.confidence = (self.confidence * (-self.age_hours / 168.0).exp()).max(0.1);
311        self
312    }
313
314    /// Return a 7-element feature vector, normalised for ML consumption.
315    pub fn to_feature_vector(&self) -> [f32; 7] {
316        [
317            self.success_rate,
318            self.avg_execution_time / 10_000.0,
319            self.result_quality,
320            (self.failure_count as f32).min(10.0) / 10.0,
321            self.age_hours / 168.0,
322            self.frequency,
323            self.confidence,
324        ]
325    }
326}
327
328// ── Tests ─────────────────────────────────────────────────────────────────────
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_jaro_winkler_exact() {
336        assert_eq!(jaro_winkler_similarity("hello", "hello"), 1.0);
337    }
338
339    #[test]
340    fn test_jaro_winkler_partial() {
341        let sim = jaro_winkler_similarity("pattern", "pattern_file");
342        assert!(sim > 0.85 && sim < 1.0, "sim={sim}");
343    }
344
345    #[test]
346    fn test_jaro_winkler_prefix_boost() {
347        let with_prefix = jaro_winkler_similarity("test_one", "test_two");
348        let without = jaro_winkler_similarity("one_test", "two_test");
349        assert!(with_prefix > without, "prefix boost should be applied");
350    }
351
352    #[test]
353    fn test_time_decay_ordering() {
354        let now = current_timestamp();
355        let recent = TimeDecayedScore::calculate(0.9, now);
356        let old = TimeDecayedScore::calculate(0.9, now.saturating_sub(7 * 24 * 3600));
357        assert!(old.decayed_score < recent.decayed_score);
358    }
359
360    #[test]
361    fn test_detect_pattern_loop() {
362        let history = vec![
363            ("grep".to_string(), "pattern1".to_string(), 0.5),
364            ("grep".to_string(), "pattern1".to_string(), 0.5),
365            ("grep".to_string(), "pattern1".to_string(), 0.5),
366        ];
367        assert_eq!(detect_pattern(&history, 10), PatternState::Loop);
368    }
369
370    #[test]
371    fn test_detect_pattern_refinement() {
372        let history = vec![
373            ("grep".to_string(), "pat1".to_string(), 0.3),
374            ("grep".to_string(), "pat2".to_string(), 0.5),
375            ("grep".to_string(), "pat3".to_string(), 0.8),
376        ];
377        assert_eq!(detect_pattern(&history, 10), PatternState::RefinementChain);
378    }
379
380    #[test]
381    fn test_detect_pattern_near_loop_requires_three_entries() {
382        let two_entries = vec![
383            ("grep".to_string(), "pattern-one".to_string(), 0.4),
384            ("grep".to_string(), "pattern-two".to_string(), 0.45),
385        ];
386        assert_eq!(detect_pattern(&two_entries, 10), PatternState::Single);
387
388        let three_entries = vec![
389            ("grep".to_string(), "pattern-one".to_string(), 0.4),
390            ("grep".to_string(), "pattern-two".to_string(), 0.45),
391            ("grep".to_string(), "pattern-three".to_string(), 0.5),
392        ];
393        assert_eq!(detect_pattern(&three_entries, 10), PatternState::NearLoop);
394    }
395
396    #[test]
397    fn test_ml_raw_score() {
398        let c = MLScoreComponents {
399            success_rate: 0.9,
400            avg_execution_time: 100.0,
401            result_quality: 0.85,
402            failure_count: 1,
403            age_hours: 2.0,
404            frequency: 0.5,
405            confidence: 0.9,
406        };
407        let score = c.raw_score();
408        assert!(score > 0.7 && score < 1.0, "score={score}");
409    }
410
411    #[test]
412    fn test_ml_feature_vector_length() {
413        let c = MLScoreComponents {
414            success_rate: 0.9,
415            avg_execution_time: 100.0,
416            result_quality: 0.85,
417            failure_count: 1,
418            age_hours: 2.0,
419            frequency: 0.5,
420            confidence: 0.9,
421        };
422        assert_eq!(c.to_feature_vector().len(), 7);
423    }
424}