1use std::collections::HashMap;
2
3pub struct FileScorer;
6
7impl FileScorer {
8 pub fn score_files(
11 task: &str,
12 files: &[crate::memory::FileEntry],
13 symbols: &[crate::memory::SymbolEntry],
14 ) -> Vec<(String, f64)> {
15 let task_words: Vec<String> = tokenize(task);
16 if task_words.is_empty() {
17 return files.iter().map(|f| (f.path.clone(), 0.0)).collect();
18 }
19
20 let mut scores: HashMap<String, f64> = HashMap::new();
21
22 for file in files {
24 let path_words = tokenize(&file.path);
25 let score = cosine_similarity_words(&task_words, &path_words);
26 *scores.entry(file.path.clone()).or_insert(0.0) += score * 2.0;
27 }
28
29 for sym in symbols {
31 let sym_words = tokenize(&sym.name);
32 let score = cosine_similarity_words(&task_words, &sym_words);
33 *scores.entry(sym.file.clone()).or_insert(0.0) += score * 3.0;
34 }
35
36 let mut result: Vec<(String, f64)> = scores.into_iter().collect();
37 result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
38 result
39 }
40
41 pub fn top_files(
43 task: &str,
44 files: &[crate::memory::FileEntry],
45 symbols: &[crate::memory::SymbolEntry],
46 max: usize,
47 ) -> Vec<String> {
48 let mut scored = Self::score_files(task, files, symbols);
49
50 let task_lower = task.to_lowercase();
52 for file in files {
53 if task_lower.contains(&file.path.to_lowercase()) {
54 if !scored.iter().any(|(p, _)| p == &file.path) {
55 scored.push((file.path.clone(), 10.0));
56 }
57 }
58 }
59
60 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
61 scored.truncate(max);
62 scored.into_iter().map(|(p, _)| p).collect()
63 }
64}
65
66fn tokenize(text: &str) -> Vec<String> {
67 text.split(|c: char| !c.is_alphanumeric() && c != '_' && c != '-' && c != '/')
68 .filter(|w| w.len() > 1)
69 .map(|w| w.to_lowercase())
70 .collect()
71}
72
73fn cosine_similarity_words(a: &[String], b: &[String]) -> f64 {
75 let mut freq_a: HashMap<&str, f64> = HashMap::new();
76 let mut freq_b: HashMap<&str, f64> = HashMap::new();
77
78 for w in a {
79 *freq_a.entry(w.as_str()).or_insert(0.0) += 1.0;
80 }
81 for w in b {
82 *freq_b.entry(w.as_str()).or_insert(0.0) += 1.0;
83 }
84
85 let all_words: std::collections::HashSet<&str> =
86 freq_a.keys().chain(freq_b.keys()).copied().collect();
87 let mut dot = 0.0;
88 let mut norm_a = 0.0;
89 let mut norm_b = 0.0;
90
91 for w in &all_words {
92 let va = freq_a.get(w).copied().unwrap_or(0.0);
93 let vb = freq_b.get(w).copied().unwrap_or(0.0);
94 dot += va * vb;
95 norm_a += va * va;
96 norm_b += vb * vb;
97 }
98
99 if norm_a == 0.0 || norm_b == 0.0 {
100 0.0
101 } else {
102 dot / (norm_a.sqrt() * norm_b.sqrt())
103 }
104}
105
106pub const PYTHON_KERNEL: &str = r#"
109import sys, json, traceback, io
110namespace = {}
111_stdout = io.StringIO()
112while True:
113 line = sys.stdin.readline()
114 if not line:
115 break
116 try:
117 req = json.loads(line)
118 _stdout = io.StringIO()
119 sys.stdout = _stdout
120 if req.get('mode') == 'eval':
121 result = eval(req['code'], namespace)
122 else:
123 exec(req['code'], namespace)
124 result = None
125 sys.stdout = sys.__stdout__
126 stdout_val = _stdout.getvalue()
127 response = {"id": req['id'], "result": str(result) if result is not None else None, "stdout": stdout_val}
128 except Exception as e:
129 sys.stdout = sys.__stdout__
130 response = {"id": req['id'], "error": traceback.format_exc()}
131 print(json.dumps(response), flush=True)
132"#;