winx_code_agent/utils/
path_prob.rs1use std::collections::HashMap;
13use std::sync::OnceLock;
14use tokenizers::Tokenizer;
15
16static PATHS_MODEL: &[u8] =
17 include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/paths_tokens.model"));
18static PATHS_VOCAB: &[u8] =
19 include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/assets/paths_model.vocab"));
20
21struct PathAnalyzer {
22 tokenizer: Tokenizer,
23 vocab_probs: HashMap<String, f64>,
24}
25
26impl PathAnalyzer {
27 fn load() -> Option<Self> {
28 let tokenizer = match Tokenizer::from_bytes(PATHS_MODEL) {
29 Ok(tokenizer) => tokenizer,
30 Err(error) => {
31 tracing::warn!("Failed to load embedded path-ranking model: {error}");
32 return None;
33 }
34 };
35
36 let text = std::str::from_utf8(PATHS_VOCAB).ok()?;
38 let mut vocab_probs = HashMap::new();
39 for line in text.lines() {
40 let parts: Vec<&str> = line.split_whitespace().collect();
41 if parts.len() == 2 {
42 if let Ok(prob) = parts[1].parse::<f64>() {
43 vocab_probs.insert(parts[0].to_string(), prob);
44 }
45 }
46 }
47
48 Some(Self { tokenizer, vocab_probs })
49 }
50
51 fn sum_log_prob(&self, tokens: &[String]) -> f64 {
52 tokens.iter().filter_map(|token| self.vocab_probs.get(token)).sum()
53 }
54}
55
56fn analyzer() -> Option<&'static PathAnalyzer> {
57 static ANALYZER: OnceLock<Option<PathAnalyzer>> = OnceLock::new();
58 ANALYZER.get_or_init(PathAnalyzer::load).as_ref()
59}
60
61pub fn score_paths(paths: &[String]) -> Option<Vec<f64>> {
66 let analyzer = analyzer()?;
67 let scores = paths
68 .iter()
69 .map(|path| match analyzer.tokenizer.encode(path.as_str(), false) {
70 Ok(encoding) => analyzer.sum_log_prob(encoding.get_tokens()),
71 Err(_) => f64::MIN,
73 })
74 .collect();
75 Some(scores)
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn ranks_source_above_noise_when_model_present() {
84 let paths =
85 vec!["src/main.rs".to_string(), "a/b/c/d/e/f/zzz_tmp_garbage_9f8a.bin".to_string()];
86 if let Some(scores) = score_paths(&paths) {
87 assert_eq!(scores.len(), 2);
88 assert!(scores[0] >= scores[1], "expected src/main.rs >= noise path");
90 }
91 }
92}