Skip to main content

winx_code_agent/utils/
path_prob.rs

1//! File-path relevance ranking, ported from wcgw's `FastPathAnalyzer`.
2//!
3//! wcgw ships a tiny unigram language model trained over repo paths: a
4//! Hugging Face tokenizer (`paths_tokens.model`) plus a vocab file mapping each
5//! token to its log-probability (`paths_model.vocab`). A path's score is the sum
6//! of the log-probabilities of its tokens — higher (less negative) means the
7//! path looks more like a "real source file worth showing" and less like noise.
8//!
9//! Both assets are embedded so ranking works offline with zero setup, matching
10//! the wcgw package that bundles them alongside `repo_context.py`.
11
12use std::collections::HashMap;
13use std::sync::OnceLock;
14use tokenizers::Tokenizer;
15
16static PATHS_MODEL: &[u8] =
17    include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/paths_tokens.model"));
18static PATHS_VOCAB: &[u8] =
19    include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/paths_model.vocab"));
20
21struct PathAnalyzer {
22    tokenizer: Tokenizer,
23    vocab_probs: HashMap<String, f64>,
24}
25
26impl PathAnalyzer {
27    fn load() -> Option<Self> {
28        let tokenizer = match Tokenizer::from_bytes(PATHS_MODEL) {
29            Ok(tokenizer) => tokenizer,
30            Err(error) => {
31                tracing::warn!("Failed to load embedded path-ranking model: {error}");
32                return None;
33            }
34        };
35
36        // Vocab lines are `<token>\t<log_prob>`; mirror wcgw's `split()` + len==2 check.
37        let text = std::str::from_utf8(PATHS_VOCAB).ok()?;
38        let mut vocab_probs = HashMap::new();
39        for line in text.lines() {
40            let parts: Vec<&str> = line.split_whitespace().collect();
41            if parts.len() == 2 {
42                if let Ok(prob) = parts[1].parse::<f64>() {
43                    vocab_probs.insert(parts[0].to_string(), prob);
44                }
45            }
46        }
47
48        Some(Self { tokenizer, vocab_probs })
49    }
50
51    fn sum_log_prob(&self, tokens: &[String]) -> f64 {
52        tokens.iter().filter_map(|token| self.vocab_probs.get(token)).sum()
53    }
54}
55
56fn analyzer() -> Option<&'static PathAnalyzer> {
57    static ANALYZER: OnceLock<Option<PathAnalyzer>> = OnceLock::new();
58    ANALYZER.get_or_init(PathAnalyzer::load).as_ref()
59}
60
61/// Score each path by summed token log-probability (higher = more relevant).
62///
63/// Returns `None` if the model failed to load, so callers can fall back to a
64/// heuristic ordering instead of silently mis-ranking everything.
65pub fn score_paths(paths: &[String]) -> Option<Vec<f64>> {
66    let analyzer = analyzer()?;
67    let scores = paths
68        .iter()
69        .map(|path| match analyzer.tokenizer.encode(path.as_str(), false) {
70            Ok(encoding) => analyzer.sum_log_prob(encoding.get_tokens()),
71            // Unencodable path sinks to the bottom rather than poisoning the batch.
72            Err(_) => f64::MIN,
73        })
74        .collect();
75    Some(scores)
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    #[test]
83    fn ranks_source_above_noise_when_model_present() {
84        let paths =
85            vec!["src/main.rs".to_string(), "a/b/c/d/e/f/zzz_tmp_garbage_9f8a.bin".to_string()];
86        if let Some(scores) = score_paths(&paths) {
87            assert_eq!(scores.len(), 2);
88            // A normal source path should not score worse than deep random noise.
89            assert!(scores[0] >= scores[1], "expected src/main.rs >= noise path");
90        }
91    }
92}